golang简化TensorFlow官方Go绑定使用的插件库tfgo
tfgo: TensorFlow in Go
tfgo 是一个简化 TensorFlow 官方 Go 绑定使用的插件库。TensorFlow 的 Go 绑定使用起来较为困难,而 tfgo 使其变得简单!
主要优势
- 解决作用域问题:每个新节点都会有新的唯一名称
- 自动类型转换:属性会自动转换为支持的类型,而不是在运行时抛出错误
- 使用方法链式调用,可以编写更优雅的 Go 代码
依赖
- TensorFlow-2.9.1 库
- TensorFlow 绑定 github.com/galeone/tensorflow
安装
go get github.com/galeone/tfgo
快速开始
基础示例
package main
import (
"fmt"
tg "github.com/galeone/tfgo"
tf "github.com/galeone/tensorflow/tensorflow/go"
)
func main() {
root := tg.NewRoot()
A := tg.NewTensor(root, tg.Const(root, [2][2]int32{{1, 2}, {-1, -2}}))
x := tg.NewTensor(root, tg.Const(root, [2][1]int64{{10}, {100}}))
b := tg.NewTensor(root, tg.Const(root, [2][1]int32{{-10}, {10}}))
Y := A.MatMul(x.Output).Add(b.Output)
// 注意:Y 只是指向 A 的指针!
// 如果想在图中创建不同的节点,需要克隆 Y 或 A
Z := A.Clone()
results := tg.Exec(root, []tf.Output{Y.Output, Z.Output}, nil, &tf.SessionOptions{})
fmt.Println("Y: ", results[0].Value(), "Z: ", results[1].Value())
fmt.Println("Y == A", Y == A) // => true
fmt.Println("Z == A", Z == A) // => false
}
输出:
Y: [[200] [-200]] Z: [[200] [-200]]
Y == A true
Z == A false
计算机视觉示例
package main
import (
tg "github.com/galeone/tfgo"
"github.com/galeone/tfgo/image"
"github.com/galeone/tfgo/image/filter"
"github.com/galeone/tfgo/image/padding"
tf "github.com/galeone/tensorflow/tensorflow/go"
"os"
)
func main() {
root := tg.NewRoot()
grayImg := image.Read(root, "/home/pgaleone/airplane.png", 1)
grayImg = grayImg.Scale(0, 255)
// 使用 Sobel 滤波器进行边缘检测:卷积
Gx := grayImg.Clone().Convolve(filter.SobelX(root), image.Stride{X: 1, Y: 1}, padding.SAME)
Gy := grayImg.Clone().Convolve(filter.SobelY(root), image.Stride{X: 1, Y: 1}, padding.SAME)
convoluteEdges := image.NewImage(root.SubScope("edge"), Gx.Square().Add(Gy.Square().Value()).Sqrt().Value()).EncodeJPEG()
Gx = grayImg.Clone().Correlate(filter.SobelX(root), image.Stride{X: 1, Y: 1}, padding.SAME)
Gy = grayImg.Clone().Correlate(filter.SobelY(root), image.Stride{X: 1, Y: 1}, padding.SAME)
correlateEdges := image.NewImage(root.SubScope("edge"), Gx.Square().Add(Gy.Square().Value()).Sqrt().Value()).EncodeJPEG()
results := tg.Exec(root, []tf.Output{convoluteEdges, correlateEdges}, nil, &tf.SessionOptions{})
file, _ := os.Create("convolved.png")
file.WriteString(results[0].Value().(string))
file.Close()
file, _ = os.Create("correlated.png")
file.WriteString(results[1].Value().(string))
file.Close()
}
训练与部署示例
Python 训练代码
import tensorflow as tf
model = tf.keras.Sequential(
[
tf.keras.layers.Conv2D(
8,
(3, 3),
strides=(2, 2),
padding="valid",
input_shape=(28, 28, 1),
activation=tf.nn.relu,
name="inputs",
), # 14x14x8
tf.keras.layers.Conv2D(
16, (3, 3), strides=(2, 2), padding="valid", activation=tf.nn.relu
), # 7x716
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(10, name="logits"), # linear
]
)
tf.saved_model.save(model, "output/keras")
Go 部署代码
package main
import (
"fmt"
tg "github.com/galeone/tfgo"
tf "github.com/galeone/tensorflow/tensorflow/go"
)
func main() {
model := tg.LoadModel("test_models/output/keras", []string{"serve"}, nil)
fakeInput, _ := tf.NewTensor([1][28][28][1]float32{})
results := model.Exec([]tf.Output{
model.Op("StatefulPartitionedCall", 0),
}, map[tf.Output]*tf.Tensor{
model.Op("serving_default_inputs_input", 0): fakeInput,
})
predictions := results[0]
fmt.Println(predictions.Value())
}
TensorFlow 安装
手动安装
curl -L "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.9.1.tar.gz" | sudo tar -C /usr/local -xz
sudo ldconfig
Docker 安装
docker pull tensorflow/tensorflow:2.9.1
tfgo 为 Go 开发者提供了更简单的方式来使用 TensorFlow,特别是在需要将训练好的模型部署到生产环境时。通过方法链式调用和自动类型转换等特性,它大大简化了 TensorFlow Go API 的使用。
更多关于golang简化TensorFlow官方Go绑定使用的插件库tfgo的实战教程也可以访问 https://www.itying.com/category-94-b0.html
1 回复
更多关于golang简化TensorFlow官方Go绑定使用的插件库tfgo的实战系列教程也可以访问 https://www.itying.com/category-94-b0.html
使用tfgo简化TensorFlow Go绑定
tfgo是一个简化TensorFlow官方Go绑定的库,它提供了更友好的API接口,让Go开发者能够更方便地使用TensorFlow的功能。
tfgo的主要特点
- 简化了TensorFlow Go API的复杂性
- 提供了更符合Go语言习惯的接口
- 支持模型加载和预测的便捷操作
- 简化了张量(Tensor)的创建和操作
安装tfgo
go get github.com/galeone/tfgo
基本使用示例
1. 创建简单计算图
package main
import (
"fmt"
tg "github.com/galeone/tfgo"
tf "github.com/tensorflow/tensorflow/tensorflow/go"
)
func main() {
root := tg.NewRoot()
a := tg.NewTensor(root, tg.Const(root, [2][2]int32{{1, 2}, {3, 4}}))
b := tg.NewTensor(root, tg.Const(root, [2][2]int32{{5, 6}, {7, 8}}))
// 矩阵乘法
product := tg.MatMul(root, a, b)
// 创建会话并执行计算图
results := tg.Exec(root, []tf.Output{product.Output}, nil, &tf.SessionOptions{})
fmt.Println(results[0].Value()) // 输出: [[19 22] [43 50]]
}
2. 加载预训练模型进行预测
package main
import (
"fmt"
tg "github.com/galeone/tfgo"
tf "github.com/tensorflow/tensorflow/tensorflow/go"
)
func main() {
// 加载保存的模型
model := tg.LoadModel("saved_model_dir", []string{"serve"}, nil)
// 准备输入张量
tensor, _ := tf.NewTensor([1][224][224][3]float32{ /* 图像数据 */ })
// 执行预测
results := model.Exec(
[]tf.Output{
model.Op("output_layer", 0),
}, map[tf.Output]*tf.Tensor{
model.Op("input_layer", 0): tensor,
},
)
fmt.Println(results[0].Value()) // 输出预测结果
}
3. 简化张量操作
package main
import (
"fmt"
tg "github.com/galeone/tfgo"
)
func main() {
// 直接从Go值创建张量
tensor := tg.NewTensor(nil, [3]int32{1, 2, 3})
// 张量运算
added := tensor.Add([3]int32{4, 5, 6})
multiplied := tensor.Mul(2)
fmt.Println(added.Value()) // 输出: [5 7 9]
fmt.Println(multiplied.Value()) // 输出: [2 4 6]
}
4. 训练简单模型
package main
import (
tg "github.com/galeone/tfgo"
tf "github.com/tensorflow/tensorflow/tensorflow/go"
)
func main() {
root := tg.NewRoot()
// 定义占位符
x := tg.NewTensor(root, tg.Placeholder(root, tf.Float, tg.Shape{None, 1}))
y := tg.NewTensor(root, tg.Placeholder(root, tf.Float, tg.Shape{None, 1}))
// 定义模型参数
w := tg.Variable(root, [1][1]float32{{0}}, tg.Float)
b := tg.Variable(root, [1]float32{0}, tg.Float)
// 线性模型
pred := tg.Add(root, tg.MatMul(root, x, w), b)
// 损失函数
loss := tg.ReduceMean(root, tg.Square(root, tg.Sub(root, pred, y)), nil)
// 优化器
optimizer := tg.Train(root).GradientDescent(0.01)
trainOp := tg.Train(root).Minimize(optimizer, loss)
// 训练数据
xTrain, _ := tf.NewTensor([][]float32{{1}, {2}, {3}, {4}})
yTrain, _ := tf.NewTensor([][]float32{{2}, {4}, {6}, {8}})
// 创建会话并训练
sess, _ := tf.NewSession(root.Graph, nil)
for i := 0; i < 1000; i++ {
sess.Run(map[tf.Output]*tf.Tensor{
x.Output: xTrain,
y.Output: yTrain,
}, []tf.Output{trainOp}, nil)
}
}
优势对比
相比直接使用TensorFlow Go绑定,tfgo提供了以下优势:
- 更简洁的API:减少了样板代码
- 类型安全:利用Go的类型系统减少错误
- 更自然的Go语言风格:符合Go语言习惯的命名和用法
- 简化模型操作:更容易加载和运行预训练模型
tfgo特别适合需要在Go应用中集成TensorFlow模型的开发者,它大大降低了使用TensorFlow Go绑定的学习曲线和开发难度。