柠檬少年|迁移学习—通俗易懂地介绍(常见网络模型pytorch实现)( 四 )


柠檬少年|迁移学习—通俗易懂地介绍(常见网络模型pytorch实现)【柠檬少年|迁移学习—通俗易懂地介绍(常见网络模型pytorch实现)】未使用迁移学习
柠檬少年|迁移学习—通俗易懂地介绍(常见网络模型pytorch实现)VGG16
#train.pyimport torch.nn as nnfrom torchvision import transforms, datasetsimport jsonimport osimport torch.optim as optimfrom model import vggimport torchimport timeimport torchvision.models.vggfrom torchvision import modelsdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print(device)#数据预处理 , 从头data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),"val": transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}data_root = os.path.abspath(os.path.join(os.getcwd(), "../../.."))# get data root pathimage_path = data_root + "/data_set/flower_data/"# flower data set pathhtrain_dataset = datasets.ImageFolder(root=image_path + "/train",transform=data_transform["train"])train_num = len(train_dataset)# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}flower_list = train_dataset.class_to_idxcla_dict = dict((val, key) for key, val in flower_list.items())# write dict into json filejson_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 20train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=0)validate_dataset = datasets.ImageFolder(root=image_path + "val",transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=0)# test_data_iter = iter(validate_loader)# test_image, test_label = test_data_iter.next()# model# = models.vgg16(pretrained=True)## model_name = "vgg16"# net = vgg(model_name=model_name, init_weights=True)# load pretrain weightsnet = models.vgg16(pretrained=False)pre = torch.load("./vgg16.pth")net.load_state_dict(pre)for parma in net.parameters():parma.requires_grad = Falsenet.classifier = torch.nn.Sequential(torch.nn.Linear(25088, 4096),torch.nn.ReLU(),torch.nn.Dropout(p=0.5),torch.nn.Linear(4096, 4096),torch.nn.ReLU(),torch.nn.Dropout(p=0.5),torch.nn.Linear(4096, 5))# model_weight_path = "./vgg16.pth"# missing_keys, unexpected_keys = net.load_state_dict(torch.load(model_weight_path), strict=False)#载入模型参数# # for param in net.parameters():# #param.requires_grad = False# # change fc layer structure## inchannel = 512# net.classifier = nn.Linear(inchannel, 5)loss_function = torch.nn.CrossEntropyLoss()optimizer = optim.Adam(net.classifier.parameters(), lr=0.001)# loss_function = nn.CrossEntropyLoss()# optimizer = optim.Adam(net.parameters(), lr=0.0001) #learn ratenet.to(device)best_acc = 0.0#save_path = './{}Net.pth'.format(model_name)save_path = './vgg16Net.pth'for epoch in range(15):# trainnet.train()running_loss = 0.0 #统计训练过程中的平均损失t1 = time.perf_counter()for step, data in enumerate(train_loader, start=0):images, labels = dataoptimizer.zero_grad()#with torch.no_grad(): #用来消除验证阶段的loss , 由于梯度在验证阶段不能传回 , 造成梯度的累计outputs = net(images.to(device))loss = loss_function(outputs, labels.to(device))#得到预测值与真实值的一个损失loss.backward()optimizer.step()#更新结点参数# print statisticsrunning_loss += loss.item()# print train processrate = (step + 1) / len(train_loader)a = "*" * int(rate * 50)b = "." * int((1 - rate) * 50)print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")print()print(time.perf_counter() - t1)# validatenet.eval()acc = 0.0# accumulate accurate number / epochwith torch.no_grad():#不去跟踪损失梯度for val_data in validate_loader:val_images, val_labels = val_data#optimizer.zero_grad()outputs = net(val_images.to(device))predict_y = torch.max(outputs, dim=1)[1]acc += (predict_y == val_labels.to(device)).sum().item()val_accurate = acc / val_numif val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('[epoch %d] train_loss: %.3ftest_accuracy: %.3f' %(epoch + 1, running_loss / step, val_accurate))print('Finished Training')


推荐阅读