Python scikit-learnで機械学習モデルを保存&ロードする

Python scikit-learnで機械学習モデルを保存&ロードする

Python Machine Learningで、正確な機械学習モデルを見つけることは、プロジェクトの終わりではありません。

今回は、scikit-learnを使って機械学習モデルを保存して読み込む方法を紹介します。これにより、モデルをファイルに保存して、それを後で読み込んで予測を行うことができます。

Pickleでモデルを保存する

pickleは、Pythonでオブジェクトをシリアライズする一般的な方法です。シリアライズ(直列化)とは、プログラミング言語においてオブジェクトをバイト列などの表現に変換することを言います。対して、デシリアライズ(非直列化)とは、バイト列を元にオブジェクトを復元することを言います。

pickle操作を使用して、機械学習モデルをシリアライズし、シリアライズされたフォーマットをファイルに保存することができます。

その後、このファイルを読み込んでモデルをデシリアライズし、新しい予測を行うために使用することができます。

以下の例は、糖尿病データセットのPima Indians発症に関するロジスティック回帰モデルをトレーニングした後、モデルをファイルに保存し、それを読み込んで新しいテストセットを予測しています。

この例を実行すると、トレーニングしたモデルが、ローカルディレクトリにfinalized_model.savとして保存されます。そして、保存されたモデルを読み込み、新しいデータセットで予測することができます。以下のような出力結果になります。

joblibでモデルを保存する

JoblibはSciPyのライブラリで、Pythonのジョブをパイプライン処理するためのユーティリティを提供しています。NumPyデータ構造を効率的に利用するPythonオブジェクトの保存とロードを行うユーティリティを提供します。

以下の例は、Pima Indians on糖尿病データセットのロジスティック回帰モデルをトレーニングし、モデルをjoblibを使ってファイルに保存し、それを読み込んで新しいテストセットを予測しています。

この例を実行すると、モデルがfinalized_model.savというファイルに保存され、モデル内のNumPy配列ごとに1つのファイル(4つの追加ファイル)が作成されます。そして、モデルがロードされた後、新しいデータに対するモデルの精度の推定値が出力されます。以下のような出力結果になります。

モデル保存のヒント

このセクションでは、機械学習モデルを構築する際の大切なポイントを紹介します。

Pythonバージョン

Pythonのバージョンに注意してください。後でモデルをロードしてデシリアライズするときに、モデルをシリアライズするために使用されたPythonと同じメジャーバージョンが必要です。

ライブラリバージョン

保存されたモデルをデシリアライズする場合、機会学習プロジェクトで使用されるすべての主要ライブラリのバージョンはほぼ同じである必要があります。これは、NumPyのバージョンとscikit-learnのバージョンに限定されません。

手動シリアル化

学習したモデルのパラメータを手動で出力して、それを、scikit-learnや他のプラットフォームで使用することができます。多くの場合、機械学習アルゴリズムが予測を行うために使用するアルゴリズムは、制御しているカスタムコードで簡単に実装できるパラメータを知るために使用されるアルゴリズムよりもはるかに簡単です。

何らかの理由によって、別の環境でモデルを後でリロードできない場合は、環境を再作成できるように、バージョンなどをメモしておいたほうが良いでしょう。

Related Post

ポッドキャスト配信中

SE社長アライの「海外スタートアップ研究室」
今、何よりも必要なのは「多様化するマーケットを読み解く力」です。
海外スタートアップの研究や、そこから学んだビジネスのヒントを共有しています。

お問合せはこちら

サービスに関する質問や記事に関するお問合せなど、お気軽にご連絡ください。お問合せをいただいてから原則24時間以内に返信させていただきます。