柠檬少年|迁移学习—通俗易懂地介绍(常见网络模型pytorch实现)( 五 )
densenet121
#train.pyimport torchimport torch.nn as nnfrom torchvision import transforms, datasetsimport jsonimport matplotlib.pyplot as pltfrom model import densenet121import osimport torch.optim as optimimport torchvision.models.densenetimport torchvision.models as modelsdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print(device)data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),#来自官网参数"val": transforms.Compose([transforms.Resize(256),#将最小边长缩放到256transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}data_root = os.path.abspath(os.path.join(os.getcwd(), "../../.."))# get data root pathimage_path = data_root + "/data_set/flower_data/"# flower data set pathtrain_dataset = datasets.ImageFolder(root=image_path + "train",transform=data_transform["train"])train_num = len(train_dataset)# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}flower_list = train_dataset.class_to_idxcla_dict = dict((val, key) for key, val in flower_list.items())# write dict into json filejson_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 16train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=0)validate_dataset = datasets.ImageFolder(root=image_path + "/val",transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=0)#迁移学习net = models.densenet121(pretrained=False)model_weight_path="./densenet121-a.pth"missing_keys, unexpected_keys = net.load_state_dict(torch.load(model_weight_path), strict= False)inchannel = net.classifier.in_featuresnet.classifier = nn.Linear(inchannel, 5)net.to(device)loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.0001)#普通# model_name = "densenet121"# net = densenet121(model_name=model_name, num_classes=5)best_acc = 0.0save_path = './densenet121.pth'for epoch in range(12):# trainnet.train()running_loss = 0.0for step, data in enumerate(train_loader, start=0):images, labels = dataoptimizer.zero_grad()logits = net(images.to(device))loss = loss_function(logits, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()# print train processrate = (step+1)/len(train_loader)a = "*" * int(rate * 50)b = "." * int((1 - rate) * 50)print("\rtrain loss: {:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100), a, b, loss), end="")print()# validatenet.eval()acc = 0.0# accumulate accurate number / epochwith torch.no_grad():for val_data in validate_loader:val_images, val_labels = val_dataoutputs = net(val_images.to(device))# eval model only have last output layer# loss = loss_function(outputs, test_labels)predict_y = torch.max(outputs, dim=1)[1]acc += (predict_y == val_labels.to(device)).sum().item()val_accurate = acc / val_numif val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('[epoch %d] train_loss: %.3ftest_accuracy: %.3f' %(epoch + 1, running_loss / step, val_accurate))print('Finished Training')
推荐阅读
- 金戈鐵馬|特朗普力挺铁杆支持者,枪杀两人的少年没错?抗议活动不断发酵
- 开封于七一|育迎宾尚法好少年,借温柔秋风多送法
- 少年|央视要搞选秀,热搜沸了!网友提名他当导师,点赞数第一
- 暖夏少年|2020电脑硬盘销量排行榜:七彩虹加入战局,硬是打倒了金士顿
- 少年一梦|2020畅销手机排行,iPhone无人超越,安卓旗舰全线溃败
- 少年帮|进入倒计时,华为突然宣布,供应链将迎来“洗牌”?
- 艾希大人|竹内结子拍摄杂志封面 柠檬黄上衣青春活力
- 少年帮|纯国产“龙芯”即将来临,正式确认?中科院宣布决定
- 上线|原创央视也搞成团选秀!《上线吧!华彩少年》有哪些优势?粉丝放心了
- 海报|央视首档少年成团选秀节目《上线吧!华彩少年》开始全球招募
