深度學習在醫療領域的應用日益廣泛,糖尿病預測是其中一個重要方向。本文使用 Pima 印第安人糖尿病資料集,結合 PyTorch 深度學習框架,構建一個糖尿病預測模型。資料預處理階段利用 Spark 的分散式計算能力,提升處理效率。模型採用兩層全連線網路結構,並使用 ReLU 啟用函式。為了評估模型的泛化能力,我們採用 K-Fold 交叉驗證方法,將資料集分成多個子集,進行多次訓練和驗證,以獲得更穩定的效能指標。
深度學習與 PyTorch 在分類別任務中的應用
在前兩章中,我們探討了深度學習在迴歸任務中的應用。從本章開始,我們將重點轉移到分類別任務上,這是深度學習中的另一個基本任務。迴歸和分類別是兩種不同型別的機器學習任務:迴歸旨在從輸入特徵中預測連續的數值,例如預測特斯拉的股票價格;而分類別則涉及將輸入資料分配到預定義的類別或類別中。在本章中,我們的目標是根據諸如懷孕次數、血糖水平、血壓、身體品質指數(BMI)、年齡和糖尿病家族史等屬性,預測一名女性被診斷出患有糖尿病的機率。
資料集介紹
本章使用的資料集是著名的 Pima Indians Diabetes Dataset。該資料集包含了 768 筆記錄,每筆記錄代表一位 Pima 印第安女性,並包含各種與健康相關的屬性,以及一個目標變數(結果),指示該女性是否患有糖尿病。該資料集適用於根據明確定義的屬性(如懷孕次數、血糖水平、血壓、身體品質指數(BMI)、年齡和糖尿病家族史)預測糖尿病結果的分類別任務。
資料集來源
- 標題:Pima Indians Diabetes Database
- 來源:Kaggle
- 網址:www.kaggle.com/uciml/pima-indians-diabetes-database
- 貢獻者:UCI Machine Learning
- 日期:2016
資料預處理與探索
我們使用 Spark 進行資料預處理,以展示其在處理大型資料集時的分散式計算能力。儘管我們的資料集相對較小,但這些概念和程式碼同樣適用於更大的資料集,使其在生產環境中具有價值。我們將探討資料增強、遷移學習、重取樣、模型複雜度、資料品質和資料表示等重要主題。
# 使用 PySpark 進行資料探索的範例程式碼
from pyspark.sql import SparkSession
# 初始化 SparkSession
spark = SparkSession.builder.appName("DiabetesClassification").getOrCreate()
# 載入資料集
df = spark.read.csv("s3://your-bucket/pima-indians-diabetes.csv", header=True, inferSchema=True)
# 顯示資料集的前幾行
df.show()
# 統計每列缺失值的數量
df.select([count(when(isnull(c), c)).alias(c) for c in df.columns]).show()
#### 內容解密:
1. 初始化 `SparkSession` 是使用 Spark 進行資料處理的第一步。
2. 使用 `read.csv` 方法從 S3 儲存桶中載入 CSV 檔案。
3. `show` 方法用於顯示 DataFrame 的前幾行,方便檢查資料。
4. 統計缺失值的程式碼使用了 `count` 和 `when` 函式結合 `isnull` 檢查每列的缺失值數量。
結構與訓練 PyTorch 分類別模型
我們將構建、訓練和評估一個 PyTorch 分類別模型,以預測糖尿病結果。首先,我們需要定義模型架構,然後進行訓練和評估。
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
# 定義 PyTorch 模型
class DiabetesClassifier(nn.Module):
def __init__(self):
super(DiabetesClassifier, self).__init__()
self.fc1 = nn.Linear(8, 128) # 輸入層到隱藏層
self.fc2 = nn.Linear(128, 2) # 隱藏層到輸出層
def forward(self, x):
x = torch.relu(self.fc1(x)) # 使用 ReLU 啟用函式
x = self.fc2(x)
return x
#### 內容解密:
1. `DiabetesClassifier` 類別繼承自 `nn.Module`,是 PyTorch 中定義模型的標準做法。
2. `__init__` 方法中定義了模型的層,包括兩個全連線層(`fc1` 和 `fc2`)。
3. `forward` 方法定義了資料在模型中的前向傳播路徑,使用了 ReLU 啟用函式。
K-Fold 交叉驗證
為了確保模型的泛化能力,我們採用 K-Fold 交叉驗證方法。這涉及將資料集分成 K 個子集,然後進行 K 次訓練和驗證,每次使用不同的子集作為驗證集。
from sklearn.model_selection import KFold
# 初始化 KFold 物件
kf = KFold(n_splits=5, shuffle=True, random_state=42)
# K-Fold 交叉驗證迴圈
for train_index, val_index in kf.split(df):
train_df = df.iloc[train_index]
val_df = df.iloc[val_index]
# 在此處新增訓練和驗證模型的程式碼
#### 內容解密:
1. `KFold` 物件被初始化,指定了折數(`n_splits`)和其他引數。
2. 在 K-Fold 迴圈中,資料被分成訓練集和驗證集,用於模型的訓練和評估。
探索與分析 Pima 印第安人糖尿病資料集
資料集探索的重要性
在進行機器學習或深度學習任務之前,瞭解資料集的特性至關重要。對於 Pima 印第安人糖尿病資料集的探索,可以揭示資料的分佈、潛在問題以及與目標變數之間的關聯。
資料集探索的內容
- 零值計數:計算資料集中每個欄位的零值數量。這有助於發現潛在的問題,例如缺失或錯誤的資料,特別是在零值不應存在的特徵中(例如血壓、血糖水平)。
- 資料摘要:提供資料集的摘要,包括數值欄位的平均值、標準差、最小值、最大值和四分位數等統計資料。這有幫助於瞭解不同特徵的分佈和範圍。
- 結果值計數:計算資料集中每個結果值的出現次數。這有助於瞭解資料集中糖尿病病例的分佈情況。
- 特徵分佈:繪製相關特徵的直方圖,以視覺化特徵的分佈。這可以提供對資料底層模式、潛在異常值以及資料是否偏態的洞察。
- 特徵與目標變數的相關性計算:計算每個特徵與目標變數(結果)之間的相關性。這有助於瞭解哪些特徵可能與目標變數更強烈相關,這對於分類別等預測建模任務至關重要。
程式碼解析
步驟1:匯入必要的函式庫
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum as pyspark_sum
import boto3
內容解密:
from pyspark.sql import SparkSession:匯入 SparkSession 類別,這是進入 Spark 的入口點,提供分散式資料處理的功能。from pyspark.sql.functions import col, sum as pyspark_sum:匯入特定的函式,用於參照 DataFrame 中的欄位和計算欄位值的總和。import boto3:匯入 boto3 模組,這是 AWS 的官方 Python SDK,用於以程式設計方式與各種 AWS 服務(如 S3 和 EC2)互動。
步驟2:類別定義
class PimaDatasetExplorer:
"""
用於使用 PySpark 探索和分析 Pima 印第安人糖尿病資料集的類別。
"""
def __init__(self, file_path):
"""
初始化 PimaDatasetExplorer 物件。
Args:
file_path (str):資料集檔案的路徑。
"""
self.file_path = file_path
self.data = None
內容解密:
- 定義了一個名為
PimaDatasetExplorer的類別,用於封裝與 Pima 印第安人糖尿病資料集相關的操作。 __init__方法初始化物件,接受資料集檔案的路徑作為引數。
步驟3:load_data 方法定義
def load_data(self):
"""
從指定的檔案路徑載入資料集到 PySpark DataFrame。
"""
try:
self.data = spark.read.csv(
self.file_path,
header=True,
inferSchema=True
)
self.data.cache()
except Exception as e:
print("載入資料時發生錯誤:", str(e))
self.data = None
內容解密:
load_data方法負責從指定路徑載入資料集到 PySpark DataFrame。- 使用
try-except區塊處理載入過程中的異常。 spark.read.csv方法讀取 CSV 檔案,header=True表示第一行包含欄位名稱,inferSchema=True表示自動推斷欄位的資料型別。cache()方法將 DataFrame 的內容儲存在記憶體或磁碟上,以提高後續操作的效能。
步驟4:preprocess_data 方法定義
def preprocess_data(self):
"""
對資料集執行預處理步驟。
"""
self.load_data()
內容解密:
preprocess_data方法目前僅呼叫load_data方法載入資料集。- 未來可以擴充套件此方法以執行其他預處理步驟,例如處理缺失值、轉換資料型別等。
PimaDatasetExplorer 類別方法實作詳解
handle_missing_values 方法定義
此方法用於計算資料集中每個欄位的缺失值數量並顯示計數結果。
程式碼實作
def handle_missing_values(self):
"""
計算資料集中每個欄位的缺失值數量並顯示計數結果。
"""
if self.data is not None:
missing_counts = self.data.select([
pyspark_sum(col(c).isNull().cast("int")).alias(c)
for c in self.data.columns
])
print("缺失值計數:")
missing_counts.show()
else:
print("資料未載入。")
內容解密:
- 方法首先檢查
self.data是否已載入資料。 - 使用
isNull()函式識別缺失值,並透過cast("int")將布林結果轉換為整數(0 或 1)。 - 利用
pyspark_sum計算每個欄位的缺失值總和,並將結果以原始欄位名稱別名。 - 顯示缺失值計數的 DataFrame。
count_zeros 方法定義
此方法旨在統計資料集中每個欄位的零值數量並顯示計數結果。
程式碼實作
def count_zeros(self):
"""
統計資料集中每個欄位的零值數量並顯示計數結果。
"""
if self.data is not None:
zero_counts = self.data.select([
pyspark_sum((col(c) == 0).cast("int")).alias(c)
for c in self.data.columns
])
print("零值計數:")
zero_counts.show()
else:
print("資料未載入。")
內容解密:
- 檢查
self.data是否已載入資料。 - 使用
(col(c) == 0)評估欄位值是否為零,並將布林結果轉換為整數。 - 利用
pyspark_sum計算每個欄位的零值總和。 - 顯示零值計數的 DataFrame。
data_summary 方法定義
此方法提供糖尿病資料集相關特徵的摘要。
程式碼實作
def data_summary(self):
"""
顯示資料集的摘要,排除特定欄位並過濾零值行。
"""
if self.data is not None:
print("資料摘要(排除 'Outcome', 'SkinThickness', 'Insulin'):")
columns_to_exclude = ['Outcome', 'SkinThickness', 'Insulin']
summary_cols = [c for c in self.data.columns if c not in columns_to_exclude]
filtered_data = self.data.filter((col("Glucose") != 0) & (col("BloodPressure") != 0) & (col("BMI") != 0))
filtered_data.select(summary_cols).describe().show()
else:
print("資料未載入。")
內容解密:
- 檢查資料是否已載入。
- 定義需要排除的欄位列表
columns_to_exclude。 - 使用列表推導式建立需要摘要的欄位列表
summary_cols。 - 過濾資料集,排除 ‘Glucose’、‘BloodPressure’ 和 ‘BMI’ 為零的行。
- 對過濾後的資料集進行描述性統計,並顯示結果。
PimaDatasetExplorer 類別方法實作詳解
方法實作概述
本章節將詳細介紹 PimaDatasetExplorer 類別中的多個關鍵方法,包括資料摘要統計、結果計數、特徵分佈探索以及特徵與目標變數的相關性計算。這些方法共同構成了資料探索和分析的核心功能。
Step 8:count_outcome 方法定義
方法功能
count_outcome 方法的主要功能是統計資料集中 Outcome 欄位的各個值的出現次數。
程式碼實作
def count_outcome(self):
"""
統計資料集中 Outcome 欄位各個值的出現次數。
"""
if self.data is not None:
print("Outcome Counts:")
outcome_counts = self.data.filter(
(col("Glucose") != 0)
& (col("BloodPressure") != 0)
& (col("BMI") != 0)
).groupBy("Outcome").count()
outcome_counts.show()
else:
print("Data not loaded.")
內容解密:
- 資料檢查:首先檢查
self.data是否為None,確保資料已載入。 - 資料過濾:使用
filter方法排除 Glucose、BloodPressure 和 BMI 為 0 的無效資料列。 - 分組統計:使用
groupBy和count方法統計 Outcome 欄位各個值的出現次數。 - 結果顯示:使用
show方法顯示統計結果。
Step 9:explore_feature_distributions 方法定義
方法功能
explore_feature_distributions 方法用於分析並視覺化資料集中每個特徵的分佈情況。
程式碼實作
def explore_feature_distributions(self):
"""
繪製資料集中每個特徵的直方圖以分析其分佈。
"""
if self.data is not None:
print("Feature Distributions:")
for column in self.data.columns:
if column not in ['SkinThickness', 'Insulin']:
print(f"Feature: {column}")
if column in ["Glucose", "BloodPressure", "BMI"]:
filtered_data = self.data.filter(
(col(column) != 0)
& (col("Glucose") != 0)
& (col("BloodPressure") != 0)
& (col("BMI") != 0)
)
else:
filtered_data = self.data
plot_data = filtered_data.select(column).toPandas()
plot_data.plot(kind='hist', title=column)
else:
print("Data not loaded.")
內容解密:
- 迴圈遍歷:遍歷資料集中的每個欄位,排除 ‘SkinThickness’ 和 ‘Insulin’。
- 條件過濾:對於特定的欄位(如 Glucose、BLOODPressure 和 BMI),過濾掉包含 0 值的列。
- 轉換為 Pandas DataFrame:使用
toPandas方法將選定的資料轉換為 Pandas DataFrame。 - 繪製直方圖:使用
plot方法繪製直方圖以視覺化特徵分佈。
Step 10:calculate_feature_target_correlation 方法定義
方法功能
calculate_feature_target_correlation 方法用於計算每個特徵與目標變數(Outcome)之間的相關性。
程式碼實作
def calculate_feature_target_correlation(self):
"""
計算每個特徵與目標變數(Outcome)之間的相關性。
"""
if self.data is not None:
print("Feature-Target Correlation (excluding 'SkinThickness', 'Insulin'):")
for column in self.data.columns:
if column not in ['Outcome', 'SkinThickness', 'Insulin']:
correlation = self.data.filter(
(col(column) != 0)
& (col("Glucose") != 0)
& (col("BloodPressure") != 0)
& (col("BMI") != 0)
).stat.corr(column, 'Outcome')
print(f"{column}: {correlation}")
else:
print("Data not loaded.")
內容解密:
- 相關性計算:遍歷每個特徵欄位,排除 ‘Outcome’、‘SkinThickness’ 和 ‘Insulin’。
- 資料過濾:過濾掉包含特定欄位 0 值的列,以確保相關性計算的準確性。
- 統計相關性:使用
stat.corr方法計算特徵與 Outcome 之間的相關性。 - 輸出結果:列印每個特徵與 Outcome 的相關性值。