Flutter机器学习推理插件tflite_maven的使用

发布于 1周前 作者 caililin 来自 Flutter

Flutter机器学习推理插件tflite_maven的使用

tflite_maven

tflite_maventflite 插件的一个分支。它是一个用于访问 TensorFlow Lite API 的 Flutter 插件。支持图像分类、对象检测(SSD 和 YOLO)、Pix2Pix、Deeplab 和 PoseNet,适用于 iOS 和 Android。

目录

破坏性更改

自1.1.0以来:
  1. iOS TensorFlow Lite 库从 TensorFlowLite 1.x 升级到 TensorFlowLiteObjC 2.x。对原生代码的更改标记为 <code>TFLITE2</code>
自1.0.0以来:
  1. 更新为 TensorFlow Lite API v1.12.0。
  2. 不再接受参数 inputSizenumChannels。它们将从输入张量中检索。
  3. numThreads 移动到了 Tflite.loadModel

安装

在您的 pubspec.yaml 文件中添加 tflite 作为依赖项。

Android

android/app/build.gradle 中,添加以下设置在 android 块内:

aaptOptions {
    noCompress 'tflite'
    noCompress 'lite'
}

iOS

解决 iOS 上的构建错误:

  • `‘vector’ 文件未找到"

    在 Xcode 中打开 ios/Runner.xcworkspace,点击 Runner > Targets > Runner > Build Settings,搜索 Compile Sources As,将其值更改为 Objective-C++

  • `‘tensorflow/lite/kernels/register.h’ 文件未找到"

    插件假设 tensorflow 头文件位于路径 “tensorflow/lite/kernels”。

    但是,对于 tensorflow 的早期版本,头文件路径为 “tensorflow/contrib/lite/kernels”。

    使用 CONTRIB_PATH 来切换路径。取消注释此处的定义:

    //#define CONTRIB_PATH
    

使用

  1. 创建一个 assets 文件夹,并将标签文件和模型文件放入其中。在 pubspec.yaml 中添加:
assets:
  - assets/labels.txt
  - assets/mobilenet_v1_1.0_224.tflite
  1. 导入库:
import 'package:tflite/tflite.dart';
  1. 加载模型和标签:
String res = await Tflite.loadModel(
  model: "assets/mobilenet_v1_1.0_224.tflite",
  labels: "assets/labels.txt",
  numThreads: 1, // 默认为 1
  isAsset: true, // 默认为 true,如果资源在 assets 之外加载则设为 false
  useGpuDelegate: false // 默认为 false,如果要使用 GPU 委托则设为 true
);
  1. 查看相应模型的部分。

  2. 释放资源:

await Tflite.close();

GPU 委托

当使用 GPU 委托时,请参阅此步骤以获取更好的性能。

图像分类

输出格式:

{
  index: 0,
  label: "person",
  confidence: 0.629
}

运行于图像:

var recognitions = await Tflite.runModelOnImage(
  path: filepath,   // 必须
  imageMean: 0.0,   // 默认为 117.0
  imageStd: 255.0,  // 默认为 1.0
  numResults: 2,    // 默认为 5
  threshold: 0.2,   // 默认为 0.1
  asynch: true      // 默认为 true
);

运行于二进制:

var recognitions = await Tflite.runModelOnBinary(
  binary: imageToByteListFloat32(image, 224, 127.5, 127.5),// 必须
  numResults: 6,    // 默认为 5
  threshold: 0.05,  // 默认为 0.1
  asynch: true      // 默认为 true
);

Uint8List imageToByteListFloat32(
    img.Image image, int inputSize, double mean, double std) {
  var convertedBytes = Float32List(1 * inputSize * inputSize * 3);
  var buffer = Float32List.view(convertedBytes.buffer);
  int pixelIndex = 0;
  for (var i = 0; i < inputSize; i++) {
    for (var j = 0; j < inputSize; j++) {
      var pixel = image.getPixel(j, i);
      buffer[pixelIndex++] = (img.getRed(pixel) - mean) / std;
      buffer[pixelIndex++] = (img.getGreen(pixel) - mean) / std;
      buffer[pixelIndex++] = (img.getBlue(pixel) - mean) / std;
    }
  }
  return convertedBytes.buffer.asUint8List();
}

Uint8List imageToByteListUint8(img.Image image, int inputSize) {
  var convertedBytes = Uint8List(1 * inputSize * inputSize * 3);
  var buffer = Uint8List.view(convertedBytes.buffer);
  int pixelIndex = 0;
  for (var i = 0; i < inputSize; i++) {
    for (var j = 0; j < inputSize; j++) {
      var pixel = image.getPixel(j, i);
      buffer[pixelIndex++] = img.getRed(pixel);
      buffer[pixelIndex++] = img.getGreen(pixel);
      buffer[pixelIndex++] = img.getBlue(pixel);
    }
  }
  return convertedBytes.buffer.asUint8List();
}

运行于图像流(视频帧):

var recognitions = await Tflite.runModelOnFrame(
  bytesList: img.planes.map((plane) {return plane.bytes;}).toList(),// 必须
  imageHeight: img.height,
  imageWidth: img.width,
  imageMean: 127.5,   // 默认为 127.5
  imageStd: 127.5,    // 默认为 127.5
  rotation: 90,       // 默认为 90,仅限 Android
  numResults: 2,      // 默认为 5
  threshold: 0.1,     // 默认为 0.1
  asynch: true        // 默认为 true
);

对象检测

输出格式:

{
  detectedClass: "hot dog",
  confidenceInClass: 0.123,
  rect: {
    x: 0.15,
    y: 0.33,
    w: 0.80,
    h: 0.27
  }
}

SSD MobileNet:

运行于图像:
var recognitions = await Tflite.detectObjectOnImage(
  path: filepath,       // 必须
  model: "SSDMobileNet",
  imageMean: 127.5,     
  imageStd: 127.5,      
  threshold: 0.4,       // 默认为 0.1
  numResultsPerClass: 2,// 默认为 5
  asynch: true          // 默认为 true
);
运行于二进制:
var recognitions = await Tflite.detectObjectOnBinary(
  binary: imageToByteListUint8(resizedImage, 300), // 必须
  model: "SSDMobileNet",  
  threshold: 0.4,                                  // 默认为 0.1
  numResultsPerClass: 2,                           // 默认为 5
  asynch: true                                     // 默认为 true
);
运行于图像流(视频帧):
var recognitions = await Tflite.detectObjectOnFrame(
  bytesList: img.planes.map((plane) {return plane.bytes;}).toList(),// 必须
  model: "SSDMobileNet",  
  imageHeight: img.height,
  imageWidth: img.width,
  imageMean: 127.5,   // 默认为 127.5
  imageStd: 127.5,    // 默认为 127.5
  rotation: 90,       // 默认为 90,仅限 Android
  numResults: 2,      // 默认为 5
  threshold: 0.1,     // 默认为 0.1
  asynch: true        // 默认为 true
);

Tiny YOLOv2:

运行于图像:
var recognitions = await Tflite.detectObjectOnImage(
  path: filepath,       // 必须
  model: "YOLO",      
  imageMean: 0.0,       
  imageStd: 255.0,      
  threshold: 0.3,       // 默认为 0.1
  numResultsPerClass: 2,// 默认为 5
  anchors: anchors,     // 默认为 [0.57273,0.677385,1.87446,2.06253,3.33843,5.47434,7.88282,3.52778,9.77052,9.16828]
  blockSize: 32,        // 默认为 32
  numBoxesPerBlock: 5,  // 默认为 5
  asynch: true          // 默认为 true
);
运行于二进制:
var recognitions = await Tflite.detectObjectOnBinary(
  binary: imageToByteListFloat32(resizedImage, 416, 0.0, 255.0), // 必须
  model: "YOLO",  
  threshold: 0.3,       // 默认为 0.1
  numResultsPerClass: 2,// 默认为 5
  anchors: anchors,     // 默认为 [0.57273,0.677385,1.87446,2.06253,3.33843,5.47434,7.88282,3.52778,9.77052,9.16828]
  blockSize: 32,        // 默认为 32
  numBoxesPerBlock: 5,  // 默认为 5
  asynch: true          // 默认为 true
);
运行于图像流(视频帧):
var recognitions = await Tflite.detectObjectOnFrame(
  bytesList: img.planes.map((plane) {return plane.bytes;}).toList(),// 必须
  model: "YOLO",  
  imageHeight: img.height,
  imageWidth: img.width,
  imageMean: 0,         // 默认为 127.5
  imageStd: 255.0,      // 默认为 127.5
  numResults: 2,        // 默认为 5
  threshold: 0.1,       // 默认为 0.1
  numResultsPerClass: 2,// 默认为 5
  anchors: anchors,     // 默认为 [0.57273,0.677385,1.87446,2.06253,3.33843,5.47434,7.88282,3.52778,9.77052,9.16828]
  blockSize: 32,        // 默认为 32
  numBoxesPerBlock: 5,  // 默认为 5
  asynch: true          // 默认为 true
);

Pix2Pix

var result = await runPix2PixOnImage(
  path: filepath,       // 必须
  imageMean: 0.0,       // 默认为 0.0
  imageStd: 255.0,      // 默认为 255.0
  asynch: true      // 默认为 true
);

Deeplab

var result = await runSegmentationOnImage(
  path: filepath,     // 必须
  imageMean: 0.0,     // 默认为 0.0
  imageStd: 255.0,    // 默认为 255.0
  labelColors: [...], // 默认为 https://github.com/shaqian/flutter_tflite/blob/master/lib/tflite.dart#L219
  outputType: "png",  // 默认为 "png"
  asynch: true        // 默认为 true
);

PoseNet

var result = await runPoseNetOnImage(
  path: filepath,     // 必须
  imageMean: 125.0,   // 默认为 125.0
  imageStd: 125.0,    // 默认为 125.0
  numResults: 2,      // 默认为 5
  threshold: 0.7,     // 默认为 0.5
  nmsRadius: 10,      // 默认为 20
  asynch: true        // 默认为 true
);

示例

静态图像中的预测

请参见 示例

实时检测

请参见 flutter_realtime_detection


完整示例 Demo

import 'dart:async';
import 'dart:io';
import 'dart:math';
import 'dart:typed_data';
import 'package:flutter/material.dart';
import 'package:flutter/services.dart';
import 'package:image/image.dart' as img;

import 'package:tflite/tflite.dart';
import 'package:image_picker/image_picker.dart';

void main() => runApp(new App());

const String mobile = "MobileNet";
const String ssd = "SSD MobileNet";
const String yolo = "Tiny YOLOv2";
const String deeplab = "DeepLab";
const String posenet = "PoseNet";

class App extends StatelessWidget {
  [@override](/user/override)
  Widget build(BuildContext context) {
    return MaterialApp(
      home: MyApp(),
    );
  }
}

class MyApp extends StatefulWidget {
  [@override](/user/override)
  _MyAppState createState() => new _MyAppState();
}

class _MyAppState extends State<MyApp> {
  File _image;
  List _recognitions;
  String _model = mobile;
  double _imageHeight;
  double _imageWidth;
  bool _busy = false;

  Future predictImagePicker() async {
    final ImagePicker _picker = ImagePicker();
    var image = await _picker.pickImage(source: ImageSource.gallery);
    if (image == null) return;
    setState(() {
      _busy = true;
    });
    predictImage(File(image.path));
  }

  Future predictImage(File image) async {
    if (image == null) return;

    switch (_model) {
      case yolo:
        await yolov2Tiny(image);
        break;
      case ssd:
        await ssdMobileNet(image);
        break;
      case deeplab:
        await segmentMobileNet(image);
        break;
      case posenet:
        await poseNet(image);
        break;
      default:
        await recognizeImage(image);
      // await recognizeImageBinary(image);
    }

    new FileImage(image)
        .resolve(new ImageConfiguration())
        .addListener(ImageStreamListener((ImageInfo info, bool _) {
      setState(() {
        _imageHeight = info.image.height.toDouble();
        _imageWidth = info.image.width.toDouble();
      });
    }));

    setState(() {
      _image = image;
      _busy = false;
    });
  }

  [@override](/user/override)
  void initState() {
    super.initState();

    _busy = true;

    loadModel().then((val) {
      setState(() {
        _busy = false;
      });
    });
  }

  Future loadModel() async {
    Tflite.close();
    try {
      String res;
      switch (_model) {
        case yolo:
          res = await Tflite.loadModel(
            model: "assets/yolov2_tiny.tflite",
            labels: "assets/yolov2_tiny.txt",
            // useGpuDelegate: true,
          );
          break;
        case ssd:
          res = await Tflite.loadModel(
            model: "assets/ssd_mobilenet.tflite",
            labels: "assets/ssd_mobilenet.txt",
            // useGpuDelegate: true,
          );
          break;
        case deeplab:
          res = await Tflite.loadModel(
            model: "assets/deeplabv3_257_mv_gpu.tflite",
            labels: "assets/deeplabv3_257_mv_gpu.txt",
            // useGpuDelegate: true,
          );
          break;
        case posenet:
          res = await Tflite.loadModel(
            model: "assets/posenet_mv1_075_float_from_checkpoints.tflite",
            // useGpuDelegate: true,
          );
          break;
        default:
          res = await Tflite.loadModel(
            model: "assets/mobilenet_v1_1.0_224.tflite",
            labels: "assets/mobilenet_v1_1.0_224.txt",
            // useGpuDelegate: true,
          );
      }
      print(res);
    } on PlatformException {
      print('Failed to load model.');
    }
  }

  Uint8List imageToByteListFloat32(
      img.Image image, int inputSize, double mean, double std) {
    var convertedBytes = Float32List(1 * inputSize * inputSize * 3);
    var buffer = Float32List.view(convertedBytes.buffer);
    int pixelIndex = 0;
    for (var i = 0; i < inputSize; i++) {
      for (var j = 0; j < inputSize; j++) {
        var pixel = image.getPixel(j, i);
        buffer[pixelIndex++] = (img.getRed(pixel) - mean) / std;
        buffer[pixelIndex++] = (img.getGreen(pixel) - mean) / std;
        buffer[pixelIndex++] = (img.getBlue(pixel) - mean) / std;
      }
    }
    return convertedBytes.buffer.asUint8List();
  }

  Uint8List imageToByteListUint8(img.Image image, int inputSize) {
    var convertedBytes = Uint8List(1 * inputSize * inputSize * 3);
    var buffer = Uint8List.view(convertedBytes.buffer);
    int pixelIndex = 0;
    for (var i = 0; i < inputSize; i++) {
      for (var j = 0; j < inputSize; j++) {
        var pixel = image.getPixel(j, i);
        buffer[pixelIndex++] = img.getRed(pixel);
        buffer[pixelIndex++] = img.getGreen(pixel);
        buffer[pixelIndex++] = img.getBlue(pixel);
      }
    }
    return convertedBytes.buffer.asUint8List();
  }

  Future recognizeImage(File image) async {
    int startTime = new DateTime.now().millisecondsSinceEpoch;
    var recognitions = await Tflite.runModelOnImage(
      path: image.path,
      numResults: 6,
      threshold: 0.05,
      imageMean: 127.5,
      imageStd: 127.5,
    );
    setState(() {
      _recognitions = recognitions;
    });
    int endTime = new DateTime.now().millisecondsSinceEpoch;
    print("Inference took ${endTime - startTime}ms");
  }

  Future recognizeImageBinary(File image) async {
    int startTime = new DateTime.now().millisecondsSinceEpoch;
    var imageBytes = (await rootBundle.load(image.path)).buffer;
    img.Image oriImage = img.decodeJpg(imageBytes.asUint8List());
    img.Image resizedImage = img.copyResize(oriImage, height: 224, width: 224);
    var recognitions = await Tflite.runModelOnBinary(
      binary: imageToByteListFloat32(resizedImage, 224, 127.5, 127.5),
      numResults: 6,
      threshold: 0.05,
    );
    setState(() {
      _recognitions = recognitions;
    });
    int endTime = new DateTime.now().millisecondsSinceEpoch;
    print("Inference took ${endTime - startTime}ms");
  }

  Future yolov2Tiny(File image) async {
    int startTime = new DateTime.now().millisecondsSinceEpoch;
    var recognitions = await Tflite.detectObjectOnImage(
      path: image.path,
      model: "YOLO",
      threshold: 0.3,
      imageMean: 0.0,
      imageStd: 255.0,
      numResultsPerClass: 1,
    );
    // var imageBytes = (await rootBundle.load(image.path)).buffer;
    // img.Image oriImage = img.decodeJpg(imageBytes.asUint8List());
    // img.Image resizedImage = img.copyResize(oriImage, 416, 416);
    // var recognitions = await Tflite.detectObjectOnBinary(
    //   binary: imageToByteListFloat32(resizedImage, 416, 0.0, 255.0),
    //   model: "YOLO",
    //   threshold: 0.3,
    //   numResultsPerClass: 1,
    // );
    setState(() {
      _recognitions = recognitions;
    });
    int endTime = new DateTime.now().millisecondsSinceEpoch;
    print("Inference took ${endTime - startTime}ms");
  }

  Future ssdMobileNet(File image) async {
    int startTime = new DateTime.now().millisecondsSinceEpoch;
    var recognitions = await Tflite.detectObjectOnImage(
      path: image.path,
      numResultsPerClass: 1,
    );
    // var imageBytes = (await rootBundle.load(image.path)).buffer;
    // img.Image oriImage = img.decodeJpg(imageBytes.asUint8List());
    // img.Image resizedImage = img.copyResize(oriImage, 300, 300);
    // var recognitions = await Tflite.detectObjectOnBinary(
    //   binary: imageToByteListUint8(resizedImage, 300),
    //   numResultsPerClass: 1,
    // );
    setState(() {
      _recognitions = recognitions;
    });
    int endTime = new DateTime.now().millisecondsSinceEpoch;
    print("Inference took ${endTime - startTime}ms");
  }

  Future segmentMobileNet(File image) async {
    int startTime = new DateTime.now().millisecondsSinceEpoch;
    var recognitions = await Tflite.runSegmentationOnImage(
      path: image.path,
      imageMean: 127.5,
      imageStd: 127.5,
    );

    setState(() {
      _recognitions = recognitions;
    });
    int endTime = new DateTime.now().millisecondsSinceEpoch;
    print("Inference took ${endTime - startTime}");
  }

  Future poseNet(File image) async {
    int startTime = new DateTime.now().millisecondsSinceEpoch;
    var recognitions = await Tflite.runPoseNetOnImage(
      path: image.path,
      numResults: 2,
    );

    print(recognitions);

    setState(() {
      _recognitions = recognitions;
    });
    int endTime = new DateTime.now().millisecondsSinceEpoch;
    print("Inference took ${endTime - startTime}ms");
  }

  onSelect(model) async {
    setState(() {
      _busy = true;
      _model = model;
      _recognitions = null;
    });
    await loadModel();

    if (_image != null)
      predictImage(_image);
    else
      setState(() {
        _busy = false;
      });
  }

  List<Widget> renderBoxes(Size screen) {
    if (_recognitions == null) return [];
    if (_imageHeight == null || _imageWidth == null) return [];

    double factorX = screen.width;
    double factorY = _imageHeight / _imageWidth * screen.width;
    Color blue = Color.fromRGBO(37, 213, 253, 1.0);
    return _recognitions.map((re) {
      return Positioned(
        left: re["rect"]["x"] * factorX,
        top: re["rect"]["y"] * factorY,
        width: re["rect"]["w"] * factorX,
        height: re["rect"]["h"] * factorY,
        child: Container(
          decoration: BoxDecoration(
            borderRadius: BorderRadius.all(Radius.circular(8.0)),
            border: Border.all(
              color: blue,
              width: 2,
            ),
          ),
          child: Text(
            "${re["detectedClass"]} ${(re["confidenceInClass"] * 100).toStringAsFixed(0)}%",
            style: TextStyle(
              background: Paint()..color = blue,
              color: Colors.white,
              fontSize: 12.0,
            ),
          ),
        ),
      );
    }).toList();
  }

  List<Widget> renderKeypoints(Size screen) {
    if (_recognitions == null) return [];
    if (_imageHeight == null || _imageWidth == null) return [];

    double factorX = screen.width;
    double factorY = _imageHeight / _imageWidth * screen.width;

    var lists = <Widget>[];
    _recognitions.forEach((re) {
      var color = Color((Random().nextDouble() * 0xFFFFFF).toInt() << 0)
          .withOpacity(1.0);
      var list = re["keypoints"].values.map<Widget>((k) {
        return Positioned(
          left: k["x"] * factorX - 6,
          top: k["y"] * factorY - 6,
          width: 100,
          height: 12,
          child: Text(
            "● ${k["part"]}",
            style: TextStyle(
              color: color,
              fontSize: 12.0,
            ),
          ),
        );
      }).toList();

      lists..addAll(list);
    });

    return lists;
  }

  [@override](/user/override)
  Widget build(BuildContext context) {
    Size size = MediaQuery.of(context).size;
    List<Widget> stackChildren = [];

    if (_model == deeplab && _recognitions != null) {
      stackChildren.add(Positioned(
        top: 0.0,
        left: 0.0,
        width: size.width,
        child: _image == null
            ? Text('No image selected.')
            : Container(
                decoration: BoxDecoration(
                    image: DecorationImage(
                        alignment: Alignment.topCenter,
                        image: MemoryImage(_recognitions),
                        fit: BoxFit.fill)),
                child: Opacity(opacity: 0.3, child: Image.file(_image))),
      ));
    } else {
      stackChildren.add(Positioned(
        top: 0.0,
        left: 0.0,
        width: size.width,
        child: _image == null ? Text('No image selected.') : Image.file(_image),
      ));
    }

    if (_model == mobile) {
      stackChildren.add(Center(
        child: Column(
          children: _recognitions != null
              ? _recognitions.map((res) {
                  return Text(
                    "${res["index"]} - ${res["label"]}: ${res["confidence"].toStringAsFixed(3)}",
                    style: TextStyle(
                      color: Colors.black,
                      fontSize: 20.0,
                      background: Paint()..color = Colors.white,
                    ),
                  );
                }).toList()
              : [],
        ),
      ));
    } else if (_model == ssd || _model == yolo) {
      stackChildren.addAll(renderBoxes(size));
    } else if (_model == posenet) {
      stackChildren.addAll(renderKeypoints(size));
    }

    if (_busy) {
      stackChildren.add(const Opacity(
        child: ModalBarrier(dismissible: false, color: Colors.grey),
        opacity: 0.3,
      ));
      stackChildren.add(const Center(child: CircularProgressIndicator()));
    }

    return Scaffold(
      appBar: AppBar(
        title: const Text('tflite example app'),
        actions: <Widget>[
          PopupMenuButton<String>(
            onSelected: onSelect,
            itemBuilder: (context) {
              List<PopupMenuEntry<String>> menuEntries = [
                const PopupMenuItem<String>(
                  child: Text(mobile),
                  value: mobile,
                ),
                const PopupMenuItem<String>(
                  child: Text(ssd),
                  value: ssd,
                ),
                const PopupMenuItem<String>(
                  child: Text(yolo),
                  value: yolo,
                ),
                const PopupMenuItem<String>(
                  child: Text(deeplab),
                  value: deeplab,
                ),
                const PopupMenuItem<String>(
                  child: Text(posenet),
                  value: posenet,
                )
              ];
              return menuEntries;
            },
          )
        ],
      ),
      body: Stack(
        children: stackChildren,
      ),
      floatingActionButton: FloatingActionButton(
        onPressed: predictImagePicker,
        tooltip: 'Pick Image',
        child: Icon(Icons.image),
      ),
    );
  }
}

更多关于Flutter机器学习推理插件tflite_maven的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html

1 回复

更多关于Flutter机器学习推理插件tflite_maven的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html


在Flutter中,使用tflite_flutter插件来进行机器学习推理是一种常见的方法。尽管你提到了tflite_maven,但需要注意的是,tflite_maven并不是Flutter社区广泛认可或使用的官方插件名。通常,Flutter开发者会使用tflite_fluttertflite插件来集成TensorFlow Lite模型。以下是如何在Flutter项目中设置和使用tflite_flutter插件的示例代码。

1. 添加依赖

首先,在你的pubspec.yaml文件中添加tflite_flutter依赖:

dependencies:
  flutter:
    sdk: flutter
  tflite_flutter: ^0.9.0+2  # 请检查最新版本号

确保运行flutter pub get来安装依赖。

2. 加载和推理TensorFlow Lite模型

假设你有一个已训练的TensorFlow Lite模型(.tflite文件),并且该模型已经放置在Flutter项目的assets目录下。

2.1 在pubspec.yaml中声明模型文件

flutter:
  assets:
    - assets/model.tflite

2.2 在Flutter代码中加载和推理模型

import 'package:flutter/material.dart';
import 'package:tflite_flutter/tflite_flutter.dart';

void main() => runApp(MyApp());

class MyApp extends StatefulWidget {
  @override
  _MyAppState createState() => _MyAppState();
}

class _MyAppState extends State<MyApp> {
  late Interpreter _interpreter;
  List<dynamic> _outputs = [];

  @override
  void initState() {
    super.initState();
    loadModel().then((value) {
      setState(() {});
    });
  }

  Future<void> loadModel() async {
    // 加载模型
    _interpreter = await Interpreter.fromAsset('assets/model.tflite');
  }

  Future<void> runModel(List<List<double>> inputData) async {
    // 运行模型推理
    _outputs = await _interpreter.run(inputData);
    setState(() {});
  }

  @override
  Widget build(BuildContext context) {
    return MaterialApp(
      home: Scaffold(
        appBar: AppBar(
          title: Text('TFLite Flutter Demo'),
        ),
        body: Center(
          child: Column(
            mainAxisAlignment: MainAxisAlignment.center,
            children: [
              ElevatedButton(
                onPressed: () async {
                  // 示例输入数据(需要根据你的模型调整)
                  List<List<double>> inputData = [
                    [1.0, 2.0, 3.0], // 示例输入,需要根据实际模型输入维度调整
                  ];
                  await runModel(inputData);
                  // 打印输出结果
                  print('Model Output: $_outputs');
                },
                child: Text('Run Model'),
              ),
            ],
          ),
        ),
      ),
    );
  }

  @override
  void dispose() {
    // 释放资源
    _interpreter.close();
    super.dispose();
  }
}

注意事项

  1. 模型输入和输出:确保你的输入数据与模型的输入维度和类型匹配。同样,解析输出数据也需要根据你的模型结构来进行。
  2. 资源管理:在dispose方法中关闭解释器以释放资源。
  3. 错误处理:在实际应用中,添加适当的错误处理逻辑,以处理模型加载失败或推理错误的情况。

这个示例展示了如何在Flutter中使用tflite_flutter插件加载和运行TensorFlow Lite模型。根据你的具体需求,你可能需要调整输入数据的准备方式和输出数据的解析方式。

回到顶部