golang简化TensorFlow官方Go绑定使用的插件库tfgo

tfgo: TensorFlow in Go

tfgo 是一个简化 TensorFlow 官方 Go 绑定使用的插件库。TensorFlow 的 Go 绑定使用起来较为困难,而 tfgo 使其变得简单!

主要优势

  • 解决作用域问题:每个新节点都会有新的唯一名称
  • 自动类型转换:属性会自动转换为支持的类型,而不是在运行时抛出错误
  • 使用方法链式调用,可以编写更优雅的 Go 代码

依赖

  1. TensorFlow-2.9.1 库
  2. TensorFlow 绑定 github.com/galeone/tensorflow

安装

go get github.com/galeone/tfgo

快速开始

基础示例

package main

import (
    "fmt"
    tg "github.com/galeone/tfgo"
    tf "github.com/galeone/tensorflow/tensorflow/go"
)

func main() {
    root := tg.NewRoot()
    A := tg.NewTensor(root, tg.Const(root, [2][2]int32{{1, 2}, {-1, -2}}))
    x := tg.NewTensor(root, tg.Const(root, [2][1]int64{{10}, {100}}))
    b := tg.NewTensor(root, tg.Const(root, [2][1]int32{{-10}, {10}}))
    Y := A.MatMul(x.Output).Add(b.Output)
    
    // 注意:Y 只是指向 A 的指针!
    
    // 如果想在图中创建不同的节点,需要克隆 Y 或 A
    Z := A.Clone()
    results := tg.Exec(root, []tf.Output{Y.Output, Z.Output}, nil, &tf.SessionOptions{})
    fmt.Println("Y: ", results[0].Value(), "Z: ", results[1].Value())
    fmt.Println("Y == A", Y == A) // => true
    fmt.Println("Z == A", Z == A) // => false
}

输出:

Y:  [[200] [-200]] Z:  [[200] [-200]]
Y == A true
Z == A false

计算机视觉示例

package main

import (
    tg "github.com/galeone/tfgo"
    "github.com/galeone/tfgo/image"
    "github.com/galeone/tfgo/image/filter"
    "github.com/galeone/tfgo/image/padding"
    tf "github.com/galeone/tensorflow/tensorflow/go"
    "os"
)

func main() {
    root := tg.NewRoot()
    grayImg := image.Read(root, "/home/pgaleone/airplane.png", 1)
    grayImg = grayImg.Scale(0, 255)

    // 使用 Sobel 滤波器进行边缘检测:卷积
    Gx := grayImg.Clone().Convolve(filter.SobelX(root), image.Stride{X: 1, Y: 1}, padding.SAME)
    Gy := grayImg.Clone().Convolve(filter.SobelY(root), image.Stride{X: 1, Y: 1}, padding.SAME)
    convoluteEdges := image.NewImage(root.SubScope("edge"), Gx.Square().Add(Gy.Square().Value()).Sqrt().Value()).EncodeJPEG()

    Gx = grayImg.Clone().Correlate(filter.SobelX(root), image.Stride{X: 1, Y: 1}, padding.SAME)
    Gy = grayImg.Clone().Correlate(filter.SobelY(root), image.Stride{X: 1, Y: 1}, padding.SAME)
    correlateEdges := image.NewImage(root.SubScope("edge"), Gx.Square().Add(Gy.Square().Value()).Sqrt().Value()).EncodeJPEG()

    results := tg.Exec(root, []tf.Output{convoluteEdges, correlateEdges}, nil, &tf.SessionOptions{})

    file, _ := os.Create("convolved.png")
    file.WriteString(results[0].Value().(string))
    file.Close()

    file, _ = os.Create("correlated.png")
    file.WriteString(results[1].Value().(string))
    file.Close()
}

训练与部署示例

Python 训练代码

import tensorflow as tf

model = tf.keras.Sequential(
    [
        tf.keras.layers.Conv2D(
            8,
            (3, 3),
            strides=(2, 2),
            padding="valid",
            input_shape=(28, 28, 1),
            activation=tf.nn.relu,
            name="inputs",
        ),  # 14x14x8
        tf.keras.layers.Conv2D(
            16, (3, 3), strides=(2, 2), padding="valid", activation=tf.nn.relu
        ),  # 7x716
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(10, name="logits"),  # linear
    ]
)

tf.saved_model.save(model, "output/keras")

Go 部署代码

package main

import (
    "fmt"
    tg "github.com/galeone/tfgo"
    tf "github.com/galeone/tensorflow/tensorflow/go"
)

func main() {
    model := tg.LoadModel("test_models/output/keras", []string{"serve"}, nil)

    fakeInput, _ := tf.NewTensor([1][28][28][1]float32{})
    results := model.Exec([]tf.Output{
        model.Op("StatefulPartitionedCall", 0),
    }, map[tf.Output]*tf.Tensor{
        model.Op("serving_default_inputs_input", 0): fakeInput,
    })

    predictions := results[0]
    fmt.Println(predictions.Value())
}

TensorFlow 安装

手动安装

curl -L "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.9.1.tar.gz" | sudo tar -C /usr/local -xz
sudo ldconfig

Docker 安装

docker pull tensorflow/tensorflow:2.9.1

tfgo 为 Go 开发者提供了更简单的方式来使用 TensorFlow,特别是在需要将训练好的模型部署到生产环境时。通过方法链式调用和自动类型转换等特性,它大大简化了 TensorFlow Go API 的使用。


更多关于golang简化TensorFlow官方Go绑定使用的插件库tfgo的实战教程也可以访问 https://www.itying.com/category-94-b0.html

1 回复

更多关于golang简化TensorFlow官方Go绑定使用的插件库tfgo的实战系列教程也可以访问 https://www.itying.com/category-94-b0.html


使用tfgo简化TensorFlow Go绑定

tfgo是一个简化TensorFlow官方Go绑定的库,它提供了更友好的API接口,让Go开发者能够更方便地使用TensorFlow的功能。

tfgo的主要特点

  1. 简化了TensorFlow Go API的复杂性
  2. 提供了更符合Go语言习惯的接口
  3. 支持模型加载和预测的便捷操作
  4. 简化了张量(Tensor)的创建和操作

安装tfgo

go get github.com/galeone/tfgo

基本使用示例

1. 创建简单计算图

package main

import (
	"fmt"
	tg "github.com/galeone/tfgo"
	tf "github.com/tensorflow/tensorflow/tensorflow/go"
)

func main() {
	root := tg.NewRoot()
	a := tg.NewTensor(root, tg.Const(root, [2][2]int32{{1, 2}, {3, 4}}))
	b := tg.NewTensor(root, tg.Const(root, [2][2]int32{{5, 6}, {7, 8}}))
	
	// 矩阵乘法
	product := tg.MatMul(root, a, b)
	
	// 创建会话并执行计算图
	results := tg.Exec(root, []tf.Output{product.Output}, nil, &tf.SessionOptions{})
	
	fmt.Println(results[0].Value()) // 输出: [[19 22] [43 50]]
}

2. 加载预训练模型进行预测

package main

import (
	"fmt"
	tg "github.com/galeone/tfgo"
	tf "github.com/tensorflow/tensorflow/tensorflow/go"
)

func main() {
	// 加载保存的模型
	model := tg.LoadModel("saved_model_dir", []string{"serve"}, nil)
	
	// 准备输入张量
	tensor, _ := tf.NewTensor([1][224][224][3]float32{ /* 图像数据 */ })
	
	// 执行预测
	results := model.Exec(
		[]tf.Output{
			model.Op("output_layer", 0),
		}, map[tf.Output]*tf.Tensor{
			model.Op("input_layer", 0): tensor,
		},
	)
	
	fmt.Println(results[0].Value()) // 输出预测结果
}

3. 简化张量操作

package main

import (
	"fmt"
	tg "github.com/galeone/tfgo"
)

func main() {
	// 直接从Go值创建张量
	tensor := tg.NewTensor(nil, [3]int32{1, 2, 3})
	
	// 张量运算
	added := tensor.Add([3]int32{4, 5, 6})
	multiplied := tensor.Mul(2)
	
	fmt.Println(added.Value())     // 输出: [5 7 9]
	fmt.Println(multiplied.Value()) // 输出: [2 4 6]
}

4. 训练简单模型

package main

import (
	tg "github.com/galeone/tfgo"
	tf "github.com/tensorflow/tensorflow/tensorflow/go"
)

func main() {
	root := tg.NewRoot()
	
	// 定义占位符
	x := tg.NewTensor(root, tg.Placeholder(root, tf.Float, tg.Shape{None, 1}))
	y := tg.NewTensor(root, tg.Placeholder(root, tf.Float, tg.Shape{None, 1}))
	
	// 定义模型参数
	w := tg.Variable(root, [1][1]float32{{0}}, tg.Float)
	b := tg.Variable(root, [1]float32{0}, tg.Float)
	
	// 线性模型
	pred := tg.Add(root, tg.MatMul(root, x, w), b)
	
	// 损失函数
	loss := tg.ReduceMean(root, tg.Square(root, tg.Sub(root, pred, y)), nil)
	
	// 优化器
	optimizer := tg.Train(root).GradientDescent(0.01)
	trainOp := tg.Train(root).Minimize(optimizer, loss)
	
	// 训练数据
	xTrain, _ := tf.NewTensor([][]float32{{1}, {2}, {3}, {4}})
	yTrain, _ := tf.NewTensor([][]float32{{2}, {4}, {6}, {8}})
	
	// 创建会话并训练
	sess, _ := tf.NewSession(root.Graph, nil)
	for i := 0; i < 1000; i++ {
		sess.Run(map[tf.Output]*tf.Tensor{
			x.Output: xTrain,
			y.Output: yTrain,
		}, []tf.Output{trainOp}, nil)
	}
}

优势对比

相比直接使用TensorFlow Go绑定,tfgo提供了以下优势:

  1. 更简洁的API:减少了样板代码
  2. 类型安全:利用Go的类型系统减少错误
  3. 更自然的Go语言风格:符合Go语言习惯的命名和用法
  4. 简化模型操作:更容易加载和运行预训练模型

tfgo特别适合需要在Go应用中集成TensorFlow模型的开发者,它大大降低了使用TensorFlow Go绑定的学习曲线和开发难度。

回到顶部