PyTorch:用VGGNet实现手写字体识别数据集的训练与测试
首先,我们需要导入PyTorch库和手写字体识别数据集。 ```python import torch import torchvision import torchvision.transforms as transforms import torch.nn as nn import torch.optim as optim import matplotlib.pyplot as plt import numpy as np ``` 然后,我们加载手写字体识别数据集,并对数据进行预处理。 ```python transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2) testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2) classes = tuple(str(i) for i in range(10)) ``` 接下来,我们定义一个VGG网络模型。 ```python class VGGNet(nn.Module): def __init__(self): super(VGGNet, self).__init__() self.features = nn.Sequential( nn.Conv2d(1, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(256, 512, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2) ) self.classifier = nn.Sequential( nn.Linear(512 * 3 * 3, 4096), nn.ReLU(inplace=True), nn.Dropout(), nn.Linear(4096, 4096), nn.ReLU(inplace=True), nn.Dropout(), nn.Linear(4096, 10) ) def forward(self, x): x = self.features(x) x = x.view(x.size(0), -1) x = self.classifier(x) return x ``` 接下来,我们定义损失函数和优化器。 ```python net = VGGNet() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) ``` 然后,我们开始训练网络。 ```python for epoch in range(2): # loop over the dataset multiple times running_loss = 0.0 for i, data in enumerate(trainloader, 0): inputs, labels = data optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() if i % 2000 == 1999: # print every 2000 mini-batches print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000)) running_loss = 0.0 print('Finished Training') ``` 我们测试训练好的网络。 ```python correct = 0 total = 0 with torch.no_grad(): for data in testloader: images, labels = data outputs = net(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print('Accuracy of the network on the 10000 test images: %d %%' % ( 100 * correct / total)) ``` 通过以上代码,我们可以得到VGGNet网络在手写字体识别数据集上的训练和测试结果。 ######[AI写代码神器 | 2483点数解答 | 2024-10-21 23:42:05]
- Java代码实现从JSON数据中提取ID,涉及“数据”与“项目 - 知识库”结构(字节豆包 | 603点数解答 | 2025-04-15 22:48:51)255
- C++ 实现:计算一排鹅与最高、最矮鹅的身高差及代码详解(字节豆包 | 549点数解答 | 2025-11-16 20:19:06)105
- 编程揭秘:计算 𝑛! 在 𝑃 进制下末尾零个数的方法与代码实现(DeepSeek | 549点数解答 | 2026-01-11 17:49:54)63
- Python 实现球类:精准计算半径、表面积与体积,附输入验证与异常处理!(阿里通义 | 261点数解答 | 2024-11-28 21:19:39)378
- Matlab编程:血管机器人订购与生物学习,实现104周运营成本最低方案(字节豆包 | 3384点数解答 | 2025-06-22 13:25:13)196
- 51 单片机:定时器 0 实现 8 个 LED 循环点亮,附代码及优化建议(字节豆包 | 1193点数解答 | 2024-12-27 15:10:29)348
- Visual C++ 6.0:实现 n×n 矩阵最大最小元素所在行对调及输出(字节豆包 | 648点数解答 | 2024-11-02 10:23:57)322
- 洛谷:根据邮件重量与加急情况精准计算邮费(字节豆包 | 88点数解答 | 2024-11-09 15:33:30)422
- Python实现:根据邮件重量与加急选择精准计算邮费(阿里通义 | 554点数解答 | 2024-11-09 15:34:05)429
- GEE代码实现:三年Landsat逐月数据处理,计算NDVI、MNDWI及PWTMI指数(GPT | 1616点数解答 | 2024-11-26 19:55:20)253
- 学号_京东笔记本数据分析:爬取、存储与可视化全流程揭秘(阿里通义 | 968点数解答 | 2024-12-12 00:20:07)217
- 小学四年级综合测试卷:涵盖选择、填空等多题型,测测你的知识掌握度!(字节豆包 | 905点数解答 | 2024-12-20 14:11:43)250