RNN、LSTMで時系列データを生成

文脈を持つ時系列データをRNN、LSTMに学習させて、新たなデータを生成することを目的とします。

 

以下の本を参考にしました。

shop.ohmsha.co.jp

 

今回は3種類のデータを用意します。

1つめは、テキストデータです。英文です。

2つめは、アルファベット順に並んだアルファベット列です。スタートの単語、単語数は自由とします。

3つめは、数字が前の2数の足し算になっている数列です。スタートの2数はランダムに決めます。また、数字は1桁のみとします。3、9ときたら次は12ではなく2です。

 

この3種類のデータを学習させて、生成させることが目的です。

 

1:データを用意する

2:環境を整えて、学習準備

3:学習する

4:生成する

5:結果確認

 

この5ステップで説明します。

 

作成したプログラムは全てgithubに上げてあります。

github.com

 

       1:データを用意する

 

作成したプログラムは data_pre.py、data_pre2.pyです。

まず、テキストデータについてです。GPUがないため、大容量のデータは学習できないので、自分で適当に打ったやつを使います。とりあえずRNNを試せればいいので10行程度のかなり軽いものです。今度、暇があったら、しっかりとしたデータで取り組みたいと思います。これをtrain.txtとして保存しました。

次に、アルファベット列データです。自分で作ります。data_pre.pyです。このデータをtrain_alpha_raw.pklに保存します。

最後に、数列データです。これも自分で作ります。data_pre2.pyです。このデータをtrain_number_raw.pklに保存します。

 

       2:環境を整えて、学習準備

 作成したプログラムはconvert.pyです。

学習できるようにそれぞれのデータを数値データにする必要があります。それぞれの単語などに対応するidを用意して、idでデータを置き換えます。そのデータを保存します。convert.pyでこれをやっています。

 

アルファベット列のidデータセットは、train_alpha_con.pkl、id辞書はtrain_alpha_voc.pklです。

数列は、train_number_con.pkl、train_number_voc.pklです。

テキストデータは、train_txt_con.pkl、train_txt_voc.pklです。

 

 

       3:学習する

作成したプログラムはrnn.py、train.pyです。

まずRNNのネットワークをクラスで定義します。rnn.pyです。これを使って2で用意したファイルからデータを読み込んで学習します。最後に学習したパラメータをファイルに保存します。これはtrain.pyです。テキストデータは、train_txt_embed.pkl、train_txt_H.pkl、train_txt_W.pklに保存します。アルファベット列と数列では、txtの部分がそれぞれalpha、numberに代わります。LSTMは、L.LinearをL.LSTMに変えるだけで使えます。今回は、RNNを学習します。

  

        4:生成する

学習は完了しました。結果をもとにデータを生成します。3同様、RNNのネットワークを使うのでrnnをimportしています。LSTMではなくRNNを使っています。

 predict.py です。

 

       5:結果確認

テキストデータ、アルファベット列、数列、各条件で5回ずつ生成します。LSTMではなく、RNNの結果です。

 

1:テキストデータ

テキストデータはそもそも単語数が20種類くらいしかないので、結果は知れていますが、やってみます。

 

10epoch:

many am many am many am many am many am many am many am many am many am many am
night hello you am family do do do person person person from from from from person person am everyone from
likes have do my you am family do do do person person person from from from from person person am
am Tokyo hello you my you am family Tokyo do am you from likes am everyone my are am many
night have am everyone from from from from person person am everyone from from from from person person am everyone

 

100epoch:

name is <eos>
your name <eos>
have <eos>
is your name <eos>
is your name <eos>

 

1000epoch:

are you from <eos>
are you from <eos>
name is person <eos>
<eos>
beautiful day <eos>

 

2:アルファベット列

データ数も変化させて比べてみようと思います。100epoch、100dataで学習に1分かかりました。きついです。アルファベット順になればokです。

 

10epoch:データ数100

<eos>
<eos>
<eos>
h i j <eos>
<eos>

 

100epoch:データ数100:

k l m n o p q r <eos>
c d e f g h i j <eos>
i j k l m n <eos>
<eos>
i j k l m n <eos>

 

10epoch:データ数1000:

<eos>
m n o p <eos>
h i j k l m n o <eos>
f g h i j k l <eos>
k l m n o p q <eos>

 

3:数列

前の2数足した値になればokです。2桁になる場合は1桁目だけです。あんまりうまくいってないです。epoch数とデータ数を増やせば、いい感じになると思います。

 

10epoch:データ数100:

8 8 <eos>
2 0 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4
4 2 8 <eos>
<eos>
<eos>

 

100epoch:データ数100:

4 2 6 8 4 2 6 8 4 2 6 8 4 2 6 8 4 2 6 8
<eos>
4 7 6 <eos>
<eos>
<eos>

 

10epoch:データ数1000:

8 7 9 4 9 5 4 9 7 8 7 7 4 3 7 8 3 1 8 9
6 5 4 4 7 0 6 3 1 9 1 2 <eos>
2 <eos>
6 5 4 4 7 0 6 3 1 9 1 2 <eos>
8 8 0 0 4 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0

 

以上です。今度、GPUを用意して本格的にやってみたいです。

アドバイス、改善点、質問があればお願いします。