tensorflow1,TensorFlow model

 2023-10-05 阅读 17 评论 0

摘要:兩種創建model的方式 1:鏈式函數創建 要創建輸入層inputs import tensorflow as tfinputs = tf.keras.Input(shape=(3,)) x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs) outputs = tf.keras.layers.Dense(5, activation=tf.nn.softma

兩種創建model的方式
1:鏈式函數創建
要創建輸入層inputs

import tensorflow as tfinputs = tf.keras.Input(shape=(3,))
x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

2:使用對象創建

import tensorflow as tfclass MyModel(tf.keras.Model):def __init__(self):super(MyModel, self).__init__()self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)def call(self, inputs):x = self.dense1(inputs)return self.dense2(x)model = MyModel()

屬性

屬性描述
layers
metrics_names所有輸出的標簽
run_eagerly是否使用eagerly模式,默認False,靜態圖
sample_weights
state_updates

方法

  1. compile
compile(optimizer,loss=None,metrics=None,loss_weights=None,sample_weight_mode=None,weighted_metrics=None,target_tensors=None,distribute=None,**kwargs
)
參數描述
optimizer(string,Object)優化器
loss(String,Object,Function),如果模型有多個輸出,可以為不同的輸出指定不同的損失函數
metrics(List(String))衡量指標,比如[‘accuracy’,‘mse’]
loss_weights
sample_weight_mode
weighted_metrics
target_tensors
distribute
**kwargs

evaluate

evaluate(x=None,y=None,batch_size=None,verbose=1,sample_weight=None,steps=None,callbacks=None,max_queue_size=10,workers=1,use_multiprocessing=False
)
參數描述
x(numpy array;tensor;[tensor];dict;tf.data;keras.utils.Sequence)
y
batch_size(int)每一次梯度下降使用的樣本數量.默認為32,如果輸入數據已經指定了batch_size,則不要再次指定
verbose
sample_weight
steps(int)執行多少個batch之后打印日志信息,默認,一個epoch,打印一次
callbacks
max_queue_size
workers
use_multiprocessing

evaluate_generator

evaluate_generator(generator,steps=None,callbacks=None,max_queue_size=10,workers=1,use_multiprocessing=False,verbose=0
)

fit

fit_generator

get_layer

get_layer(name=None,index=None
)

load_weights

load_weights(filepath,by_name=False
)

predict

predict(x,batch_size=None,verbose=0,steps=None,callbacks=None,max_queue_size=10,workers=1,use_multiprocessing=False
)

predict_generator

predict_generator(generator,steps=None,callbacks=None,max_queue_size=10,workers=1,use_multiprocessing=False,verbose=0
)

predict_on_batch

predict_on_batch(x)

reset_metrics

reset_states

save

保存模型為HDF5文件

save_weights

summary

summary(line_length=None,positions=None,print_fn=None
)

test_on_batch

to_json

to_yaml

train_on_batch

tensorflow1,參考:
https://tensorflow.google.cn/versions/r2.0/api_docs/python/tf/keras/Model

版权声明:本站所有资料均为网友推荐收集整理而来,仅供学习和研究交流使用。

原文链接:https://hbdhgg.com/3/116087.html

发表评论:

本站为非赢利网站,部分文章来源或改编自互联网及其他公众平台,主要目的在于分享信息,版权归原作者所有,内容仅供读者参考,如有侵权请联系我们删除!

Copyright © 2022 匯編語言學習筆記 Inc. 保留所有权利。

底部版权信息