现在是随机调用库的图片 能不能改成指定图片
比如说("2.jpg")
这样
"""
****************** 实现 MNIST 手写数字识别 ************************
****************************************************************
"""
# -*- coding: utf-8 -*-
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
# 默认预测四张含有数字的图片
BATCH_SIZE = 4
# 默认使用 cpu 加速
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 构建数据转换列表
tsfrm = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1037,), (0.3081,))
])
# 测试集
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(root = 'data', train = False, download = True,
transform = tsfrm),
batch_size = BATCH_SIZE, shuffle = True)
# 定义图片可视化函数
def imshow(images):
img = torchvision.utils.make_grid(images)
img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
# 将图片高和宽分别赋值给 x1,y1
x1, y1 = img.shape[0:2]
# 图片放大到原来的 5 倍,输出尺寸格式为(宽,高)
enlarge_img = cv2.resize(img, (int(y1*5), int(x1*5)))
cv2.imshow('image', enlarge_img)
cv2.waitKey(0)
# 定义一个 LeNet-5 网络,包含两个卷积层 conv1 和 conv2 ,两个线性层作为输出,最后输出 10 个维度
# 这 10 个维度作为 0-9 的标识来确定识别出的是哪个数字。
class ConvNet(nn.Module):
def __init__(self):
super().__init__()
# 1*1*28*28
# 1 个输入图片通道,10 个输出通道,5x5 卷积核
self.conv1 = nn.Conv2d(1, 10, 5)
self.conv2 = nn.Conv2d(10, 20, 3)
# 全连接层、输出层 softmax,10 个维度
self.fc1 = nn.Linear(20 * 10 * 10, 500)
self.fc2 = nn.Linear(500, 10)
# 正向传播
def forward(self, x):
in_size = x.size(0)
out = self.conv1(x) # 1* 10 * 24 *24
out = F.relu(out)
out = F.max_pool2d(out, 2, 2) # 1* 10 * 12 * 12
out = self.conv2(out) # 1* 20 * 10 * 10
out = F.relu(out)
out = out.view(in_size, -1) # 1 * 2000
out = self.fc1(out) # 1 * 500
out = F.relu(out)
out = self.fc2(out) # 1 * 10
out = F.log_softmax(out, dim=1)
return out
# 主程序入口
if __name__ == "__main__":
model_eval = ConvNet()
# 加载训练模型
model_eval.load_state_dict(torch.load('./MNISTModel.pkl', map_location=DEVICE))
model_eval.eval()
# 从测试集里面拿出几张图片
images,labels = next(iter(test_loader))
# 显示图片
imshow(images)
# 输入
inputs =
images.to(DEVICE)
# 输出
outputs = model_eval(inputs)
# 找到概率最大的下标
_, preds = torch.max(outputs, 1)
# 打印预测结果
numlist = []
for i in range(len(preds)):
label = preds.numpy()[i]
numlist.append(label)
List = ' '.join(repr(s) for s in numlist)
print('当前预测的数字为: ',List)
V2EX 是创意工作者们的社区,是一个分享自己正在做的有趣事物、交流想法,可以遇见新朋友甚至新机会的地方。
V2EX is a community of developers, designers and creative people.