兩種創建model
的方式
1:鏈式函數創建
要創建輸入層inputs
import tensorflow as tfinputs = tf.keras.Input(shape=(3,))
x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
2:使用對象創建
import tensorflow as tfclass MyModel(tf.keras.Model):def __init__(self):super(MyModel, self).__init__()self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)def call(self, inputs):x = self.dense1(inputs)return self.dense2(x)model = MyModel()
屬性 | 描述 |
---|---|
layers | 層 |
metrics_names | 所有輸出的標簽 |
run_eagerly | 是否使用eagerly模式,默認False ,靜態圖 |
sample_weights | |
state_updates |
compile(optimizer,loss=None,metrics=None,loss_weights=None,sample_weight_mode=None,weighted_metrics=None,target_tensors=None,distribute=None,**kwargs
)
參數 | 描述 |
---|---|
optimizer | (string,Object)優化器 |
loss | (String,Object,Function),如果模型有多個輸出,可以為不同的輸出指定不同的損失函數 |
metrics | (List(String))衡量指標,比如[‘accuracy’,‘mse’] |
loss_weights | |
sample_weight_mode | |
weighted_metrics | |
target_tensors | |
distribute | |
**kwargs |
evaluate(x=None,y=None,batch_size=None,verbose=1,sample_weight=None,steps=None,callbacks=None,max_queue_size=10,workers=1,use_multiprocessing=False
)
參數 | 描述 |
---|---|
x | (numpy array;tensor;[tensor];dict;tf.data;keras.utils.Sequence) |
y | |
batch_size | (int)每一次梯度下降使用的樣本數量.默認為32,如果輸入數據已經指定了batch_size,則不要再次指定 |
verbose | |
sample_weight | |
steps | (int)執行多少個batch之后打印日志信息,默認,一個epoch,打印一次 |
callbacks | |
max_queue_size | |
workers | |
use_multiprocessing |
evaluate_generator(generator,steps=None,callbacks=None,max_queue_size=10,workers=1,use_multiprocessing=False,verbose=0
)
get_layer(name=None,index=None
)
load_weights(filepath,by_name=False
)
predict(x,batch_size=None,verbose=0,steps=None,callbacks=None,max_queue_size=10,workers=1,use_multiprocessing=False
)
predict_generator(generator,steps=None,callbacks=None,max_queue_size=10,workers=1,use_multiprocessing=False,verbose=0
)
predict_on_batch(x)
保存模型為HDF5文件
summary(line_length=None,positions=None,print_fn=None
)
tensorflow1,參考:
https://tensorflow.google.cn/versions/r2.0/api_docs/python/tf/keras/Model
版权声明:本站所有资料均为网友推荐收集整理而来,仅供学习和研究交流使用。
工作时间:8:00-18:00
客服电话
电子邮件
admin@qq.com
扫码二维码
获取最新动态