Flutter神经网络处理插件loredart_nn的使用

发布于 1周前 作者 yibo5220 来自 Flutter

Flutter神经网络处理插件loredart_nn的使用

简介

简单库用于创建和训练深度神经网络,完全用Dart编写。

概念上,该库具有教育和娱乐目的。希望你会觉得使用Dart编写的神经网络很有趣。


开始使用

首先,将库导入项目中。

import 'package:loredart_nn/loredart_nn.dart';

使用NeuralNetwork

loredart_nn中,数据在NeuralNetwork中以(mini)batch矩阵形式存储,其中每一列代表一个数据记录,每一行代表一个特征。因此,NeuralNetwork使用基于列的数据表示法。

但用户输入和模型预测是List<List<double>>,其中每一行是一个数据记录。

以下是一个创建MNIST分类深度神经网络的小例子:


加载数据

// 每个数字的784像素列表
List<List<double>> xTrain = data[0].sublist(0,30000);
// 一位有效编码的数字标签
List<List<double>> yTrain = data[1].sublist(0,30000);
List<List<double>> xTest = data[0].sublist(30000);
List<List<double>> yTest = data[1].sublist(30000);

这里使用的是MNIST的扁平版本,即将数字图像展平为784像素的记录。

定义模型

var model = NeuralNetwork(
  784, // 输入记录长度=784像素
  [
    Dense(128, activation: Activation.softplus()), // 全连接层
    LayerNormalization(), // 批量规范化
    Dense(10, activation: Activation.softmax()) // 输出层,softmax函数用于概率计算
  ],
  loss: Loss.crossEntropy(), // 交叉熵损失,适用于一位有效编码的目标值
  optimizer: SGD(learningRate: 0.01, momentum: 0.9), // 可选地自定义带有动量的SGD优化器
  useAccuracyMetric: true // 对于分类任务可以使用‘准确率’度量
);

训练模型

var history = model.fit(xTrain, yTrain, epochs: 4, batchSize: 256, verbose: true);

verbose == true时,你会看到每次批次更新后模型的变化情况,以及每个周期的摘要信息,如下所示:

epoch 1/4 |118/118| -> mean time per batch: 388.01ms, mean loss [cross_entropy]: 1.478636, mean accuracy: 68.05%
epoch 2/4 |118/118| -> mean time per batch: 400.87ms, mean loss [cross_entropy]: 0.868227, mean accuracy: 83.50%
epoch 3/4 |118/118| -> mean time per batch: 390.92ms, mean loss [cross_entropy]: 0.691678, mean accuracy: 85.72%
epoch 4/4 |118/118| -> mean time per batch: 386.92ms, mean loss [cross_entropy]: 0.606761, mean accuracy: 86.84%

你可以控制批次大小和训练周期数。模型的fit方法返回history——一个包含每个周期损失信息的Map。

print(history);
// {cross_entropy: [1.4786360357701471, 0.868226552200313, 0.6916779963409534, 0.6067611970768912],
//  accuracy: [68.04819915254238, 83.49995586158192, 85.72342867231639, 86.83902718926552]}

测试模型

var metrics = model.evaluate(xTest, yTest, verbose: true);
print(metrics);
// {mean cross_entropy: 0.5742549607727949, mean accuracy: 0.8719166666666667}

同样地,当verbose == true时,你会看到更多的信息:

// evaluating batch 100/100 -> mean time per batch: 68.52ms, mean loss [cross_entropy]: 0.574255, mean accuracy: 87.19%

evaluate方法返回一个包含平均损失信息的Map对象。

使用模型

预测是针对多行数据执行的,并且输出是一个包含每个输入的预测结果的列表。

// 数据是一些输入列表
List<List<double>> data = ...;

List<List<double>> prediction = model.predict(data); // model.predict 返回数据中每行的预测结果
print(prediction[0]);
// 打印类似于 [0.00, 0.00, 0.00, 0.99, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00]

保存和加载模型权重

你可以将可训练(即Dense)层的权重和偏置保存到某个目录中。

// 将参数保存到 `mnist_classifier/model_weights.bin` 文件中
model.saveWeights('mnist_classifier');

请注意,saveWeights 方法仅保存Dense层的权重和偏置。

然后你可以将权重和偏置加载到模型中,但要确保使用适当的架构,否则加载后的模型结果可能不符合预期。

model.loadWeights('mnist_classifier');

如果你想从Flutter资源中加载权重,可以使用loadWeightsFromBytes 方法。

var model = NeuralNetwork(...);
rootBundle.load('assets/model/model_weights.bin').then((value) {
  model.loadWeightsFromBytes(value.buffer);
});

更多关于Flutter神经网络处理插件loredart_nn的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html

1 回复

更多关于Flutter神经网络处理插件loredart_nn的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html


loredart_nn 是一个用于在 Flutter 中进行神经网络处理的插件。它允许你在 Flutter 应用中集成和运行神经网络模型,进行推理、训练等操作。以下是如何使用 loredart_nn 插件的基本步骤。

1. 添加依赖

首先,你需要在 pubspec.yaml 文件中添加 loredart_nn 插件的依赖。

dependencies:
  flutter:
    sdk: flutter
  loredart_nn: ^1.0.0  # 请确保使用最新版本

然后运行 flutter pub get 来获取依赖。

2. 导入插件

在你的 Dart 文件中导入 loredart_nn 插件。

import 'package:loredart_nn/loredart_nn.dart';

3. 初始化神经网络

你可以使用 loredart_nn 提供的 API 来初始化一个神经网络模型。假设你已经有一个训练好的模型文件,你可以加载它。

void initNeuralNetwork() async {
  // 假设你有一个模型文件 'model.tflite'
  String modelPath = 'assets/model.tflite';
  LoredartNN nn = LoredartNN();

  // 加载模型
  await nn.loadModel(modelPath);
}

4. 进行推理

一旦模型加载完成,你可以使用它进行推理。假设你有一些输入数据,你可以将其传递给模型并获取输出。

void runInference() async {
  // 假设你有一些输入数据
  List<double> inputData = [1.0, 2.0, 3.0, 4.0];

  // 进行推理
  List<double> output = await nn.runModel(inputData);

  // 输出结果
  print('Inference result: $output');
}

5. 处理输出

根据你的应用需求,你可能需要对模型的输出进行进一步处理或显示在 UI 上。

void displayResult(List<double> output) {
  // 假设你有一个 Text widget 来显示结果
  Text('Result: $output');
}

6. 释放资源

当你不再需要模型时,记得释放资源。

void disposeNeuralNetwork() {
  nn.dispose();
}

7. 完整示例

以下是一个完整的示例,展示如何在 Flutter 应用中使用 loredart_nn 插件。

import 'package:flutter/material.dart';
import 'package:loredart_nn/loredart_nn.dart';

void main() {
  runApp(MyApp());
}

class MyApp extends StatelessWidget {
  [@override](/user/override)
  Widget build(BuildContext context) {
    return MaterialApp(
      home: NeuralNetworkDemo(),
    );
  }
}

class NeuralNetworkDemo extends StatefulWidget {
  [@override](/user/override)
  _NeuralNetworkDemoState createState() => _NeuralNetworkDemoState();
}

class _NeuralNetworkDemoState extends State<NeuralNetworkDemo> {
  LoredartNN nn = LoredartNN();
  List<double> result = [];

  [@override](/user/override)
  void initState() {
    super.initState();
    initNeuralNetwork();
  }

  void initNeuralNetwork() async {
    String modelPath = 'assets/model.tflite';
    await nn.loadModel(modelPath);
  }

  void runInference() async {
    List<double> inputData = [1.0, 2.0, 3.0, 4.0];
    List<double> output = await nn.runModel(inputData);
    setState(() {
      result = output;
    });
  }

  [@override](/user/override)
  void dispose() {
    nn.dispose();
    super.dispose();
  }

  [@override](/user/override)
  Widget build(BuildContext context) {
    return Scaffold(
      appBar: AppBar(
        title: Text('Neural Network Demo'),
      ),
      body: Center(
        child: Column(
          mainAxisAlignment: MainAxisAlignment.center,
          children: [
            ElevatedButton(
              onPressed: runInference,
              child: Text('Run Inference'),
            ),
            SizedBox(height: 20),
            Text('Result: $result'),
          ],
        ),
      ),
    );
  }
}
回到顶部
AI 助手
你好,我是IT营的 AI 助手
您可以尝试点击下方的快捷入口开启体验!