Flutter文本分类插件tflite_text_classification的使用
Flutter文本分类插件tflite_text_classification的使用
开发者寄语
你好👋, 这个包支持使用TensorFlow Lite模型进行文本分类。当我想要在我的Flutter应用中集成一些从TensorFlow Model Maker生成的模型时,我开发了这个插件。
是的,毫无疑问,给我一个免费的👍或⭐会鼓励我继续更新这个插件。
包描述
这是一个用于通过tflite模型进行文本分类的Flutter插件。
注意: 该项目利用tensorflow-lite-task-text库来通过模型对文本进行分类。
特性
- 支持Android 5.0(API级别21)及更高版本。
- 只需四行代码即可运行。
入门指南
在pubspec.yaml
文件中,添加此依赖项:
dependencies:
tflite_text_classification:
在项目中导入此包:
import 'package:tflite_text_classification/tflite_text_classification.dart';
基本用法
ClassificationResult? result = await TfliteTextClassification().classifyText(
params: TextClassifierParams(
text: 'aaj me bahut khush hu', // 示例文本
modelPath: 'path/mobilebert.tflite', // 模型路径
modelType: ModelType.mobileBert, // 模型类型
delegate: 0, // 解释器委托
),
);
示例代码
要运行示例项目,请下载以下模型压缩包并将其解压到示例项目的assets
文件夹中。由于这些模型的大小原因,我单独提供了这些样本模型。
Sample Average Word Vec Model.zip
提供的模型执行Hinglish(印度常见语言)文本情感分类,并且是使用TensorFlow Lite Model Maker生成的,但它们的准确性不高,仅用于演示目的。请替换为你自己的模型。
完整示例代码
import 'dart:developer';
import 'dart:io';
import 'package:flutter/material.dart';
import 'dart:async';
import 'package:flutter/services.dart';
import 'package:path_provider/path_provider.dart';
import 'package:tflite_text_classification/tflite_text_classification.dart';
/// 注意:这个示例使用了assets文件夹中的两个提供的模型。
/// 提供的模型执行Hinglish(印度常见语言)文本情感分类,并且是使用TensorFlow Lite Model Maker生成的,
/// 但它们的准确性不高,仅用于演示目的。
///
/// 请替换为你自己的模型。
void main() {
runApp(const MyApp());
}
class MyApp extends StatefulWidget {
const MyApp({super.key});
[@override](/user/override)
State<MyApp> createState() => _MyAppState();
}
class _MyAppState extends State<MyApp> {
String testMessage = '未知'; // 测试消息
String? predictedEmotion; // 预测的情感
final _tfliteTextClassificationPlugin = TfliteTextClassification();
[@override](/user/override)
void initState() {
super.initState();
initPlatformState();
}
// 平台消息是异步的,因此我们在异步方法中初始化。
Future<void> initPlatformState() async {
ClassificationResult? result;
// 平台消息可能会失败,因此我们使用try/catch PlatformException。
// 我们还处理消息可能返回null的情况。
try {
TextClassifierParams params = TextClassifierParams(
text: 'aaj me bahut khush hu', // 示例文本
modelPath: await copyAssetFileToCacheDirectory('assets/mobilebert.tflite'), // 模型路径
modelType: ModelType.mobileBert, // 模型类型
delegate: 0, // 解释器委托
);
result = await _tfliteTextClassificationPlugin.classifyText(params: params);
} on PlatformException catch (e) {
log(e.toString());
} catch (e) {
log(e.toString());
}
// 如果在异步平台消息还在飞行时小部件被从树中移除,我们需要丢弃回复而不是调用setState来更新我们的非存在的外观。
if (!mounted) return;
setState(() {
if (result != null) {
predictedEmotion = getPredictedEmotion(result); // 获取预测的情感
log(predictedEmotion.toString());
}
});
}
[@override](/user/override)
Widget build(BuildContext context) {
return MaterialApp(
home: Scaffold(
appBar: AppBar(
title: const Text('插件示例应用'),
),
),
);
}
}
/// 辅助函数,从结果中获取最高分的情感。
String? getPredictedEmotion(ClassificationResult result) {
String? predictedEmotion;
double maxScore = 0.0;
for (var category in result.categories) {
if (category.score > maxScore) {
maxScore = category.score;
predictedEmotion = category.label;
}
}
return predictedEmotion;
}
/// 辅助函数,将资产文件复制到缓存目录以供原生代码使用。
Future<String> copyAssetFileToCacheDirectory(String assetPath) async {
// 获取缓存目录路径。
Directory cacheDir = await getTemporaryDirectory();
// 在缓存目录中创建一个同名的新文件。
String fileName = assetPath.split('/').last;
File cacheFile = File('${cacheDir.path}/$fileName');
// 将资产文件复制到缓存目录。
ByteData assetData = await rootBundle.load(assetPath);
await cacheFile.writeAsBytes(assetData.buffer.asUint8List());
return cacheFile.path;
}
更多关于Flutter文本分类插件tflite_text_classification的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html
更多关于Flutter文本分类插件tflite_text_classification的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html
当然,下面是一个关于如何在Flutter项目中使用tflite_text_classification
插件进行文本分类的示例代码。这个示例将展示如何加载TensorFlow Lite模型、进行文本分类并显示结果。
首先,确保你已经添加了tflite_text_classification
插件到你的pubspec.yaml
文件中:
dependencies:
flutter:
sdk: flutter
tflite_text_classification: ^x.y.z # 替换为最新版本号
然后,运行flutter pub get
来安装依赖。
接下来,你需要一个预训练的TensorFlow Lite模型文件(.tflite
)和标签文件(通常是.txt
)。确保这些文件已经放置在你的项目中的合适位置,比如assets
文件夹。
1. 创建Flutter项目结构
假设你的项目结构如下:
my_flutter_app/
assets/
model.tflite
labels.txt
lib/
main.dart
pubspec.yaml
2. 编写main.dart
文件
下面是一个完整的示例main.dart
文件,展示了如何使用tflite_text_classification
插件:
import 'package:flutter/material.dart';
import 'package:tflite_text_classification/tflite_text_classification.dart';
import 'dart:typed_data/uint8list.dart';
import 'dart:convert';
void main() {
runApp(MyApp());
}
class MyApp extends StatelessWidget {
@override
Widget build(BuildContext context) {
return MaterialApp(
title: 'Flutter Text Classification Demo',
theme: ThemeData(
primarySwatch: Colors.blue,
),
home: MyHomePage(),
);
}
}
class MyHomePage extends StatefulWidget {
@override
_MyHomePageState createState() => _MyHomePageState();
}
class _MyHomePageState extends State<MyHomePage> {
late TfliteTextClassification _interpreter;
late List<String> _labels;
late String _result;
@override
void initState() {
super.initState();
loadModelAndLabels();
}
Future<void> loadModelAndLabels() async {
// Load the model
_interpreter = await TfliteTextClassification.loadModel(
model: "assets/model.tflite",
labels: "assets/labels.txt",
numThreads: 1, // Number of threads to use for inference
);
// Load the labels
final labelsFile = await rootBundle.loadString("assets/labels.txt");
_labels = labelsFile.split('\n').toList();
setState(() {
_result = "Model loaded successfully!";
});
}
Future<void> classifyText(String text) async {
List<List<double>> input = [[text.codeUnits.toDoubleList()]];
var output = await _interpreter.classifyText(input);
int bestLabelIndex = output.argmax();
setState(() {
_result = "Classified as: ${_labels[bestLabelIndex]}";
});
}
@override
Widget build(BuildContext context) {
return Scaffold(
appBar: AppBar(
title: Text("Flutter Text Classification Demo"),
),
body: Padding(
padding: const EdgeInsets.all(16.0),
child: Column(
crossAxisAlignment: CrossAxisAlignment.start,
children: <Widget>[
TextField(
decoration: InputDecoration(
labelText: "Enter Text",
),
onSubmitted: (value) {
classifyText(value);
},
),
SizedBox(height: 20),
Text(
_result,
style: TextStyle(fontSize: 18),
),
],
),
),
);
}
}
3. 确保标签文件格式正确
你的labels.txt
文件应该包含每个类别的名称,每行一个,例如:
Category1
Category2
Category3
...
4. 运行应用
确保你的设备和开发环境配置正确,然后运行应用:
flutter run
现在你应该能够在Flutter应用中看到一个文本输入框,当你输入文本并提交时,应用将使用TensorFlow Lite模型进行文本分类,并显示分类结果。
注意:这个示例假设你已经有一个预训练的TensorFlow Lite模型,并且该模型接受Unicode码点列表作为输入。如果你的模型输入格式不同,你可能需要调整输入处理逻辑。