返回文章列表

RLlib強化學習環境建置與課程學習應用

本文探討如何使用 Ray RLlib 構建和訓練多智慧體強化學習環境,並講解如何應用課程學習策略提升訓練效率。文章涵蓋了多智慧體環境的建立、自定義策略、策略伺服器與客戶端應用、高階迷宮環境建置、以及離線資料處理等關鍵技術,提供程式碼範例和詳細說明,幫助讀者理解並應用 RLlib 進行強化學習任務。

機器學習 強化學習

隨著強化學習應用日益廣泛,如何有效地建置和訓練複雜環境成為關鍵挑戰。本文以迷宮環境為例,逐步展示如何使用 Ray RLlib 這個強大的強化學習框架構建多智慧體環境、應用課程學習策略,並利用離線資料進行訓練。從基礎的多智慧體環境建置到進階的課程學習與離線資料處理,本文提供完整的程式碼範例和技術說明,幫助讀者掌握 RLlib 的核心概念和應用技巧。特別是針對複雜環境的訓練,課程學習提供了一種循序漸進的策略,有效提升了學習效率。此外,本文也探討瞭如何使用策略伺服器和客戶端進行分散式訓練,以及如何利用離線資料進行模型訓練和評估,為實際應用場景提供更靈活的解決方案。

多智慧體強化學習與RLlib環境的整合應用

在多智慧體強化學習(MARL)中,決定一個回合何時結束是至關重要的,這完全取決於所處理的問題和期望達到的目標。在本文中,我們將探討如何使用Ray RLlib來處理多智慧體環境,並介紹如何使用策略伺服器和客戶端進行分散式強化學習。

多智慧體環境的建立

首先,我們需要定義一個多智慧體環境。在這個例子中,我們建立了一個名為MultiAgentMaze的環境,其中包含多個智慧體和一個目標。環境的render方法被修改為在螢幕上列印迷宮,並用ID標記每個智慧體。

def render(self, *args, **kwargs):
    os.system('cls' if os.name == 'nt' else 'clear')
    grid = [['| ' for _ in range(5)] + ["|\n"] for _ in range(5)]
    grid[self.goal[0]][self.goal[1]] = '|G'
    grid[self.agents[1][0]][self.agents[1][1]] = '|1'
    grid[self.agents[2][0]][self.agents[2][1]] = '|2'
    print(''.join([''.join(grid_row) for grid_row in grid]))

內容解密:

  • os.system('cls' if os.name == 'nt' else 'clear'):根據作業系統的不同,清除終端螢幕。
  • grid:初始化一個5x5的迷宮網格,用於顯示智慧體和目標的位置。
  • grid[self.goal[0]][self.goal[1]] = '|G':將目標位置標記為|G
  • grid[self.agents[1][0]][self.agents[1][1]] = '|1'grid[self.agents[2][0]][self.agents[2][1]] = '|2':分別將兩個智慧體的位置標記為|1|2

多智慧體強化學習訓練

在定義了多智慧體環境之後,我們可以使用RLlib的DQNConfig來訓練一個DQN演算法。預設情況下,兩個智慧體將分享同一個策略。

from ray.rllib.algorithms.dqn import DQNConfig
simple_trainer = DQNConfig().environment(env=MultiAgentMaze).build()
simple_trainer.train()

內容解密:

  • DQNConfig().environment(env=MultiAgentMaze):組態DQN演算法以使用MultiAgentMaze環境。
  • .build():構建DQN訓練器。
  • .train():開始訓練過程。

自定義多智慧體策略

我們可以透過呼叫.multi_agent方法並設定policiespolicy_mapping_fn引數來為不同的智慧體分配不同的策略。

algo = DQNConfig()\
    .environment(env=MultiAgentMaze)\
    .multi_agent(
        policies={
            "policy_1": (None, env.observation_space, env.action_space, {"gamma": 0.80}),
            "policy_2": (None, env.observation_space, env.action_space, {"gamma": 0.95}),
        },
        policy_mapping_fn=lambda agent_id: f"policy_{agent_id}",
    ).build()
print(algo.train())

內容解密:

  • .multi_agent:啟用多智慧體組態。
  • policies:定義兩個不同的策略,分別具有不同的gamma值。
  • policy_mapping_fn:將每個智慧體對映到對應的策略。

策略伺服器和客戶端的應用

在某些情況下,我們可能需要在不同的機器上執行環境和RLlib演算法。這可以透過使用策略伺服器和客戶端來實作。伺服器負責執行RLlib演算法,而客戶端則與伺服器通訊以取得下一步的動作。

定義策略伺服器

# policy_server.py
import ray
from ray.rllib.agents.dqn import DQNConfig
from ray.rllib.env.policy_server_input import PolicyServerInput
import gym

ray.init()

def policy_input(context):
    return PolicyServerInput(context, "localhost", 9900)

config = DQNConfig()\
    .environment(env=None, action_space=gym.spaces.Discrete(4), observation_space=gym.spaces.Discrete(5*5))\
    .debugging(log_level="INFO")\
    .rollouts(num_rollout_workers=0)\
    .offline_data(input=policy_input, input_evaluation=[])

algo = config.build()

內容解密:

  • PolicyServerInput(context, "localhost", 9900):建立一個在本地主機上監聽9900埠的策略伺服器輸入。
  • .environment(env=None, ...):由於伺服器不直接與環境互動,因此將環境設定為None,並手動指定動作空間和觀察空間。

定義策略客戶端

客戶端連線到伺服器,並請求下一步的動作。

# policy_client.py
from ray.rllib.env.policy_client import PolicyClient

client = PolicyClient("http://localhost:9900")

內容解密:

  • PolicyClient("http://localhost:9900"):建立一個連線到本地主機9900埠上的策略伺服器的客戶端。

強化學習的高階環境建置與課程學習應用

在前面的章節中,我們已經探討瞭如何使用 Ray RLlib 處理簡單的強化學習環境。然而,在實際應用中,我們經常需要面對更複雜的環境和挑戰。本章節將介紹如何建置一個更具挑戰性的迷宮環境,並探討一些高階概念,例如課程學習,以幫助解決這些複雜問題。

建置高階迷宮環境

首先,我們將現有的 GymEnvironment 升級為一個更具挑戰性的 AdvancedEnv。具體改進包括:

  1. 增加迷宮大小:將迷宮大小從 5x5 擴充套件到 11x11。
  2. 引入障礙:在迷宮中加入障礙物,智慧體碰到這些障礙物時會受到懲罰(負獎勵)。
  3. 隨機初始位置:智慧體的初始位置將被隨機化。

初始化高階環境

from gym.spaces import Discrete
import random
import os

class AdvancedEnv(GymEnvironment):
    def __init__(self, seeker=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.maze_len = 11
        self.action_space = Discrete(4)
        self.observation_space = Discrete(self.maze_len * self.maze_len)
        if seeker:
            assert 0 <= seeker[0] < self.maze_len and 0 <= seeker[1] < self.maze_len
            self.seeker = seeker
        else:
            self.reset()
        self.goal = (self.maze_len-1, self.maze_len-1)
        self.info = {'seeker': self.seeker, 'goal': self.goal}
        self.punish_states = [(i, j) for i in range(self.maze_len) for j in range(self.maze_len) if i % 2 == 1 and j % 2 == 0]

環境重置與獎勵機制

def reset(self):
    """隨機重置尋找者的位置,並傳回觀察值"""
    self.seeker = (random.randint(0, self.maze_len - 1), random.randint(0, self.maze_len - 1))
    return self.get_observation()

def get_observation(self):
    """將尋找者的位置編碼為整數"""
    return self.maze_len * self.seeker[0] + self.seeker[1]

def get_reward(self):
    """獎勵找到目標,懲罰進入禁止狀態"""
    reward = -1 if self.seeker in self.punish_states else 0
    reward += 5 if self.seeker == self.goal else 0
    return reward

def render(self, *args, **kwargs):
    """渲染環境,例如透過列印其表示"""
    os.system('cls' if os.name == 'nt' else 'clear')
    grid = [['| ' for _ in range(self.maze_len)] + ["|\n"] for _ in range(self.maze_len)]
    for punish in self.punish_states:
        grid[punish[0]][punish[1]] = '|X'
    grid[self.goal[0]][self.goal[1]] = '|G'
    grid[self.seeker[0]][self.seeker[1]] = '|S'
    print(''.join([''.join(grid_row) for grid_row in grid]))

內容解密:

  1. reset 方法:用於隨機重置智慧體的位置,以增加學習的多樣性。
  2. get_observation 方法:將智慧體的位置轉換為整數觀察值。
  3. get_reward 方法:定義獎勵機制,智慧體到達目標獲得正獎勵,碰到障礙物獲得負獎勵。
  4. render 方法:用於視覺化環境狀態,包括智慧體、目標和障礙物的位置。

課程學習的應用

課程學習是一種有效的強化學習策略,透過逐步引入更困難的任務來加速學習過程。RLlib 支援為演算法提供一個課程來學習。

如何應用課程學習

  1. 定義簡單的初始狀態:選擇容易學習的初始狀態開始訓練。
  2. 逐步增加難度:隨著訓練的進行,逐漸引入更困難的狀態。

課程學習的核心在於選擇合適的初始狀態和調整策略。這需要對環境有深入的理解,並設計出合理的課程結構。

使用客戶端進行遠端推斷

以下是一個使用 PolicyClient 連線到遠端伺服器的客戶端程式碼示例,用於進行強化學習推斷:

# policy_client.py
import gym
from ray.rllib.env.policy_client import PolicyClient
from maze_gym_env import GymEnvironment

if __name__ == "__main__":
    env = GymEnvironment()
    client = PolicyClient("http://localhost:9900", inference_mode="remote")
    obs = env.reset()
    episode_id = client.start_episode(training_enabled=True)
    while True:
        action = client.get_action(episode_id, obs)
        obs, reward, done, info = env.step(action)
        client.log_returns(episode_id, reward, info=info)
        if done:
            client.end_episode(episode_id, obs)
            exit(0)

內容解密:

  1. 建立客戶端:連線到指定的遠端伺服器地址,並啟用遠端推斷模式。
  2. 開始新的一輪:透過 start_episode 方法開始新的訓練輪次。
  3. 取得動作:根據當前觀察值從伺服器取得動作。
  4. 執行動作並記錄反饋:在環境中執行動作,並將獎勵和資訊記錄到伺服器。
  5. 結束本輪訓練:當環境狀態達到終止條件時,結束本輪訓練並離開。

本章節介紹瞭如何構建更具挑戰性的強化學習環境,以及如何應用課程學習來提高學習效率。同時,我們還展示瞭如何使用客戶端進行遠端推斷,為實際應用提供了參考。透過這些技術,我們可以更好地應對複雜的強化學習任務。

強化學習中的課程學習與離線資料處理

在強化學習中,課程學習是一種透過逐步增加任務難度來提升學習效率的方法。為了實作課程學習,我們需要定義一個可以動態調整難度的環境。在本章中,我們將以一個進階的迷宮環境為例,展示如何使用 Ray RLlib 實作課程學習。

課程學習環境的定義

首先,我們需要定義一個 CurriculumEnv 類別,該類別繼承自 AdvancedEnvTaskSettableEnvTaskSettableEnv 是 RLlib 提供的一個介面,用於定義如何取得和設定任務難度。

from ray.rllib.env.apis.task_settable_env import TaskSettableEnv

class CurriculumEnv(AdvancedEnv, TaskSettableEnv):
    def __init__(self, *args, **kwargs):
        AdvancedEnv.__init__(self)

    def difficulty(self):
        return abs(self.seeker[0] - self.goal[0]) + abs(self.seeker[1] - self.goal[1])

    def get_task(self):
        return self.difficulty()

    def set_task(self, task_difficulty):
        while not self.difficulty() <= task_difficulty:
            self.reset()

內容解密:

  • difficulty 方法計算當前狀態的難度,即尋找者與目標之間的曼哈頓距離。
  • get_task 方法傳回當前難度。
  • set_task 方法重置環境,直到難度小於或等於指定的任務難度。

課程學習函式的定義

接下來,我們需要定義一個課程學習函式 curriculum_fn,該函式根據訓練的進度動態調整任務難度。

def curriculum_fn(train_results, task_settable_env, env_ctx):
    time_steps = train_results.get("timesteps_total")
    difficulty = time_steps // 1000
    print(f"Current difficulty: {difficulty}")
    return difficulty

內容解密:

  • curriculum_fn 函式根據總訓練步數 time_steps 計算當前難度。
  • 每 1000 步,難度增加 1。

使用課程學習進行訓練

為了使用課程學習,我們需要在 RLlib 演算法組態中設定 env_task_fn 屬性為我們的 curriculum_fn

from ray.rllib.algorithms.dqn import DQNConfig
import tempfile

temp = tempfile.mkdtemp()
trainer = (
    DQNConfig()
    .environment(env=CurriculumEnv, env_task_fn=curriculum_fn)
    .offline_data(output=temp)
    .build()
)

for i in range(15):
    trainer.train()

內容解密:

  • 建立一個臨時資料夾來儲存訓練資料。
  • 設定 CurriculumEnv 為環境,並將 curriculum_fn 指定給 env_task_fn 屬性。
  • 使用 offline_data 方法將輸出儲存到臨時資料夾。

離線資料處理

在前面的課程學習範例中,我們將訓練資料儲存到了一個臨時資料夾。現在,我們可以使用這些資料進行離線訓練。

imitation_algo = (
    DQNConfig()
    .environment(env=AdvancedEnv)
    .evaluation(off_policy_estimation_methods={})
    .offline_data(input_=temp)
    .exploration(explore=False)
    .build()
)

for i in range(10):
    imitation_algo.train()

imitation_algo.evaluate()

內容解密:

  • 建立一個新的 DQNConfig,並將輸入設定為之前儲存的臨時資料夾。
  • 設定 exploreFalse,以便在訓練過程中不進行探索。
  • 進行 10 次迭代訓練,並評估演算法的表現。

其他進階主題

本章介紹了 Ray RLlib 中的課程學習和離線資料處理。除此之外,RLlib 還提供了許多其他進階功能,例如支援多種不同的環境、組態實驗、訓練課程學習和模仿學習等。這些功能使得 RLlib 成為一個非常靈活和強大的強化學習框架。