什麼是Tensorflow的Assign Operator:以tf.assign實作Counter
這篇文章會介紹 tf.assign
這個 operator,再使用tf.assign
實作簡單的計數器(counter)。
給Pytorch用戶的小前言
如果你也是從Pytorch起家的開發者,相信你一定也對Tensorflow的assign operation感到滿頭問號。「奇怪~要對變數賦值,不就直接等號就好嗎?」很可惜,在Tensorflow的世界劇情並不是這樣。
別忘了,Tensorflow是一套完整定義的可硬體加速計算系統,而非僅僅一個深度學習套件。
如何使用Assign Operator
考慮一個簡單的小問題:程式中需要一個變數,統計跑了某一區塊的計算幾次。
如果只是一般的python程式,我大部分人會這樣寫:就開個變數、寫個for loop、每次都+1就好。但假設把這一切如法炮製到Tensorflow會發生什麼事呢?
# Terminal Output
1
1
1
1
1
很可惜這樣寫法,實際上只會得到五個1而已。
正確的寫法
剛剛上面的寫法問題在於,不管你跑幾次,var
這個tf.Variable
的值都從未更新過,所以需要使用tf.assign
指定新的數值。
tf.assign
的語法很簡單(先忽略後三個args
),簡單來說就是把ref
的數值變成value
,而ref
需要是tf.Variable
。
所以只要在剛剛的程式上加上counter = tf.assign(var, counter)
,就可以正確地得到所要的功能了。
# Terminal Output
1
2
3
4
5
範例程式碼可以直接到我的Google Colaboratory執行。
這個簡單的 counter最實際的例子就是計算 learning step。
當然,一般的時候會直接使用tf.train.global_step
。
雖然這次只是用tf.assign
做一件非常非常簡單的事情,但tf.assign
與相關的Assign operator是Tensorflow計算系統中很根本的東西。很多high level的function都是用tf.assign
實作的,當然,學會Assign operator,別忘了順便學學Control dependency(tf.control_dependencies
)這個重要的配套。