Flutter机器学习推理插件pytorch_lite的使用
Flutter机器学习推理插件pytorch_lite的使用
pytorch_lite
是一个Flutter包,旨在帮助运行PyTorch Lite模型进行分类和目标检测(包括YOLOV5和YOLOV8)。以下是详细的使用指南,包括如何准备模型、安装依赖以及编写完整的示例代码。
准备模型
分类模型
import torch
from torch.utils.mobile_optimizer import optimize_for_mobile
model = torch.load('model_scripted.pt', map_location="cpu")
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
optimized_traced_model = optimize_for_mobile(traced_script_module)
optimized_traced_model._save_for_lite_interpreter("model.pt")
目标检测模型 (YOLOV5)
!python export.py --weights "the weights of your model" --include torchscript --img 640 --optimize
例如:
!python export.py --weights yolov5s.pt --include torchscript --img 640 --optimize
目标检测模型 (YOLOV8)
!yolo mode=export model="your model" format=torchscript optimize
例如:
!yolo mode=export model=yolov8s.pt format=torchscript optimize
安装
在 pubspec.yaml
文件中添加 pytorch_lite
作为依赖项:
dependencies:
pytorch_lite: ^latest_version
创建一个 assets
文件夹,并将你的 PyTorch 模型和标签文件放入其中。修改 pubspec.yaml
文件以包含这些资源:
assets:
- assets/models/model_classification.pt
- assets/labels_classification.txt
- assets/models/model_objectDetection.torchscript
- assets/labels_objectDetection.txt
运行 flutter pub get
。
对于发布版本,在 android/app/build.gradle
文件中添加以下配置:
buildTypes {
release {
shrinkResources false
minifyEnabled false
signingConfig signingConfigs.debug
}
}
使用
导入库
import 'package:pytorch_lite/pytorch_lite.dart';
加载模型
分类模型
ClassificationModel classificationModel = await PytorchLite.loadClassificationModel(
"assets/models/model_classification.pt",
224, 224,
labelPath: "assets/labels/label_classification_imageNet.txt"
);
目标检测模型
ModelObjectDetection objectModel = await PytorchLite.loadObjectDetectionModel(
"assets/models/yolov5s.torchscript",
80, 640, 640,
labelPath: "assets/labels/labels_objectDetection_Coco.txt",
objectDetectionModelType: ObjectDetectionModelType.yolov5
);
获取分类预测结果
从图片获取预测结果
String imagePrediction = await classificationModel.getImagePrediction(
await File(image.path).readAsBytes()
);
从摄像头图像获取预测结果
String imagePrediction = await _objectModel.getCameraImagePrediction(
cameraImage,
rotation, // 检查示例中的旋转值
);
获取原始输出层
List<double>? predictionList = await _imageModel!.getImagePredictionList(
await File(image.path).readAsBytes(),
);
从摄像头图像获取原始输出层
List<double>? predictionList = await _imageModel!.getCameraImagePredictionList(
cameraImage,
rotation, // 检查示例中的旋转值
);
获取概率(如果模型未使用softmax)
List<double>? predictionListProbabilities = await _imageModel!.getImagePredictionListProbabilities(
await File(image.path).readAsBytes(),
);
从摄像头图像获取概率
List<double>? predictionListProbabilities = await _imageModel!.getCameraPredictionListProbabilities(
cameraImage,
rotation, // 检查示例中的旋转值
);
获取目标检测预测结果
从图片获取预测结果
List<ResultObjectDetection> objDetect = await _objectModel.getImagePrediction(
await File(image.path).readAsBytes(),
minimumScore: 0.1,
iOUThreshold: 0.3
);
从摄像头图像获取预测结果
List<ResultObjectDetection> objDetect = await _objectModel.getCameraImagePrediction(
cameraImage,
rotation, // 检查示例中的旋转值
minimumScore: 0.1,
iOUThreshold: 0.3
);
在图像上绘制边界框
objectModel.renderBoxesOnImage(_image!, objDetect);
自定义均值和标准差的图像预测
final mean = [0.5, 0.5, 0.5];
final std = [0.5, 0.5, 0.5];
String prediction = await classificationModel.getImagePrediction(
image,
mean: mean,
std: std
);
示例代码
以下是一个完整的示例代码,展示了如何在Flutter应用中使用 pytorch_lite
插件:
import 'package:flutter/material.dart';
import 'package:pytorch_lite/pytorch_lite.dart';
void main() async {
runApp(const ChooseDemo());
}
class ChooseDemo extends StatefulWidget {
const ChooseDemo({Key? key}) : super(key: key);
@override
State<ChooseDemo> createState() => _ChooseDemoState();
}
class _ChooseDemoState extends State<ChooseDemo> {
late ClassificationModel classificationModel;
late ModelObjectDetection objectModel;
@override
void initState() {
super.initState();
loadModels();
}
Future<void> loadModels() async {
classificationModel = await PytorchLite.loadClassificationModel(
"assets/models/model_classification.pt",
224, 224,
labelPath: "assets/labels/label_classification_imageNet.txt"
);
objectModel = await PytorchLite.loadObjectDetectionModel(
"assets/models/yolov5s.torchscript",
80, 640, 640,
labelPath: "assets/labels/labels_objectDetection_Coco.txt",
objectDetectionModelType: ObjectDetectionModelType.yolov5
);
}
@override
Widget build(BuildContext context) {
return MaterialApp(
home: Scaffold(
appBar: AppBar(
title: const Text('Pytorch Mobile Example'),
),
body: Builder(builder: (context) {
return Center(
child: Column(
children: [
TextButton(
onPressed: () async {
// 这里可以添加逻辑来处理模型推理
String prediction = await classificationModel.getImagePrediction(
await File(image.path).readAsBytes()
);
print(prediction);
},
style: TextButton.styleFrom(
backgroundColor: Colors.blue,
),
child: const Text(
"Run Classification Model",
style: TextStyle(
color: Colors.white,
),
),
),
TextButton(
onPressed: () async {
// 这里可以添加逻辑来处理模型推理
List<ResultObjectDetection> objDetect = await objectModel.getImagePrediction(
await File(image.path).readAsBytes(),
minimumScore: 0.1,
iOUThreshold: 0.3
);
print(objDetect);
},
style: TextButton.styleFrom(
backgroundColor: Colors.blue,
),
child: const Text(
"Run Object Detection Model",
style: TextStyle(
color: Colors.white,
),
),
),
],
),
);
}),
),
);
}
}
通过以上步骤,你可以轻松地在Flutter应用中集成并使用 pytorch_lite
插件进行机器学习推理。希望这些信息对你有所帮助!
更多关于Flutter机器学习推理插件pytorch_lite的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html
更多关于Flutter机器学习推理插件pytorch_lite的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html
在Flutter中使用PyTorch Lite进行机器学习推理涉及多个步骤,包括模型转换、插件集成以及调用推理接口。以下是一个简要的指南和代码示例,帮助你开始在Flutter项目中集成PyTorch Lite。
步骤 1: 准备你的PyTorch模型
首先,确保你有一个训练好的PyTorch模型,并将其转换为PyTorch Lite格式。PyTorch Lite模型通常是一个.ptl
或.tflite
文件(取决于你使用的转换工具)。这里假设你已经有一个训练好的模型,并进行了转换。
步骤 2: 添加Flutter插件
在Flutter项目中,我们需要一个能够调用本地代码(如PyTorch Lite推理)的插件。虽然目前没有官方的PyTorch Lite Flutter插件,但你可以使用torchvision
或pytorch_mobile
(如果是针对Android/iOS的本地代码)并通过MethodChannel
与Flutter通信。
以下是一个简化的示例,展示如何设置原生代码(Android/iOS)并通过MethodChannel
与Flutter通信。
Android部分
-
添加依赖
在
android/app/build.gradle
中添加PyTorch Mobile依赖:implementation 'org.pytorch:pytorch_android_lite:1.9.0' implementation 'org.pytorch:pytorch_android_torchvision_lite:1.9.0'
-
加载模型并进行推理
创建一个新的Kotlin/Java类来处理模型加载和推理。例如,创建一个名为
TorchModel.kt
的文件:package com.example.yourapp import android.content.Context import org.pytorch.IValue import org.pytorch.Module import org.pytorch.Tensor import org.pytorch.torchvision.TensorImageUtils import java.io.File class TorchModel(context: Context, modelAsset: String) { private val module: Module init { module = Module.load(File(context.filesDir, modelAsset)) } fun predict(bitmap: Bitmap): List<Float> { val inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap, TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB) val outputTensor = module.forward(IValue.from(inputTensor)).toTensor() val scores = ArrayList<Float>() for (i in 0 until outputTensor.numel()) { scores.add(outputTensor.getDataAsFloatArray()[i]) } return scores } }
-
设置MethodChannel
在你的
MainActivity.kt
中设置MethodChannel
以与Flutter通信:package com.example.yourapp import android.os.Bundle import io.flutter.embedding.android.FlutterActivity import io.flutter.embedding.engine.FlutterEngine import io.flutter.plugin.common.MethodChannel class MainActivity: FlutterActivity() { private val CHANNEL = "com.example.yourapp/torch" override fun configureFlutterEngine(flutterEngine: FlutterEngine) { super.configureFlutterEngine(flutterEngine) MethodChannel(flutterEngine.dartExecutor.binaryMessenger, CHANNEL).setMethodCallHandler { call, result -> if (call.method == "predict") { val bitmap = // 获取或转换你的Bitmap val torchModel = TorchModel(this, "model.ptl") val scores = torchModel.predict(bitmap) result.success(scores) } else { result.notImplemented() } } } }
iOS部分
iOS部分的设置类似,但你需要使用Swift或Objective-C来编写代码。这里只提供一个简要的方向:
-
添加PyTorch Mobile依赖
在你的
Podfile
中添加PyTorch Mobile:pod 'LibTorch', '~> 1.9.0' pod 'LibTorchVision', '~> 0.10.0'
-
加载模型并进行推理
创建一个新的Swift/Objective-C类来处理模型加载和推理。
-
设置FlutterMethodChannel
在
AppDelegate.swift
或AppDelegate.m
中设置FlutterMethodChannel
以与Flutter通信。
Flutter部分
最后,在Flutter中调用原生方法:
import 'package:flutter/material.dart';
import 'package:flutter/services.dart';
void main() => runApp(MyApp());
class MyApp extends StatelessWidget {
static const platform = MethodChannel('com.example.yourapp/torch');
@override
Widget build(BuildContext context) {
return MaterialApp(
home: Scaffold(
appBar: AppBar(
title: const Text('Flutter PyTorch Lite Example'),
),
body: Center(
child: ElevatedButton(
onPressed: _predict,
child: Text('Predict'),
),
),
),
);
}
Future<void> _predict() async {
try {
// 这里可以传递Bitmap数据到原生代码,但这里简化处理
final result = await platform.invokeMethod('predict');
print(result);
} on PlatformException catch (e) {
print("Failed to invoke: '${e.message}'.");
}
}
}
注意事项
- 模型转换:确保你的模型正确转换为PyTorch Lite格式。
- Bitmap处理:在Android中,你需要将图像数据转换为Bitmap,然后传递给PyTorch Lite进行推理。
- 错误处理:添加适当的错误处理机制以处理模型加载和推理中的潜在问题。
这个示例只是一个起点,你可能需要根据具体需求进行调整和扩展。