基于tensorflow。深層網絡在學習任務中取得了超越人眼的準確率,但是,經過實驗表明,模型的性能和模型的深度并非成正比,是由于模型的表達能力過強,反而在測試數據集中性能下降。ResNet的核心是,為了防止梯度彌散或爆炸,讓信息流經快捷連接到達淺層。
更正式的講,輸入xxx通過卷積層,得到特征變換后的輸出F(x)F(x)F(x),與輸入xxx進行對應元素的相加運算,得到最終輸出H(x)H(x)H(x):
H(x)=x+F(x)H(x) = x + F(x)H(x)=x+F(x)
VGG模塊和殘差模塊對比如下:
為了能夠滿足輸入xxx與卷積層的輸出F(x)F(x)F(x)能夠相加運算,需要輸入xxx的 shape 與F(x)F(x)F(x)的shape 完全一致。當出現 shape 不一致時,一般通過Conv2D進行變換,該Conv2D的核為1×1,步幅為2。
使用tensorflow2.3實現ResNet
import numpy as np
import tensorflow as tf
from tensorflow import keras
from matplotlib import pyplot as plt
import os
import math
"""
用于控制模型層數
"""
#殘差塊數
n = 3
depth = n * 9 + 1
def resnet_layer(inputs,num_filters=16,kernel_size=3,strides=1,activation='relu',batch_normalization=True,conv_first=True):"""2D Convolution-Batch Normalization-Activation stack builderArguments:inputs (tensor): 輸入num_filters (int): 卷積核個數kernel_size (int): 卷積核大小activation (string): 激活層batch_normalization (bool): 是否使用批歸一化conv_first (bool): conv-bn-active(True) or bn-active-conv (False)層堆疊次序Returns:x (tensor): 輸出"""conv = keras.layers.Conv2D(num_filters,kernel_size=kernel_size,strides=strides,padding='same',kernel_initializer='he_normal',kernel_regularizer=keras.regularizers.l2(1e-4))x = inputsif conv_first:x = conv(x)if batch_normalization:x = keras.layers.BatchNormalization()(x)if activation is not None:x = keras.layers.Activation(activation)(x)else:if batch_normalization:x = keras.layers.BatchNormalization()(x)if activation is not None:x = keras.layers.Activation(activation)(x)x = conv(x)return xdef resnet(input_shape,depth,num_classes=10):"""ResNetArguments:input_shape (tensor): 輸入尺寸depth (int): 網絡層數num_classes (int): 預測類別數Return:model (Model): 模型"""if (depth - 2) % 6 != 0:raise ValueError('depth should be 6n+2')#超參數num_filters = 16num_res_blocks = int((depth - 2) / 6)inputs = keras.layers.Input(shape=input_shape)x = resnet_layer(inputs=inputs)for stack in range(3):for res_block in range(num_res_blocks):strides = 1if stack > 0 and res_block == 0:strides = 2y = resnet_layer(inputs=x,num_filters=num_filters,strides=strides)y = resnet_layer(inputs=y,num_filters=num_filters,activation=None)if stack > 0 and res_block == 0:x = resnet_layer(inputs=x,num_filters=num_filters,kernel_size=1,strides=strides,activation=None,batch_normalization=False)x = keras.layers.add([x,y])x = keras.layers.Activation('relu')(x)num_filters *= 2x = keras.layers.AveragePooling2D(pool_size=8)(x)x = keras.layers.Flatten()(x)outputs = keras.layers.Dense(num_classes,activation='softmax',kernel_initializer='he_normal')(x)model = keras.Model(inputs=inputs,outputs=outputs)return modelmodel = resnet_v1(input_shape=input_shape,depth=depth)
#加載數據
(x_train,y_train),(x_test,y_test) = keras.datasets.cifar10.load_data()#計算類別數
num_labels = len(np.unique(y_train))#轉化為one-hot編碼
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)#預處理
input_shape = x_train.shape[1:]
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
#超參數
batch_size = 64
epochs = 200
#編譯模型
model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['acc'])
model.summary()
model.fit(x_train,y_train,batch_size=batch_size,epochs=epochs,validation_data=(x_test,y_test),shuffle=True)
scores = model.evaluate(x_test,y_test,batch_size=batch_size,verbose=0)
print('Test loss: ',scores[0])
print('Test accuracy: ',scores[1])
Epoch 104/200
782/782 [==============================] - ETA: 0s - loss: 0.2250 - acc: 0.9751
Epoch 00104: val_acc did not improve from 0.91140
782/782 [==============================] - 15s 19ms/step - loss: 0.2250 - acc: 0.9751 - val_loss: 0.4750 - val_acc: 0.9090
learning rate: 0.0001
Epoch 105/200
781/782 [============================>.] - ETA: 0s - loss: 0.2206 - acc: 0.9754
Epoch 00105: val_acc did not improve from 0.91140
782/782 [==============================] - 16s 20ms/step - loss: 0.2206 - acc: 0.9754 - val_loss: 0.4687 - val_acc: 0.9078
learning rate: 0.0001
Epoch 106/200
782/782 [==============================] - ETA: 0s - loss: 0.2160 - acc: 0.9769
Epoch 00106: val_acc did not improve from 0.91140
782/782 [==============================] - 15s 20ms/step - loss: 0.2160 - acc: 0.9769 - val_loss: 0.4886 - val_acc: 0.9053
版权声明:本站所有资料均为网友推荐收集整理而来,仅供学习和研究交流使用。
工作时间:8:00-18:00
客服电话
电子邮件
admin@qq.com
扫码二维码
获取最新动态