返回文章列表

Rust機器學習資料處理與模型訓練實作

本文介紹如何使用 Rust 語言和 rusty-machine 套件進行機器學習任務,包含資料生成、預處理、神經網路模型訓練和預測等步驟。文章詳細說明瞭如何使用 Standardizer 進行資料正規化,以及如何組態和訓練 NeuralNet 模型。

機器學習 Rust

Rust 的機器學習生態系統正在蓬勃發展,rusty-machine 提供了基礎的機器學習演算法和工具。本文示範如何使用 rusty-machine 建立一個簡單的神經網路模型,並用於分類別任務。首先,我們需要產生訓練和測試資料,並將其儲存為 CSV 格式。接著,使用 rusty-machineStandardizer 對資料進行正規化,以提升模型訓練效率。然後,我們建立一個包含輸入層、隱藏層和輸出層的神經網路模型,並使用隨機梯度下降最佳化器進行訓練。最後,使用訓練好的模型對測試資料進行預測,並評估模型的效能。這個流程涵蓋了機器學習任務的核心步驟,包含資料處理、模型訓練和預測。

建構神經網路模型的前置作業與實作細節

在進行神經網路模型的訓練與預測之前,需要先產生訓練與測試資料。本章節將介紹如何使用 Rust 語言結合相關套件來完成這些任務。

產生訓練與測試資料

首先,需要執行以下指令來產生訓練與測試資料:

$ cargo run --bin generate -- \
--config-file config/generate_cats_and_dogs.toml \
> training_nn.csv
$ cargo run --bin generate -- \
--config-file config/generate_cats_and_dogs.toml \
> testing_nn.csv

這些指令會根據 generate_cats_and_dogs.toml 組態檔的設定,產生 training_nn.csvtesting_nn.csv 兩個 CSV 檔案,分別用於模型的訓練和測試。

設定神經網路模型

在產生訓練與測試資料後,需要建立一個新的 binary 檔案 src/bin/train_and_predict.rs,並在其中實作以下功能:

  • 讀取並解析訓練資料至 Vec,然後轉換成 Array 型別。
  • 對訓練資料進行正規化處理。
  • 初始化神經網路模型。
  • 將正規化後的訓練資料輸入模型進行訓練。
  • 讀取並解析測試資料至 Vec,然後轉換成 Array 型別。
  • 使用相同的正規化引數對測試資料進行正規化處理。
  • 使用訓練好的模型對測試資料進行預測。

由於 Rust 沒有原生支援神經網路的套件,因此本範例使用 rusty-machine 套件來實作神經網路模型。

新增必要的套件

執行以下指令來新增 rusty-machine 套件:

$ cargo add rusty-machine

讀取訓練與測試資料

本範例使用 clap 套件來解析命令列引數,並讀取指定的 CSV 檔案。定義了一個 Args 結構體來儲存命令列引數:

use clap::Parser;
use std::error::Error;

#[derive(Parser)]
struct Args {
    #[arg(short = 'r', long = "train")]
    /// Training data CSV file
    training_data_csv: std::path::PathBuf,
    #[arg(short = 't', long = "test")]
    /// Testing data CSV file
    testing_data_csv: std::path::PathBuf,
}

定義了一個 SampleRow 結構體來反序列化 CSV 資料:

use serde::Deserialize;
use rusty_machine::linalg::Matrix;

#[derive(Debug, Deserialize)]
struct SampleRow {
    height: f64,
    length: f64,
    category_id: usize,
}

實作了一個 read_data_from_csv 函式來讀取 CSV 檔案並轉換成 Matrix 型別:

fn read_data_from_csv(
    file_path: std::path::PathBuf,
) -> Result<(Matrix<f64>, Matrix<f64>), Box<dyn Error>> {
    let mut input_data = vec![];
    let mut label_data = vec![];
    let mut sample_count = 0;
    let mut reader = csv::Reader::from_path(file_path)?;
    for raw_row in reader.deserialize() {
        let row: SampleRow = raw_row?;
        input_data.push(row.height);
        input_data.push(row.length);
        label_data.push(row.category_id as f64);
        sample_count += 1
    }
    let inputs = Matrix::new(sample_count, 2, input_data);
    let targets = Matrix::new(sample_count, 1, label_data);
    return Ok((inputs, targets));
}

程式碼解析

  • read_data_from_csv 函式讀取指定的 CSV 檔案,並將資料反序列化成 SampleRow 結構體。
  • 將每個樣本的 heightlength 特徵存入 input_data 向量,將 category_id 存入 label_data 向量。
  • input_datalabel_data 向量轉換成 Matrix 型別,以供神經網路模型使用。

資料正規化

在將資料輸入神經網路模型之前,需要對資料進行正規化處理,以提高模型的訓練效率和準確度。正規化的目標是將資料轉換成均值為 0、標準差為 1 的分佈。

正規化的步驟

  1. 計算資料的均值和標準差。
  2. 將資料減去均值,然後除以標準差。

正規化可以使梯度下降法更穩定、更快速地收斂,從而提高模型的效能。

為何需要正規化?

  • 正規化可以避免某些特徵因數值範圍較大而主導梯度下降的過程。
  • 正規化可以使成本函式的形狀更平滑,從而使梯度下降法更穩定、更快速地收斂。

正規化的過程

@startuml
skinparam backgroundColor #FEFEFE
skinparam componentStyle rectangle

title Rust機器學習資料處理與模型訓練實作

package "機器學習流程" {
    package "資料處理" {
        component [資料收集] as collect
        component [資料清洗] as clean
        component [特徵工程] as feature
    }

    package "模型訓練" {
        component [模型選擇] as select
        component [超參數調優] as tune
        component [交叉驗證] as cv
    }

    package "評估部署" {
        component [模型評估] as eval
        component [模型部署] as deploy
        component [監控維護] as monitor
    }
}

collect --> clean : 原始資料
clean --> feature : 乾淨資料
feature --> select : 特徵向量
select --> tune : 基礎模型
tune --> cv : 最佳參數
cv --> eval : 訓練模型
eval --> deploy : 驗證模型
deploy --> monitor : 生產模型

note right of feature
  特徵工程包含:
  - 特徵選擇
  - 特徵轉換
  - 降維處理
end note

note right of eval
  評估指標:
  - 準確率/召回率
  - F1 Score
  - AUC-ROC
end note

@enduml

此圖表展示了正規化的過程:首先計算資料的均值和標準差,然後將資料減去均值,最後除以標準差,得到正規化後的資料。

資料標準化與神經網路模型訓練

在進行機器學習任務時,資料預處理是至關重要的一步。資料標準化(Data Standardization)是其中一種常見的預處理技術,目的是將資料縮放到相同的尺度,以提高模型的訓練效率和準確性。

資料標準化的重要性

資料標準化涉及兩個主要步驟:

  1. 計算資料集的平均值並從所有資料點中減去它:這使得資料集的平均值變為0。
  2. 計算資料集的標準差並將每個資料點的座標除以標準差:這使得資料集的標準差變為1。

在進行資料標準化時,必須保留訓練資料的平均值和標準差。當標準化測試資料時,將使用訓練資料的平均值和標準差。這是因為神經網路模型中的所有引數都是針對標準化的訓練資料進行訓練的。如果測試資料具有不同的平均值和標準差,模型的預測可能會出現偏差。

使用Rusty-Machine進行資料標準化

Rusty-Machine提供了一個方便的Standardizer結構體來執行資料標準化。Standardizer實作了Transformer特徵,定義了常用資料預處理轉換的分享介面。

Standardizer的使用方法

  1. 初始化Standardizer:使用new()函式初始化,指定期望的平均值和標準差。通常,平均值設為0,標準差設為1。
  2. 擬合(Fit):使用fit()函式計算輸入資料的平均值和標準差,並將其儲存在Standardizer例項中。
  3. 轉換(Transform):使用transform()函式對提供的資料進行轉換,使用在fit()步驟中學習到的平均值和標準差。
use rusty_machine::data::transforms::Transformer;
use rusty_machine::data::transforms::Standardizer;

fn main() -> Result<(), Box<dyn Error>> {
    // ...
    let mut standardizer = Standardizer::new(0.0, 1.0);
    standardizer.fit(&training_inputs).unwrap();
    let normalized_training_inputs = standardizer.transform(training_inputs).unwrap();
    // ...
    let normalized_test_cases = standardizer.transform(testing_inputs.clone())?;
    // ...
}

訓練神經網路模型

在完成資料預處理後,可以開始訓練神經網路模型。神經網路模型的組態比K-means更為複雜,需要設定:

  • 層數和每層節點數
  • 判別準則(Criterion),包括啟用函式和損失函式
  • 最佳化演算法

使用Rusty-Machine訓練神經網路

Rusty-Machine的NeuralNet結構體提供了一個預設組態。以下是一個範例:

use rusty_machine::learning::nnet::{NeuralNet, BCECriterion};
use rusty_machine::learning::optim::grad_desc::StochasticGD;
use rusty_machine::learning::SupModel;

fn main() -> Result<(), Box<dyn Error>> {
    // ... 載入訓練資料和預處理 ...
    let layers = &[2, 2, 1];
    let criterion = BCECriterion::default();
    let gradient_descent = StochasticGD::new(0.1, 0.1, 20);
    let mut model = NeuralNet::new(layers, criterion, gradient_descent);
    model.train(&normalized_training_inputs, &training_label_data)?;
    // ... 測試 ...
}

組態解析

  1. 層組態:定義了一個三層架構,第一層有2個輸入神經元,中間層有2個神經元,輸出層有1個神經元。
  2. 判別準則:使用了預設的二元交叉熵判別準則(BCECriterion),它採用Sigmoid啟用函式和交叉熵誤差作為損失函式。
  3. 最佳化演算法:選擇了隨機梯度下降(StochasticGD)作為最佳化演算法,具有動量(Momentum)、學習率(Learning Rate)和迭代次數(Number of Iterations)等引數。

#### 內容解密:

  • let layers = &[2, 2, 1];:定義神經網路的層結構,第一層2個輸入節點,第二層2個隱藏節點,第三層1個輸出節點。
  • let criterion = BCECriterion::default();:選擇二元交叉熵作為損失函式,使用Sigmoid作為啟用函式。
  • let gradient_descent = StochasticGD::new(0.1, 0.1, 20);:初始化隨機梯度下降最佳化器,引數包括學習率、動量和迭代次數。
  • model.train(&normalized_training_inputs, &training_label_data)?;:使用標準化後的訓練資料訓練神經網路模型。

使用神經網路進行預測與機器學習實務

在前面的章節中,我們已經完成了神經網路模型的訓練。現在,我們將使用這個模型對新的資料進行預測,以驗證其準確性。首先,我們利用 generate_data.rs 指令碼生成了4,000個新的資料點。這些資料僅包含高度和長度,模型將根據這些輸入計算輸出結果。

進行預測

神經網路模型的輸出結果是一個介於0和1之間的值,其中0表示模型認為輸入資料最可能是狗,而1表示是貓。我們可以將模型的預測結果與實際答案進行比較,以評估模型的準確性。

在程式碼實作上,我們首先從CSV檔案中載入測試資料,並使用在訓練過程中建立的 Standardizer 對資料進行標準化處理。接著,我們呼叫 model.predict() 方法對標準化後的測試資料進行預測,得到預測結果(清單9-24)。

清單9-24:使用神經網路進行預測

use rusty_machine::linalg::BaseMatrix;
use std::io;

fn main() -> Result<(), Box<dyn Error>> {
    // 訓練模型
    // 測試 ====================
    let (testing_inputs, expected) = read_data_from_csv(options.testing_data_csv)?;
    
    // 使用訓練資料的平均值和變異數對測試資料進行標準化
    let normalized_test_cases = standardizer.transform(testing_inputs.clone())?;
    let res = model.predict(&normalized_test_cases)?;
    
    let mut writer = csv::Writer::from_writer(io::stdout());
    writer.write_record(&[
        "height",
        "length",
        "estimated_category_id",
        "true_category_id",
    ])?;
    
    for row in testing_inputs.iter_rows().zip(res.into_vec().into_iter()).zip(expected.into_vec().into_iter()) {
        writer.serialize((row.0.0[0], row.0.0[1], row.0.1, row.1))?;
    }
    Ok(())
}

內容解密:

  1. read_data_from_csv:從指定的CSV檔案中讀取測試資料。
  2. standardizer.transform:對測試資料進行標準化處理,以確保資料的一致性。
  3. model.predict:使用訓練好的神經網路模型對標準化後的測試資料進行預測。
  4. csv::Writer:將預測結果輸出到CSV檔案中,以便於後續分析。

執行以下命令以訓練模型並進行預測:

$ cargo run --bin train_and_predict -- \
--train training_nn.csv \
--test testing_nn.csv > results.csv

預測結果分析

透過比較 resexpected 的結果,我們發現幾乎所有的預測都是正確的。這在現實應用中並不常見,主要原因是我們的訓練和測試資料是人工生成的,易於被神經網路模型區分,並且沒有雜訊幹擾。儘管如此,這個例子仍然展示了使用 rusty-machine 訓練監督式神經網路模型的關鍵步驟。

機器學習的替代方案與未來發展

從本章的例子可以看出,機器學習不僅僅是訓練模型,還涉及許多與資料相關的操作,包括讀寫CSV檔案、資料預處理、模型組態和引數載入,以及資料視覺化等。為了高效實作機器學習應用,我們需要一個健全的生態系統,包含許多預先構建的crate,以避免從頭開始編寫所有程式碼。

Rust的機器學習生態系統仍在發展中,但已經有一些領先的crate,如 nalgebrandarray,它們提供了線性代數和陣列/矩陣操作功能。此外,smartcorelinfa crate在傳統機器學習領域表現突出,實作了多種常用的機器學習模型。

對於深度學習,目前尚未有完全使用Rust從頭構建的成熟函式庫。因此,使用Rust繫結其他語言編寫的成熟函式庫(如TensorFlow和PyTorch)是目前的最佳選擇。