0x0. 背景
隨著年紀(jì)越來越大,讀代碼越來越困難,如果你發(fā)現(xiàn)看不懂同事寫的代碼應(yīng)該怎么辦呢?不要擔(dān)心,大語言模型的時代了來了,chatgpt和gpt4會教會我們怎么讀代碼。本篇文章就來展示一下使用newbing(chatgpt)來讀oneflow softmax相關(guān)的fuse優(yōu)化kernel的過程。本文的代碼解釋均由chatgpt生成,我只是手工做了非常少的一點(diǎn)微調(diào)來保證對代碼解釋的正確性。完整代碼解釋見: https://github.com/BBuf/how-to-optim-algorithm-in-cuda/blob/master/oneflow-cuda-optimize-skills/fused_softmax/fused_scale_mask_softmax.cu#L1452-L2098
我在【BBuf的CUDA筆記】八,對比學(xué)習(xí)OneFlow 和 FasterTransformer 的 Softmax Cuda實(shí)現(xiàn) 中對oneflow以及FasterTransformer做了源碼解析,讀者可以先了解一下。在oneflow的softmax cuda kernel中已經(jīng)提到它的每種實(shí)現(xiàn)均有LOAD、STORE兩個模板參數(shù)分別代表輸入輸出。使用 load.template load<pack_size>(ptr, row_id, col_id);
和store.template store<pack_size>(ptr, row_id, col_id);
進(jìn)行讀取和寫入。使用LOAD和STORE有兩個好處:1、可以在CUDA Kernel中只關(guān)心計(jì)算類型ComputeType,而不用關(guān)心具體的數(shù)據(jù)類型T。2、只需要加幾行代碼就可以快速支持Softmax和其他Kernel Fuse,減少帶寬需求,提升整體性能。普通的SoftmaxKernel直接使用DirectLoad和DirectStore,F(xiàn)usedScaleSoftmaxKernel(也就是本文要介紹的)只需要額外定義一個ScaleMaskLoad結(jié)構(gòu)用于對輸入x做Scale預(yù)處理以及ScaleMaskStore即可快速完成kernel fuse。
本篇文章在oneflow softmax的基礎(chǔ)上將fused_scale_mask_softmax的完整實(shí)現(xiàn)抽取到了https://github.com/BBuf/how-to-optim-algorithm-in-cuda/blob/master/oneflow-cuda-optimize-skills/fused_softmax/fused_scale_mask_softmax.cu 文件中并進(jìn)行解析,讀者可以使用 https://github.com/BBuf/how-to-optim-algorithm-in-cuda/blob/master/oneflow-cuda-optimize-skills/fused_softmax/Makefile 文件來編譯這個.cu文件并運(yùn)行,得到正確的結(jié)果。
0x1. Pattern介紹
我們先來介紹一下我們要做的這個fuse pattern是什么樣子的,其實(shí)它就是Transformer解碼模塊的softmax部分,由于解碼時無法看到當(dāng)前token的后續(xù)token,所以需要對輸入做一個mask scale操作。具體來說就是把mask tensor中為true的位置對應(yīng)的原始tensor元素乘以一個常數(shù)scale,并且把mask tensor中為false的位置對應(yīng)的原始tensor元素填充為一個-inf。其實(shí)也不一定是inf,只要足夠大就可以了,比如經(jīng)常取10000.0。用python代碼表達(dá)這個pattern就是這樣子的:
import numpy as np
import oneflow as flow
batch_size = 4
num_heads = 8
seq_length = 64
broadcast_dim = 1
fill_value = -10000.0
scale_value = 2.0
x = np.random.randn(batch_size, num_heads, seq_length, seq_length).astype(
np.float32
)
mask_size = [batch_size, num_heads, seq_length, seq_length]
if broadcast_dim:
mask_size[broadcast_dim] = 1
mask = np.random.randint(0, 2, size=mask_size, dtype=bool)
fused_x_tensor = flow.tensor(x, dtype=flow.float32).to("cuda")
fused_mask_tensor = flow.tensor(mask, dtype=flow.bool).to("cuda")
fused_x_tensor.requires_grad = True
fused_out = flow._C.fused_scale_mask_softmax(
fused_x_tensor, fused_mask_tensor, fill_value=fill_value, scale=scale_value,
)
origin_x_tensor = flow.tensor(x).to("cuda")
origin_mask_tensor = flow.tensor(mask, dtype=flow.float32).to("cuda")
origin_out = flow.mul(
origin_x_tensor, origin_mask_tensor
) * scale_value + fill_value * (1.0 - origin_mask_tensor)
origin_out = flow.softmax(origin_out, dim=-1)
assert(np.allclose(fused_out.numpy(), origin_out.numpy(), atol=1e-4, rtol=1e-4))
這個pattern就是把flow.mul(origin_x_tensor, origin_mask_tensor) * scale_value + fill_value * (1.0 - origin_mask_tensor)
里面的操作融合為一個cuda kernel,也就是flow._C.fused_scale_mask_softmax(fused_x_tensor, fused_mask_tensor, fill_value=fill_value, scale=scale_value,)
。額外需要注意的是這個mask是支持廣播的,所以oneflow中對softmax kernel中的LOAD和STORE分別定義了BroadcastScaleMaskLoad/ElementwiseScaleMaskLoad以及BroadcastScaleMaskStore/ElementwiseScaleMaskStore 2組實(shí)現(xiàn)來進(jìn)行fused_scale_mask_softmax算子的數(shù)據(jù)讀取和存儲。
0x2. 使用chatgpt來讀代碼實(shí)現(xiàn)
接下來我們就用chatgpt來讀一讀BroadcastScaleMaskLoad/ElementwiseScaleMaskLoad以及BroadcastScaleMaskStore/ElementwiseScaleMaskStore的實(shí)現(xiàn)。
首先來看ElementwiseScaleMaskLoad的實(shí)現(xiàn):
template<typename SRC, typename DST, typename MASK>
struct ElementwiseScaleMaskLoad {
ElementwiseScaleMaskLoad(const SRC* src, const MASK* mask, ElementwiseMaskSoftmaxParams param)
: src(src), mask(mask), param(param) {}
template<int N>
__device__ void load(DST* dst, int64_t row, int64_t col) {
Pack<SRC, N> pack;
const int64_t offset = (row * param.row_size + col) / N;
pack.storage = *(reinterpret_cast<const PackType<SRC, N>*>(src) + offset);
Pack<int8_t, N> mask_pack;
mask_pack.storage = *(reinterpret_cast<const PackType<MASK, N>*>(mask) + offset);
#pragma unroll
for (int i = 0; i < N; ++i) {
if (mask_pack.elem[i] == 0) {
dst[i] = static_cast<DST>(param.fill);
} else {
dst[i] = static_cast<DST>(pack.elem[i]) * static_cast<DST>(param.scale);
}
}
}
const SRC* src;
const MASK* mask;
ElementwiseMaskSoftmaxParams param;
};
我們讓chatgpt來幫我們分析一下這個結(jié)構(gòu)體,chatgpt為我們分析了這個結(jié)構(gòu)體并給出了以下的信息:
// 它的功能是定義一個模板類,用于在kernel函數(shù)中加載逐元素掩碼softmax的輸入。它的模板參數(shù)有:
// SRC: 源張量的數(shù)據(jù)類型
// DST: 目標(biāo)張量的數(shù)據(jù)類型
// MASK: 掩碼張量的數(shù)據(jù)類型
// 它的構(gòu)造函數(shù)接受以下參數(shù):
// src: 源張量的指針
// mask: 掩碼張量的指針
// param: 逐元素掩碼softmax的參數(shù),是一個ElementwiseMaskSoftmaxParams結(jié)構(gòu)體的實(shí)例
// load: 一個模板函數(shù),用于從源張量和掩碼張量中加載N個元素,并將它們存儲到目標(biāo)張量中。它接受以下參數(shù):
// dst: 目標(biāo)張量的指針
// row: 當(dāng)前的行號
// col: 當(dāng)前的列號
// 它的內(nèi)部邏輯是:
// 根據(jù)行號和列號計(jì)算源張量和掩碼張量的偏移量
// 從源張量和掩碼張量中讀取N個元素,分別存儲到pack和mask_pack中
// 遍歷N個元素,如果掩碼張量的元素為0,就將目標(biāo)張量的元素設(shè)為填充值,否則就將源張量的元素乘以縮放因子,并轉(zhuǎn)換為目標(biāo)張量的數(shù)據(jù)類型
// 這個模板類的成員變量有:
// src: 源張量的指針
// mask: 掩碼張量的指針
// param: 逐元素掩碼softmax的參數(shù),是一個ElementwiseMaskSoftmaxParams結(jié)構(gòu)體的實(shí)例
// 這個模板類的目的是為了方便在kernel函數(shù)中使用逐元素掩碼softmax的輸入,提高深度學(xué)習(xí)的性能。
可以看到chatgpt給出了非常詳細(xì)的邏輯解釋,并且還指出提高深度學(xué)習(xí)性能。提高深度學(xué)習(xí)性能這句話應(yīng)該是猜出來的,畢竟從這段代碼chatgpt是不可能了解到這個結(jié)構(gòu)體是和softmax fuse kernel相關(guān)的,但能感覺到chatgpt是真的有理解這段代碼已經(jīng)非常厲害了。接著我們問問#pragma unroll
這一行代碼是干什么的?
我們再問一下這里的Pack是什么意思?
可以看到chatgpt還是較為準(zhǔn)備的為我們解釋了這里Pack的含義以及指出是便于在GPU上進(jìn)行高效的內(nèi)存訪問。
最后再拷打一下pack.storage = *(reinterpret_cast<const PackType<SRC, N>*>(src) + offset)
這句話涉及到的邏輯:
太強(qiáng)了,根本難不倒他。當(dāng)然了,一般element_wise相關(guān)的定義都是略顯簡單的,如果我們加入廣播的邏輯也就是讓chatgpt來讀一下BroadcastScaleMaskLoad看看效果如何:
// 它的功能是定義一個模板類,用于在內(nèi)核函數(shù)中加載廣播掩碼softmax的輸入。它的模板參數(shù)有:
// SRC: 源張量的數(shù)據(jù)類型
// DST: 目標(biāo)張量的數(shù)據(jù)類型
// MASK: 掩碼張量的數(shù)據(jù)類型
// num_dims: 張量的維度數(shù)
// IndexType: 索引的數(shù)據(jù)類型
// 它的構(gòu)造函數(shù)接受以下參數(shù):
// src: 源張量的指針
// mask: 掩碼張量的指針
// params: 廣播掩碼softmax的參數(shù),是一個BroadcastMaskSoftmaxParams結(jié)構(gòu)體的實(shí)例
// 它的成員函數(shù)有:
// load: 一個模板函數(shù),用于從源張量和掩碼張量中加載N個元素,并將它們存儲到目標(biāo)張量中。它接受以下參數(shù):
// dst: 目標(biāo)張量的指針
// row: 當(dāng)前的行號
// col: 當(dāng)前的列號
// 它的內(nèi)部邏輯是:
// 根據(jù)行號和列號計(jì)算源張量的偏移量
// 根據(jù)偏移量計(jì)算源張量的多維索引
// 根據(jù)多維索引計(jì)算掩碼張量的偏移量
// 從源張量和掩碼張量中讀取N個元素,分別存儲到pack和mask_pack中
// 遍歷N個元素,如果掩碼張量的元素為0,就將目標(biāo)張量的元素設(shè)為填充值,否則就將源張量的元素乘以縮放因子,并轉(zhuǎn)換為目標(biāo)張量的數(shù)據(jù)類型
// 這個模板類的成員變量有:
// src: 源張量的指針
// mask: 掩碼張量的指針
// mask_dims: 掩碼張量的每個維度的大小,是一個數(shù)組
// params: 廣播掩碼softmax的參數(shù),是一個BroadcastMaskSoftmaxParams結(jié)構(gòu)體的實(shí)例
// 這個模板類的目的是為了方便在內(nèi)核函數(shù)中使用廣播掩碼softmax的輸入,提高深度學(xué)習(xí)的性能。
template<typename SRC, typename DST, typename MASK, size_t num_dims, typename IndexType>
struct BroadcastScaleMaskLoad {
BroadcastScaleMaskLoad(const SRC* src, const MASK* mask,
BroadcastMaskSoftmaxParams<num_dims, IndexType> params)
: src(src), mask(mask), params(params) {
for (int i = 0; i < num_dims; i++) { mask_dims[i] = params.mask_dims[i]; }
}
template<int N>
__device__ void load(DST* dst, int64_t row, int64_t col) {
Pack<SRC, N> pack;
Pack<MASK, N> mask_pack;
const IndexType offset = row * params.row_size + col;
IndexType input_index[num_dims];
IndexType mask_index[num_dims];
params.src_index_helper.OffsetToNdIndex(offset, input_index);
for (int dim = 0; dim < num_dims; ++dim) {
if (mask_dims[dim] == 1) {
mask_index[dim] = 0;
} else {
mask_index[dim] = input_index[dim];
}
}
const IndexType mask_offset = params.mask_index_helper.NdIndexToOffset(mask_index);
pack.storage = *(reinterpret_cast<const PackType<SRC, N>*>(src) + offset / N);
mask_pack.storage =
*(reinterpret_cast<const PackType<MASK, N>*>(mask) + mask_offset / N);
#pragma unroll
for (int i = 0; i < N; ++i) {
if (mask_pack.elem[i] == 0) {
dst[i] = static_cast<DST>(params.fill);
} else {
dst[i] = static_cast<DST>(pack.elem[i]) * static_cast<DST>(params.scale);
}
}
}
const SRC* src;
const MASK* mask;
int64_t mask_dims[num_dims];
BroadcastMaskSoftmaxParams<num_dims, IndexType> params;
};
chatgpt給了我們一份看起來非常合理的注釋,并且似乎它已經(jīng)參透了這個結(jié)構(gòu)體類就是為了方便在kernel函數(shù)中使用廣播掩碼softmax的輸入。接下來我們依舊拷打一下細(xì)節(jié),首先我們來問一下BroadcastMaskSoftmaxParams這個結(jié)構(gòu)體類,并讓它給出解釋:
// 它的功能是定義一個結(jié)構(gòu)體,用于存儲廣播掩碼softmax的參數(shù)。
// 廣播掩碼softmax是一種在不同形狀的張量之間進(jìn)行softmax運(yùn)算的方法,
// 它會將較小的張量擴(kuò)展到較大的張量的形狀上,并用一個掩碼張量來指定哪些位置需要計(jì)算softmax,
// 哪些位置需要填充一個固定的值。這個結(jié)構(gòu)體的成員變量有:
// src_index_helper: 一個用于計(jì)算源張量索引的輔助類
// mask_index_helper: 一個用于計(jì)算掩碼張量索引的輔助類
// mask_dims: 掩碼張量的每個維度的大小
// row_size: 每一行的元素個數(shù)
// fill: 填充的值
// scale: softmax的縮放因子
// 這個結(jié)構(gòu)體的目的是為了方便在內(nèi)核函數(shù)中使用廣播掩碼softmax的參數(shù),提高深度學(xué)習(xí)的性能。
template<size_t num_dims, typename IndexType>
struct BroadcastMaskSoftmaxParams {
NdIndexOffsetHelper<IndexType, num_dims> src_index_helper;
NdIndexOffsetHelper<IndexType, num_dims> mask_index_helper;
const int64_t* mask_dims{};
int64_t row_size;
float fill;
float scale;
};
非常準(zhǔn)確!接著我們再問一下NdIndexOffsetHelper這個類的細(xì)節(jié),chatgpt給出的回答如下:
// 定義 NdIndexOffsetHelper 類來做多維張量的坐標(biāo)映射
template<typename T, int N>
class NdIndexOffsetHelper {
public:
NdIndexOffsetHelper() = default;
// 這段代碼是一個構(gòu)造函數(shù)模板,它的功能是根據(jù)給定的多維數(shù)組的維度初始化一個用于計(jì)算偏移量的輔助類。
// 它的參數(shù)是一個可變參數(shù)模板,表示可以接受任意個數(shù)和類型的參數(shù),但是第一個參數(shù)必須是一個T類型的值,
// 后面的參數(shù)必須是一個參數(shù)包。它的函數(shù)體是一個調(diào)用另一個函數(shù)的語句,這個函數(shù)是用于初始化步長數(shù)組的。
// 它使用了一個宏OF_DEVICE_FUNC,這個宏可能是用于指定函數(shù)在哪個設(shè)備上運(yùn)行的,比如GPU或者CPU。
// 這個構(gòu)造函數(shù)模板可能是用于實(shí)現(xiàn)一些多維數(shù)組的操作,比如插值或者轉(zhuǎn)置。
template<class... Ts>
OF_DEVICE_FUNC explicit NdIndexOffsetHelper(T d0, Ts... dims) {
constexpr int n = 1 + sizeof...(dims);
static_assert(n <= N, "");
T dims_arr[n] = {d0, static_cast<T>(dims)...};
// 初始化strides信息
InitStrides(dims_arr, n);
}
// 從一個類型為T的數(shù)組進(jìn)行構(gòu)造,并初始化strides信息,注意這里的strides長度設(shè)置為N
OF_DEVICE_FUNC explicit NdIndexOffsetHelper(const T* dims) { InitStrides(dims, N); }
// 從一個類型為U的數(shù)組進(jìn)行構(gòu)造,并初始化strides信息,注意這里的strides長度設(shè)置為N
template<typename U>
OF_DEVICE_FUNC explicit NdIndexOffsetHelper(const U* dims) {
T dims_arr[N];
for (int i = 0; i < N; ++i) { dims_arr[i] = dims[i]; }
InitStrides(dims_arr, N);
}
// 從一個類型為T的數(shù)組進(jìn)行構(gòu)造,并初始化strides信息,注意這里的strides長度自定義為n
OF_DEVICE_FUNC explicit NdIndexOffsetHelper(const T* dims, int n) { InitStrides(dims, n); }
// 從一個類型為U的數(shù)組進(jìn)行構(gòu)造,并初始化strides信息,注意這里的strides長度自定義為n
template<typename U>
OF_DEVICE_FUNC explicit NdIndexOffsetHelper(const U* dims, int n) {
T dims_arr[N];
for (int i = 0; i < N; ++i) {
if (i < n) { dims_arr[i] = dims[i]; }
}
InitStrides(dims_arr, n);
}
// virtual 表示這是一個虛析構(gòu)函數(shù),用于在刪除基類指針時調(diào)用派生類的析構(gòu)函數(shù),避免內(nèi)存泄漏。
// ~NdIndexOffsetHelper() 表示這是 NdIndexOffsetHelper 類的析構(gòu)函數(shù),用于釋放類對象占用的資源。
// = default; 表示這是一個默認(rèn)的析構(gòu)函數(shù),沒有自定義的操作,讓編譯器自動生成。
virtual ~NdIndexOffsetHelper() = default;
// 這段代碼是一個模板函數(shù),用于根據(jù)一個N維索引數(shù)組計(jì)算一個一維偏移量。函數(shù)的參數(shù)和返回值都是模板類型T,可以是任意數(shù)值類型。函數(shù)的主要步驟如下:
OF_DEVICE_FUNC T NdIndexToOffset(const T* index) const {
// 定義一個變量offset,初始值為0,用于存儲最終的偏移量。
T offset = 0;
#ifdef __CUDA_ARCH__
#pragma unroll
#endif
// 使用一個循環(huán),從0到N-1,遍歷索引數(shù)組的每一個元素。在循環(huán)中,使用一個數(shù)組stride_,
// 用于存儲每一個維度的步長,即每增加一個單位的索引,偏移量增加多少。
// 將索引數(shù)組的第i個元素乘以步長數(shù)組的第i個元素,然后累加到offset上。
for (int i = 0; i < N; ++i) { offset += index[i] * stride_[i]; }
return offset;
}
// 類似上面,不過這里是從0到n進(jìn)行循環(huán)
OF_DEVICE_FUNC T NdIndexToOffset(const T* index, int n) const {
assert(n <= N);
T offset = 0;
#ifdef __CUDA_ARCH__
#pragma unroll
#endif
for (int i = 0; i < N; ++i) {
if (i < n) { offset += index[i] * stride_[i]; }
}
return offset;
}
// 這段代碼是一個函數(shù)模板,它的功能是根據(jù)給定的多維索引計(jì)算一個一維偏移量。
// 它的參數(shù)是一個可變參數(shù)模板,表示可以接受任意個數(shù)和類型的參數(shù),但是第一個
// 參數(shù)必須是一個T類型的值,后面的參數(shù)必須是一個參數(shù)包。它的返回值也是一個T類型的值。
// 它的函數(shù)體是一個循環(huán),用于累加每個維度的索引乘以對應(yīng)的步長,得到最終的偏移量。
// 它使用了一個宏OF_DEVICE_FUNC,這個宏可能是用于指定函數(shù)在哪個設(shè)備上運(yùn)行的,
// 比如GPU或者CPU。這個函數(shù)模板可能是用于實(shí)現(xiàn)一些多維數(shù)組的操作,比如插值或者轉(zhuǎn)置。
template<class... Ts>
OF_DEVICE_FUNC T NdIndexToOffset(T d0, Ts... others) const {
constexpr int n = 1 + sizeof...(others);
static_assert(n <= N, "");
T index[n] = {d0, others...};
T offset = 0;
#ifdef __CUDA_ARCH__
#pragma unroll
#endif
for (int i = 0; i < n - 1; ++i) { offset += index[i] * stride_[i]; }
if (n == N) {
offset += index[n - 1];
} else {
offset += index[n - 1] * stride_[n - 1];
}
return offset;
}
// 這段代碼是一個成員函數(shù)模板,它的功能是根據(jù)給定的一維偏移量計(jì)算一個多維索引。
// 它的參數(shù)是一個T類型的值,表示偏移量,和一個T類型的指針,表示索引數(shù)組。
// 它的函數(shù)體是一個循環(huán),用于逐個維度地計(jì)算索引值,然后更新剩余的偏移量。
// 它使用了一個宏OF_DEVICE_FUNC,這個宏可能是用于指定函數(shù)在哪個設(shè)備上運(yùn)行的,比如GPU或者CPU 。
// 這個成員函數(shù)模板可能是用于實(shí)現(xiàn)一些多維數(shù)組的操作,比如插值或者轉(zhuǎn)置。
OF_DEVICE_FUNC void OffsetToNdIndex(T offset, T* index) const {
T remaining = offset;
#ifdef __CUDA_ARCH__
#pragma unroll
#endif
for (int i = 0; i < N - 1; ++i) {
const T idx = remaining / stride_[i];
index[i] = idx;
remaining = remaining - idx * stride_[i];
}
index[N - 1] = remaining;
}
// 這段代碼是用C++語言編寫的,它定義了一個名為OffsetToNdIndex的函數(shù),
// 該函數(shù)的功能是將一維的偏移量轉(zhuǎn)換為高維的索引1。這個函數(shù)是OneFlow的內(nèi)部類,
// OneFlow是一個深度學(xué)習(xí)框架,它支持分布式訓(xùn)練和推理。這個函數(shù)的參數(shù)有三個,分別是:
// offset: 一個整數(shù),表示一維的偏移量。
// index: 一個整數(shù)數(shù)組,用于存儲轉(zhuǎn)換后的高維索引。
// n: 一個整數(shù),表示高維的維度數(shù),不能超過N,N是一個常量。
// 函數(shù)的主要邏輯是:
// 首先,用一個變量remaining存儲offset的值。
// 然后,用一個循環(huán)遍歷從0到N-1的整數(shù)i。
// 在循環(huán)中,如果i小于n,那么就用remaining除以stride_[i]得到一個整數(shù)idx,這個stride_[i]是一個預(yù)定義的數(shù)組,表示每個維度的步長。
// 然后,將idx賦值給index[i],并用remaining減去idx乘以stride_[i],更新remaining的值。
// 最后,結(jié)束循環(huán)。
// 這個函數(shù)的作用是將一維的偏移量映射到高維的索引,這在深度學(xué)習(xí)中有很多應(yīng)用,比如Unfold和Fold算子,它們可以將圖像的局部區(qū)域轉(zhuǎn)換為一維的向量,或者反過來。
OF_DEVICE_FUNC void OffsetToNdIndex(T offset, T* index, int n) const {
assert(n <= N);
T remaining = offset;
#ifdef __CUDA_ARCH__
#pragma unroll
#endif
for (int i = 0; i < N; ++i) {
if (i < n) {
const T idx = remaining / stride_[i];
index[i] = idx;
remaining = remaining - idx * stride_[i];
}
}
}
// 它定義了一個名為OffsetToNdIndex的函數(shù)模板,該函數(shù)模板的功能和之前的函數(shù)類似,
// 也是將一維的偏移量轉(zhuǎn)換為高維的索引,但是它可以接受不同數(shù)量的參數(shù)。這個函數(shù)模板的參數(shù)有兩個,分別是:
// offset: 一個整數(shù),表示一維的偏移量。
// d0, others: 一系列的整數(shù)引用,用于存儲轉(zhuǎn)換后的高維索引。
// 函數(shù)模板的主要邏輯是:
// 首先,用一個常量n表示參數(shù)的個數(shù),它等于1加上others的個數(shù)。
// 然后,用一個靜態(tài)斷言檢查n是否小于等于N,N是一個常量。
// 然后,用一個指針數(shù)組index存儲d0和others的地址。
// 然后,用一個變量remaining存儲offset的值。
// 然后,用一個循環(huán)遍歷從0到n-2的整數(shù)i。
// 在循環(huán)中,如果i小于n-1,那么就用remaining除以stride_[i]得到一個整數(shù)idx,這個stride_[i]是一個預(yù)定義的數(shù)組,表示每個維度的步長。
// 然后,將idx賦值給index[i]所指向的變量,并用remaining減去idx乘以stride_[i],更新remaining的值。
// 最后,根據(jù)n和N的關(guān)系,分兩種情況處理最后一個參數(shù):
// 如果n等于N,那么就將remaining賦值給index[n-1]所指向的變量。
// 如果n小于N,那么就用remaining除以stride_[n-1]得到一個整數(shù),賦值給index[n-1]所指向的變量。
// 這個函數(shù)模板的作用是將一維的偏移量映射到高維的索引,它可以根據(jù)不同的參數(shù)個數(shù)進(jìn)行重載,這是C++的一種泛型編程的特性
template<class... Ts>
OF_DEVICE_FUNC void OffsetToNdIndex(T offset, T& d0, Ts&... others) const {
constexpr int n = 1 + sizeof...(others);
static_assert(n <= N, "");
T* index[n] = {&d0, &others...};
T remaining = offset;
#ifdef __CUDA_ARCH__
#pragma unroll
#endif
for (int i = 0; i < n - 1; ++i) {
const T idx = remaining / stride_[i];
*index[i] = idx;
remaining = remaining - idx * stride_[i];
}
if (n == N) {
*index[n - 1] = remaining;
} else {
*index[n - 1] = remaining / stride_[n - 1];
}
}
OF_DEVICE_FUNC constexpr int Size() const { return N; }
protected:
// 這段代碼也是用C++語言編寫的,它定義了一個名為InitStrides的函數(shù),
// 該函數(shù)的功能是初始化stride_數(shù)組,該數(shù)組表示每個維度的步長。這個函數(shù)的參數(shù)有兩個,分別是:
// dims: 一個整數(shù)數(shù)組,表示高維的維度大小。
// n: 一個整數(shù),表示高維的維度數(shù),不能超過N,N是一個常量。
// 函數(shù)的主要邏輯是:
// 首先,用一個循環(huán)遍歷從n-1到N-1的整數(shù)i。
// 在循環(huán)中,將stride_[i]賦值為1。
// 然后,用一個循環(huán)遍歷從n-2到0的整數(shù)i。
// 在循環(huán)中,將stride_[i]賦值為dims[i+1]乘以stride_[i+1]。
// 這個函數(shù)的作用是計(jì)算每個維度的步長,這在之前的OffsetToNdIndex函數(shù)中有用到,它可以根據(jù)不同的維度大小和維度數(shù)進(jìn)行初始化。
OF_DEVICE_FUNC void InitStrides(const T* dims, const int n) {
for (int i = n - 1; i < N; ++i) { stride_[i] = 1; }
for (int i = n - 2; i >= 0; --i) { stride_[i] = dims[i + 1] * stride_[i + 1]; }
}
T stride_[N];
};
可以看到chatgpt不僅理解了NdIndexOffsetHelper是在做坐標(biāo)映射,甚至把每個構(gòu)造函數(shù)的功能以及涉及到的c++語法比如可變參數(shù)模板都給我們展示了。
看到這里,chatgpt已經(jīng)幫助我們理解了BroadcastScaleMaskLoad/ElementwiseScaleMaskLoad這兩個結(jié)構(gòu)體類的功能和細(xì)節(jié)?;緵]有出錯,非常讓我震撼。接下來,在fused_scale_mask_softmax的啟動過程中還調(diào)用了一個廣播維度簡化的函數(shù)SimplifyBroadcastDims:https://github.com/BBuf/how-to-optim-algorithm-in-cuda/blob/master/oneflow-cuda-optimize-skills/fused_softmax/fused_scale_mask_softmax.cu#L2136 。我們來看看chatgpt是否可以合理的解釋這個函數(shù)呢?
// 這個函數(shù)的功能是簡化兩個張量的廣播維度,使它們能夠進(jìn)行算術(shù)運(yùn)算。
// 廣播是一種在不同形狀的張量之間進(jìn)行運(yùn)算的方法,它會將較小的張量擴(kuò)展到較大的張量的形狀上。
// 這段代碼的輸入?yún)?shù)有:
// num_a_dims: 張量A的維度數(shù)
// a_dims: 張量A的每個維度的大小
// num_b_dims: 張量B的維度數(shù)
// b_dims: 張量B的每個維度的大小
// simplified_num_dims: 簡化后的廣播維度數(shù)
// simplified_a_dims: 簡化后的張量A的每個維度的大小
// simplified_b_dims: 簡化后的張量B的每個維度的大小
// 這段代碼的主要邏輯是:
// 首先,找到兩個張量的最大維度數(shù),然后用一個閉包函數(shù)MakeGetDim來獲取每個張量在每個維度上的大小,如果維度數(shù)不足,就用1來填充。
// 然后,遍歷每個維度,計(jì)算兩個張量在該維度上的最大值,作為廣播后的維度大小。
// 如果該維度大小為1,就跳過,否則就判斷是否可以和上一個維度合并,如果可以,就乘以上一個維度的大小,
// 如果不可以,就添加到簡化后的維度數(shù)組中,并記錄是否是廣播維度。
// 最后,返回簡化后的廣播維度數(shù)和兩個張量的簡化后的維度大小。
// 這段代碼的目的是為了減少廣播運(yùn)算的開銷,提高深度學(xué)習(xí)的性能。
inline void SimplifyBroadcastDims(size_t num_a_dims, const int64_t* a_dims, size_t num_b_dims,
const int64_t* b_dims, size_t* simplified_num_dims,
int64_t* simplified_a_dims, int64_t* simplified_b_dims) {
const size_t num_max_dims = std::max(num_a_dims, num_b_dims);
auto MakeGetDim = [num_max_dims](size_t num_dims, const int64_t* dims) {
const int64_t num_padding_dims = num_max_dims - num_dims;
return [num_padding_dims, dims](size_t index) {
return index < num_padding_dims ? 1 : dims[index - num_padding_dims];
};
};
auto GetADim = MakeGetDim(num_a_dims, a_dims);
auto GetBDim = MakeGetDim(num_b_dims, b_dims);
*simplified_num_dims = 0;
bool prev_broadcast_a = false;
bool prev_broadcast_b = false;
for (int64_t i = 0; i < num_max_dims; ++i) {
const int64_t a_dim = GetADim(i);
const int64_t b_dim = GetBDim(i);
const int64_t broadcast_dim = std::max(a_dim, b_dim);
// CHECK_GT(broadcast_dim, 0);
const bool broadcast_a = (a_dim == 1);
const bool broadcast_b = (b_dim == 1);
// CHECK((a_dim == broadcast_dim) || broadcast_a);
// CHECK((b_dim == broadcast_dim) || broadcast_b);
if (broadcast_dim == 1) {
continue;
} else if (*simplified_num_dims != 0
&& (prev_broadcast_a == broadcast_a && prev_broadcast_b == broadcast_b)) {
simplified_a_dims[*simplified_num_dims - 1] *= a_dim;
simplified_b_dims[*simplified_num_dims - 1] *= b_dim;
} else {
simplified_a_dims[*simplified_num_dims] = a_dim;
simplified_b_dims[*simplified_num_dims] = b_dim;
*simplified_num_dims += 1;
prev_broadcast_a = broadcast_a;
prev_broadcast_b = broadcast_b;
}
}
}
可以看到chatgpt給出的注釋也是非常準(zhǔn)確的,可以指出這段代碼是為了減少廣播運(yùn)算的開銷,提高深度學(xué)習(xí)的性能。
讓chatgpt幫讀代碼到這里就結(jié)束了,BroadcastScaleMaskStore/ElementwiseScaleMaskStore的生成結(jié)果我也放到了 https://github.com/BBuf/how-to-optim-algorithm-in-cuda/blob/master/oneflow-cuda-optimize-skills/fused_softmax/fused_scale_mask_softmax.cu#L2136 。這次體驗(yàn)讓我不得不感嘆chatgpt的邏輯理解能力已經(jīng)非常接近人類了,雖然寫代碼的能力相對還差一些但相信經(jīng)過openai的數(shù)據(jù)以及參數(shù)迭代寫代碼的能力也將突飛猛進(jìn)。
0x3. 性能表現(xiàn)
接著我們來實(shí)測一下這個fused_scale_mask_softmax的效果吧。就以Pattern這一節(jié)的代碼為例子,使用nsight compute來觀測一下fused_scale_mask_softmax相比于原始實(shí)現(xiàn)的性能。profile的代碼如下:
import numpy as np
import oneflow as flow
batch_size = 4
num_heads = 8
seq_length = 64
broadcast_dim = 1
fill_value = -10000.0
scale_value = 2.0
x = np.random.randn(batch_size, num_heads, seq_length, seq_length).astype(
np.float32
)
mask_size = [batch_size, num_heads, seq_length, seq_length]
if broadcast_dim:
mask_size[broadcast_dim] = 1
mask = np.random.randint(0, 2, size=mask_size, dtype=bool)
flow._oneflow_internal.profiler.RangePush('loop begin')
for i in range(100):
fused_x_tensor = flow.tensor(x, dtype=flow.float32).to("cuda")
fused_mask_tensor = flow.tensor(mask, dtype=flow.bool).to("cuda")
flow._oneflow_internal.profiler.RangePush('fused_scale_mask_softmax')
fused_out = flow._C.fused_scale_mask_softmax(
fused_x_tensor, fused_mask_tensor, fill_value=fill_value, scale=scale_value,
)
flow._oneflow_internal.profiler.RangePop()
origin_x_tensor = flow.tensor(x).to("cuda")
origin_mask_tensor = flow.tensor(mask, dtype=flow.float32).to("cuda")
flow._oneflow_internal.profiler.RangePush('origin scale mask softmax')
origin_out = flow.mul(
origin_x_tensor, origin_mask_tensor
) * scale_value + fill_value * (1.0 - origin_mask_tensor)
origin_out = flow.softmax(origin_out, dim=-1)
flow._oneflow_internal.profiler.RangePop()
flow._oneflow_internal.profiler.RangePop()
assert(np.allclose(fused_out.numpy(), origin_out.numpy(), atol=1e-4, rtol=1e-4))
這里循環(huán)了100次計(jì)算并且使用內(nèi)置api flow._oneflow_internal.profiler.RangePush/Pop來埋點(diǎn)。獲得nsys文件之后使用NVIDIA Nsight Systems打開并進(jìn)行分析。原始實(shí)現(xiàn)對應(yīng)的時序圖如紅色框所示,計(jì)算一次的總時間大約100us:
而用上fused_scale_mask_softmax之后的時序圖如下所示,可以看到整個計(jì)算過程只有一個kernel并且只花了8個us,是非常大的性能提升。
文章來源:http://www.zghlxwxcb.cn/news/detail-429710.html
0x4. 總結(jié)
這篇文章展示了一下使用chatgpt閱讀oneflow softmax相關(guān)的fuse優(yōu)化(fused_scale_mask_softmax),在驚嘆于chatgpt的代碼邏輯理解能力的同時也可以體會到cuda中做kernel fuse相比于原始實(shí)現(xiàn)所能帶來的性能優(yōu)勢。實(shí)際上這個fuse在attention中只是局部fuse,已經(jīng)相對比較落后了。目前TensorRT/Xformers等都提供了多頭注意力的完整fuse實(shí)現(xiàn)使得多頭注意力部分的計(jì)算效率更高,oneflow中也集成了上述兩種fmha的實(shí)現(xiàn)并且做了很多擴(kuò)展提升易用性,后續(xù)有時間我將繼續(xù)解讀。文章來源地址http://www.zghlxwxcb.cn/news/detail-429710.html
到了這里,關(guān)于【BBuf的CUDA筆記】九,使用newbing(chatgpt)解析oneflow softmax相關(guān)的fuse優(yōu)化的文章就介紹完了。如果您還想了解更多內(nèi)容,請?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!