ResNet18代码实现[通俗易懂]

ResNet18代码实现[通俗易懂]importtensorflowastffromtensorflowimportkerasfromtensorflow.kerasimportlayers,Sequential,Model,datasets,optimizers#自定义的预处理函数defpreprocess(x,y):#调用此函数时会自动传入x,y对象,shape为[b,28,28],[b]#标准化到0-1x=2*tf.cast(x,dtype=t…

大家好,又见面了,我是你们的朋友全栈君。

import tensorflow as tf

from tensorflow import keras

from tensorflow.keras import layers, Sequential, Model, datasets, optimizers

# 自定义的预处理函数

def preprocess(x, y):

    # 调用此函数时会自动传入x,y对象,shape为[b,28,28],[b]

    # 标准化到0-1

    x = 2*tf.cast(x, dtype=tf.float32) / 255.-1

    # 转成整型张量

    y = tf.cast(y, dtype=tf.int32)

    # 返回的x,y将替换传入的x,y参数,从而实现数据的预处理功能

    return x, y

# 在线下载,加载CIFAR10数据集

(x,y),(x_test,y_test)= datasets.cifar10.load_data()

# 删除y的一个不必要的维度,[b,1] → [b]

y= tf.squeeze(y,axis= 1)

y_test= tf.squeeze(y_test, axis= 1)

# 打印训练集和测试集的形状

# print(x.shape,y.shape, x_test.shape, y_test.shape)

# 构建训练集对象,随机打乱,预处理,批量化

train_db= tf.data.Dataset.from_tensor_slices((x,y))

train_db= train_db.shuffle(1000).map(preprocess).batch(512)

# 构建测试集对象,预处理,批量化

test_db= tf.data.Dataset.from_tensor_slices((x_test,y_test))

test_db= test_db.map(preprocess).batch(512)

# 从训练集中采样一个Batch,并观察

sample= next(iter(train_db))

# print(‘sample:’,sample[0].shape,sample[1].shape,tf.reduce_min(sample[0]),tf.reduce_max(sample[0]))

class BasicBlock(layers.Layer):

    # 残差模块

    def __init__(self, filter_num, stride= 1):

        super(BasicBlock, self).__init__()

        #第一个卷积单元

        self.conv1= layers.Conv2D(filter_num, kernel_size=(3,3), strides= stride, padding= ‘same’)

        self.bn1= layers.BatchNormalization()

        self.relu= layers.Activation(‘relu’)

        # 第二个卷积单元

        self.conv2= layers.Conv2D(filter_num, kernel_size=(3,3), strides= 1, padding= ‘same’ )

        self.bn2= layers.BatchNormalization()

        # 通过1*1卷积完成shape匹配

        if stride != 1:

            self.downsample= Sequential()

            self.downsample.add(layers.Conv2D(filter_num, kernel_size= (1,1), strides= stride))

        else:   # shape匹配,直接短接

            self.downsample= lambda x:x

   

    def call(self, inputs, training= None):

        # 前向计算函数

        # [b,h,w,c], 通过第一个卷积单元

        out= self.conv1(inputs)

        out= self.bn1(out)

        out= self.relu(out)

        # 通过第二个卷积单元

        out= self.conv2(out)

        out= self.bn2(out)

        # 通过identity模块

        identity= self.downsample(inputs)

        # 两条路径输出直接相加

        output= layers.add([out,identity])

        output= tf.nn.relu(output)

        return output

       

class ResNet(Model):

    def __init__(self, layer_dims, num_classes= 10): #[2,2,2,2]

        super(ResNet, self).__init__()

        # 根网络,预处理

        self.stem= Sequential([

            layers.Conv2D(64, kernel_size= (3,3), strides= (1,1)),

            layers.BatchNormalization(),

            layers.Activation(‘relu’),

            layers.MaxPool2D(pool_size=(2,2), strides=(1,1), padding= ‘same’)

        ])

        # 堆叠4个Block,每个Block包含了多个BasicBlock,设置步长不一样

        self.layer1= self.build_resblock(64, layer_dims[0])

        self.layer2= self.build_resblock(128, layer_dims[1], stride= 2)

        self.layer3= self.build_resblock(256, layer_dims[2], stride= 2)

        self.layer4= self.build_resblock(512, layer_dims[3], stride= 2)

        # 通过Pooling层将高宽降低为1*1

        self.avgpool= layers.GlobalAveragePooling2D()

        # 最后连接一个全连接层分类

        self.fc= layers.Dense(num_classes)

    def build_resblock(self, filter_num, blocks, stride= 1):

        # 辅助函数,堆叠filter_num个BasicBlock

        res_blocks= Sequential()

        # 只有第一个BasicBlock的步长可能不为1, 实现下采样

        res_blocks.add(BasicBlock(filter_num, stride))

        # 其他BasicBlock步长都为1

        for _ in range(1, blocks):

            res_blocks.add(BasicBlock(filter_num, stride= 1))

        return res_blocks

   

    def call(self, inputs, training= None):

        # 前向计算函数:通过根网络

        x= self.stem(inputs)

        # 一次通过4个模块

        x= self.layer1(x)

        x= self.layer2(x)

        x= self.layer3(x)

        x= self.layer4(x)

        # 通过池化层

        x= self.avgpool(x)

        # 通过全连接层

        x= self.fc(x)

        return x

def resnet18():

    # 通过调整模块内部BasicBlock的数量和配置实现不同的ResNet

    return ResNet([2,2,2,2])

# def resnet34():

#     # 通过调整模块内部BasicBlock的数量和配置实现不同的ResNet

#     return ResNet([3,4,6,3])

model = resnet18() # ResNet18网络

model.build(input_shape=(None, 32, 32, 3))

# model.summary() # 统计网络参数

def main():

    optimizer = optimizers.Adam(learning_rate=1e-4)

    for epoch in range(10):

        for step, (x,y) in enumerate(train_db):

            with tf.GradientTape() as tape:

                # [b, 32, 32, 3] => [b, 1, 1, 512]

                logits= model(x)

                y_onehot = tf.one_hot(y, depth=10)

                # compute loss

                loss = tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True)

                loss = tf.reduce_mean(loss)

            # 对所有参数求梯度

            grads= tape.gradient(loss, model.trainable_variables)

            # 自动更新

            optimizer.apply_gradients(zip(grads,model.trainable_variables))

            if step %10 == 0:

                print(epoch, step, ‘loss:’, float(loss))

        total_num = 0

        total_correct = 0

        for x,y in test_db:

            # out = model(x)

            # out = tf.reshape(out, [-1, 512])

            logits = model(x)

            prob = tf.nn.softmax(logits, axis=1)

            pred = tf.argmax(prob, axis=1)

            pred = tf.cast(pred, dtype=tf.int32)

            correct = tf.cast(tf.equal(pred, y), dtype=tf.int32)

            correct = tf.reduce_sum(correct)

            total_num += x.shape[0]

            total_correct += int(correct)

        acc = total_correct / total_num

        print(epoch, ‘acc:’, acc)

if __name__ == ‘__main__’:

    main()

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/141598.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • 因果图方法是根据( )之间的因果关系来设计测试用例的_因果图法符号

    因果图方法是根据( )之间的因果关系来设计测试用例的_因果图法符号原标题:因果推断简介之五:因果图(CausalDiagram)编辑部于2019年10月在微信端开启《朝花夕拾》栏目,目的是推送2013年(含)之前主站发表的优秀文章,微信端与主站的同步始于2013年年初,然而初期用户量有限,故优质文章可能被埋没。这部分介绍JudeaPearl于1995年发表在Biometrika上的工作“Causaldiagramsforempirica…

    2022年8月14日
    5
  • 半监督之mixmatch

    半监督之mixmatch自洽正则化:以前遇到标记数据太少,监督学习泛化能力差的时候,人们一般进行训练数据增广,比如对图像做随机平移,缩放,旋转,扭曲,剪切,改变亮度,饱和度,加噪声等。数据增广能产生无数的修改过的新图像,扩大训练数据集。自洽正则化的思路是,对未标记数据进行数据增广,产生的新数据输入分类器,预测结果应保持自洽。即同一个数据增广产生的样本,模型预测结果应保持一致。此规则被加入到损失函数中,有如下形式,其中x是未标记数据,Augment(x)表示对x做随机增广产生的新数据,θ是模型参数,y是模型预测结.

    2025年8月9日
    2
  • IP地址、子网掩码、网络号、主机号、网络地址、主机地址以及ip段/数字-如192.168.0.1/24是什么意思?「建议收藏」

    IP地址、子网掩码、网络号、主机号、网络地址、主机地址以及ip段/数字-如192.168.0.1/24是什么意思?「建议收藏」背景知识IP地址IP地址被用来给Internet上的电脑一个编号。大家日常见到的情况是每台联网的PC上都需要有IP地址,才能正常通信。我们可以把“个人电脑”比作“一台电话”,那么“IP地址”就相当于“电话号码”,而Internet中的路由器,就相当于电信局的“程控式交换机”。IP地址是一个32位的二进制数,通常被分割为4个“8位二进制数”(也就是4个字节)。IP地址通常用“点分十进制”表示成(a

    2022年6月24日
    33
  • Android蓝牙开发(二)之蓝牙配对和蓝牙连接

    Android蓝牙开发(二)之蓝牙配对和蓝牙连接上篇文章:https://blog.csdn.net/huangliniqng/article/details/82185983讲解了打开蓝牙设备和搜索蓝牙设备,这篇文章来讲解蓝牙配对和蓝牙连接1.蓝牙配对搜索到蓝牙设备后,将设备信息填充到listview中,点击listiew则请求配对蓝牙配对有点击配对和自动配对,点击配对就是我们选择设备两个手机弹出配对确认框,点击确认…

    2022年6月29日
    113
  • idea打包教程[通俗易懂]

    idea打包教程[通俗易懂]然后点apply/ok

    2022年10月3日
    2
  • mysql语句和sql语句的区别_oracle和sqlserver的语法区别

    mysql语句和sql语句的区别_oracle和sqlserver的语法区别sql和mysql语法的区别有:mysql支持enum和set类型,sql不支持,mysql需要为表指定存储类型,mysqlL中text字段类型不允许有默认值,sql允许有等等方面都存在差异MySQL与SQLServer的语法区别1、MySQL支持enum,和set类型,SQLServer不支持2、MySQL不支持nchar,nvarchar,ntext类型3、MySQL的递增语句是AUTO_I…

    2022年10月2日
    2

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号