pytorch简单神经网络模型训练

目录

一、导入包

二、数据预处理

三、定义神经网络

四、训练模型和测试模型

五、程序入口


一、导入包

import torch
import torch.nn as nn
import torch.optim as optim # 导入优化器
from torchvision import datasets, transforms # 导入数据集和数据预处理库
from torch.utils.data import DataLoader # 数据加载库

二、数据预处理

def data_loader():
    '''数据的预处理'''

    # 定义数据预处理
    transform = transforms.Compose([

        transforms.ToTensor(),
        transforms.Normalize((0.5), (0.5))
    ])

    # 加载FashionMNIST数据集
    train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.FashionMNIST(root='.data', train=False, download=True, transform=transform)

    # 数据集加载器
    train_loader = DataLoader( train_dataset,batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset,batch_size=64, shuffle=False)

    return train_loader, test_loader

        这是一个常用于机器学习和深度学习研究中的数据集,包含了10类不同时尚商品的图像,每类有6000张训练图像和1000张测试图像。使用了PyTorch框架中的torchvision库来下载和加载Fashion-MNIST数据集。代码中定义了一个transform,它会将图像转换为张量,并对其进行归一化处理。然后,分别创建训练集和测试集的数据加载器train_loadertest_loader,这些加载器会在训练过程中以批量的形式提供数据。 

三、定义神经网络

# 定义神经网络
class QYNN(nn.Module):

    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28,128)
        self.fc2 = nn.Linear(128,10)

    def forward(self, x):
        # 将数据展平
        x = torch.flatten(x,start_dim=1)
        # 激活x,方便数据全联接
        x = torch.relu(self.fc1(x))
        # 输出10分类
        x = self.fc2(x)

        return x

代码定义了一个简单的全连接神经网络,适用于Fashion-MNIST这样的图像分类任务。这个网络包含两个全连接层(fc1fc2),分别用于特征提取和分类。

这里是您定义的QYNN类的一些解释:

  • __init__方法定义了网络的结构。网络接受28x28像素的灰度图像作为输入,首先通过一个线性层fc1将784个像素值映射到128个特征,然后通过第二个线性层fc2将128个特征映射到10个输出,对应于10个类别

  • forward方法定义了数据通过网络的前向传播过程。输入数据首先被展平成一个一维向量,然后通过fc1层,接着是ReLU激活函数,最后通过fc2层输出每个类别的得分。

四、训练模型和测试模型

训练模型

def train(model, train_loader):
    '''训练模型'''

    # 训练轮数
    epochs = 10
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            # 梯度清零
            optimizer.zero_grad()
            # 将图片塞进去
            outputs = model(inputs)
            # 计算损失
            loss = criterion(outputs, labels)
            # 反向传播
            loss.backward()
            # 更新参数
            optimizer.step()
            # 损失值的累加
            running_loss += loss.item()
        print(f'Epoch:{epoch+1}/{epochs} | Loss: {running_loss/len(train_loader)}')

测试模型

test函数使用了torch.no_grad()来禁用梯度计算,因为在测试阶段我们不需要计算梯度。函数遍历测试数据加载器中的每个批次,将输入数据传递给模型以获取输出,然后使用torch.max函数来获取每个样本的最高得分类别作为预测结果。最后,函数计算预测正确的样本数量与总样本数量,从而得到准确率。

def test(model, test_loader):
    '''测试模型'''

    correct = 0 # 正确的数量
    total = 0 # 样本的总量

    with torch.no_grad():
        for inputs, labels in test_loader:
            outputs = model(inputs)

            _, predicted = torch.max(outputs, 1)

            # 获取本次样本的数量
            total += labels.size(0)

            # 预测值 和 标签 相同则正确, 对预测值进行累加
            correct += (predicted == labels).sum().item()

    print(f'Test Accuracy: {correct / total:.2%}')

五、程序入口


if __name__ == '__main__':
    # 设置随机种子
    torch.manual_seed(21)

    # 实例化神经网络
    model = QYNN()
    # 交叉商
    criterion = nn.CrossEntropyLoss()
    # 优化器
    optimizer = optim.SGD(model.parameters(),lr = 0.01)
    # 数据集
    train_loader, test_loader = data_loader()
    # 训练样本
    train(model, train_loader)
    # 测设样本
    test(model, test_loader)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/591625.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Linux服务器常用命令总结

view查找日志关键词 注意日志级别,回车后等一会儿,因为文件可能比较大加载完需要时间 当内容显示出来后,使用“/关键词”搜索 回车就能搜到,n表示查找下一个,N表示查找上一个 find 查找 find Family -name book …

华为平板手机如何清理应用市场的存储空间

如何清理应用市场的存储空间 适用产品: 手机,平板 适用版本:不涉及系统版本 如果您的应用市场显示应用的数据较大,可能是下载的安装包没有安装成功,导致安装包未自动删除。(可参考:应用市场下…

Delta lake with Java--将数据保存到Minio

今天看了之前发的文章,居然有1条评论,看到我写的东西还是有点用。 今天要解决的问题是如何将 Delta产生的数据保存到Minio里面。 1、安装Minio,去官网下载最新版本的Minio,进入下载目录,运行如下命令,曾经…

2024年第11届生物信息学研究与应用国际会议(ICBRA 2024)即将召开!

2024年第11届生物信息学研究与应用国际会议(ICBRA 2024)将于2024年9月13-15日在意大利米兰举行。生物信息学,作为连接生物学与信息技术的桥梁,正日益成为探索生命奥秘、推动生命科学发展的重要力量。ICBRA 2024的召开,…

使用PyTorch从头实现Transformer

前言 本文使用Pytorch从头实现Transformer,原论文Attention is all you need paper,最佳解读博客,学习视频GitHub项目地址Some-Paper-CN。本项目是译者在学习长时间序列预测、CV、NLP和机器学习过程中精读的一些论文,并对其进行了…

突破传统 重新定义:3D医学影像PACS系统源码(包含RIS放射信息)实现三维重建与还原

突破传统,重新定义PACS/RIS服务,洞察用户需求,关注应用场景,新一代PACS/RIS系统,系统顶层设计采用集中分布式架构,满足医院影像全流程业务运行,同时各模块均可独立部署,满足医院未来影像信息化扩展新需求、…

爬虫自动化之drissionpage实现随时切换代理ip

目录 一、视频二、dp首次启动设置代理三、dp利用插件随时切换代理一、视频 视频直接点击学习SwitchyOmega插件使用其它二、dp首次启动设置代理 from DrissionPage import ChromiumPage, ChromiumOptions from loguru

成都旅游攻略

第一天 大熊猫基地(55一人) 切记要去早,否则只能看到熊猫屁股 文殊院(拜文殊菩萨) 杜甫草堂(50一人) 宽窄巷子(旅游打卡拍照) 奎星楼街吃晚饭 这里的饭菜很可口 第二天 东郊记忆(成都故事.川剧变脸)主要是拍照打卡 春熙路 IFS国金中心(打卡熊猫屁屁) 太…

【数据结构与算法】堆

定义 堆是是一个完全二叉树,其中每个节点的值都大于等于或小于等于其子节点的值。这取决于是最大堆还是最小堆。 小根堆:每个根都小于子节点。 大根堆:每个根都大于子节点。 以下部分图例说明来源:【从堆的定义到优先队列、堆排…

使用 TensorFlow 和 Keras 构建 U-Net

原文地址:building-a-u-net-with-tensorflow-and-keras 2024 年 4 月 11 日 计算机视觉有几个子学科,图像分割就是其中之一。如果您要分割图像,则需要在像素级别决定图像中可见的内容(执行分类时),或者从像…

模型 SOP(标准操作程序)

系列文章 分享 模型,了解更多👉 模型_思维模型目录。标准化流程,提质增效,保障合规。 1 SOP的应用 1.1 餐厅日常卫生清洁标准操作程序(SOP) 下面展示一个餐厅如何通过SOP确保清洁工作的标准化&#xff0c…

202209青少年软件编程(Python) 等级考试试卷(一级)

第 1 题 【单选题】 表达式 len(“学史明理增信 , 读史终生受益”) > len(" reading history will benefit you ") 的结果是? ( ) A :0 B :True C :False D :1 正确答案:C 试题解析: 第 2 题 【单选题】 在 turtle 画图中, 常常使用 turtle.color(co…

【doghead】mac构建

先构建libuv libuv ✘ zhangbin@zhangbin-mbp-2  ~/tet/Fargo/zhb-bifrost/Bifrost-202403/worker/third_party/libuv/build   main  cmake .. -DBUILD_TESTING=ON -- The C compiler identification is AppleClang 12.0.5.12050022 -- Check for working C compiler: …

Git的基本操作和使用

git分支指令 列出所有本地分支 git branchmaster是绿的 前面有个 表示当前分支是master* 列出所有远程分支 git branch -r列出所有本地分支和远程分支 git branch -a新建一个分支,但依然停留在当前分支 git branch [branch-name]新建一个分支,并切…

【全网首出】npm run serve报错 Expression: thread_id_key != 0x7777

总结 困扰了一天!!!一直以为是自己哪里配置错了, 结果最后发现是node.js官方的问题, Node.js v16.x版本的fibers.node被弃用 本文阅读大概:3min #npm run serve时就报错 #找了一天的文章,找不…

U盘到底要格式化成什么格式比较好?

前言 前段时间有小伙伴问我:U盘为啥无法粘贴超过4GB的压缩包。 相信这个问题很多人都会遇到,无论是压缩包、镜像文件还是电影,都会有超过4GB的时候。 如果文件超过了4GB,那么就会小伙伴遇到电脑提示:无法粘贴超过4G…

结构体介绍(1)

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 结构体(1) 前言一、struct介绍结构体声明结构体创建和初始化struct 的特殊声明结构体自引用 二、结构体内存对齐2.1.对齐规则 总结 前言 结构体 属于…

npm install digital envelope routines::unsupported解决方法

目录 一、问题描述二、问题原因三、解决方法 一、问题描述 执行命令 npm install 报错:digital envelope routines::unsupported 二、问题原因 Node.js 17 版本引入了 OpenSSL 3.0,它在算法和密钥大小方面实施了更为严格的限制。这一变化导致 npm 的升…

✔ ★Java项目——设计一个消息队列(五)【虚拟主机设计】

虚拟主机设计 创建 VirtualHost实现构造⽅法和 getter创建交换机删除交换机创建队列删除队列创建绑定删除绑定发布消息 ★路由规则1) 实现 route ⽅法2) 实现 checkRoutingKeyValid3) 实现 checkBindingKeyValid4) 实现 routeTopic5) 匹配规则测试⽤例6) 测试 Router 订阅消息1…

idea 新建spring maven项目、ioc和依赖注入

文章目录 一、新建Spring-Maven项目二、在Spring-context使用IOC和依赖注入 一、新建Spring-Maven项目 在pom.xml文件中添加插件管理依赖 <build><plugins><plugin><artifactId>maven-compiler-plugin</artifactId><version>3.1</ver…
最新文章