损失函数-交叉熵(多分类) 作者:马育民 • 2019-10-31 15:34 • 阅读:10391 需要掌握:[softmax 概率归一化](https://www.malaoshi.top/show_1GW2biBfUWT5.html "softmax 概率归一化") # 交叉熵 用于处理 **逻辑回归** 问题 交叉熵刻画的是 **实际输出(概率)** 与 **期望输出(概率)** 的**距离**,也就是交叉熵的值 **越小**,两个概率分布就 **越接近**。 假设概率分布 **p** 为真实值, **q** 为实际输出,**H(p,q)** 为交叉熵,**公式如下**: [](https://www.malaoshi.top/upload/0/0/1EF3xMC4l6AO.png) 放大概率分布之间的损失 # 计算过程 ### 一、准备测试数据 为了方便手动计算,设定: - **模型输出logits**:`input = torch.tensor([[1.0, 2.0, 3.0]])`(batch_size=1,num_classes=3) - **真实标签**:`target = torch.tensor([2])`(类别索引为2,对应第三个类别) - 无权重、无ignore_index、默认reduction='mean' --- ### 二、手动分步计算 cross_entropy #### 步骤1:计算 softmax(将logits转为概率) softmax公式(详见[链接](https://www.malaoshi.top/show_1GW2biBfUWT5.html "链接")): $$p\_i = \frac{e^{z\_i}}{\sum\_{j=1}^C e^{z\_j}}$$ 解释:$$C$$ 为类别数,这里 $$C=3$$ - 其中,$$z\_0=1.0, z\_1=2.0, z\_2=3.0$$ - 分子: $$e^{1.0} \approx 2.7182$$,$$e^{2.0} \approx 7.3890$$,$$e^{3.0} \approx 20.0855$$ - 分母(总和):$$2.7182 + 7.3890 + 20.0855 = 30.1928$$ - softmax结果: - 结果:$$p\_0 = 2.7182/30.1928 \approx 0.0900$$ - 结果:$$p\_1 = 7.3890/30.1928 \approx 0.2447$$ - 结果:$$p\_2 = 20.0855/30.1928 \approx 0.6652$$ #### 步骤2:计算 log_softmax(对softmax结果取自然对数) $$\log(p_i)$$: - 结果:$$\log(0.0900) \approx -2.4076$$ - 结果:$$\log(0.2447) \approx -1.4076$$ - 结果:$$\log(0.6652) \approx -0.4076$$ #### 步骤3:计算负对数似然(NLL) 取真实标签对应位置的log_softmax值,取负数: - 真实标签是2 → 取 $$\log(p_2)$$,即:$$\log(0.6652)$$,所以是: $$ -0.4076$$ - NLL = $$-(-0.4076) = 0.4076$$ #### 步骤4:reduction='mean'(平均) 因为批次是 `1`,平均后结果仍为 **0.4076**(这就是最终的cross_entropy损失值)。 # 为什么使用对数 通过上面例子可知,如果不用对数,也就是用 **输出值(概率)直接相乘**,由于 **输出值(概率)**是小数,多个概率(`0~1` 之间的数)相乘,结果会 **越来越小** 一批数据中,有多个样本: - 样本 1:p1 =0.6652 - 样本 2:p2 =0.2447 ### 不用对数:直接连乘 ``` p1 * p2 = 0.6652 * 0.2447 ≈ 0.1628 ``` 因为概率是小数,数据越多,结果越会无限趋近于 `0`,计算机浮点数无法精确表示(**数值下溢**),后续计算完全失效 ### 使用对数:乘积变加法 对数的性质:`log(ab)=loga+logb`,能把 **乘法转加法** 如下: $$ \log(p\_1 \ast p\_2) = \log(p\_1)+ \log(p\_2 )$$ 加法运算,数值范围稳定,不会出现快速趋近于 0 的问题 # keras中的交叉熵 ### 多分类交叉熵损失函数 1. 目标值是顺序数字的,如:1,2,3,4 ``` sparse_categorical_crossentropy ``` 详见:https://www.malaoshi.top/show_1EF5F1ygQuiB.html 2. 目标值是one-hot独热编码的 ``` categorical_crossentropy ``` 详见:https://www.malaoshi.top/show_1EF5F1s5lR6z.html # pytorch中的交叉熵 详见:[pytorch api文档:nn.functional.cross_entropy()交叉熵损失函数](https://www.malaoshi.top/show_1GW2fy7aYCus.html "pytorch api文档:nn.functional.cross_entropy()交叉熵损失函数") 感谢: https://blog.csdn.net/qq_43258953/article/details/103642353 原文出处:http://www.malaoshi.top/show_1EF4La0KENhc.html