Flutter深度学习模型部署插件flutter_d2go的使用

发布于 1周前 作者 sinazl 来自 Flutter

Flutter深度学习模型部署插件flutter_d2go的使用

flutter_d2go

Flutter插件,用于在移动设备上进行对象检测(Android和iOS)、关键点估计(Android和iOS)和实例分割(仅限Android)。该插件基于Facebook Research的d2go模型。

特性

  • 获取类别和边界框(Android和iOS)
  • 获取关键点(Android和iOS)
  • 获取掩码数据(仅限Android)
  • 对相机流图像进行实时推理(仅限Android)

预览

实时推理相机流图像

实时推理相机流图像 关键点实时推理相机流图像

对象检测和实例分割

对象检测和实例分割

关键点估计

关键点估计

安装

flutter_d2go添加到您的pubspec.yaml文件中,并将d2go模型和类别文件放入assets目录。

dependencies:
  flutter_d2go: ^版本号

assets:
  - assets/models/d2go.pt
  - assets/models/classes.txt

使用

1. 加载模型和类别

模型格式为Pytorch。类别文件格式可以查看这里。

await FlutterD2go.loadModel(
    modelPath: 'assets/models/d2go.pt',     // 必填
    labelPath: 'assets/models/classes.txt', // 必填
);

2. 获取静态图像预测

List<Map<String, dynamic>> output = await FlutterD2go.getImagePrediction(
    image: image,           // 必填 File(dart:io) 图像
    width: 320,             // 默认值为320
    height: 320,            // 默认值为320
    mean: [0.0, 0.0, 0.0],  // 默认值为[0.0, 0.0, 0.0]
    std: [1.0, 1.0, 1.0],   // 默认值为[1.0, 1.0, 1.0]
    minScore: 0.7,          // 默认值为0.5
);

3. 获取流图像预测

List<Map<String, dynamic>> output = await FlutterD2go.getStreamImagePrediction(
    imageBytesList: cameraImage.planes.map((plane) => plane.bytes).toList(),             // 必填 List<Uint8List> 图像字节数组
    imageBytesPerPixel: cameraImage.planes.map((plane) => plane.bytesPerPixel).toList(), // 默认值为[1, 2, 2]
    width: cameraImage.width,               // 默认值为720
    height: cameraImage.height,             // 默认值为1280
    inputWidth: 320,                        // 默认值为320
    inputHeight: 320,                       // 默认值为320
    mean: [0.0, 0.0, 0.0],                  // 默认值为[0.0, 0.0, 0.0]
    std: [1.0, 1.0, 1.0],                   // 默认值为[1.0, 1.0, 1.0]
    minScore: 0.7,                          // 默认值为0.5
    rotation: 90,                           // 默认值为0
);

预测输出格式

rect是原始图像的比例。 maskkeypoints 取决于d2go模型是否包含mask和关键点。

mask 将是一个Uint8List类型的位图图像字节。 keypoints 将是一个包含17个(x, y)的列表。

[
  {
    "rect": {
      "left": 74.65713500976562,
      "top": 76.94147491455078,
      "right": 350.64324951171875,
      "bottom": 323.0279846191406
    },
    "mask": [66, 77, 122, 0, 0, 0, 0, 0, 0, 0, 122, ...],
    "keypoints": [[117.14504, 77.277405], [122.74037, 73.53044], [105.95437, 73.53044], ...],
    "confidenceInClass": 0.985002338886261,
    "detectedClass": "bicycle"
  }, // 对于每个实例
...
]

问题

如果您发现任何错误或希望添加新功能,请联系此处。


完整示例代码

以下是一个完整的示例代码,展示了如何使用flutter_d2go插件进行实时推理和静态图像推理。

import 'dart:async';
import 'dart:io';
import 'dart:typed_data';

import 'package:camera/camera.dart';
import 'package:flutter/material.dart';
import 'package:flutter/services.dart';
import 'package:flutter_d2go/flutter_d2go.dart';
import 'package:image_picker/image_picker.dart';
import 'package:path_provider/path_provider.dart';

List<CameraDescription> cameras = [];

Future<void> main() async {
  try {
    WidgetsFlutterBinding.ensureInitialized();
    cameras = await availableCameras();
  } on CameraException catch (e) {
    debugPrint('Error: ${e.code}, Message: ${e.description}');
  }
  runApp(
    const MaterialApp(
      debugShowCheckedModeBanner: false,
      home: MyApp(),
    ),
  );
}

class MyApp extends StatefulWidget {
  const MyApp({Key? key}) : super(key: key);

  [@override](/user/override)
  State<MyApp> createState() => _MyAppState();
}

class _MyAppState extends State<MyApp> {
  List<RecognitionModel>? _recognitions;
  File? _selectedImage;
  final List<String> _imageList = ['test1.png', 'test2.jpeg', 'test3.png'];
  int _index = 0;
  int? _imageWidth;
  int? _imageHeight;
  final ImagePicker _picker = ImagePicker();

  CameraController? controller;
  bool _isDetecting = false;
  bool _isLiveModeOn = false;

  [@override](/user/override)
  void initState() {
    super.initState();
    loadModel();
  }

  [@override](/user/override)
  void dispose() {
    controller?.dispose();
    super.dispose();
  }

  Future<void> live() async {
    controller = CameraController(
      cameras[0],
      ResolutionPreset.high,
    );
    await controller!.initialize().then(
      (_) {
        if (!mounted) {
          return;
        }
        setState(() {});
      },
    );
    await controller!.startImageStream(
      (CameraImage cameraImage) async {
        if (_isDetecting) return;

        _isDetecting = true;

        await FlutterD2go.getStreamImagePrediction(
          imageBytesList:
              cameraImage.planes.map((plane) => plane.bytes).toList(),
          width: cameraImage.width,
          height: cameraImage.height,
          minScore: 0.5,
          rotation: 90,
        ).then(
          (predictions) {
            List<RecognitionModel>? recognitions;
            if (predictions.isNotEmpty) {
              recognitions = predictions.map(
                (e) {
                  return RecognitionModel(
                      Rectangle(
                        e['rect']['left'],
                        e['rect']['top'],
                        e['rect']['right'],
                        e['rect']['bottom'],
                      ),
                      e['mask'],
                      e['keypoints'] != null
                          ? (e['keypoints'] as List)
                          .map((k) => Keypoint(k[0], k[1]))
                          .toList()
                          : null,
                      e['confidenceInClass'],
                      e['detectedClass']);
                },
              ).toList();
            }
            setState(
              () {
                // With android, the inference result of the camera streaming image is tilted 90 degrees,
                // so the vertical and horizontal directions are reversed.
                _imageWidth = cameraImage.height;
                _imageHeight = cameraImage.width;
                _recognitions = recognitions;
              },
            );
          },
        ).whenComplete(
          () {
            setState(() => _isDetecting = false);
          },
        );
      },
    );
  }

  Future loadModel() async {
    String modelPath = 'assets/models/d2go.ptl';
    String labelPath = 'assets/models/classes.txt';
    try {
      await FlutterD2go.loadModel(
        modelPath: modelPath,
        labelPath: labelPath,
      );
      setState(() {});
    } on PlatformException {
      debugPrint('Load model or label file failed.');
    }
  }

  Future detect() async {
    final image = _selectedImage ??
        await getImageFileFromAssets('assets/images/${_imageList[_index]}');
    final decodedImage = await decodeImageFromList(image.readAsBytesSync());
    final predictions = await FlutterD2go.getImagePrediction(
      image: image,
      minScore: 0.8,
    );
    List<RecognitionModel>? recognitions;
    if (predictions.isNotEmpty) {
      recognitions = predictions.map(
        (e) {
          return RecognitionModel(
              Rectangle(
                e['rect']['left'],
                e['rect']['top'],
                e['rect']['right'],
                e['rect']['bottom'],
              ),
              e['mask'],
              e['keypoints'] != null
                  ? (e['keypoints'] as List)
                      .map((k) => Keypoint(k[0], k[1]))
                      .toList()
                  : null,
              e['confidenceInClass'],
              e['detectedClass']);
        },
      ).toList();
    }

    setState(
      () {
        _imageWidth = decodedImage.width;
        _imageHeight = decodedImage.height;
        _recognitions = recognitions;
      },
    );
  }

  Future<File> getImageFileFromAssets(String path) async {
    final byteData = await rootBundle.load(path);
    final fileName = path.split('/').last;
    final file = File('${(await getTemporaryDirectory()).path}/$fileName');
    await file.writeAsBytes(byteData.buffer
        .asUint8List(byteData.offsetInBytes, byteData.lengthInBytes));

    return file;
  }

  [@override](/user/override)
  Widget build(BuildContext context) {
    double screenWidth = MediaQuery.of(context).size.width;
    List<Widget> stackChildren = [];
    stackChildren.add(
      Positioned(
        top: 0.0,
        left: 0.0,
        width: screenWidth,
        child: _selectedImage == null
            ? Image.asset(
                'assets/images/${_imageList[_index]}',
              )
            : Image.file(_selectedImage!),
      ),
    );

    if (_isLiveModeOn) {
      stackChildren.add(
        Positioned(
          top: 0.0,
          left: 0.0,
          width: screenWidth,
          child: CameraPreview(controller!),
        ),
      );
    }

    if (_recognitions != null) {
      final aspectRatio = _imageHeight! / _imageWidth! * screenWidth;
      final widthScale = screenWidth / _imageWidth!;
      final heightScale = aspectRatio / _imageHeight!;

      if (_recognitions!.first.mask != null) {
        stackChildren.addAll(_recognitions!.map(
          (recognition) {
            return RenderSegments(
              imageWidthScale: widthScale,
              imageHeightScale: heightScale,
              recognition: recognition,
            );
          },
        ).toList());
      }

      if (_recognitions!.first.keypoints != null) {
        for (RecognitionModel recognition in _recognitions!) {
          List<Widget> keypointChildren = [];
          for (Keypoint keypoint in recognition.keypoints!) {
            keypointChildren.add(
              RenderKeypoints(
                keypoint: keypoint,
                imageWidthScale: widthScale,
                imageHeightScale: heightScale,
              ),
            );
          }
          stackChildren.addAll(keypointChildren);
        }
      }

      stackChildren.addAll(_recognitions!.map(
        (recognition) {
          return RenderBoxes(
            imageWidthScale: widthScale,
            imageHeightScale: heightScale,
            recognition: recognition,
          );
        },
      ).toList());
    }

    return Scaffold(
      appBar: AppBar(
        title: const Text('Flutter D2Go'),
        backgroundColor: Colors.deepPurpleAccent,
      ),
      body: Column(
        mainAxisAlignment: MainAxisAlignment.center,
        children: [
          const SizedBox(height: 48),
          Expanded(
            child: Stack(
              children: stackChildren,
            ),
          ),
          const SizedBox(height: 48),
          MyButton(
            onPressed: !_isLiveModeOn ? detect : null,
            text: 'Detect',
          ),
          Padding(
            padding: const EdgeInsets.symmetric(vertical: 48),
            child: Row(
              mainAxisAlignment: MainAxisAlignment.spaceAround,
              children: [
                MyButton(
                    onPressed: () {
                      setState(
                        () {
                          _recognitions = null;
                          if (_selectedImage == null) {
                            _index != 2 ? _index += 1 : _index = 0;
                          } else {
                            _selectedImage = null;
                          }
                        },
                      );
                    },
                    text: 'Test Image\n${_index + 1}/${_imageList.length}'),
                MyButton(
                    onPressed: () async {
                      final XFile? pickedFile =
                          await _picker.pickImage(source: ImageSource.gallery);
                      if (pickedFile == null) return;
                      setState(
                        () {
                          _recognitions = null;
                          _selectedImage = File(pickedFile.path);
                        },
                      );
                    },
                    text: 'Select'),
                MyButton(
                    onPressed: () async {
                      _isLiveModeOn
                          ? await controller!.stopImageStream()
                          : await live();
                      setState(
                        () {
                          _isLiveModeOn = !_isLiveModeOn;
                          _recognitions = null;
                          _selectedImage = null;
                        },
                      );
                    },
                    text: 'Live'),
              ],
            ),
          ),
        ],
      ),
    );
  }
}

class MyButton extends StatelessWidget {
  const MyButton({Key? key, required this.onPressed, required this.text})
      : super(key: key);

  final VoidCallback? onPressed;
  final String text;

  [@override](/user/override)
  Widget build(BuildContext context) {
    return SizedBox(
      width: 96,
      height: 42,
      child: ElevatedButton(
        onPressed: onPressed,
        child: Text(
          text,
          textAlign: TextAlign.center,
          style: const TextStyle(
            color: Colors.black,
            fontSize: 12,
          ),
        ),
        style: ElevatedButton.styleFrom(
          primary: Colors.grey[300],
          elevation: 0,
        ),
      ),
    );
  }
}

class RenderBoxes extends StatelessWidget {
  const RenderBoxes({
    Key? key,
    required this.recognition,
    required this.imageWidthScale,
    required this.imageHeightScale,
  }) : super(key: key);

  final RecognitionModel recognition;
  final double imageWidthScale;
  final double imageHeightScale;

  [@override](/user/override)
  Widget build(BuildContext context) {
    final left = recognition.rect.left * imageWidthScale;
    final top = recognition.rect.top * imageHeightScale;
    final right = recognition.rect.right * imageWidthScale;
    final bottom = recognition.rect.bottom * imageHeightScale;
    return Positioned(
      left: left,
      top: top,
      width: right - left,
      height: bottom - top,
      child: Container(
        decoration: BoxDecoration(
          borderRadius: const BorderRadius.all(Radius.circular(8.0)),
          border: Border.all(
            color: Colors.yellow,
            width: 2,
          ),
        ),
        child: Text(
          "${recognition.detectedClass} ${(recognition.confidenceInClass * 100).toStringAsFixed(0)}%",
          style: TextStyle(
            background: Paint()..color = Colors.yellow,
            color: Colors.black,
            fontSize: 15.0,
          ),
        ),
      ),
    );
  }
}

class RenderSegments extends StatelessWidget {
  const RenderSegments({
    Key? key,
    required this.recognition,
    required this.imageWidthScale,
    required this.imageHeightScale,
  }) : super(key: key);

  final RecognitionModel recognition;
  final double imageWidthScale;
  final double imageHeightScale;

  [@override](/user/override)
  Widget build(BuildContext context) {
    final left = recognition.rect.left * imageWidthScale;
    final top = recognition.rect.top * imageHeightScale;
    final right = recognition.rect.right * imageWidthScale;
    final bottom = recognition.rect.bottom * imageHeightScale;
    final mask = recognition.mask!;
    return Positioned(
      left: left,
      top: top,
      width: right - left,
      height: bottom - top,
      child: Image.memory(
        mask,
        fit: BoxFit.fill,
      ),
    );
  }
}

class RenderKeypoints extends StatelessWidget {
  const RenderKeypoints({
    Key? key,
    required this.keypoint,
    required this.imageWidthScale,
    required this.imageHeightScale,
  }) : super(key: key);

  final Keypoint keypoint;
  final double imageWidthScale;
  final double imageHeightScale;

  [@override](/user/override)
  Widget build(BuildContext context) {
    final x = keypoint.x * imageWidthScale;
    final y = keypoint.y * imageHeightScale;
    return Positioned(
      left: x,
      top: y,
      child: Container(
        width: 8,
        height: 8,
        decoration: const BoxDecoration(
          color: Colors.red,
          shape: BoxShape.circle,
        ),
      ),
    );
  }
}

class RecognitionModel {
  RecognitionModel(
    this.rect,
    this.mask,
    this.keypoints,
    this.confidenceInClass,
    this.detectedClass,
  );
  Rectangle rect;
  Uint8List? mask;
  List<Keypoint>? keypoints;
  double confidenceInClass;
  String detectedClass;
}

class Rectangle {
  Rectangle(this.left, this.top, this.right, this.bottom);
  double left;
  double top;
  double right;
  double bottom;
}

class Keypoint {
  Keypoint(this.x, this.y);
  double x;
  double y;
}

更多关于Flutter深度学习模型部署插件flutter_d2go的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html

1 回复

更多关于Flutter深度学习模型部署插件flutter_d2go的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html


当然,以下是如何在Flutter项目中使用flutter_d2go插件来部署深度学习模型的示例代码。flutter_d2go是一个用于在Flutter应用中部署深度学习模型的插件,支持MobileNetV2、SSD等模型。

步骤1:添加依赖

首先,在你的pubspec.yaml文件中添加flutter_d2go的依赖:

dependencies:
  flutter:
    sdk: flutter
  flutter_d2go: ^最新版本号  # 请替换为最新的发布版本号

然后运行flutter pub get来安装依赖。

步骤2:配置Android项目

由于flutter_d2go依赖TensorFlow Lite,你需要在Android项目中进行一些配置。

android/app/build.gradle中添加以下配置:

android {6
',     '...x
8    6default_Config6 {4
'        
...        
}}        
ndk     {}

            abiFilters 'armeabi-v7a', 'arm64-v8a', 'x8

dependencies {
    ...
    implementation 'org.tensorflow:tensorflow-lite:2.7.0'  # 请使用与flutter_d2go兼容的版本
    implementation 'org.tensorflow:tensorflow-lite-gpu:2.7.0'  # 如果需要GPU加速
    implementation 'org.tensorflow:tensorflow-lite-select-tf-ops:2.7.0'  # 如果需要选择性TF操作
}

步骤3:加载和推理模型

以下是一个简单的示例,展示了如何使用flutter_d2go加载一个MobileNetV2模型并进行图像分类。

Dart代码

import 'package:flutter/material.dart';
import 'package:flutter_d2go/flutter_d2go.dart';
import 'package:image_picker/image_picker.dart';

void main() {
  runApp(MyApp());
}

class MyApp extends StatelessWidget {
  @override
  Widget build(BuildContext context) {
    return MaterialApp(
      home: Scaffold(
        appBar: AppBar(
          title: Text('Flutter D2Go Example'),
        ),
        body: Center(
          child: MyHomePage(),
        ),
      ),
    );
  }
}

class MyHomePage extends StatefulWidget {
  @override
  _MyHomePageState createState() => _MyHomePageState();
}

class _MyHomePageState extends State<MyHomePage> {
  final ImagePicker _picker = ImagePicker();
  Interpreter? _interpreter;
  List<dynamic>? _labels;
  String? _result;

  @override
  void initState() {
    super.initState();
    loadModel();
  }

  Future<void> loadModel() async {
    try {
      // 加载MobileNetV2模型
      _interpreter = await Interpreter.fromAsset('models/mobilenet_v2_1.0_224.tflite');
      // 加载标签文件(假设标签文件为labels.txt,每行一个标签)
      _labels = await FileUtil.readLinesFromAsset('models/labels.txt');
    } catch (e) {
      print('Failed to load model: $e');
    }
  }

  Future<void> pickImage() async {
    final XFile? image = await _picker.pickImage(source: ImageSource.camera);
    if (image == null) return;

    final Uint8List imageBytes = await image.readAsBytes();
    final TensorImage inputTensor = TensorImage.fromBytes(
      imageBytes,
      TensorImage.CHANNEL_LAST,
      inputShape: [1, 224, 224, 3],
      dataType: DataType.UINT8,
      mean: 127.5,
      std: 127.5,
    );

    List<TensorBuffer> outputBuffers = await _interpreter!.run(inputTensor);
    final TensorBuffer outputBuffer = outputBuffers[0];
    final Float32List probabilities = outputBuffer.getFloat32List();

    // 找到概率最高的标签
    final int bestLabelIndex = probabilities.indexOf(probabilities.reduce((a, b) => Math.max(a, b)));
    setState(() {
      _result = _labels?[bestLabelIndex] ?? 'Unknown';
    });
  }

  @override
  Widget build(BuildContext context) {
    return Column(
      mainAxisAlignment: MainAxisAlignment.center,
      children: [
        TextButton(
          onPressed: pickImage,
          child: Text('Pick an image'),
        ),
        if (_result != null)
          Text(
            'Prediction: $_result',
            style: TextStyle(fontSize: 24),
          ),
      ],
    );
  }
}

注意事项

  1. 模型文件:确保你的模型文件(如mobilenet_v2_1.0_224.tflitelabels.txt)已放置在android/app/src/main/assets/models/目录下。

  2. 权限:由于需要从相机或图库中选择图像,你需要在android/app/src/main/AndroidManifest.xml中添加相机和存储权限。

<uses-permission android:name="android.permission.CAMERA" />
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE" />
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE" />
  1. Image Picker:上面的示例使用了image_picker插件来选择图像,因此你还需要在pubspec.yaml中添加image_picker的依赖。
dependencies:
  flutter:
    sdk: flutter
  flutter_d2go: ^最新版本号
  image_picker: ^最新版本号

通过上述步骤,你就可以在Flutter应用中部署并使用深度学习模型进行推理了。

回到顶部