Flutter机器学习插件yjy_tflite_flutter的使用

Flutter机器学习插件yjy_tflite_flutter的使用

概述

TensorFlow Lite Flutter 插件提供了一种灵活且快速的方法来访问 TensorFlow Lite 解释器并执行推理。该 API 类似于 TensorFlow Lite 的 Java 和 Swift API。它直接绑定到 TensorFlow Lite C API,使其高效(低延迟)。支持使用 NNAPI、Android 上的 GPU 委托和 iOS 上的 Metal 委托进行加速。

封面图


关键特性

  • 可以使用任何 TensorFlow Lite 模型。
  • 支持多线程和委托加速。
  • API 结构与 TensorFlow Lite Java API 类似。
  • 推理速度接近原生 Android 应用程序。
  • 可以选择使用任何版本的 TensorFlow 构建二进制文件。
  • 可以在不同的隔离中运行推理,以防止 UI 线程卡顿。

初始化设置:添加动态库到你的应用

Android
  1. 将脚本 install.sh(Linux/Mac)或 install.bat(Windows)放置在项目的根目录下。
  2. 在项目的根目录下执行 sh install.sh(Linux)/ install.bat(Windows),自动下载并放置二进制文件到适当的文件夹。

    注意:安装的二进制文件不包括对 GpuDelegateV2NnApiDelegate 的支持,但仍然可以使用 InterpreterOptions().useNnApiForAndroid

  3. 如果希望使用 GpuDelegateV2NnApiDelegate,请改用 sh install.sh -d(Linux)或 install.bat -d(Windows)。

这些脚本会根据最新稳定的 TensorFlow 版本安装预构建的二进制文件。有关如何使用其他 TensorFlow 版本的信息,请参阅此处


TFLite Flutter 辅助库

一个专门用于处理和操作 TensorFlow Lite 模型输入和输出的简单架构库。API 设计和文档与 TensorFlow Lite Android 支持库相同。强烈建议与 tflite_flutter_plugin 一起使用。了解更多


示例

标题 代码 Demo Blog
文本分类应用程序 代码 博客/教程
图像分类应用程序 代码 -
实时对象检测应用程序 代码 博客/教程
强化学习应用程序 代码 博客/教程

导入

import 'package:tflite_flutter/tflite_flutter.dart';

使用说明

创建解释器
从资产加载模型

your_model.tflite 放置在 assets 目录中,并确保在 pubspec.yaml 中包含资产。

final interpreter = await Interpreter.fromAsset('your_model.tflite');

有关从缓冲区或文件创建解释器的更多信息,请参阅文档。

执行推理

推荐使用TFLite Flutter Helper Library来简化输入和输出的处理。

单输入单输出
// 对于输入张量形状为 [1,5] 且类型为 float32 的情况
var input = [[1.23, 6.54, 7.81, 3.21, 2.22]];

// 如果输出张量形状为 [1,2] 且类型为 float32
var output = List.filled(1 * 2, 0).reshape([1, 2]);

// 推理
interpreter.run(input, output);

// 打印输出
print(output);
多输入多输出
var input0 = [1.23];  
var input1 = [2.43];  

// 输入列表
var inputs = [input0, input1, input0, input1];  

var output0 = List<double>.filled(1, 0);  
var output1 = List<double>.filled(1, 0);

// 输出映射
var outputs = {0: output0, 1: output1};

// 推理  
interpreter.runForMultipleInputs(inputs, outputs);

// 打印输出
print(outputs);
关闭解释器
interpreter.close();
使用委托提高性能
Android 的 NNAPI 委托
var interpreterOptions = InterpreterOptions()..useNnApiForAndroid = true;
final interpreter = await Interpreter.fromAsset('your_model.tflite',
    options: interpreterOptions);

或者

var interpreterOptions = InterpreterOptions()..addDelegate(NnApiDelegate());
final interpreter = await Interpreter.fromAsset('your_model.tflite',
    options: interpreterOptions);
Android 和 iOS 的 GPU 委托
Android 的 GpuDelegateV2
final gpuDelegateV2 = GpuDelegateV2(
        options: GpuDelegateOptionsV2(
        false,
        TfLiteGpuInferenceUsage.fastSingleAnswer,
        TfLiteGpuInferencePriority.minLatency,
        TfLiteGpuInferencePriority.auto,
        TfLiteGpuInferencePriority.auto,
    ));

var interpreterOptions = InterpreterOptions()..addDelegate(gpuDelegateV2);
final interpreter = await Interpreter.fromAsset('your_model.tflite',
    options: interpreterOptions);
iOS 的 Metal 委托 (GpuDelegate)
final gpuDelegate = GpuDelegate(
      options: GpuDelegateOptions(true, TFLGpuDelegateWaitType.active),
    );
var interpreterOptions = InterpreterOptions()..addDelegate(gpuDelegate);
final interpreter = await Interpreter.fromAsset('your_model.tflite',
    options: interpreterOptions);

更多示例代码请参考测试文件


使用任意 TensorFlow 版本

预构建的二进制文件会随着每个稳定版 TensorFlow 更新。如果您想使用最新的不稳定版 TensorFlow 或旧版本,请按照以下步骤本地构建。

Android

配置 Android 构建环境,请参考TensorFlow Lite 官方指南

对于 TensorFlow >= v2.2:

bazel build -c opt --cxxopt=--std=c++11 --config=android_arm //tensorflow/lite/c:tensorflowlite_c

// 类似地,对于 arm64 使用 --config=android_arm64

对于 TensorFlow <= v2.1:

bazel build -c opt --cxxopt=--std=c++11 --config=android_arm //tensorflow/lite/experimental/c:libtensorflowlite_c.so

// 类似地,对于 arm64 使用 --config=android_arm64
iOS

请参考TensorFlow Lite 官方指南来本地构建 iOS 版本。

注意:必须使用 macOS 来构建 iOS。


动态链接信息

tflite_flutter 动态链接到 TensorFlow Lite C API,Android 上以 libtensorflowlite_c.so 形式提供,iOS 上以 TensorFlowLiteC.framework 形式提供。

对于 Android,需要手动从发布资产下载这些二进制文件,并将 libtensorflowlite_c.so 文件放置到 <root>/android/app/src/main/jniLibs/ 目录下的每个架构文件夹中(如 arm、arm64、x86、x86_64),如示例应用程序中所示。


未来工作

  • 支持 Flutter 桌面应用程序。
  • 改进错误处理。

致谢

  • TensorFlow Lite 团队成员 Tian LIN、Jared Duke、Andrew Selle、YoungSeok Yoon、Shuangfeng Li。
  • dart-lang/tflite_native 的作者们。

示例代码

以下是完整的文本分类应用程序示例代码:

import 'package:flutter/material.dart';
import 'package:tflite_flutter_plugin_example/classifier.dart';

void main() => runApp(MyApp());

class MyApp extends StatefulWidget {
  [@override](/user/override)
  _MyAppState createState() => _MyAppState();
}

class _MyAppState extends State<MyApp> {
  late TextEditingController _controller;
  late Classifier _classifier;
  late List<Widget> _children;

  [@override](/user/override)
  void initState() {
    super.initState();
    _controller = TextEditingController();
    _classifier = Classifier();
    _children = [];
    _children.add(Container());
  }

  [@override](/user/override)
  Widget build(BuildContext context) {
    return MaterialApp(
      home: Scaffold(
        appBar: AppBar(
          backgroundColor: Colors.orangeAccent,
          title: const Text('Text classification'),
        ),
        body: Container(
          padding: const EdgeInsets.all(4),
          child: Column(
            children: <Widget>[
              Expanded(
                child: ListView.builder(
                  itemCount: _children.length,
                  itemBuilder: (_, index) {
                    return _children[index];
                  },
                ),
              ),
              Container(
                padding: const EdgeInsets.all(8),
                decoration: BoxDecoration(
                    border: Border.all(color: Colors.orangeAccent)),
                child: Row(children: <Widget>[
                  Expanded(
                    child: TextField(
                      decoration: const InputDecoration(hintText: 'Write some text here'),
                      controller: _controller,
                    ),
                  ),
                  TextButton(
                    child: const Text('Classify'),
                    onPressed: () {
                      final text = _controller.text;
                      final prediction = _classifier.classify(text);
                      setState(() {
                        _children.add(Dismissible(
                          key: GlobalKey(),
                          onDismissed: (direction) {},
                          child: Card(
                            child: Container(
                              padding: const EdgeInsets.all(16),
                              color: prediction[1] > prediction[0]
                                  ? Colors.lightGreen
                                  : Colors.redAccent,
                              child: Column(
                                crossAxisAlignment: CrossAxisAlignment.start,
                                children: <Widget>[
                                  Text(
                                    "Input: $text",
                                    style: const TextStyle(fontSize: 16),
                                  ),
                                  Text("Output:"),
                                  Text("   Positive: ${prediction[1]}"),
                                  Text("   Negative: ${prediction[0]}"),
                                ],
                              ),
                            ),
                          ),
                        ));
                        _controller.clear();
                      });
                    },
                  ),
                ]),
              ),
            ],
          ),
        ),
      ),
    );
  }
}

更多关于Flutter机器学习插件yjy_tflite_flutter的使用的实战教程也可以访问 https://www.itying.com/category-92-b0.html

1 回复

更多关于Flutter机器学习插件yjy_tflite_flutter的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html


yjy_tflite_flutter 是一个用于在 Flutter 应用中集成 TensorFlow Lite 的插件。它允许你在 Flutter 应用中使用 TensorFlow Lite 模型进行机器学习推理。以下是如何使用 yjy_tflite_flutter 插件的基本步骤:

1. 添加依赖

首先,你需要在 pubspec.yaml 文件中添加 yjy_tflite_flutter 插件的依赖:

dependencies:
  flutter:
    sdk: flutter
  yjy_tflite_flutter: ^0.1.0  # 请使用最新版本

然后运行 flutter pub get 来获取依赖。

2. 加载 TensorFlow Lite 模型

在使用模型之前,你需要将 TensorFlow Lite 模型文件(.tflite)放到你的 Flutter 项目中。通常,你可以将模型文件放在 assets 文件夹中,并在 pubspec.yaml 中声明:

flutter:
  assets:
    - assets/model.tflite

然后,你可以使用 yjy_tflite_flutter 插件加载模型:

import 'package:yjy_tflite_flutter/yjy_tflite_flutter.dart';

Future<void> loadModel() async {
  try {
    await YjyTfliteFlutter.loadModel(
      model: "assets/model.tflite",
      labels: "assets/labels.txt",  // 如果有标签文件
    );
    print("Model loaded successfully");
  } catch (e) {
    print("Failed to load model: $e");
  }
}

3. 运行推理

加载模型后,你可以使用 runModelOnImagerunModelOnFrame 等方法来进行推理。以下是一个使用 runModelOnImage 的示例:

import 'package:image_picker/image_picker.dart';

Future<void> runInference() async {
  final picker = ImagePicker();
  final pickedFile = await picker.getImage(source: ImageSource.camera);

  if (pickedFile != null) {
    try {
      var recognitions = await YjyTfliteFlutter.runModelOnImage(
        path: pickedFile.path,
        imageMean: 127.5,  // 根据需要调整
        imageStd: 127.5,   // 根据需要调整
        numResults: 5,     // 返回的结果数量
        threshold: 0.5,    // 置信度阈值
      );

      print(recognitions);
    } catch (e) {
      print("Failed to run inference: $e");
    }
  }
}

4. 释放资源

在使用完模型后,记得释放资源以避免内存泄漏:

Future<void> disposeModel() async {
  await YjyTfliteFlutter.dispose();
  print("Model disposed");
}

5. 处理结果

推理结果通常是一个包含预测标签和置信度的列表。你可以根据需要处理这些结果并显示在你的应用中。

6. 完整示例

以下是一个完整的示例,展示了如何加载模型、运行推理并显示结果:

import 'package:flutter/material.dart';
import 'package:yjy_tflite_flutter/yjy_tflite_flutter.dart';
import 'package:image_picker/image_picker.dart';

void main() => runApp(MyApp());

class MyApp extends StatelessWidget {
  [@override](/user/override)
  Widget build(BuildContext context) {
    return MaterialApp(
      home: HomeScreen(),
    );
  }
}

class HomeScreen extends StatefulWidget {
  [@override](/user/override)
  _HomeScreenState createState() => _HomeScreenState();
}

class _HomeScreenState extends State<HomeScreen> {
  List<dynamic> _recognitions = [];

  [@override](/user/override)
  void initState() {
    super.initState();
    loadModel();
  }

  Future<void> loadModel() async {
    try {
      await YjyTfliteFlutter.loadModel(
        model: "assets/model.tflite",
        labels: "assets/labels.txt",
      );
      print("Model loaded successfully");
    } catch (e) {
      print("Failed to load model: $e");
    }
  }

  Future<void> runInference() async {
    final picker = ImagePicker();
    final pickedFile = await picker.getImage(source: ImageSource.camera);

    if (pickedFile != null) {
      try {
        var recognitions = await YjyTfliteFlutter.runModelOnImage(
          path: pickedFile.path,
          imageMean: 127.5,
          imageStd: 127.5,
          numResults: 5,
          threshold: 0.5,
        );

        setState(() {
          _recognitions = recognitions;
        });
      } catch (e) {
        print("Failed to run inference: $e");
      }
    }
  }

  [@override](/user/override)
  void dispose() {
    YjyTfliteFlutter.dispose();
    super.dispose();
  }

  [@override](/user/override)
  Widget build(BuildContext context) {
    return Scaffold(
      appBar: AppBar(
        title: Text('TensorFlow Lite Flutter'),
      ),
      body: Column(
        children: [
          ElevatedButton(
            onPressed: runInference,
            child: Text('Run Inference'),
          ),
          Expanded(
            child: ListView.builder(
              itemCount: _recognitions.length,
              itemBuilder: (context, index) {
                return ListTile(
                  title: Text(_recognitions[index]['label']),
                  subtitle: Text((_recognitions[index]['confidence'] * 100).toStringAsFixed(2) + '%'),
                );
              },
            ),
          ),
        ],
      ),
    );
  }
}
回到顶部