Flutter机器学习自定义模型插件firebase_ml_custom的使用

Flutter机器学习自定义模型插件firebase_ml_custom的使用

简介

pub package

firebase_ml_custom 是一个用于在 Flutter 中使用 Firebase ML 自定义模型 API 的插件。通过该插件,您可以轻松地将自定义机器学习模型集成到您的 Flutter 应用程序中。

有关其他 Firebase 产品的 Flutter 插件,请参阅 FlutterFire 官方仓库


使用方法

要使用此插件,请在 pubspec.yaml 文件中添加以下依赖项:

dependencies:
  firebase_ml_custom: ^latest_version

然后运行 flutter pub get 来安装依赖。

此外,您还需要为每个平台(Android 和 iOS)配置 Firebase。具体步骤可以参考 官方文档 或示例项目中的说明。


平台配置

Android

为了使用与 FirebaseModelManager 相关的方法,最低 SDK 版本要求为 24。如果低于 24,则最低版本为 21。可以在应用级 build.gradle 文件中指定此版本。

iOS

需要最低部署目标为 9.0。可以在 iOS 项目的 Podfile 中添加以下行:

platform :ios, '9.0'

同时,建议在 Xcode 中将应用程序的部署目标更新为 9.0,以避免编译错误。


使用 Firebase 模型管理器

以下是使用 firebase_ml_custom 插件加载和管理自定义模型的基本步骤。


1. 创建一个 FirebaseCustomRemoteModel

首先,创建一个 FirebaseCustomRemoteModel 对象。确保您已经在 Firebase 控制台中上传了模型,并使用模型名称进行初始化。

FirebaseCustomRemoteModel remoteModel = FirebaseCustomRemoteModel('myModelName');

2. 创建一个 FirebaseModelDownloadConditions

接下来,创建一个 FirebaseModelDownloadConditions 对象,并设置下载条件。例如,您可以限制仅在 Wi-Fi 下下载模型。

FirebaseModelDownloadConditions conditions = FirebaseModelDownloadConditions(
    androidRequireWifi: true,
    iosAllowCellularAccess: false,
);

默认情况下,除了 iosAllowCellularAccess 之外的所有参数都为 falseiosAllowCellularAccess 默认为 true


3. 创建一个 FirebaseModelManager 实例

创建一个与默认 FirebaseApp 实例关联的 FirebaseModelManager 对象。

FirebaseModelManager modelManager = FirebaseModelManager.instance;

4. 调用 download() 方法

调用 download() 方法以开始下载远程模型。如果模型已下载或正在下载,则不会重复执行下载任务。

await modelManager.download(remoteModel, conditions);

5. 检查模型是否已下载

使用 isModelDownloaded() 方法检查模型是否已成功下载。

if (await modelManager.isModelDownloaded(remoteModel)) {
    // 模型已下载,可以使用它进行推理
} else {
    // 模型未下载,可能需要回退到本地模型或采取其他措施
}

您还可以通过 try-catch 块捕获下载过程中的异常。


6. 获取最新模型文件

使用 getLatestModelFile() 方法获取模型文件的路径。

File modelFile = await modelManager.getLatestModelFile(remoteModel);

您可以将此文件直接传递给解释器,或者根据需要对其进行预处理。


示例代码

以下是一个完整的示例代码,展示了如何在 Flutter 中使用 firebase_ml_custom 插件加载和使用自定义模型。

// ignore_for_file: require_trailing_commas
// Copyright 2020, the Chromium project authors.  Please see the AUTHORS file
// for details. All rights reserved. Use of this source code is governed by a
// BSD-style license that can be found in the LICENSE file.

import 'dart:async';
import 'dart:io';

import 'package:firebase_ml_custom/firebase_ml_custom.dart';
import 'package:flutter/material.dart';
import 'package:image_picker/image_picker.dart';

void main() {
  runApp(
    MaterialApp(
      home: MyApp(),
    ),
  );
}

class MyApp extends StatefulWidget {
  [@override](/user/override)
  _MyAppState createState() => _MyAppState();
}

class _MyAppState extends State<MyApp> {
  File _image;
  List<Map<dynamic, dynamic>> _labels;

  Future<String> _loaded = loadModel();

  Future<void> getImageLabels() async {
    try {
      final pickedFile = await ImagePicker().getImage(source: ImageSource.gallery);
      final image = File(pickedFile.path);
      if (image == null) {
        return;
      }
      var labels = List<Map>.from([]);
      setState(() {
        _labels = labels;
        _image = image;
      });
    } catch (exception) {
      print("Failed on getting your image and it's labels: $exception");
      print('Continuing with the program...');
      rethrow;
    }
  }

  static Future<String> loadModel() async {
    final modelFile = await loadModelFromFirebase();
    return loadTFLiteModel(modelFile);
  }

  static Future<File> loadModelFromFirebase() async {
    try {
      final model = FirebaseCustomRemoteModel('mobilenet_v1_1_0_224');
      final conditions = FirebaseModelDownloadConditions(
          androidRequireWifi: true, iosAllowCellularAccess: false);
      final modelManager = FirebaseModelManager.instance;
      await modelManager.download(model, conditions);
      assert(await modelManager.isModelDownloaded(model) == true);
      var modelFile = await modelManager.getLatestModelFile(model);
      assert(modelFile != null);
      return modelFile;
    } catch (exception) {
      print('Failed on loading your model from Firebase: $exception');
      print('The program will not be resumed');
      rethrow;
    }
  }

  static Future<String> loadTFLiteModel(File modelFile) async {
    try {
      return 'Model is loaded';
    } catch (exception) {
      print('Failed on loading your model to the TFLite interpreter: $exception');
      print('The program will not be resumed');
      rethrow;
    }
  }

  Widget readyScreen() {
    return Scaffold(
      appBar: AppBar(
        title: const Text('Firebase ML Custom example app'),
      ),
      body: Column(
        children: [
          if (_image != null)
            Image.file(_image)
          else
            const Text('Please select image to analyze.'),
          Column(
            children: _labels != null
                ? _labels.map((label) {
                    return Text("${label["label"]}");
                  }).toList()
                : [],
          ),
        ],
      ),
      floatingActionButton: FloatingActionButton(
        onPressed: getImageLabels,
        child: const Icon(Icons.add),
      ),
    );
  }

  Widget errorScreen() {
    return const Scaffold(
      body: Center(
        child: Text('Error loading model. Please check the logs.'),
      ),
    );
  }

  Widget loadingScreen() {
    return Scaffold(
      body: Center(
        child: Column(
          mainAxisAlignment: MainAxisAlignment.center,
          children: const [
            Padding(
              padding: EdgeInsets.only(bottom: 20),
              child: CircularProgressIndicator(),
            ),
            Text('Please make sure that you are using wifi.'),
          ],
        ),
      ),
    );
  }

  [@override](/user/override)
  Widget build(BuildContext context) {
    return DefaultTextStyle(
      style: Theme.of(context).textTheme.headline2,
      textAlign: TextAlign.center,
      child: FutureBuilder<String>(
        future: _loaded,
        builder: (BuildContext context, AsyncSnapshot<String> snapshot) {
          if (snapshot.hasData) {
            return readyScreen();
          } else if (snapshot.hasError) {
            return errorScreen();
          } else {
            return loadingScreen();
          }
        },
      ),
    );
  }
}

更多关于Flutter机器学习自定义模型插件firebase_ml_custom的使用的实战教程也可以访问 https://www.itying.com/category-92-b0.html

1 回复

更多关于Flutter机器学习自定义模型插件firebase_ml_custom的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html


Flutter 提供了 firebase_ml_custom 插件,允许你在 Flutter 应用中使用 Firebase ML Kit 的自定义模型。通过这个插件,你可以将训练好的 TensorFlow Lite 模型集成到你的应用中,并在设备上进行推理。

以下是使用 firebase_ml_custom 插件的基本步骤:

1. 添加依赖

首先,在 pubspec.yaml 文件中添加 firebase_ml_custom 插件的依赖:

dependencies:
  flutter:
    sdk: flutter
  firebase_core: latest_version
  firebase_ml_custom: latest_version

然后运行 flutter pub get 来安装依赖。

2. 初始化 Firebase

在使用 Firebase ML Kit 之前,你需要初始化 Firebase。在 main.dart 文件中进行初始化:

import 'package:firebase_core/firebase_core.dart';
import 'package:flutter/material.dart';

void main() async {
  WidgetsFlutterBinding.ensureInitialized();
  await Firebase.initializeApp();
  runApp(MyApp());
}

class MyApp extends StatelessWidget {
  [@override](/user/override)
  Widget build(BuildContext context) {
    return MaterialApp(
      title: 'Flutter Demo',
      theme: ThemeData(
        primarySwatch: Colors.blue,
      ),
      home: MyHomePage(),
    );
  }
}

3. 加载自定义模型

接下来,你需要加载自定义的 TensorFlow Lite 模型。假设你已经将模型文件 model.tflite 放在 assets 文件夹中。

首先,在 pubspec.yaml 文件中声明模型文件:

flutter:
  assets:
    - assets/model.tflite

然后,在代码中加载模型:

import 'package:firebase_ml_custom/firebase_ml_custom.dart';

Future<CustomRemoteModel> loadModel() async {
  final model = await FirebaseCustomRemoteModel.getModel('my_model');
  return model;
}

4. 创建模型解释器

加载模型后,你需要创建一个模型解释器来进行推理:

Future<CustomModelInterpreter> createInterpreter(CustomRemoteModel model) async {
  final options = CustomModelInterpreterOptions(model);
  final interpreter = await CustomModelInterpreter.create(options);
  return interpreter;
}

5. 进行推理

现在你可以使用模型解释器来进行推理。假设你的模型输入是一个 1x224x224x3 的张量,输出是一个 1x1000 的张量:

Future<List<double>> runInference(CustomModelInterpreter interpreter) async {
  final input = List<double>.filled(1 * 224 * 224 * 3, 0.0); // 假设输入数据
  final inputBuffer = CustomModelInputOutputOptions();
  inputBuffer.setInput(0, [1, 224, 224, 3], input);

  final outputBuffer = CustomModelInputOutputOptions();
  outputBuffer.setOutput(0, [1, 1000], List<double>.filled(1 * 1000, 0.0));

  final result = await interpreter.run(inputBuffer, outputBuffer);
  return result.getOutput(0);
}

6. 完整示例

以下是一个完整的示例,展示了如何加载模型、创建解释器并进行推理:

import 'package:firebase_core/firebase_core.dart';
import 'package:firebase_ml_custom/firebase_ml_custom.dart';
import 'package:flutter/material.dart';

void main() async {
  WidgetsFlutterBinding.ensureInitialized();
  await Firebase.initializeApp();
  runApp(MyApp());
}

class MyApp extends StatelessWidget {
  [@override](/user/override)
  Widget build(BuildContext context) {
    return MaterialApp(
      title: 'Flutter Demo',
      theme: ThemeData(
        primarySwatch: Colors.blue,
      ),
      home: MyHomePage(),
    );
  }
}

class MyHomePage extends StatefulWidget {
  [@override](/user/override)
  _MyHomePageState createState() => _MyHomePageState();
}

class _MyHomePageState extends State<MyHomePage> {
  List<double> _output = [];

  [@override](/user/override)
  void initState() {
    super.initState();
    _loadModelAndRunInference();
  }

  Future<void> _loadModelAndRunInference() async {
    final model = await FirebaseCustomRemoteModel.getModel('my_model');
    final interpreter = await CustomModelInterpreter.create(CustomModelInterpreterOptions(model));
    final output = await _runInference(interpreter);
    setState(() {
      _output = output;
    });
  }

  Future<List<double>> _runInference(CustomModelInterpreter interpreter) async {
    final input = List<double>.filled(1 * 224 * 224 * 3, 0.0); // 假设输入数据
    final inputBuffer = CustomModelInputOutputOptions();
    inputBuffer.setInput(0, [1, 224, 224, 3], input);

    final outputBuffer = CustomModelInputOutputOptions();
    outputBuffer.setOutput(0, [1, 1000], List<double>.filled(1 * 1000, 0.0));

    final result = await interpreter.run(inputBuffer, outputBuffer);
    return result.getOutput(0);
  }

  [@override](/user/override)
  Widget build(BuildContext context) {
    return Scaffold(
      appBar: AppBar(
        title: Text('Flutter ML Custom Model'),
      ),
      body: Center(
        child: Column(
          mainAxisAlignment: MainAxisAlignment.center,
          children: <Widget>[
            Text('Inference Output:'),
            Text(_output.toString()),
          ],
        ),
      ),
    );
  }
}
回到顶部