邏輯回歸為什么使用交叉熵而不用均方差?或者說邏輯回歸的損失函數(shù)為什么不用最小二乘?
下面主要從兩個角度進行闡述:
- 從邏輯回歸的角度出發(fā),邏輯回歸的預(yù)測值是一個概率,而交叉熵又表示真實概率分布與預(yù)測概率分布的相似程度,因此選擇使用交叉熵
- 從均方差(MSE)的角度來說,預(yù)測值概率與歐式距離沒有任何關(guān)系,并且在分類問題中,樣本的值不存在大小比較關(guān)系,與歐式距離更無關(guān)系,因此不適用MSE
1、損失函數(shù)的凸性(使用MSE可能會陷入局部最優(yōu))
前面我們在介紹線性回歸時,我們用到的損失函數(shù)是誤差(殘差)平方和
L = ∑ i = 1 m ( y i ? y ^ i ) 2 = ∑ i = 1 m ( y i ? x i ω ) 2 L=\sum_{i=1}^m(y_i-\hat y_i)^2=\sum_{i=1}^m(y_i-x_i\omega)^2 L=i=1∑m?(yi??y^?i?)2=i=1∑m?(yi??xi?ω)2
這是一個凸函數(shù),有全局最優(yōu)解
如果邏輯回歸也使用誤差平方和,由于邏輯回歸假設(shè)函數(shù)的外層函數(shù)是Sigmoid
函數(shù),Sigmoid
函數(shù)是一個復雜的非線性函數(shù),這就使得我們將邏輯回歸的假設(shè)函數(shù)代入上式時,即
L
=
∑
i
=
1
m
(
y
i
?
1
1
+
e
?
x
i
ω
)
2
L=\sum_{i=1}^m\left(y_i-\frac{1}{1+e^{-x_i\omega}}\right)^2
L=i=1∑m?(yi??1+e?xi?ω1?)2
那么,我們得到的 L L L是一個非凸函數(shù),不易優(yōu)化,容易陷入局部最優(yōu)解。所以邏輯回歸的損失函數(shù)使用的是對數(shù)損失函數(shù)(Log Loss Function)
在邏輯回歸(詳見:傳送門)一文中,我們已經(jīng)給出了KL散度與交叉熵的關(guān)系
交叉熵
=
K
L
散度
+
信息熵
交叉熵=KL散度+信息熵
交叉熵=KL散度+信息熵
即交叉熵等于KL散度加上信息熵。而信息熵是一個常數(shù),并且在計算的時候,交叉熵相較于KL散度更容易,所以我們直接使用了交叉熵作為損失函數(shù)
因此,我們在最小化交叉熵的時候,實際上就是在最小化 KL散度,也就是在讓預(yù)測概率分布盡可能地與真實概率分布相似
2、MSE的損失小于交叉熵的損失(導致對分類錯誤點的懲罰不夠)
邏輯回歸的數(shù)學表達式如下
h
θ
(
x
)
=
g
(
θ
T
x
)
=
1
1
+
e
?
θ
T
x
h_\theta(x)=\rm g(\theta^Tx)=\frac{1}{1+e^{-\theta^Tx}}
hθ?(x)=g(θTx)=1+e?θTx1?
對于一元邏輯回歸,其預(yù)測值為
y
^
=
σ
(
ω
x
+
b
)
\hat y = \sigma(\omega x+b)
y^?=σ(ωx+b)
其中,
σ
\sigma
σ為Sigmoid
函數(shù)
如果使用均方差作為損失函數(shù),我們以一個樣本為例,為方便計算,我們給均方差除以2(不改變函數(shù)的單調(diào)性)
C
=
1
2
(
y
?
y
^
)
2
C=\frac{1}{2}(y-\hat y)^2
C=21?(y?y^?)2
其中
y
^
\hat y
y^?=
σ
(
z
)
\sigma(z)
σ(z)=
1
1
+
e
?
z
\frac{1}{1+e^{-z}}
1+e?z1?,
z
z
z=
ω
x
+
b
\omega x+b
ωx+b,使用梯度下降法對
ω
\omega
ω進行更新,那么就需要將損失函數(shù)對
ω
\omega
ω進行求偏導數(shù)
?
C
?
ω
=
(
y
?
y
^
)
σ
′
(
z
)
x
=
(
y
?
y
^
)
y
^
(
1
?
y
^
)
x
\frac{\partial C}{\partial \omega}=(y-\hat y)\sigma'(z)x=(y-\hat y)\hat y(1-\hat y)x
?ω?C?=(y?y^?)σ′(z)x=(y?y^?)y^?(1?y^?)x
具體計算過程可參考如下或文末參考文章
可以看到,均方差損失函數(shù)的梯度與激活函數(shù)(Sigmoid
函數(shù))的梯度成正比,當預(yù)測值接近于1或0時,梯度會變得非常小,幾乎接近于0,這樣會導致當真實值與預(yù)測值差距很大時,損失函數(shù)收斂的很慢,無法進行有效學習,與我們的期望不符合
因此,如果使用均方差損失,訓練的時候可能看到的情況是預(yù)測值和真實值之間的差距越大,參數(shù)調(diào)整的越小,訓練的越慢
如果使用交叉熵作為損失函數(shù),對于二分類問題,交叉熵的形式是由極大似然估計下概率的連乘然后取對數(shù)得到的(推導見文章:傳送門)
C
=
?
[
y
ln
?
y
^
+
(
1
?
y
)
ln
?
(
1
?
y
^
)
]
C=-[y\ln \hat y +(1-y)\ln (1-\hat y)]
C=?[ylny^?+(1?y)ln(1?y^?)]
關(guān)于
ω
\omega
ω求偏導數(shù)得
?
C
?
ω
=
(
σ
(
z
)
?
y
)
x
\frac{\partial C}{\partial \omega}=(\sigma(z)-y)x
?ω?C?=(σ(z)?y)x
可以看到,交叉熵損失函數(shù)的梯度和當前預(yù)測值與真實值之間的差是有關(guān)的,沒有受到Sigmoid
函數(shù)的梯度的影響,且真實值與預(yù)測值的差越大,損失函數(shù)的梯度就越大,更新的速度也就越快,這正是我們想要的文章來源:http://www.zghlxwxcb.cn/news/detail-832871.html
參考文章:https://zhuanlan.zhihu.com/p/453411383?login=from_csdn文章來源地址http://www.zghlxwxcb.cn/news/detail-832871.html
到了這里,關(guān)于邏輯回歸為什么使用交叉熵而不用均方差?的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!