TensorflowLiteでAndroidでCNN
この記事はAizu Adc 2018 20日目にかかれた4日目の記事です。
前の人は@xatu0202氏, 次の人@ywkw氏です。
こんにちは。@masapontoです。最近は東京でデータ分析太郎として暮らしております。
気がづくと本ブログはアドベントカレンダーでしか書かなくなってしまいました。 今回はTensorflow Lite を使って Android上で画像分類(CNN)をする話をしようと思います。
TonsorFlow Liteを紹介する多くの日本語記事は、サンプルとして用意されている学習済みモデルを動かしてみたっていう話が多い気がします(主観)。
それに対し本記事では、自身で定義して学習を行ったCNNを動かしてみようと思います。
データはCIFAR10っていう画像データセットを使います。
CIFAR10はairplane, automobile, bird, cat, deer, dog, frog, horse, ship, truckの10クラスの画像データセットです。
https://www.cs.toronto.edu/~kriz/cifar.html
※本記事の内容は、単に動かしてみたよって感じの記事なので、詳細な説明を求める方にはおすすめできません。 まぁ何いっても公式ドキュメントを読めばいい話です(完)。 https://www.tensorflow.org/lite/devguide
Tensorflow Liteとは
モバイル端末や組み込み端末で機械学習の推論をするぞいっていうライブラリです。
https://www.tensorflow.org/lite/
基本的な使い方としては、以下の3ステップです。
- PC上のTensorFlowでNNを学習させてモデルMを作る。 (on Python)
- MをAndroid向けに変換し、 M'を作る。 (on Python)
- M'をAndroidアプリのassets/フォルダとかにいれて、呼び出す。 (on Android)
手順 on Python
私の開発環境は、ArchLinux (x64) / Python 3.6.1/ tensorflow 0.12.0 です。
で、下記のようにCNNを学習させるコードを書きます。 いつのまにかTensorFlowにKerasインターフェースが入っていたので、それを使って書いてみました。
# !/usr/bin/env python import tensorflow as tf def convolutional(): model = tf.keras.models.Sequential([ tf.keras.layers.Conv2D(filters=32, kernel_size=[5, 5], padding='same', activation='relu', input_shape=(32, 32, 3), name='input'), tf.keras.layers.MaxPooling2D(pool_size=[2, 2], strides=2), tf.keras.layers.Conv2D(filters=32, kernel_size=[5, 5], padding='same', activation='relu'), tf.keras.layers.MaxPooling2D(pool_size=[2, 2], strides=2), tf.keras.layers.Flatten(), tf.keras.layers.Dense(units=1024, activation='relu'), tf.keras.layers.Dropout(rate=0.4, trainable=True), tf.keras.layers.Dense(units=10, activation='softmax', name='output') ]) return model def main(): (X_train, Y_train), (X_test, Y_test) = tf.keras.datasets.cifar10.load_data() X_train = X_train/255 X_test = X_test/255 print(X_train.shape) Y_train = tf.keras.utils.to_categorical(Y_train, 10) Y_test = tf.keras.utils.to_categorical(Y_test, 10) model = convolutional() model.compile(optimizer=tf.keras.optimizers.Adam(0.0001), loss='categorical_crossentropy', metrics=['accuracy']) model.fit(X_train, Y_train, epochs=100, batch_size=50, verbose=1) # モデルの図示 (任意) tf.keras.utils.plot_model(model, to_file='model.png', show_shapes=True) # モデルをh5で保存 keras_file = "model_keras/cnn_model.h5" model.save(keras_file) # TFLiteでコンバータを用意 converter = tf.contrib.lite.TFLiteConverter.from_keras_model_file(keras_file, input_arrays=['input_input'], output_arrays=['output/Softmax'], input_shapes={'input_input': [None, 32, 32, 3]}) # コンバート tflite_model = converter.convert() open("model_keras/converted_model.tflite", "wb").write(tflite_model) if __name__ == '__main__': main()
ドキュメントにあまり詳しくない部分があって、ググって試行錯誤してなんとか動きました。
converter
を作るときに、input_arrays
とかoutput_arrays
で名前を指定する必要があるっぽいです。
あと入力層のinput_shapesも指定する必要があるみたいです。(このへんよくわかってないので教えてください)
手順 on Android
Kotlinコードを下記においておきます。
動作確認環境はAndroidStudio20172.3/Kotlin 1.2.71/Zenfone 5/Android 8.0.0です。
準備 (TensorFlowLite)
下記を追記 (src: https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2-tflite/#6)
- app/build.gradle
android { // Add aaptOptions { noCompress "tflite" } } dependencies { // Add implementation 'org.tensorflow:tensorflow-lite:1.12.0' }
また、作成したモデル(.tfliteファイル)は、AppName/app/src/main/assets/
に入れます。
コード
下記のようなコードを書いて、MainActivityから呼びます。 classifyImageFromPath(file) でファイルパスを指定してやって分類する感じです。
package io.github.masaponto.tflitecifarten import android.app.Activity import android.graphics.Bitmap import org.tensorflow.lite.Interpreter import java.io.FileInputStream import java.io.IOException import java.nio.ByteBuffer import java.nio.MappedByteBuffer import java.nio.channels.FileChannel import android.graphics.BitmapFactory import java.io.File import java.nio.ByteOrder class Classifier(activity: Activity) { private val MODEL_NAME = "converted_model.tflite" private val IMAGE_SIZE = 32 private val IMAGE_MEAN = 128 private val IMAGE_STD = 128.0f private var tffile: Interpreter private var labelProbArray: Array<FloatArray> init { tffile = Interpreter(loadModelFile(activity)) // deprecated labelProbArray = Array(1){FloatArray(10)} } @Throws(IOException::class) private fun loadModelFile(activity: Activity): MappedByteBuffer { val fileDescriptor = activity.assets.openFd(MODEL_NAME) val inputStream = FileInputStream(fileDescriptor.fileDescriptor) val fileChannel = inputStream.channel val startOffset = fileDescriptor.startOffset val declaredLength = fileDescriptor.declaredLength return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength) } fun classifyImageFromPath(path: String): Int { val file = File(path) if (!file.exists()) { throw Exception("Fail to load image") } // load image val bitmap = BitmapFactory.decodeFile(file.path) val scaledBitmap = Bitmap.createScaledBitmap(bitmap, IMAGE_SIZE, IMAGE_SIZE,true) // convert bitmap to bytebuffer val byteBuffer = convertBitmapToByteBuffer(scaledBitmap) // classification with TF Lite val pred = classifyImage(byteBuffer) return onehotToLabel(pred[0]) } private fun convertBitmapToByteBuffer(bitmap: Bitmap): ByteBuffer { val byteBuffer = ByteBuffer.allocateDirect( IMAGE_SIZE * IMAGE_SIZE * 3 * 4) byteBuffer.order(ByteOrder.nativeOrder()) val intValues = IntArray(IMAGE_SIZE * IMAGE_SIZE) bitmap.getPixels(intValues, 0, bitmap.width, 0, 0, bitmap.width, bitmap.height) var pixel = 0 for (i in 0 until IMAGE_SIZE) { for (j in 0 until IMAGE_SIZE) { val v = intValues[pixel++] byteBuffer.putFloat((((v.shr(16) and 0xFF) - IMAGE_MEAN) / IMAGE_STD)) byteBuffer.putFloat((((v.shr(8) and 0xFF) - IMAGE_MEAN) / IMAGE_STD)) byteBuffer.putFloat((((v and 0xFF) - IMAGE_MEAN) / IMAGE_STD)) } } return byteBuffer } fun classifyImage(bytebuffer: ByteBuffer): Array<FloatArray> { tffile.run(bytebuffer, labelProbArray) return labelProbArray } private fun onehotToLabel(floatArray: FloatArray): Int { val tmp = floatArray.indices.maxBy { floatArray[it] } ?: -1 return tmp + 1 } }
全体のリポジトリ
convertBitmapToByteBuffer
もよくわかってないです。(つらい..)
動かす
適当に拾った猫ちゃん画像を分類してみましょう。
クラスラベルは1:airplane, 2:automobile, 3:bird, 4:cat, 5:deer, 6:dog, 7:frog, 8:horse, 9:ship, 10:truckの順ぽいのでたぶん動いてますね(雑)。
まとめ。
Kerasインターフェースを使って簡単にCNNの学習コードかいて、携帯端末用に変換する事ができました。 ちなみに、tfliteに変換するときにNoneでなくintegerで指定すればバッチサイズを変更指定できます。 Android端末で機械学習ができるので、わりと作れるアプリの幅が広がった気がします。 こんな調子でお仕事がんばるぞい。