一、网络模型的保存与读取方式1
方法讲解
保存模型
import torch
import torchvision
model = torchvision.models.vgg16(weights='DEFAULT')
#保存模型和参数
torch.save(model,"save_method1.pth")
读取模型
import torch
model = torch.load("save_method1.pth")
print(model)
输出:
比较坑人的点
使用 torch.save 必须将该模型的架构引入到该文件中(可以使用from A import B
的方式来解决),这里举一个例子来说明
保存模型
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear#保存模型和参数class Mary(nn.Module):def __init__(self):super(Mary,self).__init__()self.model1 = nn.Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self,x):x = self.model1(x)return x
Yorelee = Mary()
torch.save(Yorelee,"save_method1_question.pth")
读取模型
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linearmodel = torch.load("save_method1_question.pth")print(model)
报错如下
说明我们还要把 Mary 这个框架复制到读取模型的.py文件中
重新更正后的读取模型代码
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linearclass Mary(nn.Module):def __init__(self):super(Mary,self).__init__()self.model1 = nn.Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self,x):x = self.model1(x)return xmodel = torch.load("save_method1_question.pth")print(model)
或者
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
from torch_save import Mary #这里仅举一个例子model = torch.load("save_method1_question.pth")print(model)
二、网络模型的保存与读取方式2
保存模型参数
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linearvgg_model = torchvision.models.vgg16(weights='DEFAULT')
#保存参数
torch.save(vgg_model.state_dict(),"save_method2.pth")
读取模型参数
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linearvgg_model = torchvision.models.vgg16(weights='DEFAULT')
parameter = torch.load("save_method2.pth")
vgg_model.load_state_dict(parameter)
print(vgg_model)