fbpx
共著書籍「データ分析の進め方 及び AI・機械学習の導入の指南」が出版されました

Python scikit-learnで機械学習モデルを保存&ロードする

Python scikit-learnで機械学習モデルを保存&ロードする AI・機械学習・ディープラーニング
Python scikit-learnで機械学習モデルを保存&ロードする

Python Machine Learningで、正確な機械学習モデルを見つけることは、プロジェクトの終わりではありません。

今回は、scikit-learnを使って機械学習モデルを保存して読み込む方法を紹介します。これにより、モデルをファイルに保存して、それを後で読み込んで予測を行うことができます。

Pickleでモデルを保存する

pickleは、Pythonでオブジェクトをシリアライズする一般的な方法です。シリアライズ(直列化)とは、プログラミング言語においてオブジェクトをバイト列などの表現に変換することを言います。対して、デシリアライズ(非直列化)とは、バイト列を元にオブジェクトを復元することを言います。

pickle操作を使用して、機械学習モデルをシリアライズし、シリアライズされたフォーマットをファイルに保存することができます。

その後、このファイルを読み込んでモデルをデシリアライズし、新しい予測を行うために使用することができます。

以下の例は、糖尿病データセットのPima Indians発症に関するロジスティック回帰モデルをトレーニングした後、モデルをファイルに保存し、それを読み込んで新しいテストセットを予測しています。

# Pickleによるモデル保存
import pandas
from sklearn import model_selection
from sklearn.linear_model import LogisticRegression
import pickle

url = "https://archive.ics.uci.edu/ml/machine-learning-databases/pima-indians-diabetes/pima-indians-diabetes.data"

names = ['preg', 'plas', 'pres', 'skin', 'test', 'mass', 'pedi', 'age', 'class']
dataframe = pandas.read_csv(url, names=names)
array = dataframe.values
X = array[:,0:8]
Y = array[:,8]
test_size = 0.33
seed = 7
X_train, X_test, Y_train, Y_test = model_selection.train_test_split(X, Y, test_size=test_size, random_state=seed)

# モデルをトレーニングする
model = LogisticRegression()
model.fit(X_train, Y_train)

# モデルを保存する
filename = 'finalized_model.sav'
pickle.dump(model, open(filename, 'wb'))

# 保存したモデルをロードする
loaded_model = pickle.load(open(filename, 'rb'))
result = loaded_model.score(X_test, Y_test)
print(result)

この例を実行すると、トレーニングしたモデルが、ローカルディレクトリにfinalized_model.savとして保存されます。そして、保存されたモデルを読み込み、新しいデータセットで予測することができます。以下のような出力結果になります。

0.755905511811

joblibでモデルを保存する

JoblibはSciPyのライブラリで、Pythonのジョブをパイプライン処理するためのユーティリティを提供しています。NumPyデータ構造を効率的に利用するPythonオブジェクトの保存とロードを行うユーティリティを提供します。

以下の例は、Pima Indians on糖尿病データセットのロジスティック回帰モデルをトレーニングし、モデルをjoblibを使ってファイルに保存し、それを読み込んで新しいテストセットを予測しています。

# joblibによるモデル保存
import pandas
from sklearn import model_selection
from sklearn.linear_model import LogisticRegression
from sklearn.externals import joblib

url = "https://archive.ics.uci.edu/ml/machine-learning-databases/pima-indians-diabetes/pima-indians-diabetes.data"

names = ['preg', 'plas', 'pres', 'skin', 'test', 'mass', 'pedi', 'age', 'class']
dataframe = pandas.read_csv(url, names=names)
array = dataframe.values
X = array[:,0:8]
Y = array[:,8]
test_size = 0.33
seed = 7
X_train, X_test, Y_train, Y_test = model_selection.train_test_split(X, Y, test_size=test_size, random_state=seed)

# モデルをトレーニングする
model = LogisticRegression()
model.fit(X_train, Y_train)

# モデルを保存
filename = 'finalized_model.sav'
joblib.dump(model, filename)

# モデルをロードする
loaded_model = joblib.load(filename)
result = loaded_model.score(X_test, Y_test)
print(result)

この例を実行すると、モデルがfinalized_model.savというファイルに保存され、モデル内のNumPy配列ごとに1つのファイル(4つの追加ファイル)が作成されます。そして、モデルがロードされた後、新しいデータに対するモデルの精度の推定値が出力されます。以下のような出力結果になります。

0.755905511811

モデル保存のヒント

このセクションでは、機械学習モデルを構築する際の大切なポイントを紹介します。

Pythonバージョン

Pythonのバージョンに注意してください。後でモデルをロードしてデシリアライズするときに、モデルをシリアライズするために使用されたPythonと同じメジャーバージョンが必要です。

ライブラリバージョン

保存されたモデルをデシリアライズする場合、機会学習プロジェクトで使用されるすべての主要ライブラリのバージョンはほぼ同じである必要があります。これは、NumPyのバージョンとscikit-learnのバージョンに限定されません。

手動シリアル化

学習したモデルのパラメータを手動で出力して、それを、scikit-learnや他のプラットフォームで使用することができます。多くの場合、機械学習アルゴリズムが予測を行うために使用するアルゴリズムは、制御しているカスタムコードで簡単に実装できるパラメータを知るために使用されるアルゴリズムよりもはるかに簡単です。

何らかの理由によって、別の環境でモデルを後でリロードできない場合は、環境を再作成できるように、バージョンなどをメモしておいたほうが良いでしょう。

【稼ぐフリーランスSEの思考】「ありがたくお仕事をいただく」なんて思う必要がない理由
こんにちは、荒井(@yutakarai)です。 フリーランスとして独立したあとは 自分で仕事を取ってこなければいけなくなります。 しかし、仕事を取りたいあまり 下手(したて)になりすぎて損をしてしまっている方をよく見かけます。 実績もたいし...
コードの写経だけで満足しないために。初心者のための機械学習WEBシステム構築のはじめの一歩。シンプルな予測システムを作ってみる
こんにちは、荒井(@yutakarai)です。 僕が最初に機械学習に学びはじめた当初、誰かが書いて公開したソースコードをそのまま写して動かす、ということを繰り返していました。 誰かが書いたソースコードをそのまま写す「写経」から学べることも大...

【ロカラボからのお知らせ】
自社事業にAIを活用しようとする前にこれだけは押さえておいてください。

【無料ダウンロード】成功するAIプロジェクトに共通する3つの最重要ポイント

事業でAIを活用する企業様が多くなってきました。
弊社でも主に製造業・医療業を中心にAIシステムの開発や導入支援をおこなってきました。

その中で見えてきた、成功するAIプロジェクトに共通する最重要ポイントをまとめたPDFファイルを無料で配布しています。

AI導入プロジェクトをスタートする際には是非ご参考にいただけたらと思います。
こちらのページからダウンロードしてください。

AI・機械学習・ディープラーニング
シェアする
ロカラボをフォローする
タイトルとURLをコピーしました