Flutter深度学习集成插件pytorch_dart的使用

Flutter深度学习集成插件pytorch_dart的使用

Pytorch_Dart 是一个 Dart 包装器,用于 Libtorch,旨在为 Dart/Flutter 项目提供类似于 PyTorch 的无缝体验。

它作为 Dart/Flutter 项目中 NumPy 的替代品。

注意: 该包处于实验阶段,未来可能会有 API 变更。

平台支持

平台 状态 预构建二进制文件
Windows x64(无CUDA)
Android arm64-v8a
armeabi-v7a
x86_64
x86
Linux x64(无CUDA)
iOS 即将上线
MacOS 即将上线

注意: 若要在 MacOS 上运行 Pytorch_Dart,请用适用于 MacOS 的 libtorch 替换 /libtorch-linux/libtorch

开始使用

将 pytorch_dart 添加到你的 pubspec.yaml

要将 Pytorch_Dart 包含到你的 Dart/Flutter 项目中,请在 pubspec.yaml 文件中添加以下内容,然后保存:

dependencies:
  pytorch_dart:^0.2.2

设置

运行以下命令进行设置:

flutter pub get
dart run pytorch_dart:setup --platform <your_platform>

目前仅支持 <your_platform> 参数的值为 linux, androidwindows(iOS 即将上线)。

对于 Windows 开发者,如果你使用的是调试版本的 libtorch,程序在调试模式下可以正常工作,但在发布模式下会抛出一些异常。反之亦然。如果你需要在发布模式下构建,你需要安装 libtorch 的发布版本。

默认情况下,设置过程会安装调试版本。若要获取 libtorch 的发布版本,请运行:

dart run pytorch_dart:setup --platform windows release

导入 Pytorch_Dart

现在你可以在 Dart/Flutter 项目中导入 Pytorch_Dart:

import 'package:pytorch_dart/pytorch_dart.dart' as torch;

对于 Android 开发者

Libtorch for Android 需要特定版本的 NDK。请按照此处的说明安装 NDK 版本 21.4.7075529:

ndk.dir=/home/pc/Android/Sdk/ndk/21.4.7075529

确保你的 local.properties 文件类似以下内容:

flutter.sdk=/home/pc/flutter
sdk.dir=/home/pc/Android/Sdk
flutter.buildMode=debug
ndk.dir=/home/pc/Android/Sdk/ndk/21.4.7075529

注意,torch.load()torch.save() 在 Android 上不可用。

排除故障

Windows

Launching lib\main.dart on Windows in debug mode...
√  Built build\windows\x64\runner\Debug\example.exe.
Error waiting for a debug connection: The log reader stopped unexpectedly, or never started.
Error launching application on Windows.

解决方案:

  1. 从这里下载 libtorch(如果你想在发布模式下运行,请下载 libtorch-win-shared-with-deps-2.2.2+cpu.zip;如果你想在调试模式下运行,请下载 libtorch-win-shared-with-deps-debug-2.2.2+cpu.zip)。
  2. 解压它。
  3. libtorch\lib\ 目录下的所有文件复制到 build\windows\x64\runner\Debug\(调试模式)或 build\windows\x64\runner\Release\(发布模式)。

使用

简介

  1. 它包括了当前版本中 torch 中的一些基本函数。
  2. 支持对 TorchScript 模型的推理。
  3. 几乎所有的函数用法与 PyTorch 保持一致。
  4. 广播也在 pytorch_dart 中有效。
  5. 即将支持 torch.nn
  6. 示例
var d=torch.eye(3,2);
print(d);

结果:

flutter:
 1  0
 0  1
 0  0
[ CPUFloatType{3,2} ]

运算符重载

注意:Dart 没有魔法函数(如 Python 中的 _radd_)。因此,在二元运算符中,张量只能位于左侧。

示例:

import 'package:pytorch_dart/pytorch_dart.dart' as torch;
...

var c=torch.DoubleTensor([[1.0,2.0,3.0],[4.0,5.0,6.0]]);
var d=c+10; // 无错误
var e=10+c; // 引发错误

其他二元运算符(-, *, /)与 + 类似。

对于运算符 [ ],你可以像在 PyTorch 中一样使用它。

但是,在当前版本中,切片不被支持。因此,你不能使用 [a:b] 来选择子张量。

示例:

import 'package:pytorch_dart/pytorch_dart.dart' as torch;
...

var c=torch.DoubleTensor([[1.0,2.0,3.0],[4.0,5.0,6.0]]);
print(c[0][0]);

结果:

flutter: 1
[ CPUDoubleType{} ]

模型推理

关于如何获得 TorchScript 模型,请参见此处。

在 Pytorch 中,我们使用 torch.jit.load() 加载 TorchScript 模型,并使用 module.forward() 进行推理。

在 Pytorch_Dart 中,我们有等效函数 torch.jit_load()module.forward()。它们与 Pytorch 版本有一些小差异。

torch.jit_load() 类似于 Pytorch 中的 torch.jit.load(),但由于我们使用了 rootBundle,所以它是一个异步函数。

加载模型的示例:

torch.JITModule? module;
void _loadModel() async{
  module=await torch.jit_load('assets/traced_resnet_model.pt');
}

然而,forward() 与原始 Pytorch 版本有一些不同。

在 Dart 中,它接收 List&lt;Dynamic&gt;,这意味着 forward() 函数的输入可以是 List&lt;Tensor&gt;, List&lt;Scalar&gt; 或其他类型。

如果您的模型的输入是一个单一的张量:

在 Python 中,以下代码如下编写:

outputTensor = module.forward(inputTensor)

但是在 Dart 中,你必须将 inputTensor 放入一个列表:

var outputTensor = module!.forward([inputTensor]);   //! 是一个空检查操作符

示例

我们提供了一个图像分类示例在 /example 目录下。

1721127375812

运行以下命令以运行它:

git clone https://github.com/Playboy-Player/pytorch_dart
cd pytorch_dart
git submodule init
git submodule update --remote
dart run pytorch_dart:setup --platform <your_platform>
cd example
flutter run --debug // 或 "flutter run --release"

函数/APIs

就像 Pytorch 一样,Pytorch_Dart 中的函数被分为多个部分。

在当前版本中,API 被分为 3 部分:

  • torch
  • torch.tensor
  • torch.jit

torch

支持的函数

  1. torch.tensor() 在 pytorch_dart 中不受支持,使用 torch.IntTensor(), torch.FloatTensor()torch.DoubleTensor() 创建张量。

  2. 当前可用的函数:

    • 注意:用 {} 包围的参数是可选参数。
    torch.ones(List<int> size, {bool requiresGrad = false, int dtype = float32, Device? device_used})
    torch.full(List<int> size, num values, {int dtype = float32, bool requiresGrad = false, Device? device_used})
    torch.eye(int n, int m, {bool requiresGrad = false, int dtype = float32, Device? device_used})
    torch.IntTensor(List<int> list)
    torch.FloatTensor(List<double> list)
    torch.DoubleTensor(List<double> list)
    torch.arange(double start, double end, double step, {bool requiresGrad = false})
    torch.linspace(double start, double end, int steps, {bool requiresGrad = false})
    torch.logspace(double start, double end, int steps, double base, {bool requiresGrad = false})
    torch.equal(Tensor a, Tensor b)
    torch.add(Tensor a, Tensor b, {double alpha=1})
    torch.sub(Tensor a, Tensor b, {double alpha=1})
    torch.mul(Tensor a, Tensor b)
    torch.div(Tensor a, Tensor b)
    torch.add_(Tensor a, Tensor b, {double alpha=1})
    torch.sub_(Tensor a, Tensor b, {double alpha=1})
    torch.mul_(Tensor a, Tensor b)
    torch.div_(Tensor a, Tensor b)
    torch.sum(Tensor a)
    torch.mm(Tensor a, Tensor b)
    torch.transpose(Tensor a, int dim0, int dim1)
    torch.permute(Tensor a, List<int> permute_list)
    torch.save(Tensor a, String path)
    torch.load(String path)
    torch.relu()
    torch.leaky_relu()
    torch.tanh()
    torch.sigmoid()
    torch.flatten(Tensor a, int startDim, int endDim)
    torch.unsqueeze(Tensor tensor, int dim)
    torch.clone(Tensor tensor)
    torch.topk(Tensor a, int k, {int dim = -1, bool largest = true, bool sorted = true})
    torch.allClose(Tensor left, Tensor right, {double rtol = 1e-08, double atol = 1e-05, bool equal_nan = false})
    torch.empty(List<int> size, {bool requiresGrad = false, int dtype = float32, Device? device_used})
    torch.ones(List<int> size, {bool requiresGrad = false, int dtype = float32, Device? device_used})
    torch.full(List<int> size, num values, {int dtype = float32, bool requiresGrad = false, Device? device_used})
    torch.eye(int n, int m, {bool requiresGrad = false, int dtype = float32, Device? device_used})
    
  3. 几乎所有的函数用法与 PyTorch 保持一致。

  4. 支持一些就地操作,例如 torch.add_()

示例用法

import 'package:pytorch_dart/pytorch_dart.dart' as torch;
...

var c=torch.DoubleTensor([[1.0,2.0,3.0],[4.0,5.0,6.0]]);
var d=torch.add(10,c)
print(d)

结果:

flutter:
 11  12  13
 14  15  16
[ CPUDoubleType{2,3} ]

torch.tensor

方法

  • .dim()
  • .dtype()
  • .shape()
  • .size()
  • .detach()
  • .add_()
  • .sub_()
  • .mul_()
  • .div_()
  • .toList()
  • .unsqueeze(int dim)
  • .clone()
  • .relu()
  • .leaky_relu()
  • .sigmoid()
  • .tanh()
  • .flatten()
  • .equal(Tensor other)
  • .sum()
  • .mm(Tensor other)
  • .view(List<int> size)

注意: 在 Pytorch_Dart 中,.dtype() 方法与 PyTorch 不同。在 PyTorch 中,.dtype 返回表示张量数据类型的对象。而在 Pytorch_Dart 中,.dtype() 返回数据类型的数值表示。这可能在未来版本中更新。

示例

import 'package:pytorch_dart/pytorch_dart.dart' as torch;
...

var c=torch.DoubleTensor([[1.0,2.0,3.0],[4.0,5.0,6.0]]);
print(c.dtype())

结果:

flutter: 7

更多关于Flutter深度学习集成插件pytorch_dart的使用的实战教程也可以访问 https://www.itying.com/category-92-b0.html

1 回复

更多关于Flutter深度学习集成插件pytorch_dart的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html


当然,关于在Flutter中集成PyTorch的插件pytorch_dart,这里是一个简单的代码示例,展示了如何使用该插件来进行深度学习模型的推理。请注意,pytorch_dart可能是一个假定的或者非官方的插件名称,因此实际使用时可能需要替换为现有的、官方支持的插件,例如torchvision配合Flutter的原生平台通道(Platform Channels)来实现类似功能。不过,为了回答你的问题,这里将假设pytorch_dart是一个封装良好的Flutter插件。

Flutter项目结构

首先,确保你的Flutter项目已经创建并配置好。你的项目结构应该类似于:

my_flutter_app/
├── android/
├── ios/
├── lib/
│   ├── main.dart
│   └── model_inference.dart
├── pubspec.yaml
└── ...

pubspec.yaml

pubspec.yaml文件中添加对pytorch_dart的依赖(注意:这是一个假设的依赖名,实际使用时请替换为真实存在的插件):

dependencies:
  flutter:
    sdk: flutter
  pytorch_dart: ^x.y.z  # 替换为实际的版本号

main.dart

main.dart中,我们可以导入pytorch_dart插件并调用模型推理功能。这里假设pytorch_dart提供了加载模型和进行推理的API。

import 'package:flutter/material.dart';
import 'package:pytorch_dart/pytorch_dart.dart';
import 'model_inference.dart';

void main() {
  runApp(MyApp());
}

class MyApp extends StatelessWidget {
  @override
  Widget build(BuildContext context) {
    return MaterialApp(
      home: Scaffold(
        appBar: AppBar(
          title: Text('Flutter PyTorch Demo'),
        ),
        body: Center(
          child: FutureBuilder<String>(
            future: performModelInference(),
            builder: (context, snapshot) {
              if (snapshot.connectionState == ConnectionState.done) {
                if (snapshot.hasError) {
                  return Text('Error: ${snapshot.error}');
                } else {
                  return Text('Model Output: ${snapshot.data}');
                }
              } else {
                return CircularProgressIndicator();
              }
            },
          ),
        ),
      ),
    );
  }
}

model_inference.dart

model_inference.dart文件中,我们封装了加载模型和进行推理的逻辑。

import 'package:pytorch_dart/pytorch_dart.dart';

Future<String> performModelInference() async {
  // 假设 pytorch_dart 提供了 loadModel 方法来加载模型
  final model = await PyTorchModel.loadModel('path/to/your/model.pt');

  // 假设模型接受一个输入张量,并返回一个输出张量
  final inputTensor = PyTorchTensor.fromList([
    // 根据你的模型输入格式调整这个张量
    [1.0, 2.0, 3.0],
  ]);

  // 执行推理
  final outputTensor = await model.infer(inputTensor);

  // 假设输出是一个简单的标量,我们可以直接转换为字符串
  final output = outputTensor.toList().first.toString();

  return output;
}

注意事项

  1. 插件依赖:上述代码假设存在一个名为pytorch_dart的Flutter插件,该插件提供了加载模型和进行推理的API。实际上,你可能需要寻找或编写一个支持PyTorch的Flutter插件,或者使用原生平台通道来调用PyTorch的C++ API。

  2. 模型路径:在model_inference.dart中,确保模型路径'path/to/your/model.pt'是正确的,并且该模型已经训练好且兼容你的推理代码。

  3. 张量格式:根据你的模型输入和输出格式,调整输入张量的创建和输出张量的处理。

  4. 错误处理:在实际应用中,你应该添加更多的错误处理逻辑,以确保在模型加载或推理失败时能够优雅地处理错误。

  5. 性能优化:对于复杂的模型或实时应用,你可能需要考虑性能优化,如使用GPU加速等。

由于pytorch_dart可能是一个不存在的插件名,你可能需要寻找现有的Flutter与PyTorch集成的解决方案,如使用torchmobile(如果存在)或者通过原生平台通道来实现自定义的集成。

回到顶部