1.保存模型的dict
#...
XXXmodel = model(*args, **kwargs)
torch.save(XXXmodel.state_dict(),'onlyDict.pkl'):
读取dict
pre_model = model(*args, **kwargs)
pre_model.load_state_dict(torch.load('onlyDict.pkl'))
2.保存整个模型(费时间、内存)
#...
XXXmodel = model(*args, **kwargs)
torch.save(XXXmodel,'fullmodel.pkl'):
读取model
pre_model = torch.load('fullmodel.pkl')