pytorch api文档:nn.Module的.load_state_dict()-加载模型权重 作者:马育民 • 2026-01-30 15:04 • 阅读:10003 # 介绍 PyTorch 中 `nn.Module` 的 `load_state_dict()` 方法的使用,这个方法是PyTorch中**加载模型权重**的核心方法,专门用于将预训练的权重参数(state_dict)加载到模型中,而非直接加载整个模型,是模型保存/加载、迁移学习的关键操作 ### 先理解:.state_dict() 在执行 `.load_state_dict()`前,必须先执行 `.state_dict()`,因为加载的本质就是**把 `.state_dict()` 键值对映射到模型的可学习参数上**。 关于 `.state_dict()` 详见 [pytorch api文档:nn.Module的.state_dict()-提取模型中可学习的参数和持久化的缓冲区](https://www.malaoshi.top/show_1GW2gek3MZKA.html "pytorch api文档:nn.Module的.state_dict()-提取模型中可学习的参数和持久化的缓冲区") # 语法 ```python model.load_state_dict(state_dict, strict=True) ``` #### 参数解释 - `state_dict`:要加载的权重字典(OrderedDict/dict类型,来自模型保存的`.pth`/`.pt`文件); - `strict`:布尔值,默认`True`,表示**加载的state_dict与模型的state_dict必须键完全匹配**,缺一不可、不能多;设为`False`时,会忽略不匹配的键(仅加载键一致的参数),迁移学习中常用。 #### 返回值 一个`NamedTuple`,包含`missing_keys`(模型有但`state_dict` 中没有的键)和`unexpected_keys`(`state_dict`中有但模型没有的键),方便排查加载问题。 # 例子 PyTorch中模型保存分**仅保存state_dict(推荐)**和**保存整个模型**,前者轻量、兼容性好,是工业界标准做法,结合`load_state_dict()`使用。 ### 步骤1:保存模型的state_dict ```python import torch import torch.nn as nn # 1. 定义并初始化模型 class SimpleModel(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 16, 3) self.fc1 = nn.Linear(16*30*30, 10) def forward(self, x): x = self.conv1(x) x = x.flatten(1) return self.fc1(x) model = SimpleModel() # 模拟训练:更新模型参数 x = torch.randn(1, 3, 32, 32) _ = model(x) # 2. 保存模型的state_dict(推荐路径:model_weights.pth) torch.save(model.state_dict(), "model_weights.pth") print("权重保存完成") ``` ### 步骤2:加载state_dict到模型(核心) **注意**:加载前必须**先实例化模型**(因为`load_state_dict()`是给已存在的模型“赋值权重”,而非创建模型),再调用方法。 ```python import torch import torch.nn as nn # 1. 先重新定义并实例化**相同结构**的模型 class SimpleModel(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 16, 3) self.fc1 = nn.Linear(16*30*30, 10) def forward(self, x): x = self.conv1(x) x = x.flatten(1) return self.fc1(x) model = SimpleModel() # 必须先实例化! # 2. 加载保存的state_dict文件 state_dict = torch.load("model_weights.pth") # 读取权重字典 # 3. 加载到模型中 load_info = model.load_state_dict(state_dict) # 4. 查看加载信息(可选,排查问题) print("缺失的键:", load_info.missing_keys) print("多余的键:", load_info.unexpected_keys) # 正常情况输出:缺失的键:[] 多余的键:[] print("权重加载完成!") ``` # 进阶用法 实际开发中(如迁移学习、模型微调),常遇到**模型结构与预训练权重的state_dict键不匹配**的情况,这是`load_state_dict()`的高频使用场景,核心是通过`strict=False`或**修改state_dict的键**解决。 ### 场景1:迁移学习(加载预训练权重,忽略不匹配层) 比如用ResNet50做分类任务,预训练权重是ImageNet的1000类分类头,而你的任务是10类,此时分类头的键不匹配,用`strict=False`忽略: ```python import torch import torch.nn as nn from torchvision import models # 1. 加载预训练ResNet50,去掉原有分类头 resnet50 = models.resnet50(pretrained=True) # 加载预训练权重 # 替换为自定义分类头(10类) resnet50.fc = nn.Linear(2048, 10) # 2. 假设手动保存了预训练ResNet50的state_dict(含原fc层) torch.save(models.resnet50(pretrained=True).state_dict(), "resnet50_pretrain.pth") pretrain_state_dict = torch.load("resnet50_pretrain.pth") # 3. 加载预训练权重,忽略新模型没有的原fc层(strict=False) load_info = resnet50.load_state_dict(pretrain_state_dict, strict=False) print("缺失的键:", load_info.missing_keys) # 输出:['fc.weight', 'fc.bias'](新分类头无预训练权重) print("多余的键:", load_info.unexpected_keys) # 输出:[] ``` ### 场景2:手动修改state_dict的键,匹配模型结构 比如预训练权重的键带前缀(如`module.conv1.weight`,多GPU训练保存的权重会带`module.`),而单GPU模型的键是`conv1.weight`,需要去掉前缀再加载: ```python # 模拟多GPU保存的state_dict(键带module.) pretrain_state_dict = {f"module.{k}": v for k, v in model.state_dict().items()} # 去掉module.前缀,匹配单GPU模型的键 new_state_dict = {k.replace("module.", ""): v for k, v in pretrain_state_dict.items()} # 加载修改后的state_dict model.load_state_dict(new_state_dict) ``` ### 场景3:加载部分权重,冻结已加载层 迁移学习中,常加载骨干网络的权重并冻结,只训练分类头: ```python # 加载预训练权重(strict=False) model.load_state_dict(pretrain_state_dict, strict=False) # 冻结骨干网络的所有层(设置requires_grad=False) for param in model.conv1.parameters(): param.requires_grad = False # 仅训练fc层(fc层参数默认requires_grad=True) optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3) ``` # 常见错误及解决办法 1. **错误**:`RuntimeError: Error(s) in loading state_dict for SimpleModel: Missing key(s) in state_dict: ...` - 原因:`strict=True`时,模型的键在加载的state_dict中不存在; - 解决:确认模型结构与保存权重时一致,或设置`strict=False`忽略缺失键。 2. **错误**:`RuntimeError: Error(s) in loading state_dict for SimpleModel: Unexpected key(s) in state_dict: ...` - 原因:`strict=True`时,加载的state_dict有模型没有的键; - 解决:删除state_dict中多余的键,或设置`strict=False`忽略。 3. **错误**:`AttributeError: 'collections.OrderedDict' object has no attribute 'load_state_dict'` - 原因:直接对`state_dict`调用`load_state_dict()`(正确是**模型实例**调用); - 解决:先实例化模型(`model = SimpleModel()`),再用`model.load_state_dict(state_dict)`。 4. **错误**:加载后模型预测结果不对 - 原因:忘记将模型切换到对应模式(训练/评估); - 解决:加载权重后,若用于预测,执行`model.eval()`(关闭Dropout/BatchNorm的训练模式)。 # 补充:与torch.load()的区别 很多新手会混淆`torch.load()`和`load_state_dict()`,核心区别: - `torch.load()`:**读取文件**的方法,将`.pth`/`.pt`文件中的数据(可以是state_dict、整个模型、优化器参数等)加载到内存,返回的是文件中保存的原始对象(如OrderedDict); - `load_state_dict()`:**模型的方法**,专门将`torch.load()`读取的**权重字典**映射到模型的可学习参数上,是“赋值”操作。 **错误用法**(直接加载文件到模型): ```python model.load_state_dict(torch.load("model_weights.pth")) # 看似正确,实际是torch.load()先读文件返回state_dict,再传入方法,这是合法的简写! # 上述代码等价于: state_dict = torch.load("model_weights.pth") model.load_state_dict(state_dict) ``` **注意:**上面的“简写”是合法的,新手容易误以为是错误,实际只是两步合并为一步,推荐分开写(方便排查 `state_dict` 问题)。 # 总结 1. `load_state_dict()`是PyTorch**加载模型权重**的核心方法,作用是将state_dict的键值对映射到模型的可学习参数,**必须先实例化模型再调用**; 2. 核心参数`strict`:`True`(默认)要求键完全匹配,`False`忽略不匹配键,是迁移学习的关键设置; 3. 最佳实践:模型保存用`torch.save(model.state_dict(), path)`,加载用`model.load_state_dict(torch.load(path), strict=...)`; 4. 多GPU保存的权重带`module.`前缀,需手动去掉后再加载到单GPU模型;加载后根据场景执行`model.eval()`(预测)或`model.train()`(训练)。 原文出处:http://www.malaoshi.top/show_1GW2gfHOO6Yb.html