golang高性能决策树梯度提升库插件catboost-cgo的使用

Golang高性能决策树梯度提升库插件catboost-cgo的使用

CatBoost-Cgo

CatBoost-Cgo是一个Golang库,提供了对CatBoost机器学习库的C API接口访问。CatBoost是一个高性能的梯度提升决策树库,特别适合处理类别特征。

兼容性

操作系统 CPU架构 CUDA GPU支持
MacOS ✅ (x86_64) 🚫
Linux ✅ (x86_64) ✅ (x86_64)
Windows 10/11 🚫 🚫

功能特性

支持的模型类型:

  • CatBoostRegressor ✅
  • CatBoostClassifier ✅
  • CatBoostRanker ✅

支持的预测类型:

  • RawFormulaVal ✅
  • Probability ✅
  • Class ✅
  • RMSEWithUncertainty ✅
  • Exponent ✅

安装

  1. 安装catboost-cgo库:
go get github.com/mirecl/catboost-cgo
  1. 从CatBoost发布页面下载共享库文件(.so或.dylib)

  2. 将共享库放在以下位置之一:

    • /usr/local/lib
    • 或设置环境变量CATBOOST_LIBRARY_PATH
    • 或在代码中手动设置路径:
import (
 cb "github.com/mirecl/catboost-cgo/catboost"
)

func main(){
  cb.SetSharedLibraryPath(...)
}

使用示例

回归模型示例

package main

import (
	"fmt"
	"log"

	cb "github.com/mirecl/catboost-cgo/catboost"
)

func main() {
	// 1. 加载模型
	model, err := cb.LoadModelFromFile("model.cbm")
	if err != nil {
		log.Fatal(err)
	}
	defer model.Free()

	// 2. 准备特征数据
	floatFeatures := []float32{1.2, 3.4, 5.6}
	catFeatures := []string{"red", "small"}

	// 3. 进行预测
	result, err := model.Predict(floatFeatures, catFeatures, cb.RawFormulaVal)
	if err != nil {
		log.Fatal(err)
	}

	fmt.Printf("Prediction result: %f\n", result)
}

分类模型示例

package main

import (
	"fmt"
	"log"

	cb "github.com/mirecl/catboost-cgo/catboost"
)

func main() {
	// 1. 加载模型
	model, err := cb.LoadModelFromFile("classifier.cbm")
	if err != nil {
		log.Fatal(err)
	}
	defer model.Free()

	// 2. 准备特征数据
	floatFeatures := []float32{0.5, 1.8, 2.3}
	catFeatures := []string{"blue", "medium"}

	// 3. 进行概率预测
	prob, err := model.Predict(floatFeatures, catFeatures, cb.Probability)
	if err != nil {
		log.Fatal(err)
	}

	// 4. 获取类别预测
	class, err := model.Predict(floatFeatures, catFeatures, cb.Class)
	if err != nil {
		log.Fatal(err)
	}

	fmt.Printf("Probability: %f, Class: %f\n", prob, class)
}

排序模型示例

package main

import (
	"fmt"
	"log"

	cb "github.com/mirecl/catboost-cgo/catboost"
)

func main() {
	// 1. 加载排序模型
	model, err := cb.LoadModelFromFile("ranker.cbm")
	if err != nil {
		log.Fatal(err)
	}
	defer model.Free()

	// 2. 准备特征数据
	floatFeatures := []float32{0.8, 1.2, 0.5}
	catFeatures := []string{"premium", "high"}

	// 3. 进行排序预测
	score, err := model.Predict(floatFeatures, catFeatures, cb.RawFormulaVal)
	if err != nil {
		log.Fatal(err)
	}

	fmt.Printf("Ranking score: %f\n", score)
}

注意事项

  1. 目前仅支持Linux和MacOS系统
  2. GPU支持仅限于Linux平台且仅支持设备0
  3. 支持数值和类别特征,但不支持文本和嵌入特征

更多使用示例可以参考库中的examples目录,包括回归、分类、排序等多种应用场景。


更多关于golang高性能决策树梯度提升库插件catboost-cgo的使用的实战教程也可以访问 https://www.itying.com/category-94-b0.html

1 回复

更多关于golang高性能决策树梯度提升库插件catboost-cgo的使用的实战系列教程也可以访问 https://www.itying.com/category-94-b0.html


CatBoost-CGO: Golang高性能决策树梯度提升库使用指南

CatBoost是Yandex开发的高性能梯度提升决策树(GBDT)库,而catboost-cgo是其Golang绑定。下面我将详细介绍如何在Golang中使用catboost-cgo进行机器学习建模。

安装与准备

首先需要安装CatBoost C++库和Go绑定:

# 安装CatBoost C++库
git clone https://github.com/catboost/catboost.git
cd catboost/catboost/libs/model_interface
make -j

# 安装Go绑定
go get github.com/dmitryikh/leaves

基本使用示例

1. 训练模型

package main

import (
	"fmt"
	"github.com/dmitryikh/leaves"
)

func main() {
	// 1. 准备训练数据
	trainData := [][]float64{
		{1.0, 3.0, 5.0},
		{2.0, 4.0, 6.0},
		{3.0, 5.0, 7.0},
	}
	trainLabels := []float64{10.0, 20.0, 30.0}

	// 2. 创建CatBoost训练器
	trainer := leaves.CatBoostTrainer{
		Iterations:    100,
		LearningRate: 0.1,
		LossFunction: "RMSE",
		ThreadCount:  4,
	}

	// 3. 训练模型
	model, err := trainer.Fit(trainData, trainLabels, nil)
	if err != nil {
		panic(err)
	}

	// 4. 保存模型
	err = model.Save("model.cbm")
	if err != nil {
		panic(err)
	}
}

2. 加载模型并进行预测

package main

import (
	"fmt"
	"github.com/dmitryikh/leaves"
)

func main() {
	// 1. 加载已训练模型
	model, err := leaves.CatBoostFromFile("model.cbm")
	if err != nil {
		panic(err)
	}

	// 2. 准备测试数据
	testData := [][]float64{
		{1.5, 3.5, 5.5},
		{2.5, 4.5, 6.5},
	}

	// 3. 进行预测
	predictions := make([]float64, len(testData))
	for i, sample := range testData {
		predictions[i] = model.PredictSingle(sample)
	}

	fmt.Println("Predictions:", predictions)
}

高级功能

1. 使用分类特征

CatBoost的一个主要优势是能直接处理分类特征:

// 定义分类特征索引
catFeatures := []int{0, 2} // 第0和第2列是分类特征

// 训练时指定分类特征
model, err := trainer.Fit(trainData, trainLabels, catFeatures)

2. 自定义评估指标

trainer := leaves.CatBoostTrainer{
	Iterations:    100,
	LearningRate: 0.1,
	LossFunction: "Logloss",
	EvalMetric:   "AUC",
	ThreadCount:  4,
}

3. 早停机制

trainer := leaves.CatBoostTrainer{
	Iterations:    1000,
	LearningRate: 0.1,
	LossFunction: "RMSE",
	EarlyStoppingRounds: 50, // 50轮无改进则停止
}

性能优化技巧

  1. 批量预测:对于大批量数据,使用批量预测更高效
predictions := model.Predict(testData)
  1. 调整线程数:根据CPU核心数调整线程数
trainer.ThreadCount = runtime.NumCPU()
  1. 使用Float32:如果精度允许,使用float32而非float64
model.PredictSingleFloat32(sampleFloat32)

常见问题解决

  1. 内存不足:减少迭代次数或使用更小的模型

  2. 预测结果不稳定:增加迭代次数或降低学习率

  3. 加载模型失败:确保模型文件路径正确且完整

结论

catboost-cgo为Golang开发者提供了强大的梯度提升决策树实现,特别适合处理包含分类特征的表格数据。通过合理配置参数和利用其高级功能,可以在保持Golang性能优势的同时获得优秀的机器学习模型效果。

如需更复杂的功能,建议直接使用CatBoost的Python接口,并通过Go调用Python脚本实现。

回到顶部