golang神经网络实现与训练插件库gobrain的使用
golang神经网络实现与训练插件库gobrain的使用
简介
gobrain是一个用Go语言实现的神经网络库,主要包含前馈神经网络和Elman循环神经网络功能。
快速开始
以下是一个完整的使用gobrain实现XOR逻辑的示例:
package main
import (
"github.com/goml/gobrain"
"math/rand"
)
func main() {
// 设置随机种子为0
rand.Seed(0)
// 创建XOR训练数据
patterns := [][][]float64{
{{0, 0}, {0}},
{{0, 1}, {1}},
{{1, 0}, {1}},
{{1, 1}, {0}},
}
// 实例化前馈神经网络
ff := &gobrain.FeedForward{}
// 初始化神经网络结构
// 2个输入节点,2个隐藏节点,1个输出节点
ff.Init(2, 2, 1)
// 训练网络
// 训练1000次,学习率0.6,动量因子0.4,显示训练误差
ff.Train(patterns, 1000, 0.6, 0.4, true)
// 测试网络
ff.Test(patterns)
}
运行上述代码后,网络将被训练并准备好使用。
测试网络
可以使用Test
方法测试网络性能:
ff.Test(patterns)
测试结果会显示类似如下输出:
[0 0] -> [0.057503945708445] : [0]
[0 1] -> [0.930100635071210] : [1]
[1 0] -> [0.927809966227284] : [1]
[1 1] -> [0.097408795324620] : [0]
其中箭头->
前是输入值,箭头后是网络输出值,冒号:
后是期望输出值。
预测新数据
使用Update
方法可以进行预测:
inputs := []float64{1, 1}
ff.Update(inputs)
输出将是一个0到1之间的值向量。
循环神经网络(RNN)
gobrain实现了Elman简单循环网络。可以通过SetContexts
函数使用这一功能:
ff.SetContexts(1, nil)
上面的例子创建了一个初始值为0.5的上下文。也可以创建自定义初始化的上下文:
contexts := [][]float64{
{0.5, 0.8, 0.1}
}
注意:自定义上下文的长度必须等于隐藏节点数+1(偏置节点)。在上面的例子中,隐藏节点数是2,因此上下文有3个值。
持久化
在示例文件夹中有训练好的网络持久化到文件的完整示例:
- example/02 将网络保存到文件
- example/03 从文件加载网络
要运行示例,进入文件夹并执行:
go run main.go
版本变更
- 1.0.0 - 添加了具有Elman RNN上下文的前馈神经网络
更多关于golang神经网络实现与训练插件库gobrain的使用的实战教程也可以访问 https://www.itying.com/category-94-b0.html
更多关于golang神经网络实现与训练插件库gobrain的使用的实战系列教程也可以访问 https://www.itying.com/category-94-b0.html
使用gobrain实现Golang神经网络
gobrain是一个轻量级的Golang神经网络库,它实现了前馈神经网络和简单的训练算法。下面我将介绍如何使用gobrain库来构建和训练神经网络。
安装gobrain
首先需要安装gobrain库:
go get github.com/goml/gobrain
基本使用示例
1. 创建简单的前馈神经网络
package main
import (
"fmt"
"github.com/goml/gobrain"
)
func main() {
// 创建新的前馈神经网络
// 参数:输入层节点数、隐藏层节点数、输出层节点数
ff := gobrain.NewFeedForward(2, 2, 1)
// 训练数据 - XOR问题
patterns := [][][]float64{
{{0, 0}, {0}},
{{0, 1}, {1}},
{{1, 0}, {1}},
{{1, 1}, {0}},
}
// 训练神经网络
// 参数:训练数据、训练次数、学习率、动量因子、是否打印错误
ff.Train(patterns, 10000, 0.6, 0.4, true)
// 测试训练结果
inputs := [][]float64{
{0, 0},
{0, 1},
{1, 0},
{1, 1},
}
for _, input := range inputs {
result := ff.Update(input)
fmt.Printf("%f XOR %f => %f\n", input[0], input[1], result[0])
}
}
2. 自定义神经网络结构
gobrain允许你自定义更复杂的网络结构:
package main
import (
"fmt"
"github.com/goml/gobrain"
)
func main() {
// 创建自定义神经网络
ff := &gobrain.FeedForward{}
// 初始化网络结构
// 参数:输入层、隐藏层、输出层节点数
ff.Init(3, []int{4, 4}, 2) // 3输入,2个隐藏层(各4节点),2输出
// 训练数据 - 模拟分类问题
patterns := [][][]float64{
{{0.1, 0.2, 0.3}, {1, 0}},
{{0.4, 0.5, 0.6}, {0, 1}},
{{0.7, 0.8, 0.9}, {1, 0}},
{{0.2, 0.3, 0.4}, {0, 1}},
}
// 训练参数
ff.Train(patterns, 5000, 0.5, 0.4, true)
// 测试
testInput := []float64{0.3, 0.4, 0.5}
result := ff.Update(testInput)
fmt.Printf("Test result: %v\n", result)
}
3. 持久化模型
gobrain支持将训练好的模型保存到文件以及从文件加载:
package main
import (
"fmt"
"github.com/goml/gobrain"
"os"
)
func main() {
ff := gobrain.NewFeedForward(2, 3, 1)
// 训练数据
patterns := [][][]float64{
{{0, 0}, {0}},
{{0, 1}, {1}},
{{1, 0}, {1}},
{{1, 1}, {0}},
}
ff.Train(patterns, 1000, 0.6, 0.4, false)
// 保存模型到文件
err := ff.Save("xor_model.json")
if err != nil {
fmt.Println("保存模型失败:", err)
return
}
// 从文件加载模型
newFF := &gobrain.FeedForward{}
err = newFF.Load("xor_model.json")
if err != nil {
fmt.Println("加载模型失败:", err)
return
}
// 测试加载的模型
result := newFF.Update([]float64{1, 0})
fmt.Println("1 XOR 0 =>", result[0])
// 清理测试文件
os.Remove("xor_model.json")
}
高级功能
自定义激活函数
gobrain默认使用sigmoid激活函数,但你可以自定义:
package main
import (
"fmt"
"github.com/goml/gobrain"
"math"
)
// 自定义激活函数
func customActivation(x float64) float64 {
// 使用tanh激活函数
return math.Tanh(x)
}
func main() {
ff := gobrain.NewFeedForward(2, 4, 1)
// 设置自定义激活函数
ff.SetActivationFunction(customActivation)
// 训练数据
patterns := [][][]float64{
{{0, 0}, {0}},
{{0, 1}, {1}},
{{1, 0}, {1}},
{{1, 1}, {0}},
}
ff.Train(patterns, 5000, 0.6, 0.4, true)
// 测试
result := ff.Update([]float64{1, 1})
fmt.Println("1 XOR 1 =>", result[0])
}
实际应用案例:手写数字识别
下面是一个简化的手写数字识别示例(实际应用中需要更复杂的网络和更多数据):
package main
import (
"fmt"
"github.com/goml/gobrain"
)
func main() {
// 假设我们有8x8像素的手写数字图像,展平为64维向量
// 输出是10个数字的概率分布
// 创建网络:64输入,128个隐藏节点,10输出
ff := gobrain.NewFeedForward(64, 128, 10)
// 模拟训练数据 (实际应用中应从文件加载真实数据)
var patterns [][][]float64
for i := 0; i < 10; i++ {
// 模拟输入数据 (实际应为图像像素值)
input := make([]float64, 64)
for j := range input {
input[j] = float64(i) * 0.1 // 简单模拟数据
}
// 创建期望输出 (one-hot编码)
output := make([]float64, 10)
output[i] = 1.0
patterns = append(patterns, [][]float64{input, output})
}
// 训练
ff.Train(patterns, 1000, 0.5, 0.4, true)
// 测试
testInput := make([]float64, 64)
for i := range testInput {
testInput[i] = 0.5 // 模拟测试输入
}
result := ff.Update(testInput)
fmt.Println("Predicted probabilities:", result)
// 找出最可能的数字
maxIndex := 0
for i, val := range result {
if val > result[maxIndex] {
maxIndex = i
}
}
fmt.Printf("Predicted number: %d (confidence: %.2f%%)\n",
maxIndex, result[maxIndex]*100)
}
注意事项
- gobrain是一个相对简单的神经网络实现,适合学习和简单任务
- 对于复杂任务,可能需要考虑更强大的库如Gorgonia或GoML
- 训练数据需要适当归一化(通常在0到1或-1到1之间)
- 学习率和动量参数需要根据具体问题调整
- 网络结构(隐藏层数量和节点数)需要实验确定
gobrain提供了神经网络的基本功能,虽然不如TensorFlow或PyTorch强大,但对于Golang中的简单机器学习任务来说是一个不错的选择。