Flutter机器学习推理插件tflite_flutter_plus的使用

Flutter机器学习推理插件tflite_flutter_plus的使用



Platform Pub Package Docs

概述

TensorFlow Lite Flutter 插件提供了一种灵活且快速的方法来访问 TensorFlow Lite 解释器并执行推理。该 API 类似于 TFLite 的 Java 和 Swift API。它直接绑定到 TFLite C API,使其高效(低延迟)。支持使用 NNAPI、GPU 委托(Android)、Metal 和 CoreML 委托(iOS),以及 XNNPack 委托(桌面平台)。

关键特性

  • 支持多平台:Android、iOS、Windows、Mac、Linux。
  • 可以使用任何 TFLite 模型。
  • 支持多线程和委托加速。
  • 结构与 TensorFlow Lite Java API 类似。
  • 推理速度接近原生 Android 应用程序。
  • 可以选择使用任何 TensorFlow 版本通过本地构建二进制文件。
  • 在不同的隔离中运行推理,以防止 UI 线程卡顿。

初始化设置:向应用添加动态库

Android

  1. 将脚本 install.sh(Linux/Mac)或 install.bat(Windows)放在项目的根目录下。
  2. 执行 sh install.sh(Linux)/ install.bat(Windows)在项目的根目录下自动下载并放置二进制文件到合适的文件夹。

    注意:安装的二进制文件将不包括对 GpuDelegateV2NnApiDelegate 的支持,但是可以使用 InterpreterOptions().useNnApiForAndroid

  3. 如果希望使用这些 GpuDelegateV2NnApiDelegate,则应使用 sh install.sh -d(Linux)或 install.bat -d(Windows)。

这些脚本基于最新的稳定 TensorFlow 发行版安装预构建的二进制文件。如需使用其他 TensorFlow 版本,请参阅 wiki 中的说明。

iOS

  1. 下载 <code>TensorFlowLiteC.framework</code>。若要构建自定义版本的 TensorFlow,请参阅 wiki 中的说明。
  2. <code>TensorFlowLiteC.framework</code> 放置在包的 pub-cache 文件夹中。

pub-cache 文件夹位置:

  • Linux/Mac: ~/.pub-cache/hosted/pub.dartlang.org/tflite_flutter-&lt;plugin-version&gt;/ios/
  • Windows: %LOCALAPPDATA%\Pub\Cache\hosted\pub.dartlang.org\tflite_flutter-&lt;plugin-version&gt;\ios\

桌面

请参阅 指南 以了解如何构建和使用桌面二进制文件。

TFLite Flutter 辅助库

一个具有简单架构的专用库,用于处理和操作 TFLite 模型的输入和输出。API 设计和文档与 TensorFlow Lite Android Support Library 相同。强烈推荐与 <code>tflite_flutter_plugin</code> 一起使用。了解更多

示例

标题 代码 Demo
文本分类应用 代码
图像分类应用 代码
对象检测应用 代码
强化学习应用 代码

导入

import 'package:tflite_flutter/tflite_flutter.dart';

使用说明

创建解释器

从资源文件导入

<code>your_model.tflite</code> 放在 <code>assets</code> 目录中,并确保在 <code>pubspec.yaml</code> 中包含资源文件。

final interpreter = await tfl.Interpreter.fromAsset('your_model.tflite');
创建解释器从缓冲区或文件

有关更多信息,请参阅文档。

执行推理

请参阅 TFLite Flutter Helper Library 以获取有关输入和输出处理的简便方法。

单输入单输出
// 例如,如果输入张量形状为 [1,5] 且类型为 float32
var input = [[1.23, 6.54, 7.81, 3.21, 2.22]];

// 如果输出张量形状为 [1,2] 且类型为 float32
var output = List.filled(1 * 2, 0).reshape([1, 2]);

// 推理
interpreter.run(input, output);

// 打印输出
print(output);
多输入多输出
var input0 = [1.23];  
var input1 = [2.43];  

// 输入:List<Object>
var inputs = [input0, input1, input0, input1];  

var output0 = List<double>.filled(1, 0);  
var output1 = List<double>.filled(1, 0);

// 输出:Map<int, Object>
var outputs = {0: output0, 1: output1};

// 推理  
interpreter.runForMultipleInputs(inputs, outputs);

// 打印输出
print(outputs)

关闭解释器

interpreter.close();

提高性能使用委托支持

注意:此功能正在测试中,可能在某些构建和设备上不稳定。

Android NNAPI 委托
var interpreterOptions = InterpreterOptions()..useNnApiForAndroid = true;
final interpreter = await Interpreter.fromAsset('your_model.tflite',
    options: interpreterOptions);

// 或者
var interpreterOptions = InterpreterOptions()..addDelegate(NnApiDelegate());
final interpreter = await Interpreter.fromAsset('your_model.tflite',
    options: interpreterOptions);
Android 和 iOS GPU 委托
Android GpuDelegateV2
final gpuDelegateV2 = GpuDelegateV2(
        options: GpuDelegateOptionsV2(
        false,
        TfLiteGpuInferenceUsage.fastSingleAnswer,
        TfLiteGpuInferencePriority.minLatency,
        TfLiteGpuInferencePriority.auto,
        TfLiteGpuInferencePriority.auto,
    ));

var interpreterOptions = InterpreterOptions()..addDelegate(gpuDelegateV2);
final interpreter = await Interpreter.fromAsset('your_model.tflite',
    options: interpreterOptions);
iOS Metal 委托 (GpuDelegate)
final gpuDelegate = GpuDelegate(
      options: GpuDelegateOptions(true, TFLGpuDelegateWaitType.active),
    );

var interpreterOptions = InterpreterOptions()..addDelegate(gpuDelegate);
final interpreter = await Interpreter.fromAsset('your_model.tflite',
    options: interpreterOptions);

代码示例

以下是一个简单的文本分类应用的示例代码:

import 'package:flutter/material.dart';

import 'classifier.dart';

void main() => runApp(const MyApp());

class MyApp extends StatefulWidget {
  const MyApp({super.key});

  [@override](/user/override)
  _MyAppState createState() => _MyAppState();
}

class _MyAppState extends State<MyApp> {
  late TextEditingController _controller;
  late Classifier _classifier;
  late List<Widget> _children;

  [@override](/user/override)
  void initState() {
    super.initState();
    _controller = TextEditingController();
    _classifier = Classifier();
    _children = [];
    _children.add(Container());
  }

  [@override](/user/override)
  Widget build(BuildContext context) {
    return MaterialApp(
      home: Scaffold(
        appBar: AppBar(
          backgroundColor: Colors.orangeAccent,
          title: const Text('Text classification'),
        ),
        body: Container(
          padding: const EdgeInsets.all(4),
          child: Column(
            children: <Widget>[
              Expanded(
                  child: ListView.builder(
                    itemCount: _children.length,
                    itemBuilder: (_, index) {
                      return _children[index];
                    },
                  )),
              Container(
                padding: const EdgeInsets.all(8),
                decoration: BoxDecoration(
                    border: Border.all(color: Colors.orangeAccent)),
                child: Row(children: <Widget>[
                  Expanded(
                    child: TextField(
                      decoration: const InputDecoration(
                          hintText: 'Write some text here'),
                      controller: _controller,
                    ),
                  ),
                  TextButton(
                    child: const Text('Classify'),
                    onPressed: () {
                      final text = _controller.text;
                      final prediction = _classifier.classify(text);
                      setState(() {
                        _children.add(Dismissible(
                          key: GlobalKey(),
                          onDismissed: (direction) {},
                          child: Card(
                            child: Container(
                              padding: const EdgeInsets.all(16),
                              color: prediction[1] > prediction[0]
                                  ? Colors.lightGreen
                                  : Colors.redAccent,
                              child: Column(
                                crossAxisAlignment: CrossAxisAlignment.start,
                                children: <Widget>[
                                  Text(
                                    "Input: $text",
                                    style: const TextStyle(fontSize: 16),
                                  ),
                                  const Text("Output:"),
                                  Text("   Positive: ${prediction[1]}"),
                                  Text("   Negative: ${prediction[0]}"),
                                ],
                              ),
                            ),
                          ),
                        ));
                        _controller.clear();
                      });
                    },
                  ),
                ]),
              ),
            ],
          ),
        ),
      ),
    );
  }
}

更多关于Flutter机器学习推理插件tflite_flutter_plus的使用的实战教程也可以访问 https://www.itying.com/category-92-b0.html

1 回复

更多关于Flutter机器学习推理插件tflite_flutter_plus的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html


当然,以下是一个关于如何使用 tflite_flutter_plus 插件在 Flutter 中进行机器学习推理的示例代码。这个插件允许你在 Flutter 应用中加载和使用 TensorFlow Lite 模型进行推理。

1. 添加依赖

首先,你需要在你的 pubspec.yaml 文件中添加 tflite_flutter_plus 依赖:

dependencies:
  flutter:
    sdk: flutter
  tflite_flutter_plus: ^2.10.0  # 请使用最新版本

然后运行 flutter pub get 来获取依赖。

2. 配置 Android 项目

在你的 android/app/build.gradle 文件中,确保你设置了正确的 minSdkVersiontargetSdkVersion,并且添加了必要的权限:

android {
    ...
    defaultConfig {
        ...
        minSdkVersion 21
        targetSdkVersion 30
        ...
    }
    ...
}

dependencies {
    ...
    implementation 'org.tensorflow:tensorflow-lite:2.5.0'  // 确保 TensorFlow Lite 版本与插件兼容
    implementation 'org.tensorflow:tensorflow-lite-gpu:2.5.0'  // 如果你打算使用 GPU 加速
    ...
}

3. 加载模型并进行推理

以下是一个完整的 Flutter 应用示例,展示如何使用 tflite_flutter_plus 加载模型并进行推理:

import 'package:flutter/material.dart';
import 'package:tflite_flutter_plus/tflite_flutter_plus.dart';

void main() {
  runApp(MyApp());
}

class MyApp extends StatefulWidget {
  @override
  _MyAppState createState() => _MyAppState();
}

class _MyAppState extends State<MyApp> {
  late Interpreter _interpreter;
  late List<Float32List> _outputs;

  @override
  void initState() {
    super.initState();
    loadModel();
  }

  Future<void> loadModel() async {
    // 确保模型文件位于设备的正确位置,例如在 assets 文件夹中
    var model = await Tflite.loadModel(
      model: "assets/model.tflite",
      labels: "assets/labels.txt",  // 如果你的模型包含标签文件
    );
    if (model != null) {
      setState(() {
        _interpreter = model;
      });
    }
  }

  Future<void> runInference(List<List<Float32List>> input) async {
    if (_interpreter != null) {
      _outputs = await _interpreter.run(input);
      setState(() {});
    }
  }

  @override
  Widget build(BuildContext context) {
    return MaterialApp(
      home: Scaffold(
        appBar: AppBar(
          title: Text('TFLite Flutter Plus Example'),
        ),
        body: Center(
          child: Column(
            mainAxisAlignment: MainAxisAlignment.center,
            children: <Widget>[
              ElevatedButton(
                onPressed: () async {
                  // 假设输入数据是一个 28x28 的灰度图像,这里以随机数为例
                  List<List<Float32List>> input = [
                    Float32List.fromList(List.generate(28 * 28, (index) => (index % 256).toDouble() / 255.0).toList())
                      .reshape([1, 28, 28, 1]) as Float32List,
                  ];

                  await runInference(input);

                  // 显示推理结果
                  print("Inference Result: ${_outputs[0].toList()}");
                },
                child: Text('Run Inference'),
              ),
            ],
          ),
        ),
      ),
    );
  }

  @override
  void dispose() {
    _interpreter?.close();
    super.dispose();
  }
}

注意事项

  1. 模型文件:确保你的 TensorFlow Lite 模型文件(例如 model.tflite)和标签文件(如果有的话,例如 labels.txt)已经放置在 assets 文件夹中,并在 pubspec.yaml 中正确声明:

    flutter:
      assets:
        - assets/model.tflite
        - assets/labels.txt
    
  2. 输入数据:根据你的模型输入要求准备输入数据。例如,上面的示例假设输入是一个 28x28 的灰度图像,但实际使用中你需要根据模型的输入规格来准备数据。

  3. 清理资源:在 dispose 方法中关闭解释器,以释放资源。

这个示例展示了如何使用 tflite_flutter_plus 插件在 Flutter 应用中加载 TensorFlow Lite 模型并进行推理。你可以根据自己的需求进一步扩展和修改这个示例。

回到顶部