生成AIだけじゃない、GNNでデータの関係性を捉えよう!

こんにちは!FLINTERS BASEの梶山です!

今回はGNN(グラフニューラルネットワーク)について簡単に紹介したいと思います!
世の中は生成AI(LLM)ブーム真っ只中でChatGPTやGemini,Claudeなどが注目を集めています。最近ではMCPサーバーとかも流行っていますね。とはいえ、まだまだ生成AIだけではできないことや精度が足りなかったりするタスクがあります。 例えば、SNSでのユーザー間のつながり、分子構造、交通ネットワークなど、「関係性」が重要な意味を持つデータは日々増加しています。LLMでもこういった構造化データを扱おうとする試みはありますが限界もあります。
そこで登場するのがGNNです。GNNはデータ間の関係性そのものを学習できるディープラーニングの一種です。GNNはデータ間の関係性を直接学習することができます。

GNNとは

GNN(グラフニューラルネットワーク)は、ノード(点)とエッジ(線)で構成されるグラフデータを効果的に処理するための手法です。従来のニューラルネットワークでは、データが固定的な構造を持つことが前提でしたが、GNNは複雑な関係性を持つデータを扱うことができます。

例えば、ソーシャルネットワークでは、ユーザー同士のつながりをグラフとして表現できます。ここで、各ユーザーがノードとなり、友達関係がエッジとして示されます。このように、GNNはノード間の関係を学習し、ユーザーの行動予測や友達推薦などに応用されます。

さらに、GNNは画像や文章もネットワークデータとして扱うことができるため、画像のピクセル間の関係や、文章内の単語の関連性を学習することも可能です。これにより、画像認識や自然言語処理の精度を向上させることが期待されています。

GNNでノード分類をしてみる

今回はグラフのノード分類をGoogle Colabratoryで試しにやってみました。 使用したデータセットはCoraで論文の研究分野のラベルや引用や非引用などの関係性のデータが入っています。
※コードの説明や数理的な説明等は省略します。実装にはGCNフィルタを用いてます。

!pip install torch-geometric

from torch_geometric.datasets import Planetoid

dataset = Planetoid(root="./dataset",name="Cora",split="full")

from torch_geometric.transforms import RandomNodeSplit

node_spliter = RandomNodeSplit(
    split="train_rest",
    num_splits=1,
    num_val=0.0,
    num_test=0.4,
    key="y"
)

splited_data = node_spliter(dataset._data)

import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, Sequential

class GCN(torch.nn.Module):
    def __init__(self, num_node_features: int, projection_dim: int, num_classes: int) -> None:
        super().__init__()
        self.conv1 = GCNConv(num_node_features, projection_dim)
        self.conv2 = GCNConv(projection_dim, num_classes)

    def forward(self, data: Data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x,edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x,edge_index)

        return F.log_softmax(x, dim=1)

device = "cuda" if torch.cuda.is_available() else "cpu"
gcn_model = GCN(
    num_node_features=dataset.num_node_features,
    projection_dim=64,
    num_classes=dataset.num_classes
).to(device)

gcn_optimizer = torch.optim.Adam(list(gcn_model.parameters()),lr=0.01)

from tqdm import tqdm

def train_gcn() -> float:
    gcn_model.train()
    total_loss = 0.0
    gcn_optimizer.zero_grad()
    out = gcn_model(splited_data)
    loss = F.cross_entropy(
        out[splited_data.train_mask],
        splited_data.y[splited_data.train_mask]
    )
    loss.backward()
    gcn_optimizer.step()
    return loss.item()


for epoch in tqdm(range(200)):
    loss = train_gcn()
    print(f"train loss:{loss:.4f}")

from tqdm import tqdm

def train_gcn() -> float:
    gcn_model.train()
    total_loss = 0.0
    gcn_optimizer.zero_grad()
    out = gcn_model(splited_data)
    loss = F.cross_entropy(
        out[splited_data.train_mask],
        splited_data.y[splited_data.train_mask]
    )
    loss.backward()
    gcn_optimizer.step()
    return loss.item()


for epoch in tqdm(range(200)):
    loss = train_gcn()
    print(f"train loss:{loss:.4f}")

#              precision    recall  f1-score   support
#
#           0       0.71      0.81      0.76       135
#           1       0.82      0.80      0.81        90
#           2       0.94      0.98      0.96       167
#           3       0.90      0.85      0.87       328
#           4       0.86      0.85      0.86       170
#           5       0.80      0.80      0.80       117
#           6       0.87      0.79      0.83        76
#
#    accuracy                           0.85      1083
#   macro avg       0.84      0.84      0.84      1083
# weighted avg       0.85      0.85      0.85      1083

大体85%ぐらいは正解できてそうですね。 ある程度データがあれば数十行のコードで学習してここまで正解できるのは改めてすごいと感じました。

終わりに

実は今まであまりDeepLearningとかにあまり興味はなかったのですが、これから生成AI(LLM)が何ができて何ができないか知っておくことは重要だと思うので少しづつその周りの知識についても勉強しようと思います。(ちなみにGNNはこれ勉強すれば大体カバーできるじゃん!ていう浅い理由で勉強し始めました笑)
あと個人的に現在はDeepLearning系のモデルに大量にデータを食わせて精度を出していますが、20年後ぐらいには全く別のアーキテクチャのモデルがとてつもない精度を出してたりしないかなと思ってます。

参考書籍等

GNNを勉強する理由が載ってます speakerdeck.com

GNNについて勉強する時に最初の一冊としてオススメです! gihyo.jp