Flutter机器学习插件tensorflow_lite_flutter的使用

发布于 1周前 作者 vueper 来自 Flutter

Flutter机器学习插件tensorflow_lite_flutter的使用

tensorflow_lite_flutter是一个用于访问TensorFlow Lite API的Flutter插件,支持图像分类、对象检测(SSD和YOLO)、Pix2Pix、Deeplab和PoseNet在iOS和Android平台上的应用。

安装

首先,在pubspec.yaml文件中添加tflite依赖:

dependencies:
  tflite: ^1.1.2
  image_picker: ^0.8.4+3

Android配置

android/app/build.gradle文件中的android块内添加以下设置:

aaptOptions {
    noCompress 'tflite'
    noCompress 'lite'
}

iOS配置

对于iOS,如果遇到构建错误,比如 'vector' file not found,需要打开ios/Runner.xcworkspace,点击Runner > Targets > Runner > Build Settings,搜索Compile Sources As,将其值更改为Objective-C++

使用方法

步骤一:准备模型文件

创建一个assets文件夹,并将标签文件和模型文件放入其中。然后在pubspec.yaml中添加:

assets:
 - assets/labels.txt
 - assets/mobilenet_v1_1.0_224.tflite

步骤二:导入库

在Dart文件中导入tensorflow_lite_flutter库:

import 'package:tflite/tflite.dart';

步骤三:加载模型和标签

String res = await Tflite.loadModel(
  model: "assets/mobilenet_v1_1.0_224.tflite",
  labels: "assets/labels.txt",
  numThreads: 1,
  isAsset: true,
  useGpuDelegate: false
);
print(res); // 输出加载结果

示例代码

下面是一个完整的示例程序,演示如何从图库选择图片并进行预测:

import 'dart:async';
import 'dart:io';
import 'dart:math';
import 'dart:typed_data';
import 'package:flutter/material.dart';
import 'package:image_picker/image_picker.dart';
import 'package:tflite/tflite.dart';

void main() => runApp(App());

class App extends StatelessWidget {
  [@override](/user/override)
  Widget build(BuildContext context) {
    return MaterialApp(
      home: MyApp(),
    );
  }
}

class MyApp extends StatefulWidget {
  [@override](/user/override)
  _MyAppState createState() => _MyAppState();
}

class _MyAppState extends State<MyApp> {
  File? _image;
  List? _recognitions;
  bool _busy = false;

  Future predictImagePicker() async {
    var imagePicker = ImagePicker();
    var image = await imagePicker.pickImage(source: ImageSource.gallery);
    if (image == null) return;
    setState(() {
      _busy = true;
    });
    predictImage(image as File);
  }

  Future predictImage(File image) async {
    var recognitions = await Tflite.runModelOnImage(
      path: image.path,
      numResults: 6,
      threshold: 0.05,
      imageMean: 127.5,
      imageStd: 127.5,
    );
    setState(() {
      _recognitions = recognitions!;
      _busy = false;
    });
  }

  [@override](/user/override)
  void initState() {
    super.initState();
    loadModel().then((val) {});
  }

  Future loadModel() async {
    Tflite.close();
    try {
      String? res = await Tflite.loadModel(
        model: "assets/mobilenet_v1_1.0_224.tflite",
        labels: "assets/labels.txt",
      );
      print(res!);
    } on PlatformException {
      print('Failed to load model.');
    }
  }

  [@override](/user/override)
  Widget build(BuildContext context) {
    Size size = MediaQuery.of(context).size;
    List<Widget> stackChildren = [];

    stackChildren.add(Positioned(
      top: 0.0,
      left: 0.0,
      width: size.width,
      child: _image == null ? Text('No image selected.') : Image.file(_image!),
    ));

    if (_recognitions != null)
      stackChildren.add(Center(
        child: Column(
          children: _recognitions!.map((res) {
            return Text(
              "${res["index"]} - ${res["label"]}: ${res["confidence"].toStringAsFixed(3)}",
              style: TextStyle(color: Colors.black, fontSize: 20.0, background: Paint()..color = Colors.white),
            );
          }).toList(),
        ),
      ));

    if (_busy) {
      stackChildren.add(const Opacity(
        child: ModalBarrier(dismissible: false, color: Colors.grey),
        opacity: 0.3,
      ));
      stackChildren.add(const Center(child: CircularProgressIndicator()));
    }

    return Scaffold(
      appBar: AppBar(
        title: const Text('tflite example app'),
      ),
      body: Stack(
        children: stackChildren,
      ),
      floatingActionButton: FloatingActionButton(
        onPressed: predictImagePicker,
        tooltip: 'Pick Image',
        child: Icon(Icons.image),
      ),
    );
  }
}

更多关于Flutter机器学习插件tensorflow_lite_flutter的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html

1 回复

更多关于Flutter机器学习插件tensorflow_lite_flutter的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html


当然,以下是一个关于如何在Flutter应用中使用tensorflow_lite_flutter插件的详细代码示例。这个示例将展示如何加载一个TensorFlow Lite模型,进行推理,并处理输出结果。

1. 添加依赖

首先,你需要在pubspec.yaml文件中添加tensorflow_lite_flutter依赖:

dependencies:
  flutter:
    sdk: flutter
  tensorflow_lite_flutter: ^2.5.0  # 请检查最新版本号

2. 导入必要的包

在你的Dart文件中(比如main.dart),导入必要的包:

import 'package:flutter/material.dart';
import 'package:tensorflow_lite_flutter/tensorflow_lite_flutter.dart';
import 'dart:typed_data/typed_data.dart';
import 'dart:ui' as ui;
import 'dart:convert';

3. 加载模型并初始化解释器

在你的应用中,你需要加载一个已经训练好的TensorFlow Lite模型,并初始化解释器。这里假设你有一个名为model.tflite的模型文件放在assets目录下。

class _MyAppState extends State<MyApp> {
  late Interpreter _interpreter;
  late Uint8List _inputBuffer;
  late List<Float32List> _outputBuffers;

  @override
  void initState() {
    super.initState();
    loadModel().then((interpreter) {
      setState(() {
        _interpreter = interpreter;
        // 初始化输入和输出缓冲区
        _inputBuffer = Uint8List(1 * 28 * 28); // 假设输入是28x28的灰度图像
        _outputBuffers = List.filled(1, Float32List(10)); // 假设输出是10个类别的概率
      });
    });
  }

  Future<Interpreter> loadModel() async {
    // 确保模型文件在assets目录下
    var model = await DefaultAssetBundle.of(context)
        .loadString('assets/model.tflite');
    var bytesBuffer = Uint8List.fromList(utf8.decode(model).codeUnits);
    var interpreter = await Interpreter.fromBuffer(bytesBuffer);
    return interpreter;
  }

4. 进行推理

接下来,你可以使用解释器进行推理。这里我们假设你有一个预处理后的图像数据imageBuffer,它已经被转换为适合模型输入的格式(比如28x28的灰度图像)。

void runInference(Uint8List imageBuffer) {
    // 将图像数据复制到输入缓冲区(如果需要,进行必要的预处理)
    _inputBuffer.setAll(0, imageBuffer);

    // 运行模型推理
    _interpreter.run(_inputBuffer, _outputBuffers).then((result) {
      // 处理输出结果
      var output = _outputBuffers[0];
      // 假设我们想要得到最大概率的类别
      var bestClassIdx = output.indexOf(output.reduce((a, b) => math.max(a, b)));
      print("Predicted class: $bestClassIdx");
    });
  }

5. 完整示例

下面是一个完整的示例,展示如何在Flutter应用中集成并使用tensorflow_lite_flutter插件:

import 'package:flutter/material.dart';
import 'package:tensorflow_lite_flutter/tensorflow_lite_flutter.dart';
import 'dart:typed_data/typed_data.dart';
import 'dart:ui' as ui;
import 'dart:convert';
import 'dart:math' as math;

void main() {
  runApp(MyApp());
}

class MyApp extends StatelessWidget {
  @override
  Widget build(BuildContext context) {
    return MaterialApp(
      title: 'TensorFlow Lite Flutter Demo',
      theme: ThemeData(
        primarySwatch: Colors.blue,
      ),
      home: MyHomePage(),
    );
  }
}

class MyHomePage extends StatefulWidget {
  @override
  _MyAppState createState() => _MyAppState();
}

class _MyAppState extends State<MyHomePage> {
  late Interpreter _interpreter;
  late Uint8List _inputBuffer;
  late List<Float32List> _outputBuffers;

  @override
  void initState() {
    super.initState();
    loadModel().then((interpreter) {
      setState(() {
        _interpreter = interpreter;
        _inputBuffer = Uint8List(1 * 28 * 28); // 假设输入是28x28的灰度图像
        _outputBuffers = List.filled(1, Float32List(10)); // 假设输出是10个类别的概率
      });
    });
  }

  Future<Interpreter> loadModel() async {
    var model = await DefaultAssetBundle.of(context)
        .loadString('assets/model.tflite');
    var bytesBuffer = Uint8List.fromList(utf8.decode(model).codeUnits);
    var interpreter = await Interpreter.fromBuffer(bytesBuffer);
    return interpreter;
  }

  void runInference(Uint8List imageBuffer) {
    _inputBuffer.setAll(0, imageBuffer);
    _interpreter.run(_inputBuffer, _outputBuffers).then((result) {
      var output = _outputBuffers[0];
      var bestClassIdx = output.indexOf(output.reduce((a, b) => math.max(a, b)));
      print("Predicted class: $bestClassIdx");
    });
  }

  @override
  Widget build(BuildContext context) {
    return Scaffold(
      appBar: AppBar(
        title: Text('TensorFlow Lite Flutter Demo'),
      ),
      body: Center(
        child: ElevatedButton(
          onPressed: () {
            // 这里你应该有一个方法来获取或生成预处理后的图像数据
            // 这里只是用一个随机生成的数组作为示例
            Uint8List randomImageBuffer = Uint8List(1 * 28 * 28).map((_) => (Math.random() * 255).toInt()).toList();
            runInference(randomImageBuffer);
          },
          child: Text('Run Inference'),
        ),
      ),
    );
  }
}

这个示例展示了如何加载一个TensorFlow Lite模型,初始化解释器,并进行推理。请注意,你需要将model.tflite文件放在assets目录下,并确保它已正确配置在pubspec.yaml文件中。此外,根据你的模型输入和输出格式,你可能需要对输入数据进行预处理,并对输出结果进行后处理。

回到顶部