返回文章列表

PyTorch糖尿病預測模型開發與驗證

本文利用 PyTorch 框架與 Pima 印第安人糖尿病資料集,構建一個深度學習分類別模型,用於預測糖尿病風險。文章涵蓋資料預處理、模型架構設計、訓練過程、K-Fold 交叉驗證以及 Spark 分散式計算的應用。透過 PySpark 進行資料探索和預處理,並使用 PyTorch

機器學習 深度學習

深度學習在醫療領域的應用日益廣泛,糖尿病預測是其中一個重要方向。本文使用 Pima 印第安人糖尿病資料集,結合 PyTorch 深度學習框架,構建一個糖尿病預測模型。資料預處理階段利用 Spark 的分散式計算能力,提升處理效率。模型採用兩層全連線網路結構,並使用 ReLU 啟用函式。為了評估模型的泛化能力,我們採用 K-Fold 交叉驗證方法,將資料集分成多個子集,進行多次訓練和驗證,以獲得更穩定的效能指標。

深度學習與 PyTorch 在分類別任務中的應用

在前兩章中,我們探討了深度學習在迴歸任務中的應用。從本章開始,我們將重點轉移到分類別任務上,這是深度學習中的另一個基本任務。迴歸和分類別是兩種不同型別的機器學習任務:迴歸旨在從輸入特徵中預測連續的數值,例如預測特斯拉的股票價格;而分類別則涉及將輸入資料分配到預定義的類別或類別中。在本章中,我們的目標是根據諸如懷孕次數、血糖水平、血壓、身體品質指數(BMI)、年齡和糖尿病家族史等屬性,預測一名女性被診斷出患有糖尿病的機率。

資料集介紹

本章使用的資料集是著名的 Pima Indians Diabetes Dataset。該資料集包含了 768 筆記錄,每筆記錄代表一位 Pima 印第安女性,並包含各種與健康相關的屬性,以及一個目標變數(結果),指示該女性是否患有糖尿病。該資料集適用於根據明確定義的屬性(如懷孕次數、血糖水平、血壓、身體品質指數(BMI)、年齡和糖尿病家族史)預測糖尿病結果的分類別任務。

資料集來源

  • 標題:Pima Indians Diabetes Database
  • 來源:Kaggle
  • 網址:www.kaggle.com/uciml/pima-indians-diabetes-database
  • 貢獻者:UCI Machine Learning
  • 日期:2016

資料預處理與探索

我們使用 Spark 進行資料預處理,以展示其在處理大型資料集時的分散式計算能力。儘管我們的資料集相對較小,但這些概念和程式碼同樣適用於更大的資料集,使其在生產環境中具有價值。我們將探討資料增強、遷移學習、重取樣、模型複雜度、資料品質和資料表示等重要主題。

# 使用 PySpark 進行資料探索的範例程式碼
from pyspark.sql import SparkSession

# 初始化 SparkSession
spark = SparkSession.builder.appName("DiabetesClassification").getOrCreate()

# 載入資料集
df = spark.read.csv("s3://your-bucket/pima-indians-diabetes.csv", header=True, inferSchema=True)

# 顯示資料集的前幾行
df.show()

# 統計每列缺失值的數量
df.select([count(when(isnull(c), c)).alias(c) for c in df.columns]).show()

#### 內容解密:
1. 初始化 `SparkSession` 是使用 Spark 進行資料處理的第一步
2. 使用 `read.csv` 方法從 S3 儲存桶中載入 CSV 檔案
3. `show` 方法用於顯示 DataFrame 的前幾行方便檢查資料
4. 統計缺失值的程式碼使用了 `count``when` 函式結合 `isnull` 檢查每列的缺失值數量

結構與訓練 PyTorch 分類別模型

我們將構建、訓練和評估一個 PyTorch 分類別模型,以預測糖尿病結果。首先,我們需要定義模型架構,然後進行訓練和評估。

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# 定義 PyTorch 模型
class DiabetesClassifier(nn.Module):
    def __init__(self):
        super(DiabetesClassifier, self).__init__()
        self.fc1 = nn.Linear(8, 128)  # 輸入層到隱藏層
        self.fc2 = nn.Linear(128, 2)  # 隱藏層到輸出層

    def forward(self, x):
        x = torch.relu(self.fc1(x))  # 使用 ReLU 啟用函式
        x = self.fc2(x)
        return x

#### 內容解密:
1. `DiabetesClassifier` 類別繼承自 `nn.Module`,是 PyTorch 中定義模型的標準做法
2. `__init__` 方法中定義了模型的層包括兩個全連線層(`fc1``fc2`)。
3. `forward` 方法定義了資料在模型中的前向傳播路徑使用了 ReLU 啟用函式

K-Fold 交叉驗證

為了確保模型的泛化能力,我們採用 K-Fold 交叉驗證方法。這涉及將資料集分成 K 個子集,然後進行 K 次訓練和驗證,每次使用不同的子集作為驗證集。

from sklearn.model_selection import KFold

# 初始化 KFold 物件
kf = KFold(n_splits=5, shuffle=True, random_state=42)

# K-Fold 交叉驗證迴圈
for train_index, val_index in kf.split(df):
    train_df = df.iloc[train_index]
    val_df = df.iloc[val_index]
    # 在此處新增訓練和驗證模型的程式碼

#### 內容解密:
1. `KFold` 物件被初始化指定了折數(`n_splits`)和其他引數
2. 在 K-Fold 迴圈中資料被分成訓練集和驗證集用於模型的訓練和評估

探索與分析 Pima 印第安人糖尿病資料集

資料集探索的重要性

在進行機器學習或深度學習任務之前,瞭解資料集的特性至關重要。對於 Pima 印第安人糖尿病資料集的探索,可以揭示資料的分佈、潛在問題以及與目標變數之間的關聯。

資料集探索的內容

  • 零值計數:計算資料集中每個欄位的零值數量。這有助於發現潛在的問題,例如缺失或錯誤的資料,特別是在零值不應存在的特徵中(例如血壓、血糖水平)。
  • 資料摘要:提供資料集的摘要,包括數值欄位的平均值、標準差、最小值、最大值和四分位數等統計資料。這有幫助於瞭解不同特徵的分佈和範圍。
  • 結果值計數:計算資料集中每個結果值的出現次數。這有助於瞭解資料集中糖尿病病例的分佈情況。
  • 特徵分佈:繪製相關特徵的直方圖,以視覺化特徵的分佈。這可以提供對資料底層模式、潛在異常值以及資料是否偏態的洞察。
  • 特徵與目標變數的相關性計算:計算每個特徵與目標變數(結果)之間的相關性。這有助於瞭解哪些特徵可能與目標變數更強烈相關,這對於分類別等預測建模任務至關重要。

程式碼解析

步驟1:匯入必要的函式庫

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum as pyspark_sum
import boto3

內容解密:

  1. from pyspark.sql import SparkSession:匯入 SparkSession 類別,這是進入 Spark 的入口點,提供分散式資料處理的功能。
  2. from pyspark.sql.functions import col, sum as pyspark_sum:匯入特定的函式,用於參照 DataFrame 中的欄位和計算欄位值的總和。
  3. import boto3:匯入 boto3 模組,這是 AWS 的官方 Python SDK,用於以程式設計方式與各種 AWS 服務(如 S3 和 EC2)互動。

步驟2:類別定義

class PimaDatasetExplorer:
    """
    用於使用 PySpark 探索和分析 Pima 印第安人糖尿病資料集的類別。
    """
    def __init__(self, file_path):
        """
        初始化 PimaDatasetExplorer 物件。
        
        Args:
            file_path (str):資料集檔案的路徑。
        """
        self.file_path = file_path
        self.data = None

內容解密:

  1. 定義了一個名為 PimaDatasetExplorer 的類別,用於封裝與 Pima 印第安人糖尿病資料集相關的操作。
  2. __init__ 方法初始化物件,接受資料集檔案的路徑作為引數。

步驟3:load_data 方法定義

def load_data(self):
    """
    從指定的檔案路徑載入資料集到 PySpark DataFrame。
    """
    try:
        self.data = spark.read.csv(
            self.file_path,
            header=True,
            inferSchema=True
        )
        self.data.cache()
    except Exception as e:
        print("載入資料時發生錯誤:", str(e))
        self.data = None

內容解密:

  1. load_data 方法負責從指定路徑載入資料集到 PySpark DataFrame。
  2. 使用 try-except 區塊處理載入過程中的異常。
  3. spark.read.csv 方法讀取 CSV 檔案,header=True 表示第一行包含欄位名稱,inferSchema=True 表示自動推斷欄位的資料型別。
  4. cache() 方法將 DataFrame 的內容儲存在記憶體或磁碟上,以提高後續操作的效能。

步驟4:preprocess_data 方法定義

def preprocess_data(self):
    """
    對資料集執行預處理步驟。
    """
    self.load_data()

內容解密:

  1. preprocess_data 方法目前僅呼叫 load_data 方法載入資料集。
  2. 未來可以擴充套件此方法以執行其他預處理步驟,例如處理缺失值、轉換資料型別等。

PimaDatasetExplorer 類別方法實作詳解

handle_missing_values 方法定義

此方法用於計算資料集中每個欄位的缺失值數量並顯示計數結果。

程式碼實作

def handle_missing_values(self):
    """
    計算資料集中每個欄位的缺失值數量並顯示計數結果。
    """
    if self.data is not None:
        missing_counts = self.data.select([
            pyspark_sum(col(c).isNull().cast("int")).alias(c)
            for c in self.data.columns
        ])
        print("缺失值計數:")
        missing_counts.show()
    else:
        print("資料未載入。")

內容解密:

  1. 方法首先檢查 self.data 是否已載入資料。
  2. 使用 isNull() 函式識別缺失值,並透過 cast("int") 將布林結果轉換為整數(0 或 1)。
  3. 利用 pyspark_sum 計算每個欄位的缺失值總和,並將結果以原始欄位名稱別名。
  4. 顯示缺失值計數的 DataFrame。

count_zeros 方法定義

此方法旨在統計資料集中每個欄位的零值數量並顯示計數結果。

程式碼實作

def count_zeros(self):
    """
    統計資料集中每個欄位的零值數量並顯示計數結果。
    """
    if self.data is not None:
        zero_counts = self.data.select([
            pyspark_sum((col(c) == 0).cast("int")).alias(c)
            for c in self.data.columns
        ])
        print("零值計數:")
        zero_counts.show()
    else:
        print("資料未載入。")

內容解密:

  1. 檢查 self.data 是否已載入資料。
  2. 使用 (col(c) == 0) 評估欄位值是否為零,並將布林結果轉換為整數。
  3. 利用 pyspark_sum 計算每個欄位的零值總和。
  4. 顯示零值計數的 DataFrame。

data_summary 方法定義

此方法提供糖尿病資料集相關特徵的摘要。

程式碼實作

def data_summary(self):
    """
    顯示資料集的摘要,排除特定欄位並過濾零值行。
    """
    if self.data is not None:
        print("資料摘要(排除 'Outcome', 'SkinThickness', 'Insulin'):")
        columns_to_exclude = ['Outcome', 'SkinThickness', 'Insulin']
        summary_cols = [c for c in self.data.columns if c not in columns_to_exclude]
        filtered_data = self.data.filter((col("Glucose") != 0) & (col("BloodPressure") != 0) & (col("BMI") != 0))
        filtered_data.select(summary_cols).describe().show()
    else:
        print("資料未載入。")

內容解密:

  1. 檢查資料是否已載入。
  2. 定義需要排除的欄位列表 columns_to_exclude
  3. 使用列表推導式建立需要摘要的欄位列表 summary_cols
  4. 過濾資料集,排除 ‘Glucose’、‘BloodPressure’ 和 ‘BMI’ 為零的行。
  5. 對過濾後的資料集進行描述性統計,並顯示結果。

PimaDatasetExplorer 類別方法實作詳解

方法實作概述

本章節將詳細介紹 PimaDatasetExplorer 類別中的多個關鍵方法,包括資料摘要統計、結果計數、特徵分佈探索以及特徵與目標變數的相關性計算。這些方法共同構成了資料探索和分析的核心功能。

Step 8:count_outcome 方法定義

方法功能

count_outcome 方法的主要功能是統計資料集中 Outcome 欄位的各個值的出現次數。

程式碼實作

def count_outcome(self):
    """
    統計資料集中 Outcome 欄位各個值的出現次數。
    """
    if self.data is not None:
        print("Outcome Counts:")
        outcome_counts = self.data.filter(
            (col("Glucose") != 0)
            & (col("BloodPressure") != 0)
            & (col("BMI") != 0)
        ).groupBy("Outcome").count()
        outcome_counts.show()
    else:
        print("Data not loaded.")

內容解密:

  1. 資料檢查:首先檢查 self.data 是否為 None,確保資料已載入。
  2. 資料過濾:使用 filter 方法排除 Glucose、BloodPressure 和 BMI 為 0 的無效資料列。
  3. 分組統計:使用 groupBycount 方法統計 Outcome 欄位各個值的出現次數。
  4. 結果顯示:使用 show 方法顯示統計結果。

Step 9:explore_feature_distributions 方法定義

方法功能

explore_feature_distributions 方法用於分析並視覺化資料集中每個特徵的分佈情況。

程式碼實作

def explore_feature_distributions(self):
    """
    繪製資料集中每個特徵的直方圖以分析其分佈。
    """
    if self.data is not None:
        print("Feature Distributions:")
        for column in self.data.columns:
            if column not in ['SkinThickness', 'Insulin']:
                print(f"Feature: {column}")
                if column in ["Glucose", "BloodPressure", "BMI"]:
                    filtered_data = self.data.filter(
                        (col(column) != 0)
                        & (col("Glucose") != 0)
                        & (col("BloodPressure") != 0)
                        & (col("BMI") != 0)
                    )
                else:
                    filtered_data = self.data
                plot_data = filtered_data.select(column).toPandas()
                plot_data.plot(kind='hist', title=column)
    else:
        print("Data not loaded.")

內容解密:

  1. 迴圈遍歷:遍歷資料集中的每個欄位,排除 ‘SkinThickness’ 和 ‘Insulin’。
  2. 條件過濾:對於特定的欄位(如 Glucose、BLOODPressure 和 BMI),過濾掉包含 0 值的列。
  3. 轉換為 Pandas DataFrame:使用 toPandas 方法將選定的資料轉換為 Pandas DataFrame。
  4. 繪製直方圖:使用 plot 方法繪製直方圖以視覺化特徵分佈。

Step 10:calculate_feature_target_correlation 方法定義

方法功能

calculate_feature_target_correlation 方法用於計算每個特徵與目標變數(Outcome)之間的相關性。

程式碼實作

def calculate_feature_target_correlation(self):
    """
    計算每個特徵與目標變數(Outcome)之間的相關性。
    """
    if self.data is not None:
        print("Feature-Target Correlation (excluding 'SkinThickness', 'Insulin'):")
        for column in self.data.columns:
            if column not in ['Outcome', 'SkinThickness', 'Insulin']:
                correlation = self.data.filter(
                    (col(column) != 0)
                    & (col("Glucose") != 0)
                    & (col("BloodPressure") != 0)
                    & (col("BMI") != 0)
                ).stat.corr(column, 'Outcome')
                print(f"{column}: {correlation}")
    else:
        print("Data not loaded.")

內容解密:

  1. 相關性計算:遍歷每個特徵欄位,排除 ‘Outcome’、‘SkinThickness’ 和 ‘Insulin’。
  2. 資料過濾:過濾掉包含特定欄位 0 值的列,以確保相關性計算的準確性。
  3. 統計相關性:使用 stat.corr 方法計算特徵與 Outcome 之間的相關性。
  4. 輸出結果:列印每個特徵與 Outcome 的相關性值。