利用Tensor Core加速矩陣乘法的代碼之理解

Tensor Core單元可以實現(xiàn)矩陣乘法的加速,也為CUDA核提供了調(diào)用接口。我最近學(xué)習(xí)了相關(guān)文檔,也讀到官方給出的矩陣乘法樣例代碼,在此記錄自己的經(jīng)驗和理解。
常見的矩陣乘法有兩種:D = A * B + C 與 C = A * B + C,兩者之間的區(qū)別是后者原地計算。
這里預(yù)先約定好矩陣形狀的代表符號,A[M, K], B[K, N],其中M是A的行數(shù),K是A的列數(shù),N是B的列數(shù),于是容易推算出D[M, N], C[M, N],這些符號需要記住。
在tensor core之前,cuda kernel核函數(shù)的矩陣運算如下圖所示,將A在y方向上分割成M / m塊[m, K]大小的矩陣,同樣將B在x方向上分割成多塊N / n塊[n, K]大小的矩陣,于是我們只需要申請 M / m * N / n個線程塊即可對A*B并行計算,具體實現(xiàn)來說,可以申請dim3 blockSize(n, m)的線程塊與dim3 gridSize(N / n, M / m)的網(wǎng)格。

而對于每個線程塊,在K方向上也可以劃分成K / k個小矩陣,A中每個小矩陣的形狀為[m, k],于是矩陣乘法變成了更小的子矩陣乘法,1*1 + 2*2 + 3*3...即可得到我們想要的結(jié)果,可以在kernel核函數(shù)中用循環(huán)實現(xiàn)這種計算。

這種矩陣乘法的實現(xiàn)比較常見,在cuda基礎(chǔ)教程中有代碼實現(xiàn)。
說回Tensor Core,其加速矩陣乘法與上述的思路類似,但我們需要先了解一下其硬件特性。與FP32 Core類似,Tensor Core就是一個運算單元,前者輸入兩個浮點數(shù),返回一個浮點數(shù)加法結(jié)果,后者輸入兩個矩陣,返回矩陣乘法結(jié)果。在cuda C的tensor core接口(wmma)中,kernel核函數(shù)中一次tensor core的運算需要占用一個warp的線程(32個)。由于tensor core的一次運算的矩陣大小是固定的,所需線程數(shù)也是固定的,所以我們多個tensor core并行運算只需要對矩陣、線程進(jìn)行分割即可,下面講講怎么分割。
假設(shè)tensor core的一次矩陣運算的形狀為[m, k] * [k, n] = [m, n],其中從A矩陣中分割出[m, k]的子矩陣,從B矩陣分割出[k, n]的子矩陣,得到一個[m, n]的子矩陣。通過簡單的計算可得,A矩陣要求在y方向上需要M / m個warp的線程(每個warp負(fù)責(zé)[m, k]的矩陣),B矩陣要求在x方向上需要N / n個warp的線程,而在kernel內(nèi)進(jìn)行K / k次的循環(huán)累加即可得到C中[m, n]的子矩陣。如果你熟悉之前的矩陣乘法,這一定不難想明白。

剩下的就是編程了:
首先預(yù)定義__CUDACC__,其實不做預(yù)定義也能編譯成功,但VS不會出現(xiàn)代碼的輸入補全提示,而且滿屏波浪號。

然后是初始化矩陣,這里A的內(nèi)容是1, 2, 3, 4....512的序列,B的元素全是1,C的元素全是0,由于tensor core不接受float的輸入,所以使用半精度half作為輸入,float作為輸出。

最后是定義tensor core接收矩陣形狀的大小核函數(shù)了

總之,使用tensor core的矩陣乘法與普通的矩陣乘法其實是類似的,只不過tensor core的運算粒度更大,吞吐量更高。
完整代碼如下:
#include<device_launch_parameters.h>
#include<iostream>
#include<thrust/device_vector.h>
#include<thrust/sequence.h>
#ifndef __CUDACC__
#define __CUDACC__
#endif // !__CUDACC__
#include<mma.h>
using namespace nvcuda;
#define uint unsigned int
#define coreSizeM 16
#define coreSizeN 16
#define coreSizeK 16
__global__ void TensorCoreMM(half* a, half* b, float* c,
const int lm, const int ln, const int lk)
{
const uint x = (blockDim.x * blockIdx.x + threadIdx.x) / 32;
const uint y = blockDim.y * blockIdx.y + threadIdx.y;
const uint la = lk, lb = ln, lc = ln;
const uint aRow = x * coreSizeM; // 當(dāng)前tile左上角在A上的行數(shù)
const uint bCol = y * coreSizeN; // 當(dāng)前tile左上角在B上的列數(shù)
if (aRow >= lm || bCol >= ln) return;
// 聲明fragment
wmma::fragment<wmma::matrix_a, coreSizeM, coreSizeN, coreSizeK, half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, coreSizeM, coreSizeN, coreSizeK, half, wmma::row_major> b_frag;
wmma::fragment<wmma::accumulator, coreSizeM, coreSizeN, coreSizeK, float> c_frag;
// 清理c_frag
wmma::fill_fragment(c_frag, 0.f);
for (int i = 0; i < la; i += coreSizeK)
{
const uint aCol = i;
const uint bRow = i;
// load
wmma::load_matrix_sync(a_frag, a + aCol + aRow * la, la);
wmma::load_matrix_sync(b_frag, b + bCol + bRow * lb, lb);
// multiple and accumulate
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}
// store
wmma::store_matrix_sync(c + bCol + aRow * lc, c_frag, lc, wmma::mem_row_major);
}
#define vectorPtr(x) thrust::raw_pointer_cast(x.data())
int main()
{
// C = A * B + C
size_t M = 32, N = 16, K = 16;
thrust::host_vector<float> A_float(M * K);
thrust::sequence(A_float.begin(), A_float.end());
thrust::device_vector<half> A(A_float.begin(), A_float.end());
thrust::device_vector<half> B(K * N, 1.f);
thrust::device_vector<float> C(M * N, 0.f);
dim3 blockSize(128, 4);
dim3 gridSize((M + blockSize.x - 1) / blockSize.x,?
(N + blockSize.y - 1) / blockSize.y);
TensorCoreMM<<<gridSize, blockSize>>>(vectorPtr(A), vectorPtr(B), vectorPtr(C), M, N, K);
for (int i = 0; i < M; ++i)
{
thrust::copy(C.begin() + i * N, C.begin() + (i + 1) * N, std::ostream_iterator<float>(std::cout, ", "));
std::cout << std::endl;
}
return 0;
}