{"id":3319,"date":"2024-04-23T17:59:32","date_gmt":"2024-04-23T09:59:32","guid":{"rendered":"https:\/\/www.aqwu.net\/wp\/?p=3319"},"modified":"2024-05-01T18:23:06","modified_gmt":"2024-05-01T10:23:06","slug":"llm-c-%e7%9a%84%e4%b8%ad%e6%96%87%e6%b3%a8%e8%a7%a3-20240423","status":"publish","type":"post","link":"https:\/\/www.aqwu.net\/wp\/?p=3319","title":{"rendered":"llm.c \u7684\u4e2d\u6587\u6ce8\u89e3-20240423"},"content":{"rendered":"\n<p>\u8fd9\u91cc\u63a5<a href=\"https:\/\/www.aqwu.net\/wp\/?p=3229\">\u4e0a\u4e00\u8282<\/a> \u5bf9C\u7a0b\u5e8f\u7684\u4e2d\u6587\u6ce8\u89e3\uff0c\u4e0b\u9762\u662f\u5bf9 train_gpt2.cu \u7684\u6ce8\u89e3\uff0c\u6240\u6709\u6ce8\u89e3\u6765\u81eaChatGPT4\u3002<\/p>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>1. train_gpt2.cu<\/strong><\/h2>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.1 \u5f00\u59cb<\/strong><\/h3>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">\/*\nGPT-2 Transformer Neural Net trained in raw CUDA\nNon-trivial notes to be aware of:\nGPT-2 Transformer\u795e\u7ecf\u7f51\u7edc\u5728\u539f\u59cbCUDA\u4e2d\u8fdb\u884c\u8bad\u7ec3\n\u9700\u8981\u6ce8\u610f\u7684\u975e\u5e73\u51e1\u7ec6\u8282\uff1a\n\nWe are being clever in the backward pass to conserve memory.\nIn particular, all parameters use a += in the backward pass, so we\ncan later do gradient accumulation. But all activations have = instead of +=\nbecause these are faster (just read, no write). This is okay for all activations\nexcept for those in the residual stream, where the gradients have to add. We make\nsure that those parts work out ok and that we do a += as necessary. E.g.,\nthe layernorms are connected to the residuals so we += in layernorm backward.\n\u6211\u4eec\u5728\u53cd\u5411\u4f20\u64ad\u4e2d\u5de7\u5999\u5730\u8282\u7701\u5185\u5b58\u3002\n\u7279\u522b\u662f\uff0c\u6240\u6709\u53c2\u6570\u5728\u53cd\u5411\u4f20\u64ad\u4e2d\u4f7f\u7528 +=\uff0c\u56e0\u6b64\u6211\u4eec\n\u53ef\u4ee5\u968f\u540e\u8fdb\u884c\u68af\u5ea6\u7d2f\u79ef\u3002\u4f46\u6240\u6709\u6fc0\u6d3b\u51fd\u6570\u90fd\u4f7f\u7528 = \u800c\u4e0d\u662f +=\uff0c\n\u56e0\u4e3a\u8fd9\u6837\u66f4\u5feb\uff08\u53ea\u8bfb\uff0c\u4e0d\u5199\uff09\u3002\u8fd9\u5bf9\u6240\u6709\u6fc0\u6d3b\u51fd\u6570\u90fd\u662f\u53ef\u4ee5\u7684\uff0c\n\u9664\u4e86\u6b8b\u5dee\u6d41\u4e2d\u7684\u90a3\u4e9b\uff0c\u5176\u68af\u5ea6\u9700\u8981\u76f8\u52a0\u3002\u6211\u4eec\u786e\u4fdd\u8fd9\u4e9b\u90e8\u5206\n\u53ef\u4ee5\u6b63\u786e\u8fd0\u884c\uff0c\u5e76\u4e14\u5728\u5fc5\u8981\u65f6\u6267\u884c +=\u3002\u4f8b\u5982\uff0c\nlayernorms \u4e0e\u6b8b\u5dee\u76f8\u8fde\uff0c\u6240\u4ee5\u6211\u4eec\u5728layernorm\u7684\u53cd\u5411\u4f20\u64ad\u4e2d\u6267\u884c +=\u3002\n*\/\n\n#include &lt;stdio.h&gt;\n#include &lt;stdlib.h&gt;\n#include &lt;ctype.h&gt;\n#include &lt;math.h&gt;\n#include &lt;time.h&gt;\n#include &lt;assert.h&gt;\n#include &lt;float.h&gt;\n#include &lt;string.h&gt;\n#include &lt;unistd.h&gt;\n#include &lt;assert.h&gt;\n#include &lt;cublas_v2.h&gt;\n#include &lt;cuda_runtime.h&gt;\n#include &lt;cublasLt.h&gt;\n#include &lt;cooperative_groups.h&gt;\n#include &lt;cooperative_groups\/reduce.h&gt;<\/pre><\/div>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.2 CUDA \u5de5\u5177<\/strong><\/h3>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">\/\/ ----------------------------------------------------------------------------\n\/\/ CUDA utils\n\/\/ CUDA \u5de5\u5177\n\n\/\/ convenience macro for calculating grid\/block dimensions for kernels\n\/\/ \u7528\u4e8e\u8ba1\u7b97\u5185\u6838\u7684\u7f51\u683c\/\u5757\u5c3a\u5bf8\u7684\u4fbf\u6377\u5b8f\n#define CEIL_DIV(M, N) (((M) + (N)-1) \/ (N))\n\n\/\/ CUDA error checking\n\/\/ CUDA \u9519\u8bef\u68c0\u67e5\n\/\/ \u8fd9\u6bb5\u4ee3\u7801\u662f\u4e00\u4e2a\u7528\u4e8e\u68c0\u67e5CUDA\u51fd\u6570\u8c03\u7528\u662f\u5426\u6210\u529f\u7684\u5de5\u5177\u51fd\u6570\u3002\n\/\/ \u5982\u679c\u8c03\u7528\u5931\u8d25\uff0c\u5b83\u4f1a\u6253\u5370\u51fa\u9519\u8bef\u6240\u5728\u7684\u6587\u4ef6\u548c\u884c\u53f7\uff0c\u4ee5\u53ca\u9519\u8bef\u63cf\u8ff0\uff0c\u7136\u540e\u9000\u51fa\u7a0b\u5e8f\u3002\n\/\/ \u8fd9\u79cd\u9519\u8bef\u68c0\u67e5\u662fCUDA\u7f16\u7a0b\u4e2d\u5e38\u7528\u7684\u505a\u6cd5\uff0c\u7528\u4ee5\u786e\u4fddCUDA\u8c03\u7528\u6b63\u786e\u6267\u884c\uff0c\n\/\/ \u4ece\u800c\u4fbf\u4e8e\u8ffd\u8e2a\u548c\u5904\u7406\u53ef\u80fd\u53d1\u751f\u7684\u95ee\u9898\u3002\nvoid cudaCheck(cudaError_t error, const char *file, int line) {\n  if (error != cudaSuccess) {\n    printf(\"[CUDA ERROR] at file %s:%d:\\n%s\\n\", file, line,\n           cudaGetErrorString(error));\n    exit(EXIT_FAILURE);\n  }\n};\n#define cudaCheck(err) (cudaCheck(err, __FILE__, __LINE__))\n\n\/\/ cuBLAS error checking\n\/\/ cuBLAS \u9519\u8bef\u68c0\u67e5\n\/\/ \u8fd9\u6bb5\u4ee3\u7801\u662f\u4e00\u4e2a\u7528\u4e8e\u68c0\u67e5cuBLAS\uff08\u4e00\u4e2aNVIDIA\u7684CUDA\u57fa\u7840\u7ebf\u6027\u4ee3\u6570\u5b50\u7a0b\u5e8f\u5e93\uff09\u64cd\u4f5c\u7684\u72b6\u6001\u7684\u51fd\u6570\u3002\n\/\/ \u5982\u679ccuBLAS\u51fd\u6570\u8c03\u7528\u672a\u80fd\u8fd4\u56deCUBLAS_STATUS_SUCCESS\uff0c\u5373\u8868\u793a\u8c03\u7528\u51fa\u73b0\u9519\u8bef\uff0c\n\/\/ \u8be5\u51fd\u6570\u4f1a\u6253\u5370\u9519\u8bef\u4ee3\u7801\u3001\u6587\u4ef6\u540d\u548c\u884c\u53f7\uff0c\u5e76\u4f7f\u7a0b\u5e8f\u9000\u51fa\u3002\n\/\/ \u6b64\u51fd\u6570\u7684\u4e3b\u8981\u76ee\u7684\u662f\u4e3a\u4e86\u5728\u7a0b\u5e8f\u4e2d\u53ca\u65f6\u53d1\u73b0\u548c\u62a5\u544acuBLAS\u5e93\u64cd\u4f5c\u7684\u9519\u8bef\uff0c\u4ece\u800c\u4fbf\u4e8e\u8c03\u8bd5\u548c\u4fdd\u8bc1\u7a0b\u5e8f\u7684\u7a33\u5b9a\u8fd0\u884c\u3002\nvoid cublasCheck(cublasStatus_t status, const char *file, int line)\n{\n    if (status != CUBLAS_STATUS_SUCCESS) {\n        printf(\"[cuBLAS ERROR]: %d %s %d\\n\", status, file, line);\n        exit(EXIT_FAILURE);\n    }\n}\n#define cublasCheck(status) { cublasCheck((status), __FILE__, __LINE__); }\n\n\/\/ cuBLAS workspace. Hardcoding to 32MiB but only Hopper needs 32, for others 4 is OK\n\/\/ cuBLAS\u5de5\u4f5c\u7a7a\u95f4\u3002\u786c\u7f16\u7801\u4e3a32MiB\uff0c\u4f46\u53ea\u6709Hopper\u9700\u898132\uff0c\u5176\u4ed6\u76844MiB\u5c31\u591f\u4e86\n\/\/ \u8fd9\u6bb5\u4ee3\u7801\u4e3b\u8981\u6d89\u53caCUDA\u548ccuBLAS\u5e93\u7684\u521d\u59cb\u5316\u548c\u914d\u7f6e\u5de5\u4f5c\u3002\n\/\/ \u5b83\u8bbe\u7f6e\u4e86cuBLAS Lt\uff08\u4e00\u4e2a\u9488\u5bf9Tensor Core\u4f18\u5316\u7684cuBLAS\u5b50\u5e93\uff09\u6240\u9700\u7684\u5de5\u4f5c\u7a7a\u95f4\u5927\u5c0f\uff0c\n\/\/ \u5e76\u521d\u59cb\u5316\u4e86\u51e0\u4e2a\u91cd\u8981\u7684\u5e93\u53e5\u67c4\u548c\u53d8\u91cf\u3002cooperative_groups\u547d\u540d\u7a7a\u95f4\u5219\u7528\u4e8eCUDA\u4e2d\u7684\u7ebf\u7a0b\u7ec4\u534f\u4f5c\uff0c\n\/\/ \u6b64\u5904\u901a\u8fc7cg\u522b\u540d\u8fdb\u884c\u5f15\u7528\uff0c\u4ee5\u65b9\u4fbf\u5728\u4ee3\u7801\u4e2d\u4f7f\u7528\u3002\n\/\/\n\/\/ \u5c06\u5de5\u4f5c\u7a7a\u95f4\u5927\u5c0f\u8bbe\u7f6e\u4e3a32MiB\nstatic size_t cublaslt_workspace_size = 32 * 1024 * 1024;\n\/\/ \u521d\u59cb\u5316cuBLAS Lt\u5de5\u4f5c\u7a7a\u95f4\u6307\u9488\u4e3aNULL\nstatic void* cublaslt_workspace = NULL;\n\/\/ \u58f0\u660e\u4e00\u4e2acuBLAS\u8ba1\u7b97\u7c7b\u578b\u53d8\u91cf\nstatic cublasComputeType_t cublas_compute_type;\n\/\/ \u58f0\u660e\u4e00\u4e2acuBLAS\u5e93\u7684\u53e5\u67c4\ncublasHandle_t cublas_handle;\n\/\/ \u58f0\u660e\u4e00\u4e2acuBLAS Lt\u5e93\u7684\u53e5\u67c4\ncublasLtHandle_t cublaslt_handle;\n\n\/\/ \u521b\u5efa\u522b\u540dcg\uff0c\u6307\u5411cooperative_groups\u547d\u540d\u7a7a\u95f4\nnamespace cg = cooperative_groups;\n<\/pre><\/div>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.3 \u6587\u4ef6\u548c\u5185\u5b58\u5de5\u5177<\/strong><\/h3>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">\/\/ ----------------------------------------------------------------------------\n\/\/ fread convenience utils, with nice handling of error checking using macros\n\/\/ simple replace fopen, fread, fclose with fopenCheck, freadCheck, fcloseCheck\n\/\/ fread\u65b9\u4fbf\u5de5\u5177\uff0c\u4f7f\u7528\u5b8f\u8fdb\u884c\u4f18\u96c5\u7684\u9519\u8bef\u68c0\u67e5\u5904\u7406\n\/\/ \u7b80\u5355\u66ff\u6362fopen, fread, fclose\u4e3afopenCheck, freadCheck, fcloseCheck\n\nFILE *fopen_check(const char *path, const char *mode, const char *file, int line) {\n    FILE *fp = fopen(path, mode);\n    if (fp == NULL) {\n        fprintf(stderr, \"Error: Failed to open file '%s' at %s:%d\\n\", path, file, line);\n        fprintf(stderr, \"Error details:\\n\");\n        fprintf(stderr, \"  File: %s\\n\", file);\n        fprintf(stderr, \"  Line: %d\\n\", line);\n        fprintf(stderr, \"  Path: %s\\n\", path);\n        fprintf(stderr, \"  Mode: %s\\n\", mode);\n        exit(EXIT_FAILURE);\n    }\n    return fp;\n}\n\n#define fopenCheck(path, mode) fopen_check(path, mode, __FILE__, __LINE__)\n\nvoid fread_check(void *ptr, size_t size, size_t nmemb, FILE *stream, const char *file, int line) {\n    size_t result = fread(ptr, size, nmemb, stream);\n    if (result != nmemb) {\n        if (feof(stream)) {\n            fprintf(stderr, \"Error: Unexpected end of file at %s:%d\\n\", file, line);\n        } else if (ferror(stream)) {\n            fprintf(stderr, \"Error: File read error at %s:%d\\n\", file, line);\n        } else {\n            fprintf(stderr, \"Error: Partial read at %s:%d. Expected %zu elements, read %zu\\n\",\n                    file, line, nmemb, result);\n        }\n        fprintf(stderr, \"Error details:\\n\");\n        fprintf(stderr, \"  File: %s\\n\", file);\n        fprintf(stderr, \"  Line: %d\\n\", line);\n        fprintf(stderr, \"  Expected elements: %zu\\n\", nmemb);\n        fprintf(stderr, \"  Read elements: %zu\\n\", result);\n        exit(EXIT_FAILURE);\n    }\n}\n\n#define freadCheck(ptr, size, nmemb, stream) fread_check(ptr, size, nmemb, stream, __FILE__, __LINE__)\n\nvoid fclose_check(FILE *fp, const char *file, int line) {\n    if (fclose(fp) != 0) {\n        fprintf(stderr, \"Error: Failed to close file at %s:%d\\n\", file, line);\n        fprintf(stderr, \"Error details:\\n\");\n        fprintf(stderr, \"  File: %s\\n\", file);\n        fprintf(stderr, \"  Line: %d\\n\", line);\n        exit(EXIT_FAILURE);\n    }\n}\n\n#define fcloseCheck(fp) fclose_check(fp, __FILE__, __LINE__)\n\n\/\/ ----------------------------------------------------------------------------\n\/\/ malloc error-handling wrapper util\n\/\/ malloc \u9519\u8bef\u5904\u7406\u5c01\u88c5\u5de5\u5177\n\nvoid *malloc_check(size_t size, const char *file, int line) {\n    void *ptr = malloc(size);\n    if (ptr == NULL) {\n        fprintf(stderr, \"Error: Memory allocation failed at %s:%d\\n\", file, line);\n        fprintf(stderr, \"Error details:\\n\");\n        fprintf(stderr, \"  File: %s\\n\", file);\n        fprintf(stderr, \"  Line: %d\\n\", line);\n        fprintf(stderr, \"  Size: %zu bytes\\n\", size);\n        exit(EXIT_FAILURE);\n    }\n    return ptr;\n}\n\n#define mallocCheck(size) malloc_check(size, __FILE__, __LINE__)\n\n<\/pre><\/div>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.4 warpReduceMax<\/strong><\/h3>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">\/\/ all the kernels\n\/\/ \u6240\u6709\u7684\u6838\u51fd\u6570\n\n\/\/ warp-level reduction for finding the maximum value\n\/\/ \u7528\u4e8e\u5bfb\u627e\u6700\u5927\u503c\u7684warp\u7ea7\u522b\u5f52\u7ea6\n__device__ float warpReduceMax(float val) {\n    for (int offset = 16; offset &gt; 0; offset \/= 2) {\n        \/\/ \u4f7f\u7528__shfl_down_sync\u51fd\u6570\u8fdb\u884c\u540c\u6b65\u7684shuffle\u64cd\u4f5c\n        val = fmaxf(val, __shfl_down_sync(0xFFFFFFFF, val, offset));\n    }\n    \/\/ \u8fd4\u56de\u5f52\u7ea6\u540e\u7684\u6700\u5927\u503c\n    return val;\n}\n<\/pre><\/div>\n\n\n\n<p>\u8fd9\u6bb5CUDA\u8bbe\u5907\u4ee3\u7801\u5b9a\u4e49\u4e86\u4e00\u4e2a\u540d\u4e3a<code>warpReduceMax<\/code>\u7684\u51fd\u6570\uff0c\u7528\u4e8e\u5728GPU\u7684warp\u7ea7\u522b\u8fdb\u884c\u5e76\u884c\u5f52\u7ea6\uff0c\u4ee5\u627e\u5230\u6700\u5927\u503c\u3002\u5728\u8fd9\u4e2a\u51fd\u6570\u4e2d\uff0c\u901a\u8fc7\u8fed\u4ee3\u51cf\u534a\u504f\u79fb\u91cf\u5e76\u4f7f\u7528CUDA\u7684<code>__shfl_down_sync<\/code>\u5185\u5efa\u51fd\u6570\u6765\u4f20\u9012\u548c\u6bd4\u8f83\u6d6e\u70b9\u6570\u503c\u3002\u8fd9\u4e2a\u5185\u5efa\u51fd\u6570\u4f7f\u5f97\u7ebf\u7a0b\u80fd\u591f\u4ece\u76f8\u540cwarp\u4e2d\u7684\u53e6\u4e00\u4e2a\u7ebf\u7a0b\u83b7\u53d6\u53d8\u91cf\u503c\uff0c\u5b9e\u73b0\u9ad8\u6548\u7684\u6570\u636e\u4ea4\u6362\u548c\u5f52\u7ea6\u3002\u6bcf\u4e00\u6b65\u5f52\u7ea6\u64cd\u4f5c\u90fd\u4f7f\u7528<code>fmaxf<\/code>\u51fd\u6570\u4fdd\u8bc1\u53d6\u5f97\u4e24\u4e2a\u503c\u4e2d\u7684\u6700\u5927\u503c\uff0c\u4ece\u800c\u6700\u7ec8\u5728\u4e00\u4e2awarp\u5185\u5f97\u5230\u6700\u5927\u503c\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.5 warpReduceSum<\/strong><\/h3>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">\/\/ warp-level reduction for summing values\n\/\/ \u7528\u4e8e\u6c42\u548c\u7684warp\u7ea7\u522b\u5f52\u7ea6\n__device__ float warpReduceSum(float val) {\n    for (int offset = 16; offset &gt; 0; offset \/= 2) {\n        \/\/ \u4f7f\u7528__shfl_down_sync\u51fd\u6570\u8fdb\u884c\u540c\u6b65\u7684shuffle\u64cd\u4f5c\uff0c\u7d2f\u52a0\u503c\n        val += __shfl_down_sync(0xFFFFFFFF, val, offset);\n    }\n    \/\/ \u8fd4\u56de\u5f52\u7ea6\u540e\u7684\u603b\u548c\n    return val;\n}\n<\/pre><\/div>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.6 encoder_forward_kernel2<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u5b9a\u4e49\u4e86\u4e00\u4e2a\u540d\u4e3a<code>encoder_forward_kernel2<\/code>\u7684CUDA\u5168\u5c40\u5185\u6838\u51fd\u6570\uff0c\u7528\u4e8e\u5728\u6df1\u5ea6\u5b66\u4e60\u6a21\u578b\u4e2d\u7684\u7f16\u7801\u5668\u524d\u5411\u4f20\u64ad\u8ba1\u7b97\u4e2d\u3002\u51fd\u6570\u63a5\u6536\u51e0\u4e2a\u53c2\u6570\uff0c\u5305\u62ec\u8f93\u51fa\u5411\u91cf\u3001\u8f93\u5165\u7d22\u5f15\u3001\u8bcd\u5d4c\u5165\u5411\u91cf\u3001\u4f4d\u7f6e\u5d4c\u5165\u5411\u91cf\u4ee5\u53ca\u4e00\u4e9b\u7ef4\u5ea6\u53c2\u6570<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">\/\/ float* out: \u8f93\u51fa\u6570\u636e\u7684\u6307\u9488\u3002\n\/\/ int* inp: \u8f93\u5165\u7d22\u5f15\u7684\u6307\u9488\u3002\n\/\/ float* wte: \u8bcd\u5d4c\u5165\u6743\u91cd\u7684\u6307\u9488\u3002\n\/\/ float* wpe: \u4f4d\u7f6e\u5d4c\u5165\u6743\u91cd\u7684\u6307\u9488\u3002\n\/\/ int B: batch\u5927\u5c0f\u3002\n\/\/ int T: \u5e8f\u5217\u957f\u5ea6\u3002\n\/\/ int C: \u5d4c\u5165\u7ef4\u5ea6\u3002\n__global__ void encoder_forward_kernel2(float* out,\n                               int* inp, float* wte, float* wpe,\n                               int B, int T, int C) {\n    \/\/ 1.\u7d22\u5f15\u8ba1\u7b97\uff1a\u901a\u8fc7blockIdx.x, blockDim.x\u548cthreadIdx.x\u8ba1\u7b97\u5f53\u524d\u7ebf\u7a0b\u7684\u5168\u5c40\u7d22\u5f15idx\u3002\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    int N = B * T * C;\n\n    \/\/ 2.\u8fb9\u754c\u68c0\u67e5\uff1a\u786e\u4fddidx\u5728\u6709\u6548\u8303\u56f4\u5185\uff0c\u5373\u5c0f\u4e8eN\uff0c\u5176\u4e2dN = B * T * C\u4e3a\u8f93\u51fa\u6570\u636e\u7684\u603b\u5143\u7d20\u6570\u3002\n    if (idx &lt; N) {\n        \/\/ 3.\u6620\u5c04\u5230\u4e09\u7ef4\u7d22\u5f15\uff1a\u8ba1\u7b97\u5bf9\u5e94\u7684batch\u7d22\u5f15b\uff0c\u65f6\u95f4\u6b65\u7d22\u5f15t\u548c\u5d4c\u5165\u7ef4\u5ea6\u7d22\u5f15c\u3002\n        int bt = idx \/ C;\n        int b = bt \/ T;\n        int t = bt % T;\n        int c = idx % C;\n\n        \/\/ 4.\u6570\u636e\u8bbf\u95ee\uff1a\u901a\u8fc7\u7ed9\u5b9a\u7684\u8f93\u5165\u7d22\u5f15inp[b * T + t]\u83b7\u53d6\u8bcd\u5d4c\u5165\u7d22\u5f15ix\u3002\n        int ix = inp[b * T + t];\n\n        \/\/ 5. \u6743\u91cd\u8bbf\u95ee\u548c\u76f8\u52a0\uff1a\u901a\u8fc7wte\u548cwpe\u8bbf\u95ee\u76f8\u5e94\u7684\u8bcd\u5d4c\u5165\u548c\u4f4d\u7f6e\u5d4c\u5165\uff0c\u5e76\u5c06\u5b83\u4eec\u76f8\u52a0\u5f97\u5230\u8f93\u51fa\u3002\n        float* out_btc = out + b * T * C + t * C + c;\n        float* wte_ix = wte + ix * C + c;\n        float* wpe_tc = wpe + t * C + c;\n        *out_btc = *wte_ix + *wpe_tc;\n    }\n}\n<\/pre><\/div>\n\n\n\n<p>\u8be5\u5185\u6838\u51fd\u6570\u7684\u4f5c\u7528\u662f\u5bf9\u6bcf\u4e2a\u8f93\u5165\u8bcd\u6c47\u7684\u8bcd\u5d4c\u5165\u548c\u5bf9\u5e94\u7684\u4f4d\u7f6e\u5d4c\u5165\u8fdb\u884c\u76f8\u52a0\uff0c\u8fd9\u662f\u8bb8\u591a\u57fa\u4e8e\u6ce8\u610f\u529b\u7684\u795e\u7ecf\u7f51\u7edc\u67b6\u6784\uff08\u5982Transformer\uff09\u4e2d\u7684\u5e38\u89c1\u6b65\u9aa4\u3002\u901a\u8fc7\u5229\u7528CUDA\u7684\u5e76\u884c\u8ba1\u7b97\u80fd\u529b\uff0c\u8be5\u51fd\u6570\u80fd\u591f\u9ad8\u6548\u5730\u5904\u7406\u5927\u89c4\u6a21\u6570\u636e\u96c6\uff0c\u9002\u7528\u4e8e\u5904\u7406\u5927\u578b\u8bed\u8a00\u6a21\u578b\u6216\u5176\u4ed6\u9700\u8981\u5927\u91cf\u5e76\u884c\u6587\u672c\u6570\u636e\u5904\u7406\u7684\u5e94\u7528\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.7 encoder_backward_kernel<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u5b9a\u4e49\u4e86\u4e00\u4e2a\u540d\u4e3a<code>encoder_backward_kernel<\/code>\u7684CUDA\u5168\u5c40\u5185\u6838\u51fd\u6570\uff0c\u7528\u4e8e\u5728\u6df1\u5ea6\u5b66\u4e60\u6a21\u578b\u4e2d\u7684\u7f16\u7801\u5668\u53cd\u5411\u4f20\u64ad\u8fc7\u7a0b\u4e2d\u7684\u68af\u5ea6\u66f4\u65b0\u3002\u8fd9\u4e2a\u51fd\u6570\u4f7f\u7528\u4e86<code>atomicAdd<\/code>\uff0c\u4e00\u4e2a\u539f\u5b50\u64cd\u4f5c\u51fd\u6570\uff0c\u4ee5\u786e\u4fdd\u5728\u5e76\u884c\u6267\u884c\u4e2d\u5bf9\u5171\u4eab\u6570\u636e\u7684\u4fee\u6539\u662f\u5b89\u5168\u7684\u3002<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">\/\/ really bad naive kernel with atomicAdd\n\/\/ \u4f7f\u7528atomicAdd\u7684\u975e\u5e38\u521d\u7ea7\u7684\u5185\u6838\u51fd\u6570\n\/\/\n\/\/ float* dwte: \u8bcd\u5d4c\u5165\u68af\u5ea6\u7684\u6307\u9488\u3002\n\/\/ float* dwpe: \u4f4d\u7f6e\u5d4c\u5165\u68af\u5ea6\u7684\u6307\u9488\u3002\n\/\/ const float* dout: \u6765\u81ea\u540e\u7eed\u5c42\u7684\u68af\u5ea6\u4f20\u9012\u7684\u6307\u9488\u3002\n\/\/ const int* inp: \u8f93\u5165\u7d22\u5f15\u7684\u6307\u9488\u3002\n\/\/ int B: batch\u5927\u5c0f\u3002\n\/\/ int T: \u5e8f\u5217\u957f\u5ea6\u3002\n\/\/ int C: \u5d4c\u5165\u7ef4\u5ea6\u3002\n\/\/ \u6838\u5fc3\u8ba1\u7b97\u8fc7\u7a0b\u5305\u62ec\uff1a\n__global__ void encoder_backward_kernel(float* dwte, float* dwpe,\n                                        const float* dout, const int* inp,\n                                        int B, int T, int C) {\n    \/\/ 1. \u7d22\u5f15\u8ba1\u7b97\uff1a\u901a\u8fc7blockIdx.x, blockDim.x\u548cthreadIdx.x\u8ba1\u7b97\u5f53\u524d\u7ebf\u7a0b\u7684\u5168\u5c40\u7d22\u5f15idx\u3002\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    int N = B * T * C;\n\n    \/\/ 2. \u8fb9\u754c\u68c0\u67e5\uff1a\u786e\u4fddidx\u5728\u6709\u6548\u8303\u56f4\u5185\uff0c\u5373\u5c0f\u4e8eN\uff0c\u5176\u4e2dN = B * T * C\u4e3a\u6570\u636e\u7684\u603b\u5143\u7d20\u6570\u3002\n    if (idx &lt; N) {\n        \/\/ 3. \u6620\u5c04\u5230\u4e09\u7ef4\u7d22\u5f15\uff1a\u8ba1\u7b97\u5bf9\u5e94\u7684batch\u7d22\u5f15b\uff0c\u65f6\u95f4\u6b65\u7d22\u5f15t\u548c\u5d4c\u5165\u7ef4\u5ea6\u7d22\u5f15c\u3002\n        int bt = idx \/ C;\n        int b = bt \/ T;\n        int t = bt % T;\n        int c = idx % C;\n\n        \/\/ 4. \u6570\u636e\u8bbf\u95ee\uff1a\u901a\u8fc7\u7ed9\u5b9a\u7684\u8f93\u5165\u7d22\u5f15inp[b * T + t]\u83b7\u53d6\u8bcd\u5d4c\u5165\u548c\u4f4d\u7f6e\u5d4c\u5165\u7684\u66f4\u65b0\u7d22\u5f15ix\u3002\n        int ix = inp[b * T + t];\n\n        \/\/ 5. \u68af\u5ea6\u7d2f\u52a0\uff1a\u4f7f\u7528atomicAdd\u6765\u5b89\u5168\u5730\u5728dwte\u548cdwpe\u7684\u5bf9\u5e94\u4f4d\u7f6e\u7d2f\u52a0\u68af\u5ea6\u503c\u3002\n        const float* dout_btc = dout + b * T * C + t * C + c;\n        float* dwte_ix = dwte + ix * C + c;\n        float* dwpe_tc = dwpe + t * C + c;\n\n        atomicAdd(dwte_ix, *dout_btc);\n        atomicAdd(dwpe_tc, *dout_btc);\n    }\n}<\/pre><\/div>\n\n\n\n<p><strong>\u6027\u80fd\u8003\u91cf<\/strong>\uff1a \u4f7f\u7528<code>atomicAdd<\/code>\u5728GPU\u7f16\u7a0b\u4e2d\u662f\u51fa\u4e8e\u9700\u8981\u4fdd\u8bc1\u5e76\u884c\u66f4\u65b0\u7684\u6570\u636e\u4e00\u81f4\u6027\u548c\u6b63\u786e\u6027\u3002\u7136\u800c\uff0c<code>atomicAdd<\/code>\u53ef\u80fd\u5bfc\u81f4\u6027\u80fd\u74f6\u9888\uff0c\u5c24\u5176\u662f\u5f53\u591a\u4e2a\u7ebf\u7a0b\u9891\u7e41\u5730\u5bf9\u540c\u4e00\u5185\u5b58\u4f4d\u7f6e\u8fdb\u884c\u66f4\u65b0\u65f6\uff0c\u8fd9\u53ef\u80fd\u5bfc\u81f4\u4e25\u91cd\u7684\u6027\u80fd\u4e0b\u964d\u3002\u8fd9\u4e2a\u5185\u6838\u51fd\u6570\u88ab\u6807\u8bb0\u4e3a\u201cnaive\u201d\uff08\u521d\u7ea7\u7684\u3001\u7b80\u5355\u7684\uff09\u4e3b\u8981\u662f\u56e0\u4e3a\u5b83\u5728\u8bbe\u8ba1\u4e0a\u6ca1\u6709\u4f18\u5316\u8fd9\u79cd\u9ad8\u51b2\u7a81\u7684\u5199\u64cd\u4f5c\uff0c\u53ef\u80fd\u4f1a\u5728\u5b9e\u9645\u5e94\u7528\u4e2d\u9047\u5230\u6027\u80fd\u95ee\u9898\u3002\u5728\u5927\u89c4\u6a21\u6570\u636e\u548c\u9ad8\u5ea6\u5e76\u884c\u7684\u60c5\u51b5\u4e0b\uff0c\u4f18\u5316\u8fd9\u79cd\u7c7b\u578b\u7684\u5185\u6838\u662f\u975e\u5e38\u91cd\u8981\u7684\uff0c\u6bd4\u5982\u901a\u8fc7\u8bbe\u8ba1\u66f4\u9ad8\u6548\u7684\u6570\u636e\u8bbf\u95ee\u6a21\u5f0f\u6216\u4f7f\u7528\u66f4\u5148\u8fdb\u7684\u5e76\u884c\u5f52\u7ea6\u6280\u672f\u6765\u51cf\u5c11\u5bf9<code>atomicAdd<\/code>\u7684\u4f9d\u8d56\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.8 layernorm_forward_kernel3<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u5b9a\u4e49\u4e86\u4e00\u4e2a\u540d\u4e3a<code>layernorm_forward_kernel3<\/code>\u7684CUDA\u5168\u5c40\u5185\u6838\u51fd\u6570\uff0c\u7528\u4e8e\u6267\u884c\u5c42\u5f52\u4e00\u5316\uff08Layer Normalization\uff09\u7684\u524d\u5411\u8ba1\u7b97\u3002\u8be5\u51fd\u6570\u4f7f\u7528\u4e86NVIDIA Cooperative Groups\u5e93\u6765\u8fdb\u884c\u5e76\u884c\u8ba1\u7b97\uff0c\u4f18\u5316\u4e86\u5185\u5b58\u8bbf\u95ee\u548c\u6570\u636e\u5f52\u7ea6\u8fc7\u7a0b\u3002<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">\/\/ float* __restrict__ out: \u8f93\u51fa\u6570\u7ec4\u7684\u6307\u9488\u3002\n\/\/ float* __restrict__ mean: \u5747\u503c\u6570\u7ec4\u7684\u6307\u9488\u3002\n\/\/ float* __restrict__ rstd: \u9006\u6807\u51c6\u5dee\u6570\u7ec4\u7684\u6307\u9488\u3002\n\/\/ const float* __restrict__ inp: \u8f93\u5165\u6570\u7ec4\u7684\u6307\u9488\u3002\n\/\/ const float* __restrict__ weight: \u6743\u91cd\u6570\u7ec4\u7684\u6307\u9488\u3002\n\/\/ const float* __restrict__ bias: \u504f\u5dee\u6570\u7ec4\u7684\u6307\u9488\u3002\n\/\/ int N: \u5904\u7406\u7684\u6570\u636e\u884c\u6570\u3002\n\/\/ int C: \u6bcf\u884c\u7684\u6570\u636e\u6570\u91cf\u3002\n__global__ void layernorm_forward_kernel3(float* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd,\n                                    const float*  __restrict__ inp, const float*  __restrict__ weight,\n                                    const float* __restrict__ bias, int N, int C) {\n    \/\/ 1. \u7ebf\u7a0b\u5757\u548c\u5206\u533a\uff1a\u4f7f\u7528cooperative_groups\u5e93\u7684\u529f\u80fd\u6765\u5b9a\u4e49\u7ebf\u7a0b\u5757\u548c\u5206\u533a\uff0c\u4ee5\u786e\u4fdd\u66f4\u6709\u6548\u7684\u6570\u636e\u5904\u7406\u548c\u5f52\u7ea6\u3002\n    cg::thread_block block = cg::this_thread_block();\n    cg::thread_block_tile&lt;32&gt; warp = cg::tiled_partition&lt;32&gt;(block);\n    int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();\n    if(idx &gt;= N) {\n        return;\n    }\n\n    \/\/ the row of input that this group of threads is responsible for\n    \/\/ \u8fd9\u7ec4\u7ebf\u7a0b\u8d1f\u8d23\u7684\u8f93\u5165\u884c\n    const float* x = inp + idx * C;\n\n    \/\/ mean\n    \/\/ 2. \u5747\u503c\u8ba1\u7b97\uff1a\u6bcf\u4e2awarp\u8ba1\u7b97\u81ea\u5df1\u8d1f\u8d23\u7684\u8f93\u5165\u884c\u7684\u5143\u7d20\u4e4b\u548c\uff0c\u7136\u540e\u4f7f\u7528cg::reduce\u51fd\u6570\u8fdb\u884c\u5f52\u7ea6\u5f97\u5230\u5747\u503c\u3002\n    float sum = 0.0f;\n    for (int i = warp.thread_rank(); i &lt; C; i += warp.size()) {\n        sum += x[i];\n    }\n    sum = cg::reduce(warp, sum, cg::plus&lt;float&gt;{});\n    float m = sum \/ C;\n    if(warp.thread_rank() == 0 &amp;&amp; mean != nullptr) {\n        __stcs(mean + idx, m);\n    }\n\n    \/\/ rstd\n    \/\/ 3. \u9006\u6807\u51c6\u5dee\u8ba1\u7b97\uff1a\u8ba1\u7b97\u6bcf\u4e2a\u5143\u7d20\u4e0e\u5747\u503c\u7684\u5dee\u7684\u5e73\u65b9\u548c\uff0c\u518d\u6b21\u4f7f\u7528cg::reduce\u8fdb\u884c\u5f52\u7ea6\uff0c\u6700\u540e\u8ba1\u7b97\u9006\u6807\u51c6\u5dee\u3002\n    sum = 0.0f;\n    for (int i = warp.thread_rank(); i &lt; C; i += warp.size()) {\n        float diff = x[i] - m;\n        sum += diff * diff;\n    }\n    sum = cg::reduce(warp, sum, cg::plus&lt;float&gt;{});\n    float s = rsqrtf(sum \/ C + 1e-5f);\n    if(warp.thread_rank() == 0 &amp;&amp; rstd != nullptr) {\n        __stcs(rstd + idx, s);\n    }\n\n    \/\/ final normalization and scaling by weight\/bias\n    \/\/ \u6700\u7ec8\u7684\u5f52\u4e00\u5316\u548c\u901a\u8fc7\u6743\u91cd\/\u504f\u5dee\u7684\u7f29\u653e\u5904\u7406\n    \/\/ 4. \u5f52\u4e00\u5316\u548c\u7f29\u653e\uff1a\u5bf9\u8f93\u5165\u6570\u636e\u8fdb\u884c\u5f52\u4e00\u5316\uff0c\u7136\u540e\u4e58\u4ee5\u6743\u91cd\u5e76\u52a0\u4e0a\u504f\u5dee\uff0c\u8ba1\u7b97\u6700\u7ec8\u7684\u8f93\u51fa\u503c\u3002\n    float* o = out + idx * C;\n    for (int c = warp.thread_rank(); c &lt; C; c += warp.size()) {\n        \/\/ load and store using the .cs \"streaming\" hint to the compiler,\n        \/\/ indicating that this data will not be reused soon, and can be streamed through the caches\n        \/\/ this allows the threads to get more cache-hits for the (shared) weight and bias parameters\n        float n = s * (__ldcs(x+c) - m);\n        __stcs(o+c, n * weight[c] + bias[c]);\n    }\n}\n<\/pre><\/div>\n\n\n\n<p>\u6027\u80fd\u4f18\u5316\u63aa\u65bd\uff1a<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong>\u7f13\u5b58\u63d0\u793a<\/strong>\uff1a\u4f7f\u7528<code>.cs<\/code>\uff08cache streaming\uff09\u6307\u793a\u7b26\u6765\u544a\u8bc9\u7f16\u8bd1\u5668\u53ef\u4ee5\u901a\u8fc7\u7f13\u5b58\u8fdb\u884c\u6570\u636e\u6d41\u5904\u7406\uff0c\u8fd9\u6709\u52a9\u4e8e\u63d0\u9ad8\u5185\u5b58\u8bbf\u95ee\u6548\u7387\u3002<\/li>\n\n\n\n<li><strong>\u5e76\u884c\u5f52\u7ea6<\/strong>\uff1a\u4f7f\u7528<code>cooperative_groups<\/code>\u4e2d\u7684\u5f52\u7ea6\u64cd\u4f5c\u6765\u4f18\u5316\u5e76\u884c\u5f52\u7ea6\u7684\u6027\u80fd\u3002<\/li>\n<\/ul>\n\n\n\n<p>\u8fd9\u79cd\u5b9e\u73b0\u65b9\u5f0f\u4e0d\u4ec5\u4fdd\u8bc1\u4e86\u8ba1\u7b97\u7684\u51c6\u786e\u6027\uff0c\u4e5f\u4f18\u5316\u4e86\u6267\u884c\u6548\u7387\uff0c\u7279\u522b\u9002\u7528\u4e8e\u5904\u7406\u5927\u89c4\u6a21\u6570\u636e\u96c6\uff0c\u5982\u5728\u6df1\u5ea6\u5b66\u4e60\u6a21\u578b\u7684\u8bad\u7ec3\u548c\u63a8\u7406\u4e2d\u5e38\u89c1\u7684\u573a\u666f\u3002<\/p>\n\n\n\n<p>\u65e2\u7136ChatGPT4 \u63d0\u51fa\u4e86\u4f18\u5316\uff0c\u90a3\u7ee7\u7eed\u8ba9ChatGPT4\u4fee\u6539\u4ee3\u7801\uff0c\u4e0b\u9762\u662fChatGPT4 \u7684\u4f18\u5316\u4ee3\u7801\uff0c\u7531\u4e8e\u6d89\u53ca\u5230\u5171\u4eab\u5185\u5b58\uff0c\u90a3\u81ea\u7136\u903b\u8f91\u5c31\u53d8\u4e86\uff0c\u540e\u9762\u770b\u662f\u5426\u53ef\u4ee5\u89e3\u51b3\u8fd9\u4e2a\u95ee\u9898\u3002<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">__global__ void layernorm_forward_kernel3(float* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd,\n                                    const float* __restrict__ inp, const float* __restrict__ weight,\n                                    const float* __restrict__ bias, int N, int C) {\n    extern __shared__ float shared_data[];\n    float* shared_weight = shared_data;\n    float* shared_bias = &amp;shared_data[C];\n\n    cg::thread_block block = cg::this_thread_block();\n    cg::thread_block_tile&lt;32&gt; warp = cg::tiled_partition&lt;32&gt;(block);\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    if(idx &gt;= N) {\n        return;\n    }\n\n    if (threadIdx.x &lt; C) {\n        shared_weight[threadIdx.x] = weight[threadIdx.x];\n        shared_bias[threadIdx.x] = bias[threadIdx.x];\n    }\n    __syncthreads(); \/\/ \u786e\u4fdd\u6240\u6709\u6570\u636e\u90fd\u88ab\u52a0\u8f7d\u5230\u5171\u4eab\u5185\u5b58\n\n    const float* x = inp + idx * C;\n    float sum = 0.0f;\n    for (int i = warp.thread_rank(); i &lt; C; i += warp.size()) {\n        sum += x[i];\n    }\n    sum = cg::reduce(warp, sum, cg::plus&lt;float&gt;{});\n    float m = sum \/ C;\n    if(warp.thread_rank() == 0 &amp;&amp; mean != nullptr) {\n        __stcs(mean + idx, m);\n    }\n\n    \/\/ \u9006\u6807\u51c6\u5dee\u8ba1\u7b97\n    sum = 0.0f;\n    for (int i = warp.thread_rank(); i &lt; C; i += warp.size()) {\n        float diff = x[i] - m;\n        sum += diff * diff;\n    }\n    sum = cg::reduce(warp, sum, cg::plus&lt;float&gt;{});\n    float s = rsqrtf(sum \/ C + 1e-5f);\n    if(warp.thread_rank() == 0 &amp;&amp; rstd != nullptr) {\n        __stcs(rstd + idx, s);\n    }\n\n    \/\/ \u5f52\u4e00\u5316\u548c\u7f29\u653e\n    float* o = out + idx * C;\n    for (int c = warp.thread_rank(); c &lt; C; c += warp.size()) {\n        float n = s * (__ldcs(x+c) - m);\n        __stcs(o+c, n * shared_weight[c] + shared_bias[c]);\n    }\n}\n<\/pre><\/div>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u786e\u4fdd\uff1a<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>\u6743\u91cd\u548c\u504f\u7f6e\u662f\u901a\u8fc7\u5171\u4eab\u5185\u5b58\u8fdb\u884c\u8bbf\u95ee\u7684\uff0c\u4ece\u800c\u51cf\u5c11\u4e86\u5168\u5c40\u5185\u5b58\u7684\u8bbf\u95ee\u6b21\u6570\u548c\u63d0\u9ad8\u4e86\u8bbf\u95ee\u6548\u7387\u3002<\/li>\n\n\n\n<li>\u4f7f\u7528cooperative_groups\u7684warp\u7ea7\u522b\u5f52\u7ea6\u6765\u4f18\u5316\u5f52\u7ea6\u64cd\u4f5c\u7684\u6548\u7387\u3002<\/li>\n\n\n\n<li>\u4f7f\u7528\u7f13\u5b58\u6d41\u6307\u793a\u7b26 <code>.cs<\/code> \u6765\u4f18\u5316\u6570\u636e\u52a0\u8f7d\u548c\u5b58\u50a8\u64cd\u4f5c\u3002<\/li>\n<\/ul>\n\n\n\n<p>\u8bf7\u6ce8\u610f\uff0c\u5171\u4eab\u5185\u5b58\u7684\u5927\u5c0f\u9700\u8981\u5728\u5185\u6838\u8c03\u7528\u65f6\u6307\u5b9a\uff0c\u8fd9\u901a\u5e38\u662f\u5728\u5185\u6838\u8c03\u7528\u7684\u914d\u7f6e\u53c2\u6570\u4e2d\u8bbe\u7f6e\uff0c\u5982 <code>&lt;&lt;&lt;numBlocks, blockSize, sharedMemSize&gt;&gt;&gt;<\/code>\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c<code>sharedMemSize<\/code> \u5e94\u8bbe\u7f6e\u4e3a <code>2 * C * sizeof(float)<\/code>\uff0c\u4ee5\u786e\u4fdd\u4e3a\u6743\u91cd\u548c\u504f\u7f6e\u5206\u914d\u8db3\u591f\u7684\u5171\u4eab\u5185\u5b58\u7a7a\u95f4\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.9 permute_kernel<\/strong><\/h3>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \" >__global__ void permute_kernel(float* q, float* k, float* v,\n                               const float* inp,\n                               int B, int N, int NH, int d) {\n    \/\/ okay so now, this kernel wants Q,K,V to all be of shape (B, NH, N, d)\n    \/\/ but instead, we have a single tensor QKV (inp) of shape (B, N, 3, NH, d)\n    \/\/ \u6b64\u5185\u6838\u51fd\u6570\u7684\u76ee\u7684\u662f\u5c06Q, K, V\u7684\u5f62\u72b6\u53d8\u4e3a (B, NH, N, d)\n    \/\/ \u4f46\u539f\u59cb\u7684\u8f93\u5165\u5f20\u91cfQKV (inp) \u7684\u5f62\u72b6\u4e3a (B, N, 3, NH, d)\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    \/\/ Q[b][nh_][n][d_] = inp[b][n][0][nh_][d_]\n    \/\/ \u8ba1\u7b97Q, K, V\u7684\u7d22\u5f15\uff0c\u5373 Q[b][nh_][n][d_] = inp[b][n][0][nh_][d_]\n    if (idx &lt; B * NH * N * d) {\n        \/\/ \u8ba1\u7b97batch\u7d22\u5f15 b\n        int b = idx \/ (NH * N * d);\n        int rest = idx % (NH * N * d);\n        \/\/ \u8ba1\u7b97\u5934\u7d22\u5f15 nh_\n        int nh_ = rest \/ (N * d);\n        rest = rest % (N * d);\n        \/\/ \u8ba1\u7b97\u5e8f\u5217\u7d22\u5f15 n\n        int n = rest \/ d;\n        \/\/ \u8ba1\u7b97\u7ef4\u5ea6\u7d22\u5f15 d_\n        int d_ = rest % d;\n        \/\/ \u8ba1\u7b97inp\u4e2d\u5bf9\u5e94\u7684\u7d22\u5f15\u4f4d\u7f6e\n        int inp_idx = (b * N * 3 * NH * d) + (n * 3 * NH * d) + (0 * NH * d) + (nh_ * d) + d_;\n        \/\/ \u4eceinp\u7684\u7279\u5b9a\u4f4d\u7f6e\u52a0\u8f7dQuery\u6570\u636e\u5230q\n        q[idx] = __ldcs(&amp;inp[inp_idx]);\n        \/\/ \u4eceinp\u7684\u7279\u5b9a\u4f4d\u7f6e\u52a0\u8f7dKey\u6570\u636e\u5230k\uff08\u6ce8\u610f\u504f\u79fb\u91cf\u4e3aNH * d\uff09\n        k[idx] = __ldcs(&amp;inp[inp_idx + NH * d]);\n        \/\/ \u4eceinp\u7684\u7279\u5b9a\u4f4d\u7f6e\u52a0\u8f7dValue\u6570\u636e\u5230v\uff08\u6ce8\u610f\u504f\u79fb\u91cf\u4e3a2 * NH * d\uff09\n        v[idx] = __ldcs(&amp;inp[inp_idx + 2 * (NH * d)]);\n    }\n}\n<\/pre><\/div>\n\n\n\n<h4 class=\"wp-block-heading\">\u89e3\u91ca<\/h4>\n\n\n\n<ol class=\"wp-block-list\">\n<li><strong>\u7d22\u5f15\u8ba1\u7b97<\/strong>\uff1a\u6bcf\u4e2a\u7ebf\u7a0b\u901a\u8fc7\u5176\u5168\u5c40\u7d22\u5f15<code>idx<\/code>\u8d1f\u8d23\u5904\u7406\u4e00\u4e2a\u7279\u5b9a\u7684\u6570\u636e\u5143\u7d20\u3002\u5168\u5c40\u7d22\u5f15\u662f\u6839\u636e\u5757\u7684\u7d22\u5f15\uff08<code>blockIdx.x<\/code>\uff09\u3001\u5757\u7684\u7ef4\u5ea6\uff08<code>blockDim.x<\/code>\uff09\u548c\u7ebf\u7a0b\u7684\u7d22\u5f15\uff08<code>threadIdx.x<\/code>\uff09\u8ba1\u7b97\u5f97\u51fa\u3002<\/li>\n\n\n\n<li><strong>\u5f62\u72b6\u8f6c\u6362<\/strong>\uff1a\u8f93\u5165\u5f20\u91cf<code>inp<\/code>\u7684\u5f62\u72b6\u662f<code>(B, N, 3, NH, d)<\/code>\uff0c\u5176\u4e2d3\u4ee3\u8868\u5408\u5e76\u4e86Q\u3001K\u3001V\u4e09\u79cd\u7c7b\u578b\u7684\u6570\u636e\u3002\u6b64\u51fd\u6570\u7684\u76ee\u6807\u662f\u5c06\u8fd9\u4e9b\u6570\u636e\u91cd\u65b0\u6392\u5217\u4e3a\u4e09\u4e2a\u5206\u5f00\u7684\u5f20\u91cf\uff0c\u6bcf\u4e2a\u5f20\u91cf\u7684\u5f62\u72b6\u4e3a<code>(B, NH, N, d)<\/code>\u3002<\/li>\n\n\n\n<li><strong>\u5185\u5b58\u8bbf\u95ee<\/strong>\uff1a\u4f7f\u7528<code>__ldcs<\/code>\uff08load from constant memory with caching\uff09\u51fd\u6570\u4ece\u5e38\u91cf\u5185\u5b58\u52a0\u8f7d\u6570\u636e\uff0c\u8fd9\u6709\u52a9\u4e8e\u63d0\u9ad8\u5185\u5b58\u8bbf\u95ee\u7684\u6548\u7387\uff0c\u56e0\u4e3a\u8be5\u51fd\u6570\u53ef\u4ee5\u5229\u7528GPU\u7684\u7f13\u5b58\u673a\u5236\u3002<\/li>\n<\/ol>\n\n\n\n<p>\u8fd9\u79cd\u5b9e\u73b0\u65b9\u5f0f\u4f7f\u5f97\u5185\u6838\u80fd\u9ad8\u6548\u5730\u4ece\u4e00\u4e2a\u590d\u5408\u5f20\u91cf\u4e2d\u63d0\u53d6\u5e76\u91cd\u6392\u6570\u636e\u5230\u4e09\u4e2a\u72ec\u7acb\u7684\u5f20\u91cf\u4e2d\uff0c\u8fd9\u662f\u5728\u5904\u7406\u57fa\u4e8eTransformer\u67b6\u6784\u7684\u795e\u7ecf\u7f51\u7edc\u6a21\u578b\u4e2d\u5e38\u89c1\u7684\u64cd\u4f5c\uff0c\u7279\u522b\u662f\u5728\u591a\u5934\u81ea\u6ce8\u610f\u529b\u673a\u5236\u7684\u5b9e\u73b0\u4e2d\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.10 permute_kernel_backward<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u5b9a\u4e49\u4e86\u4e00\u4e2a\u540d\u4e3a <code>permute_kernel_backward<\/code> \u7684CUDA\u5185\u6838\u51fd\u6570\uff0c\u7528\u4e8e\u6267\u884c\u53cd\u5411\u4f20\u64ad\u8fc7\u7a0b\u4e2d\u4ece\u5206\u6563\u7684\u67e5\u8be2\uff08Query, Q\uff09\u3001\u952e\uff08Key, K\uff09\u548c\u503c\uff08Value, V\uff09\u5f20\u91cf\u91cd\u65b0\u6784\u9020\u56de\u5408\u5e76\u5f20\u91cf <code>dinp<\/code>\u3002\u8fd9\u4e2a\u5185\u6838\u51fd\u6570\u5b9e\u8d28\u4e0a\u662f <code>permute_kernel<\/code> \u7684\u9006\u8fc7\u7a0b\uff0c\u5176\u4e2d\u5c06\u5206\u5f00\u7684\u5f20\u91cf\u91cd\u65b0\u7ec4\u5408\u5230\u4e00\u4e2a\u5927\u7684\u5f20\u91cf\u4e2d\uff0c\u7528\u4e8e\u53ef\u80fd\u7684\u540e\u7eed\u68af\u5ea6\u4f20\u9012\u6216\u53c2\u6570\u66f4\u65b0\u3002\u4e0b\u9762\u662f\u5bf9\u4ee3\u7801\u7684\u8be6\u7ec6\u89e3\u91ca\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \" >\/\/ float* dinp\uff1a\u8f93\u51fa\u7684\u5408\u5e76\u540e\u7684\u68af\u5ea6\u5f20\u91cf\uff0c\u5f62\u72b6\u4e3a (B, N, 3, NH, d)\u3002\n\/\/ const float* dq\uff1a\u8f93\u5165\u7684Query\u68af\u5ea6\u5f20\u91cf\uff0c\u5f62\u72b6\u4e3a (B, NH, N, d)\u3002\n\/\/ const float* dk\uff1a\u8f93\u5165\u7684Key\u68af\u5ea6\u5f20\u91cf\uff0c\u5f62\u72b6\u4e3a (B, NH, N, d)\u3002\n\/\/ const float* dv\uff1a\u8f93\u5165\u7684Value\u68af\u5ea6\u5f20\u91cf\uff0c\u5f62\u72b6\u4e3a (B, NH, N, d)\u3002\n\/\/ int B\uff1a\u6279\u5904\u7406\u5927\u5c0f\u3002\n\/\/ int N\uff1a\u5e8f\u5217\u957f\u5ea6\u3002\n\/\/ int NH\uff1a\u6ce8\u610f\u529b\u5934\u6570\u3002\n\/\/ int d\uff1a\u6bcf\u4e2a\u5934\u7684\u7279\u5f81\u7ef4\u6570\u3002\n__global__ void permute_kernel_backward(float* dinp,\n                                        const float* dq, const float* dk, const float* dv,\n                                        int B, int N, int NH, int d) {\n    \/\/ 1. \u7d22\u5f15\u8ba1\u7b97\uff1a\u901a\u8fc7blockIdx.x * blockDim.x + threadIdx.x\u8ba1\u7b97\u5f53\u524d\u7ebf\u7a0b\u7684\u5168\u5c40\u7d22\u5f15idx\u3002\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    \/\/ 2. \u8fb9\u754c\u68c0\u67e5\uff1a\u786e\u4fddidx\u5728\u6709\u6548\u8303\u56f4\u5185\uff08\u5373\u5c0f\u4e8eB * NH * N * d\uff09\u3002\n    if (idx &lt; B * NH * N * d) {\n        \/\/ 3. \u6620\u5c04\u5230\u4e09\u7ef4\u7d22\u5f15\uff1a\u901a\u8fc7\u4e00\u7cfb\u5217\u7684\u6574\u6570\u9664\u6cd5\u548c\u53d6\u4f59\u64cd\u4f5c\uff0c\u89e3\u6790\u51fa\u4e94\u7ef4\u7d22\u5f15\u4e2d\u7684b\uff08\u6279\u6b21\u7d22\u5f15\uff09\u3001nh_\uff08\u5934\u7d22\u5f15\uff09\u3001n\uff08\u5e8f\u5217\u7d22\u5f15\uff09\u548cd_\uff08\u7ef4\u5ea6\u7d22\u5f15\uff09\u3002\n        int b = idx \/ (NH * N * d);\n        int rest = idx % (NH * N * d);\n        int nh_ = rest \/ (N * d);\n        rest = rest % (N * d);\n        int n = rest \/ d;\n        int d_ = rest % d;\n\n        \/\/ 4. \u5408\u5e76\u5f20\u91cf\u7d22\u5f15\u8ba1\u7b97\uff1a\u8ba1\u7b97\u5728\u8f93\u51fa\u5408\u5e76\u5f20\u91cfdinp\u4e2d\u5bf9\u5e94\u4f4d\u7f6e\u7684\u7d22\u5f15inp_idx\u3002\n        int inp_idx = (b * N * 3 * NH * d) + (n * 3 * NH * d) + (0 * NH * d) + (nh_ * d) + d_;\n        \/\/ \u5c06Query\u7684\u68af\u5ea6dq[idx]\u8d4b\u503c\u5230dinp\u7684\u5bf9\u5e94\u4f4d\u7f6e\u3002\n        dinp[inp_idx] = dq[idx];\n        \/\/ \u5c06Key\u7684\u68af\u5ea6dk[idx]\u8d4b\u503c\u5230dinp\u7684NH * d\u540e\u7684\u4f4d\u7f6e\u3002\n        dinp[inp_idx + NH * d] = dk[idx];\n        \/\/ \u5c06Value\u7684\u68af\u5ea6dv[idx]\u8d4b\u503c\u5230dinp\u76842 * NH * d\u540e\u7684\u4f4d\u7f6e\u3002\n        dinp[inp_idx + 2 * (NH * d)] = dv[idx];\n    }\n}\n<\/pre><\/div>\n\n\n\n<h4 class=\"wp-block-heading\">\u6027\u80fd\u8003\u8651\uff1a<\/h4>\n\n\n\n<ul class=\"wp-block-list\">\n<li>\u8fd9\u4e2a\u5185\u6838\u5229\u7528\u4e86\u7b80\u5355\u7684\u7ebf\u6027\u5185\u5b58\u64cd\u4f5c\u6765\u91cd\u5efa\u5408\u5e76\u5f20\u91cf\uff0c\u5bf9\u5185\u5b58\u5e26\u5bbd\u7684\u9700\u6c42\u8f83\u9ad8\u3002<\/li>\n\n\n\n<li>\u5185\u5b58\u8bbf\u95ee\u6a21\u5f0f\u662f\u8fde\u7eed\u7684\uff0c\u8fd9\u6709\u52a9\u4e8eGPU\u8fdb\u884c\u9ad8\u6548\u7684\u5185\u5b58\u8bbf\u95ee\u548c\u7f13\u5b58\u4f18\u5316\u3002<\/li>\n\n\n\n<li>\u8be5\u5185\u6838\u51fd\u6570\u53ef\u4ee5\u5e76\u884c\u5730\u7531\u591a\u4e2a\u7ebf\u7a0b\u6267\u884c\uff0c\u6bcf\u4e2a\u7ebf\u7a0b\u72ec\u7acb\u5904\u7406\u4e00\u4e2a\u6570\u636e\u70b9\uff0c\u4ece\u800c\u5b9e\u73b0\u9ad8\u5e76\u884c\u6027\u3002<\/li>\n<\/ul>\n\n\n\n<p>\u901a\u8fc7\u8fd9\u79cd\u65b9\u5f0f\uff0c<code>permute_kernel_backward<\/code> \u5728\u4fdd\u8bc1\u6570\u636e\u6b63\u786e\u6027\u7684\u540c\u65f6\u5b9e\u73b0\u4e86\u9ad8\u6548\u7684\u5185\u5b58\u64cd\u4f5c\uff0c\u9002\u5408\u5728\u6df1\u5ea6\u5b66\u4e60\u6a21\u578b\u4e2d\u8fdb\u884c\u68af\u5ea6\u7684\u53cd\u5411\u4f20\u64ad\u8ba1\u7b97\uff0c\u7279\u522b\u662f\u5728\u5904\u7406\u57fa\u4e8eTransformer\u7ed3\u6784\u7684\u6a21\u578b\u65f6\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.11 unpermute_kernel<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5CUDA\u5185\u6838\u51fd\u6570 <code>unpermute_kernel<\/code> \u7684\u76ee\u7684\u662f\u5bf9\u4e00\u4e2a\u5f20\u91cf\u8fdb\u884c\u91cd\u65b0\u6392\u5217\uff08\u53cd\u7f6e\u6362\uff09\uff0c\u4ee5\u7b26\u5408\u67d0\u4e2a\u7279\u5b9a\u7684\u7ef4\u5ea6\u987a\u5e8f\u3002\u8f93\u5165\u5f20\u91cf <code>inp<\/code> \u7684\u5f62\u72b6\u4e3a <code>(B, NH, N, d)<\/code>\uff0c\u800c\u8f93\u51fa\u5f20\u91cf <code>out<\/code> \u7684\u671f\u671b\u5f62\u72b6\u4e3a <code>(B, N, NH, d)<\/code>\u3002\u8fd9\u79cd\u53d8\u6362\u5728\u5904\u7406\u6df1\u5ea6\u5b66\u4e60\u6a21\u578b\u4e2d\u7684\u6570\u636e\u65f6\u5f88\u5e38\u89c1\uff0c\u5c24\u5176\u662f\u5728\u9700\u8981\u6539\u53d8\u6570\u636e\u5e03\u5c40\u4ee5\u9002\u5e94\u4e0d\u540c\u64cd\u4f5c\u6216\u5c42\u7684\u8981\u6c42\u65f6\u3002<\/p>\n\n\n\n<p>\u4ee5\u4e0b\u662f\u5bf9\u4ee3\u7801\u7684\u8be6\u7ec6\u89e3\u91ca\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \" >\/\/ float* inp\uff1a\u8f93\u5165\u5f20\u91cf\uff0c\u5176\u5f62\u72b6\u4e3a (B, NH, N, d)\u3002\n\/\/ float* out\uff1a\u8f93\u51fa\u5f20\u91cf\uff0c\u5176\u76ee\u6807\u5f62\u72b6\u4e3a (B, N, NH, d)\u3002\n\/\/ int B\uff1a\u6279\u5904\u7406\u5927\u5c0f\u3002\n\/\/ int N\uff1a\u5e8f\u5217\u957f\u5ea6\u3002\n\/\/ int NH\uff1a\u6ce8\u610f\u529b\u5934\u6570\u3002\n\/\/ int d\uff1a\u6bcf\u4e2a\u5934\u7684\u7279\u5f81\u7ef4\u6570\u3002\n__global__ void unpermute_kernel(float* inp, float *out, int B, int N, int NH, int d) {\n   \/\/ out has shape (B, nh, N, d) but we need to unpermute it to (B, N, nh, d)\n    \/\/ 1. \u7d22\u5f15\u8ba1\u7b97\uff1a\u901a\u8fc7 blockIdx.x * blockDim.x + threadIdx.x \u8ba1\u7b97\u5f53\u524d\u7ebf\u7a0b\u7684\u5168\u5c40\u7d22\u5f15 idx\uff0c\u8fd9\u4e00\u7d22\u5f15\u4ee3\u8868 inp \u4e2d\u7684\u7ebf\u6027\u4f4d\u7f6e\u3002\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    \/\/ out[b][n][nh_][d_] &lt;- inp[b][nh_][n][d_]\n    \/\/ 2. \u8fb9\u754c\u68c0\u67e5\uff1a\u786e\u4fdd idx \u5728\u6709\u6548\u8303\u56f4\u5185\uff08\u5373\u5c0f\u4e8e B * NH * N * d\uff09\u3002\n    if (idx &lt; B * NH * N * d) {\n        \/\/ 3. \u6620\u5c04\u5230\u56db\u7ef4\u7d22\u5f15\uff1a\u901a\u8fc7\u6574\u6570\u9664\u6cd5\u548c\u53d6\u4f59\u64cd\u4f5c\uff0c\u5c06\u4e00\u7ef4\u7ebf\u6027\u7d22\u5f15 idx \u8f6c\u6362\u6210\u56db\u7ef4\u7d22\u5f15 (b, nh_, n, d_)\uff0c\u8fd9\u56db\u4e2a\u7d22\u5f15\u5206\u522b\u5bf9\u5e94\u6279\u6b21\u3001\u5934\u7d22\u5f15\u3001\u5e8f\u5217\u7d22\u5f15\u548c\u7ef4\u5ea6\u7d22\u5f15\u3002\n        int b = idx \/ (NH * N * d);\n        int rest = idx % (NH * N * d);\n        int nh_ = rest \/ (N * d);\n        rest = rest % (N * d);\n        int n = rest \/ d;\n        int d_ = rest % d;\n        \/\/ 4. \u8ba1\u7b97\u8f93\u51fa\u5f20\u91cf\u7684\u7d22\u5f15\uff1a\u6839\u636e\u8f93\u5165\u7684\u56db\u7ef4\u7d22\u5f15\u8ba1\u7b97\u8f93\u51fa\u5f20\u91cf out \u7684\u5bf9\u5e94\u7d22\u5f15 other_idx\u3002\u8fd9\u4e00\u6b65\u662f\u5c06 nh_\uff08\u5934\u7d22\u5f15\uff09\u548c n\uff08\u5e8f\u5217\u7d22\u5f15\uff09\u7684\u4f4d\u7f6e\u4e92\u6362\uff0c\u4ee5\u7b26\u5408\u8f93\u51fa\u5f20\u91cf\u7684\u671f\u671b\u5f62\u72b6\u3002\n        int other_idx = (b * NH * N * d) + (n * NH * d) + (nh_ * d) + d_;\n        \/\/ 5. \u6570\u636e\u8d4b\u503c\uff1a\u5c06\u8f93\u5165\u5f20\u91cf inp \u4e2d\u7684\u6570\u636e\u6839\u636e\u8ba1\u7b97\u51fa\u7684\u65b0\u7d22\u5f15\u8d4b\u503c\u5230\u8f93\u51fa\u5f20\u91cf out \u4e2d\u3002\u4f7f\u7528 __ldcs \uff08\u4ece\u5e38\u91cf\u5185\u5b58\u52a0\u8f7d\u6570\u636e\u5e76\u7f13\u5b58\uff09\u53ef\u4ee5\u4f18\u5316\u5185\u5b58\u8bbf\u95ee\n        out[other_idx] = __ldcs(&amp;inp[idx]);\n    }\n}\n<\/pre><\/div>\n\n\n\n<h4 class=\"wp-block-heading\">\u6027\u80fd\u8003\u8651\uff1a<\/h4>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong>\u5185\u5b58\u8bbf\u95ee\u6a21\u5f0f<\/strong>\uff1a\u7531\u4e8e\u6570\u636e\u91cd\u65b0\u6392\u5217\u901a\u5e38\u6d89\u53ca\u975e\u8fde\u7eed\u7684\u5185\u5b58\u8bbf\u95ee\u6a21\u5f0f\uff0c\u8fd9\u53ef\u80fd\u5bfc\u81f4\u5185\u5b58\u8bbf\u95ee\u6548\u7387\u964d\u4f4e\u3002\u4f7f\u7528 <code>__ldcs<\/code> \u53ef\u4ee5\u5e2e\u52a9\u51cf\u5c11\u8fd9\u79cd\u5f71\u54cd\uff0c\u56e0\u4e3a\u5b83\u5141\u8bb8\u66f4\u6709\u6548\u5730\u4f7f\u7528GPU\u7684\u7f13\u5b58\u3002<\/li>\n\n\n\n<li><strong>\u5e76\u884c\u5ea6<\/strong>\uff1a\u5185\u6838\u7684\u8bbe\u8ba1\u5141\u8bb8\u5b8c\u5168\u5e76\u884c\u7684\u6267\u884c\uff0c\u6bcf\u4e2a\u7ebf\u7a0b\u72ec\u7acb\u5904\u7406\u4e00\u4e2a\u6570\u636e\u70b9\uff0c\u8fd9\u6709\u52a9\u4e8e\u9ad8\u6548\u5229\u7528GPU\u7684\u5e76\u884c\u5904\u7406\u80fd\u529b\u3002<\/li>\n<\/ul>\n\n\n\n<p>\u6b64\u5185\u6838\u9002\u7528\u4e8e\u5728\u9700\u8981\u5c06\u6570\u636e\u4ece\u4e00\u4e2a\u5e03\u5c40\u8f6c\u6362\u4e3a\u53e6\u4e00\u4e2a\u5e03\u5c40\u7684\u573a\u666f\u4e2d\uff0c\u5e38\u89c1\u4e8e\u5904\u7406\u591a\u5934\u81ea\u6ce8\u610f\u529b\u673a\u5236\u8f93\u51fa\u7684\u6570\u636e\u8f6c\u6362\uff0c\u7279\u522b\u662f\u5728\u5c06\u8fd9\u4e9b\u6570\u636e\u4f20\u9012\u5230\u540e\u7eed\u5c42\u4e4b\u524d\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.12 unpermute_kernel_backward<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5CUDA\u5185\u6838\u51fd\u6570 <code>unpermute_kernel_backward<\/code> \u7528\u4e8e\u6267\u884c\u53cd\u5411\u8fc7\u7a0b\u7684\u6570\u636e\u91cd\u6392\u5217\uff08\u53cd\u5411\u7f6e\u6362\uff09\u3002\u5b83\u4ece\u53d8\u6362\u540e\u7684\u68af\u5ea6\u5f20\u91cf <code>dout<\/code> \u4e2d\u8bfb\u53d6\u6570\u636e\uff0c\u5e76\u5c06\u5176\u6b63\u786e\u653e\u56de\u539f\u59cb\u68af\u5ea6\u5f20\u91cf <code>dinp<\/code> \u7684\u76f8\u5e94\u4f4d\u7f6e\u3002\u8fd9\u901a\u5e38\u662f\u5728\u6df1\u5ea6\u5b66\u4e60\u6a21\u578b\u7684\u53cd\u5411\u4f20\u64ad\u8fc7\u7a0b\u4e2d\u9700\u8981\u7684\uff0c\u4ee5\u786e\u4fdd\u68af\u5ea6\u53ef\u4ee5\u6b63\u786e\u5730\u4f20\u9012\u5230\u9002\u5f53\u7684\u5c42\u3002\u8be5\u51fd\u6570\u662f <code>unpermute_kernel<\/code> \u7684\u9006\u64cd\u4f5c\u3002<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \" >\/\/ float* dinp\uff1a\u8f93\u51fa\u5f20\u91cf\uff0c\u9700\u8981\u5c06\u68af\u5ea6\u6570\u636e\u5199\u5165\u7684\u5f20\u91cf\uff0c\u5176\u5f62\u72b6\u4e3a (B, NH, N, d)\u3002\n\/\/ const float* dout\uff1a\u8f93\u5165\u5f20\u91cf\uff0c\u6765\u6e90\u4e8e\u524d\u5411\u4f20\u64ad\u7684\u53cd\u7f6e\u6362\uff0c\u5176\u5f62\u72b6\u4e3a (B, N, NH, d)\u3002\n\/\/ int B\uff1a\u6279\u5904\u7406\u5927\u5c0f\u3002\n\/\/ int N\uff1a\u5e8f\u5217\u957f\u5ea6\u3002\n\/\/ int NH\uff1a\u6ce8\u610f\u529b\u5934\u6570\u3002\n\/\/ int d\uff1a\u6bcf\u4e2a\u5934\u7684\u7279\u5f81\u7ef4\u6570\u3002\n__global__ void unpermute_kernel_backward(float* dinp, const float *dout, int B, int N, int NH, int d) {\n    \/\/ 1. \u7d22\u5f15\u8ba1\u7b97\uff1a\u901a\u8fc7 blockIdx.x * blockDim.x + threadIdx.x \u8ba1\u7b97\u5f53\u524d\u7ebf\u7a0b\u7684\u5168\u5c40\u7d22\u5f15 idx\uff0c\u8fd9\u4e00\u7d22\u5f15\u8868\u793a dinp \u4e2d\u7684\u7ebf\u6027\u4f4d\u7f6e\u3002\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    \/\/ 2. \u8fb9\u754c\u68c0\u67e5\uff1a\u786e\u4fdd idx \u5728\u6709\u6548\u8303\u56f4\u5185\uff08\u5373\u5c0f\u4e8e B * NH * N * d\uff09\u3002\n    if (idx &lt; B * NH * N * d) {\n        \/\/ 3. \u89e3\u6790\u56db\u7ef4\u7d22\u5f15\uff1a\u901a\u8fc7\u6574\u6570\u9664\u6cd5\u548c\u53d6\u4f59\u64cd\u4f5c\uff0c\u5c06\u4e00\u7ef4\u7ebf\u6027\u7d22\u5f15 idx \u8f6c\u6362\u6210\u56db\u7ef4\u7d22\u5f15 (b, nh_, n, d_)\uff0c\u8fd9\u56db\u4e2a\u7d22\u5f15\u5206\u522b\u5bf9\u5e94\u6279\u6b21\u3001\u5934\u7d22\u5f15\u3001\u5e8f\u5217\u7d22\u5f15\u548c\u7ef4\u5ea6\u7d22\u5f15\u3002\n        int b = idx \/ (NH * N * d);\n        int rest = idx % (NH * N * d);\n        int nh_ = rest \/ (N * d);\n        rest = rest % (N * d);\n        int n = rest \/ d;\n        int d_ = rest % d;\n        \/\/ 4. \u8ba1\u7b97\u8f93\u5165\u5f20\u91cf\u7684\u7d22\u5f15\uff1a\u6839\u636e\u8f93\u5165\u7684\u56db\u7ef4\u7d22\u5f15\u8ba1\u7b97\u8f93\u5165\u5f20\u91cf dout \u7684\u5bf9\u5e94\u7d22\u5f15 other_idx\u3002\u8fd9\u91cc\u4e92\u6362\u4e86 nh_ \u548c n \u7684\u4f4d\u7f6e\uff0c\u4ee5\u5339\u914d dout \u7684\u5f62\u72b6 (B, N, NH, d)\u3002\n        int other_idx = (b * NH * N * d) + (n * NH * d) + (nh_ * d) + d_;\n        \/\/ 5. \u6570\u636e\u8d4b\u503c\uff1a\u5c06\u8f93\u5165\u5f20\u91cf dout \u4e2d\u7684\u6570\u636e\u6839\u636e\u8ba1\u7b97\u51fa\u7684\u7d22\u5f15 other_idx \u8d4b\u503c\u5230\u8f93\u51fa\u5f20\u91cf dinp \u7684\u5bf9\u5e94\u4f4d\u7f6e\u3002\n        dinp[idx] = dout[other_idx];\n    }\n}\n<\/pre><\/div>\n\n\n\n<h4 class=\"wp-block-heading\">\u6027\u80fd\u8003\u8651\uff1a<\/h4>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong>\u5185\u5b58\u8bbf\u95ee\u6a21\u5f0f<\/strong>\uff1a\u7531\u4e8e\u6570\u636e\u91cd\u6392\u901a\u5e38\u6d89\u53ca\u975e\u8fde\u7eed\u7684\u5185\u5b58\u8bbf\u95ee\uff0c\u8fd9\u53ef\u80fd\u5bfc\u81f4\u5185\u5b58\u8bbf\u95ee\u6548\u7387\u964d\u4f4e\u3002\u5408\u7406\u5229\u7528\u5185\u5b58\u8bbf\u95ee\u6a21\u5f0f\u548c\u7f13\u5b58\u53ef\u4ee5\u5e2e\u52a9\u6539\u5584\u6027\u80fd\u3002<\/li>\n\n\n\n<li><strong>\u5e76\u884c\u5ea6<\/strong>\uff1a\u5185\u6838\u7684\u8bbe\u8ba1\u4f7f\u5f97\u53ef\u4ee5\u5e76\u884c\u5730\u6267\u884c\uff0c\u6bcf\u4e2a\u7ebf\u7a0b\u72ec\u7acb\u5904\u7406\u4e00\u4e2a\u6570\u636e\u70b9\uff0c\u4ece\u800c\u9ad8\u6548\u5229\u7528GPU\u7684\u5e76\u884c\u5904\u7406\u80fd\u529b\u3002<\/li>\n<\/ul>\n\n\n\n<p>\u8fd9\u79cd\u5185\u6838\u51fd\u6570\u5728\u6df1\u5ea6\u5b66\u4e60\u7684\u53cd\u5411\u4f20\u64ad\u4e2d\u7279\u522b\u6709\u7528\uff0c\u56e0\u4e3a\u5b83\u786e\u4fdd\u68af\u5ea6\u53ef\u4ee5\u6b63\u786e\u5730\u6309\u539f\u59cb\u524d\u5411\u4f20\u64ad\u65f6\u7684\u5e03\u5c40\u53cd\u5411\u4f20\u9012\uff0c\u8fd9\u5bf9\u4e8e\u57fa\u4e8e\u6a21\u578b\u53c2\u6570\u7684\u6b63\u786e\u66f4\u65b0\u81f3\u5173\u91cd\u8981\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.13 vec_at<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u4e24\u4e2aCUDA\u8bbe\u5907\u51fd\u6570 <code>vec_at<\/code> \u662f\u4e3a\u4e86\u63d0\u4f9b\u5bf9 <code>float4<\/code> \u7c7b\u578b\u5411\u91cf\u4e2d\u5355\u4e2a\u5143\u7d20\u7684\u8bbf\u95ee\u3002<code>float4<\/code> \u662f\u4e00\u4e2a\u5185\u7f6e\u7684CUDA\u6570\u636e\u7c7b\u578b\uff0c\u5b83\u5c01\u88c5\u4e86\u56db\u4e2a\u6d6e\u70b9\u6570\u3002\u8fd9\u4e9b\u51fd\u6570\u901a\u8fc7\u5bf9 <code>float4<\/code> \u5bf9\u8c61\u4f7f\u7528 <code>reinterpret_cast<\/code> \u6765\u4f5c\u4e3a <code>float<\/code> \u6570\u7ec4\u8fdb\u884c\u5904\u7406\uff0c\u4ece\u800c\u53ef\u4ee5\u76f4\u63a5\u8bbf\u95ee\u5176\u5355\u4e2a\u5143\u7d20\u3002\u8fd9\u6837\u7684\u5b9e\u73b0\u589e\u52a0\u4e86\u7075\u6d3b\u6027\uff0c\u5141\u8bb8\u4ee5\u6570\u7ec4\u7d22\u5f15\u7684\u65b9\u5f0f\u8bbf\u95ee <code>float4<\/code> \u7684\u5404\u4e2a\u7ec4\u4ef6\u3002<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \" >\/\/ \u975e\u5e38\u91cf\u7248\u672c\n\/\/ \u8fd9\u4e2a\u7248\u672c\u8fd4\u56de\u4e00\u4e2a\u5bf9 float4 \u4e2d\u76f8\u5e94\u6d6e\u70b9\u6570\u7684\u5f15\u7528\u3002\u8fd9\u5141\u8bb8\u4f60\u76f4\u63a5\u4fee\u6539 float4 \u5b9e\u4f8b\u4e2d\u7684\u76f8\u5e94\u503c\u3002\n\/\/ \u4f8b\u5982\uff0c\u53ef\u4ee5\u4f7f\u7528\u8fd9\u4e2a\u51fd\u6570\u6765\u8bbe\u7f6e float4 \u4e2d\u67d0\u4e2a\u7279\u5b9a\u4f4d\u7f6e\u7684\u503c\u3002\n__device__ float&amp; vec_at(float4&amp; vec, int index) {\n    return reinterpret_cast&lt;float*&gt;(&amp;vec)[index];\n}\n\n\/\/ \u5e38\u91cf\u7248\u672c\n\/\/ \u8fd9\u4e2a\u7248\u672c\u9002\u7528\u4e8e\u4e0d\u9700\u8981\u4fee\u6539 float4 \u5b9e\u4f8b\u5185\u5bb9\u7684\u573a\u666f\u3002\u5b83\u8fd4\u56de\u7684\u662f\u4e00\u4e2a\u503c\uff0c\u800c\u4e0d\u662f\u5f15\u7528\uff0c\n\/\/ \u8fd9\u4fdd\u8bc1\u4e86\u51fd\u6570\u4f7f\u7528\u7684\u5b89\u5168\u6027\uff0c\u907f\u514d\u4e86\u4e0d\u5c0f\u5fc3\u4fee\u6539\u6570\u636e\u7684\u98ce\u9669\u3002\n__device__ float vec_at(const float4&amp; vec, int index) {\n    return reinterpret_cast&lt;const float*&gt;(&amp;vec)[index];\n}\n<\/pre><\/div>\n\n\n\n<h4 class=\"wp-block-heading\">\u5e94\u7528\u573a\u666f\uff1a<\/h4>\n\n\n\n<p>\u8fd9\u4e9b\u51fd\u6570\u5728\u9700\u8981\u5bf9 <code>float4<\/code> \u6570\u636e\u7ed3\u6784\u8fdb\u884c\u66f4\u7ec6\u7c92\u5ea6\u64cd\u4f5c\u65f6\u975e\u5e38\u6709\u7528\uff0c\u7279\u522b\u662f\u5728\u5904\u7406\u56fe\u5f62\u548c\u7269\u7406\u8ba1\u7b97\u4e2d\uff0c\u8fd9\u4e9b\u8ba1\u7b97\u53ef\u80fd\u9700\u8981\u5bf9\u5355\u72ec\u7684\u5411\u91cf\u7ec4\u4ef6\u8fdb\u884c\u8bfb\u53d6\u6216\u4fee\u6539\u3002\u4f8b\u5982\uff0c\u5f53\u4f60\u9700\u8981\u6839\u636e\u8fd0\u7b97\u7ed3\u679c\u52a8\u6001\u4fee\u6539\u67d0\u4e2a\u7ec4\u4ef6\u800c\u4e0d\u5f71\u54cd\u5176\u4ed6\u7ec4\u4ef6\u65f6\uff0c\u8fd9\u79cd\u65b9\u6cd5\u975e\u5e38\u9002\u7528\u3002<\/p>\n\n\n\n<h4 class=\"wp-block-heading\">\u6027\u80fd\u548c\u5b89\u5168\u6027\uff1a<\/h4>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong>\u6027\u80fd<\/strong>\uff1a\u7531\u4e8e\u8fd9\u4e9b\u51fd\u6570\u53ea\u6d89\u53ca\u7c7b\u578b\u8f6c\u6362\u548c\u57fa\u672c\u7684\u7d22\u5f15\u64cd\u4f5c\uff0c\u5b83\u4eec\u7684\u6027\u80fd\u5f00\u9500\u975e\u5e38\u4f4e\u3002<code>reinterpret_cast<\/code> \u5728\u8fd0\u884c\u65f6\u51e0\u4e4e\u6ca1\u6709\u6210\u672c\uff0c\u56e0\u4e3a\u5b83\u4ec5\u662f\u5728\u7f16\u8bd1\u5668\u5c42\u9762\u4e0a\u91cd\u65b0\u89e3\u91ca\u5df2\u6709\u7684\u6570\u636e\u3002<\/li>\n\n\n\n<li><strong>\u5b89\u5168\u6027<\/strong>\uff1a\u5c3d\u7ba1 <code>reinterpret_cast<\/code> \u7528\u6cd5\u7b80\u5355\u76f4\u63a5\uff0c\u4f46\u4f7f\u7528\u65f6\u9700\u8981\u4fdd\u8bc1\u4e0d\u4f1a\u8d8a\u754c\u8bbf\u95ee\uff0c\u5373\u7d22\u5f15\u503c\u5fc5\u987b\u5728 0 \u5230 3 \u4e4b\u95f4\u3002\u8d85\u51fa\u8fd9\u4e2a\u8303\u56f4\uff0c\u884c\u4e3a\u662f\u672a\u5b9a\u4e49\u7684\uff0c\u53ef\u80fd\u4f1a\u5bfc\u81f4\u9519\u8bef\u6216\u6570\u636e\u635f\u574f\u3002<\/li>\n<\/ul>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.14 softmax_forward_kernel5<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5CUDA\u5185\u6838 <code>softmax_forward_kernel5<\/code> \u4e3b\u8981\u7528\u4e8e\u6267\u884c\u5e26\u6709\u6e29\u5ea6\u56e0\u5b50\u7684Softmax\u51fd\u6570\u7684\u524d\u5411\u8ba1\u7b97\uff0c\u7279\u522b\u662f\u5728\u81ea\u6ce8\u610f\u529b\u673a\u5236\u4e2d\u5904\u7406\u5e8f\u5217\u6570\u636e\u3002\u5b83\u5229\u7528\u5757\u3001\u7ebf\u7a0b\u548c\u5206\u5757\u7684CUDA\u7279\u6027\u6765\u9ad8\u6548\u5730\u8ba1\u7b97Softmax\u3002\u6b64\u5185\u6838\u8ba1\u7b97 <code>(N, T, T)<\/code> \u5f62\u72b6\u5f20\u91cf\u7684Softmax\uff0c\u5176\u4e2d <code>N<\/code> \u901a\u5e38\u662f\u6279\u5927\u5c0f\u4e0e\u5934\u6570\u7684\u4e58\u79ef\uff0c<code>T<\/code> \u662f\u5e8f\u5217\u957f\u5ea6\u3002<\/p>\n\n\n\n<h4 class=\"wp-block-heading\">\u6838\u5fc3\u7279\u6027\u4e0e\u64cd\u4f5c\uff1a<\/h4>\n\n\n\n<ol class=\"wp-block-list\">\n<li><strong>\u6e29\u5ea6\u56e0\u5b50\u8c03\u6574<\/strong>\uff1a\u901a\u8fc7 <code>inv_temperature<\/code> \uff08\u6e29\u5ea6\u7684\u5012\u6570\uff09\u8c03\u6574Softmax\u7684\u654f\u611f\u5ea6\u3002<\/li>\n\n\n\n<li><strong>\u53cd\u5411\u8fed\u4ee3<\/strong>\uff1a\u4e3a\u4e86\u7f13\u5b58\u4f18\u5316\uff0c\u5185\u6838\u4ece\u540e\u5411\u524d\u8ba1\u7b97\uff0c\u4ee5\u4fbf\u5728\u63a5\u4e0b\u6765\u7684\u77e9\u9635\u4e58\u6cd5\u64cd\u4f5c\u4e2d\u66f4\u597d\u5730\u5229\u7528\u7f13\u5b58\u3002<\/li>\n\n\n\n<li><strong>\u5757\u4e0e\u7ebf\u7a0b\u4f7f\u7528<\/strong>\uff1a\n<ul class=\"wp-block-list\">\n<li>\u4f7f\u7528 Cooperative Groups (<code>cg<\/code>) \u5e93\u6765\u7ba1\u7406\u7ebf\u7a0b\u4e4b\u95f4\u7684\u534f\u4f5c\uff0c\u5305\u62ec\u6570\u636e\u5f52\u7ea6\u3002<\/li>\n\n\n\n<li>\u901a\u8fc7 <code>warp<\/code> \u5c06\u7ebf\u7a0b\u5206\u7ec4\uff0c\u4ee5\u51cf\u5c11\u540c\u6b65\u548c\u5f52\u7ea6\u64cd\u4f5c\u7684\u590d\u6742\u5ea6\u3002<\/li>\n<\/ul>\n<\/li>\n\n\n\n<li><strong>\u5728\u7ebfSoftmax\u8ba1\u7b97<\/strong>\uff1a\u91c7\u7528\u5728\u7ebf\u7b97\u6cd5\u9010\u6b65\u8ba1\u7b97Softmax\uff0c\u5373\u8fb9\u8bfb\u53d6\u8fb9\u8ba1\u7b97\uff0c\u6709\u6548\u907f\u514d\u4e86\u4e00\u6b21\u6027\u8bfb\u5165\u6574\u884c\u6570\u636e\u53ef\u80fd\u5f15\u8d77\u7684\u5185\u5b58\u538b\u529b\u548c\u8ba1\u7b97\u5ef6\u8fdf\u3002<\/li>\n<\/ol>\n\n\n\n<h4 class=\"wp-block-heading\">\u8ba1\u7b97\u7ec6\u8282\uff1a<\/h4>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong>\u6700\u5927\u503c\u548c\u6c42\u548c<\/strong>\uff1a\u4e3a\u4e86\u6570\u503c\u7a33\u5b9a\u6027\uff0c\u9996\u5148\u5728Softmax\u8ba1\u7b97\u524d\u627e\u51fa\u6700\u5927\u503c\uff0c\u7136\u540e\u6839\u636e\u6700\u5927\u503c\u8c03\u6574\u6240\u6709\u6570\u503c\u8fdb\u884c\u6307\u6570\u8fd0\u7b97\uff0c\u5e76\u7d2f\u52a0\u3002<\/li>\n\n\n\n<li><strong>\u5f52\u7ea6\u64cd\u4f5c<\/strong>\uff1a\u4f7f\u7528 <code>cg::reduce<\/code> \u8fdb\u884c\u8de8\u7ebf\u7a0b\u7684\u6700\u5927\u503c\u548c\u6c42\u548c\u5f52\u7ea6\u3002<\/li>\n\n\n\n<li><strong>\u6807\u51c6\u5316<\/strong>\uff1a\u5c06\u5f52\u7ea6\u540e\u5f97\u5230\u7684\u6c42\u548c\u503c\u7528\u4e8e\u5f52\u4e00\u5316\u6240\u6709\u6307\u6570\u503c\uff0c\u4ee5\u5f97\u5230Softmax\u7684\u8f93\u51fa\u3002<\/li>\n<\/ul>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \" >\/\/ float* out: \u8f93\u51fa\u6570\u7ec4\uff0c\u5b58\u50a8\u8ba1\u7b97\u5b8c\u6210\u540e\u7684Softmax\u7ed3\u679c\uff0c\u5f62\u72b6\u4e3a (N, T, T)\u3002\n\/\/ float inv_temperature: \u6e29\u5ea6\u56e0\u5b50\u7684\u5012\u6570\u3002\u5728Softmax\u8ba1\u7b97\u4e2d\uff0c\u6e29\u5ea6\u56e0\u5b50\u7528\u4e8e\u8c03\u6574\u8f93\u51fa\u5206\u5e03\u7684\u201c\u5c16\u9510\u5ea6\u201d\uff0c\u5176\u4e2d\u8f83\u4f4e\u7684\u6e29\u5ea6\u4f7f\u8f93\u51fa\u66f4\u5c16\u9510\uff08\u66f4\u96c6\u4e2d\u4e8e\u6700\u5927\u503c\uff09\u3002\n\/\/ const float* inp: \u8f93\u5165\u6570\u7ec4\uff0c\u5b58\u50a8Softmax\u8ba1\u7b97\u4e4b\u524d\u7684\u539f\u59cb\u5f97\u5206\u6216\u5bf9\u6570\u6982\u7387\uff0c\u5f62\u72b6\u540c\u6837\u4e3a (N, T, T)\u3002\n\/\/ int N: \u7b2c\u4e00\u7ef4\u7684\u5927\u5c0f\uff0c\u901a\u5e38\u662f\u6279\u5927\u5c0f\uff08B\uff09\u4e0e\u5934\u6570\uff08NH\uff09\u7684\u4e58\u79ef\uff0c\u5373 N = B * NH\u3002\n\/\/ int T: \u5e8f\u5217\u957f\u5ea6\u6216\u65f6\u95f4\u6b65\u7684\u6570\u91cf\uff0c\u540c\u6837\u7528\u4e8e\u8868\u793a inp \u548c out \u7684\u7b2c\u4e8c\u548c\u7b2c\u4e09\u7ef4\u3002\n__global__ void softmax_forward_kernel5(float* out, float inv_temperature, const float* inp, int N, int T) {\n    \/\/ inp, out shape: (N, T, T), where N = B * NH\n    \/\/ fuses the multiplication by scale inside attention\n    \/\/ directly autoregressive, so we only compute the lower triangular part\n    \/\/ uses the online softmax algorithm\n\t\t\/\/ inp, out\u5f62\u72b6\uff1a(N, T, T)\uff0c\u5176\u4e2dN = B * NH\n\t\t\/\/ \u5728\u6ce8\u610f\u529b\u8ba1\u7b97\u4e2d\u878d\u5408\u4e86\u7f29\u653e\u4e58\u6cd5\n\t\t\/\/ \u76f4\u63a5\u81ea\u56de\u5f52\uff0c\u56e0\u6b64\u6211\u4eec\u53ea\u8ba1\u7b97\u4e0b\u4e09\u89d2\u90e8\u5206\n\t\t\/\/ \u4f7f\u7528\u5728\u7ebfSoftmax\u7b97\u6cd5\n    assert(T % 4  == 0);\n    cg::thread_block block = cg::this_thread_block();\n    cg::thread_block_tile&lt;32&gt; warp = cg::tiled_partition&lt;32&gt;(block);\n    \/\/ micro-optimization: we iterate backwards so that\n    \/\/ after the softmax backward operation completes, the cache retains the\n    \/\/ part of the matrix close to the upper left corner, which benefits the\n    \/\/ matmul operation that immediately follows.\n    \/\/ int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); \/\/ forward order\n\t\t\/\/ \u5fae\u4f18\u5316\uff1a\u6211\u4eec\u53cd\u5411\u8fed\u4ee3\uff0c\u8fd9\u6837\n\t\t\/\/ \u5728Softmax\u53cd\u5411\u64cd\u4f5c\u5b8c\u6210\u540e\uff0c\u7f13\u5b58\u4f1a\u4fdd\u7559\u63a5\u8fd1\u5de6\u4e0a\u89d2\u7684\n\t\t\/\/ \u77e9\u9635\u90e8\u5206\uff0c\u8fd9\u5bf9\u7d27\u63a5\u7740\u7684\u77e9\u9635\u4e58\u6cd5\u64cd\u4f5c\u6709\u76ca\u3002\n\t\t\/\/ int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); \/\/ \u6b63\u5411\u987a\u5e8f\n    \/\/ idx \u53cd\u5411\u8ba1\u7b97\u7d22\u5f15\u4ee5\u5229\u7528\u7f13\u5b58\uff0c\u6539\u5584\u6027\u80fd\u3002\u8fd9\u91cc\u8c03\u6574\u4e86\u7d22\u5f15\u8ba1\u7b97\u65b9\u5f0f\uff0c\u4ee5\u4ece\u6570\u636e\u672b\u5c3e\u5411\u5f00\u5934\u8fdb\u884c\u8fed\u4ee3\uff0c\u5e2e\u52a9\u540e\u7eed\u7684\u77e9\u9635\u4e58\u6cd5\u64cd\u4f5c\u4e2d\u7f13\u5b58\u4f7f\u7528\u3002\n    int idx = (gridDim.x - blockIdx.x - 1) * warp.meta_group_size() + warp.meta_group_rank(); \/\/ backward order\n    if(idx &gt;= N * T) {\n        return;\n    }\n    int own_pos = idx % T;\n    int pos_by_4 = own_pos \/ 4;\n\n    \/\/ one row of inp, i.e. inp[idx, :] of shape (T,)\n    \/\/ inp\u7684\u4e00\u884c\uff0c\u5373 inp[idx, :] \u7684\u5f62\u72b6\u662f (T,)\n    const float* x = inp + idx * T;\n\n    \/\/ not INF, so we don't get NaNs accidentally when subtracting two values.\n    \/\/ \u4e0d\u662f\u65e0\u7a77\u5927\uff0c\u8fd9\u6837\u5728\u51cf\u53bb\u4e24\u4e2a\u503c\u65f6\u4e0d\u4f1a\u610f\u5916\u5f97\u5230NaN\u3002\n    float maxval = -FLT_MAX;\n    float sumval = 0.0f;\n\n    \/\/ \u8ba1\u7b97Softmax\u7684\u6838\u5fc3\u6b65\u9aa4\uff0c\u5305\u62ec\u6700\u5927\u503c\u641c\u7d22\u3001\u6307\u6570\u6c42\u548c\u548c\u5f52\u4e00\u5316\n    const float4* x_vec = reinterpret_cast&lt;const float4*&gt;(x);\n    for (int i = warp.thread_rank(); i &lt; pos_by_4; i += warp.size()) {\n        float4 v = x_vec[i];\n        float old_maxval = maxval;\n        for(int k = 0; k &lt; 4; ++k) {\n            \/\/ \u66f4\u65b0\u6700\u5927\u503c\n            maxval = fmaxf(maxval, vec_at(v, k));\n        }\n        \/\/ \u8c03\u6574\u65e7\u7684\u6c42\u548c\u503c\n        sumval *= expf(inv_temperature * (old_maxval - maxval));\n        for(int k = 0; k &lt; 4; ++k) {\n            \/\/ \u7d2f\u52a0\u65b0\u7684\u6307\u6570\u503c\n            sumval += expf(inv_temperature * (vec_at(v, k) - maxval));\n        }\n    }\n\n    if(4*pos_by_4 + warp.thread_rank() &lt;= own_pos) {\n        \/\/ \u5355\u4e2a\u5143\u7d20\u7684\u5904\u7406\uff0c\u4ee5\u5904\u7406\u4e0d\u80fd\u6574\u96644\u7684\u5269\u4f59\u5143\u7d20\n        float old_maxval = maxval;\n        maxval = fmaxf(maxval, x[4*pos_by_4 + warp.thread_rank()]);\n        sumval *= expf(inv_temperature * (old_maxval - maxval));\n        sumval += expf(inv_temperature * (x[4*pos_by_4 + warp.thread_rank()] - maxval));\n    }\n\n    float global_maxval = cg::reduce(warp, maxval, cg::greater&lt;float&gt;{});\n    sumval *= expf(inv_temperature * (maxval - global_maxval));\n\n    float sum = cg::reduce(warp, sumval, cg::plus&lt;float&gt;{});\n    float norm = 1.f \/ sum;\n\n    \/\/ divide the whole row by the sum\n    for (int i = warp.thread_rank(); i &lt;= own_pos; i += warp.size()) {\n        \/\/ recalculation is faster than doing the round-trip through memory.\n        float ev = expf(inv_temperature * (__ldcs(x + i) - global_maxval));\n        __stcs(out + idx * T + i, ev * norm);\n    }\n}\n<\/pre><\/div>\n\n\n\n<h4 class=\"wp-block-heading\">\u6027\u80fd\u4f18\u5316\uff1a<\/h4>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong>\u77e2\u91cf\u5316\u8bbf\u95ee<\/strong>\uff1a\u901a\u8fc7\u4f7f\u7528 <code>float4<\/code> \u8fdb\u884c\u77e2\u91cf\u5316\u5185\u5b58\u8bbf\u95ee\u6765\u63d0\u9ad8\u5e26\u5bbd\u5229\u7528\u7387\u3002<\/li>\n\n\n\n<li><strong>\u5c40\u90e8\u8ba1\u7b97\u4e0e\u5ef6\u8fdf\u52a0\u8f7d<\/strong>\uff1a\u901a\u8fc7\u5728\u9700\u8981\u65f6\u624d\u52a0\u8f7d\u5e76\u8ba1\u7b97\u6570\u636e\uff0c\u51cf\u5c11\u4e86\u5185\u5b58\u8bbf\u95ee\u6b21\u6570\uff0c\u63d0\u9ad8\u4e86\u8ba1\u7b97\u6548\u7387\u3002<\/li>\n\n\n\n<li><strong>\u7f13\u5b58\u4f18\u5316<\/strong>\uff1a\u901a\u8fc7\u53cd\u5411\u8fed\u4ee3\u987a\u5e8f\uff0c\u5c3d\u91cf\u4fdd\u7559\u5728\u7f13\u5b58\u4e2d\u9891\u7e41\u8bbf\u95ee\u7684\u6570\u636e\u3002<\/li>\n<\/ul>\n\n\n\n<p>\u8fd9\u4e2a\u5185\u6838\u5728\u8bbe\u8ba1\u4e0a\u975e\u5e38\u9002\u5408\u7528\u4e8e\u5904\u7406\u5927\u89c4\u6a21\u6570\u636e\u96c6\u5408\uff0c\u5728\u4f8b\u5982Transformer\u6a21\u578b\u4e2d\u5904\u7406\u81ea\u6ce8\u610f\u529b\u7684Softmax\u5c42\u65f6\u80fd\u591f\u63d0\u4f9b\u9ad8\u6548\u7684\u8ba1\u7b97\u6027\u80fd\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.15 residual_forward_kernel<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5CUDA\u5185\u6838\u51fd\u6570 <code>residual_forward_kernel<\/code> \u7528\u4e8e\u8ba1\u7b97\u6b8b\u5dee\u8fde\u63a5\u7684\u8f93\u51fa\u3002\u5728\u6df1\u5ea6\u5b66\u4e60\u4e2d\uff0c\u5c24\u5176\u662f\u5728\u50cfTransformer\u8fd9\u6837\u7684\u7f51\u7edc\u7ed3\u6784\u4e2d\uff0c\u6b8b\u5dee\u8fde\u63a5\u662f\u4e00\u79cd\u5e38\u7528\u7684\u6280\u672f\uff0c\u5b83\u6709\u52a9\u4e8e\u51cf\u5c11\u68af\u5ea6\u6d88\u5931\u95ee\u9898\uff0c\u5141\u8bb8\u66f4\u6df1\u7684\u7f51\u7edc\u7ed3\u6784\u8fdb\u884c\u8bad\u7ec3\u3002<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \" >\/\/ float* out: \u8f93\u51fa\u6570\u7ec4\uff0c\u5b58\u50a8\u7ed3\u679c\u6570\u636e\u3002\u8fd9\u4e2a\u6570\u7ec4\u7684\u957f\u5ea6\u5e94\u8be5\u4e0e\u8f93\u5165\u6570\u7ec4 inp1 \u548c inp2 \u4e00\u81f4\u3002\n\/\/ float* inp1: \u7b2c\u4e00\u4e2a\u8f93\u5165\u6570\u7ec4\uff0c\u5176\u4e2d\u5305\u542b\u4e00\u4e9b\u524d\u4e00\u5c42\u6216\u64cd\u4f5c\u7684\u8f93\u51fa\u6570\u636e\u3002\n\/\/ float* inp2: \u7b2c\u4e8c\u4e2a\u8f93\u5165\u6570\u7ec4\uff0c\u901a\u5e38\u5305\u542b\u53e6\u4e00\u5c42\u6216\u64cd\u4f5c\u7684\u8f93\u51fa\u6570\u636e\uff0c\u8fd9\u4e24\u4e2a\u6570\u7ec4\u5728\u76f8\u540c\u7d22\u5f15\u5904\u7684\u5143\u7d20\u5c06\u88ab\u76f8\u52a0\u3002\n\/\/ int N: \u6570\u7ec4 inp1\u3001inp2 \u548c out \u7684\u5143\u7d20\u6570\u91cf\uff0c\u6307\u660e\u4e86\u5728\u8fd9\u4e09\u4e2a\u6570\u7ec4\u4e2d\u6709\u591a\u5c11\u5143\u7d20\u9700\u8981\u8fdb\u884c\u5904\u7406\u3002\n__global__ void residual_forward_kernel(float* out, float* inp1, float* inp2, int N) {\n    \/\/ 1. \u7d22\u5f15\u8ba1\u7b97\uff1a\u901a\u8fc7 blockIdx.x * blockDim.x + threadIdx.x \u8ba1\u7b97\u5f53\u524d\u7ebf\u7a0b\u7684\u5168\u5c40\u7d22\u5f15 idx\uff0c\u8fd9\u4e00\u7d22\u5f15\u7528\u4e8e\u8bbf\u95ee\u8f93\u5165\u548c\u8f93\u51fa\u6570\u7ec4\u4e2d\u7684\u5143\u7d20\u3002\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    \/\/ 2. \u8fb9\u754c\u68c0\u67e5\uff1a\u786e\u4fdd\u5f53\u524d\u7ebf\u7a0b\u7684\u7d22\u5f15 idx \u4e0d\u8d85\u8fc7\u6570\u7ec4\u7684\u5927\u5c0f N\uff0c\u8fd9\u662f\u4e3a\u4e86\u9632\u6b62\u6570\u7ec4\u8d8a\u754c\u8bbf\u95ee\u3002\n    if (idx &lt; N) {\n        \/\/ \u4f7f\u7528 __ldcs(&amp;inp1[idx]) \u4ece\u7b2c\u4e00\u4e2a\u8f93\u5165\u6570\u7ec4 inp1 \u8bfb\u53d6\u6570\u636e\uff0c__ldcs \u51fd\u6570\u4ece\u5e38\u91cf\u5185\u5b58\u52a0\u8f7d\u6570\u636e\uff0c\u8fd9\u6709\u52a9\u4e8e\u5229\u7528\u7f13\u5b58\u4ee5\u63d0\u9ad8\u6570\u636e\u8bbf\u95ee\u901f\u5ea6\u3002\n        \/\/ \u4f7f\u7528 __ldcs(&amp;inp2[idx]) \u4ece\u7b2c\u4e8c\u4e2a\u8f93\u5165\u6570\u7ec4 inp2 \u8bfb\u53d6\u6570\u636e\u3002\n        \/\/ \u5c06\u4e24\u4e2a\u6570\u636e\u76f8\u52a0\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5230\u8f93\u51fa\u6570\u7ec4 out \u7684\u76f8\u5e94\u4f4d\u7f6e\u3002\n        out[idx] = __ldcs(&amp;inp1[idx]) + __ldcs(&amp;inp2[idx]);\n    }\n}\n<\/pre><\/div>\n\n\n\n<h4 class=\"wp-block-heading\">\u6027\u80fd\u4f18\u5316\uff1a<\/h4>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong>\u4f7f\u7528 <code>__ldcs<\/code><\/strong>\uff1a\u8be5\u51fd\u6570\u5047\u8bbe\u6570\u636e\u53ef\u80fd\u5b58\u5728\u4e8e\u5e38\u91cf\u7f13\u5b58\u4e2d\uff0c\u8fd9\u53ef\u4ee5\u51cf\u5c11\u5168\u5c40\u5185\u5b58\u7684\u8bbf\u95ee\u5ef6\u8fdf\u3002\u5f53\u6570\u636e\u5b9e\u9645\u4e0a\u4e0d\u5728\u5e38\u91cf\u7f13\u5b58\u4e2d\u65f6\uff0c\u8fd9\u4e2a\u51fd\u6570\u4ecd\u7136\u4ece\u5168\u5c40\u5185\u5b58\u52a0\u8f7d\u6570\u636e\uff0c\u4f46\u901a\u5e38\u4e0d\u4f1a\u6bd4\u666e\u901a\u7684\u5168\u5c40\u5185\u5b58\u52a0\u8f7d\u66f4\u6162\u3002<\/li>\n\n\n\n<li><strong>\u7b80\u6d01\u7684\u6570\u636e\u64cd\u4f5c<\/strong>\uff1a\u8be5\u6838\u51fd\u6570\u4ec5\u8fdb\u884c\u52a0\u6cd5\u64cd\u4f5c\u548c\u6570\u636e\u5b58\u53d6\uff0c\u4f7f\u5f97\u6574\u4f53\u8ba1\u7b97\u975e\u5e38\u9ad8\u6548\u3002<\/li>\n<\/ul>\n\n\n\n<h4 class=\"wp-block-heading\">\u793a\u4f8b\u5e94\u7528\u573a\u666f\uff1a<\/h4>\n\n\n\n<p>\u8fd9\u4e2a\u5185\u6838\u51fd\u6570\u53ef\u4ee5\u5728\u6267\u884c\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\u4e2d\u7684\u524d\u5411\u4f20\u64ad\u65f6\u4f7f\u7528\uff0c\u7279\u522b\u662f\u5728\u90a3\u4e9b\u4f7f\u7528\u6b8b\u5dee\u8fde\u63a5\u7684\u7f51\u7edc\u67b6\u6784\u4e2d\u3002\u4f8b\u5982\uff0c\u5728\u6bcf\u4e2aTransformer\u7f16\u7801\u5668\u6216\u89e3\u7801\u5668\u7684\u5c42\u540e\u6dfb\u52a0\u6b8b\u5dee\u8fde\u63a5\uff0c\u53ef\u4ee5\u5e2e\u52a9\u4fdd\u6301\u4e0d\u540c\u5c42\u95f4\u4fe1\u606f\u7684\u4f20\u9012\uff0c\u63d0\u9ad8\u7f51\u7edc\u7684\u8bad\u7ec3\u7a33\u5b9a\u6027\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.16 gelu_forward_kernel<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5CUDA\u5185\u6838\u51fd\u6570 <code>gelu_forward_kernel<\/code> \u7528\u4e8e\u8ba1\u7b97Gaussian Error Linear Unit (GELU) \u6fc0\u6d3b\u51fd\u6570\u7684\u8f93\u51fa\u3002GELU \u6fc0\u6d3b\u51fd\u6570\u662f\u5728\u6df1\u5ea6\u5b66\u4e60\u6a21\u578b\u4e2d\u5e38\u7528\u7684\u975e\u7ebf\u6027\u6fc0\u6d3b\u51fd\u6570\uff0c\u7279\u522b\u662f\u5728Transformer\u548cBERT\u7b49\u81ea\u7136\u8bed\u8a00\u5904\u7406\u6a21\u578b\u4e2d\u3002<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \" >\/\/ float* out: \u8f93\u51fa\u6570\u7ec4\uff0c\u5b58\u50a8\u8ba1\u7b97\u540e\u7684GELU\u6fc0\u6d3b\u51fd\u6570\u7684\u7ed3\u679c\u3002\n\/\/ const float* inp: \u8f93\u5165\u6570\u7ec4\uff0c\u5305\u542b\u5e94\u7528GELU\u51fd\u6570\u524d\u7684\u539f\u59cb\u503c\u3002\n\/\/ int N: \u8f93\u5165\u548c\u8f93\u51fa\u6570\u7ec4\u7684\u5143\u7d20\u6570\u91cf\uff0c\u6307\u660e\u4e86\u5728\u8fd9\u4e24\u4e2a\u6570\u7ec4\u4e2d\u6709\u591a\u5c11\u5143\u7d20\u9700\u8981\u8fdb\u884c\u5904\u7406\u3002\n__global__ void gelu_forward_kernel(float* out, const float* inp, int N) {\n    \/\/ 1. \u7d22\u5f15\u8ba1\u7b97\uff1a\u901a\u8fc7 blockIdx.x * blockDim.x + threadIdx.x \u8ba1\u7b97\u5f53\u524d\u7ebf\u7a0b\u7684\u5168\u5c40\u7d22\u5f15 i\uff0c\n    int i = blockIdx.x * blockDim.x + threadIdx.x;\n    \/\/ 2. \u8fb9\u754c\u68c0\u67e5\uff1a\u786e\u4fdd\u5f53\u524d\u7ebf\u7a0b\u7684\u7d22\u5f15 i \u4e0d\u8d85\u8fc7\u6570\u7ec4\u7684\u5927\u5c0f N\uff0c\u8fd9\u662f\u4e3a\u4e86\u9632\u6b62\u6570\u7ec4\u8d8a\u754c\u8bbf\u95ee\u3002\n    if (i &lt; N) {\n        \/\/ 3. GELU\u6fc0\u6d3b\u8ba1\u7b97\uff1a\n        \/\/ \u9996\u5148\u8bfb\u53d6\u8f93\u5165\u503c xi\u3002\n        \/\/ \u8ba1\u7b97 xi \u7684\u4e09\u6b21\u65b9\uff0c\u5e76\u4e58\u4ee5\u56fa\u5b9a\u7cfb\u6570 0.044715f \u5f97\u5230 cube\u3002\n        \/\/ \u4f7f\u7528\u9884\u5b9a\u4e49\u7684 GELU_SCALING_FACTOR\uff08\u57fa\u4e8e sqrt(2.0 \/ \u03c0) \u7684\u5e38\u91cf\uff09\u548c tanh \u51fd\u6570\u6765\u8ba1\u7b97GELU\u6fc0\u6d3b\u51fd\u6570\u7684\u4e3b\u4f53\u3002\n        \/\/ \u6700\u7ec8\u6fc0\u6d3b\u503c\u901a\u8fc7 0.5 * xi * (1.0 + tanhf(GELU_SCALING_FACTOR * (xi + cube))) \u8ba1\u7b97\u5f97\u51fa\uff0c\u5e76\u5b58\u50a8\u5728\u8f93\u51fa\u6570\u7ec4 out \u4e2d\u3002\n        float xi = inp[i];\n        float cube = 0.044715f * xi * xi * xi;\n        out[i] = 0.5f * xi * (1.0f + tanhf(GELU_SCALING_FACTOR * (xi + cube)));\n    }\n}\n<\/pre><\/div>\n\n\n\n<h4 class=\"wp-block-heading\">\u6027\u80fd\u4f18\u5316\uff1a<\/h4>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong>\u5e76\u884c\u5316\u5904\u7406<\/strong>\uff1a\u6b64\u5185\u6838\u9ad8\u5ea6\u5e76\u884c\u5316\uff0c\u6bcf\u4e2a\u7ebf\u7a0b\u72ec\u7acb\u5904\u7406\u4e00\u4e2a\u6570\u7ec4\u5143\u7d20\uff0c\u5927\u5e45\u63d0\u9ad8\u5904\u7406\u6548\u7387\u3002<\/li>\n\n\n\n<li><strong>\u7b80\u6d01\u7684\u6570\u636e\u64cd\u4f5c<\/strong>\uff1a\u901a\u8fc7\u76f4\u63a5\u8ba1\u7b97\u76f8\u5173\u6570\u5b66\u8868\u8fbe\u5f0f\uff0c\u907f\u514d\u4e86\u4e0d\u5fc5\u8981\u7684\u5185\u5b58\u8bbf\u95ee\u548c\u8ba1\u7b97\uff0c\u4f18\u5316\u4e86\u6267\u884c\u901f\u5ea6\u3002<\/li>\n<\/ul>\n\n\n\n<h4 class=\"wp-block-heading\">\u793a\u4f8b\u5e94\u7528\u573a\u666f\uff1a<\/h4>\n\n\n\n<p>GELU\u6fc0\u6d3b\u51fd\u6570\u5e7f\u6cdb\u5e94\u7528\u4e8e\u5404\u79cd\u6df1\u5ea6\u5b66\u4e60\u6a21\u578b\u4e2d\uff0c\u7279\u522b\u662f\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u9886\u57df\u3002\u5176\u5728BERT\u548cGPT\u7b49\u6a21\u578b\u4e2d\u7684\u4f7f\u7528\uff0c\u5e2e\u52a9\u6a21\u578b\u5728\u5404\u79cd\u4efb\u52a1\u4e2d\u5b9e\u73b0\u66f4\u597d\u7684\u6027\u80fd\u3002\u8fd9\u4e2a\u5185\u6838\u51fd\u6570\u53ef\u4ee5\u76f4\u63a5\u7528\u4e8e\u8fd9\u4e9b\u6a21\u578b\u7684\u524d\u5411\u4f20\u64ad\u8fc7\u7a0b\u4e2d\uff0c\u5904\u7406\u6fc0\u6d3b\u51fd\u6570\u90e8\u5206\u3002<\/p>\n\n\n\n<p>\u901a\u8fc7\u8fd9\u79cd\u65b9\u5f0f\uff0c<code>gelu_forward_kernel<\/code> \u5728\u4fdd\u8bc1\u8ba1\u7b97\u6b63\u786e\u6027\u7684\u540c\u65f6\uff0c\u63d0\u4f9b\u4e86\u9ad8\u6548\u7684\u6267\u884c\u65b9\u5f0f\uff0c\u9002\u5408\u5728\u9700\u8981\u9ad8\u6027\u80fd\u8ba1\u7b97\u7684\u6df1\u5ea6\u5b66\u4e60\u5e94\u7528\u4e2d\u4f7f\u7528\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.17 gelu_backward_kernel<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5CUDA\u5185\u6838\u51fd\u6570 <code>gelu_backward_kernel<\/code> \u7528\u4e8e\u8ba1\u7b97Gaussian Error Linear Unit (GELU) \u6fc0\u6d3b\u51fd\u6570\u7684\u68af\u5ea6\u3002\u5728\u6df1\u5ea6\u5b66\u4e60\u4e2d\uff0c\u8fd9\u4e00\u6b65\u9aa4\u901a\u5e38\u5728\u6a21\u578b\u7684\u53cd\u5411\u4f20\u64ad\u9636\u6bb5\u8fdb\u884c\uff0c\u7528\u4e8e\u8ba1\u7b97\u635f\u5931\u51fd\u6570\u5173\u4e8e\u6bcf\u4e2a\u8f93\u5165\u8282\u70b9\u7684\u504f\u5bfc\u6570\u3002<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \" >\/\/ float* dinp: \u8f93\u51fa\u68af\u5ea6\u6570\u7ec4\uff0c\u5b58\u50a8\u8ba1\u7b97\u540e\u7684GELU\u6fc0\u6d3b\u51fd\u6570\u7684\u68af\u5ea6\u7ed3\u679c\u3002\n\/\/ const float* inp: \u8f93\u5165\u6570\u7ec4\uff0c\u5305\u542b\u5e94\u7528GELU\u51fd\u6570\u524d\u7684\u539f\u59cb\u503c\u3002\n\/\/ const float* dout: \u8f93\u5165\u68af\u5ea6\u6570\u7ec4\uff0c\u5305\u542b\u4ece\u4e0a\u6e38\u4f20\u9012\u4e0b\u6765\u7684\u68af\u5ea6\u3002\n\/\/ int N: \u8f93\u5165\u548c\u8f93\u51fa\u6570\u7ec4\u7684\u5143\u7d20\u6570\u91cf\uff0c\u6307\u660e\u4e86\u5728\u8fd9\u4e24\u4e2a\u6570\u7ec4\u4e2d\u6709\u591a\u5c11\u5143\u7d20\u9700\u8981\u8fdb\u884c\u5904\u7406\u3002\n__global__ void gelu_backward_kernel(float* dinp, const float* inp, const float* dout, const int N) {\n    \/\/ 1. \u7d22\u5f15\u8ba1\u7b97\uff1a\u901a\u8fc7 blockIdx.x * blockDim.x + threadIdx.x \u8ba1\u7b97\u5f53\u524d\u7ebf\u7a0b\u7684\u5168\u5c40\u7d22\u5f15 i\uff0c\n    int i = blockIdx.x * blockDim.x + threadIdx.x;\n    \/\/ 2. \u8fb9\u754c\u68c0\u67e5\uff1a\u786e\u4fdd\u5f53\u524d\u7ebf\u7a0b\u7684\u7d22\u5f15 i \u4e0d\u8d85\u8fc7\u6570\u7ec4\u7684\u5927\u5c0f N\uff0c\u8fd9\u662f\u4e3a\u4e86\u9632\u6b62\u6570\u7ec4\u8d8a\u754c\u8bbf\u95ee\u3002\n    if (i &lt; N) {\n        \/\/ 3. GELU\u68af\u5ea6\u8ba1\u7b97\uff1a\n        \/\/ \u9996\u5148\u8bfb\u53d6\u8f93\u5165\u503c x\u3002\n        \/\/ \u8ba1\u7b97 x \u7684\u4e09\u6b21\u65b9\uff0c\u5e76\u4e58\u4ee5\u56fa\u5b9a\u7cfb\u6570 0.044715f \u5f97\u5230 cube\u3002\n        \/\/ \u6784\u9020 tanh \u51fd\u6570\u7684\u53c2\u6570 tanh_arg\u3002\n        \/\/ \u4f7f\u7528 tanhf \u548c coshf \u51fd\u6570\u6765\u8ba1\u7b97 tanh \u8f93\u51fa\u503c tanh_out \u548c cosh \u8f93\u51fa\u503c coshf_out\u3002\n        \/\/ \u8ba1\u7b97 sech \u7684\u5e73\u65b9 (sech_out)\uff0c\u5373 1.0 \/ (coshf_out * coshf_out)\u3002\n        \/\/ \u6839\u636eGELU\u7684\u68af\u5ea6\u516c\u5f0f\uff0c\u8ba1\u7b97\u672c\u5730\u68af\u5ea6 local_grad\uff1a\n        \/\/ \u5c06\u8ba1\u7b97\u5f97\u5230\u7684\u672c\u5730\u68af\u5ea6\u4e58\u4ee5\u4e0a\u6e38\u4f20\u6765\u7684\u68af\u5ea6 dout[i]\uff0c\u5f97\u5230\u6700\u7ec8\u7684\u68af\u5ea6\u503c\uff0c\u5b58\u50a8\u5728 dinp[i] \u4e2d\u3002\n        float x = inp[i];\n        float cube = 0.044715f * x * x * x;\n        float tanh_arg = GELU_SCALING_FACTOR * (x + cube);\n        float tanh_out = tanhf(tanh_arg);\n        float coshf_out = coshf(tanh_arg);\n        float sech_out = 1.0f \/ (coshf_out * coshf_out);\n        float local_grad = 0.5f * (1.0f + tanh_out) + x * 0.5f * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x);\n        dinp[i] = local_grad * dout[i];\n    }\n}\n<\/pre><\/div>\n\n\n\n<h4 class=\"wp-block-heading\">\u6027\u80fd\u4f18\u5316\uff1a<\/h4>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong>\u5e76\u884c\u5316\u5904\u7406<\/strong>\uff1a\u6b64\u5185\u6838\u9ad8\u5ea6\u5e76\u884c\u5316\uff0c\u6bcf\u4e2a\u7ebf\u7a0b\u72ec\u7acb\u5904\u7406\u4e00\u4e2a\u6570\u7ec4\u5143\u7d20\uff0c\u5927\u5e45\u63d0\u9ad8\u5904\u7406\u6548\u7387\u3002<\/li>\n\n\n\n<li><strong>\u7cbe\u786e\u7684\u6570\u5b66\u51fd\u6570\u4f7f\u7528<\/strong>\uff1a\u901a\u8fc7\u7cbe\u786e\u8ba1\u7b97 <code>tanh<\/code> \u548c <code>cosh<\/code> \u53ca\u5176\u5012\u6570\uff0c\u786e\u4fdd\u4e86\u68af\u5ea6\u8ba1\u7b97\u7684\u51c6\u786e\u6027\uff0c\u8fd9\u5bf9\u4e8e\u6a21\u578b\u8bad\u7ec3\u7684\u7a33\u5b9a\u6027\u548c\u6700\u7ec8\u6027\u80fd\u81f3\u5173\u91cd\u8981\u3002<\/li>\n<\/ul>\n\n\n\n<h4 class=\"wp-block-heading\">\u793a\u4f8b\u5e94\u7528\u573a\u666f\uff1a<\/h4>\n\n\n\n<p>\u8fd9\u4e2a\u5185\u6838\u51fd\u6570\u5728\u6267\u884c\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\u4e2d\u7684\u53cd\u5411\u4f20\u64ad\u65f6\u975e\u5e38\u6709\u7528\uff0c\u7279\u522b\u662f\u5728\u90a3\u4e9b\u4f7f\u7528GELU\u6fc0\u6d3b\u51fd\u6570\u7684\u7f51\u7edc\u67b6\u6784\u4e2d\uff0c\u4f8b\u5982BERT\u548cTransformer\u3002\u6b63\u786e\u7684\u68af\u5ea6\u8ba1\u7b97\u5bf9\u4e8e\u7f51\u7edc\u6743\u91cd\u7684\u6709\u6548\u66f4\u65b0\u548c\u6a21\u578b\u8bad\u7ec3\u7684\u6536\u655b\u81f3\u5173\u91cd\u8981\u3002\u901a\u8fc7\u8fd9\u79cd\u65b9\u5f0f\uff0c<code>gelu_backward_kernel<\/code> \u63d0\u4f9b\u4e86\u9ad8\u6548\u4e14\u51c6\u786e\u7684\u6267\u884c\u65b9\u5f0f\uff0c\u9002\u5408\u5728\u9700\u8981\u9ad8\u6027\u80fd\u8ba1\u7b97\u7684\u6df1\u5ea6\u5b66\u4e60\u5e94\u7528\u4e2d\u4f7f\u7528\u3002<\/p>\n","protected":false},"excerpt":{"rendered":"<p>\u8fd9\u91cc\u63a5\u4e0a\u4e00\u8282 \u5bf9C\u7a0b\u5e8f\u7684\u4e2d\u6587\u6ce8\u89e3\uff0c\u4e0b\u9762\u662f\u5bf9 train_gpt2.cu \u7684\u6ce8\u89e3\uff0c\u6240\u6709\u6ce8\u89e3\u6765\u81eaChatGPT4\u3002 [&hellip;]<\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"site-sidebar-layout":"default","site-content-layout":"","ast-site-content-layout":"default","site-content-style":"default","site-sidebar-style":"default","ast-global-header-display":"","ast-banner-title-visibility":"","ast-main-header-display":"","ast-hfb-above-header-display":"","ast-hfb-below-header-display":"","ast-hfb-mobile-header-display":"","site-post-title":"","ast-breadcrumbs-content":"","ast-featured-img":"","footer-sml-layout":"","theme-transparent-header-meta":"","adv-header-id-meta":"","stick-header-meta":"","header-above-stick-meta":"","header-main-stick-meta":"","header-below-stick-meta":"","astra-migrate-meta-layouts":"set","ast-page-background-enabled":"default","ast-page-background-meta":{"desktop":{"background-color":"var(--ast-global-color-4)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"tablet":{"background-color":"","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"mobile":{"background-color":"","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""}},"ast-content-background-meta":{"desktop":{"background-color":"var(--ast-global-color-5)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"tablet":{"background-color":"var(--ast-global-color-5)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"mobile":{"background-color":"var(--ast-global-color-5)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""}},"_jetpack_memberships_contains_paid_content":false,"footnotes":""},"categories":[444,443,442],"tags":[242,431,314],"class_list":["post-3319","post","type-post","status-publish","format-standard","hentry","category-ai","category-llm","category-llms","tag-chatgpt","tag-llm-c","tag-openai-api"],"views":2538,"jetpack_sharing_enabled":true,"jetpack_featured_media_url":"","_links":{"self":[{"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/posts\/3319","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=%2Fwp%2Fv2%2Fcomments&post=3319"}],"version-history":[{"count":29,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/posts\/3319\/revisions"}],"predecessor-version":[{"id":3352,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/posts\/3319\/revisions\/3352"}],"wp:attachment":[{"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=%2Fwp%2Fv2%2Fmedia&parent=3319"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=%2Fwp%2Fv2%2Fcategories&post=3319"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=%2Fwp%2Fv2%2Ftags&post=3319"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}