Flutter机器学习推理插件pytorch_lite的使用

发布于 1周前 作者 h691938207 来自 Flutter

Flutter机器学习推理插件pytorch_lite的使用

pytorch_lite 是一个Flutter包,旨在帮助运行PyTorch Lite模型进行分类和目标检测(包括YOLOV5和YOLOV8)。以下是详细的使用指南,包括如何准备模型、安装依赖以及编写完整的示例代码。

准备模型

分类模型

import torch
from torch.utils.mobile_optimizer import optimize_for_mobile

model = torch.load('model_scripted.pt', map_location="cpu")
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
optimized_traced_model = optimize_for_mobile(traced_script_module)
optimized_traced_model._save_for_lite_interpreter("model.pt")

目标检测模型 (YOLOV5)

!python export.py --weights "the weights of your model" --include torchscript --img 640 --optimize

例如:

!python export.py --weights yolov5s.pt --include torchscript --img 640 --optimize

目标检测模型 (YOLOV8)

!yolo mode=export model="your model" format=torchscript optimize

例如:

!yolo mode=export model=yolov8s.pt format=torchscript optimize

安装

pubspec.yaml 文件中添加 pytorch_lite 作为依赖项:

dependencies:
  pytorch_lite: ^latest_version

创建一个 assets 文件夹,并将你的 PyTorch 模型和标签文件放入其中。修改 pubspec.yaml 文件以包含这些资源:

assets:
  - assets/models/model_classification.pt
  - assets/labels_classification.txt
  - assets/models/model_objectDetection.torchscript
  - assets/labels_objectDetection.txt

运行 flutter pub get

对于发布版本,在 android/app/build.gradle 文件中添加以下配置:

buildTypes {
    release {
        shrinkResources false
        minifyEnabled false
        signingConfig signingConfigs.debug
    }
}

使用

导入库

import 'package:pytorch_lite/pytorch_lite.dart';

加载模型

分类模型

ClassificationModel classificationModel = await PytorchLite.loadClassificationModel(
  "assets/models/model_classification.pt", 
  224, 224,
  labelPath: "assets/labels/label_classification_imageNet.txt"
);

目标检测模型

ModelObjectDetection objectModel = await PytorchLite.loadObjectDetectionModel(
  "assets/models/yolov5s.torchscript", 
  80, 640, 640,
  labelPath: "assets/labels/labels_objectDetection_Coco.txt",
  objectDetectionModelType: ObjectDetectionModelType.yolov5
);

获取分类预测结果

从图片获取预测结果

String imagePrediction = await classificationModel.getImagePrediction(
  await File(image.path).readAsBytes()
);

从摄像头图像获取预测结果

String imagePrediction = await _objectModel.getCameraImagePrediction(
  cameraImage,
  rotation, // 检查示例中的旋转值
);

获取原始输出层

List<double>? predictionList = await _imageModel!.getImagePredictionList(
  await File(image.path).readAsBytes(),
);

从摄像头图像获取原始输出层

List<double>? predictionList = await _imageModel!.getCameraImagePredictionList(
  cameraImage,
  rotation, // 检查示例中的旋转值
);

获取概率(如果模型未使用softmax)

List<double>? predictionListProbabilities = await _imageModel!.getImagePredictionListProbabilities(
  await File(image.path).readAsBytes(),
);

从摄像头图像获取概率

List<double>? predictionListProbabilities = await _imageModel!.getCameraPredictionListProbabilities(
  cameraImage,
  rotation, // 检查示例中的旋转值
);

获取目标检测预测结果

从图片获取预测结果

List<ResultObjectDetection> objDetect = await _objectModel.getImagePrediction(
  await File(image.path).readAsBytes(),
  minimumScore: 0.1, 
  iOUThreshold: 0.3
);

从摄像头图像获取预测结果

List<ResultObjectDetection> objDetect = await _objectModel.getCameraImagePrediction(
  cameraImage,
  rotation, // 检查示例中的旋转值
  minimumScore: 0.1, 
  iOUThreshold: 0.3
);

在图像上绘制边界框

objectModel.renderBoxesOnImage(_image!, objDetect);

自定义均值和标准差的图像预测

final mean = [0.5, 0.5, 0.5];
final std = [0.5, 0.5, 0.5];
String prediction = await classificationModel.getImagePrediction(
  image, 
  mean: mean, 
  std: std
);

示例代码

以下是一个完整的示例代码,展示了如何在Flutter应用中使用 pytorch_lite 插件:

import 'package:flutter/material.dart';
import 'package:pytorch_lite/pytorch_lite.dart';

void main() async {
  runApp(const ChooseDemo());
}

class ChooseDemo extends StatefulWidget {
  const ChooseDemo({Key? key}) : super(key: key);

  @override
  State<ChooseDemo> createState() => _ChooseDemoState();
}

class _ChooseDemoState extends State<ChooseDemo> {
  late ClassificationModel classificationModel;
  late ModelObjectDetection objectModel;

  @override
  void initState() {
    super.initState();
    loadModels();
  }

  Future<void> loadModels() async {
    classificationModel = await PytorchLite.loadClassificationModel(
      "assets/models/model_classification.pt", 
      224, 224,
      labelPath: "assets/labels/label_classification_imageNet.txt"
    );

    objectModel = await PytorchLite.loadObjectDetectionModel(
      "assets/models/yolov5s.torchscript", 
      80, 640, 640,
      labelPath: "assets/labels/labels_objectDetection_Coco.txt",
      objectDetectionModelType: ObjectDetectionModelType.yolov5
    );
  }

  @override
  Widget build(BuildContext context) {
    return MaterialApp(
      home: Scaffold(
        appBar: AppBar(
          title: const Text('Pytorch Mobile Example'),
        ),
        body: Builder(builder: (context) {
          return Center(
            child: Column(
              children: [
                TextButton(
                  onPressed: () async {
                    // 这里可以添加逻辑来处理模型推理
                    String prediction = await classificationModel.getImagePrediction(
                      await File(image.path).readAsBytes()
                    );
                    print(prediction);
                  },
                  style: TextButton.styleFrom(
                    backgroundColor: Colors.blue,
                  ),
                  child: const Text(
                    "Run Classification Model",
                    style: TextStyle(
                      color: Colors.white,
                    ),
                  ),
                ),
                TextButton(
                  onPressed: () async {
                    // 这里可以添加逻辑来处理模型推理
                    List<ResultObjectDetection> objDetect = await objectModel.getImagePrediction(
                      await File(image.path).readAsBytes(),
                      minimumScore: 0.1, 
                      iOUThreshold: 0.3
                    );
                    print(objDetect);
                  },
                  style: TextButton.styleFrom(
                    backgroundColor: Colors.blue,
                  ),
                  child: const Text(
                    "Run Object Detection Model",
                    style: TextStyle(
                      color: Colors.white,
                    ),
                  ),
                ),
              ],
            ),
          );
        }),
      ),
    );
  }
}

通过以上步骤,你可以轻松地在Flutter应用中集成并使用 pytorch_lite 插件进行机器学习推理。希望这些信息对你有所帮助!


更多关于Flutter机器学习推理插件pytorch_lite的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html

1 回复

更多关于Flutter机器学习推理插件pytorch_lite的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html


在Flutter中使用PyTorch Lite进行机器学习推理涉及多个步骤,包括模型转换、插件集成以及调用推理接口。以下是一个简要的指南和代码示例,帮助你开始在Flutter项目中集成PyTorch Lite。

步骤 1: 准备你的PyTorch模型

首先,确保你有一个训练好的PyTorch模型,并将其转换为PyTorch Lite格式。PyTorch Lite模型通常是一个.ptl.tflite文件(取决于你使用的转换工具)。这里假设你已经有一个训练好的模型,并进行了转换。

步骤 2: 添加Flutter插件

在Flutter项目中,我们需要一个能够调用本地代码(如PyTorch Lite推理)的插件。虽然目前没有官方的PyTorch Lite Flutter插件,但你可以使用torchvisionpytorch_mobile(如果是针对Android/iOS的本地代码)并通过MethodChannel与Flutter通信。

以下是一个简化的示例,展示如何设置原生代码(Android/iOS)并通过MethodChannel与Flutter通信。

Android部分

  1. 添加依赖

    android/app/build.gradle中添加PyTorch Mobile依赖:

    implementation 'org.pytorch:pytorch_android_lite:1.9.0'
    implementation 'org.pytorch:pytorch_android_torchvision_lite:1.9.0'
    
  2. 加载模型并进行推理

    创建一个新的Kotlin/Java类来处理模型加载和推理。例如,创建一个名为TorchModel.kt的文件:

    package com.example.yourapp
    
    import android.content.Context
    import org.pytorch.IValue
    import org.pytorch.Module
    import org.pytorch.Tensor
    import org.pytorch.torchvision.TensorImageUtils
    import java.io.File
    
    class TorchModel(context: Context, modelAsset: String) {
        private val module: Module
    
        init {
            module = Module.load(File(context.filesDir, modelAsset))
        }
    
        fun predict(bitmap: Bitmap): List<Float> {
            val inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
                TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB)
    
            val outputTensor = module.forward(IValue.from(inputTensor)).toTensor()
    
            val scores = ArrayList<Float>()
            for (i in 0 until outputTensor.numel()) {
                scores.add(outputTensor.getDataAsFloatArray()[i])
            }
    
            return scores
        }
    }
    
  3. 设置MethodChannel

    在你的MainActivity.kt中设置MethodChannel以与Flutter通信:

    package com.example.yourapp
    
    import android.os.Bundle
    import io.flutter.embedding.android.FlutterActivity
    import io.flutter.embedding.engine.FlutterEngine
    import io.flutter.plugin.common.MethodChannel
    
    class MainActivity: FlutterActivity() {
        private val CHANNEL = "com.example.yourapp/torch"
    
        override fun configureFlutterEngine(flutterEngine: FlutterEngine) {
            super.configureFlutterEngine(flutterEngine)
            MethodChannel(flutterEngine.dartExecutor.binaryMessenger, CHANNEL).setMethodCallHandler { call, result ->
                if (call.method == "predict") {
                    val bitmap = // 获取或转换你的Bitmap
                    val torchModel = TorchModel(this, "model.ptl")
                    val scores = torchModel.predict(bitmap)
                    result.success(scores)
                } else {
                    result.notImplemented()
                }
            }
        }
    }
    

iOS部分

iOS部分的设置类似,但你需要使用Swift或Objective-C来编写代码。这里只提供一个简要的方向:

  1. 添加PyTorch Mobile依赖

    在你的Podfile中添加PyTorch Mobile:

    pod 'LibTorch', '~> 1.9.0'
    pod 'LibTorchVision', '~> 0.10.0'
    
  2. 加载模型并进行推理

    创建一个新的Swift/Objective-C类来处理模型加载和推理。

  3. 设置FlutterMethodChannel

    AppDelegate.swiftAppDelegate.m中设置FlutterMethodChannel以与Flutter通信。

Flutter部分

最后,在Flutter中调用原生方法:

import 'package:flutter/material.dart';
import 'package:flutter/services.dart';

void main() => runApp(MyApp());

class MyApp extends StatelessWidget {
  static const platform = MethodChannel('com.example.yourapp/torch');

  @override
  Widget build(BuildContext context) {
    return MaterialApp(
      home: Scaffold(
        appBar: AppBar(
          title: const Text('Flutter PyTorch Lite Example'),
        ),
        body: Center(
          child: ElevatedButton(
            onPressed: _predict,
            child: Text('Predict'),
          ),
        ),
      ),
    );
  }

  Future<void> _predict() async {
    try {
      // 这里可以传递Bitmap数据到原生代码,但这里简化处理
      final result = await platform.invokeMethod('predict');
      print(result);
    } on PlatformException catch (e) {
      print("Failed to invoke: '${e.message}'.");
    }
  }
}

注意事项

  1. 模型转换:确保你的模型正确转换为PyTorch Lite格式。
  2. Bitmap处理:在Android中,你需要将图像数据转换为Bitmap,然后传递给PyTorch Lite进行推理。
  3. 错误处理:添加适当的错误处理机制以处理模型加载和推理中的潜在问题。

这个示例只是一个起点,你可能需要根据具体需求进行调整和扩展。

回到顶部