原本的目的是移植一个模型到安卓,遇到问题后,重新做了个简单的模型验证,出现同样的问题。
python 训练的代码
model = keras.Sequential([keras.layers.Dense(units=1, input_shape=[1])])
model.compile(optimizer='sgd', loss='mean_squared_error')
xs = np.array([-1.0, 0.0, 1.0, 2.0, 3.0, 4.0], dtype=np.float32)
ys = np.array([-3.0, -1.0, 0.0, 3.0, 5.0, 7.0], dtype=np.float32)
model.fit(xs, ys, epochs=500)
keras_file = 'linear.h5'
keras.models.save_model(model, keras_file)
转换成 .tflite 后,在安卓使用
Interpreter interpreter = new Interpreter(FileUtil.loadMappedFile(activity, "linear.tflite"));
interpreter.allocateTensors();
int probabilityTensorIndex = 0;
int[] probabilityShape =
interpreter.getOutputTensor(probabilityTensorIndex).shape(); //
DataType probabilityDataType = interpreter.getOutputTensor(probabilityTensorIndex).dataType();
TensorBuffer outputProbabilityBuffer = TensorBuffer.createFixedSize(probabilityShape, probabilityDataType);
int inputTensorIndex = 0;
DataType inputDataType = interpreter.getInputTensor(inputTensorIndex).dataType();
int[] inputShape = interpreter.getInputTensor(inputTensorIndex).shape();
TensorBuffer inputBuffer = TensorBuffer.createFixedSize(inputShape, inputDataType);
final float[] input = {10};
inputBuffer.loadArray(input);
interpreter.run(inputBuffer, outputProbabilityBuffer);
报错是
I/tflite: Initialized TensorFlow Lite runtime.
E/AndroidRuntime: FATAL EXCEPTION: inference
Process: com.example.my1application, PID: 26839
java.lang.IllegalArgumentException: DataType error: cannot resolve DataType of org.tensorflow.lite.support.tensorbuffer.TensorBufferFloat
at org.tensorflow.lite.Tensor.dataTypeOf(Tensor.java:344)
at org.tensorflow.lite.Tensor.throwIfTypeIsIncompatible(Tensor.java:397)
at org.tensorflow.lite.Tensor.getInputShapeIfDifferent(Tensor.java:287)
at org.tensorflow.lite.NativeInterpreterWrapper.run(NativeInterpreterWrapper.java:137)
at org.tensorflow.lite.Interpreter.runForMultipleInputsOutputs(Interpreter.java:316)
at org.tensorflow.lite.Interpreter.run(Interpreter.java:277)
at com.example.my1application.DisplayMessageActivity$1.run(DisplayMessageActivity.java:114)
at android.os.Handler.handleCallback(Handler.java:815)
at android.os.Handler.dispatchMessage(Handler.java:104)
at android.os.Looper.loop(Looper.java:207)
at android.os.HandlerThread.run(HandlerThread.java:61)
这是一个专为移动设备优化的页面(即为了让你能够在 Google 搜索结果里秒开这个页面),如果你希望参与 V2EX 社区的讨论,你可以继续到 V2EX 上打开本讨论主题的完整版本。
V2EX 是创意工作者们的社区,是一个分享自己正在做的有趣事物、交流想法,可以遇见新朋友甚至新机会的地方。
V2EX is a community of developers, designers and creative people.