pytorch api文档:torch.load()函数-加载整个模型、模型状态字典(可训练参数) 作者:马育民 • 2026-01-30 15:07 • 阅读:10006 # 介绍 PyTorch 中`torch.load()`的用法,这个函数是 PyTorch 中加载模型、张量、字典等序列化文件的核心方法,常和`torch.save()`配合使用,支持加载`.pth`/`.pt`格式的文件。 ### 作用 读取`torch.save()`保存的文件,还原为原本的 PyTorch 对象(张量、模型、字典等) # 语法 ```python torch.load(f, map_location=None, pickle_module=pickle, weights_only=False, **kwargs) ``` #### 参数解释 **`map_location`和`weights_only`是日常使用中最常用的**: 1. **`f`**:必选参数,文件路径(字符串/路径对象)或已打开的文件流,比如`"model.pth"`。 2. **`map_location`**:**指定加载数据的设备**(CPU/GPU),解决「保存时用GPU,加载时用CPU/其他GPU」的设备不匹配问题,核心用法: - 加载到**CPU**(最常用,跨设备兼容):`map_location="cpu"` - 加载到**指定GPU**(比如GPU 0):`map_location="cuda:0"` - 动态适配(有GPU用GPU,无则用CPU):`map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu")` 3. **`weights_only`**:**安全参数,建议设置为True**(PyTorch 1.13+支持),限制仅加载张量/参数,禁止加载任意Python对象,避免恶意文件的代码注入风险,**生产环境/第三方文件加载必须开启**。 4. **`pickle_module`**:底层序列化模块,默认用Python的`pickle`,一般无需修改。 # 简单例子 先通过`torch.save()`保存对象,再用`torch.load()`加载,形成完整流程: ```python # 1. 保存张量 x = torch.tensor([1,2,3]) torch.save(x, "tensor.pth") # 加载张量 x_load = torch.load("tensor.pth") print(x_load) # tensor([1, 2, 3]) # 2. 保存字典(最常用,比如同时保存模型参数、优化器参数、训练轮数) checkpoint = { "model_state_dict": model.state_dict(), # 模型参数 "optimizer_state_dict": optimizer.state_dict(), # 优化器参数 "epoch": 100, # 训练到第100轮 "loss": 0.01 # 最终损失 } torch.save(checkpoint, "checkpoint.pth") # 加载字典 ckpt = torch.load("checkpoint.pth") print(ckpt["epoch"]) # 100 ``` # 应用场景 这是`torch.load()`最核心的使用场景,分**加载完整模型**和**加载模型参数(推荐)**两种,后者更灵活(解耦模型结构和参数),是工业界标准做法。 ### 场景1:加载模型参数(推荐) 步骤:先定义模型结构 → 加载保存的参数字典 → 将参数加载到模型中 ```python import torch import torch.nn as nn # 1. 先定义和保存时完全一致的模型结构 class MyModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(3, 1) def forward(self, x): return self.linear(x) model = MyModel() # 初始化模型(空参数) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # 初始化优化器 # 2. 加载检查点(字典),指定设备+开启安全模式 ckpt = torch.load( "checkpoint.pth", map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"), weights_only=True # 安全加载 ) # 3. 将参数加载到模型和优化器中(断点续训关键) model.load_state_dict(ckpt["model_state_dict"]) # 加载模型参数 optimizer.load_state_dict(ckpt["optimizer_state_dict"]) # 加载优化器参数 start_epoch = ckpt["epoch"] + 1 # 从下一轮开始训练 best_loss = ckpt["loss"] # 加载历史最优损失 # 4. 模型启用(训练/推理) model.train() # 训练模式 # model.eval() # 推理/测试模式 ``` ### 场景2:加载完整模型(不推荐,灵活性差) 如果保存时直接保存了整个模型对象(`torch.save(model, "full_model.pth")`),可直接加载,但**缺点**是模型结构和文件强绑定,修改模型代码后无法加载,仅适合快速测试: ```python model = torch.load( "full_model.pth", map_location="cpu", weights_only=True ) model.eval() # 推理前必须切换为评估模式 ``` # 注意(避坑重点) 1. **设备匹配问题**:如果保存时用GPU(张量/模型在`cuda`上),直接在CPU环境加载会报错,**必须通过`map_location`指定设备**,这是新手最常踩的坑。 ✅ 正确做法:加载时统一指定`map_location`,跨设备兼容。 2. **安全问题**:**加载非自己保存的文件时,必须设置`weights_only=True`**,禁止加载未知文件的任意Python对象,防止恶意代码执行。 3. **模型结构一致**:加载`state_dict`(模型参数)时,**定义的模型结构必须和保存时完全一致**(层的名称、维度、数量不能变),否则会报`key mismatch`错误。 - 解决小修改的兼容问题:加载时用`model.load_state_dict(ckpt["model_state_dict"], strict=False)`,忽略不匹配的层(谨慎使用)。 4. **PyTorch版本兼容**:高版本PyTorch保存的文件,低版本可能无法加载,建议保持训练和部署的PyTorch版本一致(至少大版本一致,如2.0+)。 5. **文件路径问题**:确保加载的文件路径正确,相对路径以**当前运行脚本的目录**为基准,绝对路径更稳妥。 # 拓展:加载多个对象/自定义对象 如果`torch.save()`保存了元组/列表等多个对象,`torch.load()`会还原为对应结构,直接解包即可: ```python # 保存多个对象 a = torch.tensor([1,2]) b = torch.tensor([3,4]) torch.save((a, b), "two_tensors.pth") # 加载并解包 a_load, b_load = torch.load("two_tensors.pth", weights_only=True) print(a_load, b_load) # tensor([1, 2]) tensor([3, 4]) ``` 如果保存了**自定义类/对象**,加载时必须保证当前环境中有该类的定义(否则会报`AttributeError`)。 --- # 总结 `torch.load()`是PyTorch序列化加载的核心,核心要点记3个: 1. **核心搭配**:和`torch.save()`配合,主要加载`.pth`/`.pt`文件,支持张量、模型、字典等所有PyTorch对象; 2. **必设参数**:生产环境/日常使用建议指定`map_location`(设备兼容)+`weights_only=True`(安全加载); 3. **主流用法**:加载模型时优先用「定义结构+加载`state_dict`」的方式,而非直接加载完整模型,灵活性和解耦性更强。 附**加载模型的万能模板**:直接复制修改模型结构和文件路径,即可适配99%的场景。 原文出处:http://www.malaoshi.top/show_1GW2gdrxR5YJ.html