访问作者github: https://github.com/NefelibataBIGR/PyTorch_Notes ,获取笔记代码
七、现有模型的使用、修改、保存、加载
具体代码见Python>pytorch文件中的 ==recent_models.ipynb== 文件
以torchvision为例(视觉识别)
- 一般模型参考官方文档:Models and pre-trained weights — Torchvision 0.21 documentation
- VGG16模型参考官方文档:vgg16 — Torchvision 0.21 documentation
- 具体代码见代码文件
模型使用:
- 见官方文档
模型修改:
- 添加一层:xxx.add_module(“name” , 添加的层)
- 改里面的层:xxx.sequential_name[“name”] = 修改成的层
保存模型:
- 保存整个模型:torch.save(xxx , “xxx.pth”)
- 只保存参数:torch.save(xxx.state_dict() , “xxx.pth”)
- 空间占更小,官方推荐
加载模型:
- 加载整个模型:
- torch.load(“xxx.pth”, weights_only=False)
- weights_only=False:加载整个模型
- weights_only=True(默认):只加载模型的参数,不加载模型本身
- torch.load(“xxx.pth”, weights_only=False)
- 只加载参数:
- torch.load(“xxx_dict.pth”)
- “xxx_dict.pth”为torch.save(xxx.state_dict() , “xxx.pth”)保存的参数
- torch.load(“xxx_dict.pth”)
- 加载整个模型:
模型+参数组合:
- xxx.load_state_dict(yyy_dict)
- xxx:没训练过的模型
- yyy_dict:参数字典(用只保存参数的方法保存模型得到的字典)
- xxx.load_state_dict(yyy_dict)

评论
使用 GitHub 账号留言 · 评论存在 仓库 issues