Flutter机器学习推理插件tflite_flutter_processing的使用
Flutter机器学习推理插件tflite_flutter_processing的使用
TensorFlow Lite Flutter Helper Library
TFLite Flutter Helper Library 将 TensorFlow Lite 支持库(TFLite Support Library)和任务库(TFLite Support Task Library)引入到 Flutter 中,并帮助用户快速开发机器学习应用并部署到移动设备上,同时不会影响性能。
获取开始
安装TFLite Flutter插件
请遵循以下链接中的初始设置说明进行安装: 安装指南
基本图像处理与转换
TFLite Helper 依赖于 flutter image
包来处理图像。
TensorFlow Lite 支持库包含一系列基本的图像处理方法,例如裁剪和调整大小。要使用这些方法,可以创建一个 ImageProcessor
并添加所需的处理操作。为了将图像转换为 TensorFlow Lite 解释器所需的张量格式,可以创建一个 TensorImage
作为输入:
// 初始化代码
// 创建一个包含所有所需操作的 ImageProcessor
ImageProcessor imageProcessor = ImageProcessorBuilder()
.add(ResizeOp(224, 224, ResizeMethod.NEAREST_NEIGHBOUR))
.build();
// 从文件创建 TensorImage 对象
TensorImage tensorImage = TensorImage.fromFile(imageFile);
// 预处理图像
// 图像文件中的图像将被调整大小为 (224, 224)
tensorImage = imageProcessor.process(tensorImage);
示例应用:Image Classification
基本音频数据处理
TensorFlow Lite 支持库还定义了一个 TensorAudio
类,用于封装一些基本的音频数据处理方法。
TensorAudio tensorAudio = TensorAudio.create(
TensorAudioFormat.create(1, sampleRate), size);
tensorAudio.loadShortBytes(audioBytes);
TensorBuffer inputBuffer = tensorAudio.tensorBuffer;
示例应用:Audio Classification
创建输出对象并运行模型
// 创建一个用于存储结果的容器,并指定这是一个量化模型。
// 因此,'DataType' 被定义为 UINT8 (8位无符号整数)
TensorBuffer probabilityBuffer =
TensorBuffer.createFixedSize(<int>[1, 1001], TfLiteType.kTfLiteUInt8);
加载模型并运行推理:
import 'package:tflite_flutter/tflite_flutter.dart';
try {
// 从资产创建解释器
Interpreter interpreter =
await Interpreter.fromAsset("mobilenet_v1_1.0_224_quant.tflite");
interpreter.run(tensorImage.buffer, probabilityBuffer.buffer);
} catch (e) {
print('Error loading model: ' + e.toString());
}
访问结果
开发者可以直接通过 probabilityBuffer.getDoubleList()
访问输出结果。如果模型产生的是量化输出,记得转换结果。对于 MobileNet 量化模型,开发者需要将每个输出值除以 255 来获得每个类别的概率,范围从 0(最不可能)到 1(最可能)。
可选:映射结果到标签
开发者还可以选择将结果映射到标签。首先,将包含标签的文本文件复制到模块的 assets 目录中。然后,使用以下代码加载标签文件:
List<String> labels = await FileUtil.loadLabels("assets/labels.txt");
以下代码演示了如何将概率与类别标签关联起来:
TensorLabel tensorLabel = TensorLabel.fromList(
labels, probabilityProcessor.process(probabilityBuffer));
Map<String, double> doubleMap = tensorLabel.getMapWithFloatValue();
ImageProcessor架构
ImageProcessor 的设计允许在构建过程中提前定义图像处理操作并进行优化。目前,ImageProcessor 支持三种基本的预处理操作:
int cropSize = min(_inputImage.height, _inputImage.width);
ImageProcessor imageProcessor = ImageProcessorBuilder()
// 将图像中心裁剪为最大的正方形
.add(ResizeWithCropOrPadOp(cropSize, cropSize))
// 使用双线性或最近邻调整大小
.add(ResizeOp(224, 224, ResizeMethod.NEAREST_NEIGHBOUR))
// 按 90 度增量顺时针旋转
.add(Rot90Op(rotationDegrees ~/ 90))
.add(NormalizeOp(127.5, 127.5))
.add(QuantizeOp(128.0, 1 / 128.0))
.build();
有关归一化和量化参数的更多信息,请参见 此处。
量化
TensorProcessor 可用于量化输入张量或解量化输出张量。例如,在处理量化输出 TensorBuffer 时,开发者可以使用 DequantizeOp 将结果解量化为介于 0 和 1 之间的浮点概率:
// 解量化结果的后处理器
TensorProcessor probabilityProcessor =
TensorProcessorBuilder().add(DequantizeOp(0, 1 / 255.0)).build();
TensorBuffer dequantizedBuffer =
probabilityProcessor.process(probabilityBuffer);
读取量化参数
// 输入张量在索引 0 处的量化参数
QuantizationParams inputParams = interpreter.getInputTensor(0).params;
// 输出张量在索引 0 处的量化参数
QuantizationParams outputParams = interpreter.getOutputTensor(0).params;
任务库
目前,像 NLClassifier
、BertNLClassifier
和 BertQuestionAnswerer
这样的基于文本的模型可以与 Flutter 任务库一起使用。
集成自然语言分类器
任务库的 NLClassifier
API 将输入文本分类到不同的类别,是一个灵活且可配置的 API,可以处理大多数文本分类模型。详细指南可以在 这里 查看。
final classifier = await NLClassifier.createFromAsset('assets/$_modelFileName',
options: NLClassifierOptions());
List<Category> predictions = classifier.classify(rawText);
示例应用:Text Classification。
集成BERT自然语言分类器
任务库的 BertNLClassifier
API 与 NLClassifier
非常相似,它将输入文本分类到不同的类别,但这个 API 是专门为需要 Wordpiece 和 Sentencepiece 分词的 BERT 相关模型设计的。详细指南可以在 这里 查看。
final classifier = await BertNLClassifier.createFromAsset('assets/$_modelFileName',
options: BertNLClassifierOptions());
List<Category> predictions = classifier.classify(rawText);
集成BERT问答器
任务库的 BertQuestionAnswerer
API 加载一个 BERT 模型并根据给定段落的内容回答问题。更多详细信息,请参阅 Question-Answer 模型文档。详细指南可以在 这里 查看。
final bertQuestionAnswerer = await BertQuestionAnswerer.createFromAsset('assets/$_modelFileName');
List<QaAnswer> answeres = bertQuestionAnswerer.answer(context, question);
更多关于Flutter机器学习推理插件tflite_flutter_processing的使用的实战教程也可以访问 https://www.itying.com/category-92-b0.html
更多关于Flutter机器学习推理插件tflite_flutter_processing的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html
tflite_flutter_processing
是一个 Flutter 插件,用于在 Flutter 应用中执行 TensorFlow Lite 模型的推理。它提供了一个简单易用的 API,允许你在 Flutter 应用中加载和运行 TensorFlow Lite 模型,并对输入数据进行推理。
以下是使用 tflite_flutter_processing
插件的步骤:
1. 添加依赖
首先,在你的 pubspec.yaml
文件中添加 tflite_flutter_processing
插件的依赖:
dependencies:
flutter:
sdk: flutter
tflite_flutter_processing: ^1.0.0
然后,运行 flutter pub get
来安装依赖。
2. 加载模型
在你的 Dart 代码中,首先需要加载 TensorFlow Lite 模型。你可以将模型文件放在 assets
文件夹中,并在 pubspec.yaml
中声明:
flutter:
assets:
- assets/model.tflite
然后,在 Dart 代码中加载模型:
import 'package:tflite_flutter_processing/tflite_flutter_processing.dart';
Future<void> loadModel() async {
String modelPath = "assets/model.tflite";
try {
Interpreter interpreter = await Interpreter.fromAsset(modelPath);
print("Model loaded successfully");
} catch (e) {
print("Failed to load model: $e");
}
}
3. 准备输入数据
在进行推理之前,你需要准备好输入数据。输入数据的格式应该与模型的输入张量相匹配。例如,如果模型接受一个形状为 [1, 224, 224, 3]
的输入张量,你需要准备一个符合该形状的浮点数数组。
import 'dart:typed_data';
Float32List prepareInputData() {
// 假设输入是一个 224x224 的 RGB 图像
List<double> input = List<double>.filled(1 * 224 * 224 * 3, 0.0);
// 在这里填充输入数据,例如图像的像素值
return Float32List.fromList(input);
}
4. 执行推理
使用加载的 Interpreter
对象执行推理:
void runInference(Interpreter interpreter, Float32List input) {
// 获取输入和输出张量的形状
var inputShape = interpreter.getInputTensor(0).shape;
var outputShape = interpreter.getOutputTensor(0).shape;
// 准备输出缓冲区
var output = List<double>.filled(outputShape[1], 0.0).asMap().map((index, value) => MapEntry(index, 0.0));
// 执行推理
interpreter.run(input, output);
// 处理输出结果
print("Inference result: $output");
}
5. 释放资源
在推理完成后,记得释放 Interpreter
占用的资源:
void disposeInterpreter(Interpreter interpreter) {
interpreter.close();
}
完整示例
以下是一个完整的示例,展示如何加载模型、准备输入数据、执行推理并释放资源:
import 'package:flutter/material.dart';
import 'package:tflite_flutter_processing/tflite_flutter_processing.dart';
import 'dart:typed_data';
void main() async {
WidgetsFlutterBinding.ensureInitialized();
await loadModel();
}
Future<void> loadModel() async {
String modelPath = "assets/model.tflite";
try {
Interpreter interpreter = await Interpreter.fromAsset(modelPath);
print("Model loaded successfully");
Float32List input = prepareInputData();
runInference(interpreter, input);
disposeInterpreter(interpreter);
} catch (e) {
print("Failed to load model: $e");
}
}
Float32List prepareInputData() {
// 假设输入是一个 224x224 的 RGB 图像
List<double> input = List<double>.filled(1 * 224 * 224 * 3, 0.0);
// 在这里填充输入数据,例如图像的像素值
return Float32List.fromList(input);
}
void runInference(Interpreter interpreter, Float32List input) {
// 获取输入和输出张量的形状
var inputShape = interpreter.getInputTensor(0).shape;
var outputShape = interpreter.getOutputTensor(0).shape;
// 准备输出缓冲区
var output = List<double>.filled(outputShape[1], 0.0).asMap().map((index, value) => MapEntry(index, 0.0));
// 执行推理
interpreter.run(input, output);
// 处理输出结果
print("Inference result: $output");
}
void disposeInterpreter(Interpreter interpreter) {
interpreter.close();
}