上次講了如何給節(jié)點(diǎn)分類,這次我們來(lái)看如何用GNN完成圖分類任務(wù),也就是Graph-level的任務(wù)。
【GNN 1】PyG實(shí)現(xiàn)圖神經(jīng)網(wǎng)絡(luò),完成節(jié)點(diǎn)分類任務(wù),人話、保姆級(jí)教程-CSDN博客
圖分類就是以圖為單位的分類,舉個(gè)例子:每個(gè)學(xué)校都有社交關(guān)系網(wǎng),圖分類就是通過(guò)這個(gè)社交網(wǎng)絡(luò)判別這個(gè)學(xué)校是小學(xué)、初中、高中還是大學(xué)。
實(shí)現(xiàn)方法就是通過(guò)利用圖的結(jié)構(gòu)信息,對(duì)圖進(jìn)行嵌入(embed),也就是用向量來(lái)表示這個(gè)圖,使得分類器基于這個(gè)向量能夠輕松分類,或者說(shuō)通過(guò)對(duì)圖進(jìn)行向量表示,使得圖的分類盡可能變成一個(gè)線性可分的任務(wù)。
下圖就是形象展示了我們要干的事:把一堆圖表示成線性可分的向量們,然后構(gòu)建個(gè)分類器,完成圖分類。
圖分類的一個(gè)經(jīng)典任務(wù)就是分子屬性預(yù)測(cè),我們可以把原子看成圖的節(jié)點(diǎn),化學(xué)鍵看成邊,整個(gè)分子就是圖,我們想知道分子有什么性質(zhì)(是否是藥物小分子,能否和蛋白相互作用等),其實(shí)就是看這個(gè)圖是屬于哪一個(gè)類別的。
數(shù)據(jù)集的選擇
我們這次選擇的數(shù)據(jù)集是TUDatasets,這是TU Dortumnd University收集的大量關(guān)于分子特征的圖數(shù)據(jù),他們還發(fā)表了論文TUDataset: A collection of benchmark datasets for learning with graphs。
這么重要的數(shù)據(jù)集,當(dāng)然也可以通過(guò)PyTorch Geometric直接加載啦。
下面是數(shù)據(jù)集的大致情況,可以看到這里包括酶、蛋白質(zhì)還有其他的一些。注意了第一列不是第二列的類別,最后一列class才是。
加載數(shù)據(jù)
下面我們來(lái)加載數(shù)據(jù)吧!
先說(shuō)一下TUDataset
這個(gè)函數(shù),有兩個(gè)參數(shù),
- root (str) – Root directory where the dataset should be saved.(保存的路徑)
- name (str) – The name of the dataset.(名字)
加載完就可以發(fā)現(xiàn)數(shù)據(jù)已經(jīng)下載到我們指定的root目錄了。
# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)
# !pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
import torch
from torch_geometric.datasets import TUDataset
dataset = TUDataset(root='data/TUDataset', name='MUTAG')
# - root (str) – Root directory where the dataset should be saved.(保存的路徑)
# - name (str) – The name of the dataset.(名字)
# 查看一些數(shù)據(jù)集的基本信息
print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')
data = dataset[0] # Get the first graph object.
print()
print(data)
print('=============================================================')
# 看一下第一張圖的信息
# Gather some statistics about the first graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')
Dataset: MUTAG(188):
====================
Number of graphs: 188
Number of features: 7
Number of classes: 2
Data(edge_index=[2, 38], x=[17, 7], edge_attr=[38, 4], y=[1])
=============================================================
Number of nodes: 17
Number of edges: 38
Average node degree: 2.24
Has isolated nodes: False
Has self-loops: False
Is undirected: True
這個(gè)數(shù)據(jù)集有188張圖,有兩個(gè)類。
通過(guò)查看第一個(gè)圖的基本信息,我們可以看到它有17個(gè)節(jié)點(diǎn)(7維特征向量,用一個(gè)7維的向量來(lái)描述節(jié)點(diǎn))、38條無(wú)向邊(4維特征向量,用一個(gè)4維的向量給來(lái)描述邊,因?yàn)槭侨腴T(mén),這次我們不用)還有一個(gè)圖的標(biāo)簽y=[1]
表示圖是哪一類的(1維向量,一個(gè)數(shù))。
這里提一個(gè)小知識(shí)點(diǎn),在機(jī)器學(xué)習(xí)中,訓(xùn)練之前一般均會(huì)對(duì)數(shù)據(jù)集做shuffle,也就是打亂數(shù)據(jù)之間的順序,讓數(shù)據(jù)隨機(jī)化,這樣可以避免過(guò)擬合。
數(shù)據(jù)集shuffle的重要性 - 知乎 (zhihu.com)
PyTorch Geometric也提供了很多處理圖數(shù)據(jù)集的方法,比如shuffle()
,我們今天就首先對(duì)數(shù)據(jù)集進(jìn)行打亂,然后選擇前150個(gè)樣本進(jìn)行訓(xùn)練,剩下的38個(gè)進(jìn)行測(cè)試。
torch.manual_seed(12345)
dataset = dataset.shuffle()
train_dataset = dataset[:150]
test_dataset = dataset[150:]
print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')
Number of training graphs: 150
Number of test graphs: 38
圖的Mini-batching
因?yàn)樵趫D分類數(shù)據(jù)集中的圖通常來(lái)說(shuō)都比較小,這樣就不能充分利用GPU,所以一個(gè)想法就是就是先batch the graph,然后再把圖放到GNN中。
在圖像和自然語(yǔ)言處理領(lǐng)域,這個(gè)過(guò)程通常是通過(guò)rescaling或者padding來(lái)實(shí)現(xiàn)的,就是把每個(gè)樣本轉(zhuǎn)換成統(tǒng)一大小/形狀,然后再把它們放到一起。以圖片為例,我把所有圖片都轉(zhuǎn)換為100*100大小,然后再把這些圖像融合成一個(gè)大圖(或者以其他形式存儲(chǔ)),這個(gè)存儲(chǔ)的變量也是有維度的,這個(gè)維度的大小就是在一個(gè)batch中樣本的個(gè)數(shù),也就是batch size
。
在GNN中,rescaling和padding要么行不通,要么會(huì)造成不必要的內(nèi)存消耗。
因此,PyTorch Geometric選擇了另一種方法來(lái)實(shí)現(xiàn)多個(gè)樣本的并行化。在這里,鄰接矩陣以對(duì)角線方式堆疊(創(chuàng)建一個(gè)包含多個(gè)孤立子圖的大圖,A),node和target特征在節(jié)點(diǎn)維度上簡(jiǎn)單地拼接起來(lái)(X):
這個(gè)過(guò)程相對(duì)于其他的batching方法有一些關(guān)鍵的優(yōu)勢(shì):
- 依賴于消息傳遞方案(message passing scheme)的GNN operators不需要修改,因?yàn)橄⒉粫?huì)在屬于不同圖的兩個(gè)節(jié)點(diǎn)之間交換;
- 沒(méi)有計(jì)算或內(nèi)存開(kāi)銷,因?yàn)猷徑泳仃囈韵∈杈仃嚨姆绞奖4?,只保存非零條目,也就是只保留邊。
PyTorch Geometric在torch_geometrics .data. dataloader
類的幫助下自動(dòng)處理將多個(gè)圖批處理為一個(gè)大圖(batching multiple graphs into a single giant graph):
from torch_geometric.loader import DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
for step, data in enumerate(train_loader):
print(f'Step {step + 1}:')
print('=======')
print(f'Number of graphs in the current batch: {data.num_graphs}')
print(data)
print()
for step, data in enumerate(test_loader):
print(f'Step {step + 1}:')
print('=======')
print(f'Number of graphs in the current batch: {data.num_graphs}')
print(data)
print()
Step 1:
=======
Number of graphs in the current batch: 64
DataBatch(edge_index=[2, 2572], x=[1168, 7], edge_attr=[2572, 4], y=[64], batch=[1168], ptr=[65])
Step 2:
=======
Number of graphs in the current batch: 64
DataBatch(edge_index=[2, 2554], x=[1153, 7], edge_attr=[2554, 4], y=[64], batch=[1153], ptr=[65])
Step 3:
=======
Number of graphs in the current batch: 22
DataBatch(edge_index=[2, 868], x=[393, 7], edge_attr=[868, 4], y=[22], batch=[393], ptr=[23])
Step 1:
=======
Number of graphs in the current batch: 38
DataBatch(edge_index=[2, 1448], x=[657, 7], edge_attr=[1448, 4], y=[38], batch=[657], ptr=[39])
我們選擇將batch_size
設(shè)置為64,可以看到分成了3個(gè)隨機(jī)打亂的mini-batches。
對(duì)于每個(gè) Batch
對(duì)象都有一個(gè)對(duì)應(yīng)的batch vector,這就是起到一個(gè)索引的作用,即將每個(gè)節(jié)點(diǎn)映射到batch中各自的圖上。
batch = [ 0 , … , 0 , 1 , … , 1 , 2 , … ] \textrm{batch} = [ 0, \ldots, 0, 1, \ldots, 1, 2, \ldots ] batch=[0,…,0,1,…,1,2,…]
訓(xùn)練GNN
圖分類任務(wù)的GNN訓(xùn)練一般是這樣的流程:
- 通過(guò)幾次信息傳遞(message passing)對(duì)每個(gè)節(jié)點(diǎn)進(jìn)行嵌入
- 把節(jié)點(diǎn)嵌入聚合成圖嵌入(readout layer)
- 根據(jù)圖嵌入向量訓(xùn)練最終的分類器
有很多論文開(kāi)發(fā)了很多readout層,不過(guò)其實(shí)用的最多的還是直接把節(jié)點(diǎn)嵌入求平均:
x G = 1 ∣ V ∣ ∑ v ∈ V x v ( L ) \mathbf{x}_{\mathcal{G}} = \frac{1}{|\mathcal{V}|} \sum_{v \in \mathcal{V}} \mathcal{x}^{(L)}_v xG?=∣V∣1?v∈V∑?xv(L)?
PyTorch Geometric提供了torch_geometric.nn.global_mean_pool
這個(gè)函數(shù)。輸入:①mini batch中所有node的embeddings;②分配向量batch
;輸出:每個(gè)batch中每個(gè)圖經(jīng)過(guò)計(jì)算得到的graph embedding
,大小是 [batch_size, hidden_channels]
還有很多其他的pooling函數(shù),之后會(huì)試試的。
完整的架構(gòu)我們直接通過(guò)print(model)就可以看到,這個(gè)模型可以實(shí)現(xiàn)端到端的圖分類了!
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool
class GCN(torch.nn.Module):
def __init__(self, hidden_channels):
super(GCN, self).__init__()
torch.manual_seed(12345)
self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)
self.conv2 = GCNConv(hidden_channels, hidden_channels)
self.conv3 = GCNConv(hidden_channels, hidden_channels)
self.lin = Linear(hidden_channels, dataset.num_classes)
def forward(self, x, edge_index, batch):
# 1. Obtain node embeddings
x = self.conv1(x, edge_index)
x = x.relu()
x = self.conv2(x, edge_index)
x = x.relu()
x = self.conv3(x, edge_index)
# 2. Readout layer
x = global_mean_pool(x, batch) # [batch_size, hidden_channels]
# 3. Apply a final classifier
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin(x)
return x
model = GCN(hidden_channels=64)
print(model)
GCN(
(conv1): GCNConv(7, 64)
(conv2): GCNConv(64, 64)
(conv3): GCNConv(64, 64)
(lin): Linear(in_features=64, out_features=2, bias=True)
)
我們?yōu)镚CN層選擇的激活函數(shù)是 R e L U ( x ) = max ? ( x , 0 ) \mathrm{ReLU}(x) = \max(x, 0) ReLU(x)=max(x,0) (除了最后的readout layer)。
讓我們訓(xùn)練一下我們的模型吧,看看它在測(cè)試集上表現(xiàn)如何!
from IPython.display import Javascript
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))
model = GCN(hidden_channels=64)
# 模型的核心,64個(gè)hidden_channels,類似于神經(jīng)網(wǎng)絡(luò)的隱藏層的神經(jīng)元個(gè)數(shù)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()
def train():
model.train()
# 把模型設(shè)置為訓(xùn)練模式
for data in train_loader: # Iterate in batches over the training dataset.
out = model(data.x, data.edge_index, data.batch) # Perform a single forward pass.
loss = criterion(out, data.y) # Compute the loss.
loss.backward() # Derive gradients.
# 計(jì)算梯度
optimizer.step() # Update parameters based on gradients.
# 根據(jù)上面計(jì)算的梯度更新參數(shù)
optimizer.zero_grad() # Clear gradients.
# 清除梯度,為下一個(gè)批次的數(shù)據(jù)做準(zhǔn)備,相當(dāng)于從頭開(kāi)始
def test(loader):
model.eval()
# 把模型設(shè)置為評(píng)估模式
correct = 0
# 初始化correct為0,表示預(yù)測(cè)對(duì)的個(gè)數(shù)
for data in loader: # Iterate in batches over the training/test dataset.
out = model(data.x, data.edge_index, data.batch)
# 預(yù)測(cè)的輸出值
pred = out.argmax(dim=1) # Use the class with highest probability.
# 每個(gè)類別對(duì)應(yīng)一個(gè)概率,概率最大的就是對(duì)應(yīng)的預(yù)測(cè)值
correct += int((pred == data.y).sum()) # Check against ground-truth labels.
# 如果一樣,就是True,也就是1,correct就+1
# 準(zhǔn)確率就是正確的/總的
return correct / len(loader.dataset) # Derive ratio of correct predictions.
for epoch in range(1, 171):
train()
train_acc = test(train_loader)
test_acc = test(test_loader)
print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')
可以看到,我們的模型有0.76的測(cè)試集準(zhǔn)確度。
波動(dòng)的原因可以認(rèn)為是測(cè)試集太小了,通常來(lái)說(shuō),如果數(shù)據(jù)集比較大這種波動(dòng)情況就會(huì)消失。
換個(gè)架構(gòu)看看效果怎么樣
我們可以做的更好嗎?當(dāng)然可以。
正如多篇論文指出的那樣(Xu et al. (2018), Morris et al. (2018)),應(yīng)用鄰域歸一化降低了gnn在識(shí)別某些圖結(jié)構(gòu)方面的表達(dá)能力(neighborhood normalization decrease the expressivity of GNNs in distingushiing certain graph structures)。
另一種替代公式( Morris et al. (2018))完全省略了鄰域歸一化,并在GNN層中添加了一個(gè)簡(jiǎn)單的跳過(guò)連接,以保留中心節(jié)點(diǎn)信息:
x
v
(
?
+
1
)
=
W
1
(
?
+
1
)
x
v
(
?
)
+
W
2
(
?
+
1
)
∑
w
∈
N
(
v
)
x
w
(
?
)
\mathbf{x}_v^{(\ell+1)} = \mathbf{W}^{(\ell + 1)}_1 \mathbf{x}_v^{(\ell)} + \mathbf{W}^{(\ell + 1)}_2 \sum_{w \in \mathcal{N}(v)} \mathbf{x}_w^{(\ell)}
xv(?+1)?=W1(?+1)?xv(?)?+W2(?+1)?w∈N(v)∑?xw(?)?
這個(gè)layer也可以在PyG中輕松調(diào)用,也就是GraphConv
。
也就是說(shuō),我們想用PyG’s GraphConv
而不是 GCNConv
,然后看看效果怎么樣。
from torch_geometric.nn import GraphConv
class GNN(torch.nn.Module):
def __init__(self, hidden_channels):
super(GNN, self).__init__()
torch.manual_seed(12345)
self.conv1 = GraphConv(dataset.num_node_features, hidden_channels) # TODO
self.conv2 = GraphConv(hidden_channels, hidden_channels) # TODO
self.conv3 = GraphConv(hidden_channels, hidden_channels) # TODO
self.lin = Linear(hidden_channels, dataset.num_classes)
def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index)
x = x.relu()
x = self.conv2(x, edge_index)
x = x.relu()
x = self.conv3(x, edge_index)
x = global_mean_pool(x, batch)
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin(x)
return x
model = GNN(hidden_channels=64)
print(model)
GNN(
(conv1): GraphConv(7, 64)
(conv2): GraphConv(64, 64)
(conv3): GraphConv(64, 64)
(lin): Linear(in_features=64, out_features=2, bias=True)
)
from IPython.display import Javascript
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))
model = GNN(hidden_channels=64)
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(1, 201):
train()
train_acc = test(train_loader)
test_acc = test(test_loader)
print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')
GNN(
(conv1): GraphConv(7, 64)
(conv2): GraphConv(64, 64)
(conv3): GraphConv(64, 64)
(lin): Linear(in_features=64, out_features=2, bias=True)
)
總結(jié)
我們學(xué)習(xí)了如何應(yīng)用GNN完成圖分類,調(diào)用了GraphConv
和GCNConv
兩種架構(gòu),舉一反三,之后想用什么layer就用什么layer。
此外我們還學(xué)習(xí)了如何讓單個(gè)的圖組成batch,從而更好地利用GPU,以及如何應(yīng)用readout layer從node embedding中得到graph embedding。
參考資料:
3. Graph Classification.ipynb - Colaboratory (google.com)
Colab Notebooks and Video Tutorials — pytorch_geometric documentation (pytorch-geometric.readthedocs.io)文章來(lái)源:http://www.zghlxwxcb.cn/news/detail-810344.html
如果覺(jué)得還不錯(cuò),記得點(diǎn)贊+收藏喲!謝謝大家的閱讀?。ǎ幔┄J文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-810344.html
到了這里,關(guān)于【GNN2】PyG完成圖分類任務(wù),新手入門(mén),保姆級(jí)教程的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!