ご注文はリード化合物ですか?〜医薬化学録にわ〜

自分の勉強や備忘録などを兼ねて好き勝手なことを書いていくブログです。

グラフ畳み込みニューラルネットワークによる回帰モデル:Deep Graph Library(DGL)の使い方

グラフ畳み込みニューラルネットワークのコードの例です。
R2 は -0.05 ぐらいと低いですが、計算自体は行えます。
(解説は近日追記)

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
import dgl
from dgl.dataloading import GraphDataLoader
from dgl.nn import GraphConv
from rdkit import Chem
from rdkit.Chem.rdPartialCharges import ComputeGasteigerCharges
from rdkit.Chem.Draw import IPythonConsole, MolsToGridImage
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score

# 化学構造の読み込み
structure = Chem.SDMolSupplier("../data/logSdataset1290_2d.sdf")

# データをランダムに分割
np.random.seed(0)
tag_list = list(range(len(structure)))
np.random.shuffle(tag_list)
train_tag = np.array(tag_list)[:1200]
test_tag = np.array(tag_list)[1200:]

train_mols = []
test_mols = []
train_y = []
test_y = []
for i in train_tag:
    mol = structure[int(i)]
    train_mols.append(mol)
    train_y.append(np.float(mol.GetProp("logS")))
    
for i in test_tag:
    mol = structure[int(i)]
    test_mols.append(mol)
    test_y.append(np.float(mol.GetProp("logS")))

# ノード単位で記述子を算出する関数
def get_node_features(mol):
    node_features = []
    for atom in mol.GetAtoms():
        node_features.append([atom.GetAtomicNum(), 
                              np.int(atom.GetIsAromatic()), 
                              np.float(atom.GetProp("_GasteigerCharge"))])
    
    return torch.from_numpy(np.array(node_features)).float()

# ノードの情報を取り出す関数
def get_graph_data(mols):
    graphs = []
    for mol in mols:

        # Gasteiger charge を計算
        ComputeGasteigerCharges(mol)

        # 結合しているノードの組み合わせを算出
        node_pair = np.where(np.triu(Chem.GetAdjacencyMatrix(mol)) == 1)

        # Graph 型のデータを作成
        graph = dgl.graph((node_pair[0], node_pair[1]), num_nodes = len(mol.GetAtoms()))

        # ノード毎の特徴量を入力
        graph.ndata["atom_features"] = get_node_features(mol)

        # 自己ノードを畳み込むための self loop を設定する
        graph = dgl.add_self_loop(graph)

        graphs.append(graph)
        
    return graphs

# 化学構造(Mol オブジェクト)から Graph 型のデータを作成
train_graph = get_graph_data(train_mols)
test_graph = get_graph_data(test_mols)

# 目的変数の読み込み
train_y_tensor = torch.tensor(train_y).float()
test_y_tensor = torch.tensor(test_y).float()

# Data loadder の定義
train_dataloader_input = list(zip(train_graph, train_y_tensor))
train_loader = GraphDataLoader(train_dataloader_input, batch_size = 64, drop_last = False, shuffle = True)

test_dataloader_input = list(zip(test_graph, test_y_tensor))
test_loader = GraphDataLoader(test_dataloader_input, batch_size = 64, drop_last = False, shuffle = False)

# モデルのクラスを設定
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_1 = GraphConv(3, 8)
        self.conv_2 = GraphConv(8, 16)
        self.linear_1 = nn.Linear(16, 8)
        self.norm = nn.BatchNorm1d(8)
        self.linear_2 = nn.Linear(8, 1)
        
    def forward(self, g, f):
        h = F.relu(self.conv_1(g, f))
        h = self.conv_2(g, h)
        g.ndata["hidden"] = h
        
        x = dgl.mean_nodes(g, "hidden")
        x = F.relu(self.linear_1(x))
        x = self.norm(x)
        x = self.linear_2(x)
        
        return x

# モデルのインスタンスを作成
model = Model()

# 最適化関数の設定
optimizer = Adam(model.parameters(), lr = 0.001)

# 損失関数の設定
loss_function = nn.MSELoss()

# モデルの学習
epoch = 30
model.train()
for i in list(range(epoch)):
    
    epoch_loss = 0
    
    # バッチごとの処理
    for batch_graph, batch_y in train_loader:
        
        # トレーニングデータに対する誤差の算出
        batch_pred_y = model(batch_graph, batch_graph.ndata["atom_features"])
        batch_loss = loss_function(batch_pred_y, batch_y)
        
        # 誤差逆伝播と重み更新
        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()
        
        epoch_loss += batch_loss.item()
        
    # エポックごとの誤差を表示
    print(epoch_loss)

# モデルの推論
model.eval()
with torch.no_grad():
    test_loss = 0
    pred_y = torch.Tensor()
    
    # トレーニングと同様の処理を行うが、重みの更新はしない
    for batch_graph, batch_y in test_loader:
        
        batch_pred_y = model(batch_graph, batch_graph.ndata["atom_features"])
        batch_loss = loss_function(batch_pred_y, batch_y)        
        test_loss += batch_loss.item()
        pred_y = torch.cat([pred_y, batch_pred_y])
        
    # 誤差を表示
    print(test_loss)

# Numpy 型に戻した後に R2 を算出
pred_y = pred_y.detach().numpy()
print(r2_score(test_y, pred_y))