Rust 的機器學習生態系統正在蓬勃發展,雖然仍處於早期階段,但已有一些 crate 可用於構建機器學習應用。本文使用 rusty-machine crate 示範瞭如何建立一個簡單的神經網路模型,包含資料預處理、模型訓練和預測等步驟。rusty-machine 提供了基本的機器學習演算法和工具,方便入門和學習。不過,需要注意的是,rusty-machine 並非最活躍的專案,對於更進階的應用,可能需要考慮其他方案,例如 linfa 或 smartcore 等更活躍的 crate,或是使用 Rust 繫結到其他語言的深度學習框架,例如 TensorFlow 或 PyTorch。資料的處理流程包含從 CSV 檔案讀取資料,並使用 Standardizer 進行正規化,以提升模型訓練效率和準確度。模型訓練部分則使用隨機梯度下降法進行最佳化,並使用二元交叉熵作為損失函式。
設定神經網路模型
在生成訓練和測試資料後,需要建立模型訓練和預測的程式碼。這些程式碼將被放入一個新的二進位制檔案src/bin/train_and_predict.rs中。該二進位制檔案需要完成以下任務:
- 讀取和解析訓練資料到
Vec中,並將其轉換為Array。 - 標準化訓練資料。
- 初始化神經網路模型。
- 將標準化的訓練資料輸入模型進行訓練。
- 讀取和解析測試資料到
Vec中,並將其轉換為Array。 - 使用相同的引數標準化測試資料。
- 使用訓練好的模型對測試資料進行預測。
讀取訓練和測試資料
在K-means範例(清單9-15)中,從STDIN讀取CSV輸入。然而,在監督模型中,需要兩個輸入檔案:訓練資料和測試資料。因此,這次將透過CLI引數給出CSV檔案的路徑,並直接從檔案中讀取。使用clap crate和第2章中介紹的程式碼,可以建立兩個引數:training_data_csv和testing_data_csv(清單9-20)。
清單9-20:神經網路CLI引數解析
use clap::Parser;
use std::error::Error;
#[derive(Parser)]
struct Args {
#[arg(short = 'r', long = "train")]
/// 訓練資料CSV檔案
training_data_csv: std::path::PathBuf,
#[arg(short = 't', long = "test")]
/// 測試資料CSV檔案
testing_data_csv: std::path::PathBuf,
}
fn main() -> Result<(), Box<dyn Error>> {
let args = Args::parse();
// ...
Ok(())
}
清單9-21:從CSV讀取訓練資料
use serde::Deserialize;
use rusty_machine::linalg::Matrix;
// ...
#[derive(Debug, Deserialize)]
struct SampleRow {
height: f64,
length: f64,
category_id: usize,
}
fn read_data_from_csv(
file_path: std::path::PathBuf,
) -> Result<(Matrix<f64>, Matrix<f64>), Box<dyn Error>> {
let mut input_data = vec![];
let mut label_data = vec![];
let mut sample_count = 0;
let mut reader = csv::Reader::from_path(file_path)?;
for raw_row in reader.deserialize() {
let row: SampleRow = raw_row?;
input_data.push(row.height);
input_data.push(row.length);
label_data.push(row.category_id as f64);
sample_count += 1
}
let inputs = Matrix::new(sample_count, 2, input_data);
let targets = Matrix::new(sample_count, 1, label_data);
return Ok((inputs, targets));
}
fn main() -> Result<(), Box<dyn Error>> {
let options = Args::parse();
// ...
}
內容解密:
- 定義了一個名為
SampleRow的結構體,用於反序列化CSV檔案中的每一行資料。該結構體包含三個欄位:height、length和category_id,分別對應CSV檔案中的高度、長度和類別ID。 - 使用
csv::Reader從指定的檔案路徑讀取CSV檔案,並將其反序列化為SampleRow結構體的例項。 - 將反序列化後的資料儲存在
input_data和label_data向量中,分別用於儲存輸入資料和標籤資料。 - 將向量轉換為
Matrix型別,以便於後續的神經網路訓練。
新增依賴
由於Rust沒有純Rust實作的神經網路主要crate,因此這裡使用了相對較舊的rusty-machine crate。雖然該crate及其伴隨的線性代數crate rulinalg 目前尚未積極更新,但其介面對於學習Rust中的神經網路來說是最簡單的。
要取得必要的crate,請在命令列中執行以下命令:
$ cargo add rusty-machine
正規化訓練資料與神經網路模型訓練
在將資料輸入神經網路模型之前,有一個重要的步驟可以顯著提高模型的訓練速度和準確性,這個步驟稱為正規化(Normalization)。正規化的目標是將輸入資料進行偏移和縮放,使其具有平均值 0 和標準差 1。這對於神經網路等模型來說非常重要,因為在進行梯度下降(Gradient Descent)時,正規化的資料集可以避免最佳化過程被某個維度過大的數值所主導。同時,這也可以使成本函式(Cost Function)具有更平滑的形狀,從而使梯度下降過程更快更穩定。
正規化的步驟
正規化的過程包括以下步驟:
- 計算資料集的平均值,並從所有資料點中減去該平均值,以使資料集的平均值變為 0。
- 計算資料集的標準差,並將每個資料點的座標除以標準差,以將資料集縮放到標準差為 1。
在實際操作中,我們需要保留訓練資料的平均值和標準差,以便在正規化測試資料時使用相同的引數。這是因為神經網路模型的所有引數都是針對正規化的訓練資料進行訓練的,如果測試資料具有不同的平均值和標準差,模型的預測可能會出現偏差。
使用 Standardizer 進行正規化
幸運的是,我們不需要自己編寫這部分程式碼。rusty-machine 提供了 Standardizer 結構體,可以方便地進行正規化。Standardizer 實作了 Transformer 特性,定義了常用資料預處理轉換的分享介面。
Standardizer 的使用方法
- 使用
new()函式初始化Standardizer,並指定期望的平均值和標準差。 - 使用
fit()函式計算輸入資料的平均值和標準差,並將其儲存在Standardizer例項中。 - 使用
transform()函式對提供的資料進行轉換,使用在fit()步驟中學習到的平均值和標準差。
use rusty_machine::data::transforms::Transformer;
use rusty_machine::data::transforms::Standardizer;
fn main() -> Result<(), Box<dyn Error>> {
let options = Args::parse();
let (training_inputs, training_label_data) = read_data_from_csv(options.training_data_csv)?;
let mut standardizer = Standardizer::new(0.0, 1.0);
standardizer.fit(&training_inputs).unwrap();
let normalized_training_inputs = standardizer.transform(training_inputs).unwrap();
// ... 使用 normalized_training_inputs 訓練模型 ...
let (testing_inputs, expected) = read_data_from_csv(options.testing_data_csv)?;
let normalized_test_cases = standardizer.transform(testing_inputs.clone())?;
// ... 使用 normalized_test_cases 進行預測 ...
Ok(())
}
程式碼解密:
- 初始化
Standardizer:使用Standardizer::new(0.0, 1.0)初始化一個Standardizer例項,指定平均值為 0,標準差為 1。 - 擬合訓練資料:呼叫
standardizer.fit(&training_inputs),計算訓練資料的平均值和標準差,並儲存在Standardizer例項中。 - 轉換訓練資料:呼叫
standardizer.transform(training_inputs),對訓練資料進行正規化轉換。 - 轉換測試資料:使用相同的
Standardizer例項對測試資料進行正規化轉換,確保測試資料與訓練資料具有相同的正規化引數。
神經網路模型的訓練與預測
在完成資料正規化之後,我們就可以開始訓練神經網路模型。與 K-means 模型相比,神經網路模型的組態更加複雜,需要設定以下引數:
- 網路層數和每層的神經元數量
- 評估標準,包括啟用函式和損失函式
- 最佳化演算法
設定神經網路模型的引數
在設定神經網路模型的引數時,需要考慮以下因素:
- 網路層數和每層的神經元數量:需要根據具體問題和資料集進行調整。
- 評估標準:需要選擇合適的啟用函式和損失函式,以滿足具體問題的需求。
- 最佳化演算法:需要選擇合適的最佳化演算法,以確保模型的收斂性和準確性。
@startuml
skinparam backgroundColor #FEFEFE
skinparam componentStyle rectangle
title Rust機器學習模型訓練與預測
package "機器學習流程" {
package "資料處理" {
component [資料收集] as collect
component [資料清洗] as clean
component [特徵工程] as feature
}
package "模型訓練" {
component [模型選擇] as select
component [超參數調優] as tune
component [交叉驗證] as cv
}
package "評估部署" {
component [模型評估] as eval
component [模型部署] as deploy
component [監控維護] as monitor
}
}
collect --> clean : 原始資料
clean --> feature : 乾淨資料
feature --> select : 特徵向量
select --> tune : 基礎模型
tune --> cv : 最佳參數
cv --> eval : 訓練模型
eval --> deploy : 驗證模型
deploy --> monitor : 生產模型
note right of feature
特徵工程包含:
- 特徵選擇
- 特徵轉換
- 降維處理
end note
note right of eval
評估指標:
- 準確率/召回率
- F1 Score
- AUC-ROC
end note
@enduml
此圖示展示了神經網路模型訓練的流程,包括資料準備、正規化、模型訓練、模型評估和模型預測等步驟。
綜上所述,正規化是神經網路模型訓練中的重要步驟,可以提高模型的訓練速度和準確性。透過使用 Standardizer 結構體,可以方便地進行正規化。同時,需要根據具體問題和資料集進行神經網路模型的引數設定,以確保模型的收斂性和準確性。
使用 Rusty-Machine 訓練神經網路模型
建立神經網路模型
Rusty-machine 提供了一個 NeuralNet 結構體來建立神經網路模型。這個結構體有一個 ::default() 函式,可以用來初始化模型。然而,如果我們檢視其內部實作,會發現它其實是根據特定的組態來建立模型的,如下所示:
程式碼範例:訓練神經網路
use rusty_machine::learning::nnet::{NeuralNet, BCECriterion};
use rusty_machine::learning::optim::grad_desc::StochasticGD;
use rusty_machine::learning::SupModel;
fn main() -> Result<(), Box<dyn Error>> {
// ... 載入訓練資料並進行預處理 ...
let layers = &[2, 2, 1]; // 定義神經網路層數
let criterion = BCECriterion::default(); // 使用二元交叉熵作為損失函式
let gradient_descent = StochasticGD::new(0.1, 0.1, 20); // 使用隨機梯度下降法
let mut model = NeuralNet::new(
layers,
criterion,
gradient_descent
);
model.train(
&normalized_training_inputs,
&training_label_data
)?;
// ... 測試模型 ...
Ok(())
}
內容解密:
let layers = &[2, 2, 1];:定義了一個三層的神經網路架構,第一層有兩個輸入神經元,第二層有兩個神經元,第三層有一個輸出神經元。let criterion = BCECriterion::default();:預設使用二元交叉熵(BCECriterion)作為損失函式,該函式內部使用了 Sigmoid 啟動函式和交叉熵誤差作為損失計算方法。let gradient_descent = StochasticGD::new(0.1, 0.1, 20);:選擇隨機梯度下降法(StochasticGD)作為最佳化演算法,並設定了動量(momentum)、學習率(learning rate)和迭代次數(iterations)。- 動量(預設:0.1)
- 學習率(預設:0.1,但實際上是原始學習率的平方根)
- 迭代次數(預設:20)
這些引數都會影響神經網路模型的效能。雖然這裡使用了預設值,但在實際應用中,調整這些引數以獲得最佳模型效能是非常重要的。
訓練模型
透過呼叫 model.train() 方法並傳入訓練資料和對應的標籤,可以開始訓練模型。這個過程會進行複雜的數學運算,並將學習到的權重和其他引數儲存在模型中。
程式碼範例:訓練模型
model.train(&normalized_training_inputs, &training_label_data)?;
內容解密:
model.train()方法接收兩個引數:輸入資料和目標標籤。它實作了SupModel特徵,該特徵要求提供真實標籤以進行監督式學習。
進行預測
在訓練完成後,可以使用測試資料來評估模型的效能。透過呼叫 model.predict() 方法,可以獲得模型的預測結果。
程式碼範例:使用神經網路進行預測
let res = model.predict(&normalized_test_cases)?;
let mut writer = csv::Writer::from_writer(io::stdout());
writer.write_record(&[
"height",
"length",
"estimated_category_id",
"true_category_id",
])?;
for row in testing_inputs.iter_rows().zip(res.into_vec().into_iter()).zip(expected.into_vec().into_iter()) {
writer.serialize((row.0.0[0], row.0.0[1], row.1, row.0.1))?;
}
內容解密:
let res = model.predict(&normalized_test_cases)?;:使用訓練好的模型對標準化後的測試資料進行預測,獲得預測結果。- 將預測結果與原始測試資料一起輸出到 CSV 檔案中,以便比較預測值和真實值。
9.6 替代方案與現狀
從本章的範例中可以看出,機器學習並不僅僅是訓練模型。在訓練模型之前和之後,還有許多與資料相關的操作。這些操作包括:
- 讀取和寫入CSV或其他結構化資料格式
- 預處理資料(如標準化)
- 設定和載入模型組態和引數
- 視覺化資料
每次進行機器學習應用開發時,從頭開始撰寫所有程式碼並不是一個實際的做法。我們需要一個強大的生態系統,具備許多預先建置的crate,以幫助快速有效地實作學習部分,而不必擔心線性代數和資料操作等基本任務。與Rust的其他領域類別似,有一個「Are you learning yet?」頁面在追蹤生態系統的發展狀況。如該頁面所述,Rust中的機器學習領域「適合實驗,但生態系統尚未非常完整」。此外,還有Awesome-Rust-MachineLearning的Github倉函式庫,它是一個濃縮且更新的crate列表,目前可用於Rust中各種機器學習目的。
對於基礎數學crate,nalgebra和ndarray已經成為事實上的標準。它們提供線性代數和陣列/矩陣運算,類別似於Python中的numpy。許多機器學習演算法也依賴於高效能運算(HPC)的程式碼,這些程式碼能夠更好地利用硬體(CPU、GPU等)和平行處理能力。在這個領域有很多實驗性的專案,如std::simd、RustCUDA和rayon等。
如果考慮傳統的機器學習(這裡的「傳統」指的是非深度學習),smartcore和linfa crate都是領先且相對全面的。它們都實作了幾種常用的傳統機器學習模型,並繼續這樣做,同時也非常重視與ndarray等crate的互操作性。
深度學習的現狀
對於深度學習,目前尚未有從頭開始使用Rust構建的成熟函式庫。我們在這裡使用了rusty-machine,因為它易於設定和使用,適合學習目的,但該函式庫已不再被積極維護。因此,要進入深度學習領域,目前最好的選擇是使用Rust繫結到其他語言編寫的成熟函式庫。有針對TVM專案的Rust繫結,它是一個開源的深度學習編譯器堆積疊。還有tensorflow/rust用於TensorFlow,以及tch-rs用於PyTorch,這兩者都是主流的深度學習框架,也可能是目前使用Rust進行深度學習最受歡迎的工具。
程式碼範例:使用nalgebra進行線性代數運算
use nalgebra::Matrix2;
fn main() {
let matrix = Matrix2::new(1.0, 2.0, 3.0, 4.0);
println!("Matrix: {}", matrix);
}
內容解密:
- 引入nalgebra函式庫:首先,我們引入nalgebra函式庫以使用其提供的矩陣運算功能。
- 建立矩陣:使用
Matrix2::new函式建立一個2x2的矩陣。 - 列印矩陣:使用
println!宏列印預出建立的矩陣。
Rust在實作高效能和安全的機器學習應用方面具有極大的潛力,但仍需要更多的工作來使生態系統準備好投入生產使用。如果你感興趣,我們鼓勵你聯絡一個開源專案並開始參與;通常,深入理解某件事情最好的方法就是親自實作它的一部分。