Flutter机器学习教育插件teachable的使用

发布于 1周前 作者 sinazl 来自 Flutter

Flutter机器学习教育插件teachable的使用

teachable 是一个用于在 Flutter 应用程序中使用 Teachable Machine 的插件。它可以帮助开发者创建基于姿势分类器(Posenet)的应用程序。

必要部分

Teachable Machine 的 HTML

首先,将以下代码存储到您的 assets 目录中的 HTML 文件中:

<div>
  <canvas id="canvas"
          style="position:fixed;min-height:100%;min-width:100%;height:100%;width:100%;top:0%;left:0%;resize:none;">
  </canvas>
</div>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.3.1/dist/tf.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@teachablemachine/pose@0.8/dist/teachablemachine-pose.min.js"></script>
<script type="text/javascript">
  const URL = "Your URL comes here";
  let model, webcam, ctx, labelContainer, maxPredictions;

  async function init() {
    const modelURL = URL + "model.json";
    const metadataURL = URL + "metadata.json";
    model = await tmPose.load(modelURL, metadataURL);

    maxPredictions = model.getTotalClasses();

    // 设置摄像头
    const size = 600;
    const flip = true; // 是否翻转摄像头
    webcam = new tmPose.Webcam(size, size, flip); // 宽度, 高度, 翻转
    await webcam.setup(); // 请求访问摄像头
    await webcam.play();
    window.requestAnimationFrame(loop);

    // 添加/获取 DOM 元素
    const canvas = document.getElementById("canvas");
    canvas.width = size; canvas.height = size;
    ctx = canvas.getContext("2d");
  }

  async function loop(timestamp) {
    webcam.update(); // 更新摄像头帧
    await predict();
    window.requestAnimationFrame(loop);
  }

  async function predict() {
    // 预测 1: 使用 posenet 进行输入
    // estimatePose 可以接受图像、视频或 canvas 元素
    const { pose, posenetOutput } = await model.estimatePose(webcam.canvas);
    // 预测 2: 使用 Teachable Machine 分类模型
    const prediction = await model.predict(posenetOutput);

    let ans = 0, score = 0;
    for (let i = 0; i < maxPredictions; i++) {
      if (prediction[i].probability.toFixed(2) > score) {
        score = prediction[i].probability.toFixed(2);
        ans = prediction[i].className;
      }
    }

    // 最后绘制姿势
    drawPose(pose);

    try {
      // 调用 Flutter 代码
      window.flutter_inappwebview.callHandler('updater', prediction);
    } catch (e) {
      // 异常处理
    }
  }

  function drawPose(pose) {
    if (webcam.canvas) {
      ctx.drawImage(webcam.canvas, 0, 0);
      // 绘制关键点和骨骼
      if (pose) {
        const minPartConfidence = 0.5;
        tmPose.drawKeypoints(pose.keypoints, minPartConfidence, ctx);
        tmPose.drawSkeleton(pose.keypoints, minPartConfidence, ctx);
      }
    }
  }
  init();
</script>

请注意替换 Your URL comes here 为您的 Teachable Machine 模型的 URL。

Flutter 部分

在 Flutter 应用程序中添加权限请求:

void main() async {
  WidgetsFlutterBinding.ensureInitialized();
  // 请求相机和麦克风权限
  await Permission.camera.request();
  await Permission.microphone.request();
  runApp(MyApp());
}

Teachable 小部件的使用

在 Flutter 应用程序中使用 Teachable 小部件:

import 'dart:convert';

import 'package:flutter/material.dart';
import 'package:permission_handler/permission_handler.dart';
import 'package:teachable/teachable.dart';

void main() async {
  WidgetsFlutterBinding.ensureInitialized();

  await Permission.camera.request();
  await Permission.microphone.request();
  runApp(MyApp());
}

class MyApp extends StatelessWidget {
  [@override](/user/override)
  Widget build(BuildContext context) {
    return MaterialApp(
      title: 'Flutter Demo',
      theme: ThemeData(
        primarySwatch: Colors.blue,
      ),
      home: MyHomePage(title: 'Flutter Demo Home Page'),
    );
  }
}

class MyHomePage extends StatefulWidget {
  MyHomePage({Key? key, required this.title}) : super(key: key);

  final String title;

  [@override](/user/override)
  _MyHomePageState createState() => _MyHomePageState();
}

class _MyHomePageState extends State<MyHomePage> {
  String pose = "";

  [@override](/user/override)
  Widget build(BuildContext context) {
    return Scaffold(
      appBar: AppBar(title: Text("Pose Classifier")),
      body: Stack(
        children: [
          Container(
            child: Column(children: <Widget>[
              Expanded(
                child: Container(
                  child: Teachable(
                    path: "pose/index.html", // 替换为您的 HTML 文件路径
                    results: (res) {
                      var resp = jsonDecode(res);
                      setState(() {
                        pose = (resp['Tree Pose'] * 100.0).toString();
                      });
                    },
                  ),
                ),
              ),
            ]),
          ),
          Align(
            alignment: Alignment.bottomCenter,
            child: Container(
              width: double.infinity,
              height: 50,
              decoration: BoxDecoration(
                color: Colors.black.withOpacity(0.5),
              ),
              child: Column(
                mainAxisAlignment: MainAxisAlignment.spaceBetween,
                children: [
                  Row(
                    mainAxisAlignment: MainAxisAlignment.spaceEvenly,
                    children: [
                      Text(
                        "TREE POSE",
                        style: TextStyle(color: Colors.white),
                      ),
                      Text(
                        pose,
                        style: TextStyle(color: Colors.white),
                      ),
                    ],
                  ),
                ],
              ),
            ),
          )
        ],
      ),
    );
  }
}

权限设置

Android

AndroidManifest.xml 中添加以下权限:

<uses-permission android:name="android.permission.INTERNET"/>
<uses-permission android:name="android.permission.CAMERA" />
<uses-permission android:name="android.permission.RECORD_AUDIO" />
<uses-permission android:name="android.permission.MODIFY_AUDIO_SETTINGS" />
<uses-permission android:name="android.permission.VIDEO_CAPTURE" />
<uses-permission android:name="android.permission.AUDIO_CAPTURE" />
iOS

确保在 Info.plist 中进行相应的配置,可以参考 InAppWebView 文档WebRTC API 文档

训练模型

您可以使用 Teachable Machine 训练自己的模型,并将其 URL 添加到您的应用程序中。

示例

下面是一个完整的示例代码,展示了如何在 Flutter 应用程序中使用 Teachable 插件:

import 'dart:convert';

import 'package:flutter/material.dart';
import 'package:permission_handler/permission_handler.dart';
import 'package:teachable/teachable.dart';

void main() async {
  WidgetsFlutterBinding.ensureInitialized();

  await Permission.camera.request();
  await Permission.microphone.request();
  runApp(MyApp());
}

class MyApp extends StatelessWidget {
  [@override](/user/override)
  Widget build(BuildContext context) {
    return MaterialApp(
      title: 'Flutter Demo',
      theme: ThemeData(
        primarySwatch: Colors.blue,
      ),
      home: MyHomePage(title: 'Flutter Demo Home Page'),
    );
  }
}

class MyHomePage extends StatefulWidget {
  MyHomePage({Key? key, required this.title}) : super(key: key);

  final String title;

  [@override](/user/override)
  _MyHomePageState createState() => _MyHomePageState();
}

class _MyHomePageState extends State<MyHomePage> {
  String pose = "";

  [@override](/user/override)
  Widget build(BuildContext context) {
    return Scaffold(
      appBar: AppBar(title: Text("Pose Classifier")),
      body: Stack(
        children: [
          Container(
            child: Column(children: <Widget>[
              Expanded(
                child: Container(
                  child: Teachable(
                    path: "pose/index.html", // 替换为您的 HTML 文件路径
                    results: (res) {
                      var resp = jsonDecode(res);
                      setState(() {
                        pose = (resp['Tree Pose'] * 100.0).toString();
                      });
                    },
                  ),
                ),
              ),
            ]),
          ),
          Align(
            alignment: Alignment.bottomCenter,
            child: Container(
              width: double.infinity,
              height: 50,
              decoration: BoxDecoration(
                color: Colors.black.withOpacity(0.5),
              ),
              child: Column(
                mainAxisAlignment: MainAxisAlignment.spaceBetween,
                children: [
                  Row(
                    mainAxisAlignment: MainAxisAlignment.spaceEvenly,
                    children: [
                      Text(
                        "TREE POSE",
                        style: TextStyle(color: Colors.white),
                      ),
                      Text(
                        pose,
                        style: TextStyle(color: Colors.white),
                      ),
                    ],
                  ),
                ],
              ),
            ),
          )
        ],
      ),
    );
  }
}

更多关于Flutter机器学习教育插件teachable的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html

1 回复

更多关于Flutter机器学习教育插件teachable的使用的实战系列教程也可以访问 https://www.itying.com/category-92-b0.html


在Flutter中使用Teachable Machine插件来集成机器学习教育功能,可以帮助学生和开发者轻松创建和部署机器学习模型。Teachable Machine是一个由Google提供的工具,允许用户通过简单的拖放界面训练图像、声音和姿态识别模型。下面是一个基本的代码示例,展示了如何在Flutter项目中集成Teachable Machine插件来处理图像识别任务。

首先,你需要确保你的Flutter项目已经设置好,并且已经添加了必要的依赖项。Teachable Machine插件并不是官方的Flutter插件,但你可以使用TensorFlow Lite插件来加载和运行Teachable Machine导出的TensorFlow Lite模型。

  1. 添加依赖项

在你的pubspec.yaml文件中添加TensorFlow Lite的依赖项:

dependencies:
  flutter:
    sdk: flutter
  tflite_flutter: ^0.9.0+2  # 请检查最新版本号

然后运行flutter pub get来安装依赖项。

  1. 准备模型

使用Teachable Machine训练你的模型,并导出为TensorFlow Lite格式。你会得到一个.tflite文件和一些标签文件(通常是.txt.labels)。将这些文件放到你的Flutter项目的assets目录下(如果没有这个目录,请创建它)。

  1. 加载和运行模型

下面是一个示例代码,展示了如何在Flutter应用中加载Teachable Machine导出的TensorFlow Lite模型,并使用它来识别图像。

import 'package:flutter/material.dart';
import 'package:tflite_flutter/tflite_flutter.dart';
import 'dart:io';
import 'dart:typed_data';

void main() => runApp(MyApp());

class MyApp extends StatefulWidget {
  @override
  _MyAppState createState() => _MyAppState();
}

class _MyAppState extends State<MyApp> {
  late Interpreter _interpreter;
  late List<String> _labels;

  @override
  void initState() {
    super.initState();
    _loadModelAndLabels();
  }

  Future<void> _loadModelAndLabels() async {
    // 加载TensorFlow Lite模型
    await Tflite.loadModel(
      model: 'assets/model.tflite', // 这里是你的.tflite文件路径
      labels: 'assets/labels.txt',  // 这里是你的标签文件路径
    ).then((value) {
      setState(() {
        _interpreter = value;
      });
    }).catchError((e) {
      print(e);
    });

    // 异步加载标签
    _labels = await rootBundle.loadString('assets/labels.txt').then((value) {
      return value.split('\n');
    }).catchError((e) {
      print(e);
      return [];
    });
  }

  Future<void> _runModelOnImage(File image) async {
    // 预处理图像(例如,调整大小到224x224)
    final Uint8List imageBytes = await image.readAsBytes();
    final TensorImage input = TensorImage.fromBytes(
      imageBytes,
      inputShape: [1, 224, 224, 3],  // 根据你的模型输入形状调整
    );

    // 运行模型
    final List<dynamic> output = await _interpreter.run(input.buffer.asUint8List());

    // 获取预测结果
    final int bestResultIndex = output.mapIndexed((i, e) => MapEntry(i, e[0] as double)).toList()
        .indexWhere((element) => element.value == element.values.reduce((a, b) => math.max(a, b)));

    setState(() {
      print('Prediction: ${_labels[bestResultIndex]}');
    });
  }

  @override
  Widget build(BuildContext context) {
    return MaterialApp(
      home: Scaffold(
        appBar: AppBar(
          title: Text('Teachable Machine Flutter Demo'),
        ),
        body: Center(
          child: ElevatedButton(
            onPressed: () async {
              // 这里选择图像的逻辑(例如,使用image_picker插件)
              final File image = await ImagePicker().pickImage(source: ImageSource.gallery);
              if (image != null && _interpreter != null) {
                _runModelOnImage(image);
              }
            },
            child: Text('Pick an Image'),
          ),
        ),
      ),
    );
  }
}

注意:上面的代码示例使用了image_picker插件来选择图像,但你需要自己添加这个插件的依赖项并在pubspec.yaml中配置它。此外,图像预处理部分(如调整大小)可能需要根据你的具体模型进行调整。

这个示例代码展示了如何在Flutter应用中加载TensorFlow Lite模型,并使用它来识别图像。你可以根据需要对代码进行扩展,例如添加UI元素来显示预测结果,或者处理不同类型的输入数据(如声音或姿态)。

回到顶部