在深度學習專案中,資料準備往往佔據了大部分時間和精力。利用 PySpark 的分散式運算能力,可以大幅提升資料處理效率。本文將會介紹如何使用 PySpark 從 AWS S3 讀取資料,進行資料清洗、轉換、特徵工程等準備工作,最後將資料轉換成適合深度學習框架(如 PyTorch 和 TensorFlow)使用的格式。同時也會說明如何在 Databricks 環境中安全地掛載 S3 儲存桶,避免直接在程式碼中暴露存取金鑰等敏感資訊。透過 IAM 角色的設定,可以讓 Databricks 叢集自動取得臨時憑證,更安全地存取 S3 儲存桶中的資料。
在AWS上設定深度學習環境:掛載S3儲存桶
在AWS上使用Databricks進行深度學習時,掛載S3儲存桶是至關重要的一步。本章節將詳細介紹如何使用dbutils.fs.mount函式掛載S3儲存桶,以及如何使用IAM角色提高安全性。
使用dbutils.fs.mount掛載S3儲存桶
dbutils.fs.mount函式用於將S3儲存桶掛載到Databricks環境中。該函式需要多個引數,包括S3儲存桶的URL、掛載點和額外的組態。
程式碼範例
AWS_BUCKET_NAME = "your-bucket-name"
ACCESS_KEY = "your-access-key"
SECRET_KEY = "your-secret-key"
MOUNT_NAME = "/mnt/datalake-central"
dbutils.fs.mount(
source=f"s3a://{AWS_BUCKET_NAME}/",
mount_point=MOUNT_NAME,
extra_configs={
f"s3a.access.key.{AWS_BUCKET_NAME}": ACCESS_KEY,
f"s3a.secret.key.{AWS_BUCKET_NAME}": SECRET_KEY
}
)
內容解密:
AWS_BUCKET_NAME、ACCESS_KEY和SECRET_KEY變數分別儲存S3儲存桶的名稱、存取金鑰和秘密金鑰。MOUNT_NAME變數定義了掛載點的路徑。dbutils.fs.mount函式將S3儲存桶掛載到指定的掛載點。source引數指定了S3儲存桶的URL,mount_point引數指定了掛載點,extra_configs引數提供了額外的組態,包括存取金鑰和秘密金鑰。
取得存取金鑰和秘密金鑰
要取得存取金鑰和秘密金鑰,需要在AWS Management Console中進行以下步驟:
- 登入AWS Management Console。
- 導航到IAM服務。
- 選擇“使用者”並找到需要建立存取金鑰的使用者。
- 在“安全憑證”標籤中,點選“建立存取金鑰”。
- 記錄下存取金鑰ID和秘密存取金鑰。
使用IAM角色提高安全性
直接在程式碼中硬編碼存取金鑰和秘密金鑰存在安全風險。使用IAM角色可以提高安全性。
建立IAM角色
- 登入AWS Management Console並導航到IAM服務。
- 點選“角色”並建立新角色。
- 選擇EC2作為信任實體。
- 將
AmazonS3FullAccess策略附加到角色。
將IAM角色附加到Databricks叢集
- 在AWS Management Console中導航到EC2儀錶板。
- 找到組成Databricks叢集的EC2例項。
- 將建立的IAM角色附加到這些例項。
使用IAM角色掛載S3儲存桶
AWS_BUCKET_NAME = "your-bucket-name"
MOUNT_NAME = "/mnt/datalake-central"
dbutils.fs.mount(
source=f"s3a://{AWS_BUCKET_NAME}/",
mount_point=MOUNT_NAME
)
內容解密:
- 無需提供存取金鑰和秘密金鑰,因為IAM角色提供了臨時憑證。
dbutils.fs.mount函式使用IAM角色掛載S3儲存桶。
驗證掛載是否成功
MOUNT_NAME = "/mnt/datalake-central"
mounts = dbutils.fs.mounts()
mount_point_exists = any(mount.mountPoint == MOUNT_NAME for mount in mounts)
if mount_point_exists:
print("S3儲存桶掛載成功。")
else:
print("S3儲存桶掛載失敗。")
內容解密:
dbutils.fs.mounts()函式檢索所有已掛載的檔案系統。any()函式檢查指定的掛載點是否存在。- 根據檢查結果列印相應的訊息。
透過使用IAM角色和dbutils.fs.mount函式,可以安全地將S3儲存桶掛載到Databricks環境中,從而方便地進行深度學習相關的操作。
在 Databricks 中掛載 S3 儲存桶並進行資料讀寫操作
在前一章節中,我們探討瞭如何在 Amazon Web Services(AWS)上設定深度學習環境並組態必要的工具和服務。本章節將重點介紹如何在 Databricks 中掛載 S3 儲存桶,並進行資料的讀寫操作。
使用 dbutils.fs.mount 函式掛載 S3 儲存桶
首先,我們需要使用 dbutils.fs.mount 函式將 S3 儲存桶掛載到 Databricks。以下程式碼展示瞭如何實作這一步驟:
MOUNT_NAME = "/mnt/datalake-central"
SOURCE_URL = "s3a://your-s3-bucket-name"
ACCESS_KEY = "your-access-key"
SECRET_KEY = "your-secret-key"
try:
dbutils.fs.mount(
source=SOURCE_URL,
mount_point=MOUNT_NAME,
extra_configs={
"fs.s3a.access.key": ACCESS_KEY,
"fs.s3a.secret.key": SECRET_KEY
}
)
print(f"S3 bucket mounted successfully at {MOUNT_NAME}")
except Exception as e:
if "Directory already mounted" in str(e):
print(f"{MOUNT_NAME} is already mounted.")
else:
print(f"Failed to mount S3 bucket: {str(e)}")
# 檢查掛載點是否存在
if MOUNT_NAME in [mount.mountPoint for mount in dbutils.fs.mounts()]:
print(f"{MOUNT_NAME} exists among the mounted file systems.")
else:
print("S3 bucket mount failed.")
程式碼解析:
- 定義變數:首先定義了掛載名稱
MOUNT_NAME、S3 儲存桶的 URLSOURCE_URL、存取金鑰ACCESS_KEY和秘密金鑰SECRET_KEY。 - 掛載 S3 儲存桶:使用
dbutils.fs.mount函式將 S3 儲存桶掛載到指定的掛載點。過程中需要提供 S3 的存取金鑰和秘密金鑰。 - 錯誤處理:程式碼中包含了錯誤處理機制。如果掛載點已經存在,則輸出相應的訊息;如果發生其他錯誤,則輸出錯誤訊息。
- 檢查掛載狀態:最後,檢查掛載點是否存在於已掛載的檔案系統中,以確認掛載是否成功。
在掛載的 S3 儲存桶中進行資料讀寫操作
成功掛載 S3 儲存桶後,我們可以在 Databricks 中進行資料的讀寫操作。以下範例展示瞭如何建立一個 DataFrame,將其儲存到掛載的 S3 儲存桶中,並再次讀取:
# 建立 DataFrame
data_to_write = [("Chicken Taco", 5.99), ("Beef Taco", 6.49), ("Vegetarian Taco", 4.99)]
df = spark.createDataFrame(data_to_write, ["Taco", "Price"])
df = df.repartition(1)
# 將 DataFrame 儲存到 S3
MOUNT_NAME = "/mnt/datalake-central"
df.write.mode("overwrite").parquet(MOUNT_NAME + "/taco_data_parquet")
# 從 S3 讀取 DataFrame
df_read = spark.read.parquet(MOUNT_NAME + "/taco_data_parquet")
df_read.show()
程式碼解析:
- 建立 DataFrame:使用
spark.createDataFrame方法建立一個包含 taco 資料的 DataFrame。 - 儲存 DataFrame:將 DataFrame 重新分割為單一分割區後,以 Parquet 格式儲存到掛載的 S3 儲存桶中。如果檔案已存在,則覆寫它。
- 讀取 DataFrame:使用
spark.read.parquet方法從 S3 讀取剛才儲存的 Parquet 檔案,並將其載入為 DataFrame。 - 顯示 DataFrame:最後,使用
show方法顯示 DataFrame 的內容。
使用PySpark進行深度學習的資料準備
介紹
本章節將探討如何利用PySpark進行深度學習的資料準備工作。我們將介紹資料前處理、特徵工程以及如何將資料轉換為PyTorch和TensorFlow相容的張量格式。同時,我們也會討論如何利用PySpark的平行處理能力來提高資料處理的效率。
資料集介紹
我們將使用Tesla股票價格的歷史資料集作為範例。這個資料集將用於示範如何使用PySpark進行資料準備,並將其應用於PyTorch和TensorFlow的深度學習任務中。
載入必要的函式庫
首先,我們需要載入必要的Python函式庫:
import boto3
import matplotlib.pyplot as plt
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
import pyspark.sql.functions as F
import numpy as np
函式庫說明
boto3:用於與AWS服務互動,下載S3儲存桶中的Tesla股票CSV檔案。matplotlib.pyplot:用於根據DataFrame建立視覺化圖表。SparkSession:PySpark的DataFrame API入口點,用於建立SparkSession物件以進行資料處理操作。col:用於在check_for_null_values方法中參照DataFrame欄位。F:pyspark.sql.functions的別名,用於縮寫對PySpark函式的呼叫。numpy:用於數值運算和調整視覺化中的刻度位置。
資料處理器類別定義
接下來,我們定義了一個名為DataProcessor的類別,用於處理各種資料處理任務:
class DataProcessor:
def __init__(self, spark_session):
self.spark = spark_session
載入資料方法
在DataProcessor類別中,我們定義了load_data方法,用於從CSV檔案載入Tesla股票價格資料:
def load_data(self, file_path: str):
"""
使用SparkSession從CSV檔案載入股票價格資料。
"""
try:
df = self.spark.read.csv(
file_path,
header=True,
inferSchema=True
)
return df
except Exception as e:
print(f"載入資料時發生錯誤:{str(e)}")
return None
列印DataFrame的前N列方法
我們還定義了print_first_n_rows方法,用於列印DataFrame的前N列:
def print_first_n_rows(self, df, n=10):
"""列印DataFrame的前n列。"""
print(f"DataFrame的前{n}列:")
df.show(n)
內容解密:
load_data方法使用SparkSession物件的read.csv方法讀取指定路徑的CSV檔案,並根據檔案內容推斷DataFrame的結構描述。- 如果資料載入成功,該方法傳回包含股票價格資料的DataFrame;如果發生例外狀況,則捕捉例外並傳回None。
print_first_n_rows方法用於列印DataFrame的前N列,預設值為10。
資料探索與視覺化
在載入資料後,我們可以進行資料探索,包括列印前幾列資料、計算描述性統計資料、檢查缺失值以及視覺化資料。
此圖示說明瞭使用PySpark進行深度學習資料準備的流程:
@startuml
skinparam backgroundColor #FEFEFE
skinparam componentStyle rectangle
title PySpark深度學習資料準備與S3整合
package "機器學習流程" {
package "資料處理" {
component [資料收集] as collect
component [資料清洗] as clean
component [特徵工程] as feature
}
package "模型訓練" {
component [模型選擇] as select
component [超參數調優] as tune
component [交叉驗證] as cv
}
package "評估部署" {
component [模型評估] as eval
component [模型部署] as deploy
component [監控維護] as monitor
}
}
collect --> clean : 原始資料
clean --> feature : 乾淨資料
feature --> select : 特徵向量
select --> tune : 基礎模型
tune --> cv : 最佳參數
cv --> eval : 訓練模型
eval --> deploy : 驗證模型
deploy --> monitor : 生產模型
note right of feature
特徵工程包含:
- 特徵選擇
- 特徵轉換
- 降維處理
end note
note right of eval
評估指標:
- 準確率/召回率
- F1 Score
- AUC-ROC
end note
@enduml
內容解密:
- 資料準備的第一步是載入資料,使用PySpark的
read.csv方法可以輕鬆實作這一步驟。 - 資料探索包括列印前幾列資料、計算描述性統計資料和檢查缺失值等步驟,可以幫助我們瞭解資料的基本情況。
- 資料視覺化可以幫助我們更直觀地瞭解資料的分佈和特徵。
- 資料前處理和特徵工程是準備好資料的重要步驟,可以根據具體的深度學習任務進行相應的處理。
- 將資料轉換為PyTorch或TensorFlow相容的張量格式,以便進行後續的模型訓練。
使用 PySpark 進行資料準備的關鍵方法實作細節
在資料準備過程中,DataProcessor 類別提供了多個關鍵方法來處理和探索資料。這些方法涵蓋了資料檢視、描述性統計計算、資料視覺化以及空值檢查等功能。
資料檢視方法:show_data
def show_data(self, df, n):
"""顯示 DataFrame 的前 n 筆資料。"""
print(f"顯示前 {n} 筆資料:")
df.show(n)
內容解密:
show_data方法用於顯示 DataFrame 的前n筆資料。- 透過
print陳述式輸出提示訊息,告知使用者即將顯示的資料筆數。 - 使用 DataFrame 的
show方法來顯示指定數量的資料列。
描述性統計計算:calculate_descriptive_statistics
def calculate_descriptive_statistics(self, df):
"""計算 DataFrame 的描述性統計。"""
print("計算描述性統計:")
df.summary().show()
內容解密:
calculate_descriptive_statistics方法負責計算 DataFrame 的描述性統計。- 首先輸出提示訊息,說明正在進行描述性統計的計算。
- 使用 DataFrame 的
summary方法來計算統計資料,包括計數、平均值、標準差、最小值和最大值等。
資料視覺化:visualize_data
def visualize_data(self, df):
"""根據 DataFrame 建立視覺化圖表。"""
print("進行資料視覺化:")
try:
df_pd = df.toPandas()
plt.figure(figsize=(10, 6))
plt.plot(df_pd['Date'], df_pd['Close'])
plt.xlabel('日期')
plt.ylabel('收盤價')
plt.title('Tesla 股票收盤價變化趨勢')
plt.xticks(rotation=45, ha='right')
plt.gca().invert_xaxis()
plt.xticks(np.arange(0, len(df_pd['Date']), step=max(len(df_pd['Date']) // 10, 1)))
plt.show()
except Exception as e:
print(f"視覺化過程中發生錯誤:{str(e)}")
內容解密:
visualize_data方法將 Spark DataFrame 轉換為 Pandas DataFrame,以便使用 Matplotlib 繪製圖表。- 使用
plt.plot繪製 Tesla 股票收盤價隨時間變化的折線圖。 - 設定圖表的標籤、標題以及 x 軸刻度的旋轉,以提高可讀性。
- 反轉 x 軸以確保日期按時間順序排列。
- 調整 x 軸刻度的間隔,以避免過於密集。
空值檢查:check_for_null_values
def check_for_null_values(self, df):
"""檢查 DataFrame 中的空值。"""
print("進行空值檢查:")
null_counts = df.select([col(c).isNull().cast("int").alias(c) for c in df.columns]).agg(*[F.sum(c).alias(c) for c in df.columns]).toPandas()
print(null_counts)
內容解密:
check_for_null_values方法檢查 DataFrame 中的空值數量。- 使用 Spark DataFrame 的操作來計算每列的空值數量,並將結果轉換為 Pandas DataFrame 以便顯示。
- 輸出每列的空值計數結果。
從 S3 複製檔案:copy_file_from_s3
def copy_file_from_s3(bucket_name: str, file_key: str, local_file_path: str):
"""
從 S3 bucket 複製檔案到本地路徑。
"""
try:
s3 = boto3.client('s3')
s3.download_file(bucket_name, file_key, local_file_path)
print(f"檔案已從 S3 bucket {bucket_name} 下載到 {local_file_path}")
except Exception as e:
print(f"下載檔案過程中發生錯誤:{str(e)}")
內容解密:
copy_file_from_s3函式使用 boto3 函式庫從指定的 S3 bucket 下載檔案到本地路徑。- 在下載過程中,若發生任何例外狀況,均會被捕捉並輸出錯誤訊息。