返回文章列表

PySpark 生存分析實戰:使用 AFT 模型精準預測客戶生命週期價值 (CLV)

本文為一份完整的 PySpark 資料科學實戰教學,指導您如何利用加速失效時間 (AFT) 生存分析模型,來精準預測客戶生命週期價值 (CLV)。內容涵蓋 AFT 模型理論、RFM 特徵工程、PySpark 模型訓練與評估,提供一套從商業問題到數據洞察的完整解決方案。

機器學習 資料分析

在競爭激烈的市場中,準確預測客戶生命週期價值 (Customer Lifetime Value, CLV) 對於企業制定行銷策略、優化資源分配至關重要。傳統的迴歸模型在處理帶有「刪失」(censored) 特性的客戶流失資料時常常力不從心。本文將深入探討如何應用生存分析中的加速失效時間 (Accelerated Failure Time, AFT) 模型,結合 PySpark 的大數據處理能力,建立一個更穩健、更準確的 CLV 預測系統。

第一部分:理論基礎 - 為何選擇 AFT 模型?

生存分析是專門用來分析「事件發生時間」的統計方法。在 CLV 場景中,這個「事件」就是客戶流失。與更廣為人知的 Cox 比例風險 (PH) 模型相比,AFT 模型具有獨特的優勢。

  • Cox PH 模型: 假設變數對「風險率」(Hazard Rate) 的影響是成比例且不隨時間變化的。其係數解釋的是風險比,相對不直觀。
  • AFT 模型: 不依賴比例風險假設,它直接模擬變數對「生存時間」本身的影響。其係數可以直接解釋為對生存時間的加速或減速因子,更易於理解和應用。

AFT vs. Cox PH 模型對比圖:此圖簡要對比了 AFT 模型與 Cox PH 模型在核心假設與係數解釋上的主要差異。AFT 模型直接模擬變數對生存時間的影響,使其結果更易於業務解釋;而 Cox 模型則關注於風險比,在某些情境下較為抽象。

@startuml
!theme _none_
skinparam dpi auto
skinparam defaultFontName "Microsoft JhengHei UI"
skinparam minClassWidth 100
skinparam defaultFontSize 16
title AFT vs. Cox PH 模型對比

package "AFT 模型" {
  [核心假設: \n協變數對生存時間有\n**乘法效應** (加速/減速)]
  [係數解釋: \n直接量化對**生存時間**的影響]
}

package "Cox PH 模型" {
  [核心假設: \n協變數對風險函數有\n**比例效應**]
  [係數解釋: \n量化對**風險比 (Hazard Ratio)** 的影響]
}
@enduml

PySpark 中的 RFM 分析實作

RFM(Recency、Frequency、Monetary)分析是一種常見的客戶細分方法,以下是使用 PySpark 生成合成資料並進行 RFM 分析的範例:

# 載入必要函式庫
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, count, sum, max, lit, datediff, when, concat
import datetime
import random

# 初始化 Spark session
spark = SparkSession.builder.appName("RFM Analysis").getOrCreate()

# 生成合成資料
num_customers = 100
num_transactions = 1000
customers = [f"customer_{i}" for i in range(1, num_customers + 1)]
transaction_dates = [datetime.date.today() - datetime.timedelta(days=random.randint(1, 365)) for _ in range(num_transactions)]
transaction_amounts = [random.uniform(10.0, 1000.0) for _ in range(num_transactions)]
data = [(random.choice(customers), transaction_dates[i], transaction_amounts[i]) for i in range(num_transactions)]
columns = ["CustomerID", "TransactionDate", "Amount"]

# 建立 DataFrame
df = spark.createDataFrame(data, columns)

#### 程式碼解析:
1. 載入必要的 PySpark 函式庫和資料型別用於資料處理和分析
2. 初始化一個名為 "RFM Analysis" 的 Spark session
3. 生成包含客戶交易資料的合成資料包括客戶 ID交易日期和交易金額
4. 將生成的資料轉換為 Spark DataFrame以便進行後續的 RFM 分析

# 計算 Recency、Frequency 和 Monetary 值
current_date = datetime.date.today()
df = df.withColumn("CurrentDate", lit(current_date).cast("date"))
recency_df = df.groupBy("CustomerID").agg(datediff(max("CurrentDate"), max("TransactionDate")).alias("Recency"))
frequency_df = df.groupBy("CustomerID").agg(count("TransactionDate").alias("Frequency"))
monetary_df = df.groupBy("CustomerID").agg(sum("Amount").alias("Monetary"))

#### 程式碼解析:
1. 新增一欄位 "CurrentDate" 以代表當前日期用於計算客戶最近一次交易日期與當前日期的差值Recency)。
2. 分別計算每個客戶的 Recency最近一次交易距今的天數)、Frequency交易次數和 Monetary總交易金額)。

# 合併 RFM 值並進行分段
rfm_df = recency_df.join(frequency_df, "CustomerID").join(monetary_df, "CustomerID")
quantiles = [0.25, 0.50, 0.75]
recency_quantiles = rfm_df.approxQuantile("Recency", quantiles, 0.0)
frequency_quantiles = rfm_df.approxQuantile("Frequency", quantiles, 0.0)
monetary_quantiles = rfm_df.approxQuantile("Monetary", quantiles, 0.0)

rfm_df = rfm_df.withColumn("R_Segment", when(col("Recency") <= recency_quantiles[0], "R1")
                                      .when(col("Recency") <= recency_quantiles[1], "R2")
                                      .when(col("Recency") <= recency_quantiles[2], "R3")
                                      .otherwise("R4"))

#### 程式碼解析:
1. 合併每個客戶的 RecencyFrequency 和 Monetary 值到一個 DataFrame 中
2. 使用 quantile 方法計算 RecencyFrequency 和 Monetary 的四分位數並根據這些分位數將客戶分為不同的 RFM 區段

### PySpark 中的 AFT 生存分析實作

以下是使用 PySpark 中的 AFTSurvivalRegression 進行 CLV 生存分析的範例

```python
# 載入必要函式庫
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import AFTSurvivalRegression
from pyspark.sql.functions import col

# 初始化 Spark session
spark = SparkSession.builder.appName("Survival Analysis for CLV with AFT").getOrCreate()

# 建立範例資料(實際應用中應使用真實資料)
# ...

# 使用 VectorAssembler 組裝特徵向量
assembler = VectorAssembler(inputCols=["feature1", "feature2"], outputCol="features")
data = assembler.transform(data)

# 建立 AFT 生存迴歸模型
aft_model = AFTSurvivalRegression(labelCol="label", censorCol="censor", featuresCol="features")
model = aft_model.fit(data)

#### 程式碼解析:
1. 載入必要的 PySpark ML 函式庫用於建立和訓練 AFT 生存分析模型
2. 使用 VectorAssembler 將輸入特徵組裝成一個特徵向量供模型訓練使用
3. 初始化並訓練一個 AFTSurvivalRegression 模型用於生存分析

#### AFT 模型訓練與預測時序圖
此循序圖展示了使用 PySpark 進行 AFT 模型訓練與預測的過程
```plantuml
@startuml
!theme _none_
skinparam dpi auto
skinparam defaultFontName "Microsoft JhengHei UI"
skinparam minClassWidth 100
skinparam defaultFontSize 16
title AFT 模型訓練與預測時序

actor User as 使用者
participant "PySpark" as Spark
participant "AFTSurvivalRegression" as AFT

使用者 -> Spark : 準備訓練資料 (包含 features, label, censor)
Spark -> AFT : 建立 AFT 模型實例
使用者 -> AFT : 設定模型參數
AFT -> Spark : 呼叫 .fit(train_data)
Spark -> Spark : 迭代計算擬合模型
Spark --> AFT : 回傳已訓練的模型 (model)
使用者 -> Spark : 準備測試資料
Spark -> model : 呼叫 .transform(test_data)
model -> Spark : 進行預測
Spark --> 使用者 : 回傳包含預測結果的 DataFrame
@enduml

生存分析在客戶終身價值(CLV)預測中的應用

客戶終身價值(CLV)是企業評估客戶長期價值的重要指標。透過生存分析(Survival Analysis)結合客戶交易資料,可以更準確地預測客戶的終身價值。本篇文章將介紹如何使用生存分析模型來預測客戶流失並計算客戶終身價值。

資料生成與準備

首先,我們需要生成模擬的客戶交易資料和客戶流失資料。這些資料將用於後續的分析和建模。

程式碼範例:生成模擬資料

# 生成模擬資料
num_customers = 100
num_transactions = 1000
customers = [f"customer_{i}" for i in range(1, num_customers + 1)]
transaction_dates = [datetime.date.today() - datetime.timedelta(days=random.randint(1, 365)) for _ in range(num_transactions)]
transaction_amounts = [random.uniform(10.0, 1000.0) for _ in range(num_transactions)]
churn_dates = [random.choice([datetime.date.today() - datetime.timedelta(days=random.randint(30, 365)), None]) for _ in range(num_customers)]
transaction_data = [(random.choice(customers), transaction_dates[i], transaction_amounts[i]) for i in range(num_transactions)]
churn_data = [(customers[i], churn_dates[i]) for i in range(num_customers)]
transaction_columns = ["CustomerID", "TransactionDate", "Amount"]
churn_columns = ["CustomerID", "ChurnDate"]

# 建立DataFrames
transaction_df = spark.createDataFrame(transaction_data, transaction_columns)
churn_df = spark.createDataFrame(churn_data, churn_columns)

# 顯示生成的資料
transaction_df.show(5)
churn_df.show(5)

程式碼解析:

  1. 模擬資料生成:我們首先生成了 100 位客戶和 1000 筆交易資料,交易日期和金額都是隨機生成的。
  2. 客戶流失日期:客戶的流失日期也是隨機生成的,部分客戶可能尚未流失(右刪失資料),因此其流失日期為 None
  3. DataFrame 建立:使用 Spark 將生成的資料轉換為 DataFrame,以便於後續的資料處理。

計算RFM指標

RFM(Recency、Frequency、Monetary)是衡量客戶價值的三個重要指標。我們需要計算每個客戶的RFM值。

程式碼範例:計算RFM指標

# 計算Recency、Frequency和Monetary值
from pyspark.sql.types import DateType # 需要匯入 DateType
current_date = datetime.date.today()
transaction_df = transaction_df.withColumn("CurrentDate", lit(current_date).cast(DateType()))
recency_df = transaction_df.groupBy("CustomerID").agg(datediff(max("CurrentDate"), max("TransactionDate")).alias("Recency"))
frequency_df = transaction_df.groupBy("CustomerID").agg(count("TransactionDate").alias("Frequency"))
monetary_df = transaction_df.groupBy("CustomerID").agg(sum("Amount").alias("Monetary"))

# 合併RFM資料
rfm_df = recency_df.join(frequency_df, "CustomerID").join(monetary_df, "CustomerID")

# 將RFM資料與客戶流失資料合併
rfm_churn_df = rfm_df.join(churn_df, "CustomerID")

程式碼解析:

  1. Recency 計算:計算客戶最近一次交易的日期與當前日期的差值。
  2. Frequency 計算:統計每個客戶的交易次數。
  3. Monetary 計算:計算每個客戶的交易總金額。
  4. RFM 合併:將三個指標合併到一個 DataFrame 中,並與客戶流失資料進行連結。

生存分析模型建立與預測

使用加速失效時間(AFT)生存迴歸模型來預測客戶的生存時間,並進一步計算客戶終身價值。

程式碼範例:建立AFT模型並進行預測

# 資料預處理
future_date = current_date + datetime.timedelta(days=365)
rfm_churn_df = rfm_churn_df.withColumn("ChurnDate", when(col("ChurnDate").isNull(), lit(future_date)).otherwise(col("ChurnDate")))
tenure_df = transaction_df.groupBy("CustomerID").agg(min("TransactionDate").alias("FirstTransactionDate"))
rfm_churn_df = rfm_churn_df.join(tenure_df, "CustomerID")
rfm_churn_df = rfm_churn_df.withColumn("Tenure", datediff(col("ChurnDate"), col("FirstTransactionDate")))
rfm_churn_df = rfm_churn_df.withColumn("Event", when(col("ChurnDate") < future_date, 1).otherwise(0))
rfm_churn_df = rfm_churn_df.filter(col("Tenure") > 0)

# 組裝特徵向量
assembler = VectorAssembler(inputCols=["Recency", "Frequency", "Monetary"], outputCol="features")
rfm_churn_df = assembler.transform(rfm_churn_df)

# 分割訓練和測試資料
train_df, test_df = rfm_churn_df.randomSplit([0.8, 0.2], seed=12345)

# 建立AFT模型
aft = AFTSurvivalRegression(featuresCol="features", labelCol="Tenure", censorCol="Event")
aft_model = aft.fit(train_df)

# 進行預測
predictions = aft_model.transform(test_df)
predictions.select("CustomerID", "features", "Event", "Tenure", "prediction").show(5)

程式碼解析:

  1. 資料預處理:填充缺失的流失日期,計算客戶的存續時間(Tenure),並標記事件是否發生(Event,1 表示已流失,0 表示未流失)。
  2. 特徵向量組裝:將 RFM 指標組裝成特徵向量,作為模型的輸入。
  3. 模型訓練與預測:使用 AFT 模型進行訓練和預測,輸出客戶的預測存續時間。

客戶終身價值(CLV)計算

根據生存分析的預測結果,計算每個客戶的預期終身價值。

程式碼範例:計算CLV

# 計算平均月收入
rfm_churn_df = rfm_churn_df.withColumn("AvgMonthlyRevenue", col("Monetary") / (col("Tenure") / 30))

# 定義CLV計算函式
def calculate_clv(rfm_churn_df, predictions):
    rfm_churn_df = rfm_churn_df.join(predictions.select("CustomerID", "prediction"), "CustomerID")
    rfm_churn_df = rfm_churn_df.withColumn("PredictedCLV", col("AvgMonthlyRevenue") * col("prediction") / 30)
    return rfm_churn_df

# 應用CLV計算
clv_df = calculate_clv(rfm_churn_df, predictions)

# 計算實際CLV
clv_df = clv_df.withColumn("ActualCLV", col("AvgMonthlyRevenue") * (col("Tenure") / 30))

程式碼解析:

  1. 平均月收入計算:根據客戶的交易總金額和存續時間,計算平均月收入。
  2. CLV 計算:結合預測的存續時間,計算每個客戶的預期終身價值 (Predicted CLV)。
  3. 實際 CLV 計算:使用實際的存續時間,計算實際的終身價值 (Actual CLV),用於評估模型的準確性。

模型評估

使用迴歸評估器來評估CLV預測的準確性。

程式碼範例:評估CLV預測結果

# 評估CLV預測結果
evaluator_mae = RegressionEvaluator(labelCol="ActualCLV", predictionCol="PredictedCLV", metricName="mae")
mae = evaluator_mae.evaluate(clv_df)
print(f"Mean Absolute Error (MAE): {mae}")

evaluator_mse = RegressionEvaluator(labelCol="ActualCLV", predictionCol="PredictedCLV", metricName="mse")
mse = evaluator_mse.evaluate(clv_df)
print(f"Mean Squared Error (MSE): {mse}")

evaluator_rmse = RegressionEvaluator(labelCol="ActualCLV", predictionCol="PredictedCLV", metricName="rmse")
rmse = evaluator_rmse.evaluate(clv_df)
print(f"Root Mean Squared Error (RMSE): {rmse}")

程式碼解析:

  1. MAE、MSE、RMSE 計算:使用 RegressionEvaluator(需從 pyspark.ml.evaluation 匯入)計算平均絕對誤差(MAE)、均方誤差(MSE)和均方根誤差(RMSE),以評估模型的預測準確性。

RFM-CLV 分析完整流程圖

此圖展示了從資料生成到模型評估的完整客戶生命週期價值 (CLV) 分析流程。

@startuml
!theme _none_
skinparam dpi auto
skinparam defaultFontName "Microsoft JhengHei UI"
skinparam minClassWidth 100
skinparam defaultFontSize 16
title RFM-CLV 分析完整流程

start
:生成合成交易與流失資料;
:計算 RFM 指標\n(Recency, Frequency, Monetary);
:資料預處理\n(計算 Tenure, Event);
:特徵工程\n(VectorAssembler);
:分割訓練/測試資料集;
:使用訓練集擬合 AFT 生存模型;
:對測試集進行預測\n(預測客戶生存時間);
:計算平均月收入;
:計算預測 CLV 與實際 CLV;
:使用迴歸評估器評估模型\n(MAE, MSE, RMSE);
stop
@enduml