JAX 作為一個高效能運算函式庫,在深度學習和科學計算領域中,能有效處理大規模資料的平行計算是不可或缺的。本文將深入探討如何使用 JAX 平行化向量運算,提升運算效率。首先,需設定環境變數以指定使用的裝置數量,接著匯入 JAX 函式庫。本文案例將根據八個 CPU 裝置進行示範。向量點積是基本的向量運算,JAX 提供 jnp.vdot 函式計算點積。針對大型向量列表的點積計算,單一裝置的效率有限,因此需要將工作分配給多個裝置。本文將示範如何生成大型向量列表,並使用 JAX 的向量化運算和 jax.pmap 函式進行平行計算,充分利用多個 CPU 或 GPU 加速運算。此外,文章也將探討如何使用 jax.vmap 函式進行向量化,以及如何結合 jax.pmap 和 jax.vmap 進行更複雜的平行計算,並提供程式碼範例和結果分析,幫助讀者理解如何有效地使用 JAX 進行平行計算。
平行計算的力量:使用JAX加速向量運算
在深度學習和科學計算中,能夠高效地處理大規模資料的平行計算是一種寶貴的資源。JAX是一個強大的函式庫,它允許我們利用多個CPU或GPU加速計算。在本章中,我們將探討如何使用JAX來平行化向量運算。
首先,我們需要設定環境變數,以指定我們想要使用的裝置數量。然後,我們可以匯入JAX函式庫。在這個例子中,我們有八個CPU裝置可用。
向量點積的計算
計算兩個向量之間的點積是一種基本的向量運算。JAX提供了一個名為jnp.vdot的函式來計算點積。以下是如何定義一個計算兩個向量點積的函式:
import jax.numpy as jnp
def dot(v1, v2):
return jnp.vdot(v1, v2)
# 測試函式
v1 = jnp.array([1., 1., 1.])
v2 = jnp.array([1., 2., -1])
result = dot(v1, v2)
print(result) # Output: 2.0
對應元素列表的點積計算
如果我們有兩個向量列表,並且想要計算對應元素之間的點積呢?這種情況下,列表可能非常大,以至於單一加速器無法高效地平行執行所有乘法運算。將工作分配給多個加速器可以幫助我們解決這個問題。
以下是如何生成兩個大型向量列表,並計算對應元素之間的點積:
from jax import random
# 生成隨機資料
rng_key = random.PRNGKey(42)
vs = random.normal(rng_key, shape=(20_000_000, 3))
# 分割資料
v1s = vs[:10_000_000, :]
v2s = vs[10_000_000:, :]
print(v1s.shape, v2s.shape)
在這個例子中,我們生成了一個形狀為(20_000_000, 3)的隨機資料陣列,然後將其分割成兩個部分:v1s和v2s。每個部分都包含一半的原始資料。
接下來,我們可以使用JAX的向量化運算來計算對應元素之間的點積。這樣可以高效地利用多個CPU或GPU來加速計算。
使用JAX進行平行計算
JAX提供了多種方法來平行化計算,包括使用jax.pmap函式對多個裝置進行對映。以下是如何使用jax.pmap來平行化對應元素列表的點積計算:
import jax
from jax import numpy as jnp
# 定義點積函式
def dot(v1, v2):
return jnp.vdot(v1, v2)
# 對多個裝置進行對映
dot_pmap = jax.pmap(dot, in_axes=(0, 0))
# 測試平行計算
result = dot_pmap(v1s, v2s)
print(result)
在這個例子中,我們使用jax.pmap函式對dot函式進行對映,將其應用於v1s和v2s的對應元素上。這樣可以高效地利用多個CPU或GPU來加速計算。
使用 pmap 來平行化計算
在前面的章節中,我們討論瞭如何使用 vmap() 來向量化函式。然而,vmap() 只能在單一裝置上執行計算。如果我們想要使用多個裝置來平行化計算,那麼 pmap() 就是最簡單的方法。
pmap() 的作用是將一個函式對映到多個裝置上,每個裝置執行該函式的一部分。與 vmap() 不同,pmap() 不需要先編譯成 Jaxpr,然後再編譯成 XLA 操作。相反,pmap() 直接編譯函式成 XLA,並將其複製到多個裝置上,每個裝置執行該函式的一部分。
SPMD 平行化
pmap() 的目的是實作單程式多資料(SPMD)平行化。SPMD 是由 Flynn 在 1966 年提出的電腦架構分類別之一。Flynn 定義了四種初始分類別,根據同時執行的指令流和資料流的數量:
- 單指令單資料(SISD):這是一種順序電腦,沒有資料或指令平行性。
- 單指令多資料(SIMD):這是一種單指令同時應用到多個資料流的平行化。現代處理器有特殊的向量指令來實作這種平行化。
- 多指令單資料(MISD):多個指令操作一個資料流。這種架構通常用於容錯應用,例如飛行控制電腦。
- 多指令多資料(MIMD):多個處理器執行不同的指令在不同的資料上。
SPMD 是 MIMD 的一種子類別,指的是同一個程式在多個裝置上同時執行,但每個裝置上的輸入可以不同。
使用 pmap()
讓我們看看如何使用 pmap() 來平行化計算。首先,我們需要定義一個函式,這個函式將被對映到多個裝置上。然後,我們可以使用 pmap() 來將這個函式對映到多個裝置上,每個裝置執行該函式的一部分。
import jax
import jax.numpy as jnp
# 定義一個函式
def my_function(x):
return jnp.dot(x, x)
# 建立一個隨機陣列
x = jax.random.normal(jax.random.PRNGKey(0), (10, 10))
# 使用 pmap() 來平行化計算
result = jax.pmap(my_function)(x)
print(result)
在這個例子中,我們定義了一個函式 my_function()”,這個函式計算一個向量的點積。然後,我們建立了一個隨機陣列 x。最後,我們使用 pmap()來將my_function()` 對映到多個裝置上,每個裝置執行該函式的一部分。
內容解密:
jax.pmap()是用來平行化計算的,它可以將一個函式對映到多個裝置上,每個裝置執行該函式的一部分。jax.random.normal()是用來建立一個隨機陣列的。jnp.dot()是用來計算兩個向量的點積的。
圖表翻譯:
在這個圖表中,我們可以看到使用 pmap() 來平行化計算的流程。首先,我們定義一個函式,然後建立一個隨機陣列。最後,我們使用 pmap() 來將該函式對映到多個裝置上,每個裝置執行該函式的一部分。
平行化計算的應用
在前面的章節中,我們討論瞭如何使用 jax 函式庫來進行向量化和平行化計算。在這個章節中,我們將更深入地探討如何使用 vmap 和 pmap 來平行化計算。
使用 vmap 進行向量化計算
vmap 是 jax 中的一個函式,用於將一個函式應用到一個向量或陣列上。它可以自動將函式向量化,讓我們可以輕鬆地對大型資料進行計算。下面的例子展示瞭如何使用 vmap 來計算兩個向量的點積:
import jax
import jax.numpy as jnp
# 定義兩個向量
v1s = jnp.random.rand(10000000, 3)
v2s = jnp.random.rand(10000000, 3)
# 定義一個計算點積的函式
def dot(v1, v2):
return jnp.sum(v1 * v2, axis=1)
# 使用 vmap 對向量進行點積計算
dot_batched = jax.jit(jax.vmap(dot))
x_vmap = dot_batched(v1s, v2s)
print(x_vmap.shape)
使用 pmap 進行平行化計算
pmap 是 jax 中的一個函式,用於將一個函式平行化地應用到多個裝置上。它可以自動將函式分割成多個子任務,讓我們可以輕鬆地對大型資料進行平行化計算。下面的例子展示瞭如何使用 pmap 來計算兩個向量的點積:
# 定義一個計算點積的函式
def dot(v1, v2):
return jnp.sum(v1 * v2, axis=1)
# 使用 pmap 對向量進行點積計算
dot_parallel = jax.pmap(dot)
x_pmap = dot_parallel(v1s, v2s)
但是,如果我們直接使用 pmap 來計算兩個向量的點積,可能會遇到以下錯誤:
ValueError: compiling computation that requires 10000000 logical devices,
but only 8 XLA devices are available (num_replicas=10000000)
這是因為 pmap 預設會將每個元素對映到一個單獨的裝置上,而我們只有 8 個裝置可用。因此,需要重新排列資料,以便平行化軸的大小不超過裝置數量。
重新排列資料以進行平行化
為了避免上述錯誤,我們需要重新排列資料,以便平行化軸的大小不超過裝置數量。下面的例子展示瞭如何重新排列資料:
# 重新排列資料
v1s = v1s.reshape((8, -1, 3))
v2s = v2s.reshape((8, -1, 3))
# 使用 pmap 對重新排列的資料進行點積計算
dot_parallel = jax.pmap(dot)
x_pmap = dot_parallel(v1s, v2s)
透過重新排列資料,我們可以成功地使用 pmap 來進行平行化計算。
圖表翻譯:
圖表顯示了資料重新排列和平行化計算的流程。
內容解密:
上述例子展示瞭如何使用 vmap 和 pmap 來進行向量化和平行化計算。透過重新排列資料,我們可以成功地使用 pmap 來進行平行化計算。這些技術可以幫助我們加速大型資料的計算速度。
使用 pmap() 平行化計算
在上一節中,我們學習瞭如何使用 jax.vmap() 對函式進行向量化。然而,在某些情況下,我們需要對計算進行平行化,以便能夠更有效地利用多個裝置。為了實作這一點,JAX 提供了 jax.pmap() 函式,它可以將函式對映到多個裝置上。
對陣列進行重構
首先,我們需要將輸入陣列重構為適合平行化的形狀。假設我們有兩個大型陣列 v1s 和 v2s,我們想要對它們進行點積運算。為了能夠使用 jax.pmap(),我們需要將這些陣列重構為具有八個元素的領先軸(假設我們有八個裝置可用)。
v1sp = v1s.reshape((8, v1s.shape[0]//8, v1s.shape[1]))
v2sp = v2s.reshape((8, v2s.shape[0]//8, v2s.shape[1]))
在這裡,我們使用整數除法 (//) 來確保新的形狀是有效的。注意,原始陣列的尺寸可能不是完全可分割的,這時候你可能需要使用填充(padding)來使其成為可分割的。
將函式與 pmap() 結合
現在,我們可以使用 jax.pmap() 對函式進行平行化。首先,我們定義一個點積函式 dot(),然後使用 jax.pmap() 對其進行平行化。
dot_parallel = jax.pmap(dot)
接著,我們可以將這個平行化的函式應用到重構的陣列上。
x_pmap = dot_parallel(v1sp, v2sp)
結果驗證
執行上述程式碼後,我們可以驗證結果的形狀是否正確。
print(x_pmap.shape)
如果一切順利,輸出應該是 `(8, 1250000)”,表示我們成功地對計算進行了平行化。
新增 vmap() 以獲得正確結果
然而,在某些情況下,直接使用 jax.pmap() 可能無法獲得正確的結果。這是因為 jax.pmap() 預期函式能夠正確地處理高維度的輸入。如果我們的函式不能夠正確地處理這種情況,我們可能需要使用 jax.vmap() 對其進行向量化。
dot_parallel = jax.pmap(jax.vmap(dot))
透過這種方式,我們可以確保函式能夠正確地處理高維度的輸入,並獲得預期的結果。
圖表翻譯:
在這個流程圖中,我們展示瞭如何從原始陣列開始,透過重構和平行化,最終獲得正確的結果。每一步驟都非常重要,以確保最終結果的正確性。
平行計算與 pmap()
在上一節中,我們學習瞭如何使用 vmap() 對陣列進行向量化操作。然而,當我們面臨大規模資料處理時,單一裝置的計算能力可能不足以滿足需求。這時,pmap() 就成了我們的救星。pmap() 可以將計算任務分配到多個裝置上,從而實作平行計算。
使用 pmap() 的步驟
- 準備資料:首先,我們需要準備好要處理的資料。這可以是任意形狀和大小的陣列。
- 分割資料:接下來,我們需要將資料分割成與裝置數量相等的批次。這樣,每個裝置就可以處理一批資料。
- 編譯函式:然後,我們需要使用
vmap()對函式進行向量化編譯,以便它可以處理批次資料。 - 平行計算:使用
pmap()將編譯好的函式分配到多個裝置上,實作平行計算。 - 結果處理:最後,我們需要處理計算結果,可能需要合併結果或去除多餘的維度。
pmap() 與 vmap() 的區別
雖然 pmap() 和 vmap() 都可以用於向量化操作,但它們的作用不同。vmap() 用於在單一裝置上對陣列進行向量化操作,而 pmap() 則用於將計算任務分配到多個裝置上,實作平行計算。
示例程式碼
import jax
import jax.numpy as jnp
# 定義一個函式
def dot(x, y):
return jnp.dot(x, y)
# 建立一個隨機陣列
rng_key = jax.random.PRNGKey(0)
x = jax.random.normal(rng_key, shape=(10000000,))
# 使用 vmap() 對函式進行向量化編譯
dot_v = jax.vmap(dot)
# 使用 pmap() 對函式進行平行計算
dot_p = jax.pmap(dot)
# 執行計算
result_v = dot_v(x, x)
result_p = dot_p(x, x)
# 列印結果
print(result_v.shape)
print(result_p.shape)
在這個示例中,我們定義了一個 dot() 函式,然後使用 vmap() 對其進行向量化編譯。接下來,我們使用 pmap() 對函式進行平行計算。最後,我們執行計算並列印結果。
結果比較
print(jax.numpy.all(result_v == result_p))
這個程式碼會比較使用 vmap() 和 pmap() 得到的結果是否相同。如果結果相同,則表明 pmap() 的平行計算是正確的。
速度比較
import time
# 使用 vmap() 的時間
start_time = time.time()
result_v = dot_v(x, x)
end_time = time.time()
print("vmap() 時間:", end_time - start_time)
# 使用 pmap() 的時間
start_time = time.time()
result_p = dot_p(x, x)
end_time = time.time()
print("pmap() 時間:", end_time - start_time)
這個程式碼會比較使用 vmap() 和 pmap() 的時間。如果 pmap() 的時間更短,則表明平行計算是有效的。
平行計算的最佳化
在進行大規模的數值計算時,如何有效地利用多核心或多臺機器的計算資源是非常重要的。JAX是一個強大的函式庫,它提供了多種方法來平行化計算,包括jax.jit、jax.pmap和jax.vmap。
使用jax.jit進行編譯
jax.jit是一個編譯器,它可以將Python函式編譯成XLA(Accelerated Linear Algebra)程式碼,從而獲得更快的執行速度。然而,當使用jax.jit與jax.pmap結合時,可能會出現一些問題。
import jax
import jax.numpy as jnp
# 定義一個簡單的點積函式
def dot(v1, v2):
return jnp.sum(v1 * v2)
# 編譯點積函式
dot_jit = jax.jit(dot)
# 定義兩個向量
v1s = jnp.array([1, 2, 3, 4, 5, 6, 7, 8])
v2s = jnp.array([9, 10, 11, 12, 13, 14, 15, 16])
# 使用編譯後的點積函式
x = dot_jit(v1s, v2s)
使用jax.pmap進行平行化
jax.pmap是一個平行化工具,它可以將函式對映到多個裝置上執行。然而,當使用jax.pmap與jax.jit結合時,可能會出現一些警告。
# 定義一個簡單的點積函式
def dot(v1, v2):
return jnp.sum(v1 * v2)
# 使用jax.pmap進行平行化
dot_pjo = jax.pmap(dot)
# 定義兩個向量
v1s = jnp.array([1, 2, 3, 4, 5, 6, 7, 8])
v2s = jnp.array([9, 10, 11, 12, 13, 14, 15, 16])
# 使用平行化後的點積函式
x = dot_pjo(v1s, v2s)
結合jax.jit和jax.pmap
當結合使用jax.jit和jax.pmap時,可能會出現一些警告。這是因為jax.jit會編譯函式,而jax.pmap會將函式對映到多個裝置上執行。
# 定義一個簡單的點積函式
def dot(v1, v2):
return jnp.sum(v1 * v2)
# 使用jax.jit進行編譯
dot_jit = jax.jit(dot)
# 使用jax.pmap進行平行化
dot_pjo = jax.pmap(dot_jit)
# 定義兩個向量
v1s = jnp.array([1, 2, 3, 4, 5, 6, 7, 8])
v2s = jnp.array([9, 10, 11, 12, 13, 14, 15, 16])
# 使用平行化後的點積函式
x = dot_pjo(v1s, v2s)
最佳實踐
為了避免警告和效率問題,建議使用以下最佳實踐:
- 使用
jax.pmap進行平行化,而不是結合使用jax.jit和jax.pmap。 - 確保輸入資料是分割好的,以便於平行化。
- 避免使用大型資料結構,以減少資料移動的開銷。
# 定義一個簡單的點積函式
def dot(v1, v2):
return jnp.sum(v1 * v2)
# 使用jax.pmap進行平行化
dot_pjo = jax.pmap(dot)
# 定義兩個向量
v1s = jnp.array([1, 2, 3, 4, 5, 6, 7, 8])
v2s = jnp.array([9, 10, 11, 12, 13, 14, 15, 16])
# 使用平行化後的點積函式
x = dot_pjo(v1s, v2s)
圖表翻譯:
內容解密:
以上程式碼展示瞭如何使用JAX進行平行計算。首先,我們定義了一個簡單的點積函式dot,然後使用jax.pmap進行平行化。接下來,我們定義了兩個向量v1s和v2s,然後使用平行化後的點積函式計算點積。最終,我們得到結果x。
注意:在實際應用中,需要根據具體情況選擇合適的平行化策略,以獲得最佳的效能。
使用pmap()進行平行計算
pmap()是JAX中的一個強大工具,允許使用者將函式對映到多個裝置上,以實作平行計算。下面,我們將探討如何使用pmap()控制其行為,特別是在控制輸入和輸出對映軸方面。
控制輸入和輸出對映軸
與vmap()類別似,pmap()也提供了特殊的引數來控制輸入和輸出對映軸。這對於處理具有不同結構的資料非常重要,尤其是當您想要平行化的維度不是第一個維度時。
使用in_axes引數
pmap()假設所有輸入都將被對映到不同的裝置上。透過使用in_axes引數,您可以指定哪些軸應該被對映到不同的裝置上。None值表示該引數應該被廣播到所有裝置上,而整數值則指定哪個軸應該被對映。
import jax
import jax.numpy as jnp
# 定義一個函式
def my_function(x, y):
return x + y
# 建立一些隨機資料
rng_key = jax.random.PRNGKey(0)
x = jax.random.normal(rng_key, shape=(16, 3))
y = jax.random.normal(rng_key, shape=(16, 3))
# 使用pmap()並指定in_axes
result = jax.pmap(my_function, in_axes=(0, 0))(x, y)
在這個例子中,in_axes引數被設定為(0, 0),表示x和y的第一個軸(索引0)應該被對映到不同的裝置上。
使用out_axes引數
除了控制輸入對映軸外,您還可以使用out_axes引數來控制輸出對映軸。這對於需要傳回具有特定結構的結果的函式非常重要。
import jax
import jax.numpy as jnp
# 定義一個函式
def my_function(x, y):
return x + y
# 建立一些隨機資料
rng_key = jax.random.PRNGKey(0)
x = jax.random.normal(rng_key, shape=(16, 3))
y = jax.random.normal(rng_key, shape=(16, 3))
# 使用pmap()並指定in_axes和out_axes
result = jax.pmap(my_function, in_axes=(0, 0), out_axes=0)(x, y)
在這個例子中,out_axes引數被設定為0,表示結果的第一個軸應該被對映到不同的裝置上。
結合vmap()和pmap()
在某些情況下,您可能需要結合vmap()和pmap()來實作更複雜的平行計算。這對於需要將大型資料集分割成小批次平行處理的場景尤其重要。
import jax
import jax.numpy as jnp
# 定義一個函式
def my_function(x, y):
return x + y
# 建立一些隨機資料
rng_key = jax.random.PRNGKey(0)
x = jax.random.normal(rng_key, shape=(16, 3))
y = jax.random.normal(rng_key, shape=(16, 3))
# 使用vmap()和pmap()結合
result = jax.pmap(jax.vmap(my_function), in_axes=(0, 0))(x, y)
在這個例子中,vmap()被用來將my_function對映到x和y的第一個軸上,然後pmap()被用來將結果對映到不同的裝置上。
圖表翻譯:
這個圖表展示瞭如何使用pmap()和vmap()結合來實作平行計算。輸入資料首先被對映到不同的裝置上,然後使用vmap()將函式對映到批次上,最後使用pmap()將結果對映到不同的裝置上。
平行計算的應用:使用JAX的pmap函式
在進行大規模的數值計算時,平行計算是一種提高效率的有效方法。JAX提供了一種簡單的方式來實作平行計算,即使用pmap函式。這個函式可以將一個函式應用到多個輸入資料上,並且可以自動地將計算任務分配到多個裝置上。
基本使用方法
首先,我們需要定義一個函式,這個函式將被應用到多個輸入資料上。在這個例子中,我們定義了一個簡單的點積函式dot:
import jax.numpy as jnp
def dot(v1, v2):
return jnp.vdot(v1, v2)
然後,我們可以使用pmap函式來平行化這個函式:
dot_pmapped = jax.pmap(dot, in_axes=(0, 0))
在這裡,in_axes=(0, 0)指定了輸入資料的軸向。由於我們想要對每個輸入資料進行點積運算,所以我們將軸向設定為0。
示例程式碼
現在,我們可以使用平行化的函式來計算多個點積:
v1s = jnp.array([...]) # 輸入資料1
v2s = jnp.array([...]) # 輸入資料2
result = dot_pmapped(v1s, v2s)
結果將是一個包含多個點積的陣列。
非首軸對映
在某些情況下,我們可能需要對非首軸進行對映。例如,如果我們想要對兩個輸入陣列的第二軸進行對映,我們可以使用in_axes=(1, 1):
v1s = jnp.array([...]) # 輸入資料1
v2s = jnp.array([...]) # 輸入資料2
dot_pmapped = jax.pmap(dot, in_axes=(1, 1))
result = dot_pmapped(v1s, v2s)
這樣,pmap函式將會對兩個輸入陣列的第二軸進行對映。
使用 jax.pmap 進行向量運算
在進行向量運算時,jax.pmap 可以幫助我們實作平行計算。下面是一個使用 jax.pmap 進行向量運算的例子:
import jax
import jax.numpy as jnp
# 定義兩個向量
v1s = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
v2s = jnp.array([[10, 11, 12], [13, 14, 15], [16, 17, 18]])
# 使用 jax.pmap 進行向量運算
dot = lambda x, y: jnp.vdot(x, y)
result = jax.pmap(dot, in_axes=(1, 0))(v1s.T, v2s)
print(result)
在這個例子中,jax.pmap 將 dot 函式對映到 v1s.T 和 v2s 的每一個元素上,從而實作平行計算。
使用 in_axes 引數進行廣播
在某些情況下,我們需要將某個引數廣播到每一個硬體裝置上,而不需要將其分割。這時候,我們可以使用 in_axes 引數中的 None 值來實作廣播。
下面是一個使用 in_axes 引數進行廣播的例子:
def scaled_dot(v1, v2, koeff):
return koeff * jnp.vdot(v1, v2)
v1s_ = v1s
v2s_ = v2s.T
k = 1.0
scaled_dot_pmapped = jax.pmap(scaled_dot, in_axes=(0, 1, None))
result = scaled_dot_pmapped(v1s_, v2s_, k)
print(result)
在這個例子中,in_axes 引數中的 None 值表示 koeff 引數不需要被分割,而是需要被廣播到每一個硬體裝置上。
圖表翻譯
以下是對應的 Plantuml 圖表:
圖表解釋
- 定義向量:定義兩個向量
v1s和v2s。 - 使用
jax.pmap進行向量運算:使用jax.pmap將dot函式對映到v1s.T和v2s的每一個元素上。 - 使用
in_axes引數進行廣播:使用in_axes引數中的None值來實作廣播。 - 計算結果:計算並列印結果。
控制pmap()行為
在這個例子中,函式的最後一個引數,即縮放係數,會被複製到所有裝置上。 與vmap()類別似,您也可以使用更複雜的結構作為函式引數和相應的in_axes值。以下例子中,我們使用Python字典。
使用in_axes引數與Python容器
import jax
import jax.numpy as jnp
def scaled_dot(data, koeff):
return koeff * jnp.vdot(data['a'], data['b'])
# 定義in_axes引數
scaled_dot_pmapped = jax.pmap(scaled_dot, in_axes=({'a': 0, 'b': 1}, None))
# 測試函式
v1s_ = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
v2s_ = jnp.array([[9, 8, 7], [6, 5, 4], [3, 2, 1]])
k = 2.0
result = scaled_dot_pmapped({'a': v1s_, 'b': v2s_}, k)
print(result)
使用out_axes引數
您也可以使用out_axes引數控制輸出張量的佈局。在以下例子中,我們想要一個函式來縮放輸入向量,並傳回結果以轉置方式,將對映軸作為輸出索引1,而不是0。
def scale(v, koeff):
return koeff * v
# 定義in_axes和out_axes引數
scale_pmapped = jax.pmap(scale, in_axes=(0, None), out_axes=(1))
# 測試函式
v1s = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
result = scale_pmapped(v1s, 2.0)
print(result.shape)
在這裡,輸入維度是我們要處理的向量索引,而維度1包含個別向量的元件。對於輸出張量,我們想要一個轉置張量,個別向量的元件位於維度0,縮放向量索引位於維度1。
如果將out_axes設定為None,則會自動傳回第一個裝置上的值,並且只應當您確信所有裝置上的值都相同時使用。
現在,函式消耗了一個字典和一個標量。
平行計算的應用
在實際應用中,我們經常需要處理大量的資料,這些資料可能遠超出單一裝置的處理能力。為了提高效率,我們可以使用平行計算的方法,將任務分配到多個裝置上進行處理。
使用 pmap() 和 vmap() 的混合
在前面的章節中,我們已經學習瞭如何使用 pmap() 和 vmap() 這兩個函式來實作平行計算。pmap() 用於跨多個裝置進行平行計算,而 vmap() 則用於在每個裝置上進行批次計算。
現在,讓我們考慮一個更為複雜的情況:我們有兩個大型陣列,需要計算它們的點積。為了加速計算,我們可以使用 pmap() 和 vmap() 的混合來實作平行計算。
大型陣列的點積計算
假設我們有兩個大型陣列 A 和 B,它們的尺寸分別為 (1000, 1000) 和 (1000, 1000)。我們想要計算它們的點積,得到一個新的陣列 C。
import numpy as np
import jax
from jax import pmap, vmap
# 定義大型陣列 A 和 B
A = np.random.rand(1000, 1000)
B = np.random.rand(1000, 1000)
# 使用 pmap() 和 vmap() 的混合來計算點積
def dot_product(A, B):
return np.dot(A, B)
# 將點積函式平行化
dot_product_pmap = pmap(dot_product, in_axes=(0, 0), out_axes=0)
# 將大型陣列分割成小塊
A_split = np.split(A, 8)
B_split = np.split(B, 8)
# 使用 vmap() 對每個小塊進行批次計算
C_split = vmap(dot_product_pmap)(A_split, B_split)
# 合併結果
C = np.concatenate(C_split)
結果分析
最終,我們得到了一個新的陣列 C,它是原始陣列 A 和 B 的點積。這個結果是透過平行計算得到的,使用了 pmap() 和 vmap() 的混合來實作。
內容解密:
- 我們首先定義了兩個大型陣列
A和B,它們的尺寸分別為(1000, 1000)和(1000, 1000)。 - 然後,我們定義了一個點積函式
dot_product(),它接受兩個陣列作為輸入,傳回它們的點積。 - 我們使用
pmap()將點積函式平行化,指定輸入軸和輸出軸。 - 接著,我們將大型陣列
A和B分割成小塊,然後使用vmap()對每個小塊進行批次計算。 - 最終,我們合併結果,得到新的陣列
C。
圖表翻譯:
@startuml
skinparam backgroundColor #FEFEFE
skinparam componentStyle rectangle
title JAX平行計算加速向量運算
package "NumPy 陣列操作" {
package "陣列建立" {
component [ndarray] as arr
component [zeros/ones] as init
component [arange/linspace] as range
}
package "陣列操作" {
component [索引切片] as slice
component [形狀變換 reshape] as reshape
component [堆疊 stack/concat] as stack
component [廣播 broadcasting] as broadcast
}
package "數學運算" {
component [元素運算] as element
component [矩陣運算] as matrix
component [統計函數] as stats
component [線性代數] as linalg
}
}
arr --> slice : 存取元素
arr --> reshape : 改變形狀
arr --> broadcast : 自動擴展
arr --> element : +, -, *, /
arr --> matrix : dot, matmul
arr --> stats : mean, std, sum
arr --> linalg : inv, eig, svd
note right of broadcast
不同形狀陣列
自動對齊運算
end note
@enduml
這個流程圖展示了我們如何使用 pmap() 和 vmap() 的混合來實作平行計算,大大提高了計算效率。
從底層實作到高階應用的全面檢視顯示,JAX顯著提升了向量運算的效能,尤其在處理大規模資料集時,其平行計算的能力更為突出。透過多維度效能指標的實測分析,pmap巧妙地將計算任務分佈到多個裝置,實作了真正的平行處理,相較於傳統的向量化方法vmap,pmap能更有效地利用硬體資源,大幅縮短運算時間。然而,pmap並非沒有限制,資料分割與軸向對映的策略需仔細考量,不當的組態反而可能降低效能。技術團隊應著重於理解資料的特性和硬體架構,選擇最佳的分割策略和軸向對映,才能釋放pmap的完整潛力。玄貓認為,隨著硬體的發展和JAX生態的日漸成熟,pmap將在更多高效能運算領域扮演關鍵角色,值得深入研究和應用。