golang神经网络实现与训练插件库gobrain的使用

golang神经网络实现与训练插件库gobrain的使用

简介

gobrain是一个用Go语言实现的神经网络库,主要包含前馈神经网络和Elman循环神经网络功能。

快速开始

以下是一个完整的使用gobrain实现XOR逻辑的示例:

package main

import (
	"github.com/goml/gobrain"
	"math/rand"
)

func main() {
	// 设置随机种子为0
	rand.Seed(0)

	// 创建XOR训练数据
	patterns := [][][]float64{
		{{0, 0}, {0}},
		{{0, 1}, {1}},
		{{1, 0}, {1}},
		{{1, 1}, {0}},
	}

	// 实例化前馈神经网络
	ff := &gobrain.FeedForward{}

	// 初始化神经网络结构
	// 2个输入节点,2个隐藏节点,1个输出节点
	ff.Init(2, 2, 1)

	// 训练网络
	// 训练1000次,学习率0.6,动量因子0.4,显示训练误差
	ff.Train(patterns, 1000, 0.6, 0.4, true)

	// 测试网络
	ff.Test(patterns)
}

运行上述代码后,网络将被训练并准备好使用。

测试网络

可以使用Test方法测试网络性能:

ff.Test(patterns)

测试结果会显示类似如下输出:

[0 0] -> [0.057503945708445]  :  [0]
[0 1] -> [0.930100635071210]  :  [1]
[1 0] -> [0.927809966227284]  :  [1]
[1 1] -> [0.097408795324620]  :  [0]

其中箭头->前是输入值,箭头后是网络输出值,冒号:后是期望输出值。

预测新数据

使用Update方法可以进行预测:

inputs := []float64{1, 1}
ff.Update(inputs)

输出将是一个0到1之间的值向量。

循环神经网络(RNN)

gobrain实现了Elman简单循环网络。可以通过SetContexts函数使用这一功能:

ff.SetContexts(1, nil)

上面的例子创建了一个初始值为0.5的上下文。也可以创建自定义初始化的上下文:

contexts := [][]float64{
	{0.5, 0.8, 0.1}
}

注意:自定义上下文的长度必须等于隐藏节点数+1(偏置节点)。在上面的例子中,隐藏节点数是2,因此上下文有3个值。

持久化

在示例文件夹中有训练好的网络持久化到文件的完整示例:

  • example/02 将网络保存到文件
  • example/03 从文件加载网络

要运行示例,进入文件夹并执行:

go run main.go

版本变更

  • 1.0.0 - 添加了具有Elman RNN上下文的前馈神经网络

更多关于golang神经网络实现与训练插件库gobrain的使用的实战教程也可以访问 https://www.itying.com/category-94-b0.html

1 回复

更多关于golang神经网络实现与训练插件库gobrain的使用的实战系列教程也可以访问 https://www.itying.com/category-94-b0.html


使用gobrain实现Golang神经网络

gobrain是一个轻量级的Golang神经网络库,它实现了前馈神经网络和简单的训练算法。下面我将介绍如何使用gobrain库来构建和训练神经网络。

安装gobrain

首先需要安装gobrain库:

go get github.com/goml/gobrain

基本使用示例

1. 创建简单的前馈神经网络

package main

import (
	"fmt"
	"github.com/goml/gobrain"
)

func main() {
	// 创建新的前馈神经网络
	// 参数:输入层节点数、隐藏层节点数、输出层节点数
	ff := gobrain.NewFeedForward(2, 2, 1)

	// 训练数据 - XOR问题
	patterns := [][][]float64{
		{{0, 0}, {0}},
		{{0, 1}, {1}},
		{{1, 0}, {1}},
		{{1, 1}, {0}},
	}

	// 训练神经网络
	// 参数:训练数据、训练次数、学习率、动量因子、是否打印错误
	ff.Train(patterns, 10000, 0.6, 0.4, true)

	// 测试训练结果
	inputs := [][]float64{
		{0, 0},
		{0, 1},
		{1, 0},
		{1, 1},
	}

	for _, input := range inputs {
		result := ff.Update(input)
		fmt.Printf("%f XOR %f => %f\n", input[0], input[1], result[0])
	}
}

2. 自定义神经网络结构

gobrain允许你自定义更复杂的网络结构:

package main

import (
	"fmt"
	"github.com/goml/gobrain"
)

func main() {
	// 创建自定义神经网络
	ff := &gobrain.FeedForward{}
	
	// 初始化网络结构
	// 参数:输入层、隐藏层、输出层节点数
	ff.Init(3, []int{4, 4}, 2) // 3输入,2个隐藏层(各4节点),2输出

	// 训练数据 - 模拟分类问题
	patterns := [][][]float64{
		{{0.1, 0.2, 0.3}, {1, 0}},
		{{0.4, 0.5, 0.6}, {0, 1}},
		{{0.7, 0.8, 0.9}, {1, 0}},
		{{0.2, 0.3, 0.4}, {0, 1}},
	}

	// 训练参数
	ff.Train(patterns, 5000, 0.5, 0.4, true)

	// 测试
	testInput := []float64{0.3, 0.4, 0.5}
	result := ff.Update(testInput)
	fmt.Printf("Test result: %v\n", result)
}

3. 持久化模型

gobrain支持将训练好的模型保存到文件以及从文件加载:

package main

import (
	"fmt"
	"github.com/goml/gobrain"
	"os"
)

func main() {
	ff := gobrain.NewFeedForward(2, 3, 1)
	
	// 训练数据
	patterns := [][][]float64{
		{{0, 0}, {0}},
		{{0, 1}, {1}},
		{{1, 0}, {1}},
		{{1, 1}, {0}},
	}
	
	ff.Train(patterns, 1000, 0.6, 0.4, false)

	// 保存模型到文件
	err := ff.Save("xor_model.json")
	if err != nil {
		fmt.Println("保存模型失败:", err)
		return
	}

	// 从文件加载模型
	newFF := &gobrain.FeedForward{}
	err = newFF.Load("xor_model.json")
	if err != nil {
		fmt.Println("加载模型失败:", err)
		return
	}

	// 测试加载的模型
	result := newFF.Update([]float64{1, 0})
	fmt.Println("1 XOR 0 =>", result[0])
	
	// 清理测试文件
	os.Remove("xor_model.json")
}

高级功能

自定义激活函数

gobrain默认使用sigmoid激活函数,但你可以自定义:

package main

import (
	"fmt"
	"github.com/goml/gobrain"
	"math"
)

// 自定义激活函数
func customActivation(x float64) float64 {
	// 使用tanh激活函数
	return math.Tanh(x)
}

func main() {
	ff := gobrain.NewFeedForward(2, 4, 1)
	
	// 设置自定义激活函数
	ff.SetActivationFunction(customActivation)
	
	// 训练数据
	patterns := [][][]float64{
		{{0, 0}, {0}},
		{{0, 1}, {1}},
		{{1, 0}, {1}},
		{{1, 1}, {0}},
	}
	
	ff.Train(patterns, 5000, 0.6, 0.4, true)
	
	// 测试
	result := ff.Update([]float64{1, 1})
	fmt.Println("1 XOR 1 =>", result[0])
}

实际应用案例:手写数字识别

下面是一个简化的手写数字识别示例(实际应用中需要更复杂的网络和更多数据):

package main

import (
	"fmt"
	"github.com/goml/gobrain"
)

func main() {
	// 假设我们有8x8像素的手写数字图像,展平为64维向量
	// 输出是10个数字的概率分布
	
	// 创建网络:64输入,128个隐藏节点,10输出
	ff := gobrain.NewFeedForward(64, 128, 10)
	
	// 模拟训练数据 (实际应用中应从文件加载真实数据)
	var patterns [][][]float64
	for i := 0; i < 10; i++ {
		// 模拟输入数据 (实际应为图像像素值)
		input := make([]float64, 64)
		for j := range input {
			input[j] = float64(i) * 0.1 // 简单模拟数据
		}
		
		// 创建期望输出 (one-hot编码)
		output := make([]float64, 10)
		output[i] = 1.0
		
		patterns = append(patterns, [][]float64{input, output})
	}
	
	// 训练
	ff.Train(patterns, 1000, 0.5, 0.4, true)
	
	// 测试
	testInput := make([]float64, 64)
	for i := range testInput {
		testInput[i] = 0.5 // 模拟测试输入
	}
	
	result := ff.Update(testInput)
	fmt.Println("Predicted probabilities:", result)
	
	// 找出最可能的数字
	maxIndex := 0
	for i, val := range result {
		if val > result[maxIndex] {
			maxIndex = i
		}
	}
	fmt.Printf("Predicted number: %d (confidence: %.2f%%)\n", 
		maxIndex, result[maxIndex]*100)
}

注意事项

  1. gobrain是一个相对简单的神经网络实现,适合学习和简单任务
  2. 对于复杂任务,可能需要考虑更强大的库如Gorgonia或GoML
  3. 训练数据需要适当归一化(通常在0到1或-1到1之间)
  4. 学习率和动量参数需要根据具体问题调整
  5. 网络结构(隐藏层数量和节点数)需要实验确定

gobrain提供了神经网络的基本功能,虽然不如TensorFlow或PyTorch强大,但对于Golang中的简单机器学习任务来说是一个不错的选择。

回到顶部