Flutter深度学习模型推理插件flutter_pytorch的使用

发布于 1周前 作者 nodeper 来自 Flutter

Flutter深度学习模型推理插件flutter_pytorch的使用

简介

flutter_pytorch 是一个Flutter插件,用于在移动设备上运行PyTorch Lite模型进行推理。它支持图像分类和目标检测任务。例如,它可以用于运行YOLOv5模型,但不支持YOLOv7。iOS平台的支持可以通过参考PyTorch iOS示例应用来添加,欢迎提交PR。

准备工作

1. 准备模型
图像分类模型
import torch
from torch.utils.mobile_optimizer import optimize_for_mobile

# 加载预训练模型
model = torch.load('model_scripted.pt', map_location="cpu")
model.eval()

# 创建示例输入
example = torch.rand(1, 3, 224, 224)

# 跟踪模型并优化
traced_script_module = torch.jit.trace(model, example)
optimized_traced_model = optimize_for_mobile(traced_script_module)

# 保存优化后的模型
optimized_traced_model._save_for_lite_interpreter("model.pt")
目标检测模型(YOLOv5)
!python export.py --weights "your_model_weights.pt" --include torchscript --img 640 --optimize

例如:

!python export.py --weights yolov5s.pt --include torchscript --img 640 --optimize

安装

1. 添加依赖

pubspec.yaml 文件中添加 pytorch_lite 依赖:

dependencies:
  flutter_pytorch: ^latest_version
2. 添加模型和标签文件

创建一个 assets 文件夹,并将你的PyTorch模型和标签文件放入其中。修改 pubspec.yaml 文件以包含这些资源:

assets:
  - assets/models/model_classification.pt
  - assets/labels/label_classification_imageNet.txt
  - assets/models/model_objectDetection.torchscript
  - assets/labels/labels_objectDetection_Coco.txt
3. 运行 flutter pub get
flutter pub get
4. 配置发布版本

对于发布版本,编辑 android/app/build.gradle 文件,在 release 配置中添加以下内容:

buildTypes {
    release {
        shrinkResources false
        minifyEnabled false
        signingConfig signingConfigs.debug
    }
}

使用

1. 导入库
import 'package:flutter_pytorch/flutter_pytorch.dart';
2. 加载模型
图像分类模型
ClassificationModel classificationModel = await FlutterPytorch.loadClassificationModel(
  "assets/models/model_classification.pt", 
  224, 
  224, 
  labelPath: "assets/labels/label_classification_imageNet.txt"
);
目标检测模型
ModelObjectDetection objectModel = await FlutterPytorch.loadObjectDetectionModel(
  "assets/models/yolov5s.torchscript", 
  80, 
  640, 
  640, 
  labelPath: "assets/labels/labels_objectDetection_Coco.txt"
);
3. 获取分类预测结果
获取分类预测标签
String imagePrediction = await classificationModel.getImagePrediction(
  await File(image.path).readAsBytes()
);
获取分类预测的原始输出层
List<double?>? predictionList = await _imageModel!.getImagePredictionList(
  await File(image.path).readAsBytes()
);
获取分类预测的概率(如果模型未使用softmax)
List<double?>? predictionListProbabilites = await _imageModel!.getImagePredictionListProbabilities(
  await File(image.path).readAsBytes()
);
4. 获取目标检测预测结果
List<ResultObjectDetection?> objDetect = await _objectModel.getImagePrediction(
  await File(image.path).readAsBytes(),
  minimumScore: 0.1, 
  IOUThershold: 0.3
);
5. 在图像上绘制检测框
objectModel.renderBoxesOnImage(_image!, objDetect);
6. 使用自定义均值和标准差进行图像预测
final mean = [0.5, 0.5, 0.5];
final std = [0.5, 0.5, 0.5];
String prediction = await classificationModel.getImagePrediction(
  image, 
  mean: mean, 
  std: std
);

示例代码

以下是一个完整的示例代码,展示了如何在Flutter应用中使用 flutter_pytorch 插件进行模型推理:

import 'package:flutter/material.dart';
import 'package:flutter_pytorch/flutter_pytorch.dart';

void main() => runApp(MyApp());

class MyApp extends StatelessWidget {
  [@override](/user/override)
  Widget build(BuildContext context) {
    return MaterialApp(
      home: ChooseDemo(),
    );
  }
}

class ChooseDemo extends StatefulWidget {
  const ChooseDemo({Key? key}) : super(key: key);

  [@override](/user/override)
  State<ChooseDemo> createState() => _ChooseDemoState();
}

class _ChooseDemoState extends State<ChooseDemo> {
  ClassificationModel? classificationModel;
  ModelObjectDetection? objectModel;

  [@override](/user/override)
  void initState() {
    super.initState();
    _loadModel();
  }

  Future<void> _loadModel() async {
    // 加载分类模型
    classificationModel = await FlutterPytorch.loadClassificationModel(
      "assets/models/model_classification.pt", 
      224, 
      224, 
      labelPath: "assets/labels/label_classification_imageNet.txt"
    );

    // 加载目标检测模型
    objectModel = await FlutterPytorch.loadObjectDetectionModel(
      "assets/models/yolov5s.torchscript", 
      80, 
      640, 
      640, 
      labelPath: "assets/labels/labels_objectDetection_Coco.txt"
    );

    setState(() {});
  }

  [@override](/user/override)
  Widget build(BuildContext context) {
    return Scaffold(
      appBar: AppBar(
        title: Text('Pytorch Mobile Example'),
      ),
      body: Center(
        child: Column(
          mainAxisAlignment: MainAxisAlignment.center,
          children: [
            if (classificationModel != null && objectModel != null)
              TextButton(
                onPressed: () async {
                  // 选择图片或相机
                  final image = await ImagePicker().pickImage(source: ImageSource.gallery);
                  if (image == null) return;

                  // 获取分类预测
                  String classificationResult = await classificationModel!.getImagePrediction(
                    await File(image.path).readAsBytes()
                  );

                  // 获取目标检测预测
                  List<ResultObjectDetection?> detectionResults = await objectModel!.getImagePrediction(
                    await File(image.path).readAsBytes(),
                    minimumScore: 0.1, 
                    IOUThershold: 0.3
                  );

                  // 显示结果
                  showDialog(
                    context: context,
                    builder: (context) => AlertDialog(
                      title: Text('Prediction Results'),
                      content: Column(
                        mainAxisSize: MainAxisSize.min,
                        children: [
                          Text('Classification: $classificationResult'),
                          if (detectionResults.isNotEmpty)
                            ListView.builder(
                              shrinkWrap: true,
                              itemCount: detectionResults.length,
                              itemBuilder: (context, index) {
                                final result = detectionResults[index];
                                return ListTile(
                                  title: Text(result?.className ?? 'Unknown'),
                                  subtitle: Text('Score: ${result?.score}'),
                                );
                              },
                            ),
                        ],
                      ),
                      actions: [
                        TextButton(
                          onPressed: () => Navigator.pop(context),
                          child: Text('Close'),
                        ),
                      ],
                    ),
                  );
                },
                style: TextButton.styleFrom(
                  backgroundColor: Colors.blue,
                ),
                child: Text(
                  "Run Model with Image",
                  style: TextStyle(
                    color: Colors.white,
                  ),
                ),
              ),
            else
              CircularProgressIndicator(), // 模型加载中
          ],
        ),
      ),
    );
  }
}

更多关于Flutter深度学习模型推理插件flutter_pytorch的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html

1 回复

更多关于Flutter深度学习模型推理插件flutter_pytorch的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html


当然,关于Flutter深度学习模型推理插件flutter_pytorch的使用,这里提供一个简单的代码案例来展示如何在Flutter应用中进行PyTorch模型的加载和推理。

首先,确保你已经在Flutter项目中添加了flutter_pytorch依赖。在pubspec.yaml文件中添加以下依赖:

dependencies:
  flutter:
    sdk: flutter
  flutter_pytorch: ^x.y.z  # 替换为最新版本号

然后运行flutter pub get来安装依赖。

接下来,我们编写一个简单的Flutter应用来加载一个预训练的PyTorch模型并进行推理。假设我们有一个已经训练好的PyTorch模型并保存为model.pt文件。

1. 准备PyTorch模型(在Python中)

确保你的PyTorch模型已经保存为model.pt文件。这里是一个简单的模型保存示例:

import torch
import torch.nn as nn
import torch.optim as optim

# 定义一个简单的线性模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, x):
        return self.linear(x)

# 实例化模型,训练并保存
model = SimpleModel()
dummy_input = torch.tensor([[1.0]])
dummy_target = torch.tensor([[2.0]])
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

optimizer.zero_grad()
output = model(dummy_input)
loss = criterion(output, dummy_target)
loss.backward()
optimizer.step()

torch.save(model.state_dict(), 'model.pt')

2. 在Flutter中使用flutter_pytorch加载模型

接下来,在Flutter项目中编写代码来加载这个模型并进行推理。

import 'package:flutter/material.dart';
import 'package:flutter_pytorch/flutter_pytorch.dart';

void main() {
  runApp(MyApp());
}

class MyApp extends StatelessWidget {
  @override
  Widget build(BuildContext context) {
    return MaterialApp(
      home: Scaffold(
        appBar: AppBar(
          title: Text('Flutter PyTorch Example'),
        ),
        body: Center(
          child: PyTorchModelWidget(),
        ),
      ),
    );
  }
}

class PyTorchModelWidget extends StatefulWidget {
  @override
  _PyTorchModelWidgetState createState() => _PyTorchModelWidgetState();
}

class _PyTorchModelWidgetState extends State<PyTorchModelWidget> {
  late PyTorchModel pytorchModel;

  @override
  void initState() {
    super.initState();
    loadModel();
  }

  Future<void> loadModel() async {
    // 加载模型文件
    pytorchModel = await PyTorchModel.loadAsset('assets/model.pt');

    // 准备输入数据
    var inputTensor = Tensor.fromList([1.0], dtype: DType.float32).reshape([1, 1]);

    // 执行推理
    var outputTensor = await pytorchModel.forward(inputTensor);

    // 处理输出
    var outputValue = outputTensor.data.toList().single[0];
    print('Model output: $outputValue');

    // 如果需要在UI中显示结果,可以使用setState更新状态
    // setState(() {
    //   // 更新UI状态
    // });
  }

  @override
  Widget build(BuildContext context) {
    return Text('Loading model...');
    // 可以在这里添加更多UI元素来显示推理结果
  }
}

注意事项

  1. 模型文件路径:确保model.pt文件已经放置在Flutter项目的assets文件夹中,并在pubspec.yaml中声明该资产:

    flutter:
      assets:
        - assets/model.pt
    
  2. Tensor处理:输入和输出数据都是以Tensor的形式处理,确保输入数据的维度和类型与模型训练时一致。

  3. 异步操作:加载模型和进行推理都是异步操作,使用await关键字等待操作完成。

这个示例展示了如何在Flutter中使用flutter_pytorch插件加载PyTorch模型并进行简单的推理。根据你的具体需求,你可以进一步扩展这个示例,比如处理更复杂的模型、在UI中显示结果等。

回到顶部