GRU模型
隨著深度學(xué)習(xí)領(lǐng)域的快速發(fā)展,循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)已成為自然語(yǔ)言處理(NLP)等領(lǐng)域中常用的模型之一。但是,在RNN中,如果時(shí)間步數(shù)較大,會(huì)導(dǎo)致梯度消失或爆炸的問(wèn)題,這影響了模型的訓(xùn)練效果。為了解決這個(gè)問(wèn)題,研究人員提出了新的模型,其中GRU是其中的一種。
本文將介紹GRU的數(shù)學(xué)原理、代碼實(shí)現(xiàn),并通過(guò)pytorch和sklearn的數(shù)據(jù)集進(jìn)行試驗(yàn),最后對(duì)該模型進(jìn)行總結(jié)。
數(shù)學(xué)原理
GRU是一種門(mén)控循環(huán)單元(Gated Recurrent Unit)模型。與傳統(tǒng)的RNN相比,它具有更強(qiáng)的建模能力和更好的性能。
重置門(mén)和更新門(mén)
在GRU中,每個(gè)時(shí)間步有兩個(gè)狀態(tài):隱藏狀態(tài) h t h_t ht?和更新門(mén) r t r_t rt?。。更新門(mén)控制如何從先前的狀態(tài)中獲得信息,而隱藏狀態(tài)捕捉序列中的長(zhǎng)期依賴(lài)關(guān)系。
GRU的核心思想是使用“門(mén)”來(lái)控制信息的流動(dòng)。這些門(mén)是由sigmoid激活函數(shù)控制的,它們決定了哪些信息被保留和傳遞。
在每個(gè)時(shí)間步
t
t
t,GRU模型執(zhí)行以下操作:
1.計(jì)算重置門(mén)
r
t
=
σ
(
W
r
[
x
t
,
h
t
?
1
]
)
r_t = \sigma(W_r[x_t, h_{t-1}])
rt?=σ(Wr?[xt?,ht?1?])
其中,
W
r
W_r
Wr?是權(quán)重矩陣,
σ
\sigma
σ表示sigmoid函數(shù)。重置門(mén)
r
t
r_t
rt?告訴模型是否要忽略先前的隱藏狀態(tài)
h
t
?
1
h_{t-1}
ht?1?,并只依賴(lài)于當(dāng)前輸入
x
t
x_t
xt?。
2.計(jì)算更新門(mén)
z
t
=
σ
(
W
z
[
x
t
,
h
t
?
1
]
)
z_t = \sigma(W_z[x_t, h_{t-1}])
zt?=σ(Wz?[xt?,ht?1?])
其中,更新門(mén)
z
t
z_t
zt?告訴模型新的隱藏狀態(tài)
h
t
h_t
ht?在多大程度上應(yīng)該使用先前的狀態(tài)
h
t
?
1
h_{t-1}
ht?1?。
候選隱藏狀態(tài)和隱藏狀態(tài)
在計(jì)算完重置門(mén)和更新門(mén)之后,我們可以計(jì)算候選隱藏狀態(tài) h ~ t \tilde{h}_{t} h~t?和隱藏狀態(tài) h t h_t ht?。
1.計(jì)算候選隱藏狀態(tài)
h
~
t
=
tanh
?
(
W
[
x
t
,
r
t
?
h
t
?
1
]
)
\tilde{h}_{t} = \tanh(W[x_t, r_t * h_{t-1}])
h~t?=tanh(W[xt?,rt??ht?1?])
其中,
W
W
W是權(quán)重矩陣。候選隱藏狀態(tài)
h
~
t
\tilde{h}_{t}
h~t?利用當(dāng)前輸入
x
t
x_t
xt?和重置門(mén)
r
t
r_t
rt?來(lái)估計(jì)下一個(gè)可能的隱藏狀態(tài)。
2.計(jì)算隱藏狀態(tài)
h
t
=
(
1
?
z
t
)
?
h
t
?
1
+
z
t
?
h
~
t
h_{t} = (1 - z_t) * h_{t-1} + z_t * \tilde{h}_{t}
ht?=(1?zt?)?ht?1?+zt??h~t?
這是GRU的最終隱藏狀態(tài)公式。它在候選隱藏狀態(tài)
h
~
t
\tilde{h}_{t}
h~t?和先前的隱藏狀態(tài)
h
t
h_t
ht?之間進(jìn)行加權(quán),其中權(quán)重由更新門(mén)
z
t
z_t
zt?控制。
代碼實(shí)現(xiàn)
下面是使用pytorch和sklearn的房?jī)r(jià)數(shù)據(jù)集實(shí)現(xiàn)GRU的示例代碼:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_boston
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
# 加載數(shù)據(jù)集并進(jìn)行標(biāo)準(zhǔn)化
data = load_boston()
X = data.data
y = data.target
scaler = StandardScaler()
X = scaler.fit_transform(X)
y = y.reshape(-1, 1)
# 轉(zhuǎn)換為張量
X = torch.tensor(X, dtype=torch.float32).unsqueeze(1)
y = torch.tensor(y, dtype=torch.float32)
# 定義GRU模型
class GRUNet(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(GRUNet, self).__init__()
self.hidden_size = hidden_size
self.gru = nn.GRU(input_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
out, _ = self.gru(x)
out = self.fc(out[:, -1, :])
return out
input_size = X.shape[2]
hidden_size = 32
output_size = 1
model = GRUNet(input_size, hidden_size, output_size)
# 定義損失函數(shù)和優(yōu)化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 訓(xùn)練模型
num_epochs = 10000
loss_list = []
for epoch in range(num_epochs):
optimizer.zero_grad()
outputs = model(X)
loss = criterion(outputs, y)
loss.backward()
optimizer.step()
if (epoch+1) % 100 == 0:
loss_list.append(loss.item())
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')
# 可視化損失曲線
plt.plot(range(100), loss_list)
plt.xlabel('num_epochs')
plt.ylabel('loss of GRU Training')
plt.show()
# 預(yù)測(cè)新數(shù)據(jù)
new_data_point = X[0].reshape(1, 1, -1)
prediction = model(new_data_point)
print(f'Predicted value: {prediction.item()}')
上述代碼首先加載并標(biāo)準(zhǔn)化房?jī)r(jià)數(shù)據(jù)集,然后定義了一個(gè)包含GRU層和全連接層的GRUNet模型,并使用均方誤差作為損失函數(shù)和Adam優(yōu)化器進(jìn)行訓(xùn)練。訓(xùn)練完成后,使用matplotlib庫(kù)繪制損失曲線(如下圖所示),并使用訓(xùn)練好的模型對(duì)新的數(shù)據(jù)點(diǎn)進(jìn)行預(yù)測(cè)。文章來(lái)源:http://www.zghlxwxcb.cn/news/detail-828991.html
總結(jié)
GRU是一種門(mén)控循環(huán)單元模型,它通過(guò)更新門(mén)和重置門(mén),有效地解決了梯度消失或爆炸的問(wèn)題。在本文中,我們介紹了GRU的數(shù)學(xué)原理、代碼實(shí)現(xiàn)和代碼解釋?zhuān)⑼ㄟ^(guò)pytorch和sklearn的房?jī)r(jià)數(shù)據(jù)集進(jìn)行了試驗(yàn)。文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-828991.html
到了這里,關(guān)于機(jī)器學(xué)習(xí)入門(mén)--門(mén)控循環(huán)單元(GRU)原理與實(shí)踐的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!