【備忘録】PyTorchで黒橋研日本語BERT学習済みモデルを使ってみる

品川です。最近本格的にBERTを使い始めました。
京大黒橋研が公開している日本語学習済みBERTを試してみようとしてたのですが、Hugging Faceが若干仕様を変更していて少しだけハマったので、使い方を備忘録としてメモしておきます。

準備

学習済みモデルのダウンロード

下記の黒橋研のサイトから、学習済みモデル(.zip)をダウンロードして解凍します(この時ダウンロードするのはtransformers用のモデルです)
BERT日本語Pretrainedモデル - KUROHASHI-CHU-MURAWAKI LAB

特に、下記の点には注意です。今回はBASE 通常版を使ってみます。

(更新: 19/11/15) pytorch-pretrained-BERTはtransformersという名前にかわっています。こちらを使う場合は以下のモデルをお使いください。transformersで使う場合、モデルの絶対パスのどこかに「bert」の文字列を含んでいる必要があります。例えば、zipを解凍し、 /somewhere/bert/Japanese_L-12_H-768_A-12_E-30_BPE_transformers/ のように配置してください。

  • BASE 通常版: Japanese_L-12_H-768_A-12_E-30_BPE_transformers.zip (393M; 19/11/15公開)
  • BASE WWM版: Japanese_L-12_H-768_A-12_E-30_BPE_WWM_transformers.zip (393M; 19/11/15公開)
  • LARGE WWM版: Japanese_L-24_H-1024_A-16_E-30_BPE_WWM_transformers.zip (1.2G; 20/2/29公開)
Juman++のインストール

入力テキストの前処理にはJuman++が必要です。黒橋研のBERTモデルは最初に全角に正規化された入力文をJuman++で形態素単位に分割し、さらにBPEでサブワードに分割する処理をしているそうです。(詳しくは上記ページ参照)

sudo権限がないならば、下記のように自分のホーム以下にインストールしておくのが良いかと思います。

tar xJvf jumanpp-1.02.tar.xz
cd jumanpp-1.02
./configure --prefix=$HOME
make
make install
pyknp、transformersのpipインストール

また、pythonからJuman++とBERTを呼び出すためのライブラリもpipでいれておきます。

pip install pyknp
pip install transformers

このtransformersは過去の関連記事だとpytorch-pretrained-BERTだったようですが、今はTensorflow版と統合されて一つのライブラリになっているようです。
これに合わせて若干仕様が変更になっているので、使うときには注意が必要です。

ベースのコードを修正して動かしてみる

動かすのにベースとして参考にさせていただいたのは下記の記事です。
pytorchでBERTの日本語学習済みモデルを利用する - 文章埋め込み編 - Out-of-the-box

この記事ではコードも提供してくださってるのでありがたかったです。このリポジトリbert_juman.pyという名前のファイルをベースに動かしてみました。(この記事を書くついでにプルリクも一応送ってみました)
github.com

中身の詳細な説明は上記の記事に譲り、ここではbert_juman.pytransformersを利用するときの変更点のみについて書きます。
bert_juman.pyは全角入力のテキストに対して最上層ひとつ手前の隠れ層のベクトルをとってくるものになっています。
transformersでは、この隠れ層を取得するのに明示的に引数が必要になる点が異なります。

コード上の具体的な変更点は以下の2つです。
まず、ライブラリからのインポート自体は同じようにできるので、ライブラリ名のみを変えます。

import numpy as np
import torch
#from pytorch_pretrained_bert import BertTokenizer, BertModel
from transformers import BertTokenizer, BertModel
from pyknp import Juman

次に、get_sentence_embeddingのmodelのforward部分の引数にoutput_hidden_states=Trueを加えればOKです。ちなみにreturn_dict=Trueをつけてるのは、forwardした時の出力がdictでとれるので便利だからです。これをつけなければtupleで出力されます。

class BertWithJumanModel():
    ...
    def get_sentence_embedding(self, text, pooling_layer=-2, pooling_strategy=None):
        ...                                                                                                                                       
        with torch.no_grad():                                                                                                                                                                            
            # all_encoder_layers, _ = self.model(tokens_tensor) # for pytorch_pretrained_bert                                                                                                            
            all_encoder_layers = self.model(tokens_tensor, return_dict=True, output_hidden_states=True)["hidden_states"]  # for transformers                                                                                                                 

修正したコードで実際に試してみたらこんな感じです。

In [5]: from bert_juman_with_transformers import BertWithJumanModel

In [6]: bert = BertWithJumanModel("../../MODELS/bert/Japanese_L-12_H-768_A-12_E-30_BPE_transformers/")

In [7]: bert.get_sentence_embedding("吾輩は猫である。").shape
Out[7]: (768,)

In [8]: bert.get_sentence_embedding("吾輩は猫である。")
Out[8]:
array([-4.25627619e-01, -3.42006892e-01, -7.15176389e-02, -1.09820056e+00,
        1.08186698e+00, -2.35575914e-01, -1.89862773e-01, -5.50958455e-01,
        1.87978148e-01, -9.03697014e-01, -2.67813027e-01, -1.49959311e-01,
        5.91513515e-01, -3.52201462e-01,  1.84209332e-01,  4.01529483e-02,
        1.53244898e-01, -6.31160438e-01, -2.07539946e-01, -1.49968192e-01,
       -3.31581414e-01,  4.01663631e-01,  3.73950928e-01, -4.13331598e-01,

おまけ

get_sentence_embedding関数内の各変数の表示
text = "吾輩は猫である。"
preprocessed_text = _preprocess_text(text)
tokens = juman_tokenizer.tokenize(preprocessed_text)
bert_tokens = bert_tokenizer.tokenize(" ".join(tokens))
ids = bert_tokenizer.convert_tokens_to_ids(["[CLS]"] + bert_tokens[:126] + ["[SEP]"]) # max_seq_len-2
tokens_tensor = torch.tensor(ids).reshape(1, -1)

print(preprocessed_text)
print(tokens)
print(bert_tokens)
print(ids)
print(tokens_tensor)

#結果
吾輩は猫である。
['吾輩', 'は', '猫', 'である', '。']
['[UNK]', 'は', '猫', 'である', '。']
[2, 1, 9, 4817, 32, 7, 3]
tensor([[   2,    1,    9, 4817,   32,    7,    3]])
BERT modelのforwardの引数と出力の関係

transformersのmodeling_bert.pyを眺めるとわかります。

  • last_hidden_state: 最終層の隠れ層のベクトル(1xtoken数x各tokenのベクトル次元)
  • pooler_output: 最終層の隠れ層のベクトルの内、最初のtokenのみを取り出してdense+tanh()する操作(最終層のCLSに対応するベクトルの抽出)
  • hidden_states: 入力のembeddings、最終層の隠れ層のベクトルも含めた、全層の隠れベクトルのリスト(12段なら1xtoken数x各tokenのベクトル次元のtorch.tensorが13個できる)。リストの後ろの要素ほど最終層に近い層のベクトル。forwardの引数にoutput_hidden_states=Trueを入れると出力される
  • attentions: 各段のtransformerでのforward計算でのattentionのリスト(12段なら1xhead数xtoken数xtoken数のtorch.tensorが12個できる)。forwardの引数にoutput_attentions=Trueを入れると出力される
model.eval()
with torch.no_grad():
    outputs = model(tokens_tensor, return_dict=True, output_hidden_states=True, output_attentions=True)
print(outputs["last_hidden_state"].shape)
print(outputs["pooler_output"].shape)
print(len(outputs["attentions"]), [a.shape for a in outputs["attentions"]])
print(len(outputs["hidden_states"]), [h.shape for h in outputs["hidden_states"]])

#結果
odict_keys(['last_hidden_state', 'pooler_output', 'hidden_states', 'attentions'])
torch.Size([1, 12, 768])
torch.Size([1, 768])
13 [torch.Size([1, 12, 12, 12]), torch.Size([1, 12, 12, 12]), torch.Size([1, 12, 12, 12]), torch.Size([1, 12, 12, 12]), torch.Size([1, 12, 12, 12]), torch.Size([1, 12, 12, 12]), torch.Size([1, 12, 12, 12]), torch.Size([1, 12, 12, 12]), torch.Size([1, 12, 12, 12]), torch.Size([1, 12, 12, 12]), torch.Size([1, 12, 12, 12]), torch.Size([1, 12, 12, 12])]
12 [torch.Size([1, 12, 768]), torch.Size([1, 12, 768]), torch.Size([1, 12, 768]), torch.Size([1, 12, 768]), torch.Size([1, 12, 768]), torch.Size([1, 12, 768]), torch.Size([1, 12, 768]), torch.Size([1, 12, 768]), torch.Size([1, 12, 768]), torch.Size([1, 12, 768]), torch.Size([1, 12, 768]), torch.Size([1, 12, 768]), torch.Size([1, 12, 768])]

ちなみに

  • BASE版は隠れ層のtokenあたりの次元が768、層数が12、head数が12 (headあたりが担当する次元のサイズが64)
  • LARGE版は隠れ層のtokenあたりの次元数が1024、層数が24、head数が16 (headあたりが担当する次元のサイズが64)