1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
| import torch from torch import nn from torch.nn import functional as F from torch import optim import torchvision
from matplotlib import pyplot as plt import pandas as pd import numpy as np
from Util import plot_curve,plot_image,one_hot,pd_one_hot
batch_size = 512
train_loader = torch.utils.data.DataLoader( torchvision.datasets.MNIST('mnist_data' , train=True ,download=True ,transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor() ,torchvision.transforms.Normalize((0.1307,), (0.3081,)) ]) ) ,batch_size=batch_size ,shuffle=True )
test_loader = torch.utils.data.DataLoader( torchvision.datasets.MNIST('mnist_data' ,train=False ,download=True ,transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor() ,torchvision.transforms.Normalize((0.1307,), (0.3081,)) ]) ), batch_size=batch_size ,shuffle=False)
x, y = next(iter(train_loader)) print(x.shape, y.shape, x.min(), x.max()) plot_image(x, y, 'image sample')
class Net(nn.Module): def __init__(self): super(Net , self).__init__() self.fc1 = nn.Linear(28*28 , 256) self.fc2 = nn.Linear(256 , 64) self.fc3 = nn.Linear(64 , 10) def forward(self , x): x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x
net = Net()
optimizer = optim.SGD(net.parameters() , lr = 0.01 , momentum= 0.9)
loss_s = [ ]
for each in range(3): for location , (x,y) in enumerate(train_loader): x = x.view(x.size(0) , 28*28) out = net(x) y_onehot = pd_one_hot(y) loss = F.mse_loss(out , torch.from_numpy(y_onehot).float()) optimizer.zero_grad() loss.backward() optimizer.step() if(location % 5 == 0): loss_s.append(loss.item()) print('第' , each+1 , '次迭代完成')
plt.plot(range(len(loss_s)) , loss_s , 'y') plt.show()
total_correct = 0
for x,y in test_loader: x = x.view(x.size(0), 28*28) out = net(x) pred = out.argmax(dim=1) correct = pred.eq(y).sum().float().item() total_correct += correct
print('正确率:' , total_correct/len(test_loader.dataset))
x , y = next(iter(test_loader)) plot_image(x , y , 'test')
|