返回文章列表

機器學習系統推論管線最佳化策略

本文探討機器學習系統推論管線的最佳化策略,強調效能分析的重要性,並探討模型最佳化、服務最佳化、程式碼最佳化和硬體最佳化等不同層面的方法。文章以 Supermegaretail 和 PhotoStock Inc. 兩個案例,分別闡述批次處理和即時查詢的系統設計,並提供程式碼範例和圖表說明,以更具體地展現最佳化策

機器學習 系統設計

在機器學習系統工程中,推論管線的最佳化至關重要。最佳化核心在於平衡模型速度、準確性和資源容量,並依模型特性和架構而異。最佳化的起始步驟是效能分析,藉此找出系統瓶頸,而非直接進行模型調整。分析需涵蓋整個系統,包含資料預處理、網路互動等環節,並針對最可最佳化的部分著手,而非僅關注最慢的部分。由於機器學習系統大量使用 GPU 和非同步執行等特性,分析過程需藉助多種工具,從 Python 內建的 cProfile 到專門的 GPU 分析器,才能取得全面的效能資料。

最佳化推論管線

「過早的最佳化是萬惡之源。」—— 唐納德·克努斯

最佳化推論管線是一個廣泛的話題,它本身並不是機器學習系統設計的一部分,但仍是機器學習系統工程中至關重要的部分,值得單獨設一章節,至少可以列出常見的方法和工具作為全景概覽。其核心往往歸結為模型速度與準確性以及所需資源容量之間的權衡,這意味著眾多的最佳化技術主要取決於模型的特性和架構。

一個合理的問題可能會在此時出現:「最佳化的起始程式是什麼?」我們在面試中曾問過許多機器學習工程師這個問題,並收到了各種各樣的答案,這些答案提到了諸如模型修剪(model pruning)、量化(quantization,見「Pruning and Quantization for Deep Neural Network Acceleration: A Survey」 https://arxiv.org/abs/2101.09671)和蒸餾(distillation,見「Knowledge Distillation: A Survey」 https://arxiv.org/abs/2006.05525)等術語,並參考了一些論文的狀態(我們只提到了幾篇調查文章,以便您將它們作為起點)。這些技術在機器學習研究社群中廣為人知,它們很有用且經常適用,但它們僅僅關注模型最佳化,而沒有讓我們控制整個畫面。最務實的答案是,「我會從分析(profiling)開始。」

從分析(Profiling)開始

分析是一種測量系統效能並識別瓶頸的過程。與前面提到的技術相比,分析是一種更通用的方法,可以應用於整個系統。就像戰略先於戰術一樣,分析是一個很好的起點,可以識別出最有前途的最佳化方向,並選擇最合適的技術。

分析實務經驗

你可能會對看似最明顯的因素並非模型的弱點感到驚訝。以延遲(latency)為例,有些情況下它並不是瓶頸(尤其是當服務的模型不是最近的生成式模型,而是更傳統的模型時),問題可能隱藏在其他地方(例如,資料預處理或網路互動)。此外,即使模型是管線中最慢的部分,也不自動意味著它應該是最佳化的目標。這看起來可能違反直覺,但我們應該尋找系統中最可最佳化的部分,而不是最慢的部分。

程式碼範例與解析
import cProfile

def my_function():
    # 模擬某些工作
    result = sum(i for i in range(1000000))
    return result

# 使用cProfile進行效能分析
cProfile.run('my_function()')

程式碼解析:

  1. 匯入 cProfile 模組:用於對 Python 程式碼進行效能分析。
  2. 定義 my_function():這是一個範例函式,計算從 0 到 999999 的總和。
  3. 使用 cProfile.run():對 my_function() 進行效能分析,並輸出分析結果。

透過這種方式,可以瞭解程式碼的哪一部分佔用了最多的執行時間,從而針對性地進行最佳化。

另一個例子來自 Arseny 的經驗;他曾被要求減少一個系統的延遲。該系統是一個相對簡單的管線,按順序執行數十個簡單模型。改進時間的第一個想法是用批次推論(batch inference)取代順序執行。然而,分析證明事實恰恰相反:推論本身是微不足道的(約 5%),相比於花費在資料預處理上的時間,而批次推論並不能提供幫助。最終最佳化的是預處理步驟,這一步驟無法從批次處理中受益,最終與最初設計的順序執行相結合。最終,Arseny 在未觸及核心模型推論的情況下,將系統速度提高了 40%,而諸如執行緒管理、IO 和序列化/反序列化函式以及級聯快取(cascade caching)等元素才是真正的唾手可得的成果。

這裡我們應該提到,對機器學習系統進行分析與對常規軟體進行分析略有不同,因為機器學習系統大量使用 GPU、非同步執行、使用高效能原生程式碼上的薄 Python 包裝器等原因。由於整個過程可能包含更多的變數,因此在解釋分析結果時應該小心,因為很容易被問題的複雜性所混淆。

GPU 執行的情況可能特別令人困惑。典型的 CPU 負載很簡單:資料被載入記憶體,CPU 執行程式碼:完成。GPU 是獨立的裝置,通常具有內建記憶體,資料必須在執行前複製到 GPU 記憶體。複製過程本身可能是一個瓶頸,而且如何測量它並不總是顯而易見的。GPU 執行的高度平行性質也使得效果呈現非線性。例如,如果你在單張圖片上執行模型,可能需要 100 毫秒,但如果你在 64 張圖片上執行,處理時間只會增加到 200 毫秒。這是因為在處理單一專案時,GPU 未被充分利用,從而導致將資料複製到 GPU 記憶體的開銷很大。同樣,在模型架構的更低層級上,減少卷積層中的濾波器數量可能不會減少延遲,因為 GPU 未被充分利用,並且在底層使用相同的 CUDA 核心。

因此,一種適當的分析方法需要準備多種工具,從基本的剖析器如 Python 標準函式庫中的 cProfile,更先進的第三方工具如 Scalene(https://github.com/plasma-umass/scalene)、memray(https://github.com/bloomberg/memray)、py-spy(https://github.com/benfred/py-spy)和特定於 ML 框架的工具(如 PyTorch Profiler),到低階 GPU 分析器如 nvprof(https://mng.bz/OmqE)。

總之,在進行機器學習系統最佳化時,首先進行全面的效能分析是非常重要的。透過分析,可以找出系統的瓶頸所在,並選擇合適的最佳化策略,從而達到更好的最佳化效果。

設計檔案:服務與推論最佳化

在開發與佈署機器學習(ML)系統時,服務與推論階段的最佳化至關重要。根據系統的不同需求,可能需要採用特定的硬體或軟體最佳化技術。例如,在使用張量處理單元(Tensor Processing Units, TPUs)或物聯網(IoT)處理器等特殊硬體時,可能需要使用供應商特定的工具。

在分析效能分析結果後,我們可以開始最佳化系統,解決最明顯的瓶頸。一般的最佳化方法涵蓋多個層面:

  • 與模型相關的最佳化,如架構變更、剪枝(pruning)、量化(quantization)、蒸餾(distillation)、特徵選擇等。
  • 與服務相關的最佳化,如批次處理、快取、預先計算等。
  • 與程式碼相關的最佳化,如更有效的低階演算法、使用更有效的正規化(如向量化演算法取代 for 迴圈),或重寫為更快的函式庫/框架(如用於數值計算的 Numba),甚至是更換程式語言(如用 C++ 或 Rust 取代 Python 瓶頸)。
  • 與硬體相關的最佳化,如使用更強大的硬體、垂直/水平擴充套件等。

最佳最佳化即最小最佳化

從整體設計的角度來看待最佳化問題,可以避免在維護階段出現一些問題。如果我們在設計階段就根據原始需求對系統進行徹底的處理,就可以避免許多問題。如果我們意識到嚴格的延遲需求,就應該最初選擇足夠快的模型。雖然我們可以透過量化或剪枝減少記憶體佔用,或透過蒸餾減少延遲,但從一開始就選擇接近目標需求的模型總是更好的。

設計檔案:Supermegaretail 的服務與推論

根據 Supermegaretail 的關鍵特徵和需求,解決方案不需要即時參與,可以批次修改模型,但仍涉及大量工作。

XII. 服務與推論

服務與推論的主要考慮因素是:

  • 高效的批次處理量,因為預測將在每日、每週和每月對大量資料進行。
  • 保護敏感的庫存和銷售資料。
  • 可擴充套件且具成本效益的架構,能夠處理批次作業。
  • 監控資料和預測品質。

I. 服務架構

我們將使用 Docker 容器來執行批次需求預測作業,這些容器由 AWS Batch 在 EC2 機器上進行協調。AWS Batch 將允許定義資源需求、動態調整所需容器數量以及排隊大量工作負載。

批次作業將按照排程從 S3 處理輸入資料,執行推論,並將結果輸出回 S3。如果需要,可以透過簡單的 Flask API 允許按需批次推論請求。

所有資料傳輸和處理都將在安全的 AWS 基礎設施上進行,與外部隔離。將使用適當的憑證進行身份驗證和授權。

II. 基礎設施

批次伺服器將使用自動擴充套件群組來比對工作負載需求。可以利用 Spot 例項來降低靈活批次作業的成本。

目前不需要專門的硬體或最佳化,因為優先考慮批次處理量,而批次性質允許充分的平行化。我們將利用 AWS Batch 和 S3 提供的水平擴充套件選項。

III. 監控

要追蹤的批次作業關鍵指標包括:

  • 作業成功率、持續時間和失敗率。
  • 每個作業處理的行數。
  • 伺服器利用率:CPU、記憶體、磁碟空間。
  • 與實際需求相比的預測準確度。
  • 資料驗證檢查和警示。

這些監控有助於確保批次處理過程保持高效和可擴充套件,並產生高品質的預測。未來,我們可以根據生產資料評估最佳化的需求。

程式碼範例與解析

以下是一個簡單的 Flask API 範例,用於提供按需批次推論請求:

from flask import Flask, request, jsonify
import pandas as pd
from sklearn.externals import joblib

app = Flask(__name__)

# 載入已訓練的模型
model = joblib.load('model.joblib')

@app.route('/predict', methods=['POST'])
def predict():
    # 從請求中取得資料
    data = request.get_json()
    df = pd.DataFrame(data)
    
    # 執行預測
    predictions = model.predict(df)
    
    # 傳回預測結果
    return jsonify(predictions.tolist())

if __name__ == '__main__':
    app.run(debug=True)

內容解密:

  1. 載入已訓練的機器學習模型:model = joblib.load('model.joblib')。這行程式碼使用 joblib 函式庫載入之前儲存的模型,以便在 API 中使用。
  2. 定義 Flask API 的路由:@app.route('/predict', methods=['POST'])。這行程式碼定義了一個名為 /predict 的 API 路由,接受 POST 請求,用於接收資料並傳回預測結果。
  3. 從請求中取得資料:data = request.get_json()。這行程式碼從 POST 請求中取得 JSON 資料,並將其轉換為 Python 物件。
  4. 將資料轉換為 Pandas DataFrame:df = pd.DataFrame(data)。這行程式碼將取得的資料轉換為 Pandas DataFrame,以便進行預測。
  5. 使用模型執行預測:predictions = model.predict(df)。這行程式碼使用載入的模型對 DataFrame 中的資料進行預測。
  6. 傳回預測結果:return jsonify(predictions.tolist())。這行程式碼將預測結果轉換為 JSON 格式並傳回給客戶端。

圖表說明

以下是一個 Plantuml 圖表,用於呈現 Supermegaretail 的服務架構:

@startuml
skinparam backgroundColor #FEFEFE
skinparam defaultTextAlignment center
skinparam rectangleBackgroundColor #F5F5F5
skinparam rectangleBorderColor #333333
skinparam arrowColor #333333

title 圖表說明

rectangle "請求" as node1
rectangle "處理請求" as node2
rectangle "執行批次作業" as node3
rectangle "儲存結果" as node4
rectangle "傳回結果" as node5

node1 --> node2
node2 --> node3
node3 --> node4
node4 --> node5

@enduml

此圖示說明瞭 Supermegaretail 的服務架構,包括客戶端請求、Flask API 處理請求、AWS Batch 執行批次作業、S3 儲存結果等步驟。

圖表內容解密:

  1. 圖表中的節點代表不同的元件,包括客戶端、Flask API、AWS Batch 和 S3。
  2. 箭頭代表元件之間的互動流程,例如客戶端傳送請求給 Flask API,Flask API 將請求轉發給 AWS Batch 等。
  3. 圖表呈現了整個服務架構的工作流程,有助於理解系統的運作方式。

為 PhotoStock Inc. 設計檔案:服務與推理

XII. 服務與推理

由於我們的搜尋引擎根據向量相似性,因此我們需要關心兩個方面:為可搜尋專案生成向量(更新索引)和搜尋使用者查詢(查詢索引)。這兩個方面有不同的需求和限制,因此我們將分別設計它們。

I. 索引更新

更新索引是一種批次處理,通常每天發生一次(正如第13章所述)。除了定期更新外,我們還需要支援初始索引建立或在核心模型更新時重新建立索引。雖然這是一個相對罕見的事件,但擁有一個可以按需執行的流程非常重要。

這兩種情況具有相同的特點:

  • 緩和的延遲要求
  • 嚴格的吞吐量要求

我們需要在合理的時間內以合理的成本處理大量的專案。對於粗略估計,我們應該在幾個小時內支援每天重新索引約10^5個專案。如果核心模型更新,我們還需要在合理時間內重新索引約10^8個專案。

II. 索引查詢

查詢索引是一種實時處理,涉及每個使用者查詢。我們需要最小化查詢的延遲,以便我們的吞吐量要求不會太高,因為我們的平均搜尋次數是每天150,000次(請參閱第12章),大約是每秒2次查詢。然而,查詢次數並不是均勻分佈的,我們需要能夠處理峰值負載約每秒100次查詢,並且能夠在流量激增的情況下快速擴充套件和縮減。

我們建議對批次和實時推理使用轉換為ONNX的相同模型。這不是一個硬性要求,但它將簡化系統的設計和維護。然而,批次和實時推理的處理過程應該根據不同的需求分開。

III. 框架與硬體

從軟體角度來看,我們將使用Nvidia Triton Inference Server作為服務框架。它是一款高效能的開源推理服務軟體,支援ONNX,並具有許多簡化服務流程的功能。我們將使用Triton Inference Server的HTTP API與應用程式進行通訊。

對於批次推理,我們將使用雲端供應商提供的預設解決方案:AWS Sagemaker。它是一種受管理的服務,允許我們在可組態的例項上執行批次推理,如果需要,可以輕鬆擴充套件,並且與我們使用的其他AWS服務整合良好。我們可以考慮在底層使用Spot例項來降低批次推理的成本。批次作業本身將是一個簡單的指令碼,位於實時推理之上,增加了一個IO層,從佇列中讀取資料並將結果寫入S3和資料函式庫。

對於實時推理,最好有一個無伺服器的解決方案,可以在沒有查詢時縮小到零。然而,考慮到我們的高負載和低延遲要求,使用像AWS Lambda這樣的主要供應商可能很難實作;因此,我們將採用更傳統的方法,在負載平衡器後面使用一組伺服器。我們將使用AWS EC2例項作為伺服器,並使用AWS Application Load Balancer作為負載平衡器。我們可以在底層使用Spot例項來降低實時推理的成本,因為每個工作節點都是無狀態的,如果被AWS終止,我們可以輕鬆地用新的替換它。我們需要確保系統具有合理數量的可用工作節點,並在需要時啟用額外的擴充套件。

IV. 輔助基礎設施

我們將從預設的float32精確度開始模型服務,但稍後將嘗試使用較低的精確度(例如float16)來降低服務成本。最佳化模型本身以降低延遲也可以稍後進行,儘管目前我們不期望特定的瓶頸,因為CLIP模型相對簡單。

由於某些查詢比其他查詢更受歡迎,因此我們可以使用快取來減少對推理伺服器的負載。我們可以使用AWS Elasticache來實作這一點;它是一種受管理的服務,支援Redis和Memcached。我們可以使用簡單的鍵值快取,具有較低的存活時間(具體數字取決於資料分析)。快取對於執行時推理很有用,但對於批次推理則不是;批次推理應該負責更新快取,如果鍵已更改。

詳細說明

為了滿足 PhotoStock Inc. 的搜尋引擎需求,我們設計了一個系統來處理索引更新和查詢。這兩個過程有不同的需求:索引更新需要高吞吐量,而查詢需要低延遲。

索引更新過程

索引更新是一個每天發生的批次過程,需要處理大量的專案。我們設計了一個系統,可以在幾個小時內重新索引約10^5個專案,並在核心模型更新時重新索引約10^8個專案。

import boto3
from botocore.exceptions import ClientError

def update_index():
    # 初始化Sagemaker客戶端
    sagemaker = boto3.client('sagemaker')

    try:
        # 建立Sagemaker任務來更新索引
        response = sagemaker.create_processing_job(
            ProcessingJobName='index-update-job',
            ProcessingResources={
                'ClusterConfig': {
                    'InstanceCount': 1,
                    'InstanceType': 'ml.m5.xlarge',
                    'VolumeSizeInGB': 30
                }
            },
            AppSpecification={
                'ImageUri': 'your-docker-image-uri'
            }
        )
        print("索引更新任務已建立:", response['ProcessingJobArn'])
    except ClientError as e:
        print("建立索引更新任務失敗:", e)

#### 內容解密:
1. **初始化Sagemaker客戶端**使用boto3函式庫初始化一個Sagemaker客戶端以便與Sagemaker服務進行互動
2. **建立Sagemaker任務**呼叫`create_processing_job`方法建立一個Sagemaker任務來更新索引
3. **組態任務資源**指定任務所需的資源包括例項數量例項型別和儲存大小
4. **指定Docker映像**提供用於執行任務的Docker映像URI
5. **錯誤處理**捕捉並處理可能發生的ClientError異常以確保任務建立失敗時能夠正確報告錯誤

##### 索引查詢過程
查詢索引是一個實時過程需要最小化延遲以滿足使用者需求我們設計了一個系統可以處理每秒100次查詢的峰值負載並能夠根據流量變化快速擴充套件或縮減

```python
from tritonclient.http import InferenceServerClient, InferInput

def query_index(query_vector):
    # 初始化Triton Inference Server客戶端
    client = InferenceServerClient(url='your-triton-server-url')

    # 準備輸入資料
    input_data = InferInput(name='input_vector', shape=[1, -1], datatype='FP32')
    input_data.set_data_from_numpy(query_vector)

    try:
        # 傳送推理請求
        response = client.infer(model_name='your-model-name', inputs=[input_data])
        print("推理結果:", response.as_numpy('output'))
    except Exception as e:
        print("傳送推理請求失敗:", e)

#### 內容解密:
1. **初始化Triton客戶端**建立一個Triton Inference Server客戶端用於與Triton伺服器互動
2. **準備輸入資料**建立InferInput物件將查詢向量轉換為適合Triton輸入的格式
3. **傳送推理請求**呼叫`infer`方法向Triton伺服器傳送推理請求並取得結果
4. **錯誤處理**捕捉並處理可能發生的異常以確保請求失敗時能夠正確報告錯誤