KerasのPretrained Modelをpbファイルで保存してTensorflowで解凍する (Tensorflow 1.15)
もともとImageNetで事前学習したモデルをTensorflowで使用したかったのですが,
Tensorflowの事前学習済みモデルの解凍方法が難しい (ビルドを失敗した) & Tensorflow 1.xの記事が少なくて苦労しました...
今回は, Kerasで事前学習したモデル (Pretrained model)をpbファイルで保存してTensorflow上で解凍するための操作を行いました.
事前学習モデル取得からpbファイル作成(1~3)までをスクリプトにまとめましたのでよろしければご覧下さい.
github.com
1. バージョン確認 & 事前準備
Tensorflow 1.15.2で行います.import tensorflow as tf print(tf.__version__)
1.15.2
また, tensorflow (r1.15)のfreeze_graph.pyを使用するのでcloneします.
git clone -b r1.15 --single-branch https://github.com/tensorflow/tensorflow.git
2. Kerasで事前学習済みモデルを保存 (ckpt)
Kerasの事前学習済みモデルは以下から確認することができます (ここで自分でデータセットを用意して学習しても構いません).Applications - Keras Documentation
事前学習済みモデルを読み込む or 自分でデータセットを用意してモデルを訓練した後に, .ckptファイルで保存します.
import tensorflow as tf # from keras import backend as Kでも動作する from tensorflow.keras import backend as K IMG_HEIGHT = 299 IMG_WIDTH = 299 # モデルをロード (Imagenetで事前学習済みのInceptionV3をロード) model = tf.keras.applications.inception_v3.InceptionV3(include_top=True, weights='imagenet', input_tensor=None, input_shape=(IMG_HEIGHT,IMG_WIDTH,3), pooling=None, classes=1000) # outputのノード名が必要なのでprintして確認する print(model.output.op.name) #ファイル名を.ckptとしてモデルを保存 saver = tf.train.Saver() saver.save(K.get_session(), 'frozen_model.ckpt')
predictions/Softmax
ckpt.meta, ckpt.data-00000-of-00001 (data-以降の記述は異なる場合がある), ckpt.indexの3種類のファイルが作成されるはずです.
1つだけの.ckptというファイルが作成されるわけではないので注意して下さい.
3. ckpt → pbに変換
cloneしたfreeze_graph.pyをオプションを指定して実行するとpbファイルが得られます.- --input_meta_graph = frozen_model.ckpt.meta (2で作成したファイル)
- --input_checkpoint = frozen_model.ckpt (2で作成したファイル)
- --output_graph = frozen_model.pb (出力ファイル名 ※.pbにする)
- --output_node_names = predictions/Softmax (output ノード名)
- --input_binary = true
python tensorflow/tensorflow/python/tools/freeze_graph.py --input_meta_graph=frozen_model.meta --input_checkpoint=frozen_model.ckpt --output_graph=frozen_model.pb --output_node_names=predictions/Softmax --input_binary=true
4. Tensorflow上でpbファイルを解凍する
以下の操作を行うことで, 訓練済みモデルのgraphがdefault_graphに追加されます.# pbファイルのあるディレクトリ model_dir = '' with tf.Session() as sess: with gfile.FastGFile(os.path.join(model_dir, 'frozen_model.pb'), 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) _ = tf.import_graph_def(graph_def, name = '')
.h5ファイルをpbファイルに変換したい場合は以下の記事がわかりやすいのでおすすめです.
keras_to_tensorflow - Qiita