柠檬少年|迁移学习—通俗易懂地介绍(常见网络模型pytorch实现)( 三 )


柠檬少年|迁移学习—通俗易懂地介绍(常见网络模型pytorch实现)选择你要下载的模型 , copy到浏览器即可 , 若是觉得慢可以用迅雷等等 。
ResNet 详细讲解在这篇博文里:ResNet——CNN经典网络模型详解(pytorch实现)
#train.pyimport torchimport torch.nn as nnfrom torchvision import transforms, datasetsimport jsonimport matplotlib.pyplot as pltimport osimport torch.optim as optimfrom model import resnet34, resnet101import torchvision.models.resnetdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print(device)data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),#来自官网参数"val": transforms.Compose([transforms.Resize(256),#将最小边长缩放到256transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}data_root = os.getcwd()image_path = data_root + "/flower_data/"# flower data set pathtrain_dataset = datasets.ImageFolder(root=image_path + "train",transform=data_transform["train"])train_num = len(train_dataset)# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}flower_list = train_dataset.class_to_idxcla_dict = dict((val, key) for key, val in flower_list.items())# write dict into json filejson_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 16train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=0)validate_dataset = datasets.ImageFolder(root=image_path + "/val",transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=0)net = resnet34()# net = resnet34(num_classes=5)# load pretrain weightsmodel_weight_path = "./resnet34-pre.pth"missing_keys, unexpected_keys = net.load_state_dict(torch.load(model_weight_path), strict=False)#载入模型参数# for param in net.parameters():#param.requires_grad = False# change fc layer structureinchannel = net.fc.in_featuresnet.fc = nn.Linear(inchannel, 5)net.to(device)loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.0001)best_acc = 0.0save_path = './resNet34.pth'for epoch in range(3):# trainnet.train()running_loss = 0.0for step, data in enumerate(train_loader, start=0):images, labels = dataoptimizer.zero_grad()logits = net(images.to(device))loss = loss_function(logits, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()# print train processrate = (step+1)/len(train_loader)a = "*" * int(rate * 50)b = "." * int((1 - rate) * 50)print("\rtrain loss: {:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100), a, b, loss), end="")print()# validatenet.eval()acc = 0.0# accumulate accurate number / epochwith torch.no_grad():for val_data in validate_loader:val_images, val_labels = val_dataoutputs = net(val_images.to(device))# eval model only have last output layer# loss = loss_function(outputs, test_labels)predict_y = torch.max(outputs, dim=1)[1]acc += (predict_y == val_labels.to(device)).sum().item()val_accurate = acc / val_numif val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('[epoch %d] train_loss: %.3ftest_accuracy: %.3f' %(epoch + 1, running_loss / step, val_accurate))print('Finished Training')


推荐阅读