????????簡(jiǎn)介:這個(gè)代碼可以用于時(shí)間序列修復(fù)和生成。使用transformer提取單變量或者多變時(shí)間窗口的趨勢(shì)分布情況。然后使用GAN生成分布類似的時(shí)間序列。
? ? ? ? 此外,還實(shí)現(xiàn)了基于prompt的數(shù)據(jù)生成,比如指定生成某個(gè)月份的數(shù)據(jù)、某半個(gè)月的數(shù)據(jù)、某一個(gè)星期的數(shù)據(jù)。
1、模型架構(gòu)
? ? ? ? 如下圖所示,生成器和鑒別器都使用Transformer的編碼器部分提取時(shí)間序列的特征,然后鑒別器使用這些進(jìn)行二分類、生成器使用這些特征生成偽造的數(shù)據(jù)。
? ? ? ? 重點(diǎn):在下面的圖的基礎(chǔ)上,我還添加了基于提示的生成代碼,類似于AI提示繪畫(huà)一樣,因此可以指定生成一月份、二月份等任意指定周期的數(shù)據(jù)。
2、訓(xùn)練GAN的代碼
? ? ? ? 下面是GAN的訓(xùn)練部分。
# 訓(xùn)練GAN
num_epochs = 100
for epoch in range(num_epochs):
for real_x,x_g,zz in loader: # 分別是真實(shí)值real_x、提示詞信息x_g、噪聲zz
real_data = real_x
noisy_data = x_g
# Train Discriminator
optimizer_D.zero_grad()
out = discriminator(real_data)
real_loss = criterion(discriminator(real_data), torch.ones(real_data.size(0), 1))
fake_data = generator(noisy_data,zz)
fake_loss = criterion(discriminator(fake_data.detach()), torch.zeros(fake_data.size(0), 1))
d_loss = real_loss + fake_loss
d_loss.backward()
optimizer_D.step()
# Train Generator
optimizer_G.zero_grad()
g_loss = criterion(discriminator(fake_data), torch.ones(fake_data.size(0), 1))
g_loss.backward()
optimizer_G.step()
print(f'Epoch [{epoch+1}/{num_epochs}], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}')
3、生成器代碼
class Generator(nn.Module):
def __init__(self, seq_len=8, patch_size=2, channels=1, num_classes=9, latent_dim=100, embed_dim=10, depth=1,
num_heads=5, forward_drop_rate=0.5, attn_drop_rate=0.5):
super(Generator, self).__init__()
self.channels = channels
self.latent_dim = latent_dim
self.seq_len = seq_len
self.embed_dim = embed_dim
self.patch_size = patch_size
self.depth = depth
self.attn_drop_rate = attn_drop_rate
self.forward_drop_rate = forward_drop_rate
self.l1 = nn.Linear(self.latent_dim, self.seq_len * self.embed_dim)
self.pos_embed = nn.Parameter(torch.zeros(1, self.seq_len, self.embed_dim))
self.blocks = Gen_TransformerEncoder(
depth=self.depth,
emb_size = self.embed_dim,
drop_p = self.attn_drop_rate,
)
self.deconv = nn.Sequential(
nn.Conv2d(self.embed_dim, self.channels, 1, 1, 0)
)
def forward(self, z):
x = self.l1(z).view(-1, self.seq_len, self.embed_dim)
x = x + self.pos_embed
H, W = 1, self.seq_len
x = self.blocks(x)
x = x.reshape(x.shape[0], 1, x.shape[1], x.shape[2])
output = self.deconv(x.permute(0, 3, 1, 2))
output = output.view(-1, self.channels, H, W)
return output
4、生成數(shù)據(jù)和真實(shí)數(shù)據(jù)分布對(duì)比
? ? ? ? 使用PCA和TSNE對(duì)生成的時(shí)間窗口數(shù)據(jù)進(jìn)行降維,然后scatter這些二維點(diǎn)。如果生成的真實(shí)數(shù)據(jù)的互相混合在一起,說(shuō)明模型學(xué)習(xí)到了真東西,也就是模型偽造的數(shù)據(jù)和真實(shí)數(shù)據(jù)分布是一樣的,美滋滋。從下面的PCA可以看出,兩者的分布還是近似的。
? ? ? ? 進(jìn)一步的,可以擬合兩個(gè)二維正態(tài)分布,然后計(jì)算他們的KL散度作為一個(gè)評(píng)價(jià)指標(biāo)。
5、生成數(shù)據(jù)展示
? ? ? ? 上面是真實(shí)數(shù)據(jù)、下面是偽造的數(shù)據(jù)。由于只有幾百個(gè)樣本,以及參數(shù)都沒(méi)有進(jìn)行調(diào)整,但是效果還不錯(cuò)。
6、損失函數(shù)變化情況
? ? ? ? 模型還是學(xué)習(xí)到了一點(diǎn)東西的。文章來(lái)源:http://www.zghlxwxcb.cn/news/detail-861675.html
文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-861675.html
到了這里,關(guān)于時(shí)間序列生成數(shù)據(jù),TransformerGAN的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!