PyTorch 5:现有模型的使用、修改、保存、加载

2026-05-11 深度学习 826 字 · 约 3 分钟

访问作者github: https://github.com/NefelibataBIGR/PyTorch_Notes ,获取笔记代码

七、现有模型的使用、修改、保存、加载

具体代码见Python>pytorch文件中的 ==recent_models.ipynb== 文件

  • 以torchvision为例(视觉识别)

  • 模型使用:

    • 见官方文档
  • 模型修改:

    • 添加一层:xxx.add_module(“name” , 添加的层)
    • 改里面的层:xxx.sequential_name[“name”] = 修改成的层
  • 保存模型:

    • 保存整个模型:torch.save(xxx , “xxx.pth”)
    • 只保存参数:torch.save(xxx.state_dict() , “xxx.pth”)
      • 空间占更小,官方推荐
  • 加载模型:

    • 加载整个模型:
      • torch.load(“xxx.pth”, weights_only=False)
        • weights_only=False:加载整个模型
        • weights_only=True(默认):只加载模型的参数,不加载模型本身
    • 只加载参数:
      • torch.load(“xxx_dict.pth”)
        • “xxx_dict.pth”为torch.save(xxx.state_dict() , “xxx.pth”)保存的参数
  • 模型+参数组合:

    • xxx.load_state_dict(yyy_dict)
      • xxx:没训练过的模型
      • yyy_dict:参数字典(用只保存参数的方法保存模型得到的字典)
#深度学习 #python #pytorch

评论

使用 GitHub 账号留言 · 评论存在 仓库 issues