本記事は、deeplearn.jsのサイトのPort TensorFlow modelsを翻訳(適宜意訳)したものです。誤り等あればご指摘いただけたら幸いです。
TensorFlowモデルをdeeplearn.jsに移植する
このチュートリアルでは、TensorFlowモデルをトレーニングしてdeeplearn.jsに移植する方法を示します。このチュートリアルで使用されているコードと必要なリソースはすべてdemos/mnist
に格納されています。
MNISTデータセットの手書き数字を予測する、完全結合ニューラルネットワーク(fully connected neural network)を使用します。このコードは公式のTensorFlow MNISTチュートリアルからforkされています。
[aside type=”normal”]
注:deeplearn.js repo のベース・ディレクトリを $BASE として参照します。
[/aside]
まず最初に、deeplearn.jsリポジトリをクローンし、TensorFlowがインストールされていることを確認します。$BASE
に移動(cd)して次のコマンドを実行し、モデルをトレーニングします。
python demos/mnist/fully_connected_feed.py
トレーニングには約1分かかり、/tmp/tensorflow/mnist/tensorflow/mnist/logs/fully_connected_feed/
にモデルチェックポイントが格納されます。
次に、TensorFlowチェックポイントからdeeplearn.jsに重み(weight)を移植する必要があります。これを行うスクリプトを提供しています。 $BASE
ディレクトリから実行します。
python scripts/dump_checkpoint_vars.py --output_dir=demos/mnist/ --checkpoint_file=/tmp/tensorflow/mnist/logs/fully_connected_feed/model.ckpt-1999
このスクリプトは、demos/mnist
ディレクトリに一連のファイル(variableごとに1つのファイルと、manifest.json
)を保存します。manifest.json
は、変数名をファイルとその形状にマップする単純なディクショナリーです。
{ ..., "hidden1/weights": { "filename": "hidden1_weights", "shape": [784, 128] }, ... }
コーディングを開始する前に、$BASE
ディレクトリから静的なHTTPサーバーを起動する必要があります。
npm run prep ./node_modules/.bin/http-server >> Starting up http-server, serving ./ >> Available on: >> http://127.0.0.1:8080 >> Hit CTRL-C to stop the server
ブラウザでhttp://localhost:8080/demos/mnist/manifest.json
にアクセスして、HTTP経由でmanifest.json
にアクセスできることを確認してください。
これで、deeplearn.jsコードを書く準備が整いました。
[aside type=”normal”]
注:TypeScriptで記述する場合は、コードをJavaScriptにコンパイルして、静的HTTPサーバー経由で提供するようにしてください。
[/aside]
重み(weight)を読むには、CheckpointLoader
を作成し、manifestファイルを指し示す必要があります。次に、変数名をNDArrays
にマップするディクショナリーを返すloader.getAllVariables()
を呼び出します。これで、モデルを書く準備が整いました。以下は、CheckpointLoader
の使用方法を示す抜粋になります。
import {CheckpointLoader, Graph} from 'deeplearnjs'; // manifest.json is in the same dir as index.html. const varLoader = new CheckpointLoader('.'); varLoader.getAllVariables().then(vars => { // Write your model here. const g = new Graph(); const input = g.placeholder('input', [784]); const hidden1W = g.constant(vars['hidden1/weights']); const hidden1B = g.constant(vars['hidden1/biases']); const hidden1 = g.relu(g.add(g.matmul(input, hidden1W), hidden1B)); ... ... const math = new NDArrayMathGPU(); const sess = new Session(g, math); math.scope(() => { const result = sess.eval(...); console.log(result.getValues()); }); });
完全なモデルコードの詳細については、demos/mnist/mnist.ts
を参照してください。このデモでは、3つの異なるAPIを使用してMNISTモデルを正確に実装しています。
buildModelGraphAPI()
は、TensorFlow APIを模倣したGraph APIを使用して、フィードとフェッチを遅延実行(lazy execution)します。ユーザーは、入力データ以外のGPU関連のメモリリークを心配する必要はありません。buildModelLayerAPI()
は、Graph APIをKeraレイヤAPIを模倣するGraph.layers
と組み合わせて使用します。buildModelMathAPI()
は、Math APIを使用します。これはdeeplearn.jsの最も低いレベルのAPIであり、ユーザに最も多くの機能を与えます。数学コマンドはnumpyのようにすぐに実行されます。mathコマンドはmath.scope()に含まれ、中間のmathコマンドで作成されたNDArraysが自動的にクリーンアップされます。
このmnistデモを実行するために、変更されたときにタイプコピーコードを見て再コンパイルするwatch-demo
スクリプトがあります。さらに、スクリプトは、静的なhtml/jsファイルを提供する8080上の単純なHTTPサーバーを実行します。watch-demo
を実行する前に、8080ポートを解放するために、チュートリアルで前述したHTTPサーバーを終了させてください。次に、$BASE
から web app デモのエントリ・ポイントのdemos/mnist/mnist.ts
を指すように、watch-demo
を実行します。
./scripts/watch-demo demos/mnist/mnist.ts >> Starting up http-server, serving ./ >> Available on: >> http://127.0.0.1:8080 >> http://192.168.1.5:8080 >> Hit CTRL-C to stop the server >> 1410084 bytes written to demos/mnist/bundle.js (0.91 seconds) at 5:17:45 PM
http://localhost:8080/demos/mnist/
にアクセスすると、demos/mnist/sample_data.json
に保存されているテストイメージを使用して測定された〜90%のテスト精度を示す簡単なページが表示されます。