Flutter机器学习插件yjy_tflite_flutter的使用
Flutter机器学习插件yjy_tflite_flutter的使用
概述
TensorFlow Lite Flutter 插件提供了一种灵活且快速的方法来访问 TensorFlow Lite 解释器并执行推理。该 API 类似于 TensorFlow Lite 的 Java 和 Swift API。它直接绑定到 TensorFlow Lite C API,使其高效(低延迟)。支持使用 NNAPI、Android 上的 GPU 委托和 iOS 上的 Metal 委托进行加速。
关键特性
- 可以使用任何 TensorFlow Lite 模型。
- 支持多线程和委托加速。
- API 结构与 TensorFlow Lite Java API 类似。
- 推理速度接近原生 Android 应用程序。
- 可以选择使用任何版本的 TensorFlow 构建二进制文件。
- 可以在不同的隔离中运行推理,以防止 UI 线程卡顿。
初始化设置:添加动态库到你的应用
Android
- 将脚本
install.sh
(Linux/Mac)或install.bat
(Windows)放置在项目的根目录下。 - 在项目的根目录下执行
sh install.sh
(Linux)/install.bat
(Windows),自动下载并放置二进制文件到适当的文件夹。注意:安装的二进制文件不包括对
GpuDelegateV2
和NnApiDelegate
的支持,但仍然可以使用InterpreterOptions().useNnApiForAndroid
。 - 如果希望使用
GpuDelegateV2
和NnApiDelegate
,请改用sh install.sh -d
(Linux)或install.bat -d
(Windows)。
这些脚本会根据最新稳定的 TensorFlow 版本安装预构建的二进制文件。有关如何使用其他 TensorFlow 版本的信息,请参阅此处。
TFLite Flutter 辅助库
一个专门用于处理和操作 TensorFlow Lite 模型输入和输出的简单架构库。API 设计和文档与 TensorFlow Lite Android 支持库相同。强烈建议与 tflite_flutter_plugin
一起使用。了解更多。
示例
标题 | 代码 | Demo | Blog |
---|---|---|---|
文本分类应用程序 | 代码 | ![]() |
博客/教程 |
图像分类应用程序 | 代码 | ![]() |
- |
实时对象检测应用程序 | 代码 | ![]() |
博客/教程 |
强化学习应用程序 | 代码 | ![]() |
博客/教程 |
导入
import 'package:tflite_flutter/tflite_flutter.dart';
使用说明
创建解释器
从资产加载模型
将 your_model.tflite
放置在 assets
目录中,并确保在 pubspec.yaml
中包含资产。
final interpreter = await Interpreter.fromAsset('your_model.tflite');
有关从缓冲区或文件创建解释器的更多信息,请参阅文档。
执行推理
推荐使用TFLite Flutter Helper Library来简化输入和输出的处理。
单输入单输出
// 对于输入张量形状为 [1,5] 且类型为 float32 的情况
var input = [[1.23, 6.54, 7.81, 3.21, 2.22]];
// 如果输出张量形状为 [1,2] 且类型为 float32
var output = List.filled(1 * 2, 0).reshape([1, 2]);
// 推理
interpreter.run(input, output);
// 打印输出
print(output);
多输入多输出
var input0 = [1.23];
var input1 = [2.43];
// 输入列表
var inputs = [input0, input1, input0, input1];
var output0 = List<double>.filled(1, 0);
var output1 = List<double>.filled(1, 0);
// 输出映射
var outputs = {0: output0, 1: output1};
// 推理
interpreter.runForMultipleInputs(inputs, outputs);
// 打印输出
print(outputs);
关闭解释器
interpreter.close();
使用委托提高性能
Android 的 NNAPI 委托
var interpreterOptions = InterpreterOptions()..useNnApiForAndroid = true;
final interpreter = await Interpreter.fromAsset('your_model.tflite',
options: interpreterOptions);
或者
var interpreterOptions = InterpreterOptions()..addDelegate(NnApiDelegate());
final interpreter = await Interpreter.fromAsset('your_model.tflite',
options: interpreterOptions);
Android 和 iOS 的 GPU 委托
Android 的 GpuDelegateV2
final gpuDelegateV2 = GpuDelegateV2(
options: GpuDelegateOptionsV2(
false,
TfLiteGpuInferenceUsage.fastSingleAnswer,
TfLiteGpuInferencePriority.minLatency,
TfLiteGpuInferencePriority.auto,
TfLiteGpuInferencePriority.auto,
));
var interpreterOptions = InterpreterOptions()..addDelegate(gpuDelegateV2);
final interpreter = await Interpreter.fromAsset('your_model.tflite',
options: interpreterOptions);
iOS 的 Metal 委托 (GpuDelegate)
final gpuDelegate = GpuDelegate(
options: GpuDelegateOptions(true, TFLGpuDelegateWaitType.active),
);
var interpreterOptions = InterpreterOptions()..addDelegate(gpuDelegate);
final interpreter = await Interpreter.fromAsset('your_model.tflite',
options: interpreterOptions);
更多示例代码请参考测试文件。
使用任意 TensorFlow 版本
预构建的二进制文件会随着每个稳定版 TensorFlow 更新。如果您想使用最新的不稳定版 TensorFlow 或旧版本,请按照以下步骤本地构建。
Android
配置 Android 构建环境,请参考TensorFlow Lite 官方指南。
对于 TensorFlow >= v2.2:
bazel build -c opt --cxxopt=--std=c++11 --config=android_arm //tensorflow/lite/c:tensorflowlite_c
// 类似地,对于 arm64 使用 --config=android_arm64
对于 TensorFlow <= v2.1:
bazel build -c opt --cxxopt=--std=c++11 --config=android_arm //tensorflow/lite/experimental/c:libtensorflowlite_c.so
// 类似地,对于 arm64 使用 --config=android_arm64
iOS
请参考TensorFlow Lite 官方指南来本地构建 iOS 版本。
注意:必须使用 macOS 来构建 iOS。
动态链接信息
tflite_flutter
动态链接到 TensorFlow Lite C API,Android 上以 libtensorflowlite_c.so
形式提供,iOS 上以 TensorFlowLiteC.framework
形式提供。
对于 Android,需要手动从发布资产下载这些二进制文件,并将 libtensorflowlite_c.so
文件放置到 <root>/android/app/src/main/jniLibs/
目录下的每个架构文件夹中(如 arm、arm64、x86、x86_64),如示例应用程序中所示。
未来工作
- 支持 Flutter 桌面应用程序。
- 改进错误处理。
致谢
- TensorFlow Lite 团队成员 Tian LIN、Jared Duke、Andrew Selle、YoungSeok Yoon、Shuangfeng Li。
- dart-lang/tflite_native 的作者们。
示例代码
以下是完整的文本分类应用程序示例代码:
import 'package:flutter/material.dart';
import 'package:tflite_flutter_plugin_example/classifier.dart';
void main() => runApp(MyApp());
class MyApp extends StatefulWidget {
[@override](/user/override)
_MyAppState createState() => _MyAppState();
}
class _MyAppState extends State<MyApp> {
late TextEditingController _controller;
late Classifier _classifier;
late List<Widget> _children;
[@override](/user/override)
void initState() {
super.initState();
_controller = TextEditingController();
_classifier = Classifier();
_children = [];
_children.add(Container());
}
[@override](/user/override)
Widget build(BuildContext context) {
return MaterialApp(
home: Scaffold(
appBar: AppBar(
backgroundColor: Colors.orangeAccent,
title: const Text('Text classification'),
),
body: Container(
padding: const EdgeInsets.all(4),
child: Column(
children: <Widget>[
Expanded(
child: ListView.builder(
itemCount: _children.length,
itemBuilder: (_, index) {
return _children[index];
},
),
),
Container(
padding: const EdgeInsets.all(8),
decoration: BoxDecoration(
border: Border.all(color: Colors.orangeAccent)),
child: Row(children: <Widget>[
Expanded(
child: TextField(
decoration: const InputDecoration(hintText: 'Write some text here'),
controller: _controller,
),
),
TextButton(
child: const Text('Classify'),
onPressed: () {
final text = _controller.text;
final prediction = _classifier.classify(text);
setState(() {
_children.add(Dismissible(
key: GlobalKey(),
onDismissed: (direction) {},
child: Card(
child: Container(
padding: const EdgeInsets.all(16),
color: prediction[1] > prediction[0]
? Colors.lightGreen
: Colors.redAccent,
child: Column(
crossAxisAlignment: CrossAxisAlignment.start,
children: <Widget>[
Text(
"Input: $text",
style: const TextStyle(fontSize: 16),
),
Text("Output:"),
Text(" Positive: ${prediction[1]}"),
Text(" Negative: ${prediction[0]}"),
],
),
),
),
));
_controller.clear();
});
},
),
]),
),
],
),
),
),
);
}
}
更多关于Flutter机器学习插件yjy_tflite_flutter的使用的实战教程也可以访问 https://www.itying.com/category-92-b0.html
更多关于Flutter机器学习插件yjy_tflite_flutter的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html
yjy_tflite_flutter
是一个用于在 Flutter 应用中集成 TensorFlow Lite 的插件。它允许你在 Flutter 应用中使用 TensorFlow Lite 模型进行机器学习推理。以下是如何使用 yjy_tflite_flutter
插件的基本步骤:
1. 添加依赖
首先,你需要在 pubspec.yaml
文件中添加 yjy_tflite_flutter
插件的依赖:
dependencies:
flutter:
sdk: flutter
yjy_tflite_flutter: ^0.1.0 # 请使用最新版本
然后运行 flutter pub get
来获取依赖。
2. 加载 TensorFlow Lite 模型
在使用模型之前,你需要将 TensorFlow Lite 模型文件(.tflite
)放到你的 Flutter 项目中。通常,你可以将模型文件放在 assets
文件夹中,并在 pubspec.yaml
中声明:
flutter:
assets:
- assets/model.tflite
然后,你可以使用 yjy_tflite_flutter
插件加载模型:
import 'package:yjy_tflite_flutter/yjy_tflite_flutter.dart';
Future<void> loadModel() async {
try {
await YjyTfliteFlutter.loadModel(
model: "assets/model.tflite",
labels: "assets/labels.txt", // 如果有标签文件
);
print("Model loaded successfully");
} catch (e) {
print("Failed to load model: $e");
}
}
3. 运行推理
加载模型后,你可以使用 runModelOnImage
或 runModelOnFrame
等方法来进行推理。以下是一个使用 runModelOnImage
的示例:
import 'package:image_picker/image_picker.dart';
Future<void> runInference() async {
final picker = ImagePicker();
final pickedFile = await picker.getImage(source: ImageSource.camera);
if (pickedFile != null) {
try {
var recognitions = await YjyTfliteFlutter.runModelOnImage(
path: pickedFile.path,
imageMean: 127.5, // 根据需要调整
imageStd: 127.5, // 根据需要调整
numResults: 5, // 返回的结果数量
threshold: 0.5, // 置信度阈值
);
print(recognitions);
} catch (e) {
print("Failed to run inference: $e");
}
}
}
4. 释放资源
在使用完模型后,记得释放资源以避免内存泄漏:
Future<void> disposeModel() async {
await YjyTfliteFlutter.dispose();
print("Model disposed");
}
5. 处理结果
推理结果通常是一个包含预测标签和置信度的列表。你可以根据需要处理这些结果并显示在你的应用中。
6. 完整示例
以下是一个完整的示例,展示了如何加载模型、运行推理并显示结果:
import 'package:flutter/material.dart';
import 'package:yjy_tflite_flutter/yjy_tflite_flutter.dart';
import 'package:image_picker/image_picker.dart';
void main() => runApp(MyApp());
class MyApp extends StatelessWidget {
[@override](/user/override)
Widget build(BuildContext context) {
return MaterialApp(
home: HomeScreen(),
);
}
}
class HomeScreen extends StatefulWidget {
[@override](/user/override)
_HomeScreenState createState() => _HomeScreenState();
}
class _HomeScreenState extends State<HomeScreen> {
List<dynamic> _recognitions = [];
[@override](/user/override)
void initState() {
super.initState();
loadModel();
}
Future<void> loadModel() async {
try {
await YjyTfliteFlutter.loadModel(
model: "assets/model.tflite",
labels: "assets/labels.txt",
);
print("Model loaded successfully");
} catch (e) {
print("Failed to load model: $e");
}
}
Future<void> runInference() async {
final picker = ImagePicker();
final pickedFile = await picker.getImage(source: ImageSource.camera);
if (pickedFile != null) {
try {
var recognitions = await YjyTfliteFlutter.runModelOnImage(
path: pickedFile.path,
imageMean: 127.5,
imageStd: 127.5,
numResults: 5,
threshold: 0.5,
);
setState(() {
_recognitions = recognitions;
});
} catch (e) {
print("Failed to run inference: $e");
}
}
}
[@override](/user/override)
void dispose() {
YjyTfliteFlutter.dispose();
super.dispose();
}
[@override](/user/override)
Widget build(BuildContext context) {
return Scaffold(
appBar: AppBar(
title: Text('TensorFlow Lite Flutter'),
),
body: Column(
children: [
ElevatedButton(
onPressed: runInference,
child: Text('Run Inference'),
),
Expanded(
child: ListView.builder(
itemCount: _recognitions.length,
itemBuilder: (context, index) {
return ListTile(
title: Text(_recognitions[index]['label']),
subtitle: Text((_recognitions[index]['confidence'] * 100).toStringAsFixed(2) + '%'),
);
},
),
),
],
),
);
}
}