Flutter深度学习模型推理插件flutter_pytorch的使用
Flutter深度学习模型推理插件flutter_pytorch的使用
简介
flutter_pytorch
是一个Flutter插件,用于在移动设备上运行PyTorch Lite模型进行推理。它支持图像分类和目标检测任务。例如,它可以用于运行YOLOv5模型,但不支持YOLOv7。iOS平台的支持可以通过参考PyTorch iOS示例应用来添加,欢迎提交PR。
准备工作
1. 准备模型
图像分类模型
import torch
from torch.utils.mobile_optimizer import optimize_for_mobile
# 加载预训练模型
model = torch.load('model_scripted.pt', map_location="cpu")
model.eval()
# 创建示例输入
example = torch.rand(1, 3, 224, 224)
# 跟踪模型并优化
traced_script_module = torch.jit.trace(model, example)
optimized_traced_model = optimize_for_mobile(traced_script_module)
# 保存优化后的模型
optimized_traced_model._save_for_lite_interpreter("model.pt")
目标检测模型(YOLOv5)
!python export.py --weights "your_model_weights.pt" --include torchscript --img 640 --optimize
例如:
!python export.py --weights yolov5s.pt --include torchscript --img 640 --optimize
安装
1. 添加依赖
在 pubspec.yaml
文件中添加 pytorch_lite
依赖:
dependencies:
flutter_pytorch: ^latest_version
2. 添加模型和标签文件
创建一个 assets
文件夹,并将你的PyTorch模型和标签文件放入其中。修改 pubspec.yaml
文件以包含这些资源:
assets:
- assets/models/model_classification.pt
- assets/labels/label_classification_imageNet.txt
- assets/models/model_objectDetection.torchscript
- assets/labels/labels_objectDetection_Coco.txt
3. 运行 flutter pub get
flutter pub get
4. 配置发布版本
对于发布版本,编辑 android/app/build.gradle
文件,在 release
配置中添加以下内容:
buildTypes {
release {
shrinkResources false
minifyEnabled false
signingConfig signingConfigs.debug
}
}
使用
1. 导入库
import 'package:flutter_pytorch/flutter_pytorch.dart';
2. 加载模型
图像分类模型
ClassificationModel classificationModel = await FlutterPytorch.loadClassificationModel(
"assets/models/model_classification.pt",
224,
224,
labelPath: "assets/labels/label_classification_imageNet.txt"
);
目标检测模型
ModelObjectDetection objectModel = await FlutterPytorch.loadObjectDetectionModel(
"assets/models/yolov5s.torchscript",
80,
640,
640,
labelPath: "assets/labels/labels_objectDetection_Coco.txt"
);
3. 获取分类预测结果
获取分类预测标签
String imagePrediction = await classificationModel.getImagePrediction(
await File(image.path).readAsBytes()
);
获取分类预测的原始输出层
List<double?>? predictionList = await _imageModel!.getImagePredictionList(
await File(image.path).readAsBytes()
);
获取分类预测的概率(如果模型未使用softmax)
List<double?>? predictionListProbabilites = await _imageModel!.getImagePredictionListProbabilities(
await File(image.path).readAsBytes()
);
4. 获取目标检测预测结果
List<ResultObjectDetection?> objDetect = await _objectModel.getImagePrediction(
await File(image.path).readAsBytes(),
minimumScore: 0.1,
IOUThershold: 0.3
);
5. 在图像上绘制检测框
objectModel.renderBoxesOnImage(_image!, objDetect);
6. 使用自定义均值和标准差进行图像预测
final mean = [0.5, 0.5, 0.5];
final std = [0.5, 0.5, 0.5];
String prediction = await classificationModel.getImagePrediction(
image,
mean: mean,
std: std
);
示例代码
以下是一个完整的示例代码,展示了如何在Flutter应用中使用 flutter_pytorch
插件进行模型推理:
import 'package:flutter/material.dart';
import 'package:flutter_pytorch/flutter_pytorch.dart';
void main() => runApp(MyApp());
class MyApp extends StatelessWidget {
[@override](/user/override)
Widget build(BuildContext context) {
return MaterialApp(
home: ChooseDemo(),
);
}
}
class ChooseDemo extends StatefulWidget {
const ChooseDemo({Key? key}) : super(key: key);
[@override](/user/override)
State<ChooseDemo> createState() => _ChooseDemoState();
}
class _ChooseDemoState extends State<ChooseDemo> {
ClassificationModel? classificationModel;
ModelObjectDetection? objectModel;
[@override](/user/override)
void initState() {
super.initState();
_loadModel();
}
Future<void> _loadModel() async {
// 加载分类模型
classificationModel = await FlutterPytorch.loadClassificationModel(
"assets/models/model_classification.pt",
224,
224,
labelPath: "assets/labels/label_classification_imageNet.txt"
);
// 加载目标检测模型
objectModel = await FlutterPytorch.loadObjectDetectionModel(
"assets/models/yolov5s.torchscript",
80,
640,
640,
labelPath: "assets/labels/labels_objectDetection_Coco.txt"
);
setState(() {});
}
[@override](/user/override)
Widget build(BuildContext context) {
return Scaffold(
appBar: AppBar(
title: Text('Pytorch Mobile Example'),
),
body: Center(
child: Column(
mainAxisAlignment: MainAxisAlignment.center,
children: [
if (classificationModel != null && objectModel != null)
TextButton(
onPressed: () async {
// 选择图片或相机
final image = await ImagePicker().pickImage(source: ImageSource.gallery);
if (image == null) return;
// 获取分类预测
String classificationResult = await classificationModel!.getImagePrediction(
await File(image.path).readAsBytes()
);
// 获取目标检测预测
List<ResultObjectDetection?> detectionResults = await objectModel!.getImagePrediction(
await File(image.path).readAsBytes(),
minimumScore: 0.1,
IOUThershold: 0.3
);
// 显示结果
showDialog(
context: context,
builder: (context) => AlertDialog(
title: Text('Prediction Results'),
content: Column(
mainAxisSize: MainAxisSize.min,
children: [
Text('Classification: $classificationResult'),
if (detectionResults.isNotEmpty)
ListView.builder(
shrinkWrap: true,
itemCount: detectionResults.length,
itemBuilder: (context, index) {
final result = detectionResults[index];
return ListTile(
title: Text(result?.className ?? 'Unknown'),
subtitle: Text('Score: ${result?.score}'),
);
},
),
],
),
actions: [
TextButton(
onPressed: () => Navigator.pop(context),
child: Text('Close'),
),
],
),
);
},
style: TextButton.styleFrom(
backgroundColor: Colors.blue,
),
child: Text(
"Run Model with Image",
style: TextStyle(
color: Colors.white,
),
),
),
else
CircularProgressIndicator(), // 模型加载中
],
),
),
);
}
}
更多关于Flutter深度学习模型推理插件flutter_pytorch的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html
更多关于Flutter深度学习模型推理插件flutter_pytorch的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html
当然,关于Flutter深度学习模型推理插件flutter_pytorch
的使用,这里提供一个简单的代码案例来展示如何在Flutter应用中进行PyTorch模型的加载和推理。
首先,确保你已经在Flutter项目中添加了flutter_pytorch
依赖。在pubspec.yaml
文件中添加以下依赖:
dependencies:
flutter:
sdk: flutter
flutter_pytorch: ^x.y.z # 替换为最新版本号
然后运行flutter pub get
来安装依赖。
接下来,我们编写一个简单的Flutter应用来加载一个预训练的PyTorch模型并进行推理。假设我们有一个已经训练好的PyTorch模型并保存为model.pt
文件。
1. 准备PyTorch模型(在Python中)
确保你的PyTorch模型已经保存为model.pt
文件。这里是一个简单的模型保存示例:
import torch
import torch.nn as nn
import torch.optim as optim
# 定义一个简单的线性模型
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(1, 1)
def forward(self, x):
return self.linear(x)
# 实例化模型,训练并保存
model = SimpleModel()
dummy_input = torch.tensor([[1.0]])
dummy_target = torch.tensor([[2.0]])
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
optimizer.zero_grad()
output = model(dummy_input)
loss = criterion(output, dummy_target)
loss.backward()
optimizer.step()
torch.save(model.state_dict(), 'model.pt')
2. 在Flutter中使用flutter_pytorch
加载模型
接下来,在Flutter项目中编写代码来加载这个模型并进行推理。
import 'package:flutter/material.dart';
import 'package:flutter_pytorch/flutter_pytorch.dart';
void main() {
runApp(MyApp());
}
class MyApp extends StatelessWidget {
@override
Widget build(BuildContext context) {
return MaterialApp(
home: Scaffold(
appBar: AppBar(
title: Text('Flutter PyTorch Example'),
),
body: Center(
child: PyTorchModelWidget(),
),
),
);
}
}
class PyTorchModelWidget extends StatefulWidget {
@override
_PyTorchModelWidgetState createState() => _PyTorchModelWidgetState();
}
class _PyTorchModelWidgetState extends State<PyTorchModelWidget> {
late PyTorchModel pytorchModel;
@override
void initState() {
super.initState();
loadModel();
}
Future<void> loadModel() async {
// 加载模型文件
pytorchModel = await PyTorchModel.loadAsset('assets/model.pt');
// 准备输入数据
var inputTensor = Tensor.fromList([1.0], dtype: DType.float32).reshape([1, 1]);
// 执行推理
var outputTensor = await pytorchModel.forward(inputTensor);
// 处理输出
var outputValue = outputTensor.data.toList().single[0];
print('Model output: $outputValue');
// 如果需要在UI中显示结果,可以使用setState更新状态
// setState(() {
// // 更新UI状态
// });
}
@override
Widget build(BuildContext context) {
return Text('Loading model...');
// 可以在这里添加更多UI元素来显示推理结果
}
}
注意事项
-
模型文件路径:确保
model.pt
文件已经放置在Flutter项目的assets
文件夹中,并在pubspec.yaml
中声明该资产:flutter: assets: - assets/model.pt
-
Tensor处理:输入和输出数据都是以Tensor的形式处理,确保输入数据的维度和类型与模型训练时一致。
-
异步操作:加载模型和进行推理都是异步操作,使用
await
关键字等待操作完成。
这个示例展示了如何在Flutter中使用flutter_pytorch
插件加载PyTorch模型并进行简单的推理。根据你的具体需求,你可以进一步扩展这个示例,比如处理更复杂的模型、在UI中显示结果等。