Golang基础CSV操作与线性代数库 - RocketC

Golang基础CSV操作与线性代数库 - RocketC 我正在开发一个库,旨在提供从零开始编写机器学习算法的功能。目前它还处于非常早期的阶段,仅包含一些基础功能,我正在积极开发中。请花时间看一下,并提供宝贵的反馈和建议。

GitHub

aryanmaurya1/RocketC

头像

用于CSV数据操作和训练线性回归模型的简单库。 - aryanmaurya1/RocketC


更多关于Golang基础CSV操作与线性代数库 - RocketC的实战教程也可以访问 https://www.itying.com/category-94-b0.html

3 回复

感谢您提供的反馈,我将根据您的建议进行修改。

更多关于Golang基础CSV操作与线性代数库 - RocketC的实战系列教程也可以访问 https://www.itying.com/category-94-b0.html


我希望在GitHub上评论已提交的代码能有更简单的方法!

我快速浏览了5分钟,注意到以下几点:

  • 当你不修改 d 时,不要使用 func (d *DataFrame),直接使用 func (d DataFrame) —— DataFrame 是一个切片,传递它的成本很低。这也会让你的代码更清晰!(Matrix 同理)
  • make(DataFrame, row, row) 是多余的,直接用 make(DataFrame, row) 就可以了
  • 与其先写 n := m.Row() 然后 for i := 0; i < n; i++ {,我更倾向于直接写 for i := range m {,但这取决于你的习惯
  • ReadCsvMatrix 应该命名为 ReadCSVMatrix
  • ReadCsvMatrix 应该返回 (Matrix, error) —— 不要打印错误并返回 nil,应该把错误返回!
  • dropFrist 拼写错误 —— 应该是 dropFirst
  • 进阶提示:在你向 f 写入数据的地方使用 defer f.Close() 可能会掩盖错误 —— 你应该在这里检查错误

祝你好运!

作为Go语言开发者,看到你开始构建机器学习基础库RocketC,这是一个很有价值的项目。以下是我对当前代码的专业分析:

CSV操作部分

你的ReadCSV函数基础实现不错,但需要更多错误处理和功能扩展:

// 建议增强的CSV读取函数
func ReadCSVWithOptions(filename string, hasHeader bool, delimiter rune) ([][]string, []string, error) {
    file, err := os.Open(filename)
    if err != nil {
        return nil, nil, fmt.Errorf("打开文件失败: %w", err)
    }
    defer file.Close()

    reader := csv.NewReader(file)
    reader.Comma = delimiter
    reader.FieldsPerRecord = -1 // 允许可变字段数
    
    records, err := reader.ReadAll()
    if err != nil {
        return nil, nil, fmt.Errorf("读取CSV失败: %w", err)
    }
    
    if len(records) == 0 {
        return nil, nil, errors.New("CSV文件为空")
    }
    
    var header []string
    var data [][]string
    
    if hasHeader {
        header = records[0]
        data = records[1:]
    } else {
        data = records
    }
    
    return data, header, nil
}

线性回归实现

当前实现可以优化为更高效的矩阵运算:

type LinearRegression struct {
    Weights []float64
    Bias    float64
}

func (lr *LinearRegression) Train(X [][]float64, y []float64, learningRate float64, epochs int) {
    if len(X) == 0 || len(X[0]) == 0 {
        return
    }
    
    nFeatures := len(X[0])
    nSamples := len(X)
    
    // 初始化权重
    lr.Weights = make([]float64, nFeatures)
    lr.Bias = 0.0
    
    for epoch := 0; epoch < epochs; epoch++ {
        totalError := 0.0
        
        for i := 0; i < nSamples; i++ {
            // 预测值
            prediction := lr.Bias
            for j := 0; j < nFeatures; j++ {
                prediction += lr.Weights[j] * X[i][j]
            }
            
            // 计算误差
            error := y[i] - prediction
            totalError += error * error
            
            // 更新权重和偏置
            lr.Bias += learningRate * error
            for j := 0; j < nFeatures; j++ {
                lr.Weights[j] += learningRate * error * X[i][j]
            }
        }
        
        // 可选:打印每轮损失
        if epoch%100 == 0 {
            fmt.Printf("Epoch %d, MSE: %f\n", epoch, totalError/float64(nSamples))
        }
    }
}

// 批量梯度下降版本(更高效)
func (lr *LinearRegression) TrainBatch(X [][]float64, y []float64, learningRate float64, epochs int, batchSize int) {
    nSamples := len(X)
    nFeatures := len(X[0])
    
    for epoch := 0; epoch < epochs; epoch++ {
        for start := 0; start < nSamples; start += batchSize {
            end := start + batchSize
            if end > nSamples {
                end = nSamples
            }
            
            // 计算批次梯度
            gradWeights := make([]float64, nFeatures)
            gradBias := 0.0
            
            for i := start; i < end; i++ {
                prediction := lr.Bias
                for j := 0; j < nFeatures; j++ {
                    prediction += lr.Weights[j] * X[i][j]
                }
                
                error := y[i] - prediction
                gradBias += error
                for j := 0; j < nFeatures; j++ {
                    gradWeights[j] += error * X[i][j]
                }
            }
            
            // 更新参数
            batchLen := float64(end - start)
            lr.Bias += learningRate * gradBias / batchLen
            for j := 0; j < nFeatures; j++ {
                lr.Weights[j] += learningRate * gradWeights[j] / batchLen
            }
        }
    }
}

性能优化建议

// 使用gonum进行矩阵运算(如果考虑添加依赖)
import "gonum.org/v1/gonum/mat"

func TrainWithGonum(X *mat.Dense, y *mat.VecDense) *mat.VecDense {
    // 使用正规方程 (X^T X)^-1 X^T y
    var XT, XTX, XTXInv mat.Dense
    XT.CloneFrom(X.T())
    XTX.Mul(&XT, X)
    
    var inv mat.Dense
    inv.Inverse(&XTX)
    
    var XTXInvXT mat.Dense
    XTXInvXT.Mul(&inv, &XT)
    
    var weights mat.VecDense
    weights.MulVec(&XTXInvXT, y)
    
    return &weights
}

测试示例

func TestLinearRegression() {
    // 示例数据
    X := [][]float64{
        {1.0, 2.0},
        {2.0, 3.0},
        {3.0, 4.0},
        {4.0, 5.0},
    }
    y := []float64{3.0, 5.0, 7.0, 9.0}
    
    lr := &LinearRegression{}
    lr.Train(X, y, 0.01, 1000)
    
    // 预测
    testX := []float64{5.0, 6.0}
    prediction := lr.Bias
    for i := 0; i < len(testX); i++ {
        prediction += lr.Weights[i] * testX[i]
    }
    fmt.Printf("预测结果: %f\n", prediction)
}

当前库的基础架构合理,建议下一步添加:

  1. 数据标准化/归一化功能
  2. 模型持久化(保存/加载权重)
  3. 交叉验证支持
  4. 性能指标计算(R²、MSE等)
回到顶部