test_pytorch_MNIST.py 11 KB


  1. import matplotlib.pyplot as plt
  2. import torch
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. from torchvision import datasets, transforms
  6. from torch.utils.data import DataLoader
  7. import numpy as np
  8. import tkinter as tk
  9. from PIL import Image, ImageDraw
  10. from tkinter import messagebox
  11. import cv2
  12. from torchsummary import summary
  13. from torchviz import make_dot
  14. import netron
  15. workmode = 1 #0:训练 1:加载模型
  16. # 数据预处理
  17. if workmode==0:
  18. transform = transforms.Compose([
  19. # transforms.RandomRotation(10), # 随机旋转 10 度
  20. # transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)), # 随机平移
  21. transforms.ToTensor(),
  22. transforms.Normalize((0.1307,), (0.3081,))
  23. ])
  24. else:
  25. transform = transforms.Compose([
  26. transforms.ToTensor(),
  27. transforms.Normalize((0.1307,), (0.3081,))
  28. ])
  29. # 加载训练集和测试集
  30. train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
  31. test_dataset = datasets.MNIST('data', train=False, transform=transform)
  32. train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
  33. test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
  34. test_loader_predicted=np.zeros(len(test_loader.dataset))
  35. def showfig(images,labels,totalnum,num):
  36. # 选择要显示的图片数量
  37. num_images = num
  38. # 创建一个子图布局
  39. fig, axes = plt.subplots(1, num_images, figsize=(15, 3))
  40. # 遍历数据集并显示图片和标签
  41. for i in range(num_images):
  42. # image, label = train_dataset[i]
  43. random_numbers = np.random.choice(totalnum+1, num_images, replace=False)
  44. # image = train_dataset.train_data[random_numbers[i]]
  45. # label = train_dataset.train_labels[random_numbers[i]]
  46. image = images[random_numbers[i]]
  47. label = labels[random_numbers[i]]
  48. # 将张量转换为numpy数组并调整维度
  49. image = image.squeeze().numpy()
  50. # 显示图片
  51. axes[i].imshow(image, cmap='gray')
  52. # 设置标题为标签
  53. axes[i].set_title(f'idx-{random_numbers[i]}-Label: {label}')
  54. axes[i].axis('off')
  55. # # # 显示图形
  56. # plt.show()
  57. # while 1:
  58. # pass
  59. showfig(train_dataset.data,train_dataset.targets,60000,4)
  60. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  61. # # 定义神经网络模型
  62. # class Net(nn.Module):
  63. # def __init__(self):
  64. # super(Net, self).__init__()
  65. # self.fc1 = nn.Linear(784, 128)
  66. # self.fc2 = nn.Linear(128, 64)
  67. # self.fc3 = nn.Linear(64, 10)
  68. #
  69. # def forward(self, x):
  70. # x = x.view(-1, 784)
  71. # x = torch.relu(self.fc1(x))
  72. # x = torch.relu(self.fc2(x))
  73. # x = self.fc3(x)
  74. # return x
  75. # model = Net()
  76. # criterion = nn.CrossEntropyLoss()
  77. # optimizer = optim.SGD(model.parameters(), lr=0.01)
  78. # # 定义修改后的 AlexNet 模型,适应 MNIST 数据集
  79. # class AlexNet(nn.Module):
  80. # def __init__(self, num_classes=10):
  81. # super(AlexNet, self).__init__()
  82. # self.features = nn.Sequential(
  83. # nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1), # 修改卷积核大小和通道数
  84. # nn.ReLU(inplace=True),
  85. # nn.MaxPool2d(kernel_size=2, stride=2), # 修改池化核大小和步长
  86. # nn.Conv2d(32, 64, kernel_size=3, padding=1),
  87. # nn.ReLU(inplace=True),
  88. # nn.MaxPool2d(kernel_size=2, stride=2),
  89. # nn.Conv2d(64, 128, kernel_size=3, padding=1),
  90. # nn.ReLU(inplace=True),
  91. # nn.Conv2d(128, 128, kernel_size=3, padding=1),
  92. # nn.ReLU(inplace=True),
  93. # nn.Conv2d(128, 128, kernel_size=3, padding=1),
  94. # nn.ReLU(inplace=True),
  95. # nn.MaxPool2d(kernel_size=2, stride=2),
  96. # )
  97. # self.classifier = nn.Sequential(
  98. # nn.Dropout(),
  99. # nn.Linear(128 * 3 * 3, 128), # 修改全连接层输入维度
  100. # nn.ReLU(inplace=True),
  101. # nn.Dropout(),
  102. # nn.Linear(128, 128),
  103. # nn.ReLU(inplace=True),
  104. # nn.Linear(128, num_classes),
  105. # )
  106. #
  107. # def forward(self, x):
  108. # x = self.features(x)
  109. # x = x.view(x.size(0), 128 * 3 * 3) # 修改展平后的维度
  110. # x = self.classifier(x)
  111. # return x
  112. #
  113. #
  114. # # 初始化模型、损失函数和优化器
  115. # model = AlexNet().to(device)
  116. # LeNet-5模型定义
  117. class LeNet5(nn.Module):
  118. def __init__(self):
  119. super(LeNet5, self).__init__()
  120. self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
  121. self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
  122. self.fc1 = nn.Linear(16 * 4 * 4, 120)
  123. self.fc2 = nn.Linear(120, 84)
  124. self.fc3 = nn.Linear(84, 10)
  125. def forward(self, x):
  126. x = torch.relu(self.conv1(x))
  127. x = nn.MaxPool2d(2)(x)
  128. x = torch.relu(self.conv2(x))
  129. x = nn.MaxPool2d(2)(x)
  130. x = x.view(x.size(0), -1)
  131. x = torch.relu(self.fc1(x))
  132. x = torch.relu(self.fc2(x))
  133. x = self.fc3(x)
  134. return x
  135. # 实例化模型
  136. model = LeNet5()
  137. criterion = nn.CrossEntropyLoss()
  138. optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
  139. # # 打印模型结构
  140. # summary(model, input_size=(1, 28, 28))
  141. # # 创建一个随机输入张量
  142. # x = torch.randn(1, 1, 28, 28).to(device)
  143. # # 前向传播
  144. # y = model(x).to(device)
  145. # # 使用 torchviz 生成计算图
  146. # dot = make_dot(y, params=dict(model.named_parameters()))
  147. # # 保存计算图为图像文件(这里保存为 PNG 格式)
  148. # dot.render('alexnet_model', format='png', cleanup=True, view=True)
  149. # 训练过程
  150. def train(model, train_loader, optimizer, criterion, epoch):
  151. model.train()
  152. for batch_idx, (data, target) in enumerate(train_loader):
  153. optimizer.zero_grad()
  154. output = model(data)
  155. loss = criterion(output, target)
  156. loss.backward()
  157. optimizer.step()
  158. if batch_idx % 100 == 0:
  159. print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
  160. epoch, batch_idx * len(data), len(train_loader.dataset),
  161. 100. * batch_idx / len(train_loader), loss.item()))
  162. # 验证过程
  163. def test(model, test_loader):
  164. model.eval()
  165. correct = 0
  166. total = 0
  167. with torch.no_grad():
  168. for batch_idx ,(data, target) in enumerate(test_loader):
  169. output = model(data)
  170. _, predicted = torch.max(output.data, 1)
  171. test_loader_predicted[batch_idx*64+0:batch_idx*64+64]=predicted.numpy()
  172. total += target.size(0)
  173. correct += (predicted == target).sum().item()
  174. accuracy = 100 * correct / total
  175. print('Test Accuracy: {:.2f}%'.format(accuracy))
  176. # # 训练和验证模型
  177. if workmode==0: #0:训练
  178. for epoch in range(10):
  179. train(model, train_loader, optimizer, criterion, epoch)
  180. test(model, test_loader)
  181. torch.save(model, 'model.pth')
  182. print(f'save model as model.pth')
  183. else: #1:加载模型
  184. print(f'load model.pth')
  185. try:
  186. model = torch.load('model.pth', weights_only=False)
  187. model.eval()
  188. except Exception as e:
  189. print(f"加载模型时出现错误: {e}")
  190. test(model, test_loader)
  191. # netron.start('model.pth') # 输出网络结构图
  192. showfig(test_loader.dataset.data, test_loader_predicted, 10000, 4)
  193. plt.show()
  194. def save_drawing():
  195. global drawing_points
  196. # 创建一个空白图像
  197. image = Image.new("RGB", (canvas.winfo_width(), canvas.winfo_height()), "black")
  198. draw = ImageDraw.Draw(image)
  199. # 绘制线条
  200. for i in range(1, len(drawing_points)):
  201. x1,y1=drawing_points[i - 1]
  202. x2, y2 = drawing_points[i]
  203. if (x1 is not None) and (x2 is not None) and (y1 is not None) and (y2 is not None):
  204. draw.line((x1, y1, x2, y2), fill="white", width=20)
  205. image = image.convert('L')
  206. image1=image.resize((28,28))
  207. # # 4. 转换为 numpy 数组
  208. # image_array = np.array(image1)
  209. #
  210. # # 5. 二值化
  211. # _, binary_image = cv2.threshold(image_array, 127, 255, cv2.THRESH_BINARY)
  212. #
  213. # # 6. 居中处理
  214. # rows, cols = binary_image.shape
  215. # M = cv2.moments(binary_image)
  216. # if M["m00"] != 0:
  217. # cX = int(M["m10"] / M["m00"])
  218. # cY = int(M["m01"] / M["m00"])
  219. # else:
  220. # cX, cY = 0, 0
  221. # shift_x = cols / 2 - cX
  222. # shift_y = rows / 2 - cY
  223. # M = np.float32([[1, 0, shift_x], [0, 1, shift_y]])
  224. # centered_image = cv2.warpAffine(binary_image, M, (cols, rows))
  225. #
  226. # # 7. 归一化
  227. # normalized_image = centered_image / 255.0
  228. #
  229. # # 8. 调整维度以适应模型输入
  230. # final_image = normalized_image.reshape(28, 28)
  231. # image1 = Image.fromarray(final_image)
  232. # # 转换为numpy数组
  233. # img_array = np.array(image1)
  234. # # 中值滤波
  235. # filtered_img = cv2.medianBlur(img_array, 3)
  236. # # 转换回Image对象(如果需要的话)
  237. # image1 = Image.fromarray(filtered_img)
  238. tensor_image = transform(image1) #torch.Size([3, 28, 28])
  239. # gray_tensor = torch.mean(tensor_image, dim=0, keepdim=True)
  240. # pool = torch.nn.MaxPool2d(kernel_size=10, stride=10)
  241. # pooled_image = pool(gray_tensor.unsqueeze(0)).squeeze(0)
  242. # pooled_image = gray_tensor
  243. pooled_image = tensor_image.unsqueeze(0)
  244. # print(f'tensor_image :{gray_tensor.shape} -pooled_image:{pooled_image.shape}')
  245. # simage=pooled_image.view(28,28)
  246. # simage = (simage - simage.min()) / (simage.max() - simage.min())
  247. # np_array = (simage.numpy() * 255).astype('uint8')
  248. # image_f = Image.fromarray(np_array)
  249. # image_f.show()
  250. with torch.no_grad():
  251. output = torch.softmax(model(pooled_image),1)
  252. # print(f'output.data={output}')
  253. v, predicted = torch.max(output, 1)
  254. print(f'预测数字={predicted.numpy()[0]},概率:{(v*100).numpy()[0]:.2f}%')
  255. messagebox.showinfo('识别结果',f'predicted={predicted.numpy()[0]}')
  256. drawing_points=[]
  257. canvas.delete("all")
  258. # 保存图像
  259. image.save("drawing.png")
  260. image1.save("drawing28x28.png")
  261. # print("绘画已保存为 drawing.png")
  262. last_x,last_y=None,None
  263. # last_y=[]
  264. def on_mouse_move(event):
  265. global last_x,last_y
  266. # drawing_points.append((event.x, event.y))
  267. drawing_points.append((last_x, last_y))
  268. if (last_x is not None) and (last_y is not None) :
  269. canvas.create_line(last_x, last_y, event.x, event.y, fill="white", width=20, smooth=True, splinesteps=10)
  270. # canvas.create_line(last_x, last_y , event.x, event.y, fill="white", width=20)
  271. last_x, last_y = event.x, event.y
  272. def on_mouse_release(event):
  273. global last_x, last_y
  274. last_x, last_y = None, None
  275. # print("on_mouse_release")
  276. pass
  277. root = tk.Tk()
  278. canvas = tk.Canvas(root, width=280*2, height=280*2, bg="black")
  279. canvas.pack()
  280. # canvas_show = tk.Canvas(root, width=280, height=280, bg="black")
  281. # canvas_show.pack()
  282. button = tk.Button(root, text="识别", command=save_drawing)
  283. button.pack()
  284. drawing_points = []
  285. canvas.bind("<B1-Motion>", on_mouse_move)
  286. canvas.bind("<ButtonRelease-1>", on_mouse_release)
  287. root.mainloop()