Flutter深度学习集成插件pytorch_dart的使用
Flutter深度学习集成插件pytorch_dart的使用
Pytorch_Dart 是一个 Dart 包装器,用于 Libtorch,旨在为 Dart/Flutter 项目提供类似于 PyTorch 的无缝体验。
它作为 Dart/Flutter 项目中 NumPy 的替代品。
注意: 该包处于实验阶段,未来可能会有 API 变更。
平台支持
平台 | 状态 | 预构建二进制文件 |
---|---|---|
Windows | ✅ | x64(无CUDA) |
Android | ✅ | arm64-v8a armeabi-v7a x86_64 x86 |
Linux | ✅ | x64(无CUDA) |
iOS | ❌ | 即将上线 |
MacOS | ❌ | 即将上线 |
注意: 若要在 MacOS 上运行 Pytorch_Dart,请用适用于 MacOS 的 libtorch 替换 /libtorch-linux/libtorch
。
开始使用
将 pytorch_dart 添加到你的 pubspec.yaml
要将 Pytorch_Dart 包含到你的 Dart/Flutter 项目中,请在 pubspec.yaml
文件中添加以下内容,然后保存:
dependencies:
pytorch_dart:^0.2.2
设置
运行以下命令进行设置:
flutter pub get
dart run pytorch_dart:setup --platform <your_platform>
目前仅支持 <your_platform>
参数的值为 linux
, android
和 windows
(iOS 即将上线)。
对于 Windows 开发者,如果你使用的是调试版本的 libtorch,程序在调试模式下可以正常工作,但在发布模式下会抛出一些异常。反之亦然。如果你需要在发布模式下构建,你需要安装 libtorch 的发布版本。
默认情况下,设置过程会安装调试版本。若要获取 libtorch 的发布版本,请运行:
dart run pytorch_dart:setup --platform windows release
导入 Pytorch_Dart
现在你可以在 Dart/Flutter 项目中导入 Pytorch_Dart:
import 'package:pytorch_dart/pytorch_dart.dart' as torch;
对于 Android 开发者
Libtorch for Android 需要特定版本的 NDK。请按照此处的说明安装 NDK 版本 21.4.7075529:
ndk.dir=/home/pc/Android/Sdk/ndk/21.4.7075529
确保你的 local.properties
文件类似以下内容:
flutter.sdk=/home/pc/flutter
sdk.dir=/home/pc/Android/Sdk
flutter.buildMode=debug
ndk.dir=/home/pc/Android/Sdk/ndk/21.4.7075529
注意,torch.load()
和 torch.save()
在 Android 上不可用。
排除故障
Windows
Launching lib\main.dart on Windows in debug mode...
√ Built build\windows\x64\runner\Debug\example.exe.
Error waiting for a debug connection: The log reader stopped unexpectedly, or never started.
Error launching application on Windows.
解决方案:
- 从这里下载 libtorch(如果你想在发布模式下运行,请下载
libtorch-win-shared-with-deps-2.2.2+cpu.zip
;如果你想在调试模式下运行,请下载libtorch-win-shared-with-deps-debug-2.2.2+cpu.zip
)。 - 解压它。
- 将
libtorch\lib\
目录下的所有文件复制到build\windows\x64\runner\Debug\
(调试模式)或build\windows\x64\runner\Release\
(发布模式)。
使用
简介
- 它包括了当前版本中
torch
中的一些基本函数。 - 支持对 TorchScript 模型的推理。
- 几乎所有的函数用法与 PyTorch 保持一致。
- 广播也在 pytorch_dart 中有效。
- 即将支持
torch.nn
。 - 示例
var d=torch.eye(3,2);
print(d);
结果:
flutter:
1 0
0 1
0 0
[ CPUFloatType{3,2} ]
运算符重载
注意:Dart 没有魔法函数(如 Python 中的 _radd_
)。因此,在二元运算符中,张量只能位于左侧。
示例:
import 'package:pytorch_dart/pytorch_dart.dart' as torch;
...
var c=torch.DoubleTensor([[1.0,2.0,3.0],[4.0,5.0,6.0]]);
var d=c+10; // 无错误
var e=10+c; // 引发错误
其他二元运算符(-
, *
, /
)与 +
类似。
对于运算符 [ ]
,你可以像在 PyTorch 中一样使用它。
但是,在当前版本中,切片不被支持。因此,你不能使用 [a:b]
来选择子张量。
示例:
import 'package:pytorch_dart/pytorch_dart.dart' as torch;
...
var c=torch.DoubleTensor([[1.0,2.0,3.0],[4.0,5.0,6.0]]);
print(c[0][0]);
结果:
flutter: 1
[ CPUDoubleType{} ]
模型推理
关于如何获得 TorchScript 模型,请参见此处。
在 Pytorch 中,我们使用 torch.jit.load()
加载 TorchScript 模型,并使用 module.forward()
进行推理。
在 Pytorch_Dart 中,我们有等效函数 torch.jit_load()
和 module.forward()
。它们与 Pytorch 版本有一些小差异。
torch.jit_load()
类似于 Pytorch 中的 torch.jit.load()
,但由于我们使用了 rootBundle
,所以它是一个异步函数。
加载模型的示例:
torch.JITModule? module;
void _loadModel() async{
module=await torch.jit_load('assets/traced_resnet_model.pt');
}
然而,forward()
与原始 Pytorch 版本有一些不同。
在 Dart 中,它接收 List<Dynamic>
,这意味着 forward()
函数的输入可以是 List<Tensor>
, List<Scalar>
或其他类型。
如果您的模型的输入是一个单一的张量:
在 Python 中,以下代码如下编写:
outputTensor = module.forward(inputTensor)
但是在 Dart 中,你必须将 inputTensor
放入一个列表:
var outputTensor = module!.forward([inputTensor]); //! 是一个空检查操作符
示例
我们提供了一个图像分类示例在 /example
目录下。
运行以下命令以运行它:
git clone https://github.com/Playboy-Player/pytorch_dart
cd pytorch_dart
git submodule init
git submodule update --remote
dart run pytorch_dart:setup --platform <your_platform>
cd example
flutter run --debug // 或 "flutter run --release"
函数/APIs
就像 Pytorch 一样,Pytorch_Dart 中的函数被分为多个部分。
在当前版本中,API 被分为 3 部分:
torch
torch.tensor
torch.jit
torch
支持的函数
-
torch.tensor()
在 pytorch_dart 中不受支持,使用torch.IntTensor()
,torch.FloatTensor()
或torch.DoubleTensor()
创建张量。 -
当前可用的函数:
- 注意:用
{}
包围的参数是可选参数。
torch.ones(List<int> size, {bool requiresGrad = false, int dtype = float32, Device? device_used}) torch.full(List<int> size, num values, {int dtype = float32, bool requiresGrad = false, Device? device_used}) torch.eye(int n, int m, {bool requiresGrad = false, int dtype = float32, Device? device_used}) torch.IntTensor(List<int> list) torch.FloatTensor(List<double> list) torch.DoubleTensor(List<double> list) torch.arange(double start, double end, double step, {bool requiresGrad = false}) torch.linspace(double start, double end, int steps, {bool requiresGrad = false}) torch.logspace(double start, double end, int steps, double base, {bool requiresGrad = false}) torch.equal(Tensor a, Tensor b) torch.add(Tensor a, Tensor b, {double alpha=1}) torch.sub(Tensor a, Tensor b, {double alpha=1}) torch.mul(Tensor a, Tensor b) torch.div(Tensor a, Tensor b) torch.add_(Tensor a, Tensor b, {double alpha=1}) torch.sub_(Tensor a, Tensor b, {double alpha=1}) torch.mul_(Tensor a, Tensor b) torch.div_(Tensor a, Tensor b) torch.sum(Tensor a) torch.mm(Tensor a, Tensor b) torch.transpose(Tensor a, int dim0, int dim1) torch.permute(Tensor a, List<int> permute_list) torch.save(Tensor a, String path) torch.load(String path) torch.relu() torch.leaky_relu() torch.tanh() torch.sigmoid() torch.flatten(Tensor a, int startDim, int endDim) torch.unsqueeze(Tensor tensor, int dim) torch.clone(Tensor tensor) torch.topk(Tensor a, int k, {int dim = -1, bool largest = true, bool sorted = true}) torch.allClose(Tensor left, Tensor right, {double rtol = 1e-08, double atol = 1e-05, bool equal_nan = false}) torch.empty(List<int> size, {bool requiresGrad = false, int dtype = float32, Device? device_used}) torch.ones(List<int> size, {bool requiresGrad = false, int dtype = float32, Device? device_used}) torch.full(List<int> size, num values, {int dtype = float32, bool requiresGrad = false, Device? device_used}) torch.eye(int n, int m, {bool requiresGrad = false, int dtype = float32, Device? device_used})
- 注意:用
-
几乎所有的函数用法与 PyTorch 保持一致。
-
支持一些就地操作,例如
torch.add_()
。
示例用法
import 'package:pytorch_dart/pytorch_dart.dart' as torch;
...
var c=torch.DoubleTensor([[1.0,2.0,3.0],[4.0,5.0,6.0]]);
var d=torch.add(10,c)
print(d)
结果:
flutter:
11 12 13
14 15 16
[ CPUDoubleType{2,3} ]
torch.tensor
方法
.dim()
.dtype()
.shape()
.size()
.detach()
.add_()
.sub_()
.mul_()
.div_()
.toList()
.unsqueeze(int dim)
.clone()
.relu()
.leaky_relu()
.sigmoid()
.tanh()
.flatten()
.equal(Tensor other)
.sum()
.mm(Tensor other)
.view(List<int> size)
注意: 在 Pytorch_Dart 中,.dtype()
方法与 PyTorch 不同。在 PyTorch 中,.dtype
返回表示张量数据类型的对象。而在 Pytorch_Dart 中,.dtype()
返回数据类型的数值表示。这可能在未来版本中更新。
示例
import 'package:pytorch_dart/pytorch_dart.dart' as torch;
...
var c=torch.DoubleTensor([[1.0,2.0,3.0],[4.0,5.0,6.0]]);
print(c.dtype())
结果:
flutter: 7
更多关于Flutter深度学习集成插件pytorch_dart的使用的实战教程也可以访问 https://www.itying.com/category-92-b0.html
更多关于Flutter深度学习集成插件pytorch_dart的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html
当然,关于在Flutter中集成PyTorch的插件pytorch_dart
,这里是一个简单的代码示例,展示了如何使用该插件来进行深度学习模型的推理。请注意,pytorch_dart
可能是一个假定的或者非官方的插件名称,因此实际使用时可能需要替换为现有的、官方支持的插件,例如torchvision
配合Flutter的原生平台通道(Platform Channels)来实现类似功能。不过,为了回答你的问题,这里将假设pytorch_dart
是一个封装良好的Flutter插件。
Flutter项目结构
首先,确保你的Flutter项目已经创建并配置好。你的项目结构应该类似于:
my_flutter_app/
├── android/
├── ios/
├── lib/
│ ├── main.dart
│ └── model_inference.dart
├── pubspec.yaml
└── ...
pubspec.yaml
在pubspec.yaml
文件中添加对pytorch_dart
的依赖(注意:这是一个假设的依赖名,实际使用时请替换为真实存在的插件):
dependencies:
flutter:
sdk: flutter
pytorch_dart: ^x.y.z # 替换为实际的版本号
main.dart
在main.dart
中,我们可以导入pytorch_dart
插件并调用模型推理功能。这里假设pytorch_dart
提供了加载模型和进行推理的API。
import 'package:flutter/material.dart';
import 'package:pytorch_dart/pytorch_dart.dart';
import 'model_inference.dart';
void main() {
runApp(MyApp());
}
class MyApp extends StatelessWidget {
@override
Widget build(BuildContext context) {
return MaterialApp(
home: Scaffold(
appBar: AppBar(
title: Text('Flutter PyTorch Demo'),
),
body: Center(
child: FutureBuilder<String>(
future: performModelInference(),
builder: (context, snapshot) {
if (snapshot.connectionState == ConnectionState.done) {
if (snapshot.hasError) {
return Text('Error: ${snapshot.error}');
} else {
return Text('Model Output: ${snapshot.data}');
}
} else {
return CircularProgressIndicator();
}
},
),
),
),
);
}
}
model_inference.dart
在model_inference.dart
文件中,我们封装了加载模型和进行推理的逻辑。
import 'package:pytorch_dart/pytorch_dart.dart';
Future<String> performModelInference() async {
// 假设 pytorch_dart 提供了 loadModel 方法来加载模型
final model = await PyTorchModel.loadModel('path/to/your/model.pt');
// 假设模型接受一个输入张量,并返回一个输出张量
final inputTensor = PyTorchTensor.fromList([
// 根据你的模型输入格式调整这个张量
[1.0, 2.0, 3.0],
]);
// 执行推理
final outputTensor = await model.infer(inputTensor);
// 假设输出是一个简单的标量,我们可以直接转换为字符串
final output = outputTensor.toList().first.toString();
return output;
}
注意事项
-
插件依赖:上述代码假设存在一个名为
pytorch_dart
的Flutter插件,该插件提供了加载模型和进行推理的API。实际上,你可能需要寻找或编写一个支持PyTorch的Flutter插件,或者使用原生平台通道来调用PyTorch的C++ API。 -
模型路径:在
model_inference.dart
中,确保模型路径'path/to/your/model.pt'
是正确的,并且该模型已经训练好且兼容你的推理代码。 -
张量格式:根据你的模型输入和输出格式,调整输入张量的创建和输出张量的处理。
-
错误处理:在实际应用中,你应该添加更多的错误处理逻辑,以确保在模型加载或推理失败时能够优雅地处理错误。
-
性能优化:对于复杂的模型或实时应用,你可能需要考虑性能优化,如使用GPU加速等。
由于pytorch_dart
可能是一个不存在的插件名,你可能需要寻找现有的Flutter与PyTorch集成的解决方案,如使用torchmobile
(如果存在)或者通过原生平台通道来实现自定义的集成。