Flutter深度学习模型部署插件flutter_d2go的使用
Flutter深度学习模型部署插件flutter_d2go的使用
flutter_d2go
Flutter插件,用于在移动设备上进行对象检测(Android和iOS)、关键点估计(Android和iOS)和实例分割(仅限Android)。该插件基于Facebook Research的d2go模型。
特性
- 获取类别和边界框(Android和iOS)
- 获取关键点(Android和iOS)
- 获取掩码数据(仅限Android)
- 对相机流图像进行实时推理(仅限Android)
预览
实时推理相机流图像
对象检测和实例分割
关键点估计
安装
将flutter_d2go
添加到您的pubspec.yaml
文件中,并将d2go模型和类别文件放入assets目录。
dependencies:
flutter_d2go: ^版本号
assets:
- assets/models/d2go.pt
- assets/models/classes.txt
使用
1. 加载模型和类别
模型格式为Pytorch。类别文件格式可以查看这里。
await FlutterD2go.loadModel(
modelPath: 'assets/models/d2go.pt', // 必填
labelPath: 'assets/models/classes.txt', // 必填
);
2. 获取静态图像预测
List<Map<String, dynamic>> output = await FlutterD2go.getImagePrediction(
image: image, // 必填 File(dart:io) 图像
width: 320, // 默认值为320
height: 320, // 默认值为320
mean: [0.0, 0.0, 0.0], // 默认值为[0.0, 0.0, 0.0]
std: [1.0, 1.0, 1.0], // 默认值为[1.0, 1.0, 1.0]
minScore: 0.7, // 默认值为0.5
);
3. 获取流图像预测
List<Map<String, dynamic>> output = await FlutterD2go.getStreamImagePrediction(
imageBytesList: cameraImage.planes.map((plane) => plane.bytes).toList(), // 必填 List<Uint8List> 图像字节数组
imageBytesPerPixel: cameraImage.planes.map((plane) => plane.bytesPerPixel).toList(), // 默认值为[1, 2, 2]
width: cameraImage.width, // 默认值为720
height: cameraImage.height, // 默认值为1280
inputWidth: 320, // 默认值为320
inputHeight: 320, // 默认值为320
mean: [0.0, 0.0, 0.0], // 默认值为[0.0, 0.0, 0.0]
std: [1.0, 1.0, 1.0], // 默认值为[1.0, 1.0, 1.0]
minScore: 0.7, // 默认值为0.5
rotation: 90, // 默认值为0
);
预测输出格式
rect
是原始图像的比例。
mask
和 keypoints
取决于d2go模型是否包含mask和关键点。
mask
将是一个Uint8List类型的位图图像字节。
keypoints
将是一个包含17个(x, y)的列表。
[
{
"rect": {
"left": 74.65713500976562,
"top": 76.94147491455078,
"right": 350.64324951171875,
"bottom": 323.0279846191406
},
"mask": [66, 77, 122, 0, 0, 0, 0, 0, 0, 0, 122, ...],
"keypoints": [[117.14504, 77.277405], [122.74037, 73.53044], [105.95437, 73.53044], ...],
"confidenceInClass": 0.985002338886261,
"detectedClass": "bicycle"
}, // 对于每个实例
...
]
问题
如果您发现任何错误或希望添加新功能,请联系此处。
完整示例代码
以下是一个完整的示例代码,展示了如何使用flutter_d2go
插件进行实时推理和静态图像推理。
import 'dart:async';
import 'dart:io';
import 'dart:typed_data';
import 'package:camera/camera.dart';
import 'package:flutter/material.dart';
import 'package:flutter/services.dart';
import 'package:flutter_d2go/flutter_d2go.dart';
import 'package:image_picker/image_picker.dart';
import 'package:path_provider/path_provider.dart';
List<CameraDescription> cameras = [];
Future<void> main() async {
try {
WidgetsFlutterBinding.ensureInitialized();
cameras = await availableCameras();
} on CameraException catch (e) {
debugPrint('Error: ${e.code}, Message: ${e.description}');
}
runApp(
const MaterialApp(
debugShowCheckedModeBanner: false,
home: MyApp(),
),
);
}
class MyApp extends StatefulWidget {
const MyApp({Key? key}) : super(key: key);
[@override](/user/override)
State<MyApp> createState() => _MyAppState();
}
class _MyAppState extends State<MyApp> {
List<RecognitionModel>? _recognitions;
File? _selectedImage;
final List<String> _imageList = ['test1.png', 'test2.jpeg', 'test3.png'];
int _index = 0;
int? _imageWidth;
int? _imageHeight;
final ImagePicker _picker = ImagePicker();
CameraController? controller;
bool _isDetecting = false;
bool _isLiveModeOn = false;
[@override](/user/override)
void initState() {
super.initState();
loadModel();
}
[@override](/user/override)
void dispose() {
controller?.dispose();
super.dispose();
}
Future<void> live() async {
controller = CameraController(
cameras[0],
ResolutionPreset.high,
);
await controller!.initialize().then(
(_) {
if (!mounted) {
return;
}
setState(() {});
},
);
await controller!.startImageStream(
(CameraImage cameraImage) async {
if (_isDetecting) return;
_isDetecting = true;
await FlutterD2go.getStreamImagePrediction(
imageBytesList:
cameraImage.planes.map((plane) => plane.bytes).toList(),
width: cameraImage.width,
height: cameraImage.height,
minScore: 0.5,
rotation: 90,
).then(
(predictions) {
List<RecognitionModel>? recognitions;
if (predictions.isNotEmpty) {
recognitions = predictions.map(
(e) {
return RecognitionModel(
Rectangle(
e['rect']['left'],
e['rect']['top'],
e['rect']['right'],
e['rect']['bottom'],
),
e['mask'],
e['keypoints'] != null
? (e['keypoints'] as List)
.map((k) => Keypoint(k[0], k[1]))
.toList()
: null,
e['confidenceInClass'],
e['detectedClass']);
},
).toList();
}
setState(
() {
// With android, the inference result of the camera streaming image is tilted 90 degrees,
// so the vertical and horizontal directions are reversed.
_imageWidth = cameraImage.height;
_imageHeight = cameraImage.width;
_recognitions = recognitions;
},
);
},
).whenComplete(
() {
setState(() => _isDetecting = false);
},
);
},
);
}
Future loadModel() async {
String modelPath = 'assets/models/d2go.ptl';
String labelPath = 'assets/models/classes.txt';
try {
await FlutterD2go.loadModel(
modelPath: modelPath,
labelPath: labelPath,
);
setState(() {});
} on PlatformException {
debugPrint('Load model or label file failed.');
}
}
Future detect() async {
final image = _selectedImage ??
await getImageFileFromAssets('assets/images/${_imageList[_index]}');
final decodedImage = await decodeImageFromList(image.readAsBytesSync());
final predictions = await FlutterD2go.getImagePrediction(
image: image,
minScore: 0.8,
);
List<RecognitionModel>? recognitions;
if (predictions.isNotEmpty) {
recognitions = predictions.map(
(e) {
return RecognitionModel(
Rectangle(
e['rect']['left'],
e['rect']['top'],
e['rect']['right'],
e['rect']['bottom'],
),
e['mask'],
e['keypoints'] != null
? (e['keypoints'] as List)
.map((k) => Keypoint(k[0], k[1]))
.toList()
: null,
e['confidenceInClass'],
e['detectedClass']);
},
).toList();
}
setState(
() {
_imageWidth = decodedImage.width;
_imageHeight = decodedImage.height;
_recognitions = recognitions;
},
);
}
Future<File> getImageFileFromAssets(String path) async {
final byteData = await rootBundle.load(path);
final fileName = path.split('/').last;
final file = File('${(await getTemporaryDirectory()).path}/$fileName');
await file.writeAsBytes(byteData.buffer
.asUint8List(byteData.offsetInBytes, byteData.lengthInBytes));
return file;
}
[@override](/user/override)
Widget build(BuildContext context) {
double screenWidth = MediaQuery.of(context).size.width;
List<Widget> stackChildren = [];
stackChildren.add(
Positioned(
top: 0.0,
left: 0.0,
width: screenWidth,
child: _selectedImage == null
? Image.asset(
'assets/images/${_imageList[_index]}',
)
: Image.file(_selectedImage!),
),
);
if (_isLiveModeOn) {
stackChildren.add(
Positioned(
top: 0.0,
left: 0.0,
width: screenWidth,
child: CameraPreview(controller!),
),
);
}
if (_recognitions != null) {
final aspectRatio = _imageHeight! / _imageWidth! * screenWidth;
final widthScale = screenWidth / _imageWidth!;
final heightScale = aspectRatio / _imageHeight!;
if (_recognitions!.first.mask != null) {
stackChildren.addAll(_recognitions!.map(
(recognition) {
return RenderSegments(
imageWidthScale: widthScale,
imageHeightScale: heightScale,
recognition: recognition,
);
},
).toList());
}
if (_recognitions!.first.keypoints != null) {
for (RecognitionModel recognition in _recognitions!) {
List<Widget> keypointChildren = [];
for (Keypoint keypoint in recognition.keypoints!) {
keypointChildren.add(
RenderKeypoints(
keypoint: keypoint,
imageWidthScale: widthScale,
imageHeightScale: heightScale,
),
);
}
stackChildren.addAll(keypointChildren);
}
}
stackChildren.addAll(_recognitions!.map(
(recognition) {
return RenderBoxes(
imageWidthScale: widthScale,
imageHeightScale: heightScale,
recognition: recognition,
);
},
).toList());
}
return Scaffold(
appBar: AppBar(
title: const Text('Flutter D2Go'),
backgroundColor: Colors.deepPurpleAccent,
),
body: Column(
mainAxisAlignment: MainAxisAlignment.center,
children: [
const SizedBox(height: 48),
Expanded(
child: Stack(
children: stackChildren,
),
),
const SizedBox(height: 48),
MyButton(
onPressed: !_isLiveModeOn ? detect : null,
text: 'Detect',
),
Padding(
padding: const EdgeInsets.symmetric(vertical: 48),
child: Row(
mainAxisAlignment: MainAxisAlignment.spaceAround,
children: [
MyButton(
onPressed: () {
setState(
() {
_recognitions = null;
if (_selectedImage == null) {
_index != 2 ? _index += 1 : _index = 0;
} else {
_selectedImage = null;
}
},
);
},
text: 'Test Image\n${_index + 1}/${_imageList.length}'),
MyButton(
onPressed: () async {
final XFile? pickedFile =
await _picker.pickImage(source: ImageSource.gallery);
if (pickedFile == null) return;
setState(
() {
_recognitions = null;
_selectedImage = File(pickedFile.path);
},
);
},
text: 'Select'),
MyButton(
onPressed: () async {
_isLiveModeOn
? await controller!.stopImageStream()
: await live();
setState(
() {
_isLiveModeOn = !_isLiveModeOn;
_recognitions = null;
_selectedImage = null;
},
);
},
text: 'Live'),
],
),
),
],
),
);
}
}
class MyButton extends StatelessWidget {
const MyButton({Key? key, required this.onPressed, required this.text})
: super(key: key);
final VoidCallback? onPressed;
final String text;
[@override](/user/override)
Widget build(BuildContext context) {
return SizedBox(
width: 96,
height: 42,
child: ElevatedButton(
onPressed: onPressed,
child: Text(
text,
textAlign: TextAlign.center,
style: const TextStyle(
color: Colors.black,
fontSize: 12,
),
),
style: ElevatedButton.styleFrom(
primary: Colors.grey[300],
elevation: 0,
),
),
);
}
}
class RenderBoxes extends StatelessWidget {
const RenderBoxes({
Key? key,
required this.recognition,
required this.imageWidthScale,
required this.imageHeightScale,
}) : super(key: key);
final RecognitionModel recognition;
final double imageWidthScale;
final double imageHeightScale;
[@override](/user/override)
Widget build(BuildContext context) {
final left = recognition.rect.left * imageWidthScale;
final top = recognition.rect.top * imageHeightScale;
final right = recognition.rect.right * imageWidthScale;
final bottom = recognition.rect.bottom * imageHeightScale;
return Positioned(
left: left,
top: top,
width: right - left,
height: bottom - top,
child: Container(
decoration: BoxDecoration(
borderRadius: const BorderRadius.all(Radius.circular(8.0)),
border: Border.all(
color: Colors.yellow,
width: 2,
),
),
child: Text(
"${recognition.detectedClass} ${(recognition.confidenceInClass * 100).toStringAsFixed(0)}%",
style: TextStyle(
background: Paint()..color = Colors.yellow,
color: Colors.black,
fontSize: 15.0,
),
),
),
);
}
}
class RenderSegments extends StatelessWidget {
const RenderSegments({
Key? key,
required this.recognition,
required this.imageWidthScale,
required this.imageHeightScale,
}) : super(key: key);
final RecognitionModel recognition;
final double imageWidthScale;
final double imageHeightScale;
[@override](/user/override)
Widget build(BuildContext context) {
final left = recognition.rect.left * imageWidthScale;
final top = recognition.rect.top * imageHeightScale;
final right = recognition.rect.right * imageWidthScale;
final bottom = recognition.rect.bottom * imageHeightScale;
final mask = recognition.mask!;
return Positioned(
left: left,
top: top,
width: right - left,
height: bottom - top,
child: Image.memory(
mask,
fit: BoxFit.fill,
),
);
}
}
class RenderKeypoints extends StatelessWidget {
const RenderKeypoints({
Key? key,
required this.keypoint,
required this.imageWidthScale,
required this.imageHeightScale,
}) : super(key: key);
final Keypoint keypoint;
final double imageWidthScale;
final double imageHeightScale;
[@override](/user/override)
Widget build(BuildContext context) {
final x = keypoint.x * imageWidthScale;
final y = keypoint.y * imageHeightScale;
return Positioned(
left: x,
top: y,
child: Container(
width: 8,
height: 8,
decoration: const BoxDecoration(
color: Colors.red,
shape: BoxShape.circle,
),
),
);
}
}
class RecognitionModel {
RecognitionModel(
this.rect,
this.mask,
this.keypoints,
this.confidenceInClass,
this.detectedClass,
);
Rectangle rect;
Uint8List? mask;
List<Keypoint>? keypoints;
double confidenceInClass;
String detectedClass;
}
class Rectangle {
Rectangle(this.left, this.top, this.right, this.bottom);
double left;
double top;
double right;
double bottom;
}
class Keypoint {
Keypoint(this.x, this.y);
double x;
double y;
}
更多关于Flutter深度学习模型部署插件flutter_d2go的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html
更多关于Flutter深度学习模型部署插件flutter_d2go的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html
当然,以下是如何在Flutter项目中使用flutter_d2go
插件来部署深度学习模型的示例代码。flutter_d2go
是一个用于在Flutter应用中部署深度学习模型的插件,支持MobileNetV2、SSD等模型。
步骤1:添加依赖
首先,在你的pubspec.yaml
文件中添加flutter_d2go
的依赖:
dependencies:
flutter:
sdk: flutter
flutter_d2go: ^最新版本号 # 请替换为最新的发布版本号
然后运行flutter pub get
来安装依赖。
步骤2:配置Android项目
由于flutter_d2go
依赖TensorFlow Lite,你需要在Android项目中进行一些配置。
在android/app/build.gradle
中添加以下配置:
android {6
', '...x
8 6default_Config6 {4
'
...
}}
ndk {}
abiFilters 'armeabi-v7a', 'arm64-v8a', 'x8
dependencies {
...
implementation 'org.tensorflow:tensorflow-lite:2.7.0' # 请使用与flutter_d2go兼容的版本
implementation 'org.tensorflow:tensorflow-lite-gpu:2.7.0' # 如果需要GPU加速
implementation 'org.tensorflow:tensorflow-lite-select-tf-ops:2.7.0' # 如果需要选择性TF操作
}
步骤3:加载和推理模型
以下是一个简单的示例,展示了如何使用flutter_d2go
加载一个MobileNetV2模型并进行图像分类。
Dart代码
import 'package:flutter/material.dart';
import 'package:flutter_d2go/flutter_d2go.dart';
import 'package:image_picker/image_picker.dart';
void main() {
runApp(MyApp());
}
class MyApp extends StatelessWidget {
@override
Widget build(BuildContext context) {
return MaterialApp(
home: Scaffold(
appBar: AppBar(
title: Text('Flutter D2Go Example'),
),
body: Center(
child: MyHomePage(),
),
),
);
}
}
class MyHomePage extends StatefulWidget {
@override
_MyHomePageState createState() => _MyHomePageState();
}
class _MyHomePageState extends State<MyHomePage> {
final ImagePicker _picker = ImagePicker();
Interpreter? _interpreter;
List<dynamic>? _labels;
String? _result;
@override
void initState() {
super.initState();
loadModel();
}
Future<void> loadModel() async {
try {
// 加载MobileNetV2模型
_interpreter = await Interpreter.fromAsset('models/mobilenet_v2_1.0_224.tflite');
// 加载标签文件(假设标签文件为labels.txt,每行一个标签)
_labels = await FileUtil.readLinesFromAsset('models/labels.txt');
} catch (e) {
print('Failed to load model: $e');
}
}
Future<void> pickImage() async {
final XFile? image = await _picker.pickImage(source: ImageSource.camera);
if (image == null) return;
final Uint8List imageBytes = await image.readAsBytes();
final TensorImage inputTensor = TensorImage.fromBytes(
imageBytes,
TensorImage.CHANNEL_LAST,
inputShape: [1, 224, 224, 3],
dataType: DataType.UINT8,
mean: 127.5,
std: 127.5,
);
List<TensorBuffer> outputBuffers = await _interpreter!.run(inputTensor);
final TensorBuffer outputBuffer = outputBuffers[0];
final Float32List probabilities = outputBuffer.getFloat32List();
// 找到概率最高的标签
final int bestLabelIndex = probabilities.indexOf(probabilities.reduce((a, b) => Math.max(a, b)));
setState(() {
_result = _labels?[bestLabelIndex] ?? 'Unknown';
});
}
@override
Widget build(BuildContext context) {
return Column(
mainAxisAlignment: MainAxisAlignment.center,
children: [
TextButton(
onPressed: pickImage,
child: Text('Pick an image'),
),
if (_result != null)
Text(
'Prediction: $_result',
style: TextStyle(fontSize: 24),
),
],
);
}
}
注意事项
-
模型文件:确保你的模型文件(如
mobilenet_v2_1.0_224.tflite
和labels.txt
)已放置在android/app/src/main/assets/models/
目录下。 -
权限:由于需要从相机或图库中选择图像,你需要在
android/app/src/main/AndroidManifest.xml
中添加相机和存储权限。
<uses-permission android:name="android.permission.CAMERA" />
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE" />
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE" />
- Image Picker:上面的示例使用了
image_picker
插件来选择图像,因此你还需要在pubspec.yaml
中添加image_picker
的依赖。
dependencies:
flutter:
sdk: flutter
flutter_d2go: ^最新版本号
image_picker: ^最新版本号
通过上述步骤,你就可以在Flutter应用中部署并使用深度学习模型进行推理了。