Flutter机器学习推理插件tflite_flutter的使用

发布于 1周前 作者 ionicwang 来自 Flutter

Flutter机器学习推理插件tflite_flutter的使用

简介

TensorFlow Lite Flutter插件提供了一个灵活且快速的解决方案,用于访问TensorFlow Lite解释器并执行推理。其API与TFLite Java和Swift API类似,直接绑定到TFLite C API,使其高效(低延迟)。它支持使用NNAPI、GPU委托(Android)、Metal和CoreML委托(iOS)以及XNNPack委托(桌面平台)进行加速。

tflite_flutter_cover

Platform-Flutter Pub Package Docs License-Apache 2.0

关键特性

  • 支持Android和iOS多平台
  • 可以使用任何TFLite模型
  • 使用多线程加速
  • 类似于TensorFlow Lite Java API的结构
  • 推理速度接近使用Java API构建的原生Android应用程序
  • 在不同的isolate中运行推理以防止UI线程卡顿

(重要) 初始设置:将动态库添加到您的应用

Android & iOS

  • 示例和支持现在支持动态库下载!iOS示例可以通过以下命令运行:
    flutter build ios
    flutter install ios
    
  • Android可以通过以下命令运行:
    flutter build android
    flutter install android
    
    注意:这需要设备的最低API级别为26。 注意:TFLite可能无法在iOS模拟器上工作。建议您使用物理设备进行测试。
  • 当创建发布存档(IPA)时,符号会被Xcode剥离,因此flutter build ipa命令可能会抛出Failed to lookup symbol ... symbol not found错误。解决方法是:
    1. 在Xcode中,转到Target Runner > Build Settings > Strip Style
    2. All Symbols更改为Non-Global Symbols

MacOS

对于MacOS,需要手动将TensorFlow Lite动态库添加到项目中。首先需要构建一个.dylib文件。可以参考Bazel构建指南CMake构建指南来构建库文件。

lipo -create arm64/libtensorflowlite_c.dylib x86/libtensorflowlite_c.dylib -output libtensorflowlite_c.dylib

然后根据官方Flutter指南中的步骤1和2将库添加到您的XCode项目中。

Linux

对于Linux,需要手动将TensorFlow Lite动态库添加到项目中。首先需要构建一个.so文件。可以参考Bazel构建指南CMake构建指南来构建库文件。

  1. 创建一个名为blobs的文件夹在项目的顶级目录
  2. libtensorflowlite_c-linux.so复制到此文件夹
  3. 将以下行附加到您的linux/CMakeLists.txt
    # get tf lite binaries
    install(
      FILES ${PROJECT_BUILD_DIR}/../blobs/libtensorflowlite_c-linux.so
      DESTINATION ${INSTALL_BUNDLE_DATA_DIR}/../blobs/
    )
    

Windows

对于Windows,需要手动将TensorFlow Lite动态库添加到项目中。首先需要构建一个.dll文件。可以参考Bazel构建指南CMake构建指南来构建库文件。

  1. 创建一个名为blobs的文件夹在项目的顶级目录
  2. libtensorflowlite_c-win.dll复制到此文件夹
  3. 将以下行附加到您的windows/CMakeLists.txt
    # get tf lite binaries
    install(
      FILES ${PROJECT_BUILD_DIR}/../blobs/libtensorflowlite_c-win.dll 
      DESTINATION ${INSTALL_BUNDLE_DATA_DIR}/../blobs/
    )
    

TFLite Flutter Helper Library

辅助库已被弃用。新的开发正在进行中,替代方案位于flutter-mediapipe。当前计划是在2023年8月底之前提供广泛的支持。

导入

import 'package:tflite_flutter/tflite_flutter.dart';

使用说明

导入库

pubspec.yaml文件的依赖部分添加tflite_flutter: ^0.10.1(根据最新版本调整版本号)

创建解释器

从资源加载

your_model.tflite放在assets目录下,并确保在pubspec.yaml中包含资源。

final interpreter = await Interpreter.fromAsset('assets/your_model.tflite');

执行推理

单输入输出

使用void run(Object input, Object output)

// 输入张量形状 [1,5] 类型 float32
var input = [[1.23, 6.54, 7.81, 3.21, 2.22]];

// 输出张量形状 [1,2] 类型 float32
var output = List.filled(1*2, 0).reshape([1,2]);

// 推理
interpreter.run(input, output);

// 打印输出
print(output);

多输入输出

使用void runForMultipleInputs(List<Object> inputs, Map<int, Object> outputs)

var input0 = [1.23];  
var input1 = [2.43];  

// 输入: List<Object>
var inputs = [input0, input1, input0, input1];  

var output0 = List<double>.filled(1, 0);  
var output1 = List<double>.filled(1, 0);

// 输出: Map<int, Object>
var outputs = {0: output0, 1: output1};

// 推理  
interpreter.runForMultipleInputs(inputs, outputs);

// 打印输出
print(outputs);

关闭解释器

interpreter.close();

异步推理使用IsolateInterpreter

为了利用异步推理,首先创建你的Interpreter,然后用IsolateInterpreter包装它。

final interpreter = await Interpreter.fromAsset('assets/your_model.tflite');
final isolateInterpreter =
        await IsolateInterpreter.create(address: interpreter.address);

isolateInterpreterrunrunForMultipleInputs方法都是异步的:

await isolateInterpreter.run(input, output);
await isolateInterpreter.runForMultipleInputs(inputs, outputs);

通过使用IsolateInterpreter,推理将在单独的isolate中运行。这确保了负责UI任务的主要isolate保持未被阻塞且响应迅速。

贡献给这个包

该包使用melos管理。开始工作之前,请确保运行引导命令。

dart pub global activate melos # 安装或激活melos全局
melos bootstrap # 初始化工作区并引导包

生成代码

该包使用ffigen生成FFI绑定。要运行代码生成,可以使用以下melos命令:

melos run ffigen 

完整示例Demo

下面是一个完整的示例演示如何在Flutter中使用tflite_flutter插件进行推理:

import 'package:flutter/material.dart';
import 'package:tflite_flutter/tflite_flutter.dart';

void main() {
  runApp(MyApp());
}

class MyApp extends StatelessWidget {
  @override
  Widget build(BuildContext context) {
    return MaterialApp(
      home: Scaffold(
        appBar: AppBar(title: Text('TFLite Flutter Demo')),
        body: Center(child: MyHomePage()),
      ),
    );
  }
}

class MyHomePage extends StatefulWidget {
  @override
  _MyHomePageState createState() => _MyHomePageState();
}

class _MyHomePageState extends State<MyHomePage> {
  late Interpreter interpreter;
  List<List<double>>? output;

  @override
  void initState() {
    super.initState();
    loadModel();
  }

  Future<void> loadModel() async {
    try {
      interpreter = await Interpreter.fromAsset('assets/your_model.tflite');
      print('Model loaded successfully.');
    } catch (e) {
      print('Failed to load model: $e');
    }
  }

  Future<void> runInference() async {
    var input = [[1.23, 6.54, 7.81, 3.21, 2.22]];
    var outputData = List.filled(1 * 2, 0).reshape([1, 2]);
    
    if (interpreter != null) {
      interpreter.run(input, outputData);
      setState(() {
        output = outputData;
      });
    }
  }

  @override
  Widget build(BuildContext context) {
    return Column(
      mainAxisAlignment: MainAxisAlignment.center,
      children: <Widget>[
        ElevatedButton(
          onPressed: runInference,
          child: Text('Run Inference'),
        ),
        SizedBox(height: 20),
        if (output != null)
          Text('Output: ${output!.map((row) => row.join(', ')).join('\n')}')
        else
          Text('No output yet.'),
      ],
    );
  }

  @override
  void dispose() {
    interpreter.close();
    super.dispose();
  }
}

这个示例展示了如何加载模型、执行推理并在界面上显示结果。希望这对您有所帮助!


更多关于Flutter机器学习推理插件tflite_flutter的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html

1 回复

更多关于Flutter机器学习推理插件tflite_flutter的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html


当然,以下是如何在Flutter项目中使用tflite_flutter插件进行机器学习推理的一个基本示例。这个示例将展示如何加载一个TensorFlow Lite模型并进行推理。

1. 添加依赖

首先,你需要在pubspec.yaml文件中添加tflite_flutter依赖:

dependencies:
  flutter:
    sdk: flutter
  tflite_flutter: ^0.9.0  # 请检查最新版本号

然后运行flutter pub get来安装依赖。

2. 导入必要的包

在你的Flutter项目中,打开你需要使用TensorFlow Lite推理的Dart文件,并导入tflite_flutter包:

import 'package:flutter/material.dart';
import 'package:tflite_flutter/tflite_flutter.dart';

3. 加载模型和进行推理

以下是一个简单的示例,展示了如何加载一个TensorFlow Lite模型并进行推理:

void main() => runApp(MyApp());

class MyApp extends StatelessWidget {
  @override
  Widget build(BuildContext context) {
    return MaterialApp(
      home: Scaffold(
        appBar: AppBar(
          title: Text('TFLite Flutter Example'),
        ),
        body: Center(
          child: MyTFLiteModel(),
        ),
      ),
    );
  }
}

class MyTFLiteModel extends StatefulWidget {
  @override
  _MyTFLiteModelState createState() => _MyTFLiteModelState();
}

class _MyTFLiteModelState extends State<MyTFLiteModel> {
  late Interpreter _interpreter;
  late List<Float32List> _inputTensorBuffer;
  late List<Float32List> _outputTensorBuffer;

  @override
  void initState() {
    super.initState();
    loadModel().then((_) => performInference());
  }

  Future<void> loadModel() async {
    // 加载TensorFlow Lite模型
    _interpreter = await Interpreter.fromAsset('model.tflite');
    // 假设模型有一个输入和一个输出,并且输入形状为[1, 224, 224, 3](例如MobileNet)
    // 这里需要根据实际模型的输入形状进行调整
    _inputTensorBuffer = List.generate(1, () => Float32List(1 * 224 * 224 * 3));
    _outputTensorBuffer = List.generate(1, () => Float32List(1001)); // 假设输出有1001个类
    setState(() {});
  }

  Future<void> performInference() async {
    // 这里你需要填充输入张量,这里只是示例,用随机数据填充
    for (int i = 0; i < _inputTensorBuffer[0].length; i++) {
      _inputTensorBuffer[0][i] = Random().nextDouble() * 255;
    }

    // 执行推理
    await _interpreter.run(_inputTensorBuffer, _outputTensorBuffer);

    // 处理输出结果,例如找到概率最高的类
    double maxProbability = _outputTensorBuffer[0].reduce((a, b) => Math.max(a, b));
    int predictedClass = _outputTensorBuffer[0].asMap().entries
        .firstWhere((entry) => entry.value == maxProbability)
        .key;

    // 打印结果
    print('Predicted class: $predictedClass');
  }

  @override
  Widget build(BuildContext context) {
    return Column(
      mainAxisAlignment: MainAxisAlignment.center,
      children: <Widget>[
        Text('Loading model...'),
        ElevatedButton(
          onPressed: () => performInference(),
          child: Text('Run Inference'),
        ),
      ],
    );
  }
}

4. 放置模型文件

确保你的TensorFlow Lite模型文件(例如model.tflite)已经被放置在Flutter项目的assets文件夹中,并在pubspec.yaml中正确配置:

flutter:
  assets:
    - assets/model.tflite

注意事项

  1. 模型兼容性:确保你的TensorFlow Lite模型与Flutter插件兼容,并且输入输出的形状和数据类型正确。
  2. 性能优化:对于大型模型或复杂任务,可能需要优化输入数据的预处理和后处理步骤,以提高推理性能。
  3. 权限管理:如果你的应用需要访问设备的摄像头或文件系统等,请确保在AndroidManifest.xmlInfo.plist中正确配置权限。

这个示例展示了基本的模型加载和推理过程,你可以根据实际需求进行扩展和修改。

回到顶部