tensorflow中提供了tf.train.ExponentialMovingAverage來實現滑動平均模型,他使用指數衰減來計算變量的移動平均值。
tf.train.ExponentialMovingAverage.__init__(self, decay, num_updates=None, zero_debias=False, name="ExponentialMovingAverage"):
tensorflow.js。decay是衰減率在創建ExponentialMovingAverage對象時,需指定衰減率(decay),用于控制模型的更新速度。影子變量的初始值與訓練變量的初始值相同。當運行變量更新時,每個影子變量都會更新為:
shadowvariable=decay?shadowvariable+(1?decay)?variable
num_updates是ExponentialMovingAverage提供用來動態設置decay的參數,當初始化時提供了參數,即不為none時,每次的衰減率是:
min{decay,(1+num_updates)/(10+num_updates)}
apply()方法添加了訓練變量的影子副本,并保持了其影子副本中訓練變量的移動平均值操作。在每次訓練之后調用此操作,更新移動平均值。
TensorFlow、average()和average_name()方法可以獲取影子變量及其名稱。
decay設置為接近1的值比較合理,通常為:0.999,0.9999等
實例代碼如下:
v1 = tf.Variable(0, dtype=tf.float32) # 定義一個變量,初始值為0
step = tf.Variable(0, trainable=False) # step為迭代輪數變量,控制衰減率
ema = tf.train.ExponentialMovingAverage(0.99, step) # 初始設定衰減率為0.99
maintain_averages_op = ema.apply([v1]) # 更新列表中的變量
with tf.Session() as sess:init_op = tf.global_variables_initializer() # 初始化所有變量
sess.run(init_op)
print(sess.run([v1, ema.average(v1)])) # 輸出初始化后變量v1的值和v1的滑動平均值
sess.run(tf.assign(v1, 5)) # 更新v1的值
sess.run(maintain_averages_op) # 更新v1的滑動平均值
print(sess.run([v1, ema.average(v1)]))
sess.run(tf.assign(step, 10000)) # 更新迭代輪轉數step
sess.run(tf.assign(v1, 10))
sess.run(maintain_averages_op)
print(sess.run([v1, ema.average(v1)]))# 再次更新滑動平均值,
sess.run(maintain_averages_op)
print(sess.run([v1, ema.average(v1)]))# 更新v1的值為15
sess.run(tf.assign(v1, 15))sess.run(maintain_averages_op)
print(sess.run([v1, ema.average(v1)]))
#
# [0.0, 0.0]
# [5.0, 4.5]
# [10.0, 4.5549998]
# [10.0, 4.6094499]
# [15.0, 4.7133551]
tensorflow框架。計算步驟如下:
滑動平均模型的作用是提高測試值上的健壯性。那它是如何實現這個功能的呢?其實滑動平均模型的原理就是一階滯后濾波法,其表達式如下:
上面的實例
**********************************************
輸入?0.0?
輸出計算:
decay = min(0.99,(1+0)/(10+0)) =0.1
輸出 = 0.1 * 0+(1-0.1)*0 = 0
**********************************************
輸入 5.0?
輸出計算:
decay = min(0.99,(1+0)/(10+0)) =0.1
輸出 = 0.1 * 0+(1-0.1)*5= 4.5
**********************************************
輸入 10.0?
輸出計算:
decay = min(0.99,(1+10000)/(10+10000)) =0.99
輸出 = 0.99 * 4.5+(1-0.99)*10= 4.555
**********************************************
輸入 10.0?
輸出計算:
decay = min(0.99,(1+10000)/(10+10000)) =0.99
輸出 = 0.99 * 4.555+(1-0.99)*15= 4.60945
**********************************************
輸入 15.0?
輸出計算:
decay = min(0.99,(1+10000)/(10+10000)) =0.99
輸出 = 0.99 * 4.60945+(1-0.99)*15= 4.713355
**********************************************
參考下面博客
https://blog.csdn.net/kuweicai/article/details/80517284
https://blog.csdn.net/qq_39521554/article/details/79028012
https://www.cnblogs.com/cloud-ken/p/7521609.html
?
版权声明:本站所有资料均为网友推荐收集整理而来,仅供学习和研究交流使用。
工作时间:8:00-18:00
客服电话
电子邮件
admin@qq.com
扫码二维码
获取最新动态