最美情侣中文字幕电影,在线麻豆精品传媒,在线网站高清黄,久久黄色视频

歡迎光臨散文網(wǎng) 會員登陸 & 注冊

利用Tensor Core加速矩陣乘法的代碼之理解

2022-08-26 21:38 作者:不會跑路的小向晚  | 我要投稿


Tensor Core的官方文檔名字叫Programming Guide

Tensor Core單元可以實現(xiàn)矩陣乘法的加速,也為CUDA核提供了調(diào)用接口。我最近學(xué)習(xí)了相關(guān)文檔,也讀到官方給出的矩陣乘法樣例代碼,在此記錄自己的經(jīng)驗和理解。

常見的矩陣乘法有兩種:D = A * B + C 與 C = A * B + C,兩者之間的區(qū)別是后者原地計算。

這里預(yù)先約定好矩陣形狀的代表符號,A[M, K], B[K, N],其中M是A的行數(shù),K是A的列數(shù),N是B的列數(shù),于是容易推算出D[M, N], C[M, N],這些符號需要記住。

在tensor core之前,cuda kernel核函數(shù)的矩陣運算如下圖所示,將A在y方向上分割成M / m塊[m, K]大小的矩陣,同樣將B在x方向上分割成多塊N / n塊[n, K]大小的矩陣,于是我們只需要申請 M / m * N / n個線程塊即可對A*B并行計算,具體實現(xiàn)來說,可以申請dim3 blockSize(n, m)的線程塊與dim3 gridSize(N / n, M / m)的網(wǎng)格。

而對于每個線程塊,在K方向上也可以劃分成K / k個小矩陣,A中每個小矩陣的形狀為[m, k],于是矩陣乘法變成了更小的子矩陣乘法,1*1 + 2*2 + 3*3...即可得到我們想要的結(jié)果,可以在kernel核函數(shù)中用循環(huán)實現(xiàn)這種計算。

這種矩陣乘法的實現(xiàn)比較常見,在cuda基礎(chǔ)教程中有代碼實現(xiàn)。

說回Tensor Core,其加速矩陣乘法與上述的思路類似,但我們需要先了解一下其硬件特性。與FP32 Core類似,Tensor Core就是一個運算單元,前者輸入兩個浮點數(shù),返回一個浮點數(shù)加法結(jié)果,后者輸入兩個矩陣,返回矩陣乘法結(jié)果。在cuda C的tensor core接口(wmma)中,kernel核函數(shù)中一次tensor core的運算需要占用一個warp的線程(32個)。由于tensor core的一次運算的矩陣大小是固定的,所需線程數(shù)也是固定的,所以我們多個tensor core并行運算只需要對矩陣、線程進(jìn)行分割即可,下面講講怎么分割。

假設(shè)tensor core的一次矩陣運算的形狀為[m, k] * [k, n] = [m, n],其中從A矩陣中分割出[m, k]的子矩陣,從B矩陣分割出[k, n]的子矩陣,得到一個[m, n]的子矩陣。通過簡單的計算可得,A矩陣要求在y方向上需要M / m個warp的線程(每個warp負(fù)責(zé)[m, k]的矩陣),B矩陣要求在x方向上需要N / n個warp的線程,而在kernel內(nèi)進(jìn)行K / k次的循環(huán)累加即可得到C中[m, n]的子矩陣。如果你熟悉之前的矩陣乘法,這一定不難想明白。

剩下的就是編程了:

首先預(yù)定義__CUDACC__,其實不做預(yù)定義也能編譯成功,但VS不會出現(xiàn)代碼的輸入補全提示,而且滿屏波浪號。

然后是初始化矩陣,這里A的內(nèi)容是1, 2, 3, 4....512的序列,B的元素全是1,C的元素全是0,由于tensor core不接受float的輸入,所以使用半精度half作為輸入,float作為輸出。

最后是定義tensor core接收矩陣形狀的大小核函數(shù)了

總之,使用tensor core的矩陣乘法與普通的矩陣乘法其實是類似的,只不過tensor core的運算粒度更大,吞吐量更高。

完整代碼如下:

#include<device_launch_parameters.h>

#include<iostream>

#include<thrust/device_vector.h>

#include<thrust/sequence.h>


#ifndef __CUDACC__

#define __CUDACC__

#endif // !__CUDACC__


#include<mma.h>


using namespace nvcuda;


#define uint unsigned int


#define coreSizeM 16

#define coreSizeN 16

#define coreSizeK 16

__global__ void TensorCoreMM(half* a, half* b, float* c,

const int lm, const int ln, const int lk)

{

const uint x = (blockDim.x * blockIdx.x + threadIdx.x) / 32;

const uint y = blockDim.y * blockIdx.y + threadIdx.y;


const uint la = lk, lb = ln, lc = ln;


const uint aRow = x * coreSizeM; // 當(dāng)前tile左上角在A上的行數(shù)

const uint bCol = y * coreSizeN; // 當(dāng)前tile左上角在B上的列數(shù)


if (aRow >= lm || bCol >= ln) return;


// 聲明fragment

wmma::fragment<wmma::matrix_a, coreSizeM, coreSizeN, coreSizeK, half, wmma::row_major> a_frag;

wmma::fragment<wmma::matrix_b, coreSizeM, coreSizeN, coreSizeK, half, wmma::row_major> b_frag;

wmma::fragment<wmma::accumulator, coreSizeM, coreSizeN, coreSizeK, float> c_frag;


// 清理c_frag

wmma::fill_fragment(c_frag, 0.f);


for (int i = 0; i < la; i += coreSizeK)

{

const uint aCol = i;

const uint bRow = i;

// load

wmma::load_matrix_sync(a_frag, a + aCol + aRow * la, la);

wmma::load_matrix_sync(b_frag, b + bCol + bRow * lb, lb);


// multiple and accumulate

wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);

}


// store

wmma::store_matrix_sync(c + bCol + aRow * lc, c_frag, lc, wmma::mem_row_major);

}


#define vectorPtr(x) thrust::raw_pointer_cast(x.data())


int main()

{

// C = A * B + C

size_t M = 32, N = 16, K = 16;

thrust::host_vector<float> A_float(M * K);

thrust::sequence(A_float.begin(), A_float.end());

thrust::device_vector<half> A(A_float.begin(), A_float.end());

thrust::device_vector<half> B(K * N, 1.f);

thrust::device_vector<float> C(M * N, 0.f);


dim3 blockSize(128, 4);

dim3 gridSize((M + blockSize.x - 1) / blockSize.x,?

(N + blockSize.y - 1) / blockSize.y);

TensorCoreMM<<<gridSize, blockSize>>>(vectorPtr(A), vectorPtr(B), vectorPtr(C), M, N, K);


for (int i = 0; i < M; ++i)

{

thrust::copy(C.begin() + i * N, C.begin() + (i + 1) * N, std::ostream_iterator<float>(std::cout, ", "));

std::cout << std::endl;

}

return 0;

}


利用Tensor Core加速矩陣乘法的代碼之理解的評論 (共 條)

分享到微博請遵守國家法律
会理县| 忻城县| 鲁甸县| 饶平县| 凌云县| 浦江县| 砚山县| 彝良县| 安康市| 武义县| 左权县| 永嘉县| 平阴县| 永靖县| 台东市| 麦盖提县| 白山市| 曲松县| 合阳县| 广南县| 灵川县| 海阳市| 岳阳县| 辛集市| 河北省| 八宿县| 台中市| 米泉市| 黑水县| 云安县| 秦安县| 屯昌县| 新丰县| 会宁县| 大石桥市| 保山市| 澳门| 慈利县| 阿拉善左旗| 墨竹工卡县| 临江市|