詳細分析Tensorflow,Tensorflow:estimator訓練

 2023-12-09 阅读 22 评论 0

摘要:學習流程:Estimator 封裝了對機器學習不同階段的控制,用戶無需不斷的為新機器學習任務重復編寫訓練、評估、預測的代碼。可以專注于對網絡結構的控制。 數據導入:Estimator 的數據導入也是由 input_fn 獨立定義的。例如,用戶可以非常方便的只通過

學習流程:Estimator 封裝了對機器學習不同階段的控制,用戶無需不斷的為新機器學習任務重復編寫訓練、評估、預測的代碼。可以專注于對網絡結構的控制。
數據導入:Estimator 的數據導入也是由 input_fn 獨立定義的。例如,用戶可以非常方便的只通過改變 input_fn 的定義,來使用相同的網絡結構學習不同的數據。
網絡結構:Estimator 的網絡結構是在 model_fn 中獨立定義的,用戶創建的任何網絡結構都可以在 Estimator 的控制下進行機器學習。這可以允許用戶很方便的使用別人定義好的 model_fn。model_fn模型函數必須要有features, mode兩個參數,可自己選擇加入labels(可以把label也放進features中)。最后要返回特定的tf.estimator.EstimatorSpec()。模型有三個階段都共用的正向傳播部分,和由mode值來控制返回不同tf.estimator.EstimatorSpec的三個分支。

?

?

訓練

輸出信息解析

[Tensorflow:模型訓練tensorflow.train]

在訓練或評估中利用Hook打印中間信息

hooks:如果不送值,則訓練過程中不會顯示字典中的數值。

詳細分析Tensorflow、steps:指定了訓練多少次,如果不送值,則訓練到dataset API遍歷完數據集為止。

max_steps:指定了最大訓練次數。

# 在訓練或評估的循環中,每50次print出一次字典中的數值
tensors_to_log = {"probabilities": "softmax_tensor"}
logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log, every_n_iter=50)
mnist_classifier.train(input_fn=train_input_fn, hooks=[logging_hook])

?

early stopping

函數原型

tf.contrib.estimator.stop_if_no_increase_hook(
? ? estimator,
? ? metric_name,
? ? max_steps_without_increase,
? ? eval_dir=None,
? ? min_steps=0,
? ? run_every_secs=60,
? ? run_every_steps=None
)

'stop_if_no_decrease_hook'這個模塊在tf 1.10才加入。hook可以看作一個管理訓練過程的工具,比如說這里就是設置提前終止的條件,變量loss在100000步以內沒有下降即終止,實際上更廣泛的用法是用在對測試集的f1值上。

參數

tensorflow模型訓練、metric_name: str類型,比如loss或者accuracy.?hook中的參數metric_name='acc'就是tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=metrics)中的eval_metric_ops,即tf模塊代碼中通過的for step, metrics in read_eval_metrics(eval_dir).items()得到的。但是訓練好checkpoint后,就不能改,需要刪除之前訓練好的模型,重新訓練。

max_steps_without_increase: int,如果沒有增加的最大長是多少,如果超過了這個最大步長metric還是沒有增加那么就會停止。

eval_dir:默認是使用estimator.eval_dir目錄,用于存放評估的summary file。

min_steps:訓練的最小步長,如果訓練小于這個步長那么永遠都不會停止。

run_every_secs和run_every_steps:表示多長時間獲得步長調用一次should_stop_fn。

示例

? ? ? ? metrics = {
? ? ? ? ? ? 'acc': tf.metrics.accuracy(tf.argmax(labels), tf.argmax(pred_ids)),
? ? ? ? ? ? 'precision': tf.metrics.precision(tf.argmax(labels), tf.argmax(pred_ids)),
? ? ? ? ? ? 'precision_': tf_metrics.precision(tf.argmax(labels), tf.argmax(pred_ids), num_labels),
? ? ? ? ? ? 'recall': tf.metrics.recall(tf.argmax(labels), tf.argmax(pred_ids)),
? ? ? ? ? ? 'recall_': tf_metrics.recall(tf.argmax(labels), tf.argmax(pred_ids), num_labels),
? ? ? ? ? ? 'f1_': tf_metrics.f1(tf.argmax(labels), tf.argmax(pred_ids), num_labels),
? ? ? ? ? ? 'auc': tf.metrics.auc(labels, pred_ids),
? ? ? ? }

tensorflow2中文教程、? ? ? ? for metric_name, op in metrics.items():
? ? ? ? ? ? tf.summary.scalar(metric_name, op[1])

? ? ? ? ''' train and evaluate '''
? ? ? ? if mode == tf.estimator.ModeKeys.EVAL:
? ? ? ? ? ? return tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=metrics)
? ? ? ? elif mode == tf.estimator.ModeKeys.TRAIN:
? ? ? ? ? ? train_op = tf.train.AdamOptimizer().minimize(loss=loss,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?global_step=tf.train.get_or_create_global_step())

...

? ? ? ? hook = tf.contrib.estimator.stop_if_no_increase_hook(estimator, 'f1', max_steps_without_increase=1000,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?min_steps=8000, run_every_secs=120)?
? ? ? ? train_spec = tf.estimator.TrainSpec(input_fn=train_inpf, hooks=[hook])

[簡書tf.estimate]

from:?-柚子皮-

Tensorflow lite、ref:

?

版权声明:本站所有资料均为网友推荐收集整理而来,仅供学习和研究交流使用。

原文链接:https://hbdhgg.com/2/194051.html

发表评论:

本站为非赢利网站,部分文章来源或改编自互联网及其他公众平台,主要目的在于分享信息,版权归原作者所有,内容仅供读者参考,如有侵权请联系我们删除!

Copyright © 2022 匯編語言學習筆記 Inc. 保留所有权利。

底部版权信息