test_pytorch.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import torch
  2. import torch.nn as nn
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. print(f"torch.__version__:{torch.__version__}") # pytorch版本
  6. print(f"torch.version.cuda:{torch.version.cuda}") # cuda版本
  7. print(torch.cuda.is_available()) # 查看cuda是否可用
  8. #
  9. # 使用GPU or CPU
  10. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  11. print(f"using device:{device}")
  12. # 生成一些随机的训练数据
  13. np.random.seed(42)
  14. x = np.random.rand(1000, 1)
  15. y = 10 * x + 1 + 0.5 * np.random.randn(1000, 1)
  16. # y[900]=1000
  17. # 将数据转换为张量
  18. x_tensor = torch.from_numpy(x).float()
  19. y_tensor = torch.from_numpy(y).float()
  20. # 定义线性回归模型
  21. class LinearRegressionModel(nn.Module):
  22. def __init__(self):
  23. super(LinearRegressionModel, self).__init__()
  24. self.linear = nn.Linear(1, 1)
  25. def forward(self, x):
  26. return self.linear(x)
  27. # 创建模型实例
  28. model = LinearRegressionModel()
  29. # 定义损失函数和优化器
  30. criterion = nn.MSELoss()
  31. optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
  32. # 训练模型
  33. num_epochs = 6000
  34. dbug_epos=np.zeros(num_epochs)
  35. dbug_loss=np.zeros(num_epochs)
  36. for epoch in range(num_epochs):
  37. # 前向传播
  38. outputs = model(x_tensor)
  39. loss = criterion(outputs, y_tensor)
  40. # 反向传播和优化
  41. optimizer.zero_grad()
  42. loss.backward()
  43. optimizer.step()
  44. dbug_epos[epoch]=epoch
  45. dbug_loss[epoch] = loss.item()
  46. if (epoch + 1) % 10 == 0:
  47. print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}')
  48. # 进行预测
  49. with torch.no_grad():
  50. predicted = model(x_tensor)
  51. # 绘制原始数据和预测结果
  52. # plt.scatter(x, y, label='Original Data')
  53. # plt.plot(x, predicted.numpy(), color='red', label='Predicted Line')
  54. # plt.xlabel('x')
  55. # plt.ylabel('y')
  56. # plt.legend()
  57. # plt.show()
  58. fig1, ax1 = plt.subplots()
  59. ax1.scatter(x, y, label='Original Data')
  60. ax1.plot(x, predicted.numpy(), color='red', label='Predicted Line')
  61. ax1.set_xlabel('x')
  62. ax1.set_ylabel('y')
  63. ax1.legend()
  64. fig2, ax2 = plt.subplots()
  65. ax2.scatter(dbug_epos, dbug_loss, label='loss')
  66. ax2.plot(dbug_epos, dbug_loss, color='red', label='loss')
  67. ax2.set_xlabel('dbug_epos')
  68. ax2.set_ylabel('dbug_loss')
  69. ax2.legend()
  70. plt.show()