ご注文はリード化合物ですか?〜医薬化学録にわ〜

自分の勉強や備忘録などを兼ねて好き勝手なことを書いていくブログです。

Deep Graph Library によるグラフ畳み込みネットワークの基本(追記予定)

最近、論文を書いてたり会議が増えたり人事関連で色々あったり投資を始めたりごちうさを観たりで忙しく、中々記事を書けていませんでしたが、久々にブログを更新しました。

AI ブーム、深層学習ブームは化学分野でも続いていますが、化学者としては、構造式を取り扱いたいことが多く、構造式を AI に入力をして、物性とか薬理作用とか出てこないかなぁ、と思うこともあります。
それを実現するのがグラフ畳み込みネットワーク(Graph convolutional neural network: GCNN)です。
GCNN を使った研究は時々あるものの、GCNN って何をしているのか、どうやって実装するのか、意外と日本語でまとまっていないので、以下のサンプルコードを元に解説していきます。

なお、オリジナルのコードは、pen さんのブログにあり、それを劣化させました。

Try GCN QSPR with pytorch based graph library #RDKit #Pytorch #dgliwatobipen.wordpress.com


import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.autograd import Variable
from sklearn.metrics import r2_score
from rdkit import Chem
from rdkit.Chem.Draw import IPythonConsole
import dgl
import dgl.function as fn
from dgl import DGLGraph

# 原子の種類のリスト
element_list = ["C", "N", "O", "S", "P", "F", "Cl", "Br", "I", "B", "Si", "Se", 'H', "Unknown"]
 
atom_features_dim = len(element_list) + 6 + 5 + 1
MAX_ATOMNUM =60
BOND_FDIM = 5
MAX_NB = 10
 
# 特徴リストの one hot vector を算出
def atom_feature_encoding(feature, feature_list):
    
    # 特徴リストに含まれない場合、 "Unknown" 表記にする
    if feature not in feature_list:
        feature = feature_list[-1]
        
    # 返り値は True or False で構成されるリスト
    return [feature == f for f in feature_list]
 

# 各原子の特徴量を出力
# 原子の種類(23 次元)+ 結合次数(6 次元)+電荷(5 次元)+ 芳香族性(1 次元)
def atom_features(atom):
    return (atom_feature_encoding(atom.GetSymbol(), element_list)
            + atom_feature_encoding(atom.GetDegree(), [0,1,2,3,4,5])
            + atom_feature_encoding(atom.GetFormalCharge(), [-1,-2,1,2,0])
            + [atom.GetIsAromatic()])
 
# 結合の種類(単結合、二重結合、三重結合、共役系、環構造 それぞれ 1 or 0 (True of False) で表現)
def bond_features(bond):
    bond_type = bond.GetBondType()
    return (torch.Tensor([bond_type == Chem.rdchem.BondType.SINGLE,
                                      bond_type == Chem.rdchem.BondType.DOUBLE,
                                      bond_type == Chem.rdchem.BondType.TRIPLE,
                                      bond_type == Chem.rdchem.BondType.AROMATIC,
                                      bond.IsInRing()]))
 
# Mol class の分子データを GCNN に入力する DGL Graph 型に変換
def mol2dgl_single(mols):
    
    # 各分子の DGL Graph 型を入れるリスト
    cand_graphs = []
    
    n_nodes = 0
    n_edges = 0
 
    # 各分子ごとに処理
    for mol in mols:
        
        # DGL Graph の型を用意
        g = DGLGraph() 
        
        # 各原子(グラフのノード)の特徴量を格納するリスト
        atom_feature_list = []
        
        # 原子の情報(Atom class)を atom_feature_list に突っ込む
        for atom in mol.GetAtoms():
            atom_feature_list.append(atom_features(atom))
            
        # 分子中の原子数を DGL Graph 型に入力
        g.add_nodes(mol.GetNumAtoms())
        
        # 各原子の特徴量ベクトルを DGL Graph 型に "h" という名前で入力
        g.ndata["h"] = torch.Tensor(atom_feature_list)
 
        bond_begin_list = []
        bond_end_list = []
        
        # 結合毎に処理、結合の向きも考慮
        for bond in mol.GetBonds():
            
            # 結合の始点と終点の原子のインデックスを取得
            begin_idx = bond.GetBeginAtom().GetIdx()
            end_idx = bond.GetEndAtom().GetIdx()
                        
            # 結合の特徴量を算出
            features = bond_features(bond)
 
            # インデックスに関する情報をリストとしてまとめる
            bond_begin_list.append(begin_idx)
            bond_end_list.append(end_idx)
            
            # (無向グラフなので)向きを変えて同じ操作を行う
            bond_begin_list.append(end_idx)
            bond_end_list.append(begin_idx)
            
        # 隣接するノードのインデックスを DGL Graph 型に入力
        g.add_edges(bond_begin_list, bond_end_list)

        # 出力用リストに入力
        cand_graphs.append(g)
        
    # DGL グラフ型として出力する
    return cand_graphs

少し長いですが、上が構造式を GCNN に入力できるよう変換するためのコードです。
最初に、いつものようにライブラリを読んでいきます。
今回は深層学習ライブラリとして PyTorch を用います。Tensorflow よりも GCNN 周りのライブラリが充実している気がします。
GCNN のライブラリは Deep Graph Library (DGL) を用います。
PyTorch と GCNN は terminal から pip で簡単に入れられます。

pip install torch torchvision
pip install dgl

構造式中の各原子(グラフの用語ではノードとも呼ぶ)ごとに、特徴量を binary vector(0 or 1 で構成されたベクトル)として算出します。
今回は、原子の種類は何か、結合次数と形式電荷はいくつか、芳香属性はあるか、これらの特徴を取り出します。
atom_feature_encoding という関数は、第一引数の feature (文字でも数字でも良い)が第二引数に入力しているリストに含まれるどの要素に対応しているかを確認します。返り値は引数に用いたリストと同じ長さの binary vector で、対応する要素のみ 1 、他の要素は 0 となっています。
atom features と bond features という関数は atom_feature_encoding を実行するだけの関数で、引数が RDKit の atom クラス bond クラスを用いているかという違いだけで、役回りは一緒です(ただ、pen さんのオリジナルのコードだと、bond クラスで取得した情報がその後のモデルで使われていない気がします)。

mol2dgl_single が本丸で、RDKit の mol オブジェクトを DGL Graph という GCNN に入力可能な型に変換する関数となります。
引数の mols は mol オブジェクトのリストです(Chem.SDMolSupplier で直接読み込んだものは使えないので注意。[mol for mol in supplier] といった変換が必要)。
cand_graphs が出力用のリストです。このリストの中に、各分子の DGL Graph オブフェクトが入ります。

for 文で各分子の特徴量を算出と DGL Graph オブフェクトの用意をしていきます。
最初に出力する DGL Graph インスタンスを g と定義して用意します。atom_features で原子の特徴量を計算したら、g.ndata["h"] に Tensor 型として入力します。この操作で、DGL Graph オブジェクトに "h" という名称でノード(原子)の情報を定義します。入力する Tensor 型は原子数 * 原子の特徴量数の二次元ベクトルです。

(以下、追記予定)