pytorch api文档:nn.Module的.state_dict()-提取模型中可学习的参数和持久化的缓冲区 作者:马育民 • 2026-01-30 15:54 • 阅读:10005 # 介绍 PyTorch 中,`nn.Module` 的 `state_dict()` 方法,用于**提取模型中可学习的参数和持久化的缓冲区**,返回一个 **有序字典**(`OrderedDict`),键是参数/缓冲区的名称,值是对应的张量(`Tensor`)。 是 PyTorch 中**保存/加载模型参数**、**迁移学习**的基础。 ### 作用 `state_dict()` 返回 **有序字典(OrderedDict)**: - 键是模型层的名称(如`conv1.weight`、`fc.bias`) - 值是对应层的可学习参数(权重/偏置,tensor类型) **包含:** - **模型中需要被优化更新的可学习参数:**如`nn.Linear`/`nn.Conv2d`的`weight`、`bias` - **持久化的缓冲区:**如`nn.BatchNorm2d`的`running_mean`、`running_var`,这类数据不参与梯度下降,但需要随模型保存 **不包含:** - 模型的结构 - 无参数层(`nn.ReLU`/`nn.Sequential`)不会出现在其中 - 优化器状态,优化器有自己独立的`state_dict()`(保存学习率、动量、梯度累积等)。 # 例子 先定义一个简单的自定义模型,再调用`state_dict()`直观感受其返回结果: ```python import torch import torch.nn as nn # 定义简单模型 class SimpleModel(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 16, 3) # 卷积层:有weight/bias self.bn1 = nn.BatchNorm2d(16) # BN层:有weight/bias + running_mean/running_var self.fc1 = nn.Linear(16*30*30, 10)# 全连接层:有weight/bias def forward(self, x): x = self.bn1(torch.relu(self.conv1(x))) x = x.flatten(1) return self.fc1(x) # 实例化模型 model = SimpleModel() # 获取模型的state_dict model_state = model.state_dict() # 打印state_dict的键(参数/缓冲区名称) print("model.state_dict()的键:") for key in model_state.keys(): print(key) print("\n某一参数的形状:", model_state["conv1.weight"].shape) ``` **输出结果**(键的命名规则:`层名.参数/缓冲区名`): ``` model.state_dict()的键: conv1.weight conv1.bias bn1.weight bn1.bias bn1.running_mean bn1.running_var bn1.num_batches_tracked fc1.weight fc1.bias 某一参数的形状: torch.Size([16, 3, 3, 3]) ``` 可以看到: 1. 卷积/全连接层的`weight`和`bias`被保存; 2. BN层除了`weight/bias`,还有`running_mean`(均值)、`running_var`(方差)、`num_batches_tracked`(训练批次计数)这些缓冲区; 3. 键的命名与模型中定义的层名严格对应,嵌套层(如`nn.Sequential`)的键会是`seq层名.子层索引.参数名`(例:`features.0.weight`)。 # 特性 1. **仅保存可训练/持久化张量**:模型中手动定义的普通张量(未用`nn.Parameter`包装)不会被保存,只有`nn.Parameter`(可学习参数)和`nn.Buffer`(持久化缓冲区)会被纳入; 2. **返回OrderedDict**:有序字典保证参数的顺序与模型定义一致,加载时不会因顺序问题出错(PyTorch 1.7+也兼容普通dict,但推荐用原生OrderedDict); 3. **子模块自动递归收集**:如果模型包含子模块(如自定义层、`nn.Sequential`、`nn.ModuleList`),`state_dict()`会自动递归收集所有子模块的参数/缓冲区,无需手动处理; 4. **eval/train模式不影响内容**:模型在训练(train)或评估(eval)模式下,`state_dict()`的内容完全一致,仅影响BN/Dropout的运行逻辑,不改变参数本身。 # 应用场景(保存/加载模型) `state_dict()`的最主要用途是**模型的保存与加载**,也是PyTorch推荐的方式(比直接保存模型实例更灵活、更省空间)。 ### 场景1:仅保存/加载模型参数(最常用) 适合模型结构不变,仅恢复参数的场景(如继续训练、模型推理): ```python # 1. 保存模型state_dict(推荐存为.pth/.pt格式) torch.save(model.state_dict(), "simple_model_params.pth") # 2. 加载模型参数(步骤:先实例化模型结构,再加载参数) new_model = SimpleModel() # 必须先定义和原模型完全一致的结构 new_model.load_state_dict(torch.load("simple_model_params.pth")) new_model.eval() # 推理前务必切换到评估模式(关闭BN/Dropout的训练特性) ``` ### 场景2:保存/加载「模型+优化器」状态(继续训练) 如果需要中断训练后继续训练,需要同时保存模型参数和优化器状态(优化器的`state_dict()`保存学习率、动量等): ```python # 定义优化器 optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # 保存模型+优化器的state_dict(存入一个字典) checkpoint = { "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "epoch": 50, # 还可以保存当前训练轮数、损失值等 "loss": 0.123 } torch.save(checkpoint, "train_checkpoint.pth") # 加载检查点继续训练 checkpoint = torch.load("train_checkpoint.pth") model.load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optimizer_state"]) start_epoch = checkpoint["epoch"] + 1 last_loss = checkpoint["loss"] ``` ### 场景3:部分加载参数(模型微调/迁移学习) 实际开发中常遇到**预训练模型结构与目标模型不完全一致**的情况(如最后一层输出类别不同),此时可通过过滤键来部分加载参数: ```python # 假设预训练模型是SimpleModel(输出10类),目标模型是NewModel(输出20类,仅fc1层不同) class NewModel(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 16, 3) self.bn1 = nn.BatchNorm2d(16) self.fc1 = nn.Linear(16*30*30, 20) # 输出20类,与原模型不同 # 加载预训练参数,过滤掉fc1层的参数 pretrained_state = torch.load("simple_model_params.pth") new_model = NewModel() # 过滤键:保留除fc1.weight/fc1.bias外的所有参数 new_state = {k: v for k, v in pretrained_state.items() if not k.startswith("fc1.")} # strict=False:允许模型与state_dict的键不完全匹配(未匹配的参数随机初始化) new_model.load_state_dict(new_state, strict=False) ``` # 与model.parameters()的区别 容易混淆`state_dict()`和`model.parameters()`,区别如下: | 特性 | `model.state_dict()` | `model.parameters()` | |---------------------|----------------------------|-----------------------------| | 返回类型 | 有序字典(键+张量值)| 生成器(仅张量,无键)| | 包含内容 | 可学习参数 + 持久化缓冲区 | 仅可学习参数(无缓冲区)| | 核心用途 | 保存/加载/迁移模型参数 | 传给优化器(指定待优化参数)| | 是否可直接索引 | 是(如`state['conv1.weight']`) | 否(需转列表,如`list(params)[0]`) | **示例**: ```python # parameters():仅返回可学习参数的生成器,用于优化器 optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # state_dict():返回带键的字典,用于保存 torch.save(model.state_dict(), "model.pth") ``` # 拓展:手动修改state_dict `state_dict()`返回的是有序字典,可直接修改其值(如修改参数、冻结部分层),修改后重新传给`load_state_dict()`即可生效: ```python # 冻结conv1层:将conv1的参数梯度置为False(等价于修改state_dict后加载) for k, v in model.state_dict().items(): if k.startswith("conv1."): v.requires_grad = False # 冻结参数,不参与梯度更新 model.load_state_dict(model.state_dict()) ``` # 总结 1. `nn.Module.state_dict()`是PyTorch模型参数管理的核心,返回**有序字典**,键为参数/缓冲区名称,值为对应张量,仅包含可学习参数(`nn.Parameter`)和持久化缓冲区(`nn.Buffer`); 2. 其核心用途是**模型保存/加载/迁移**,PyTorch推荐优先保存`state_dict()`而非模型实例,灵活性更高; 3. 加载参数时需先**实例化与原模型结构一致的模型**,`strict=False`支持部分加载参数(适用于微调/迁移学习); 4. 与`model.parameters()`的核心区别:`state_dict()`带键、包含缓冲区,用于保存;`parameters()`仅返回可学习参数生成器,用于传给优化器; 5. 优化器也有独立的`state_dict()`,保存训练相关状态,继续训练时需与模型`state_dict()`一起保存。 原文出处:http://www.malaoshi.top/show_1GW2gek3MZKA.html