返回文章列表

PyTorch深度學習模型預測Tesla股價

本文示範使用 PyTorch 建立深度學習模型,預測 Tesla 股票價格。利用從 Yahoo! Finance 取得的 Tesla 股票歷史資料,結合 Spark 進行資料預處理,並使用 AWS S3 和 EC2

機器學習 深度學習

深度學習在金融領域的應用日益普及,尤其在股票價格預測方面。本文將示範如何使用 PyTorch 建立一個深度學習模型,並以預測 Tesla 股票價格為例,探討如何利用開盤價、最高價、最低價和成交量等特徵預測收盤價。利用 PySpark 從 AWS S3 載入資料,並在 EC2 上進行預處理和模型訓練,展現了雲端運算與深度學習技術的結合應用。程式碼中使用 SparkSession 讀取 CSV 檔案,並運用 VectorAssembler 和 StandardScaler 進行特徵工程,確保模型輸入資料的品質。模型訓練過程包含資料載入器、損失函式和最佳化器的設定,以及訓練迴圈的執行,最後評估模型的預測準確性。

使用 PyTorch 進行迴歸分析的深度學習模型實作

本章節將展示如何使用 PyTorch 建立、訓練和評估一個用於迴歸任務的深度學習模型,並以預測 Tesla 股票價格為例。準確預測股票價格在金融領域至關重要,可以為投資決策和風險管理策略提供重要參考。

資料集準備

本範例使用從 Yahoo! Finance 取得的 Tesla 股票歷史日資料集。資料以 CSV 格式下載並儲存在 AWS S3 的 “instance1bucket” 儲存桶中。資料集的詳細內容在前一章節已進行過探討,本文將載入 CSV 檔案,將其從 S3 儲存桶複製到 EC2 例項的本地目錄,並列印前五行以確認特徵和目標變數的內容。

載入必要模組

import subprocess
from pyspark.sql import SparkSession
import logging
  • import subprocess:匯入 subprocess 模組,用於執行 shell 命令或與系統 shell 互動。在本例中,用於執行 aws s3 cp 命令,將 CSV 檔案從 S3 儲存桶複製到本地目錄。
  • from pyspark.sql import SparkSession:匯入 SparkSession 類別,用於建立 Spark 應用程式並載入 CSV 檔案到 DataFrame 中。
  • import logging:匯入 logging 模組,用於生成日誌訊息,協助除錯和監控 Python 應用程式。

定義資料載入函式

def load_data(file_path: str):
    """
    使用 SparkSession 載入股票價格資料從 CSV 檔案。
    """
    spark = (SparkSession.builder
             .appName("StockPricePrediction")
             .getOrCreate())
    df = spark.read.csv(
        file_path,
        header=True,
        inferSchema=True
    )
    return df
  • def load_data(file_path: str):定義 load_data 函式,接受一個字串引數 file_path,表示 CSV 檔案的路徑。
  • spark = SparkSession.builder.appName("StockPricePrediction").getOrCreate():建立或取得一個名為 “StockPricePrediction” 的 SparkSession 物件。
  • df = spark.read.csv(file_path, header=True, inferSchema=True):使用 SparkSession 的 read.csv 方法讀取 CSV 檔案到 DataFrame 中,並自動推斷欄位資料型別。

定義資料複製和列印函式

def copy_and_print_data():
    """
    將 CSV 檔案從 S3 儲存桶複製到本地目錄並列印前 5 筆觀測值。
    """
    s3_bucket_path = "s3://instance1bucket/TSLA_stock.csv"
    local_file_path = "/home/ubuntu/airflow/dags/TSLA_stock.csv"
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)
    try:
        subprocess.run(
            [
                "aws", "s3", "cp",
                s3_bucket_path,
                local_file_path
            ],
            check=True
        )
        df = load_data(local_file_path)
        print("DataFrame 的前 5 筆觀測值:")
        df.show(5)
    except FileNotFoundError:
        logger.error(f"在 {s3_bucket_path} 未找到資料檔案")
    except subprocess.CalledProcessError as e:
        logger.error(f"從 S3 複製資料時出錯:{e}")
  • def copy_and_print_data():定義 copy_and_print_data 函式,無引數,負責將 CSV 檔案從 S3 複製到本地並列印前五筆資料。
  • subprocess.run(["aws", "s3", "cp", s3_bucket_path, local_file_path], check=True):使用 subprocess 執行 aws s3 cp 命令,將檔案從 S3 複製到本地。

程式碼解析

  1. load_data 函式:此函式利用 SparkSession 載入 CSV 檔案到 DataFrame。它設定了應用程式名稱為 “StockPricePrediction”,這有助於在 Spark 的監控工具中識別應用程式。
  2. copy_and_print_data 函式:此函式首先組態 logging 基本設定,然後嘗試將 CSV 檔案從指定的 S3 路徑複製到本地路徑。如果成功,它會呼叫 load_data 函式載入本地 CSV 檔案到 DataFrame,並列印前五筆觀測值。如果過程中出現錯誤,如檔案未找到或複製過程出錯,它會記錄錯誤訊息。
程式碼最佳化與擴充套件

系統架構圖示

@startuml
skinparam backgroundColor #FEFEFE
skinparam defaultTextAlignment center
skinparam rectangleBackgroundColor #F5F5F5
skinparam rectangleBorderColor #333333
skinparam arrowColor #333333

title 程式碼解析

rectangle "下載" as node1
rectangle "複製" as node2
rectangle "載入" as node3
rectangle "處理" as node4
rectangle "輸出" as node5

node1 --> node2
node2 --> node3
node3 --> node4
node4 --> node5

@enduml

此圖示說明瞭資料從 CSV 檔案下載到 S3,再複製到 EC2 本地,最終被載入到 Spark DataFrame 中進行處理,並輸入到 PyTorch 模型進行訓練和預測的流程。

詳細解析

  1. CSV 檔案下載至 S3:首先,從 Yahoo! Finance 下載 Tesla 股票的歷史資料,並將其儲存在 AWS S3 的儲存桶中。
  2. 複製至 EC2 本地目錄:使用 aws s3 cp 命令將 CSV 檔案從 S3 複製到 EC2 例項的本地目錄中。
  3. 載入至 Spark DataFrame:透過 SparkSession 將本地的 CSV 檔案載入到 DataFrame 中,以便進行進一步的資料處理。
  4. PyTorch 模型訓練:將處理好的資料輸入到 PyTorch 模型中進行訓練,以建立能夠預測股票價格的深度學習模型。
  5. 輸出預測結果:最終,訓練好的模型將用於預測 Tesla 股票的未來價格,並輸出預測結果。

使用 PySpark 處理 S3 上的 Tesla 股價資料

本章節將介紹如何使用 PySpark 將儲存在 AWS S3 儲存桶中的 Tesla 股價 CSV 檔案複製到本地目錄,並列印出 DataFrame 的前五筆觀察資料。

程式碼解析

以下為實作此功能的 Python 指令碼:

import logging
import subprocess
from pyspark.sql import SparkSession

def load_data(file_path):
    # 初始化 SparkSession
    spark = SparkSession.builder.appName("Tesla Stock Data").getOrCreate()
    # 讀取 CSV 檔案
    df = spark.read.csv(file_path, header=True, inferSchema=True)
    return df

def copy_and_print_data():
    """
    從 S3 儲存桶複製 CSV 檔案到本地目錄,並列印前五筆觀察資料。
    """
    s3_bucket_path = "s3://instance1bucket/TSLA_stock.csv"
    local_file_path = "/home/ubuntu/airflow/dags/TSLA_stock.csv"

    # 組態日誌模組
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)

    try:
        # 使用 AWS CLI 將檔案從 S3 複製到本地
        subprocess.run(["aws", "s3", "cp", s3_bucket_path, local_file_path], check=True)
        
        # 載入資料到 DataFrame
        df = load_data(local_file_path)
        
        # 列印前五筆觀察資料
        print("DataFrame 的前五筆觀察資料:")
        df.show(5)
    
    except FileNotFoundError:
        logger.error(f"在 {s3_bucket_path} 找不到資料檔案")
    
    except subprocess.CalledProcessError as e:
        logger.error(f"從 S3 複製資料時發生錯誤:{e}")

if __name__ == "__main__":
    copy_and_print_data()

內容解密:

  1. load_data 函式:初始化一個 SparkSession,並使用 spark.read.csv 方法讀取指定的 CSV 檔案,將其載入到 DataFrame 中。

    • header=True 表示 CSV 檔案包含標題列。
    • inferSchema=True 表示自動推斷欄位的資料型別。
  2. copy_and_print_data 函式

    • 定義了 S3 上的檔案路徑和本地儲存路徑。
    • 組態日誌模組以記錄執行過程中的資訊。
    • 使用 subprocess.run 執行 AWS CLI 命令,將 S3 上的 CSV 檔案複製到本地。
    • 載入本地 CSV 檔案到 DataFrame,並列印前五筆資料。
    • 使用 try-except 區塊處理可能發生的異常,如檔案未找到或命令執行錯誤。
  3. if __name__ == "__main__"::判斷指令碼是否直接被執行,如果是,則呼叫 copy_and_print_data 函式。

執行指令碼

將上述指令碼儲存為 spark_data_processing.py,並放置於 /home/ubuntu/airflow/dags 目錄下。然後,透過以下步驟執行:

  1. 切換到指令碼所在目錄:cd /home/ubuntu/airflow/dags
  2. 啟動虛擬環境:source myenv/bin/activate
  3. 執行指令碼:python3 spark_data_processing.py

輸出結果

執行後,輸出結果如下:

DataFrame 的前五筆觀察資料:
+
---
-
---
---
+
---
-
---
+
---
-
---
+
---
-
---
+
---
-
---
+
---
-
---
---
+
|      Date|   Open|   High|    Low|  Close|    Volume|
+
---
-
---
---
+
---
-
---
+
---
-
---
+
---
-
---
+
---
-
---
+
---
-
---
---
+
|  2/23/24| 195.31| 197.57| 191.50| 191.97|  78670300|
|  2/22/24| 194.00| 198.32| 191.36| 197.41|  92739500|
|  2/21/24| 193.36| 199.44| 191.95| 194.77| 103844000|
|  2/20/24| 196.13| 198.60| 189.13| 193.76| 104545800|
|  2/16/24| 202.06| 203.17| 197.40| 199.95| 111173600|
+
---
-
---
---
+
---
-
---
+
---
-
---
+
---
-
---
+
---
-
---
+
---
-
---
---
+

該資料集包含 Tesla 股價從2019年2月26日至2024年2月23日的歷史資料,共1,258筆觀察記錄。輸出的前五筆資料展示了每個交易日的開盤價(Open)、最高價(High)、最低價(Low)、收盤價(Close)和交易量(Volume)。這些指標能夠反映 Tesla 股價在特定時間段內的波動情況。

使用PyTorch預測Tesla股價

本章節利用先前探索的資料集,實作深度學習演算法來預測Tesla的股價。資料集中,模型的特徵包含開盤價(Open)、最高價(High)、最低價(Low)和成交量(Volume),而目標變數是收盤價(Close)。股票的開盤價能夠提供交易日開始時市場情緒的寶貴資訊,反映投資者對新聞、事件或隔夜發展的初始反應。同樣地,成交量也很重要,因為高成交量通常表示投資者興趣增加,並可能預示潛在的價格變動。成交量也可以用作流動性的代理,這對於準確的價格發現至關重要。至於交易期間達到的最高和最低價格,它們提供了價格波動和交易範圍的見解。最高價格表示買盤興趣的最高水平,而最低價格代表賣盤興趣的最低水平。分析最高和最低價格之間的範圍有助於識別趨勢、支撐和阻力水平。

深度學習模型的建立與訓練

透過將這些特徵納入預測模型,我們可以捕捉基本的市場動態(如開盤情緒和交易活動)和技術層面(如價格波動和交易範圍),這對於準確預測股價至關重要。然而,值得注意的是,我們的模型旨在說明而非提供完整的解釋。雖然這些特徵很有價值,但納入額外的特徵,如技術指標、基本面資料或市場情緒,可能會進一步增強模型的預測能力。然而,這超出了本章的範圍。

程式碼實作步驟

整個建模過程包括以下十個步驟:

  1. 匯入必要的函式庫
  2. 設定日誌系統
  3. 載入資料
  4. 資料預處理
  5. 建立DataLoader
  6. 訓練模型
  7. 評估模型
  8. 主函式
  9. 輔助函式
  10. 執行
步驟1:匯入必要的函式庫
import logging
import subprocess
from typing import Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler, StandardScaler
from pyspark.sql import DataFrame

內容解密:

  • logging:用於在程式執行過程中生成日誌訊息,幫助追蹤程式進度、除錯問題和監控活動。
  • subprocess:允許程式執行系統命令並與其他程式互動,用於從AWS S3複製資料到EC2本地目錄。
  • typing:用於指定函式引數和傳回值的預期型別,提高程式碼的可讀性和健壯性。
  • numpy:用於科學計算,提供對大型多維陣列和矩陣的支援,以及一系列數學函式,用於將Spark DataFrame欄位轉換為NumPy陣列。
  • torchtorch.nntorch.optim:PyTorch的核心元件,用於定義和訓練神經網路模型。
  • DataLoaderTensorDataset:PyTorch中的類別,用於建立資料載入器,以便在訓練和測試神經網路模型時批次載入資料。
  • pyspark.sqlpyspark.ml.featureDataFrame:PySpark的元件,用於載入、預處理和操作結構化資料,並將其轉換為PyTorch張量進行神經網路模型的訓練。
步驟2:設定日誌系統
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

內容解密:

  • logging.basicConfig(level=logging.INFO):設定日誌系統顯示INFO級別或更高的重要訊息,確保程式執行過程中顯示足夠重要的日誌訊息。
  • logger = logging.getLogger(__name__):建立一個日誌記錄器,用於記錄程式執行過程中的資訊和錯誤訊息。