1. nn.RNN 構(gòu)建單向 RNN
torch.nn.RNN 的 PyTorch 鏈接:torch.nn.RNN(*args, **kwargs)
nn.RNN 的用法和輸入輸出參數(shù)的介紹直接看代碼:
import torch
import torch.nn as nn
# 單層單向 RNN
embed_dim = 5 # 每個(gè)輸入元素的特征維度,如每個(gè)單詞用長度為 5 的特征向量表示
hidden_dim = 6 # 隱狀態(tài)的特征維度,如每個(gè)單詞在隱藏層中用長度為 6 的特征向量表示
rnn_layers = 4 # 循環(huán)層數(shù)
rnn = nn.RNN(input_size=embed_dim, hidden_size=hidden_dim, num_layers=rnn_layers, batch_first=True)
# 輸入
batch_size = 2
sequence_length = 3 # 輸入的序列長度,如 i love you 的序列長度為 3,每個(gè)單詞用長度為 feature_num 的特征向量表示
input = torch.randn(batch_size, sequence_length, embed_dim)
h0 = torch.randn(rnn_layers, batch_size, hidden_dim)
# output 表示隱藏層在各個(gè) time step 上計(jì)算并輸出的隱狀態(tài)
# hn 表示所有掩藏層的在最后一個(gè) time step 隱狀態(tài), 即單詞 you 的隱狀態(tài)
output, hn = rnn(input, h0)
print(f"output = {output}")
print(f"hn = {hn}")
print(f"output.shape = {output.shape}") # torch.Size([2, 3, 6]) [batch_size, sequence_length, hidden_dim]
print(f"hn.shape = {hn.shape}") # torch.Size([4, 2, 6]) [rnn_layers, batch_size, hidden_dim]
"""
output = tensor([[[-0.3727, -0.2137, -0.3619, -0.6116, -0.1483, 0.8292],
[ 0.1138, -0.6310, -0.3897, -0.5275, 0.2012, 0.3399],
[-0.0522, -0.5991, -0.3114, -0.7089, 0.3824, 0.1903]],
[[ 0.1370, -0.6037, 0.3906, -0.5222, 0.8498, 0.8887],
[-0.3463, -0.3293, -0.1874, -0.7746, 0.2287, 0.1343],
[-0.2588, -0.4145, -0.2608, -0.3799, 0.4464, 0.1960]]],
grad_fn=<TransposeBackward1>)
hn = tensor([[[-0.2892, 0.7568, 0.4635, -0.2106, -0.0123, -0.7278],
[ 0.3492, -0.3639, -0.4249, -0.6626, 0.7551, 0.9312]],
[[ 0.0154, 0.0190, 0.3580, -0.1975, -0.1185, 0.3622],
[ 0.0905, 0.6483, -0.1252, 0.3903, 0.0359, -0.3011]],
[[-0.2833, -0.3383, 0.2421, -0.2168, -0.6694, -0.5462],
[ 0.2976, 0.0724, -0.0116, -0.1295, -0.6324, -0.0302]],
[[-0.0522, -0.5991, -0.3114, -0.7089, 0.3824, 0.1903],
[-0.2588, -0.4145, -0.2608, -0.3799, 0.4464, 0.1960]]],
grad_fn=<StackBackward0>)
output.shape = torch.Size([2, 3, 6])
hn.shape = torch.Size([4, 2, 6])
"""
需要特別注意的是 nn.RNN 的第二個(gè)輸出 hn 表示所有掩藏層的在最后一個(gè) time step 隱狀態(tài),聽起來很難理解,看下面的紅色方框內(nèi)的數(shù)據(jù)就懂了。即 output[:, -1, :] = hn[-1, : , :]
這里 hn 保存了四次循環(huán)中最后一個(gè) time step 隱狀態(tài)的數(shù)值,以輸入 i love you 為了,hn 保存的是單詞 you 的隱狀態(tài)。
2. nn.LSTM 構(gòu)建單向 LSTM
torch.nn.RNN 的 PyTorch 鏈接:torch.nn.LSTM(*args, **kwargs)
nn.LSTM 的用法和輸入輸出參數(shù)的介紹直接看代碼:
import torch
import torch.nn as nn
batch_size = 4
seq_len = 3 # 輸入的序列長度
embed_dim = 5 # 每個(gè)輸入元素的特征維度
hidden_size = 5 * 2 # 隱狀態(tài)的特征維度,根據(jù)工程經(jīng)驗(yàn)可取 hidden_size = embed_dim * 2
num_layers = 2 # LSTM 的層數(shù),一般設(shè)置為 1-4 層;多層 LSTM 的介紹可以參考 https://blog.csdn.net/weixin_41041772/article/details/88032093
lstm = nn.LSTM(input_size=embed_dim, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
# h0 可以缺省
input = torch.randn(batch_size, seq_len, embed_dim)
"""
output 表示隱藏層在各個(gè) time step 上計(jì)算并輸出的隱狀態(tài)
hn 表示所有隱藏層的在最后一個(gè) time step 隱狀態(tài), 即單詞 you 的隱狀態(tài);所以 hn 與句子長度 seq_len 無關(guān)
hn[-1] 表示最后一個(gè)隱藏層的在最后一個(gè) time step 隱狀態(tài),即 LSTM 的輸出
cn 表示句子的最后一個(gè)單詞的細(xì)胞狀態(tài);所以 cn 與句子長度 seq_len 無關(guān)
其中 output[:, -1, :] = hn[-1,:,:]
"""
output, (hn, cn) = lstm(input)
print(f"output.shape = {output.shape}") # torch.Size([4, 3, 10])
print(f"hn.shape = {hn.shape}") # torch.Size([2, 4, 10])
print(f"cn.shape = {cn.shape}") # torch.Size([2, 4, 10])
print(f"output = {output}")
print(f"hn = {hn}")
print(f"output[:, -1, :] = {output[:, -1, :]}")
print(f"hn[-1,:,:] = {hn[-1,:,:]}")
"""
output = tensor([[[ 0.0447, 0.0111, 0.0292, 0.0692, -0.0547, -0.0120, -0.0202, -0.0243, 0.1216, 0.0643],
[ 0.0780, 0.0279, 0.0231, 0.1061, -0.0819, -0.0027, -0.0269, -0.0509, 0.1800, 0.0921],
[ 0.0993, 0.0160, 0.0516, 0.1402, -0.1146, -0.0177, -0.0607, -0.0715, 0.2110, 0.0954]],
[[ 0.0542, -0.0053, 0.0415, 0.0899, -0.0561, -0.0376, -0.0327, -0.0276, 0.1159, 0.0545],
[ 0.0819, -0.0015, 0.0640, 0.1263, -0.1021, -0.0502, -0.0495, -0.0464, 0.1814, 0.0750],
[ 0.0914, 0.0034, 0.0558, 0.1418, -0.1327, -0.0643, -0.0616, -0.0674, 0.2195, 0.0886]],
[[ 0.0552, -0.0006, 0.0351, 0.0864, -0.0486, -0.0192, -0.0305, -0.0289, 0.1103, 0.0554],
[ 0.0835, -0.0099, 0.0415, 0.1396, -0.0758, -0.0829, -0.0616, -0.0604, 0.1740, 0.0828],
[ 0.1202, -0.0113, 0.0570, 0.1608, -0.0836, -0.0801, -0.0792, -0.0874, 0.1923, 0.0829]],
[[ 0.0115, -0.0026, 0.0267, 0.0747, -0.0867, -0.0250, -0.0199, -0.0154, 0.1158, 0.0649],
[ 0.0628, 0.0003, 0.0297, 0.1191, -0.1028, -0.0342, -0.0509, -0.0496, 0.1759, 0.0831],
[ 0.0569, 0.0105, 0.0158, 0.1300, -0.1367, -0.0207, -0.0514, -0.0629, 0.2029, 0.1042]]], grad_fn=<TransposeBackward0>)
hn = tensor([[[-0.1933, -0.0058, -0.1237, 0.0348, -0.1394, 0.2403, 0.1591, -0.1143, 0.1211, -0.1971],
[-0.2387, 0.0433, -0.0296, 0.0877, -0.1198, 0.1919, 0.0832, 0.0738, 0.1907, -0.1807],
[-0.2174, 0.0721, -0.0447, 0.1081, -0.0520, 0.2519, 0.4040, -0.0033, 0.1378, -0.2930],
[-0.2130, -0.0404, -0.0588, -0.1346, -0.1865, 0.1032, -0.0269, 0.0265, -0.0664, -0.1800]],
[[ 0.0993, 0.0160, 0.0516, 0.1402, -0.1146, -0.0177, -0.0607, -0.0715, 0.2110, 0.0954],
[ 0.0914, 0.0034, 0.0558, 0.1418, -0.1327, -0.0643, -0.0616, -0.0674, 0.2195, 0.0886],
[ 0.1202, -0.0113, 0.0570, 0.1608, -0.0836, -0.0801, -0.0792, -0.0874, 0.1923, 0.0829],
[ 0.0569, 0.0105, 0.0158, 0.1300, -0.1367, -0.0207, -0.0514, -0.0629, 0.2029, 0.1042]]], grad_fn=<StackBackward0>)
驗(yàn)證 output[:, -1, :] = hn[-1,:,:]
output[:, -1, :] = tensor([[ 0.0993, 0.0160, 0.0516, 0.1402, -0.1146, -0.0177, -0.0607, -0.0715,0.2110, 0.0954],
[ 0.0914, 0.0034, 0.0558, 0.1418, -0.1327, -0.0643, -0.0616, -0.0674, 0.2195, 0.0886],
[ 0.1202, -0.0113, 0.0570, 0.1608, -0.0836, -0.0801, -0.0792, -0.0874, 0.1923, 0.0829],
[ 0.0569, 0.0105, 0.0158, 0.1300, -0.1367, -0.0207, -0.0514, -0.0629, 0.2029, 0.1042]], grad_fn=<SliceBackward0>)
hn[-1,:,:] = tensor([[ 0.0993, 0.0160, 0.0516, 0.1402, -0.1146, -0.0177, -0.0607, -0.0715,0.2110, 0.0954],
[ 0.0914, 0.0034, 0.0558, 0.1418, -0.1327, -0.0643, -0.0616, -0.0674, 0.2195, 0.0886],
[ 0.1202, -0.0113, 0.0570, 0.1608, -0.0836, -0.0801, -0.0792, -0.0874, 0.1923, 0.0829],
[ 0.0569, 0.0105, 0.0158, 0.1300, -0.1367, -0.0207, -0.0514, -0.0629, 0.2029, 0.1042]], grad_fn=<SliceBackward0>)
"""
3. 推薦參考資料
多層 LSTM 的介紹可以參考博客 RNN之多層LSTM理解:輸入,輸出,時(shí)間步,隱藏節(jié)點(diǎn)數(shù),層數(shù)
RNN 的原理和 PyTorch 源碼復(fù)現(xiàn)可以參考視頻 PyTorch RNN的原理及其手寫復(fù)現(xiàn)文章來源:http://www.zghlxwxcb.cn/news/detail-530187.html
LSTM 的原理和 PyTorch 源碼復(fù)現(xiàn)可以參考視頻 PyTorch LSTM和LSTMP的原理及其手寫復(fù)現(xiàn)文章來源地址http://www.zghlxwxcb.cn/news/detail-530187.html
到了這里,關(guān)于【PyTorch API】 nn.RNN 和 nn.LSTM 介紹和代碼詳解的文章就介紹完了。如果您還想了解更多內(nèi)容,請?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!