Flutter神经网络插件tfann的使用

Flutter神经网络插件tfann的使用

tfann 是一个用于Flutter的轻量级快速人工神经网络库。它利用内部的tiny SIMD矩阵库,能够将网络结构保存到文件,并且可以从网络生成纯 Dart 代码而没有任何依赖。

开始使用

典型的用法如下:

import 'package:tfann/tfann.dart';

...

// 准备训练数据
List<TrainSetInputOutput> xor_data = [
  TrainSetInputOutput.lists([-1, -1, -1], [0, 0, 0, 0]),
  TrainSetInputOutput.lists([1, 1, -1], [0, 0, 1, 1]),
  TrainSetInputOutput.lists([1, -1, -1], [1, 0, 1, 0]),
  TrainSetInputOutput.lists([-1, 1, -1], [1, 0, 1, 0]),
  TrainSetInputOutput.lists([-1, -1, 1], [1, 0, 1, 0]),
  TrainSetInputOutput.lists([1, 1, 1], [1, 1, 1, 0]),
  TrainSetInputOutput.lists([1, -1, 1], [0, 0, 1, 1]),
  TrainSetInputOutput.lists([-1, 1, 1], [0, 0, 1, 1]),
];

// 创建神经网络
final xor_net = TfannNetwork.full([3, 4, 4], [ActivationFunctionType.uscsls, ActivationFunctionType.uscsls]);

// 在训练前输出网络的预测结果
print("before training...");
xor_data.forEach((data) => print(
  "in: ${data.input.toList()} out: ${xor_net.feedForward(data.input).toList()} expected: ${data.output.toList()}"));

// 训练网络
for (int i = 0; i < 10000; ++i) {
  xor_data.forEach((data) {
    xor_net.train(data, learningRate: 0.04);
  });
}

// 在训练后输出网络的预测结果
print("after training...");

xor_data.forEach((data) => print(
  "in: ${data.input.toList()} out: ${xor_net.feedForward(data.input).toList()} expected: ${data.output.toList()}"));

保存和加载网络

要保存网络到文件:

await xor_net.save("binary.net");

要从文件加载网络:

var xor_net = TfannNetwork.fromFile("binary.net")!;

编译网络为纯 Dart 代码

在开发过程中,可以使用 --enable-asserts 标志来捕获错误。编译后的网络代码没有依赖关系,适合生产阶段使用。

print(compileNetwork(xor_net));

输出如下:

import 'dart:typed_data';
import 'dart:math';

final Float32x4 _SIMD0 = Float32x4.zero();
final Float32x4 _SIMD0_75 = Float32x4.splat(0.75);
final Float32x4 _SIMD0_5 = Float32x4.splat(0.5);
final Float32x4 _SIMD0_25 = Float32x4.splat(0.25);
final Float32x4 _SIMD0_125 = Float32x4.splat(0.125);
final Float32x4 _SIMD0_375 = Float32x4.splat(0.375);
final Float32x4 _SIMD0_625 = Float32x4.splat(0.625);
final Float32x4 _SIMD0_0625 = Float32x4.splat(0.0625);
final Float32x4 _SIMD0_03 = Float32x4.splat(0.03);
final Float32x4 _SIMD0_65625 = Float32x4.splat(0.65625);
final Float32x4 _SIMD0_065 = Float32x4.splat(0.065);
final Float32x4 _SIMD0_185 = Float32x4.splat(0.185);
final Float32x4 _SIMD0_104 = Float32x4.splat(0.104);
final Float32x4 _SIMD0_208 = Float32x4.splat(0.208);
final Float32x4 _SIMD0_704 = Float32x4.splat(0.704);
final Float32x4 _SIMDm0_8 = Float32x4.splat(-0.8);
final Float32x4 _SIMDm1_5 = Float32x4.splat(-1.5);
final Float32x4 _SIMD0_28125 = Float32x4.splat(0.28125);
final Float32x4 _SIMD1 = Float32x4.splat(1);
final Float32x4 _SIMD1_47 = Float32x4.splat(1.47);
final Float32x4 _SIMD1_6 = Float32x4.splat(1.6);
final Float32x4 _SIMD4 = Float32x4.splat(4);
final Float32x4 _SIMD8 = Float32x4.splat(8);
final Float32x4 _SIMDm2 = Float32x4.splat(-2);
final Float32x4 _SIMD0_875 = Float32x4.splat(0.875);
final Float32x4 _SIMD0_4 = Float32x4.splat(0.4);
final Float32x4 _SIMDm0_16 = Float32x4.splat(-0.16);
final Float32x4List _SimdSignMaskVector = Float32x4List.fromList(List.generate(
    16,
    (index) => Float32x4(
        (index & 1) != 0 ? -1.0 : 1.0,
        (index & 2) != 0 ? -1.0 : 1.0,
        (index & 4) != 0 ? -1.0 : 1.0,
        (index & 8) != 0 ? -1.0 : 1.0)));

double uscsls(double x) {
  if (x >= 1.6) return 0.065 * x + 0.704;
  if (x > -0.8) {
    var x2 = x * x;
    var x3 = x2 * x;
    return 0.125 * (x2 - x3) + 0.625 * x;
  }
  return 0.185 * x - 0.208;
}

Float32x4 uscslsX4(Float32x4 x) {
  Int32x4 greater1_6 = x.greaterThan(_SIMD1_6);
  Float32x4 x2 = x * x;

  Float32x4 branch1Result = x.scale(0.065) + _SIMD0_704;
  Float32x4 x3 = x2 * x;

  Int32x4 lessThanMinus0_8 = x.lessThanOrEqual(_SIMDm0_8);
  Float32x4 branch3Result = x.scale(0.185) - _SIMD0_208;

  return greater1_6.select(
      branch1Result,
      lessThanMinus0_8.select(
          branch3Result,
           (x2 - x3).scale(0.125) +  x.scale(0.625)));
}

final List<Float32x4List> Lweight_tfann_evaluate_0 = [Uint32List.fromList([1065924784, 3218828940, 3218824008, 0]).buffer.asFloat32x4List(), Uint32List.fromList([1074832170, 3207276024, 3207270630, 0]).buffer.asFloat32x4List(), Uint32List.fromList([3218045595, 1058827529, 1058838751, 0]).buffer.asFloat32x4List(), Uint32List.fromList([3213025879, 3213257327, 3213261317, 0]).buffer.asFloat32x4List()];
final Float32x4List Lbias_tfann_evaluate_0 = Uint32List.fromList([1051252787, 3212525348, 3213439945, 1049866728]).buffer.asFloat32x4List();
final List<Float32x4List> Lweight_tfann_evaluate_1 = [Uint32List.fromList([3232711821, 1078727539, 3223330061, 1083118854]).buffer.asFloat32x4List(), Uint32List.fromList([3220807383, 3217432562, 3229760405, 3194501247]).buffer.asFloat32x4List(), Uint32List.fromList([3223501112, 1079543989, 1069180988, 3181878151]).buffer.asFloat32x4List(), Uint32List.fromList([1078650051, 1071470358, 1085387923, 3224445642]).buffer.asFloat32x4List()];
final Float32x4List Lbias_tfann_evaluate_1 = Uint32List.fromList([1070831670, 3197145344, 1083611721, 1076128681]).buffer.asFloat32x4List();

List<double> tfann_evaluate(List<double> inData) {
  assert(inData.length == 3);
  Float32List input = Float32List(4);
  for (int i = 0; i < 3; ++i) input[i] = inData[i];
  Float32x4List currentTensor = input.buffer.asFloat32x4List();
  Float32List outputTensor;
  outputTensor = Float32List(4);
  for (int r = 0; r < 4; ++r) {
    Float32x4List weightRow = Lweight_tfann_evaluate_0[r];
    Float32x4 sum = currentTensor[0] * weightRow[0];
    outputTensor[r] = sum.x + sum.y + sum.z;
  }
  currentTensor = outputTensor.buffer.asFloat32x4List();
  currentTensor[0] += Lbias_tfann_evaluate_0[0];
  currentTensor[0] = uscslsX4(currentTensor[0]);
  outputTensor = Float32List(4);
  for (int r = 0; r < 4; ++r) {
    Float32x4List weightRow = Lweight_tfann_evaluate_1[r];
    Float32x4 sum = currentTensor[0] * weightRow[0];
    outputTensor[r] = sum.x + sum.y + sum.z + sum.w;
  }
  currentTensor = outputTensor.buffer.asFloat32x4List();
  currentTensor[0] += Lbias_tfann_evaluate_1[0];
  currentTensor[0] = uscslsX4(currentTensor[0]);
  return currentTensor.buffer.asFloat32List(0, 4).toList();
}

更多关于Flutter神经网络插件tfann的使用的实战教程也可以访问 https://www.itying.com/category-92-b0.html

1 回复

更多关于Flutter神经网络插件tfann的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html


tfann 是一个用于在 Flutter 应用中使用 TensorFlow Lite 进行神经网络推理的插件。它允许你将训练好的 TensorFlow Lite 模型集成到 Flutter 应用中,以便在移动设备上进行推理。

以下是使用 tfann 插件的基本步骤:

1. 添加依赖

首先,在 pubspec.yaml 文件中添加 tfann 插件的依赖:

dependencies:
  flutter:
    sdk: flutter
  tfann: ^0.0.1  # 请检查最新版本

然后运行 flutter pub get 来获取依赖。

2. 导入库

在你的 Dart 文件中导入 tfann

import 'package:tfann/tfann.dart';

3. 加载模型

你需要将 TensorFlow Lite 模型文件(.tflite)放在 assets 目录下,并在 pubspec.yaml 中声明:

flutter:
  assets:
    - assets/model.tflite

然后在代码中加载模型:

Tfann tfann = Tfann();

void loadModel() async {
  await tfann.loadModel("assets/model.tflite");
}

4. 进行推理

加载模型后,你可以使用 tfann 进行推理。假设你的模型接受一个输入并产生一个输出:

void runInference() async {
  List<double> input = [1.0, 2.0, 3.0];  // 输入数据
  List<double> output = await tfann.run(input);
  print(output);  // 输出结果
}

5. 释放资源

在使用完模型后,记得释放资源:

void dispose() {
  tfann.dispose();
}

6. 完整示例

以下是一个完整的示例,展示如何加载模型并进行推理:

import 'package:flutter/material.dart';
import 'package:tfann/tfann.dart';

void main() {
  runApp(MyApp());
}

class MyApp extends StatelessWidget {
  [@override](/user/override)
  Widget build(BuildContext context) {
    return MaterialApp(
      home: HomePage(),
    );
  }
}

class HomePage extends StatefulWidget {
  [@override](/user/override)
  _HomePageState createState() => _HomePageState();
}

class _HomePageState extends State<HomePage> {
  Tfann tfann = Tfann();

  [@override](/user/override)
  void initState() {
    super.initState();
    loadModel();
  }

  void loadModel() async {
    await tfann.loadModel("assets/model.tflite");
    runInference();
  }

  void runInference() async {
    List<double> input = [1.0, 2.0, 3.0];  // 输入数据
    List<double> output = await tfann.run(input);
    print(output);  // 输出结果
  }

  [@override](/user/override)
  void dispose() {
    tfann.dispose();
    super.dispose();
  }

  [@override](/user/override)
  Widget build(BuildContext context) {
    return Scaffold(
      appBar: AppBar(
        title: Text('TFANN Example'),
      ),
      body: Center(
        child: Text('Check the console for inference output.'),
      ),
    );
  }
}
回到顶部