pytorch api文档:torch.save()函数-保存整个模型、模型状态字典(可训练参数) 作者:马育民 • 2026-01-30 12:40 • 阅读:10005 # 介绍 PyTorch 的 `torch.save()` 函数,用于**保存模型/张量/字典等数据**的核心函数 # 语法 ```python torch.save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True) ``` #### 参数说明 | 参数 | 作用 | 新手常用值 | |------|------|------------| | `obj` | **要保存的对象**(核心),可是模型、张量、列表、字典、优化器等任意Python对象 | 模型/模型状态字典/张量 | | `f` | 保存路径+文件名,支持字符串路径、文件对象 | 如 `'model.pth'`/`'checkpoint.pt'`(后缀常用.pth/.pt,无强制要求) | | 其余参数 | 底层序列化相关,默认值完全满足日常使用,**无需修改** | - | # 例子 ```python import torch # 1. 保存单个张量 x = torch.tensor([1,2,3,4]) torch.save(x, 'tensor.pth') # 2. 保存字典(实战常用,可同时存多个数据) data = { 'epoch': 50, # 训练轮数 'loss': 0.01, # 最终损失 'weight': x # 张量/模型参数 } torch.save(data, 'data_dict.pth') ``` # 应用场景 这是 `torch.save()` 最常用的场景,**有两种保存方式**,新手务必区分(推荐使用**方式2**,工业界标准) ### 准备 先定义一个简单的测试模型: ```python import torch import torch.nn as nn # 定义测试模型 class SimpleModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(10, 2) # 输入10维,输出2维 def forward(self, x): return self.linear(x) model = SimpleModel() # 实例化模型 optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # 定义优化器 ``` ### 方式1:保存**整个模型**(不推荐,移植性差) 直接将模型对象传入 `obj`,会保存模型的**完整结构+参数** **缺点:**保存的文件更大、跨PyTorch版本/跨设备(CPU/GPU)时容易加载失败。 ```python # 保存整个模型 torch.save(model, 'whole_model.pth') ``` 使用 `torch.load()` 加载: ``` # 1. 加载整个模型(不推荐,直接用) model_whole = torch.load('whole_model.pth', map_location='cpu') ``` ### 方式2-1:保存**模型状态字典(state_dict)**(强烈推荐) PyTorch中,`nn.Module` 类型的模型都有 `state_dict()` 方法,返回一个**有序字典**,仅保存模型的**可训练参数**(如卷积核、全连接层权重),不保存模型结构。 **优点:**文件小、移植性强、灵活度高(加载时只需先定义模型结构,再加载参数)。 ```python # 保存模型状态字典(核心推荐) torch.save(model.state_dict(), 'model_state_dict.pth') ``` 使用 `model.load_state_dict() `加载: ``` # 2. 加载模型状态字典(推荐,核心步骤:实例化模型 → 加载参数 → 加载到模型) model = SimpleModel() # 先实例化模型(结构必须和保存时一致) model.load_state_dict(torch.load('model_state_dict.pth', map_location='cpu')) # 关键:load_state_dict() model.eval() # 推理前必须调用,关闭Dropout/BatchNorm等层 ``` ### 方式2-2:保存训练断点(强烈推荐) ``` # 进阶:保存训练断点(模型参数+优化器参数+训练信息) checkpoint = { 'epoch': 50, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'train_loss': 0.01, 'val_acc': 0.98 } torch.save(checkpoint, 'checkpoint.pth') # 断点续训必备 ``` 使用 `model.load_state_dict() `加载: ``` # 3. 加载训练断点(断点续训) model = SimpleModel() optimizer = optim.SGD(model.parameters(), lr=0.01) # 先实例化优化器(参数需和保存时一致) checkpoint = torch.load('checkpoint.pth', map_location='cpu') model.load_state_dict(checkpoint['model_state_dict']) # 加载模型参数 optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # 加载优化器参数 epoch = checkpoint['epoch'] # 恢复训练轮数 train_loss = checkpoint['train_loss'] # 恢复训练信息 # 继续训练时,直接从epoch+1开始 model.train() # 训练前调用,开启Dropout/BatchNorm等层 ``` # 注意(避坑重点) 1. **文件后缀**:PyTorch官方推荐用 `.pth` 或 `.pt`,也可用 `.pkl`,**无强制要求**,只是约定俗成的标识,不影响加载。 2. **GPU/CPU 跨设备保存/加载** - GPU保存,CPU加载:必须指定 `map_location='cpu'`,否则会报CUDA设备不存在错误; - CPU保存,GPU加载:可指定 `map_location='cuda:0'`,或加载后用 `model.cuda()` 移到GPU; - 多GPU保存,单GPU加载:先按单GPU定义模型,再加载参数(需配合 `nn.DataParallel` 解包,新手暂不深究)。 3. **推理前必做**:加载模型后,若用于**预测/验证**,必须调用 `model.eval()`,否则Dropout、BatchNorm等层会随机改变输出,导致推理结果不一致。 4. **断点续训必做**:加载优化器参数后,若继续训练,需调用 `model.train()`,开启训练模式。 5. **序列化安全**:`torch.save()` 基于Python的 `pickle` 模块序列化,**不要加载未知来源的.pth文件**,存在安全风险。 # 完整实战流程(保存+加载+推理) ```python import torch import torch.nn as nn # 1. 定义并初始化模型 class SimpleModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(10, 2) def forward(self, x): return self.linear(x) model = SimpleModel() # 模拟训练:给模型赋随机参数(实际训练后是训练好的参数) for param in model.parameters(): torch.nn.init.normal_(param) # 2. 保存训练好的模型(状态字典,推荐) torch.save(model.state_dict(), 'trained_model.pth') print("模型保存完成!") # 3. 加载模型 model_load = SimpleModel() # 重新实例化模型 model_load.load_state_dict(torch.load('trained_model.pth', map_location='cpu')) model_load.eval() # 推理模式 print("模型加载完成!") # 4. 用加载的模型推理 x_test = torch.randn(1, 10) # 测试输入:batch_size=1,维度=10 with torch.no_grad(): # 推理时关闭梯度计算,节省资源 output = model_load(x_test) print("推理输出:", output) ``` # 总结 1. `torch.save(obj, f)` 是PyTorch的**通用保存函数**,`obj` 支持任意Python对象,核心用于保存模型; 2. 保存模型**优先用状态字典(model.state_dict())**,而非整个模型,兼顾文件大小和移植性; 3. 训练断点需保存**模型+优化器状态字典+训练信息**,配合 `torch.load()` 实现断点续训; 4. 跨设备加载必须指定 `map_location`,推理前调用 `model.eval()`,训练前调用 `model.train()`; 5. 加载状态字典的核心步骤:**先定义模型结构 → 实例化 → 调用model.load_state_dict()加载参数**。 原文出处:http://www.malaoshi.top/show_1GW2gdkWbINz.html