AI・機械学習 コラム

事前学習の改善で自然言語処理の精度が向上|Transformerベースのモデル「ProphetNet」を紹介

はじめに

インフォマティクスで機械学習業務を担当している大橋です。

今回はディープラーニングを用いた自然言語処理の論文、ProphetNet [1] を紹介したいと思います。既存の構造に対する修正を最小限にし、ちょっとした発想で事前学習の方法を改善して精度を向上させたことが書かれている興味深い論文です。

自然言語処理と転移学習

自然言語処理とは、機械に日本語や英語等の言語を理解・判別させたり生成させたりする処理のことです。例えば迷惑メールの判別、文章の自動要約、日英翻訳などです。

これらを機械学習で実現しようとしたときに問題になるのが、正解付き学習データの作成です。十分な精度を出すためには相応の量を用意する必要があり、多くの場合人間が手作業で作成するため、かなりの時間がかかります。先の要約の例だと、ニュース記事等の文章に対してその要約文を作成する作業です。

この問題を軽減するため、機械学習では「転移学習」という手法がよく用いられます。この手法では①事前学習②ファインチューニングの2段階で学習させます。

まず「①事前学習」で、Web等に大量に存在しているテキストデータを用いて文法や単語の意味等を学習させます。この学習では人間が正解データを作成する必要がなく、データを収集してくるだけでよいので安価です。

その後「②ファインチューニング」で、人間が用意したデータを用いて学習させます。既に言語の基本的な知識を備えた状態で学習するので、0から学習するのに比べて少量のデータで済みます。①で学習したものを②の学習に転用・転移するので転移学習と呼ばれます。

事前学習のさせ方にも色々なものが提案されています。この部分を工夫したら文章要約や文章生成の精度が向上したというのが今回紹介する「ProphetNet」です。

※本記事は前回記事同様、Transformer [2] の知識がある程度あることを前提として書かれています。Transformerについても、いずれ解説したいと思います。

ProphetNet概要

先述した通り、事前学習にはいろいろなものがあります。

代表的なものとして、文章を途中まで与えて次の単語を予測させるもの(Language Model)や、文章の中からランダムに単語を隠し虫食い状態にしてその部分を予測させるもの(Denoise)等があります。

例えば「吾輩は猫である」という文章を考えます。このとき「吾輩は」までを与えて「猫」を予測させるのがLanguage Modelで、「吾輩は×である」を与えて×の部分を「猫」と予測させるのがDenoiseです。

今回紹介するProphetNetはLanguage Modelに類する事前学習を行います。

具体的には、文章を途中まで与えたときに次の単語のみを予測するのではなく\(N\)個先まで予測させます。つまり\(N=2\)の場合、先ほどの例では「吾輩は」を与えたときに「猫で」までを予測させることに対応します。図示すると図1です。

図下方の\(X_0\)~\(X_5\)がモデルへの入力で、図上方の\(X_1\)~\(X_7\)がモデルの出力です。Language Modelは各単語に対して一つの次の単語を予測していますが、Bigram ProphetNetは次とその次の2つの単語を予測します。

図1:参考文献[1]より抜粋

このような手法が提案された経緯には、Language Modelの場合、直前の単語に強く影響されてしまい文章全体としての意味を捉えきれていないと指摘されていたということがあります。

予測する範囲を大きくして、より文章全体を見渡せるよう工夫したのが今回提案された手法です。

ProphetNet詳細

それではどのように\(N\)個先まで予測させるのでしょうか?

ProphetNetはTransformerをベースに作られたエンコーダ・デコーダモデルです。まず基本的な構成を図示すると図2のようになります。

図2:参考文献[1]より抜粋

順を追って説明します。

ネットワークの入出力

図2左下がネットワークの入力です。事前学習の際は、単語列を一定確率で虫食い式に隠したものを入力とします。

図2右上がネットワークの出力です。ネットワークは、入力文章をもとに隠された単語から\(n\)個先までの単語を予測します。

エンコーダ

図2の左中はエンコーダにあたります。エンコーダ部分は、オリジナルのTrasnformerのものと変更ありません。入力単語列\(x_1,\dots ,x_M\)に対して出力を\(H_{enc}\)と書くことにすると、エンコーダ部分は

\(H_{enc}=Encoder(x_1,\dots ,x_M)\)

と抽象的に書けます。

デコーダ

図2の右中はデコーダにあたります。Transformerのデコーダは、エンコーダの出力と\(t-1\)番目までの予測を用いて\(t\)番目の単語を予測しますが、その予測確率\(p(y_t|y_{< t},x)\)は以下のように書けます。

\(p(y_t|y_{< t},x) =Decoder(y_{< t},H_{enc})\)

ProphetNetデコーダ部分は、この部分を次の単語だけでなく\(n\)個先まで予測するよう、以下のように変更します。

\(p(y_t|y_{< t},x),\dots ,p(y_{t+n-1}|y_{< t},x) =Decoder(y_{< t},H_{enc})\)

目的関数

Language Modelではステップにつき1単語しか予測しませんでしたが、ProphetNetでは複数単語を予測します。そのため、それらの出力に対して目的関数を以下のように拡張します。

\(\mathcal{L}=-\sum^{N-1}_{n=0}\alpha_{n}\left(\sum^{T-n}_{t=1}\log p_{\theta}(y_{t+n}|y_{< t},x)\right)\)

\(\alpha_{n}\)はハイパーパラメータ、\(\theta \)はモデルのパラメータです。最初の和の\(n=0\)の部分がオリジナルのTransformerで採用されているLanguage Modelの損失関数で、\(n\ge 1\)の部分が今回追加された損失関数の部分です。

デコーダの注意機構

図2の右中のデコーダの具体的な構成を紹介します。構成をまとめたものが図3になります。

図3:参考文献[1]より抜粋

まずAttention機構を復習すると、3種類のベクトル\(Q\)(query), \(K\)(key), \(V\)(value)を入力としたとき以下のように演算するものです。

\( Attention(Q,K,V)=Softmax(\frac{QK^{T}}{\sqrt{d_k}})V\)

Transformerのオリジナルのデコーダは2種類のAttention機構を持ちますが、ProphetNetではMasked Attention [2]と呼ばれる方に変更を加えます。

\(k+1\)層目でのTransformerのオリジナルのMasked Attentionの出力は\(k\)層目での\(t\)番目単語埋め込みを\(h^{(k)}_{t}\)と書くと、図3(a)のように以下となります。

\( h^{(k+1)}_{t}=Attention(h^{(k)}_t, h^{(k)}_{\le t}, h^{(k)}_{\le t} )\)

これを論文[1]ではmain streamと呼んでいます。

ProphetNetでは今までの単語の情報を用いて\(n\)個先までの単語の予測のためのベクトルをそれぞれに出力するようにします。そのため、main streamで計算されていた\(h^{(k)}_{l}\)に加えて、1個先の単語を予測するための新たなベクトル\(g^{(k)}_{l}\)を導入します。

\(k+1\)層目での\(g\)の出力は図3(b)のように以下で計算されます。

\( g^{(k+1)}_{t}=Attention(g^{(k)}_t, h^{(k)}_{\le t}\oplus g^{(k)}, h^{(k)}_{\le t}\oplus g^{(k)} )\)

この計算過程をfirst predicting streamと呼びます。

2個先の単語まで予測する場合は、さらにsecond predicting streamを導入します。新たなベクトル\(s^{(k)}_{l}\)を導入し、\(k+1\)層目での\(s\)の出力は\(g\)同様、図3(c)のように以下で計算されます。

\( s^{(k+1)}_{t}=Attention(s^{(k)}_t, h^{(k)}_{\le t}\oplus s^{(k)}, h^{(k)}_{\le t}\oplus s^{(k)} )\)

以降同様に\(N\)個先まで予測させるときは\(N\)-th predicting streamまで導入します。

デコーダへの入力

図3(d)下部はデコーダへの入力です。main streamに関してはオリジナルと変更ありませんが、\(n\)-th prediction stream (\(1 \leq n \leq N\))に関しては学習可能なパラメータ\(p_{n}\)を入力としています。

 

精度

肝心の事前学習ですが、BERTに倣って英語のWikipediaとBookCorpusを用いています。その事前学習したモデルを個々のタスクにファインチューニングした精度を以下で見ていきます。

要約

まずは表1、表2に文章要約のタスクであるCNN/Daily MailGiwawordの精度をそれぞれ掲載します。これらは記事からその要約文を推定するタスクです。

表1:参考文献[1]より抜粋

表2:参考文献[1]より抜粋

記事執筆時にはさらに精度の高いモデルが考案されていますが、論文[1]で比較されているモデルの中では最も精度が高い結果になっています。

文章生成

次に表3に文章生成タスクの精度を掲載します。使用したデータセットSQuAD1.1では、記事とその内容に基づく質問文とその答えの3つが与えられます。

表3は記事と答えを与えて、その答えを導く質問を生成するタスクの精度です。

表3:参考文献[1]より抜粋

こちらも記事執筆時にはさらに精度の高いモデルが考案されていますが、論文[1]で比較されているモデルの中では最も精度が高い結果になっています。

データサイズと精度

事前学習で用いるデータセットを大きくすると精度が上がることが知られていますが、ProphetNetでもデータセットを大きくしてみて精度を検証しています。結果を表4に掲載します。

表4:参考文献[1]より抜粋

表2、表3と比べて精度が向上していることが分かります。また、そのほかのモデルと比較すると、小さなデータサイスで同等の精度を達成していることが分かります。他のモデルに比べて効率的に学習できているようです。

予測単語数の違い

最後にどのくらい先までの単語を予測させるかで、どのくらい精度に違いが生じるかの結果を掲載します。

表5では1~3単語先までを予測させたものでの精度の違いを見ています。評価に用いたデータセットは文章要約タスクであるCNN/DailyMailデータセットです。

表5:参考文献[1]より抜粋

予測単語数を大きくすれば精度が上がっていることが分かります。予測する単語数が多くなれば、それだけ正しく文脈を理解しないといけないので自然な結果ですね。

まとめ

一つ先の単語を予測するだけではなく、N個先まで予測させてみるという発想で精度を向上させたという論文でした。

どのような手法が精度が出るか等を実験で確かめるというのはこの分野の王道ですが、それに加え数学的な解析がなされるようになるとより面白いかなと思いました。そのためにはもう少し分野の成熟が必要かもしれませんが、今後の研究に期待したいです。

本記事では省略した部分もあるので、詳細が気になる方は元論文[1]にあたられることをお勧めします。

参考文献

[1]Yan, Y., Qi, W., Gong, Y., Liu, D., Duan, N., Chen, J., Zhang, R., & Zhou, M. (2020). ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training. ArXiv, abs/2001.04063.

[2]Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser, L., & Polosukhin, I. (2017). Attention is All you Need. ArXiv, abs/1706.03762.

-AI・機械学習, コラム
-, ,

© 2020 株式会社インフォマティクス