Flutter神经网络插件tfann的使用
Flutter神经网络插件tfann的使用
tfann
是一个用于Flutter的轻量级快速人工神经网络库。它利用内部的tiny SIMD矩阵库,能够将网络结构保存到文件,并且可以从网络生成纯 Dart 代码而没有任何依赖。
开始使用
典型的用法如下:
import 'package:tfann/tfann.dart';
...
// 准备训练数据
List<TrainSetInputOutput> xor_data = [
TrainSetInputOutput.lists([-1, -1, -1], [0, 0, 0, 0]),
TrainSetInputOutput.lists([1, 1, -1], [0, 0, 1, 1]),
TrainSetInputOutput.lists([1, -1, -1], [1, 0, 1, 0]),
TrainSetInputOutput.lists([-1, 1, -1], [1, 0, 1, 0]),
TrainSetInputOutput.lists([-1, -1, 1], [1, 0, 1, 0]),
TrainSetInputOutput.lists([1, 1, 1], [1, 1, 1, 0]),
TrainSetInputOutput.lists([1, -1, 1], [0, 0, 1, 1]),
TrainSetInputOutput.lists([-1, 1, 1], [0, 0, 1, 1]),
];
// 创建神经网络
final xor_net = TfannNetwork.full([3, 4, 4], [ActivationFunctionType.uscsls, ActivationFunctionType.uscsls]);
// 在训练前输出网络的预测结果
print("before training...");
xor_data.forEach((data) => print(
"in: ${data.input.toList()} out: ${xor_net.feedForward(data.input).toList()} expected: ${data.output.toList()}"));
// 训练网络
for (int i = 0; i < 10000; ++i) {
xor_data.forEach((data) {
xor_net.train(data, learningRate: 0.04);
});
}
// 在训练后输出网络的预测结果
print("after training...");
xor_data.forEach((data) => print(
"in: ${data.input.toList()} out: ${xor_net.feedForward(data.input).toList()} expected: ${data.output.toList()}"));
保存和加载网络
要保存网络到文件:
await xor_net.save("binary.net");
要从文件加载网络:
var xor_net = TfannNetwork.fromFile("binary.net")!;
编译网络为纯 Dart 代码
在开发过程中,可以使用 --enable-asserts
标志来捕获错误。编译后的网络代码没有依赖关系,适合生产阶段使用。
print(compileNetwork(xor_net));
输出如下:
import 'dart:typed_data';
import 'dart:math';
final Float32x4 _SIMD0 = Float32x4.zero();
final Float32x4 _SIMD0_75 = Float32x4.splat(0.75);
final Float32x4 _SIMD0_5 = Float32x4.splat(0.5);
final Float32x4 _SIMD0_25 = Float32x4.splat(0.25);
final Float32x4 _SIMD0_125 = Float32x4.splat(0.125);
final Float32x4 _SIMD0_375 = Float32x4.splat(0.375);
final Float32x4 _SIMD0_625 = Float32x4.splat(0.625);
final Float32x4 _SIMD0_0625 = Float32x4.splat(0.0625);
final Float32x4 _SIMD0_03 = Float32x4.splat(0.03);
final Float32x4 _SIMD0_65625 = Float32x4.splat(0.65625);
final Float32x4 _SIMD0_065 = Float32x4.splat(0.065);
final Float32x4 _SIMD0_185 = Float32x4.splat(0.185);
final Float32x4 _SIMD0_104 = Float32x4.splat(0.104);
final Float32x4 _SIMD0_208 = Float32x4.splat(0.208);
final Float32x4 _SIMD0_704 = Float32x4.splat(0.704);
final Float32x4 _SIMDm0_8 = Float32x4.splat(-0.8);
final Float32x4 _SIMDm1_5 = Float32x4.splat(-1.5);
final Float32x4 _SIMD0_28125 = Float32x4.splat(0.28125);
final Float32x4 _SIMD1 = Float32x4.splat(1);
final Float32x4 _SIMD1_47 = Float32x4.splat(1.47);
final Float32x4 _SIMD1_6 = Float32x4.splat(1.6);
final Float32x4 _SIMD4 = Float32x4.splat(4);
final Float32x4 _SIMD8 = Float32x4.splat(8);
final Float32x4 _SIMDm2 = Float32x4.splat(-2);
final Float32x4 _SIMD0_875 = Float32x4.splat(0.875);
final Float32x4 _SIMD0_4 = Float32x4.splat(0.4);
final Float32x4 _SIMDm0_16 = Float32x4.splat(-0.16);
final Float32x4List _SimdSignMaskVector = Float32x4List.fromList(List.generate(
16,
(index) => Float32x4(
(index & 1) != 0 ? -1.0 : 1.0,
(index & 2) != 0 ? -1.0 : 1.0,
(index & 4) != 0 ? -1.0 : 1.0,
(index & 8) != 0 ? -1.0 : 1.0)));
double uscsls(double x) {
if (x >= 1.6) return 0.065 * x + 0.704;
if (x > -0.8) {
var x2 = x * x;
var x3 = x2 * x;
return 0.125 * (x2 - x3) + 0.625 * x;
}
return 0.185 * x - 0.208;
}
Float32x4 uscslsX4(Float32x4 x) {
Int32x4 greater1_6 = x.greaterThan(_SIMD1_6);
Float32x4 x2 = x * x;
Float32x4 branch1Result = x.scale(0.065) + _SIMD0_704;
Float32x4 x3 = x2 * x;
Int32x4 lessThanMinus0_8 = x.lessThanOrEqual(_SIMDm0_8);
Float32x4 branch3Result = x.scale(0.185) - _SIMD0_208;
return greater1_6.select(
branch1Result,
lessThanMinus0_8.select(
branch3Result,
(x2 - x3).scale(0.125) + x.scale(0.625)));
}
final List<Float32x4List> Lweight_tfann_evaluate_0 = [Uint32List.fromList([1065924784, 3218828940, 3218824008, 0]).buffer.asFloat32x4List(), Uint32List.fromList([1074832170, 3207276024, 3207270630, 0]).buffer.asFloat32x4List(), Uint32List.fromList([3218045595, 1058827529, 1058838751, 0]).buffer.asFloat32x4List(), Uint32List.fromList([3213025879, 3213257327, 3213261317, 0]).buffer.asFloat32x4List()];
final Float32x4List Lbias_tfann_evaluate_0 = Uint32List.fromList([1051252787, 3212525348, 3213439945, 1049866728]).buffer.asFloat32x4List();
final List<Float32x4List> Lweight_tfann_evaluate_1 = [Uint32List.fromList([3232711821, 1078727539, 3223330061, 1083118854]).buffer.asFloat32x4List(), Uint32List.fromList([3220807383, 3217432562, 3229760405, 3194501247]).buffer.asFloat32x4List(), Uint32List.fromList([3223501112, 1079543989, 1069180988, 3181878151]).buffer.asFloat32x4List(), Uint32List.fromList([1078650051, 1071470358, 1085387923, 3224445642]).buffer.asFloat32x4List()];
final Float32x4List Lbias_tfann_evaluate_1 = Uint32List.fromList([1070831670, 3197145344, 1083611721, 1076128681]).buffer.asFloat32x4List();
List<double> tfann_evaluate(List<double> inData) {
assert(inData.length == 3);
Float32List input = Float32List(4);
for (int i = 0; i < 3; ++i) input[i] = inData[i];
Float32x4List currentTensor = input.buffer.asFloat32x4List();
Float32List outputTensor;
outputTensor = Float32List(4);
for (int r = 0; r < 4; ++r) {
Float32x4List weightRow = Lweight_tfann_evaluate_0[r];
Float32x4 sum = currentTensor[0] * weightRow[0];
outputTensor[r] = sum.x + sum.y + sum.z;
}
currentTensor = outputTensor.buffer.asFloat32x4List();
currentTensor[0] += Lbias_tfann_evaluate_0[0];
currentTensor[0] = uscslsX4(currentTensor[0]);
outputTensor = Float32List(4);
for (int r = 0; r < 4; ++r) {
Float32x4List weightRow = Lweight_tfann_evaluate_1[r];
Float32x4 sum = currentTensor[0] * weightRow[0];
outputTensor[r] = sum.x + sum.y + sum.z + sum.w;
}
currentTensor = outputTensor.buffer.asFloat32x4List();
currentTensor[0] += Lbias_tfann_evaluate_1[0];
currentTensor[0] = uscslsX4(currentTensor[0]);
return currentTensor.buffer.asFloat32List(0, 4).toList();
}
更多关于Flutter神经网络插件tfann的使用的实战教程也可以访问 https://www.itying.com/category-92-b0.html
更多关于Flutter神经网络插件tfann的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html
tfann
是一个用于在 Flutter 应用中使用 TensorFlow Lite 进行神经网络推理的插件。它允许你将训练好的 TensorFlow Lite 模型集成到 Flutter 应用中,以便在移动设备上进行推理。
以下是使用 tfann
插件的基本步骤:
1. 添加依赖
首先,在 pubspec.yaml
文件中添加 tfann
插件的依赖:
dependencies:
flutter:
sdk: flutter
tfann: ^0.0.1 # 请检查最新版本
然后运行 flutter pub get
来获取依赖。
2. 导入库
在你的 Dart 文件中导入 tfann
:
import 'package:tfann/tfann.dart';
3. 加载模型
你需要将 TensorFlow Lite 模型文件(.tflite
)放在 assets
目录下,并在 pubspec.yaml
中声明:
flutter:
assets:
- assets/model.tflite
然后在代码中加载模型:
Tfann tfann = Tfann();
void loadModel() async {
await tfann.loadModel("assets/model.tflite");
}
4. 进行推理
加载模型后,你可以使用 tfann
进行推理。假设你的模型接受一个输入并产生一个输出:
void runInference() async {
List<double> input = [1.0, 2.0, 3.0]; // 输入数据
List<double> output = await tfann.run(input);
print(output); // 输出结果
}
5. 释放资源
在使用完模型后,记得释放资源:
void dispose() {
tfann.dispose();
}
6. 完整示例
以下是一个完整的示例,展示如何加载模型并进行推理:
import 'package:flutter/material.dart';
import 'package:tfann/tfann.dart';
void main() {
runApp(MyApp());
}
class MyApp extends StatelessWidget {
[@override](/user/override)
Widget build(BuildContext context) {
return MaterialApp(
home: HomePage(),
);
}
}
class HomePage extends StatefulWidget {
[@override](/user/override)
_HomePageState createState() => _HomePageState();
}
class _HomePageState extends State<HomePage> {
Tfann tfann = Tfann();
[@override](/user/override)
void initState() {
super.initState();
loadModel();
}
void loadModel() async {
await tfann.loadModel("assets/model.tflite");
runInference();
}
void runInference() async {
List<double> input = [1.0, 2.0, 3.0]; // 输入数据
List<double> output = await tfann.run(input);
print(output); // 输出结果
}
[@override](/user/override)
void dispose() {
tfann.dispose();
super.dispose();
}
[@override](/user/override)
Widget build(BuildContext context) {
return Scaffold(
appBar: AppBar(
title: Text('TFANN Example'),
),
body: Center(
child: Text('Check the console for inference output.'),
),
);
}
}