Flutter文本分类插件tflite_text_classification的使用

发布于 1周前 作者 songsunli 来自 Flutter

Flutter文本分类插件tflite_text_classification的使用

pub package wakatime

开发者寄语

你好👋, 这个包支持使用TensorFlow Lite模型进行文本分类。当我想要在我的Flutter应用中集成一些从TensorFlow Model Maker生成的模型时,我开发了这个插件。

是的,毫无疑问,给我一个免费的👍或⭐会鼓励我继续更新这个插件。

包描述

这是一个用于通过tflite模型进行文本分类的Flutter插件。

注意: 该项目利用tensorflow-lite-task-text库来通过模型对文本进行分类。

特性

  • 支持Android 5.0(API级别21)及更高版本。
  • 只需四行代码即可运行。

入门指南

pubspec.yaml文件中,添加此依赖项:

dependencies:
  tflite_text_classification: 

在项目中导入此包:

import 'package:tflite_text_classification/tflite_text_classification.dart';

基本用法

ClassificationResult? result = await TfliteTextClassification().classifyText(
  params: TextClassifierParams(
    text: 'aaj me bahut khush hu', // 示例文本
    modelPath: 'path/mobilebert.tflite', // 模型路径
    modelType: ModelType.mobileBert, // 模型类型
    delegate: 0, // 解释器委托
  ),
);

示例代码

要运行示例项目,请下载以下模型压缩包并将其解压到示例项目的assets文件夹中。由于这些模型的大小原因,我单独提供了这些样本模型。

Sample Average Word Vec Model.zip

Sample Mobilebert Model.zip

提供的模型执行Hinglish(印度常见语言)文本情感分类,并且是使用TensorFlow Lite Model Maker生成的,但它们的准确性不高,仅用于演示目的。请替换为你自己的模型。

完整示例代码

import 'dart:developer';
import 'dart:io';

import 'package:flutter/material.dart';
import 'dart:async';

import 'package:flutter/services.dart';
import 'package:path_provider/path_provider.dart';
import 'package:tflite_text_classification/tflite_text_classification.dart';

/// 注意:这个示例使用了assets文件夹中的两个提供的模型。
/// 提供的模型执行Hinglish(印度常见语言)文本情感分类,并且是使用TensorFlow Lite Model Maker生成的,
/// 但它们的准确性不高,仅用于演示目的。
///
/// 请替换为你自己的模型。

void main() {
  runApp(const MyApp());
}

class MyApp extends StatefulWidget {
  const MyApp({super.key});

  [@override](/user/override)
  State<MyApp> createState() => _MyAppState();
}

class _MyAppState extends State<MyApp> {
  String testMessage = '未知'; // 测试消息
  String? predictedEmotion; // 预测的情感
  final _tfliteTextClassificationPlugin = TfliteTextClassification();

  [@override](/user/override)
  void initState() {
    super.initState();
    initPlatformState();
  }

  // 平台消息是异步的,因此我们在异步方法中初始化。
  Future<void> initPlatformState() async {
    ClassificationResult? result;
    // 平台消息可能会失败,因此我们使用try/catch PlatformException。
    // 我们还处理消息可能返回null的情况。
    try {
      TextClassifierParams params = TextClassifierParams(
        text: 'aaj me bahut khush hu', // 示例文本
        modelPath: await copyAssetFileToCacheDirectory('assets/mobilebert.tflite'), // 模型路径
        modelType: ModelType.mobileBert, // 模型类型
        delegate: 0, // 解释器委托
      );

      result = await _tfliteTextClassificationPlugin.classifyText(params: params);
    } on PlatformException catch (e) {
      log(e.toString());
    } catch (e) {
      log(e.toString());
    }

    // 如果在异步平台消息还在飞行时小部件被从树中移除,我们需要丢弃回复而不是调用setState来更新我们的非存在的外观。
    if (!mounted) return;

    setState(() {
      if (result != null) {
        predictedEmotion = getPredictedEmotion(result); // 获取预测的情感
        log(predictedEmotion.toString());
      }
    });
  }

  [@override](/user/override)
  Widget build(BuildContext context) {
    return MaterialApp(
      home: Scaffold(
        appBar: AppBar(
          title: const Text('插件示例应用'),
        ),
      ),
    );
  }
}

/// 辅助函数,从结果中获取最高分的情感。
String? getPredictedEmotion(ClassificationResult result) {
  String? predictedEmotion;

  double maxScore = 0.0;
  for (var category in result.categories) {
    if (category.score > maxScore) {
      maxScore = category.score;
      predictedEmotion = category.label;
    }
  }

  return predictedEmotion;
}

/// 辅助函数,将资产文件复制到缓存目录以供原生代码使用。
Future<String> copyAssetFileToCacheDirectory(String assetPath) async {
  // 获取缓存目录路径。
  Directory cacheDir = await getTemporaryDirectory();

  // 在缓存目录中创建一个同名的新文件。
  String fileName = assetPath.split('/').last;
  File cacheFile = File('${cacheDir.path}/$fileName');

  // 将资产文件复制到缓存目录。
  ByteData assetData = await rootBundle.load(assetPath);
  await cacheFile.writeAsBytes(assetData.buffer.asUint8List());

  return cacheFile.path;
}

更多关于Flutter文本分类插件tflite_text_classification的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html

1 回复

更多关于Flutter文本分类插件tflite_text_classification的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html


当然,下面是一个关于如何在Flutter项目中使用tflite_text_classification插件进行文本分类的示例代码。这个示例将展示如何加载TensorFlow Lite模型、进行文本分类并显示结果。

首先,确保你已经添加了tflite_text_classification插件到你的pubspec.yaml文件中:

dependencies:
  flutter:
    sdk: flutter
  tflite_text_classification: ^x.y.z  # 替换为最新版本号

然后,运行flutter pub get来安装依赖。

接下来,你需要一个预训练的TensorFlow Lite模型文件(.tflite)和标签文件(通常是.txt)。确保这些文件已经放置在你的项目中的合适位置,比如assets文件夹。

1. 创建Flutter项目结构

假设你的项目结构如下:

my_flutter_app/
  assets/
    model.tflite
    labels.txt
  lib/
    main.dart
  pubspec.yaml

2. 编写main.dart文件

下面是一个完整的示例main.dart文件,展示了如何使用tflite_text_classification插件:

import 'package:flutter/material.dart';
import 'package:tflite_text_classification/tflite_text_classification.dart';
import 'dart:typed_data/uint8list.dart';
import 'dart:convert';

void main() {
  runApp(MyApp());
}

class MyApp extends StatelessWidget {
  @override
  Widget build(BuildContext context) {
    return MaterialApp(
      title: 'Flutter Text Classification Demo',
      theme: ThemeData(
        primarySwatch: Colors.blue,
      ),
      home: MyHomePage(),
    );
  }
}

class MyHomePage extends StatefulWidget {
  @override
  _MyHomePageState createState() => _MyHomePageState();
}

class _MyHomePageState extends State<MyHomePage> {
  late TfliteTextClassification _interpreter;
  late List<String> _labels;
  late String _result;

  @override
  void initState() {
    super.initState();
    loadModelAndLabels();
  }

  Future<void> loadModelAndLabels() async {
    // Load the model
    _interpreter = await TfliteTextClassification.loadModel(
      model: "assets/model.tflite",
      labels: "assets/labels.txt",
      numThreads: 1,  // Number of threads to use for inference
    );

    // Load the labels
    final labelsFile = await rootBundle.loadString("assets/labels.txt");
    _labels = labelsFile.split('\n').toList();

    setState(() {
      _result = "Model loaded successfully!";
    });
  }

  Future<void> classifyText(String text) async {
    List<List<double>> input = [[text.codeUnits.toDoubleList()]];
    var output = await _interpreter.classifyText(input);
    int bestLabelIndex = output.argmax();
    setState(() {
      _result = "Classified as: ${_labels[bestLabelIndex]}";
    });
  }

  @override
  Widget build(BuildContext context) {
    return Scaffold(
      appBar: AppBar(
        title: Text("Flutter Text Classification Demo"),
      ),
      body: Padding(
        padding: const EdgeInsets.all(16.0),
        child: Column(
          crossAxisAlignment: CrossAxisAlignment.start,
          children: <Widget>[
            TextField(
              decoration: InputDecoration(
                labelText: "Enter Text",
              ),
              onSubmitted: (value) {
                classifyText(value);
              },
            ),
            SizedBox(height: 20),
            Text(
              _result,
              style: TextStyle(fontSize: 18),
            ),
          ],
        ),
      ),
    );
  }
}

3. 确保标签文件格式正确

你的labels.txt文件应该包含每个类别的名称,每行一个,例如:

Category1
Category2
Category3
...

4. 运行应用

确保你的设备和开发环境配置正确,然后运行应用:

flutter run

现在你应该能够在Flutter应用中看到一个文本输入框,当你输入文本并提交时,应用将使用TensorFlow Lite模型进行文本分类,并显示分类结果。

注意:这个示例假设你已经有一个预训练的TensorFlow Lite模型,并且该模型接受Unicode码点列表作为输入。如果你的模型输入格式不同,你可能需要调整输入处理逻辑。

回到顶部