golang通用机器学习库插件GoLearn的使用

GoLearn通用机器学习库插件使用指南

GoLearn是一个"开箱即用"的Go语言机器学习库,目标是简单且可定制化。目前正在积极开发中,欢迎用户反馈。

安装

请参考官方安装说明进行安装。

快速开始

数据以Instances的形式加载,您可以对它们执行类似矩阵的操作,并将它们传递给估计器。GoLearn实现了scikit-learn的Fit/Predict接口,因此您可以轻松地交换估计器进行试验。GoLearn还包括数据辅助函数,如交叉验证和训练测试分割。

示例代码

package main

import (
	"fmt"

	"github.com/sjwhitworth/golearn/base"
	"github.com/sjwhitworth/golearn/evaluation"
	"github.com/sjwhitworth/golearn/knn"
)

func main() {
	// 加载数据集,包含表头。表头属性将被存储。
	// 可以将instances视为R或Pandas中的Data Frame结构。
	// 您也可以从头创建instances。
	rawData, err := base.ParseCSVToInstances("datasets/iris.csv", true)
	if err != nil {
		panic(err)
	}

	// 打印数据的友好摘要
	fmt.Println(rawData)

	// 初始化一个新的KNN分类器
	cls := knn.NewKnnClassifier("euclidean", "linear", 2)

	// 进行训练-测试分割
	trainData, testData := base.InstancesTrainTestSplit(rawData, 0.50)
	cls.Fit(trainData)

	// 计算欧几里得距离并返回最流行的标签
	predictions, err := cls.Predict(testData)
	if err != nil {
		panic(err)
	}

	// 打印精确度/召回率指标
	confusionMat, err := evaluation.GetConfusionMatrix(testData, predictions)
	if err != nil {
		panic(fmt.Sprintf("无法获取混淆矩阵: %s", err.Error()))
	}
	fmt.Println(evaluation.GetSummary(confusionMat))
}

示例输出:

Iris-virginica	28	2	  56	0.9333	0.9333  0.9333
Iris-setosa	    29	0	  59	1.0000  1.0000	1.0000
Iris-versicolor	27	2	  57	0.9310	0.9310  0.9310
总体准确率: 0.9545

更多示例

GoLearn附带了许多实用示例。您可以查看并运行这些示例来了解其工作原理:

cd $GOPATH/src/github.com/sjwhitworth/golearn/examples/knnclassifier
go run knnclassifier_iris.go

cd $GOPATH/src/github.com/sjwhitworth/golearn/examples/instances
go run instances.go

cd $GOPATH/src/github.com/sjwhitworth/golearn/examples/trees
go run trees.go

文档

  • 英文文档
  • 中文文档(简体)
  • 中文文档(繁体)

加入团队

如有兴趣加入开发团队,请发送邮件至stephenjameswhitworth@gmail.com


更多关于golang通用机器学习库插件GoLearn的使用的实战教程也可以访问 https://www.itying.com/category-94-b0.html

1 回复

更多关于golang通用机器学习库插件GoLearn的使用的实战系列教程也可以访问 https://www.itying.com/category-94-b0.html


GoLearn - Go语言的通用机器学习库

GoLearn是Go语言中一个简单但功能强大的机器学习库,它提供了许多常见的机器学习算法实现。下面我将介绍GoLearn的基本用法和示例代码。

安装GoLearn

go get -u github.com/sjwhitworth/golearn

基本功能

1. 加载数据集

GoLearn支持多种格式的数据集,包括CSV、ARFF等。

package main

import (
	"fmt"
	"github.com/sjwhitworth/golearn/base"
)

func main() {
	// 从CSV文件加载数据集
	rawData, err := base.ParseCSVToInstances("data.csv", true)
	if err != nil {
		panic(err)
	}
	
	// 打印数据集信息
	fmt.Println(rawData)
}

2. 数据预处理

import (
	"github.com/sjwhitworth/golearn/filters"
)

func preprocessData(data *base.Instances) *base.Instances {
	// 离散化连续属性
	discretizeFilter := filters.NewDiscretizeFilter(data, 10) // 分为10个区间
	for _, attr := range data.AllAttributes() {
		if attr.GetType() == base.Float64Type {
			discretizeFilter.AddAttribute(attr)
		}
	}
	discretizeFilter.Train()
	filteredData := base.NewLazilyFilteredInstances(data, discretizeFilter)
	
	return filteredData
}

3. 训练和评估模型

决策树示例

import (
	"github.com/sjwhitworth/golearn/trees"
	"github.com/sjwhitworth/golearn/evaluation"
)

func decisionTreeExample(data *base.Instances) {
	// 划分训练集和测试集
	trainData, testData := base.InstancesTrainTestSplit(data, 0.75)
	
	// 创建决策树模型
	tree := trees.NewID3DecisionTree(0.6) // 信息增益阈值0.6
	
	// 训练模型
	err := tree.Fit(trainData)
	if err != nil {
		panic(err)
	}
	
	// 预测
	predictions, err := tree.Predict(testData)
	if err != nil {
		panic(err)
	}
	
	// 评估模型
	confusionMat, err := evaluation.GetConfusionMatrix(testData, predictions)
	if err != nil {
		panic(err)
	}
	
	fmt.Println(evaluation.GetSummary(confusionMat))
}

KNN示例

import (
	"github.com/sjwhitworth/golearn/knn"
)

func knnExample(data *base.Instances) {
	trainData, testData := base.InstancesTrainTestSplit(data, 0.75)
	
	// 创建KNN分类器,k=10
	cls := knn.NewKnnClassifier("euclidean", "linear", 10)
	
	// 训练
	cls.Fit(trainData)
	
	// 预测
	predictions := cls.Predict(testData)
	
	// 评估
	confusionMat := evaluation.GetConfusionMatrix(testData, predictions)
	fmt.Println(evaluation.GetSummary(confusionMat))
}

4. 特征选择

import (
	"github.com/sjwhitworth/golearn/feature_selection"
)

func featureSelectionExample(data *base.Instances) {
	// 创建特征选择器
	selector := feature_selection.NewChiSquaredFeatureSelection(5) // 选择5个最佳特征
	
	// 应用特征选择
	selectedData, err := selector.Transform(data)
	if err != nil {
		panic(err)
	}
	
	return selectedData
}

完整示例

下面是一个完整的机器学习流程示例:

package main

import (
	"fmt"
	"github.com/sjwhitworth/golearn/base"
	"github.com/sjwhitworth/golearn/evaluation"
	"github.com/sjwhitworth/golearn/knn"
	"github.com/sjwhitworth/golearn/filters"
)

func main() {
	// 1. 加载数据
	rawData, err := base.ParseCSVToInstances("data/iris.csv", true)
	if err != nil {
		panic(err)
	}
	
	// 2. 数据预处理
	filteredData := preprocessData(rawData)
	
	// 3. 划分训练测试集
	trainData, testData := base.InstancesTrainTestSplit(filteredData, 0.75)
	
	// 4. 初始化KNN分类器
	cls := knn.NewKnnClassifier("euclidean", "linear", 2)
	
	// 5. 训练模型
	cls.Fit(trainData)
	
	// 6. 预测
	predictions, err := cls.Predict(testData)
	if err != nil {
		panic(err)
	}
	
	// 7. 评估
	confusionMat, err := evaluation.GetConfusionMatrix(testData, predictions)
	if err != nil {
		panic(err)
	}
	
	fmt.Println(evaluation.GetSummary(confusionMat))
}

func preprocessData(data *base.Instances) *base.Instances {
	// 这里可以添加各种预处理步骤
	return data
}

GoLearn支持的主要算法

  1. 分类算法:

    • KNN (k-近邻)
    • 决策树 (ID3)
    • 随机森林
    • 朴素贝叶斯
    • 线性回归 (也可用于分类)
  2. 聚类算法:

    • K-means
  3. 特征选择:

    • 卡方检验
    • 信息增益
  4. 模型评估:

    • 交叉验证
    • 混淆矩阵
    • 多种评估指标 (准确率、召回率、F1等)

优缺点

优点:

  • 纯Go实现,易于部署
  • 简单的API设计
  • 支持常见机器学习任务
  • 活跃的社区支持

缺点:

  • 相比Python生态(如scikit-learn)功能较少
  • 性能可能不如专用库
  • 深度学习支持有限

GoLearn适合在Go生态系统中快速实现机器学习原型或简单应用。对于更复杂的任务,可能需要考虑其他更专业的库或语言。

回到顶部