基于tensorflow,深度殘差網絡(ResNet)詳解與實現(tensorflow2.x)

 2023-11-19 阅读 20 评论 0

摘要:深度殘差網絡(ResNet)詳解與實現(tensorflow2.x)ResNet原理ResNet實現模型創建數據加載模型編譯模型訓練測試模型訓練過程 ResNet原理 基于tensorflow。深層網絡在學習任務中取得了超越人眼的準確率,但是,經過實驗表明,模型

深度殘差網絡(ResNet)詳解與實現(tensorflow2.x)

    • ResNet原理
    • ResNet實現
      • 模型創建
      • 數據加載
      • 模型編譯
      • 模型訓練
      • 測試模型
      • 訓練過程

ResNet原理

基于tensorflow。深層網絡在學習任務中取得了超越人眼的準確率,但是,經過實驗表明,模型的性能和模型的深度并非成正比,是由于模型的表達能力過強,反而在測試數據集中性能下降。ResNet的核心是,為了防止梯度彌散或爆炸,讓信息流經快捷連接到達淺層。
更正式的講,輸入xxx通過卷積層,得到特征變換后的輸出F(x)F(x)F(x),與輸入xxx進行對應元素的相加運算,得到最終輸出H(x)H(x)H(x):
H(x)=x+F(x)H(x) = x + F(x)H(x)=x+F(x)
VGG模塊和殘差模塊對比如下:

VGG模塊和殘差模塊
為了能夠滿足輸入xxx與卷積層的輸出F(x)F(x)F(x)能夠相加運算,需要輸入xxx的 shape 與F(x)F(x)F(x)的shape 完全一致。當出現 shape 不一致時,一般通過Conv2D進行變換,該Conv2D的核為1×1,步幅為2。

ResNet實現

使用tensorflow2.3實現ResNet

模型創建

import numpy as np
import tensorflow as tf
from tensorflow import keras
from matplotlib import pyplot as plt
import os
import math
"""
用于控制模型層數
"""
#殘差塊數
n = 3
depth = n * 9 + 1
def resnet_layer(inputs,num_filters=16,kernel_size=3,strides=1,activation='relu',batch_normalization=True,conv_first=True):"""2D Convolution-Batch Normalization-Activation stack builderArguments:inputs (tensor): 輸入num_filters (int): 卷積核個數kernel_size (int): 卷積核大小activation (string): 激活層batch_normalization (bool): 是否使用批歸一化conv_first (bool): conv-bn-active(True) or bn-active-conv (False)層堆疊次序Returns:x (tensor): 輸出"""conv = keras.layers.Conv2D(num_filters,kernel_size=kernel_size,strides=strides,padding='same',kernel_initializer='he_normal',kernel_regularizer=keras.regularizers.l2(1e-4))x = inputsif conv_first:x = conv(x)if batch_normalization:x = keras.layers.BatchNormalization()(x)if activation is not None:x = keras.layers.Activation(activation)(x)else:if batch_normalization:x = keras.layers.BatchNormalization()(x)if activation is not None:x = keras.layers.Activation(activation)(x)x = conv(x)return xdef resnet(input_shape,depth,num_classes=10):"""ResNetArguments:input_shape (tensor): 輸入尺寸depth (int): 網絡層數num_classes (int): 預測類別數Return:model (Model): 模型"""if (depth - 2) % 6 != 0:raise ValueError('depth should be 6n+2')#超參數num_filters = 16num_res_blocks = int((depth - 2) / 6)inputs = keras.layers.Input(shape=input_shape)x = resnet_layer(inputs=inputs)for stack in range(3):for res_block in range(num_res_blocks):strides = 1if stack > 0 and res_block == 0:strides = 2y = resnet_layer(inputs=x,num_filters=num_filters,strides=strides)y = resnet_layer(inputs=y,num_filters=num_filters,activation=None)if stack > 0 and res_block == 0:x = resnet_layer(inputs=x,num_filters=num_filters,kernel_size=1,strides=strides,activation=None,batch_normalization=False)x = keras.layers.add([x,y])x = keras.layers.Activation('relu')(x)num_filters *= 2x = keras.layers.AveragePooling2D(pool_size=8)(x)x = keras.layers.Flatten()(x)outputs = keras.layers.Dense(num_classes,activation='softmax',kernel_initializer='he_normal')(x)model = keras.Model(inputs=inputs,outputs=outputs)return modelmodel = resnet_v1(input_shape=input_shape,depth=depth)

數據加載

#加載數據
(x_train,y_train),(x_test,y_test) = keras.datasets.cifar10.load_data()#計算類別數
num_labels = len(np.unique(y_train))#轉化為one-hot編碼
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)#預處理
input_shape = x_train.shape[1:]
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.

模型編譯

#超參數
batch_size = 64
epochs = 200
#編譯模型
model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['acc'])
model.summary()

模型訓練

model.fit(x_train,y_train,batch_size=batch_size,epochs=epochs,validation_data=(x_test,y_test),shuffle=True)

測試模型

scores = model.evaluate(x_test,y_test,batch_size=batch_size,verbose=0)
print('Test loss: ',scores[0])
print('Test accuracy: ',scores[1])

訓練過程

Epoch 104/200
782/782 [==============================] - ETA: 0s - loss: 0.2250 - acc: 0.9751 
Epoch 00104: val_acc did not improve from 0.91140
782/782 [==============================] - 15s 19ms/step - loss: 0.2250 - acc: 0.9751 - val_loss: 0.4750 - val_acc: 0.9090
learning rate:  0.0001
Epoch 105/200
781/782 [============================>.] - ETA: 0s - loss: 0.2206 - acc: 0.9754 
Epoch 00105: val_acc did not improve from 0.91140
782/782 [==============================] - 16s 20ms/step - loss: 0.2206 - acc: 0.9754 - val_loss: 0.4687 - val_acc: 0.9078
learning rate:  0.0001
Epoch 106/200
782/782 [==============================] - ETA: 0s - loss: 0.2160 - acc: 0.9769 
Epoch 00106: val_acc did not improve from 0.91140
782/782 [==============================] - 15s 20ms/step - loss: 0.2160 - acc: 0.9769 - val_loss: 0.4886 - val_acc: 0.9053

版权声明:本站所有资料均为网友推荐收集整理而来,仅供学习和研究交流使用。

原文链接:https://hbdhgg.com/1/182846.html

发表评论:

本站为非赢利网站,部分文章来源或改编自互联网及其他公众平台,主要目的在于分享信息,版权归原作者所有,内容仅供读者参考,如有侵权请联系我们删除!

Copyright © 2022 匯編語言學習筆記 Inc. 保留所有权利。

底部版权信息