Flutter机器学习推理插件tflite_flutter_plus的使用
Flutter机器学习推理插件tflite_flutter_plus的使用
概述
TensorFlow Lite Flutter 插件提供了一种灵活且快速的方法来访问 TensorFlow Lite 解释器并执行推理。该 API 类似于 TFLite 的 Java 和 Swift API。它直接绑定到 TFLite C API,使其高效(低延迟)。支持使用 NNAPI、GPU 委托(Android)、Metal 和 CoreML 委托(iOS),以及 XNNPack 委托(桌面平台)。
关键特性
- 支持多平台:Android、iOS、Windows、Mac、Linux。
- 可以使用任何 TFLite 模型。
- 支持多线程和委托加速。
- 结构与 TensorFlow Lite Java API 类似。
- 推理速度接近原生 Android 应用程序。
- 可以选择使用任何 TensorFlow 版本通过本地构建二进制文件。
- 在不同的隔离中运行推理,以防止 UI 线程卡顿。
初始化设置:向应用添加动态库
Android
- 将脚本
install.sh
(Linux/Mac)或install.bat
(Windows)放在项目的根目录下。 - 执行
sh install.sh
(Linux)/install.bat
(Windows)在项目的根目录下自动下载并放置二进制文件到合适的文件夹。注意:安装的二进制文件将不包括对
GpuDelegateV2
和NnApiDelegate
的支持,但是可以使用InterpreterOptions().useNnApiForAndroid
。 - 如果希望使用这些
GpuDelegateV2
和NnApiDelegate
,则应使用sh install.sh -d
(Linux)或install.bat -d
(Windows)。
这些脚本基于最新的稳定 TensorFlow 发行版安装预构建的二进制文件。如需使用其他 TensorFlow 版本,请参阅 wiki 中的说明。
iOS
- 下载
<code>TensorFlowLiteC.framework</code>
。若要构建自定义版本的 TensorFlow,请参阅 wiki 中的说明。 - 将
<code>TensorFlowLiteC.framework</code>
放置在包的 pub-cache 文件夹中。
pub-cache 文件夹位置:
- Linux/Mac:
~/.pub-cache/hosted/pub.dartlang.org/tflite_flutter-<plugin-version>/ios/
- Windows:
%LOCALAPPDATA%\Pub\Cache\hosted\pub.dartlang.org\tflite_flutter-<plugin-version>\ios\
桌面
请参阅 指南 以了解如何构建和使用桌面二进制文件。
TFLite Flutter 辅助库
一个具有简单架构的专用库,用于处理和操作 TFLite 模型的输入和输出。API 设计和文档与 TensorFlow Lite Android Support Library 相同。强烈推荐与 <code>tflite_flutter_plugin</code>
一起使用。了解更多。
示例
标题 | 代码 | Demo |
---|---|---|
文本分类应用 | 代码 | ![]() |
图像分类应用 | 代码 | ![]() |
对象检测应用 | 代码 | ![]() |
强化学习应用 | 代码 | ![]() |
导入
import 'package:tflite_flutter/tflite_flutter.dart';
使用说明
创建解释器
从资源文件导入
将 <code>your_model.tflite</code>
放在 <code>assets</code>
目录中,并确保在 <code>pubspec.yaml</code>
中包含资源文件。
final interpreter = await tfl.Interpreter.fromAsset('your_model.tflite');
创建解释器从缓冲区或文件
有关更多信息,请参阅文档。
执行推理
请参阅 TFLite Flutter Helper Library 以获取有关输入和输出处理的简便方法。
单输入单输出
// 例如,如果输入张量形状为 [1,5] 且类型为 float32
var input = [[1.23, 6.54, 7.81, 3.21, 2.22]];
// 如果输出张量形状为 [1,2] 且类型为 float32
var output = List.filled(1 * 2, 0).reshape([1, 2]);
// 推理
interpreter.run(input, output);
// 打印输出
print(output);
多输入多输出
var input0 = [1.23];
var input1 = [2.43];
// 输入:List<Object>
var inputs = [input0, input1, input0, input1];
var output0 = List<double>.filled(1, 0);
var output1 = List<double>.filled(1, 0);
// 输出:Map<int, Object>
var outputs = {0: output0, 1: output1};
// 推理
interpreter.runForMultipleInputs(inputs, outputs);
// 打印输出
print(outputs)
关闭解释器
interpreter.close();
提高性能使用委托支持
注意:此功能正在测试中,可能在某些构建和设备上不稳定。
Android NNAPI 委托
var interpreterOptions = InterpreterOptions()..useNnApiForAndroid = true;
final interpreter = await Interpreter.fromAsset('your_model.tflite',
options: interpreterOptions);
// 或者
var interpreterOptions = InterpreterOptions()..addDelegate(NnApiDelegate());
final interpreter = await Interpreter.fromAsset('your_model.tflite',
options: interpreterOptions);
Android 和 iOS GPU 委托
Android GpuDelegateV2
final gpuDelegateV2 = GpuDelegateV2(
options: GpuDelegateOptionsV2(
false,
TfLiteGpuInferenceUsage.fastSingleAnswer,
TfLiteGpuInferencePriority.minLatency,
TfLiteGpuInferencePriority.auto,
TfLiteGpuInferencePriority.auto,
));
var interpreterOptions = InterpreterOptions()..addDelegate(gpuDelegateV2);
final interpreter = await Interpreter.fromAsset('your_model.tflite',
options: interpreterOptions);
iOS Metal 委托 (GpuDelegate)
final gpuDelegate = GpuDelegate(
options: GpuDelegateOptions(true, TFLGpuDelegateWaitType.active),
);
var interpreterOptions = InterpreterOptions()..addDelegate(gpuDelegate);
final interpreter = await Interpreter.fromAsset('your_model.tflite',
options: interpreterOptions);
代码示例
以下是一个简单的文本分类应用的示例代码:
import 'package:flutter/material.dart';
import 'classifier.dart';
void main() => runApp(const MyApp());
class MyApp extends StatefulWidget {
const MyApp({super.key});
[@override](/user/override)
_MyAppState createState() => _MyAppState();
}
class _MyAppState extends State<MyApp> {
late TextEditingController _controller;
late Classifier _classifier;
late List<Widget> _children;
[@override](/user/override)
void initState() {
super.initState();
_controller = TextEditingController();
_classifier = Classifier();
_children = [];
_children.add(Container());
}
[@override](/user/override)
Widget build(BuildContext context) {
return MaterialApp(
home: Scaffold(
appBar: AppBar(
backgroundColor: Colors.orangeAccent,
title: const Text('Text classification'),
),
body: Container(
padding: const EdgeInsets.all(4),
child: Column(
children: <Widget>[
Expanded(
child: ListView.builder(
itemCount: _children.length,
itemBuilder: (_, index) {
return _children[index];
},
)),
Container(
padding: const EdgeInsets.all(8),
decoration: BoxDecoration(
border: Border.all(color: Colors.orangeAccent)),
child: Row(children: <Widget>[
Expanded(
child: TextField(
decoration: const InputDecoration(
hintText: 'Write some text here'),
controller: _controller,
),
),
TextButton(
child: const Text('Classify'),
onPressed: () {
final text = _controller.text;
final prediction = _classifier.classify(text);
setState(() {
_children.add(Dismissible(
key: GlobalKey(),
onDismissed: (direction) {},
child: Card(
child: Container(
padding: const EdgeInsets.all(16),
color: prediction[1] > prediction[0]
? Colors.lightGreen
: Colors.redAccent,
child: Column(
crossAxisAlignment: CrossAxisAlignment.start,
children: <Widget>[
Text(
"Input: $text",
style: const TextStyle(fontSize: 16),
),
const Text("Output:"),
Text(" Positive: ${prediction[1]}"),
Text(" Negative: ${prediction[0]}"),
],
),
),
),
));
_controller.clear();
});
},
),
]),
),
],
),
),
),
);
}
}
更多关于Flutter机器学习推理插件tflite_flutter_plus的使用的实战教程也可以访问 https://www.itying.com/category-92-b0.html
更多关于Flutter机器学习推理插件tflite_flutter_plus的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html
当然,以下是一个关于如何使用 tflite_flutter_plus
插件在 Flutter 中进行机器学习推理的示例代码。这个插件允许你在 Flutter 应用中加载和使用 TensorFlow Lite 模型进行推理。
1. 添加依赖
首先,你需要在你的 pubspec.yaml
文件中添加 tflite_flutter_plus
依赖:
dependencies:
flutter:
sdk: flutter
tflite_flutter_plus: ^2.10.0 # 请使用最新版本
然后运行 flutter pub get
来获取依赖。
2. 配置 Android 项目
在你的 android/app/build.gradle
文件中,确保你设置了正确的 minSdkVersion
和 targetSdkVersion
,并且添加了必要的权限:
android {
...
defaultConfig {
...
minSdkVersion 21
targetSdkVersion 30
...
}
...
}
dependencies {
...
implementation 'org.tensorflow:tensorflow-lite:2.5.0' // 确保 TensorFlow Lite 版本与插件兼容
implementation 'org.tensorflow:tensorflow-lite-gpu:2.5.0' // 如果你打算使用 GPU 加速
...
}
3. 加载模型并进行推理
以下是一个完整的 Flutter 应用示例,展示如何使用 tflite_flutter_plus
加载模型并进行推理:
import 'package:flutter/material.dart';
import 'package:tflite_flutter_plus/tflite_flutter_plus.dart';
void main() {
runApp(MyApp());
}
class MyApp extends StatefulWidget {
@override
_MyAppState createState() => _MyAppState();
}
class _MyAppState extends State<MyApp> {
late Interpreter _interpreter;
late List<Float32List> _outputs;
@override
void initState() {
super.initState();
loadModel();
}
Future<void> loadModel() async {
// 确保模型文件位于设备的正确位置,例如在 assets 文件夹中
var model = await Tflite.loadModel(
model: "assets/model.tflite",
labels: "assets/labels.txt", // 如果你的模型包含标签文件
);
if (model != null) {
setState(() {
_interpreter = model;
});
}
}
Future<void> runInference(List<List<Float32List>> input) async {
if (_interpreter != null) {
_outputs = await _interpreter.run(input);
setState(() {});
}
}
@override
Widget build(BuildContext context) {
return MaterialApp(
home: Scaffold(
appBar: AppBar(
title: Text('TFLite Flutter Plus Example'),
),
body: Center(
child: Column(
mainAxisAlignment: MainAxisAlignment.center,
children: <Widget>[
ElevatedButton(
onPressed: () async {
// 假设输入数据是一个 28x28 的灰度图像,这里以随机数为例
List<List<Float32List>> input = [
Float32List.fromList(List.generate(28 * 28, (index) => (index % 256).toDouble() / 255.0).toList())
.reshape([1, 28, 28, 1]) as Float32List,
];
await runInference(input);
// 显示推理结果
print("Inference Result: ${_outputs[0].toList()}");
},
child: Text('Run Inference'),
),
],
),
),
),
);
}
@override
void dispose() {
_interpreter?.close();
super.dispose();
}
}
注意事项
-
模型文件:确保你的 TensorFlow Lite 模型文件(例如
model.tflite
)和标签文件(如果有的话,例如labels.txt
)已经放置在assets
文件夹中,并在pubspec.yaml
中正确声明:flutter: assets: - assets/model.tflite - assets/labels.txt
-
输入数据:根据你的模型输入要求准备输入数据。例如,上面的示例假设输入是一个 28x28 的灰度图像,但实际使用中你需要根据模型的输入规格来准备数据。
-
清理资源:在
dispose
方法中关闭解释器,以释放资源。
这个示例展示了如何使用 tflite_flutter_plus
插件在 Flutter 应用中加载 TensorFlow Lite 模型并进行推理。你可以根据自己的需求进一步扩展和修改这个示例。