用什么方法可以实现深度残差网络?
2020-12-01 18:01:54 +08:00
fanqieipnet
AlexNet,VGG,GoogLeNet 等网络模型的出现将神经网络的发展带入了几十层的阶段,研究人员发现网络的层数越深,越有可能获得更好的泛化能力。但是当模型加深以后,网络变得越来越难训练,这主要是由于梯度弥散现象造成的。在较深层数的神经网络中间,梯度信息由网络的末层逐层传向网络的首层时,传递的过程中会出现梯度接近于 0 的现象。网络层数越深,梯度弥散现象可能会越严重。用什么方法可以实现深度残差网络?今天番茄加速就来分析一下。
为了解决这个问题,研究人员尝试给深层神经网络添加一种回退到浅层神经网络的机制。当深层神经网络可以轻松地回退到浅层神经网络时,深层神经网络可以获得与浅层神经网络相当的模型性能,而不至于更糟糕。
通过在输入和输出之间添加一条直接连接的 Skip Connection 可以让神经网络具有回退的能力。以 VGG13 深度神经网络为例,假设观察到 VGG13 模型出现梯度弥散现象,而 10 层的网络模型并没有观测到梯度弥散现象,那么可以考虑在最后的两个卷积层添加 Skip Connection,通过这种方式网络模型可以自动选择是否经由这两个卷积层完成特征变换,还是直接跳过这两个卷积层而选择 Skip Connection,亦或结合两个卷积层和 Skip Connection 的输出。
ResNet 通过在卷积层的输入和输出之间添加 Skip Connection 实现层数回退机制,输入𝑥通过两个卷积层,得到特征变换后的输出ℱ(𝑥),与输入𝑥进行对应元素的相加运算,得到最终输出:H(x) = x + f(x)
ResNet 实现
1.定义残差模块
首先创建一个新类,在初始化阶段创建残差块中需要的卷积层,激活函数层等,首先新建 f(𝑥)卷积层:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Sequential
class BasicBlock(layers.Layer):
# 残差模块
def __init__(self, filter_num, stride=1):
super(BasicBlock, self).__init__()
# 第一个卷积单元
self.conv1 = layers.Conv2D(filter_num, (3, 3), strides=stride, padding='same')
self.bn1 = layers.BatchNormalization()
self.relu = layers.Activation('relu')
# 第二个卷积单元
self.conv2 = layers.Conv2D(filter_num, (3, 3), strides=1, padding='same')
self.bn2 = layers.BatchNormalization()
if stride != 1:# 通过 1x1 卷积完成 shape 匹配
self.downsample = Sequential()
self.downsample.add(layers.Conv2D(filter_num, (1, 1), strides=stride))
else:# shape 匹配,直接连接
self.downsample = lambda x:x
def call(self, inputs, training=None):
# [b, h, w, c],通过第一个卷积单元
out = self.conv1(inputs)
out = self.bn1(out)
out = self.relu(out)
# 通过第二个卷积单元
out = self.conv2(out)
out = self.bn2(out)
# 通过 identity 模块
identity = self.downsample(inputs)
# 2 条路径输出直接相加
output = layers.add([out, identity])
output = tf.nn.relu(output) # 激活函数
return output
2.实现 ResNet 类
class ResNet(keras.Model):
# 通用的 ResNet 实现类
def __init__(self, layer_dims, num_classes=10): # [2, 2, 2, 2]
super(ResNet, self).__init__()
# 根网络,预处理
self.stem = Sequential([layers.Conv2D(64, (3, 3), strides=(1, 1)),
layers.BatchNormalization(),
layers.Activation('relu'),
layers.MaxPool2D(pool_size=(2, 2), strides=(1, 1), padding='same')
])
# 堆叠 4 个 Block,每个 block 包含了多个 BasicBlock,设置步长不一样
self.layer1 = self.build_resblock(64, layer_dims[0])
self.layer2 = self.build_resblock(128, layer_dims[1], stride=2)
self.layer3 = self.build_resblock(256, layer_dims[2], stride=2)
self.layer4 = self.build_resblock(512, layer_dims[3], stride=2)
# 通过 Pooling 层将高宽降低为 1x1
self.avgpool = layers.GlobalAveragePooling2D()
# 最后连接一个全连接层分类
self.fc = layers.Dense(num_classes)
def call(self, inputs, training=None):
# 通过根网络
x = self.stem(inputs)
# 一次通过 4 个模块
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
# 通过池化层
x = self.avgpool(x)
# 通过全连接层
x = self.fc(x)
return x
def build_resblock(self, filter_num, blocks, stride=1):
# 辅助函数,堆叠 filter_num 个 BasicBlock
res_blocks = Sequential()
# 只有第一个 BasicBlock 的步长可能不为 1,实现下采样
res_blocks.add(BasicBlock(filter_num, stride))
for _ in range(1, blocks):#其他 BasicBlock 步长都为 1
res_blocks.add(BasicBlock(filter_num, stride=1))
return res_blocks
上面就是用 Tensorflow 实现深度残差网络的原理和方法,欢迎指正。
0 条回复
这是一个专为移动设备优化的页面(即为了让你能够在 Google 搜索结果里秒开这个页面),如果你希望参与 V2EX 社区的讨论,你可以继续到 V2EX 上打开本讨论主题的完整版本。
https://www.v2ex.com/t/731078
V2EX 是创意工作者们的社区,是一个分享自己正在做的有趣事物、交流想法,可以遇见新朋友甚至新机会的地方。
V2EX is a community of developers, designers and creative people.