Pytorch迁移学习

    xiaoxiao2023-10-06  136

    在实际应用中,很少有人从头开始训练整个卷积网络,因为很难获得足够多的数据。因此,常用的做法是使用在庞大数据集上训练好的模型作为预训练模型,用来初始化网络,或者提取特征。

    迁移学习的主要应用场景有以下两种:

    微调模型。使用预训练模型初始化网络特征提取。除最后一层全连接层之外,固定网络中其他层的权重,最后的全连接层权重随机初始化,这一层的参数会得到训练。

    导包

    from __future__ import print_function, division import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler import numpy as np import torchvision from torchvision import datasets, models, transforms import matplotlib.pyplot as plt import time import os import copy plt.ion() # interactive mode

    数据导入

    使用torchvision和torch.utils.data来导入数据。

    训练一个模型对蚂蚁、蜜蜂进行分类,训练图片每类有120张图片,75张验证图片,数据集非常小,因此使用迁移学习。

    # Data augmentation and normalization for training # Just normalization for validation data_transforms = { 'train': transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), 'val': transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), } data_dir = 'data/hymenoptera_data' image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']} dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4, shuffle=True, num_workers=4) for x in ['train', 'val']} dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} class_names = image_datasets['train'].classes device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    训练图片展示

    def imshow(inp, title=None): """Imshow for Tensor.""" inp = inp.numpy().transpose((1, 2, 0)) mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) inp = std * inp + mean inp = np.clip(inp, 0, 1) plt.imshow(inp) if title is not None: plt.title(title) plt.pause(0.001) # pause a bit so that plots are updated # Get a batch of training data inputs, classes = next(iter(dataloaders['train'])) # Make a grid from batch out = torchvision.utils.make_grid(inputs) imshow(out, title=[class_names[x] for x in classes])

    训练模型

    调整学习率保存最佳模型

    下面的scheduler是torch.optim.lr_scheduler中的学习率调整器。

    def train_model(model, criterion, optimizer, scheduler, num_epochs=25): since = time.time() best_model_wts = copy.deepcopy(model.state_dict()) best_acc = 0.0 for epoch in range(num_epochs): print('Epoch {}/{}'.format(epoch, num_epochs - 1)) print('-' * 10) # Each epoch has a training and validation phase for phase in ['train', 'val']: if phase == 'train': scheduler.step() model.train() # Set model to training mode else: model.eval() # Set model to evaluate mode running_loss = 0.0 running_corrects = 0 class_total = list(0. for i in range(NUM_CLASS)) class_correct = list(0. for i in range(NUM_CLASS)) # Iterate over data. for inputs, labels in dataloaders[phase]: inputs = inputs.to(device) labels = labels.to(device) # zero the parameter gradients optimizer.zero_grad() # forward # track history if only in train with torch.set_grad_enabled(phase=='train'): outputs = model(inputs) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) # backward + optimize only if in training phase if phase == 'train': loss.backward() optimizer.step() # statistics running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) c = (preds == labels.data).squeeze() for label, pre in zip(labels.data, c): class_correct[label] += pre.item() class_total[label] += 1 epoch_loss = running_loss / dataset_sizes[phase] epoch_acc = running_corrects.double() / dataset_sizes[phase] print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc)) for i in range(NUM_CLASS): print('Accuracy of %5s : - %%' % ( class_names[i], 100 * class_correct[i] / class_total[i])) # deep copy the model if phase == 'val' and epoch_acc > best_acc: best_acc = epoch_acc best_model_wts = copy.deepcopy(model.state_dict()) print() time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) print('Best val Acc: {:4f}'.format(best_acc)) # load best model weights model.load_state_dict(best_model_wts) return model

    可视化预测结果

    def visualize_model(model, num_images=6): was_training = model.training model.eval() images_so_far = 0 fig = plt.figure() with torch.no_grad(): for i, (inputs, labels) in enumerate(dataloaders['val']): inputs = inputs.to(device) labels = labels.to(device) outputs = model(inputs) _, preds = torch.max(outputs, 1) for j in range(inputs.size()[0]): images_so_far += 1 ax = plt.subplot(num_images//2, 2, images_so_far) ax.axis('off') ax.set_title('predicted: {}'.format(class_names[preds[j]])) imshow(inputs.cpu().data[j]) if images_so_far == num_images: model.train(mode=was_training) return model.train(mode=was_training)
    finetuning 网络

    加载预训练模型,重置最后一层全连接网络。

    model_ft = models.vgg19_bn(pretrained=True) num_ftrs = model_ft.classifier[6].in_features model_ft.classifier[6] = nn.Linear(num_ftrs, 2) model_ft = model_ft.to(device) criterion = nn.CrossEntropyLoss() # Observe that all parameters are being optimized optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9) # Decay LR by a factor of 0.1 every 7 epochs exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
    训练并评估
    model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25) Epoch 0/24 ---------- train Loss: 0.4654 Acc: 0.7582 Accuracy of ants : 79 % Accuracy of bees : 71 % val Loss: 0.2124 Acc: 0.9150 Accuracy of ants : 82 % Accuracy of bees : 100 % Epoch 1/24 ---------- train Loss: 0.3709 Acc: 0.8484 Accuracy of ants : 82 % Accuracy of bees : 87 % val Loss: 0.1678 Acc: 0.9281 Accuracy of ants : 95 % Accuracy of bees : 91 % Epoch 2/24 ---------- train Loss: 0.2731 Acc: 0.8730 Accuracy of ants : 88 % Accuracy of bees : 85 % val Loss: 0.1629 Acc: 0.9085 Accuracy of ants : 95 % Accuracy of bees : 87 % Epoch 3/24 ---------- train Loss: 0.2576 Acc: 0.8852 Accuracy of ants : 91 % Accuracy of bees : 85 % val Loss: 0.1792 Acc: 0.9216 Accuracy of ants : 87 % Accuracy of bees : 97 % Epoch 4/24 ---------- train Loss: 0.1734 Acc: 0.9303 Accuracy of ants : 92 % Accuracy of bees : 93 % val Loss: 0.1090 Acc: 0.9542 Accuracy of ants : 95 % Accuracy of bees : 96 % Epoch 5/24 ---------- train Loss: 0.1955 Acc: 0.9221 Accuracy of ants : 92 % Accuracy of bees : 91 % val Loss: 0.1055 Acc: 0.9477 Accuracy of ants : 95 % Accuracy of bees : 95 % Epoch 6/24 ---------- train Loss: 0.1743 Acc: 0.9098 Accuracy of ants : 94 % Accuracy of bees : 87 % val Loss: 0.1086 Acc: 0.9412 Accuracy of ants : 94 % Accuracy of bees : 95 % Epoch 7/24 ---------- train Loss: 0.1719 Acc: 0.9426 Accuracy of ants : 93 % Accuracy of bees : 95 % val Loss: 0.1130 Acc: 0.9477 Accuracy of ants : 95 % Accuracy of bees : 95 % Epoch 8/24 ---------- train Loss: 0.1358 Acc: 0.9590 Accuracy of ants : 97 % Accuracy of bees : 94 % val Loss: 0.1102 Acc: 0.9412 Accuracy of ants : 95 % Accuracy of bees : 93 % Epoch 9/24 ---------- train Loss: 0.1995 Acc: 0.9057 Accuracy of ants : 93 % Accuracy of bees : 87 % val Loss: 0.1068 Acc: 0.9477 Accuracy of ants : 94 % Accuracy of bees : 96 % Epoch 10/24 ---------- train Loss: 0.2080 Acc: 0.8934 Accuracy of ants : 86 % Accuracy of bees : 92 % val Loss: 0.0920 Acc: 0.9477 Accuracy of ants : 97 % Accuracy of bees : 93 % Epoch 11/24 ---------- train Loss: 0.1665 Acc: 0.9385 Accuracy of ants : 94 % Accuracy of bees : 93 % val Loss: 0.0982 Acc: 0.9477 Accuracy of ants : 94 % Accuracy of bees : 96 % Epoch 12/24 ---------- train Loss: 0.1357 Acc: 0.9426 Accuracy of ants : 95 % Accuracy of bees : 93 % val Loss: 0.0993 Acc: 0.9412 Accuracy of ants : 94 % Accuracy of bees : 95 % Epoch 13/24 ---------- train Loss: 0.1355 Acc: 0.9549 Accuracy of ants : 95 % Accuracy of bees : 95 % val Loss: 0.1007 Acc: 0.9477 Accuracy of ants : 95 % Accuracy of bees : 95 % Epoch 14/24 ---------- train Loss: 0.2216 Acc: 0.9057 Accuracy of ants : 90 % Accuracy of bees : 90 % val Loss: 0.0976 Acc: 0.9477 Accuracy of ants : 95 % Accuracy of bees : 95 % Epoch 15/24 ---------- train Loss: 0.1494 Acc: 0.9303 Accuracy of ants : 93 % Accuracy of bees : 92 % val Loss: 0.1100 Acc: 0.9412 Accuracy of ants : 95 % Accuracy of bees : 93 % Epoch 16/24 ---------- train Loss: 0.1786 Acc: 0.9303 Accuracy of ants : 93 % Accuracy of bees : 92 % val Loss: 0.1069 Acc: 0.9412 Accuracy of ants : 95 % Accuracy of bees : 93 % Epoch 17/24 ---------- train Loss: 0.1210 Acc: 0.9508 Accuracy of ants : 97 % Accuracy of bees : 92 % val Loss: 0.0990 Acc: 0.9477 Accuracy of ants : 95 % Accuracy of bees : 95 % Epoch 18/24 ---------- train Loss: 0.1851 Acc: 0.9180 Accuracy of ants : 89 % Accuracy of bees : 94 % val Loss: 0.0903 Acc: 0.9608 Accuracy of ants : 95 % Accuracy of bees : 97 % Epoch 19/24 ---------- train Loss: 0.2010 Acc: 0.9221 Accuracy of ants : 95 % Accuracy of bees : 89 % val Loss: 0.1039 Acc: 0.9412 Accuracy of ants : 95 % Accuracy of bees : 93 % Epoch 20/24 ---------- train Loss: 0.1388 Acc: 0.9385 Accuracy of ants : 93 % Accuracy of bees : 94 % val Loss: 0.1145 Acc: 0.9477 Accuracy of ants : 97 % Accuracy of bees : 93 % Epoch 21/24 ---------- train Loss: 0.1390 Acc: 0.9426 Accuracy of ants : 94 % Accuracy of bees : 94 % val Loss: 0.1068 Acc: 0.9477 Accuracy of ants : 94 % Accuracy of bees : 96 % Epoch 22/24 ---------- train Loss: 0.1254 Acc: 0.9590 Accuracy of ants : 94 % Accuracy of bees : 97 % val Loss: 0.1016 Acc: 0.9477 Accuracy of ants : 97 % Accuracy of bees : 93 % Epoch 23/24 ---------- train Loss: 0.1194 Acc: 0.9467 Accuracy of ants : 92 % Accuracy of bees : 96 % val Loss: 0.1058 Acc: 0.9412 Accuracy of ants : 95 % Accuracy of bees : 93 % Epoch 24/24 ---------- train Loss: 0.1457 Acc: 0.9467 Accuracy of ants : 92 % Accuracy of bees : 96 % val Loss: 0.1034 Acc: 0.9477 Accuracy of ants : 94 % Accuracy of bees : 96 % Training complete in 4m 24s Best val Acc: 0.960784 visualize_model(model_ft)

    预训练模型作为特征提取器

    在这里需要固定除最后一层外其他所有层,使用requires_grad == False 来固定参数,这样在反向传播的时候不会计算梯度。

    model_conv = torchvision.models.resnet18(pretrained=True) for param in model_conv.parameters(): param.requires_grad = False # Parameters of newly constructed modules have requires_grad=True by default num_ftrs = model_conv.fc.in_features model_conv.fc = nn.Linear(num_ftrs, 2) model_conv = model_conv.to(device) criterion = nn.CrossEntropyLoss() # Observe that only parameters of final layer are being optimized as # opposed to before. optimizer_conv = optim.SGD(model_conv.classifier[6].parameters(), lr=0.001, momentum=0.9) # Decay LR by a factor of 0.1 every 7 epochs exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)
    训练并评估
    model_conv = train_model(model_conv, criterion, optimizer_conv, exp_lr_scheduler, num_epochs=25) Epoch 0/24 ---------- train Loss: 0.5984 Acc: 0.7008 Accuracy of ants : 70 % Accuracy of bees : 69 % val Loss: 0.1762 Acc: 0.9281 Accuracy of ants : 87 % Accuracy of bees : 98 % Epoch 1/24 ---------- train Loss: 0.5020 Acc: 0.7746 Accuracy of ants : 78 % Accuracy of bees : 76 % val Loss: 0.2006 Acc: 0.9346 Accuracy of ants : 94 % Accuracy of bees : 93 % Epoch 2/24 ---------- train Loss: 0.3177 Acc: 0.8607 Accuracy of ants : 86 % Accuracy of bees : 85 % val Loss: 0.1839 Acc: 0.9281 Accuracy of ants : 94 % Accuracy of bees : 92 % Epoch 3/24 ---------- train Loss: 0.3501 Acc: 0.8730 Accuracy of ants : 87 % Accuracy of bees : 86 % val Loss: 0.1481 Acc: 0.9542 Accuracy of ants : 95 % Accuracy of bees : 96 % Epoch 4/24 ---------- train Loss: 0.3446 Acc: 0.8566 Accuracy of ants : 85 % Accuracy of bees : 85 % val Loss: 0.3494 Acc: 0.9020 Accuracy of ants : 81 % Accuracy of bees : 98 % Epoch 5/24 ---------- train Loss: 0.3998 Acc: 0.8238 Accuracy of ants : 79 % Accuracy of bees : 85 % val Loss: 0.2495 Acc: 0.8693 Accuracy of ants : 97 % Accuracy of bees : 79 % Epoch 6/24 ---------- train Loss: 0.5243 Acc: 0.7992 Accuracy of ants : 83 % Accuracy of bees : 76 % val Loss: 0.1553 Acc: 0.9542 Accuracy of ants : 92 % Accuracy of bees : 98 % Epoch 7/24 ---------- train Loss: 0.2630 Acc: 0.8975 Accuracy of ants : 87 % Accuracy of bees : 91 % val Loss: 0.1181 Acc: 0.9673 Accuracy of ants : 97 % Accuracy of bees : 97 % Epoch 8/24 ---------- train Loss: 0.3254 Acc: 0.8770 Accuracy of ants : 86 % Accuracy of bees : 88 % val Loss: 0.1476 Acc: 0.9608 Accuracy of ants : 95 % Accuracy of bees : 97 % Epoch 9/24 ---------- train Loss: 0.3085 Acc: 0.8811 Accuracy of ants : 91 % Accuracy of bees : 85 % val Loss: 0.1379 Acc: 0.9673 Accuracy of ants : 97 % Accuracy of bees : 97 % Epoch 10/24 ---------- train Loss: 0.3302 Acc: 0.8730 Accuracy of ants : 89 % Accuracy of bees : 85 % val Loss: 0.1628 Acc: 0.9673 Accuracy of ants : 95 % Accuracy of bees : 98 % Epoch 11/24 ---------- train Loss: 0.2677 Acc: 0.8770 Accuracy of ants : 85 % Accuracy of bees : 90 % val Loss: 0.1374 Acc: 0.9673 Accuracy of ants : 97 % Accuracy of bees : 97 % Epoch 12/24 ---------- train Loss: 0.2701 Acc: 0.8852 Accuracy of ants : 89 % Accuracy of bees : 87 % val Loss: 0.1368 Acc: 0.9673 Accuracy of ants : 97 % Accuracy of bees : 97 % Epoch 13/24 ---------- train Loss: 0.2253 Acc: 0.8852 Accuracy of ants : 85 % Accuracy of bees : 91 % val Loss: 0.1350 Acc: 0.9673 Accuracy of ants : 97 % Accuracy of bees : 97 % Epoch 14/24 ---------- train Loss: 0.2009 Acc: 0.9180 Accuracy of ants : 93 % Accuracy of bees : 90 % val Loss: 0.1335 Acc: 0.9608 Accuracy of ants : 97 % Accuracy of bees : 96 % Epoch 15/24 ---------- train Loss: 0.3129 Acc: 0.8689 Accuracy of ants : 86 % Accuracy of bees : 86 % val Loss: 0.1288 Acc: 0.9673 Accuracy of ants : 97 % Accuracy of bees : 97 % Epoch 16/24 ---------- train Loss: 0.3808 Acc: 0.8320 Accuracy of ants : 83 % Accuracy of bees : 82 % val Loss: 0.1517 Acc: 0.9477 Accuracy of ants : 97 % Accuracy of bees : 93 % Epoch 17/24 ---------- train Loss: 0.2780 Acc: 0.8770 Accuracy of ants : 89 % Accuracy of bees : 85 % val Loss: 0.1363 Acc: 0.9673 Accuracy of ants : 97 % Accuracy of bees : 97 % Epoch 18/24 ---------- train Loss: 0.1963 Acc: 0.9098 Accuracy of ants : 91 % Accuracy of bees : 90 % val Loss: 0.1537 Acc: 0.9673 Accuracy of ants : 97 % Accuracy of bees : 97 % Epoch 19/24 ---------- train Loss: 0.3126 Acc: 0.8730 Accuracy of ants : 84 % Accuracy of bees : 90 % val Loss: 0.1336 Acc: 0.9608 Accuracy of ants : 95 % Accuracy of bees : 97 % Epoch 20/24 ---------- train Loss: 0.3448 Acc: 0.8770 Accuracy of ants : 87 % Accuracy of bees : 87 % val Loss: 0.1428 Acc: 0.9673 Accuracy of ants : 95 % Accuracy of bees : 98 % Epoch 21/24 ---------- train Loss: 0.2615 Acc: 0.9057 Accuracy of ants : 94 % Accuracy of bees : 86 % val Loss: 0.1276 Acc: 0.9673 Accuracy of ants : 97 % Accuracy of bees : 97 % Epoch 22/24 ---------- train Loss: 0.3990 Acc: 0.8361 Accuracy of ants : 82 % Accuracy of bees : 85 % val Loss: 0.1467 Acc: 0.9673 Accuracy of ants : 95 % Accuracy of bees : 98 % Epoch 23/24 ---------- train Loss: 0.2309 Acc: 0.9180 Accuracy of ants : 91 % Accuracy of bees : 92 % val Loss: 0.1301 Acc: 0.9673 Accuracy of ants : 97 % Accuracy of bees : 97 % Epoch 24/24 ---------- train Loss: 0.2305 Acc: 0.9139 Accuracy of ants : 92 % Accuracy of bees : 90 % val Loss: 0.1328 Acc: 0.9673 Accuracy of ants : 97 % Accuracy of bees : 97 % Training complete in 2m 45s Best val Acc: 0.967320 visualize_model(model_ft)

    扫码关注微信公众号:机器工匠,回复关键字”pytorch迁移学习“获取代码和数据。

    最新回复(0)