使用预训练模型的 Alexnet 进行图片分类,准确率与网络数据不符,可能是什么原因导致的?

2021-11-21 01:48:54 +08:00
 Richard14

预训练的意思是用 torchvision 里写好的 alexnet (修改最后一层),不是指导入训练好的,尝试用 quickstart 里的代码训练 cifar10 ,但是网上普遍查到的实验数据,准确率大概在 80%,78%左右,我迭代到收敛也只能得到 70%的准确率,这个差异产生的原因是啥呢?

完整代码:

from utils import *
from pipeit import *
import os,sys,time,pickle,random
import matplotlib.pyplot as plt
import numpy as np 
import torch
from torch import nn
from torchvision import datasets, models
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torchvision.transforms import ToTensor, Lambda, Resize, Compose, InterpolationMode

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))
torch.backends.cudnn.benchmark=True

# Download training data from open datasets.
training_data = datasets.CIFAR10(
    root=".\\data\\cifar10",
    train=True,
    download=True,
    transform=Compose([
        Resize((64, 64), InterpolationMode.BICUBIC),
        ToTensor()
    ])
)

# Download test data from open datasets.
test_data = datasets.CIFAR10(
    root=".\\data\\cifar10",
    train=False,
    download=True,
    transform=Compose([
        Resize((64, 64), InterpolationMode.BICUBIC),
        ToTensor()
    ])
)

def imshow(training_data):
    labels_map = {
        0: "plane",
        1: "car",
        2: "bird",
        3: "cat",
        4: "deer",
        5: "dog",
        6: "frog",
        7: "horse",
        8: "ship",
        9: "truck",
    }
    cols, rows = 3, 3
    figure = plt.figure(figsize=(8,8))
    for i in range(1, cols * rows + 1):
        sample_idx = torch.randint(len(training_data), size=(1,)).item()
        img, label = training_data[sample_idx]
        img = img.swapaxes(0,1)
        img = img.swapaxes(1,2)
        figure.add_subplot(rows, cols, i)
        plt.title(labels_map[label])
        plt.axis("off")
        plt.imshow(img)
    plt.show()

# imshow(training_data)

def train_loop(dataloader, net, loss_fn, optimizer):
    size = len(dataloader)
    train_loss = 0
    for batch_idx, (X, tag) in enumerate(dataloader):
        X, tag = X.to(device), tag.to(device)
        pred = net(X)
        loss = loss_fn(pred, tag)
        train_loss += loss.item()

        # Back propagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    train_loss /= size 
    return train_loss

def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    return test_loss, correct

net = models.alexnet().to(device)
net.classifier[6] = nn.Linear(4096, 10).to(device)

learning_rate = 0.01
batch_size = 128
weight_decay = 0

train_dataloader = DataLoader(training_data, batch_size = batch_size)
test_dataloader = DataLoader(test_data, batch_size = batch_size)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr = learning_rate)

epochs = 50
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    st_time = time.time()
    train_loss = train_loop(train_dataloader, net, loss_fn, optimizer)
    test_loss, correct = test_loop(test_dataloader, net, loss_fn)
    print(f"Train loss: {train_loss:>8f}, Test loss: {test_loss:>8f}, Accuracy: {(100*correct):>0.1f}%, Epoch time: {time.time() - st_time:.2f}s\n")
print("Done!")
torch.save(net.state_dict(), 'alexnet-pre1.model')

最后收敛时的数据在这样:

Epoch 52
-------------------------------
Train loss: 0.399347, Test loss: 0.970927, Accuracy: 70.3%, Epoch time: 17.20s
779 次点击
所在节点    问与答
1 条回复
KangolHsu
2021-11-21 23:53:55 +08:00
输入的图片 64*64 ?是不是有点小啊

这是一个专为移动设备优化的页面(即为了让你能够在 Google 搜索结果里秒开这个页面),如果你希望参与 V2EX 社区的讨论,你可以继续到 V2EX 上打开本讨论主题的完整版本。

https://www.v2ex.com/t/816868

V2EX 是创意工作者们的社区,是一个分享自己正在做的有趣事物、交流想法,可以遇见新朋友甚至新机会的地方。

V2EX is a community of developers, designers and creative people.

© 2021 V2EX