返回文章列表

JAX自動向量化技術與效能分析

本文深入探討JAX的自動向量化技術,使用vmap()函式簡化批次資料處理,並提供實際案例與效能比較。有效提升機器學習模型訓練效率,並簡化程式碼複雜度,適合深度學習開發者參考。

機器學習 Web 開發

JAX作為新興的機器學習框架,其自動向量化功能備受關注。利用vmap()函式,開發者可以輕鬆地將單個元素的運算擴充套件到整個陣列或批次資料,無需手動編寫迴圈。這不僅簡化了程式碼,更重要的是顯著提升了運算效率,尤其在處理大規模資料集時效果更加明顯。對於深度學習模型的訓練和最佳化,vmap()的應用能有效縮短訓練時間,並降低程式碼維護的難度,是JAX框架的一大優勢。

使用vmap()進行自動向量化

我們已經在前面的章節中介紹了自動向量化和使用vmap()的概念。vmap()是一種簡單的方法,可以將一個函式應用於多個元素或陣列上。它可以簡化程式設計過程並加速計算。

實際應用案例

在本章的最後一部分,我們將介紹幾個實際應用案例,展示如何使用自動向量化來加速計算。這些案例包括影像處理、神經網路和資料分析等。

影像處理

在影像處理中,自動向量化可以用來加速影像濾波器的應用。例如,您可以使用vmap()將一個濾波器應用於多個影像上。

神經網路

在神經網路中,自動向量化可以用來加速神經網路的前向和後向傳播。例如,您可以使用vmap()將一個神經網路應用於多個資料點上。

資料分析

在資料分析中,自動向量化可以用來加速資料處理和分析。例如,您可以使用vmap()將一個函式應用於多個資料點上。

處理向量陣列的函式

在進行向量運算時,經常需要處理陣列中的向量。為了簡化這個過程,我們可以定義一個函式,專門用於處理向量陣列。

函式定義

import numpy as np

def process_vector_array(array):
    """
    處理向量陣列的函式。
    
    Parameters:
    array (numpy.ndarray): 向量陣列。
    
    Returns:
    result (numpy.ndarray): 處理後的結果。
    """
    # 對陣列中的每個向量進行處理
    result = np.array([process_vector(vector) for vector in array])
    
    return result

def process_vector(vector):
    """
    處理單個向量的函式。
    
    Parameters:
    vector (numpy.ndarray): 單個向量。
    
    Returns:
    result (numpy.ndarray): 處理後的結果。
    """
    # 對向量進行必要的運算
    result = np.sum(vector)
    
    return result

函式使用示例

# 定義一個向量陣列
vector_array = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

# 處理向量陣列
result = process_vector_array(vector_array)

print(result)

內容解密:

上述程式碼定義了兩個函式:process_vector_arrayprocess_vectorprocess_vector_array 函式接受一個向量陣列作為輸入,然後對陣列中的每個向量進行處理。process_vector 函式接受一個單個向量作為輸入,然後對向量進行必要的運算。在這個例子中,process_vector 函式計算向量的總和。

圖表翻譯:

上述 Plantuml 圖表展示了處理向量陣列的流程。首先,輸入一個向量陣列。然後,對陣列中的每個向量進行處理,計算向量的總和。最後,輸出結果。

向量化您的程式碼

在本章中,我們將討論如何從適用於單一元素的函式轉換為適用於元素陣列(或批次)的函式。讓我們從一個簡單的函式開始,該函式計算兩個向量(或rank-1張量,型別為陣列)之間的點積。

6.1 向量化簡介

首先,我們定義一個計算兩個向量之間點積的函式,如下所示:

import jax.numpy as jnp

def dot(v1, v2):
    return jnp.vdot(v1, v2)

# 測試函式
print(dot(jnp.array([1., 1., 1.]), jnp.array([1., 2., -1])))

這個函式很簡單,它只是計算兩個向量之間的點積。現在,假設您有兩個向量列表(JAX函式不支援Python列表,因此我們假設是一個包含「向量列表」的陣列;從技術上講,這是一個具有額外維度的陣列,或是我們情況下的rank-2張量),您需要計算兩個輸入列表中對應元素之間的點積列表。

6.2 生成向量列表

首先,我們生成兩個隨機向量列表:

import jax.random as random

rng_key = random.PRNGKey(42)
vs = random.normal(rng_key, shape=(20, 3))
v1s = vs[:10, :]
v2s = vs[10:, :]

有不同的方法可以得到您想要的結果。您可以使用我所謂的「天真」方法,這些方法可能(或可能不)有效,但通常效率不高。您可以手動重寫函式以使其適用於陣列(向量化它)。或者,您可以依靠自動向量化,它將適用於單一元素的函式轉換為適用於元素陣列的函式。讓我們涵蓋所有這些方法。

6.1.1 天真方法

首先,有兩種不同的天真方法;它們是最直接的,但不一定是最有效的。

第一種天真方法是將陣列傳遞給原始函式而不進行任何修改。一般而言,有三種可能的結果:

  • 使用vdot()函式
  • 計算兩個向量之間的點積
  • 建立一個隨機數生成器金鑰(稍後會有更多相關內容)

內容解密:

上述程式碼使用jax.numpy模組定義了一個計算兩個向量之間點積的函式dot()。然後,它使用jax.random模組生成兩個隨機向量列表v1sv2s。這些向量列表將用於演示向量化。

圖表翻譯:

此圖表描述了程式碼的執行流程。首先,定義dot()函式,然後生成隨機向量列表,接著計算點積,最後輸出結果。

6.3 手動重寫函式

手動重寫函式以使其適用於陣列涉及修改原始函式以迭代陣列中的每個元素並計算點積。

def dot_manual(v1s, v2s):
    result = []
    for v1, v2 in zip(v1s, v2s):
        result.append(jnp.vdot(v1, v2))
    return jnp.array(result)

# 測試手動重寫的函式
print(dot_manual(v1s, v2s))

內容解密:

上述程式碼定義了一個名為dot_manual()的新函式,該函式手動迭代兩個輸入陣列v1sv2s中的每個元素,計算對應元素之間的點積,並將結果儲存在一個新陣列中。

圖表翻譯:

此圖表描述了手動重寫的函式的執行流程。首先,定義dot_manual()函式,然後迭代陣列中的每個元素,計算點積,儲存結果,最後輸出結果。

6.4 自動向量化

JAX提供了一種自動向量化的方法,可以將適用於單一元素的函式轉換為適用於元素陣列的函式。

from jax import vmap

dot_vmap = vmap(dot, in_axes=(0, 0))

# 測試自動向量化的函式
print(dot_vmap(v1s, v2s))

內容解密:

上述程式碼使用JAX的vmap()函式自動向量化原始的dot()函式。這建立了一個新的函式`dot_vmap(),該函式可以將兩個陣列作為輸入並計算對應元素之間的點積。

圖表翻譯:

此圖表描述了自動向量化的執行流程。首先,定義dot_vmap()函式,然後自動向量化原始的dot()函式,接著計算點積,最後輸出結果。

向量化程式碼的重要性

在深度學習中,向量化程式碼是提高效率和效能的關鍵。向量化可以讓我們同時處理多個資料點,從而加速計算速度。在本章中,我們將探討如何向量化程式碼,並比較不同方法的優缺點。

直接應用函式

首先,我們來看看直接應用函式的方法。這種方法看起來很簡單,但實際上可能會出現問題。讓我們考慮一個例子,假設我們有兩個10個元素的陣列,每個元素是一個3維向量。我們想要計算每個向量之間的點積。

import numpy as np

# 生成兩個10個元素的陣列,每個元素是一個3維向量
v1s = np.random.rand(10, 3)
v2s = np.random.rand(10, 3)

# 定義點積函式
def dot(v1, v2):
    return np.dot(v1, v2)

# 直接應用函式
result = dot(v1s, v2s)

結果可能不是我們想要的。因為 NumPy 的廣播機制,函式可能會自動向量化,但這種行為不可預測。有三種可能的情況:

  1. 函式自動向量化,得到正確結果。
  2. 函式出現錯誤,提示我們需要修改函式。
  3. 函式得到錯誤結果,但沒有出現錯誤。

第三種情況是最糟糕的,因為我們可能不知道結果是錯誤的。

使用列表推導

另一個方法是使用列表推導,對每個元素單獨計算點積。

# 使用列表推導
result = [dot(v1s[i], v2s[i]) for i in range(v1s.shape[0])]

這種方法可以得到正確結果,但有兩個缺點:

  1. 結果是一個 Python 列表,而不是一個 NumPy 陣列。
  2. 效能可能會降低,因為列表推導需要更多的時間和空間。

向量化函式

最好的方法是向量化函式,使其可以同時處理多個元素。

# 向量化函式
def dot_vectorized(v1s, v2s):
    return np.sum(v1s * v2s, axis=1)

這種方法可以得到正確結果,並且效能最佳。

比較效能

讓我們比較三種方法的效能。

import timeit

# 直接應用函式
def test_dot():
    result = dot(v1s, v2s)

# 使用列表推導
def test_dot_list():
    result = [dot(v1s[i], v2s[i]) for i in range(v1s.shape[0])]

# 向量化函式
def test_dot_vectorized():
    result = dot_vectorized(v1s, v2s)

print("直接應用函式:", timeit.timeit(test_dot, number=1000))
print("使用列表推導:", timeit.timeit(test_dot_list, number=1000))
print("向量化函式:", timeit.timeit(test_dot_vectorized, number=1000))

結果顯示,向量化函式的效能最佳。

圖表翻譯:

在這個圖表中,我們可以看到三種方法的優缺點。直接應用函式可能出現問題,使用列表推導可以得到正確結果,但效能可能會降低。向量化函式可以得到正確結果,並且效能最佳。

手動向量化

為了克服簡單方法的低效率,我們可以透過手動重寫和向量化單元素函式,使其能夠接受批次資料作為輸入。這通常意味著我們的輸入張量將具有額外的批次維度,我們需要重寫計算以使用它。對於簡單的計算,這很直接,但對於複雜的函式,這可能會很複雜。

在這裡,我們將使用NumPy介面中的一個非常強大的函式,稱為einsum()。我們在附錄D中提到它時,正在討論xmap()。現在,只需考慮它是一個選項,可以使我們的函式向量化。

手動向量化函式

import jax.numpy as jnp

def dot_vectorized(v1s, v2s):
    return jnp.einsum('ij,ij->i', v1s, v2s)

# 測試向量化函式
v1s = jnp.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16], [17, 18], [19, 20]])
v2s = jnp.array([[20, 19], [18, 17], [16, 15], [14, 13], [12, 11], [10, 9], [8, 7], [6, 5], [4, 3], [2, 1]])

result = dot_vectorized(v1s, v2s)
print(result)

內容解密:

在上面的程式碼中,我們定義了一個名為dot_vectorized的函式,它接受兩個批次向量v1sv2s作為輸入。然後,我們使用jnp.einsum函式計算每個批次向量之間的點積。'ij,ij->i'引數指定了計算的維度,表示對於每個批次向量,計算其對應元素之間的點積。

這種手動向量化方法使我們的函式能夠高效地處理批次資料,並且結果與預期相符。透過使用jnp.einsum,我們可以簡潔地表達複雜的計算,並使程式碼更易於閱讀和維護。

圖表翻譯:

在這個圖表中,我們展示了批次向量v1sv2s之間的點積計算過程。jnp.einsum函式用於計算點積,然後傳回結果。這個圖表直觀地展示了計算過程,使得理解程式碼邏輯更加容易。

自動向量化:簡化函式轉換的過程

在前面的章節中,我們探討瞭如何使用 einsum() 對函式進行向量化。然而,這種方法需要手動重寫函式以支援向量化,對於複雜的函式來說,這個過程可能很困難且容易出錯。

幸運的是,JAX 提供了一種名為自動向量化的替代方案。vmap() 轉換可以將原本只能處理單一元素的函式轉換為能夠處理批次的函式(見圖 6.3)。

自動向量化的優點

使用 vmap() 的優點在於我們不需要修改原始函式就能夠獲得所需的結果。這使得程式碼更加直接和易於理解,這在電腦科學中是一個很大的優點,因為大多數程式碼需要維護,而其他人需要閱讀和理解它。

實際應用

讓我們透過實際程式碼來看看如何使用 vmap() 來自動向量化函式。假設我們有兩個向量 v1sv2s,我們想要計算它們之間的點積。

import jax
import jax.numpy as jnp

# 定義原始函式
def dot(v1, v2):
    return jnp.dot(v1, v2)

# 使用 vmap() 對函式進行自動向量化
dot_vmapped = jax.vmap(dot)

# 測試自動向量化的函式
v1s = jnp.array([1, 2, 3])
v2s = jnp.array([4, 5, 6])
result = dot_vmapped(v1s, v2s)
print(result)

效能比較

現在,我們已經有了多種實作方式,讓我們比較一下它們的行為和效能。透過使用 vmap(),我們可以簡化函式轉換的過程,使程式碼更加易於維護和理解。

圖表翻譯:

在這個流程圖中,我們可以看到原始函式如何透過 vmap() 轉換為能夠處理批次的函式。這使得我們可以簡化程式碼並提高效能。

6.1.4 效能比較

在前面的章節中,我們提到過,直接的方法可能會提供一個簡單的解決方案,但這種方法可能會有一些缺點。這裡,我們有兩個與時間相關的指標對我們很重要:開發解決方案的時間和執行計算的時間。開發解決方案的時間基本上是很明確的:自動向量化比手動安排迴圈和保持函式不變更簡單;在這種情況下,我們需要更多時間來理解什麼是錯誤的,最後才會寫出正確的解決方案。與手動向量化相比,自動向量化也更快地開發。

現在,讓我們來看看第二個時間相關指標,即所有方法的執行速度。

執行速度比較

import timeit
import jax
import jax.numpy as jnp

# 定義原始的dot函式
def dot(v1, v2):
    return jnp.sum(v1 * v2)

# 定義手動向量化的dot函式
def dot_vectorized(v1s, v2s):
    return jnp.sum(v1s * v2s, axis=1)

# 定義自動向量化的dot函式
def dot_vmapped(v1s, v2s):
    return jax.vmap(dot)(v1s, v2s)

# 定義輸入資料
v1s = jnp.array([[1., 1., 1.]] * 10)
v2s = jnp.array([[1., 1., -1.]] * 10)

# 執行速度比較
print("原始方法:")
%timeit [dot(v1s[i], v2s[i]) for i in range(v1s.shape[0])]

print("手動向量化方法:")
%timeit dot_vectorized(v1s, v2s)

print("自動向量化方法:")
%timeit dot_vmapped(v1s, v2s)

# JIT編譯
dot_vectorized_jitted = jax.jit(dot_vectorized)
dot_vmapped_jitted = jax.jit(dot_vmapped)

# warm-up
dot_vectorized_jitted(v1s, v2s)
dot_vmapped_jitted(v1s, v2s)

print("JIT編譯後的手動向量化方法:")
%timeit dot_vectorized_jitted(v1s, v2s)

print("JIT編譯後的自動向量化方法:")
%timeit dot_vmapped_jitted(v1s, v2s)

從結果可以看出,原始方法是最慢的,而手動向量化的方法則具有最佳的效能。自動向量化的方法雖然比原始方法快很多,但仍然比手動向量化的方法慢一些。經過JIT編譯後,兩種向量化方法的速度變得相似,自動向量化的方法稍微慢一些,但兩者的信賴區間重疊,因此無法確定哪一個更快。

內部實作細節

讓我們進一步探討這些函式的內部實作細節,並生成jaxprs以便更好地理解它們。

# 生成jaxpr
jax.make_jaxpr(dot)(jnp.array([1., 1., 1.]), jnp.array([1., 1., -1]))
jax.make_jaxpr(dot_vectorized)(v1s, v2s)

這些jaxprs將幫助我們更好地理解每個函式的內部實作細節。

6.2 控制 vmap() 行為

雖然基本使用 vmap() 轉換相當直截了當,但在許多情況下,您需要更細膩的控制。例如,您的陣列可能以不同的方式排列,批次維度可能不是第一個維度。或者,您可能使用更複雜的結構作為函式引數,例如字典。vmap() 提供了有用的方法來處理不同的張量結構。

6.2.1 控制陣列軸向進行對映

您可以控制哪些陣列軸向進行對映。為此,vmap() 函式有一個名為 in_axes 的引數。這個引數可以是一個整數、None,或是一個(可能巢狀的)標準 Python 容器,例如元組、列表或字典。

如果 in_axes 引數是一個整數(預設值為 0),則陣列軸向由指定的索引決定。在我們的例子中,我們沒有明確使用這個引數,因此函式對每個引數的第一軸(索引 0)進行對映,批次維度是索引 0 的維度。

假設您需要對不同的引數使用不同的索引。在這種情況下,您可以使用一個整數和 None 的元組,長度等於原始函式的位置引數數量。None 值表示不對該引數進行對映。一般規則是 in_axes 結構應該對應於相關輸入的結構。

vmap() 呼叫從清單 6.6 等同於以下清單。

清單 6.9 使用 in_axes 引數

dot_vmapped = jax.vmap(dot, in_axes=(0, 0))

圖 6.4 顯示了兩個轉置陣列,其中向量沿著張量的水平維度排列。如果您想對轉置陣列進行對映,如圖 6.4 所示,您可能會使用 in_axes=(1, 1) 值。

讓我們考慮一個更複雜的情況,具有不同的軸向和不需要對映的軸向。我們增加了一個名為 koeff 的新引數到我們的點積函式中。它計算與之前相同的點積,但現在乘以了一個係數。您想要此函式應用到的陣列結構不同:第一個與之前相同,但第二個是原始陣列的轉置版本,批次維度現在是索引 1。

這可能自然發生在您處理不同資料源時,資料以不同的方式排列。

以下程式碼展示瞭如何使用 in_axes 引數控制陣列軸向進行對映:

import jax
import jax.numpy as jnp

# 定義點積函式
def dot(a, b):
    return jnp.sum(a * b, axis=1)

# 定義帶有 koeff 引數的點積函式
def dot_with_koeff(a, b, koeff):
    return koeff * jnp.sum(a * b, axis=1)

# 建立一些示例資料
a = jnp.array([[1, 2], [3, 4]])
b = jnp.array([[5, 6], [7, 8]])
koeff = 2.0

# 使用 vmap() 對點積函式進行對映
dot_vmapped = jax.vmap(dot, in_axes=(0, 0))

# 使用 vmap() 對帶有 koeff 引數的點積函式進行對映
dot_with_koeff_vmapped = jax.vmap(dot_with_koeff, in_axes=(0, 0, None))

# 執行對映後的函式
result = dot_vmapped(a, b)
result_with_koeff = dot_with_koeff_vmapped(a, b, koeff)

print(result)
print(result_with_koeff)

在這個例子中,我們定義了兩個點積函式:dotdot_with_koeff。然後,我們使用 vmap() 對這些函式進行對映,控制陣列軸向進行對映。最後,我們執行對映後的函式並列印結果。

控制 vmap() 行為

在深度學習和數值計算中,vmap() 是一個強大的工具,允許我們將函式應用於陣列的每個元素。然而,在某些情況下,我們可能需要控制 vmap() 的行為,以適應特定的計算需求。

瞭解 vmap() 的運作

vmap() 的基本思想是將一個函式應用於陣列的每個元素,然後傳回結果陣列。例如,假設我們有兩個向量 v1v2,我們可以使用 vmap() 計算它們的點積:

import jax.numpy as jnp

def dot_product(v1, v2):
    return jnp.vdot(v1, v2)

v1 = jnp.array([1, 2, 3])
v2 = jnp.array([4, 5, 6])

result = dot_product(v1, v2)
print(result)  # Output: 32

控制 vmap() 的行為

在某些情況下,我們可能需要控制 vmap() 的行為,以適應特定的計算需求。例如,假設我們有兩個陣列 v1sv2s,我們可以使用 vmap() 計算它們的點積:

v1s = jnp.array([[1, 2, 3], [4, 5, 6]])
v2s = jnp.array([[7, 8, 9], [10, 11, 12]])

def scaled_dot(v1, v2, koeff):
    return koeff * jnp.vdot(v1, v2)

v1s_ = v1s
v2s_ = v2s.T
k = 1.0

result = jax.vmap(scaled_dot, (0, 0, None))(v1s_, v2s_, k)
print(result)  # Output: [ 254. 610.]

在這個例子中,vmap()scaled_dot 函式應用於 v1s_v2s_ 的每個元素,然後傳回結果陣列。

向量化運算的最佳化

在進行矩陣運算時,尤其是涉及多維陣列的計算,能夠有效地組織和轉換資料以適應運算需求是非常重要的。這樣不僅可以簡化程式碼的複雜度,也能夠提高運算效率。

資料轉換和陣列結構

考慮到以下情況:((10, 3), (3, 10)),這裡我們有兩個陣列,分別為10x3和3x10的矩陣。如果我們想要讓這些資料在相同的佈局下進行運算,可能需要進行資料轉換。或者,我們也可以選擇跳過這一步,並告知批次運算函式,輸入陣列是以不同的方式組織的。

批次運算函式

批次運算函式可以根據輸入陣列的不同組織方式進行調整。這個過程可以透過引入一個額外的引數來實作,這個引數用於指示陣列的結構和組織方式。

陣列結構和Scaled Dot Product

新的陣列結構和scaled_dot()函式可以用圖6.5來描述。首先,第一個陣列保持不變。第二個陣列則需要進行轉置,以便與第一個陣列匹配。

引數和係數

在這個過程中,還有一個係數(constant)需要考慮,這個係數對於所有陣列元素都是一致的。

範例和結果

以v1和v2為例,v1是一個10x3的陣列,v2是一個3x10的陣列。透過scaled_dot()函式,我們可以得到一個結果陣列,其形狀為(1,)。

內容解密:

在上述程式碼中,我們定義了兩個陣列v1和v2,分別代表10x3和3x10的矩陣。透過scaled_dot()函式,我們可以計算出結果陣列result,其形狀為(1,)。

import numpy as np

def scaled_dot(v1, v2, k):
    # v1: (10, 3)
    # v2: (3, 10)
    # k: scalar
    result = np.dot(v1, v2) * k
    return result

# 範例使用
v1 = np.random.rand(10, 3)
v2 = np.random.rand(3, 10)
k = 2.0

result = scaled_dot(v1, v2, k)
print(result.shape)  # (10, 10)

圖表翻譯:

上述Plantuml圖表描述了scaled_dot()函式的運算過程。v1和v2分別代表輸入陣列,result代表輸出結果。透過scaled_dot()函式,我們可以計算出結果陣列。

這個過程展示瞭如何透過調整陣列結構和引入額外引數來實作批次運算,並且如何使用scaled_dot()函式計算結果陣列。

使用 jax.vmap() 進行向量化運算

jax.vmap() 是一個強大的工具,允許您將函式應用於陣列的元素上。當您需要對陣列進行元素級別的運算時,jax.vmap() 可以幫助您提高效率。

基本使用

首先,讓我們定義一個簡單的函式 scaled_dot()”,它計算兩個向量的點積,並乘以一個標量 k`。

import jax.numpy as jnp

def scaled_dot(v1, v2, k):
    return k * jnp.vdot(v1, v2)

現在,我們想要將這個函式應用於多個向量對上。為此,我們可以使用 jax.vmap()

v1s_ = jnp.array([[1, 2], [3, 4], [5, 6]])
v2s_ = jnp.array([[7, 8], [9, 10], [11, 12]])
k = 2.0

scaled_dot_batched = jax.vmap(scaled_dot, in_axes=(0, 0, None))
result = scaled_dot_batched(v1s_, v2s_, k)
print(result)

在這個例子中,in_axes 引數指定了哪些軸應該被對映。對於 v1s_v2s_,我們對映第一軸(索引 0),而對於 k,我們不進行對映(None)。

使用 in_axes 引數

in_axes 引數可以是一個元組,也可以是一個標準的 Python 容器,甚至可以是巢狀的容器。這允許您對函式的不同引數進行不同的對映。

def scaled_dot(data, koeff):
    return koeff * jnp.vdot(data['a'], data['b'])

v1s_ = jnp.array([[1, 2], [3, 4], [5, 6]])
v2s_ = jnp.array([[7, 8], [9, 10], [11, 12]])
k = 2.0

scaled_dot_batched = jax.vmap(scaled_dot, in_axes=({'a': 0, 'b': 1}, None))
result = scaled_dot_batched({'a': v1s_, 'b': v2s_}, k)
print(result)

在這個例子中,in_axes 引數是一個字典,它指定了對 data 引數的對映。對於 koeff 引數,我們不進行對映(None)。

控制輸出陣列軸

在某些情況下,您可能需要控制輸出陣列的軸,特別是當輸出有多個維度且您需要批次維度不是第一個維度時。這可能是因為管道中的下一個函式需要輸入資料以特定的格式。

6.2.2 控制輸出陣列軸

讓我們考慮一個函式,該函式用於縮放向量,並希望輸出與輸入向量佈局相比為轉置,如圖6.6所示。

def scale(v, k):
    return k * v

為了實作這一點,jax.vmap提供了一個名為out_axes的引數。out_axes用於指定輸出的軸。以下是使用out_axes的示例:

scale_batched = jax.vmap(scale, in_axes=(0, None), out_axes=(1))

在這個例子中,in_axes=(0, None)指定輸入的第一個引數(即v)沿著第一個軸(即批次軸)進行對映,而第二個引數(即k)不進行對映。out_axes=(1)指定輸出的軸為1,即第二個軸。

現在,當我們呼叫scale_batched時,輸出將被轉置:

v1s = np.array([[-1.4672383, -1.6510035, 3.5308602, -2.2189112, 0.3024418,
                 0.7649379, -4.028754, -3.0968533, 0.34476107, -2.9087348],
                [-1.4672383, -1.6510035, 3.5308602, -2.2189112, 0.3024418,
                 0.7649379, -4.028754, -3.0968533, 0.34476107, -2.9087348]])

scale_batched(v1s, 2.0)

輸出將是一個轉置的陣列:

Array([[-2.9344766, -3.302007, 7.0617204, -4.4378224, 0.6048836,
         1.5298758, -8.057508, -6.1937066, 0.68952214, -5.8174696 ],
       [-2.9344766, -3.302007, 7.0617204, -4.4378224, 0.6048836,
         1.5298758, -8.057508, -6.1937066, 0.68952214, -5.8174696 ]])

向量化程式碼的重要性

在深度學習和機器學習中,向量化程式碼是一種提高效率和簡化計算的重要技巧。透過使用向量化操作,可以將多個元素的計算同時進行,從而減少迴圈的使用和提高程式碼的執行速度。

使用 jax.vmap 函式

jax.vmap 是一個強大的函式,可以用來向量化程式碼。它可以將一個函式應用到多個輸入上,並傳回一個包含多個輸出的陣列。下面是一個簡單的例子:

import jax
import jax.numpy as jnp

def scale(v, koeff=1.0):
    return koeff * v

# 定義輸入陣列
v1s = jnp.array([-1.5357308, -0.7061183, 4.0082793, -0.69232166, -3.2186441, 
                 2.0812016, 3.585087, 0.15288436, 2.0001278, 2.0246687])

# 使用 vmap 函式向量化 scale 函式
scale_batched = jax.vmap(scale, in_axes=(0, None), out_axes=(1))

# 執行向量化的 scale 函式
result = scale_batched(v1s, koeff=2.0)

在這個例子中,scale 函式是一個簡單的函式,它將輸入陣列 v 乘以一個係數 koeff。我們使用 jax.vmap 函式向量化 scale 函式,然後執行向量化的 scale 函式。

命名引數和 vmap 函式

當使用命名引數時,需要注意 vmap 函式的行為。命名引數總是沿著其領先軸(索引 0)進行對映。否則,可能會出現意外的錯誤訊息。

下面是一個使用命名引數的例子:

def scale(v, koeff=1.0):
    return koeff * v

# 使用 vmap 函式向量化 scale 函式
scale_batched = jax.vmap(scale, in_axes=(0, None), out_axes=(1))

# 執行向量化的 scale 函式
result = scale_batched(v1s, koeff=2.0)

在這個例子中,scale 函式的第二個引數 koeff 是一個命名引數。當使用 vmap 函式向量化 scale 函式時,需要指定 in_axesout_axes 引數,以確保正確的對映。

使用 JAX 的 vmap 函式進行批次運算

在使用 JAX 的 vmap 函式進行批次運算時,需要注意 in_axes 引數的設定。in_axes 引數用於指定哪些軸需要進行對映。然而,在使用命名引數(keyword arguments)時,需要特別注意,因為 in_axes 引數只適用於位置引數(positional arguments)。

問題描述

當我們嘗試使用 vmap 函式對一個具有命名引數的函式進行批次運算時,會遇到一個奇怪的錯誤。錯誤資訊指出,vmap 試圖對一個沒有維度的引數進行對映。

解決方案

有幾種方法可以解決這個問題:

  1. 切換回位置引數:我們可以將函式修改為使用位置引數而不是命名引數。
  2. 建立一個包裝函式:我們可以建立一個包裝函式來隱藏具有命名引數的函式。
  3. 廣播命名引數:我們可以將命名引數廣播成一個具有所需維度的陣列,以便進行對映。

以下是示例程式碼,展示瞭如何使用包裝函式和廣播命名引數的方法:

import jax
import jax.numpy as jnp
from functools import partial

# 定義原始函式
def scale(x, koeff):
    return x * koeff

# 建立一個包裝函式
scale2 = partial(scale, koeff=2.0)

# 使用 vmap 函式進行批次運算
scale_batched = jax.vmap(scale2, in_axes=(0,), out_axes=(1,))

# 測試批次運算
v1s = jnp.array([[-1, -2, 3], [-4, -5, 6]])
result = scale_batched(v1s)
print(result)

在這個示例中,我們建立了一個包裝函式 scale2,它隱藏了 koeff 引數。然後,我們使用 vmap 函式對 scale2 函式進行批次運算。

使用JAX進行向量化運算

在深度學習和數值計算中,向量化運算是一種提高效率的重要方法。JAX是一個由Google開發的高效能機器學習函式庫,它提供了強大的向量化運算功能。下面,我們將探討如何使用JAX進行向量化運算。

定義一個簡單的函式

首先,讓我們定義一個簡單的函式scale,它將輸入的向量乘以一個係數koeff

import jax.numpy as jnp

def scale(v, koeff):
    return v * koeff

這個函式非常簡單,它只需將輸入的向量v乘以係數koeff即可。

使用jax.vmap進行向量化

現在,我們想要對一個批次的向量進行相同的運算。為此,我們可以使用JAX的jax.vmap函式,它可以將一個函式應用到一個批次的輸入上:

scale_batched = jax.vmap(scale, in_axes=(0,), out_axes=(1,))

在這裡,in_axes=(0,)指定了輸入的批次維度,而out_axes=(1,)指定了輸出的批次維度。

應用批次運算

現在,我們可以使用scale_batched函式對一個批次的向量進行運算:

v1s = jnp.array([
    [-1.4672383, -1.6510035, 3.5308602, -2.2189112, 0.3024418],
    [-1.5357308, -0.7061183, 4.0082793, -0.69232166, -3.2186441],
    [-1.6228952, 1.5497094, -3.2013843, 0.50127184, -0.2000112]
])

koeff = jnp.broadcast_to(2.0, (v1s.shape[0],))

result = scale_batched(v1s, koeff=koeff)

在這裡,v1s是一個批次的向量,而koeff是一個批次的係數。結果將是一個批次的向量,每個元素都是原始向量乘以對應的係數。

結果

最終結果如下:

Array([
    [-2.9344766, -3.302007, 7.0617204, -4.4378224, 0.6048836],
    [-3.0714616, -1.4122366, 8.0165586, -1.3846432, -6.4372882],
    [-3.2457904, 3.0994188, -6.4027686, 1.00254368, -0.4000224]
], dtype=float32)

這個結果是每個原始向量乘以對應的係數後的結果。

6.2.4 使用裝飾器風格

在許多函式轉換中,您可以使用 vmap() 來實作向量化。根據情況,這可能會給您更清晰的程式碼,無需臨時函式。以下,我們使用裝飾器重寫了清單 6.13 中的 scale() 函式。

清單 6.16 使用裝飾器

from functools import partial
from jax import vmap

@partial(vmap, in_axes=(0, None), out_axes=(1))
def scale(v, koeff):
    return koeff * v

# 測試 scale 函式
v1s = np.array([[-1.4672383, -1.6510035, 3.5308602, -2.2189112, 0.3024418,
                 0.7649379, -4.028754, -3.0968533, 0.34476107, -2.9087348]])
result = scale(v1s, 2.0)
print(result)

內容解密:

在這個例子中,我們使用 @partial(vmap, in_axes=(0, None), out_axes=(1)) 來裝飾 scale() 函式。這使得 scale() 函式可以對陣列進行向量化運算。in_axes=(0, None) 表示 v 引數沿著第一個軸(索引 0)進行對映,而 koeff 引數不進行對映(因為它是標量)。out_axes=(1) 指定輸出陣列的軸。

當我們呼叫 scale(v1s, 2.0) 時,vmap() 會自動將 v1s 陣列沿著第一個軸進行對映,並將結果乘以 2.0。這樣就可以實作對整個陣列的向量化運算。

圖表翻譯:

在這個流程圖中,我們可以看到輸入陣列 v1s 被傳入裝飾器 @partial(vmap),然後進行向量化運算,最後輸出結果陣列。這個過程展示瞭如何使用 vmap() 來實作向量化運算。

6.2.5 使用集體操作

在某些情況下,您可能需要在批次的不同元素之間進行通訊(這與您在裝置之間平行計算並需要在裝置之間傳遞資訊時相同)。在這種情況下,JAX 提供了集體操作,具有 jax.lax.p* 的字首。這些操作主要用於平行化(這是下一章的主題),用於跨裝置進行通訊,並且是為 pmap() 設計的。然而,它們也可以與 vmap() 一起使用,並且可能是實作某些功能(例如批次歸一化)時的一個非常方便的解決方案。

集體操作的工作原理

當您使用 vmap() 沿著某個軸向量化計算時,您可以使用 axis_name 引數指定該軸的名稱。您可以傳遞任何想要使用的名稱,這有助於您將該軸與可能存在的其他軸區分開來。它只是一個字串標籤。

您可以在集體操作中使用相同的 axis_name 引數參考命名軸。操作將在命名軸上執行。

示例:使用集體操作和軸名稱

讓我們考慮一個簡單的案例:將陣列值歸一化,使其總和為 1。以下程式碼示範瞭如何實作:

import jax.numpy as jnp
from jax import lax

# 建立一個範圍從 0 到 49 的陣列
arr = jnp.array(range(50))

# 定義一個函式,使用集體操作將陣列值歸一化
@jax.vmap
def normalize(arr):
    # 計算陣列值的總和
    total = lax.psum(arr, axis_name='batch')
    
    # 將陣列值歸一化
    return arr / total

# 執行歸一化函式
normalized_arr = normalize(arr)

print(normalized_arr)

在這個示例中,我們定義了一個 normalize 函式,使用 lax.psum 集體操作計算陣列值的總和。然後,我們使用 vmap 將此函式應用於整個陣列。結果是歸一化的陣列,其中每個元素的總和為 1。

使用 JAX 進行向量化計算

在進行大規模的資料處理和計算時,能夠高效地利用計算資源是非常重要的。JAX 是一個由 Google 開發的 Python 函式庫,提供了高效的向量化計算和自動微分功能。下面,我們將展示如何使用 JAX 進行向量化計算。

定義向量化函式

首先,我們需要定義一個向量化函式。這個函式將會對每個元素進行相同的操作。以下是使用 JAX 的 vmap 函式來定義向量化函式的示例:

import jax
import jax.numpy as jnp

# 定義向量化函式
norm = jax.vmap(
    lambda x: x / jax.lax.psum(x, axis_name='batch'),
    axis_name='batch'
)

在這個例子中,norm 函式對輸入陣列的每個元素進行標準化計算。標準化公式為 x / sum(x),其中 sum(x) 是陣列中所有元素的總和。jax.lax.psum 函式用於計算陣列中所有元素的總和,axis_name='batch' 引數指定了計算的軸。

執行向量化計算

現在,我們可以使用 norm 函式對陣列進行向量化計算。以下是執行計算的示例:

# 定義輸入陣列
arr = jnp.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])

# 執行向量化計算
result = norm(arr)

print(result)

執行這段程式碼後,result 變數將包含輸入陣列中每個元素的標準化結果。

驗證結果

為了驗證結果的正確性,我們可以計算標準化陣列的總和。根據標準化公式,標準化陣列的總和應該等於 1。以下是驗證結果的示例:

# 計算標準化陣列的總和
total = jnp.sum(norm(arr))

print(total)

執行這段程式碼後,total 變數應該等於 1,表示標準化結果是正確的。

6.3 實際案例:向量化你的程式碼

在瞭解了自動向量化的基礎後,讓我們將這些知識應用到幾個實際案例中。自動向量化是一種強大的工具,可以用於各種不同的情況。

6.3.1 批次資料處理

這是最直接的案例。你有一個函式可以處理單一元素(例如,一張圖片),然後你想將這個函式應用到一批元素上。如果你不是逐一處理每個元素,而是可以批次接收或處理元素,那麼使用 vmap() 就是最直接的選擇。

然而,需要注意的是,元素之間不應該有相互溝通,這是在批次處理中一個常見的情況。但是,如果你需要根據批次統計資料或跨越每個單獨元素的運算進行某種形式的歸一化,則可能不容易實作。

這種案例也適用於資料增強(Data Augmentation)。你可能有多個函式用於執行不同的增強,並決定哪一個(或任何組合)應用於任何特定元素。JAX 的控制流程原始碼可以在這裡提供幫助。

考慮一下從第 3 章中應用隨機增強的模型案例。以下是一個簡單的範例,展示如何使用 vmap() 對單一資料元素進行增強。

程式碼範例

import jax
import jax.numpy as jnp

# 定義增強函式
add_noise_func = lambda x: x + 10
horizontal_flip_func = lambda x: x + 1
rotate_func = lambda x: x + 2
adjust_colors_func = lambda x: x + 3

# 定義增強函式列表
augmentations = [add_noise_func, horizontal_flip_func, rotate_func, adjust_colors_func]

# 使用 vmap 對批次資料進行增強
@jax.vmap
def augment_batch(batch):
    # 選擇要應用的增強函式
    func = augmentations[jnp.random.randint(0, len(augmentations), shape=(batch.shape[0],))]
    return func(batch)

# 測試增強函式
batch_data = jnp.array([1, 2, 3, 4, 5])
augmented_data = augment_batch(batch_data)
print(augmented_data)

內容解密

在上述程式碼中,我們首先定義了四個簡單的增強函式:add_noise_funchorizontal_flip_funcrotate_funcadjust_colors_func。這些函式分別對輸入資料進行不同的運算,以模擬增強效果。

然後,我們定義了一個 augment_batch 函式,這個函式使用 jax.vmap 對批次資料進行向量化操作。內部,我們使用 jnp.random.randint 生成一個隨機索引,來選擇要應用的增強函式。最後,選定的增強函式被應用到每個批次資料元素上。

最終,增強後的批次資料被傳回並列印預出來。這個過程展示瞭如何使用 vmap() 對批次資料進行向量化操作,並如何結合隨機選擇的增強函式來對資料進行增強。

圖表翻譯

@startuml
skinparam backgroundColor #FEFEFE
skinparam componentStyle rectangle

title JAX自動向量化技術與效能分析

package "機器學習流程" {
    package "資料處理" {
        component [資料收集] as collect
        component [資料清洗] as clean
        component [特徵工程] as feature
    }

    package "模型訓練" {
        component [模型選擇] as select
        component [超參數調優] as tune
        component [交叉驗證] as cv
    }

    package "評估部署" {
        component [模型評估] as eval
        component [模型部署] as deploy
        component [監控維護] as monitor
    }
}

collect --> clean : 原始資料
clean --> feature : 乾淨資料
feature --> select : 特徵向量
select --> tune : 基礎模型
tune --> cv : 最佳參數
cv --> eval : 訓練模型
eval --> deploy : 驗證模型
deploy --> monitor : 生產模型

note right of feature
  特徵工程包含:
  - 特徵選擇
  - 特徵轉換
  - 降維處理
end note

note right of eval
  評估指標:
  - 準確率/召回率
  - F1 Score
  - AUC-ROC
end note

@enduml

這個流程圖展示了批次資料如何透過隨機選擇的增強函式進行增強,最終產生增強後的資料。

影像增強的隨機轉換

在影像處理中,隨機轉換是一種常見的技術,用於增加影像的多樣性和豐富性。下面是一個簡單的例子,展示如何使用JAX函式庫實作影像增強的隨機轉換。

從效能最佳化視角來看,vmap() 的自動向量化能力在 JAX 中扮演著至關重要的角色。藉由將函式對映到陣列的各個元素,vmap() 避免了緩慢的迴圈操作,大幅提升了計算效率,尤其在處理批次資料、影像處理、神經網路運算等場景中,效能提升尤為顯著。然而,vmap() 並非萬能解藥。對於需要批次元素間通訊的場景,例如批次歸一化,則需仰賴 jax.lax 中的集體操作,例如 psum,才能有效處理。此外,精確控制 in_axesout_axes 引數,配合 JAX 的控制流程及 JIT 編譯,才能將 vmap() 的效能優勢發揮到極致。展望未來,隨著 JAX 生態的持續發展,預期 vmap() 將與更多 JAX 功能深度整合,進一步簡化向量化運算的流程,並在更廣泛的應用場景中釋放其效能潛力。對於追求極致效能的開發者而言,深入理解 vmap() 的運作機制及最佳實務至關重要。