{"id":3229,"date":"2024-04-21T13:54:41","date_gmt":"2024-04-21T05:54:41","guid":{"rendered":"https:\/\/www.aqwu.net\/wp\/?p=3229"},"modified":"2024-04-28T19:53:29","modified_gmt":"2024-04-28T11:53:29","slug":"llm-c-%e7%9a%84%e4%b8%ad%e6%96%87%e6%b3%a8%e8%a7%a3-20240421","status":"publish","type":"post","link":"https:\/\/www.aqwu.net\/wp\/?p=3229","title":{"rendered":"llm.c \u7684\u4e2d\u6587\u6ce8\u89e3-20240421"},"content":{"rendered":"\n<p>llm.c \u7b80\u5355\u3001\u7eaf C\/CUDA \u7684 LLM \u8bad\u7ec3\u3002\u4e0d\u9700\u8981 245MB \u7684 PyTorch \u6216 107MB \u7684 cPython\u3002\u8bad\u7ec3 GPT-2 \uff08CPU\uff0c fp32\uff09 \u5728\u5355\u4e2a\u6587\u4ef6&nbsp;<a href=\"https:\/\/github.com\/karpathy\/llm.c\/blob\/master\/train_gpt2.c\">train_gpt2.c<\/a>&nbsp;\u4e2d\u662f ~1,000 \u884c\u5e72\u51c0\u4ee3\u7801\uff0c\u5728 GPU \u4e0a\u8bad\u7ec3\u5b83\u662f ~2,000 \u884c\uff08\u6dfb\u52a0 CUDA \u5185\u6838\uff09\u5728&nbsp;<a href=\"https:\/\/github.com\/karpathy\/llm.c\/blob\/master\/train_gpt2.cu\">train_gpt2.cu<\/a>&nbsp;\u4e2d\u3002\u4ee3\u7801\u7acb\u5373\u7f16\u8bd1\u5e76\u8fd0\u884c\uff0c\u5b83\u4e0e PyTorch \u53c2\u8003\u5b9e\u73b0\u5b8c\u5168\u5339\u914d\uff0c\u5e76\u4e14\u5b83 ~\u5339\u914d\uff08\u7f16\u8bd1\uff09PyTorch \u7684\u901f\u5ea6\uff08fp32\uff0cno flash attention\uff09\u3002\u6211\u9009\u62e9 GPT-2 \u4f5c\u4e3a\u7b2c\u4e00\u4e2a\u5de5\u4f5c\u793a\u4f8b\uff0c\u56e0\u4e3a\u5b83\u662f LLM \u7684\u7956\u7236\uff0c\u662f\u73b0\u4ee3\u5806\u6808\u7b2c\u4e00\u6b21\u7ec4\u5408\u5728\u4e00\u8d77\u3002<\/p>\n\n\n\n<p>\u53c2\u89c1\uff1a<a href=\"https:\/\/github.com\/karpathy\/llm.c\/tree\/master\">karpathy\/llm.c: LLM training in simple, raw C\/CUDA (github.com)<\/a> \u83b7\u5f97\u66f4\u591a\u77e5\u8bc6<\/p>\n\n\n\n<p>\u4e0b\u9762\u7684\u4ee3\u7801\u4e3b\u8981\u662f\u4ece train_gpt2.c\u548ctest_gpt2.c \u4e2d\u83b7\u53d6\uff0c\u8fdb\u884c\u4e2d\u6587\u7684\u6ce8\u91ca\uff0c\u4ee3\u7801\u662f\u4f7f\u7528ChatGPT4 \u8fdb\u884c\u6ce8\u89e3\uff0c<\/p>\n\n\n\n<p>\u4f9d\u7167\u5f53\u524d\u6587\u4ef6\u7684\u51fd\u6570\u6392\u5e8f\uff0c\u8fdb\u884c\u6ce8\u89e3,\u4e3a\u4e86\u4fdd\u6301\u548c\u6e90\u4ee3\u7801\u7684\u4e00\u81f4\u6027\uff0c\u65b9\u4fbf\u5bf9\u6bd4\uff0c\u82f1\u6587\u6ce8\u91ca\u4fdd\u6301\u4e0d\u53d8\u3002<\/p>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>1. train_gpt2.c<\/strong><\/h2>\n\n\n\n<p>train_gpt2.c \u662fC \u8bed\u8a00\u7248\u7684\u8bad\u7ec3\u4ee3\u7801\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.1 encoder_forward<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u662f\u7f16\u7801\u5668\uff08encoder\uff09\u7684\u524d\u5411\u4f20\u64ad\u51fd\u6570, \u5b83\u7684\u4f5c\u7528\u662f\u5c06\u8f93\u5165\u7684token\u5e8f\u5217\u8f6c\u6362\u6210\u5bf9\u5e94\u7684\u5d4c\u5165\u5411\u91cf\u5e8f\u5217\u3002\u8fd9\u4e9b\u5d4c\u5165\u5411\u91cf\u5e8f\u5217\u65e2\u5305\u62ec\u4e86\u6bcf\u4e2atoken\u7684\u4fe1\u606f\uff0c\u4e5f\u5305\u62ec\u4e86token\u5728\u5e8f\u5217\u4e2d\u7684\u4f4d\u7f6e\u4fe1\u606f\u3002\u4ee5\u4e0b\u662f\u5bf9\u8fd9\u4e2a\u51fd\u6570\u7684\u8be6\u7ec6\u4e2d\u6587\u6ce8\u91ca\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">\/\/ all the individual layers' forward and backward passes\n\/\/ B = batch_size, T = sequence_length, C = channels, V = vocab_size\n\/\/ \u6240\u6709\u5355\u72ec\u5c42\u7684\u524d\u5411\u548c\u53cd\u5411\u4f20\u64ad\n\/\/ B = \u6279\u5927\u5c0f, T = \u5e8f\u5217\u957f\u5ea6, C = \u901a\u9053\u6570, V = \u8bcd\u6c47\u8868\u5927\u5c0f\n\/\/ B = batch_size: \u6279\u91cf\u5927\u5c0f\uff08Batch Size\uff09\u662f\u6307\u5728\u8bad\u7ec3\u6a21\u578b\u65f6\uff0c\u4e00\u6b21\u6027\u8f93\u5165\u6a21\u578b\u7684\u6570\u636e\u6837\u672c\u6570\u91cf\u3002\n\/\/     \u8f83\u5927\u7684\u6279\u91cf\u5927\u5c0f\u53ef\u4ee5\u63d0\u9ad8\u5185\u5b58\u5229\u7528\u7387\u548c\u8bad\u7ec3\u901f\u5ea6\uff0c\u4f46\u4e5f\u53ef\u80fd\u9700\u8981\u66f4\u591a\u7684\u5185\u5b58\uff0c\u5e76\u4e14\u53ef\u80fd\u5f71\u54cd\u6a21\u578b\u7684\u6cdb\u5316\u80fd\u529b\u3002\n\/\/ T = sequence_length: \u5e8f\u5217\u957f\u5ea6\uff08Sequence Length\uff09\u662f\u6307\u5728\u5904\u7406\u5e8f\u5217\u6570\u636e\u65f6\uff0c\u5e8f\u5217\u4e2d\u5143\u7d20\u7684\u6570\u91cf\u3002\n\/\/     \u5bf9\u4e8e\u6587\u672c\u6570\u636e\uff0c\u8fd9\u53ef\u4ee5\u662f\u53e5\u5b50\u4e2d\u7684\u5355\u8bcd\u6570\u91cf\uff1b\u5bf9\u4e8e\u65f6\u95f4\u5e8f\u5217\u6570\u636e\uff0c\u8fd9\u53ef\u4ee5\u662f\u4e00\u7cfb\u5217\u8fde\u7eed\u89c2\u6d4b\u503c\u7684\u6570\u91cf\u3002\n\/\/ C = channels: \u901a\u9053\u6570\uff08Channels\uff09\u5728\u4e0d\u540c\u4e0a\u4e0b\u6587\u4e2d\u53ef\u80fd\u6709\u4e0d\u540c\u7684\u542b\u4e49\u3002\n\/\/     \u5728\u5904\u7406\u56fe\u50cf\u6570\u636e\u65f6\uff0c\u5b83\u53ef\u80fd\u6307\u7684\u662f\u989c\u8272\u901a\u9053\uff08\u5982RGB\u56fe\u50cf\u67093\u4e2a\u901a\u9053\uff09\u3002\n\/\/     \u5728\u5904\u7406\u6587\u672c\u6216\u5176\u4ed6\u7c7b\u578b\u7684\u5e8f\u5217\u6570\u636e\u65f6\uff0c\u5b83\u53ef\u80fd\u6307\u7684\u662f\u5d4c\u5165\u5411\u91cf\u7684\u7ef4\u5ea6\u3002\n\/\/ V = vocab_size: \u8bcd\u6c47\u8868\u5927\u5c0f\uff08Vocabulary Size\uff09\u662f\u6307\u6a21\u578b\u7528\u4e8e\u8868\u793a\u6570\u636e\u7684\u8bcd\u6c47\u8868\u4e2d\u552f\u4e00\u8bcd\u6c47\u7684\u6570\u91cf\u3002\n\/\/     \u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\uff08NLP\uff09\u4efb\u52a1\u4e2d\uff0c\u8bcd\u6c47\u8868\u5927\u5c0f\u51b3\u5b9a\u4e86\u6a21\u578b\u53ef\u4ee5\u8bc6\u522b\u7684\u4e0d\u540c\u5355\u8bcd\u6216token\u7684\u6570\u91cf\u3002\n\nvoid encoder_forward(float* out,\n                   int* inp, float* wte, float* wpe,\n                   int B, int T, int C) {\n    \/\/ out is (B,T,C). At each position (b,t), a C-dimensional vector summarizing token &amp; position\n    \/\/ inp is (B,T) of integers, holding the token ids at each (b,t) position\n    \/\/ wte is (V,C) of token embeddings, short for \"weight token embeddings\"\n    \/\/ wpe is (maxT,C) of position embeddings, short for \"weight positional embedding\"\n    \/\/ out\u662f(B,T,C)\u7ef4\u7684\u8f93\u51fa\u3002\u5728\u6bcf\u4e2a\u4f4d\u7f6e(b,t)\uff0c\u90fd\u6709\u4e00\u4e2aC\u7ef4\u5411\u91cf\uff0c\u7efc\u5408\u4e86token\u548c\u4f4d\u7f6e\u7684\u4fe1\u606f\n    \/\/ inp\u662f(B,T)\u7ef4\u7684\u6574\u6570\u6570\u7ec4\uff0c\u6bcf\u4e2a\u4f4d\u7f6e(b,t)\u6301\u6709\u4e00\u4e2atoken\u7684id\n    \/\/ wte\u662f(V,C)\u7ef4\u7684token\u5d4c\u5165\u77e9\u9635\uff0c\u7b80\u79f0\u4e3a\"weight token embeddings\"\n    \/\/ wpe\u662f(maxT,C)\u7ef4\u7684\u4f4d\u7f6e\u5d4c\u5165\u77e9\u9635\uff0c\u7b80\u79f0\u4e3a\"weight positional embedding\"\n    \/\/ \u904d\u5386\u6bcf\u4e2a\u6279\u6b21\n    for (int b = 0; b &lt; B; b++) {\n        \/\/ \u904d\u5386\u6bcf\u4e2a\u4f4d\u7f6e\n        for (int t = 0; t &lt; T; t++) {\n            \/\/ seek to the output position in out[b,t,:]\n            \/\/ \u5b9a\u4f4d\u5230\u8f93\u51fa\u5f20\u91cf\u7684\u6307\u5b9a\u4f4d\u7f6eout[b,t,:]\n            float* out_bt = out + b * T * C + t * C;\n            \/\/ get the index of the token at inp[b, t]\n            \/\/ \u83b7\u53d6\u5f53\u524d\u4f4d\u7f6e\u7684token\u7d22\u5f15\n            int ix = inp[b * T + t];\n            \/\/ seek to the position in wte corresponding to the token\n            \/\/ \u5b9a\u4f4d\u5230token\u5d4c\u5165\u77e9\u9635\u4e2d\u5bf9\u5e94\u7684token\u5d4c\u5165\u5411\u91cf\n            float* wte_ix = wte + ix * C;\n            \/\/ seek to the position in wpe corresponding to the position\n            \/\/ \u5b9a\u4f4d\u5230\u4f4d\u7f6e\u5d4c\u5165\u77e9\u9635\u4e2d\u5bf9\u5e94\u7684\u4f4d\u7f6e\u5d4c\u5165\u5411\u91cf\n            float* wpe_t = wpe + t * C;\n            \/\/ add the two vectors and store the result in out[b,t,:]\n            \/\/ \u5c06token\u5d4c\u5165\u5411\u91cf\u548c\u4f4d\u7f6e\u5d4c\u5165\u5411\u91cf\u76f8\u52a0\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728out[b,t,:]\u4e2d\n            for (int i = 0; i &lt; C; i++) {\n                out_bt[i] = wte_ix[i] + wpe_t[i];\n            }\n        }\n    }\n}\n<\/pre><\/div>\n\n\n\n<p>\u5728\u6df1\u5ea6\u5b66\u4e60\u6a21\u578b\u7684\u8bad\u7ec3\u8fc7\u7a0b\u4e2d\uff0c\u6bcf\u4e00\u5c42\u90fd\u4f1a\u8fdb\u884c\u524d\u5411\u4f20\u64ad\uff0c\u8ba1\u7b97\u5176\u8f93\u51fa\u4ee5\u4f9b\u4e0b\u4e00\u5c42\u4f7f\u7528\uff0c\u76f4\u5230\u751f\u6210\u6700\u7ec8\u7684\u8f93\u51fa\u3002\u5728\u8ba1\u7b97\u635f\u5931\uff08\u5373\u6a21\u578b\u8f93\u51fa\u4e0e\u5b9e\u9645\u6807\u7b7e\u4e4b\u95f4\u7684\u5dee\u5f02\uff09\u540e\uff0c\u6a21\u578b\u901a\u8fc7\u53cd\u5411\u4f20\u64ad\u8fc7\u7a0b\u8ba1\u7b97\u635f\u5931\u76f8\u5bf9\u4e8e\u6bcf\u4e2a\u53c2\u6570\u7684\u68af\u5ea6\uff0c\u7136\u540e\u4f7f\u7528\u8fd9\u4e9b\u68af\u5ea6\u6765\u66f4\u65b0\u6a21\u578b\u53c2\u6570\u3002\u8fd9\u4e2a\u8fc7\u7a0b\u5728\u591a\u4e2a\u8bad\u7ec3\u5468\u671f\uff08Epochs\uff09\u4e2d\u91cd\u590d\u8fdb\u884c\uff0c\u76f4\u5230\u6a21\u578b\u6027\u80fd\u8fbe\u5230\u6ee1\u610f\u7684\u6c34\u5e73\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.2 encoder_backward<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u6267\u884c\u7f16\u7801\u5668\u5c42\u7684\u53cd\u5411\u4f20\u64ad\u64cd\u4f5c\uff0c\u7528\u4e8e\u66f4\u65b0\u8bcd\u5d4c\u5165\uff08word embeddings\uff09\u548c\u4f4d\u7f6e\u5d4c\u5165\uff08positional embeddings\uff09\u7684\u68af\u5ea6\u3002\u8fd9\u662f\u6df1\u5ea6\u5b66\u4e60\u6a21\u578b\uff0c\u5c24\u5176\u662f\u5728\u5904\u7406\u5e8f\u5217\u6570\u636e\u5982\u6587\u672c\u65f6\u7684\u5173\u952e\u6b65\u9aa4\u3002\u4ee5\u4e0b\u662f\u5bf9\u8fd9\u4e2a\u51fd\u6570\u7684\u8be6\u7ec6\u4e2d\u6587\u6ce8\u91ca\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">void encoder_backward(float* dwte, float* dwpe,\n                      float* dout, int* inp,\n                      int B, int T, int C) {\n    \/\/ \u904d\u5386\u6279\u6b21\u548c\u5e8f\u5217\u4e2d\u7684\u6bcf\u4e2a\u4f4d\u7f6e\n    for (int b = 0; b &lt; B; b++) {\n        for (int t = 0; t &lt; T; t++) {\n            \/\/ \u5b9a\u4f4d\u5230\u7279\u5b9a\u6279\u6b21\u548c\u65f6\u95f4\u6b65\u7684\u8f93\u51fa\u68af\u5ea6\n            float* dout_bt = dout + b * T * C + t * C;\n            \/\/ \u83b7\u53d6\u5f53\u524d\u4f4d\u7f6e\u7684\u8f93\u5165\u8bcd\u7d22\u5f15\n            int ix = inp[b * T + t];\n            \/\/ \u5b9a\u4f4d\u5230\u4e0e\u5f53\u524d\u8bcd\u7d22\u5f15\u5bf9\u5e94\u7684\u8bcd\u5d4c\u5165\u68af\u5ea6\n            float* dwte_ix = dwte + ix * C;\n            \/\/ \u5b9a\u4f4d\u5230\u4e0e\u5f53\u524d\u4f4d\u7f6e\u5bf9\u5e94\u7684\u4f4d\u7f6e\u5d4c\u5165\u68af\u5ea6\n            float* dwpe_t = dwpe + t * C;\n            \/\/ \u7d2f\u52a0\u8ba1\u7b97\u8bcd\u5d4c\u5165\u548c\u4f4d\u7f6e\u5d4c\u5165\u7684\u68af\u5ea6\n            for (int i = 0; i &lt; C; i++) {\n            \t  \/\/ \u83b7\u53d6\u5f53\u524d\u68af\u5ea6\n                float d = dout_bt[i];\n                \/\/ \u66f4\u65b0\u8bcd\u5d4c\u5165\u68af\u5ea6\n                dwte_ix[i] += d;\n                \/\/ \u66f4\u65b0\u4f4d\u7f6e\u5d4c\u5165\u68af\u5ea6\n                dwpe_t[i] += d;\n            }\n        }\n    }\n}\n<\/pre><\/div>\n\n\n\n<p>\u5728\u8bad\u7ec3\u8fc7\u7a0b\u4e2d\uff0c\u6a21\u578b\u7684\u524d\u5411\u4f20\u64ad\u4f1a\u8ba1\u7b97\u51fa\u6700\u7ec8\u7684\u8f93\u51fa\uff0c\u7136\u540e\u901a\u8fc7\u6bd4\u8f83\u9884\u6d4b\u7ed3\u679c\u548c\u5b9e\u9645\u7ed3\u679c\u6765\u8ba1\u7b97\u635f\u5931\u3002\u5728\u53cd\u5411\u4f20\u64ad\u9636\u6bb5\uff0c\u6839\u636e\u635f\u5931\u51fd\u6570\u76f8\u5bf9\u4e8e\u6a21\u578b\u53c2\u6570\u7684\u68af\u5ea6\uff0c\u66f4\u65b0\u6a21\u578b\u7684\u53c2\u6570\uff0c\u4ee5\u51cf\u5c11\u672a\u6765\u7684\u9884\u6d4b\u8bef\u5dee\u3002\u8fd9\u6bb5\u4ee3\u7801\u7279\u522b\u5904\u7406\u4e86\u5bf9\u4e8e\u5e8f\u5217\u6a21\u578b\u4e2d\u4e24\u79cd\u91cd\u8981\u7c7b\u578b\u5d4c\u5165\u2014\u2014\u8bcd\u5d4c\u5165\u548c\u4f4d\u7f6e\u5d4c\u5165\u7684\u68af\u5ea6\u66f4\u65b0\uff0c\u8fd9\u5bf9\u4e8e\u6a21\u578b\u7406\u89e3\u8f93\u5165\u5e8f\u5217\u7684\u8bed\u4e49\u548c\u7ed3\u6784\u81f3\u5173\u91cd\u8981\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.3 layernorm_forward<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u5b9e\u73b0\u4e86\u5c42\u5f52\u4e00\u5316\uff08Layer Normalization\uff09\u7684\u524d\u5411\u4f20\u64ad\u8fc7\u7a0b\u3002\u5c42\u5f52\u4e00\u5316\u662f\u6df1\u5ea6\u5b66\u4e60\u4e2d\u5e38\u7528\u7684\u4e00\u79cd\u6280\u672f\uff0c\u7279\u522b\u662f\u5728\u5904\u7406\u5e8f\u5217\u6570\u636e\u7684\u6a21\u578b\u4e2d\uff0c\u5982\u5faa\u73af\u795e\u7ecf\u7f51\u7edc\uff08RNNs\uff09\u548cTransformer\u3002\u5b83\u6709\u52a9\u4e8e\u7a33\u5b9a\u795e\u7ecf\u7f51\u7edc\u7684\u8bad\u7ec3\u8fc7\u7a0b\uff0c\u901a\u8fc7\u5bf9\u6bcf\u4e00\u5c42\u7684\u6fc0\u6d3b\u8fdb\u884c\u5f52\u4e00\u5316\u6765\u51cf\u5c11\u8bad\u7ec3\u8fc7\u7a0b\u4e2d\u7684\u5185\u90e8\u534f\u53d8\u91cf\u504f\u79fb\uff08Internal Covariate Shift\uff09\u3002\u4ee5\u4e0b\u662f\u5bf9\u8fd9\u4e2a\u51fd\u6570\u7684\u8be6\u7ec6\u4e2d\u6587\u6ce8\u91ca\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">void layernorm_forward(float* out, float* mean, float* rstd,\n                       float* inp, float* weight, float* bias,\n                       int B, int T, int C) {\n    \/\/ reference: https:\/\/pytorch.org\/docs\/stable\/generated\/torch.nn.LayerNorm.html\n    \/\/ both inp and out are (B,T,C) of the activations\n    \/\/ mean and rstd are (B,T) buffers, to be used later in backward pass\n    \/\/ at each position (b,t) of the input, the C-dimensional vector\n    \/\/ of activations gets normalized, then scaled and shifted\n    \/\/ \u53c2\u8003\uff1ahttps:\/\/pytorch.org\/docs\/stable\/generated\/torch.nn.LayerNorm.html\n    \/\/ \u8f93\u5165(inp)\u548c\u8f93\u51fa(out)\u90fd\u662f(B,T,C)\u7ef4\u5ea6\u7684\u6fc0\u6d3b\u503c\n    \/\/ mean\u548crstd\u662f(B,T)\u7ef4\u5ea6\u7684\u7f13\u51b2\u533a\uff0c\u7a0d\u540e\u5728\u53cd\u5411\u4f20\u64ad\u4e2d\u4f7f\u7528\n    \/\/ \u5bf9\u4e8e\u8f93\u5165\u7684\u6bcf\u4e2a\u4f4d\u7f6e(b,t)\uff0cC\u7ef4\u7684\u6fc0\u6d3b\u5411\u91cf\u4f1a\u88ab\u5f52\u4e00\u5316\uff0c\u7136\u540e\u8fdb\u884c\u7f29\u653e\u548c\u504f\u79fb\n    \/\/ \u907f\u514d\u9664\u4ee50\n    float eps = 1e-5f;\n    for (int b = 0; b &lt; B; b++) {\n        for (int t = 0; t &lt; T; t++) {\n            \/\/ seek to the input position inp[b,t,:]\n            \/\/ \u5b9a\u4f4d\u5230\u8f93\u5165\u4f4d\u7f6einp[b,t,:]\n            float* x = inp + b * T * C + t * C;\n            \/\/ calculate the mean\n            \/\/ \u8ba1\u7b97\u5747\u503c\n            float m = 0.0f;\n            for (int i = 0; i &lt; C; i++) {\n                m += x[i];\n            }\n            m = m\/C;\n            \/\/ calculate the variance (without any bias correction)\n            \/\/ \u8ba1\u7b97\u65b9\u5dee\uff08\u4e0d\u8fdb\u884c\u504f\u5dee\u6821\u6b63\uff09\n            float v = 0.0f;\n            for (int i = 0; i &lt; C; i++) {\n                float xshift = x[i] - m;\n                v += xshift * xshift;\n            }\n            v = v\/C;\n            \/\/ calculate the rstd (reciprocal standard deviation)\n            \/\/ \u8ba1\u7b97\u9006\u6807\u51c6\u5dee\uff08reciprocal standard deviation\uff09\n            float s = 1.0f \/ sqrtf(v + eps);\n            \/\/ seek to the output position in out[b,t,:]\n            \/\/ \u5b9a\u4f4d\u5230\u8f93\u51fa\u4f4d\u7f6eout[b,t,:]\n            float* out_bt = out + b * T * C + t * C;\n            for (int i = 0; i &lt; C; i++) {\n                \/\/ \u5f52\u4e00\u5316\n                float n = (s * (x[i] - m)); \/\/ normalize\n                \/\/ \u7f29\u653e\u548c\u504f\u79fb\n                float o = n * weight[i] + bias[i]; \/\/ scale and shift\n                \/\/ \u5199\u5165\u7ed3\u679c\n                out_bt[i] = o; \/\/ write\n            }\n            \/\/ cache the mean and rstd for the backward pass later\n            \/\/ \u4e3a\u53cd\u5411\u4f20\u64ad\u7f13\u5b58\u5747\u503c\u548c\u9006\u6807\u51c6\u5dee\n            mean[b * T + t] = m;\n            rstd[b * T + t] = s;\n        }\n    }\n}\n<\/pre><\/div>\n\n\n\n<p>\u901a\u8fc7\u8ba1\u7b97\u6bcf\u4e2a\u4f4d\u7f6e\u4e0a\u7684\u6fc0\u6d3b\u5411\u91cf\u7684\u5747\u503c\u548c\u6807\u51c6\u5dee\uff0c\u7136\u540e\u5bf9\u6bcf\u4e2a\u6fc0\u6d3b\u5411\u91cf\u8fdb\u884c\u5f52\u4e00\u5316\u5904\u7406\uff08\u5373\uff0c\u51cf\u53bb\u5747\u503c\uff0c\u9664\u4ee5\u6807\u51c6\u5dee\uff09\uff0c\u6700\u540e\u5bf9\u5f52\u4e00\u5316\u540e\u7684\u5411\u91cf\u8fdb\u884c\u7ebf\u6027\u53d8\u6362\uff08\u901a\u8fc7\u6743\u91cd\u548c\u504f\u7f6e\u8fdb\u884c\u7f29\u653e\u548c\u504f\u79fb\uff09\uff0c\u5f97\u5230\u6700\u7ec8\u7684\u8f93\u51fa\u3002\u8fd9\u79cd\u5904\u7406\u65b9\u5f0f\u4f7f\u5f97\u6a21\u578b\u7684\u6bcf\u4e00\u5c42\u90fd\u80fd\u6709\u7a33\u5b9a\u7684\u6fc0\u6d3b\u5206\u5e03\uff0c\u4ece\u800c\u6709\u52a9\u4e8e\u6539\u5584\u8bad\u7ec3\u8fc7\u7a0b\u4e2d\u7684\u6570\u503c\u7a33\u5b9a\u6027\u548c\u6536\u655b\u901f\u5ea6\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.4 layernorm_backward<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u5b9e\u73b0\u4e86\u5c42\u5f52\u4e00\u5316\uff08Layer Normalization\uff09\u7684\u53cd\u5411\u4f20\u64ad\u8fc7\u7a0b\u3002\u5728\u6df1\u5ea6\u5b66\u4e60\u6a21\u578b\u7684\u8bad\u7ec3\u8fc7\u7a0b\u4e2d\uff0c\u53cd\u5411\u4f20\u64ad\u662f\u4e00\u79cd\u8ba1\u7b97\u635f\u5931\u51fd\u6570\u5173\u4e8e\u6bcf\u4e2a\u53c2\u6570\u68af\u5ea6\u7684\u65b9\u6cd5\uff0c\u8fd9\u4e9b\u68af\u5ea6\u968f\u540e\u88ab\u7528\u6765\u66f4\u65b0\u6a21\u578b\u7684\u53c2\u6570\u3002\u5c42\u5f52\u4e00\u5316\u7684\u53cd\u5411\u4f20\u64ad\u7279\u522b\u91cd\u8981\uff0c\u56e0\u4e3a\u5b83\u6d89\u53ca\u5230\u5982\u4f55\u5c06\u6765\u81ea\u540e\u7eed\u5c42\u7684\u68af\u5ea6\u4f20\u9012\u56de\u524d\u9762\u7684\u5c42\uff0c\u5e76\u66f4\u65b0\u76f8\u5173\u7684\u6743\u91cd\u548c\u504f\u7f6e\u3002\u4ee5\u4e0b\u662f\u5bf9\u8fd9\u4e2a\u51fd\u6570\u7684\u8be6\u7ec6\u4e2d\u6587\u6ce8\u91ca\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">void layernorm_backward(float* dinp, float* dweight, float* dbias,\n                        float* dout, float* inp, float* weight, float* mean, float* rstd,\n                        int B, int T, int C) {\n    \/\/ \u904d\u5386\u6bcf\u4e2a\u6279\u6b21\u548c\u65f6\u95f4\u6b65\n    for (int b = 0; b &lt; B; b++) {\n        for (int t = 0; t &lt; T; t++) {\n            \/\/ \u5b9a\u4f4d\u5230\u7279\u5b9a\u6279\u6b21\u548c\u65f6\u95f4\u6b65\u7684\u8f93\u51fa\u68af\u5ea6\n            float* dout_bt = dout + b * T * C + t * C;\n            \/\/ \u5b9a\u4f4d\u5230\u5bf9\u5e94\u7684\u8f93\u5165\u548c\u8f93\u5165\u68af\u5ea6\n            float* inp_bt = inp + b * T * C + t * C;\n            float* dinp_bt = dinp + b * T * C + t * C;\n            \/\/ \u83b7\u53d6\u5bf9\u5e94\u7684\u5747\u503c\u548c\u9006\u6807\u51c6\u5dee\n            float mean_bt = mean[b * T + t];\n            float rstd_bt = rstd[b * T + t];\n\n            \/\/ first: two reduce operations\n            \/\/ \u9996\u5148\uff1a\u8fdb\u884c\u4e24\u6b21reduce\u64cd\u4f5c\u4ee5\u51c6\u5907\u68af\u5ea6\u8ba1\u7b97\n            float dnorm_mean = 0.0f;\n            float dnorm_norm_mean = 0.0f;\n            for (int i = 0; i &lt; C; i++) {\n                \/\/ \u8ba1\u7b97\u5f52\u4e00\u5316\u7684\u503c\u53ca\u5176\u4e0e\u8f93\u51fa\u68af\u5ea6\u7684\u4e58\u79ef\n                float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;\n                float dnorm_i = weight[i] * dout_bt[i];\n                dnorm_mean += dnorm_i;\n                dnorm_norm_mean += dnorm_i * norm_bti;\n            }\n            \/\/ \u5bf9\u4e24\u4e2a\u6c42\u548c\u7ed3\u679c\u53d6\u5e73\u5747\n            dnorm_mean = dnorm_mean \/ C;\n            dnorm_norm_mean = dnorm_norm_mean \/ C;\n\n            \/\/ now iterate again and accumulate all the gradients\n            \/\/ \u518d\u6b21\u904d\u5386\u5e76\u7d2f\u52a0\u6240\u6709\u68af\u5ea6\n            for (int i = 0; i &lt; C; i++) {\n                \/\/ \u8ba1\u7b97\u5f52\u4e00\u5316\u503c\u53ca\u5176\u4e0e\u8f93\u51fa\u68af\u5ea6\u7684\u4e58\u79ef\n                float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;\n                float dnorm_i = weight[i] * dout_bt[i];\n                \/\/ gradient contribution to bias\n                \/\/ \u5bf9\u504f\u7f6e\u7684\u68af\u5ea6\u8d21\u732e\n                dbias[i] += dout_bt[i];\n                \/\/ gradient contribution to weight\n                \/\/ \u5bf9\u6743\u91cd\u7684\u68af\u5ea6\u8d21\u732e\n                dweight[i] += norm_bti * dout_bt[i];\n                \/\/ gradient contribution to input\n                \/\/ \u5bf9\u8f93\u5165\u7684\u68af\u5ea6\u8d21\u732e\n                float dval = 0.0f;\n                dval += dnorm_i; \/\/ term 1\n                dval -= dnorm_mean; \/\/ term 2\n                dval -= norm_bti * dnorm_norm_mean; \/\/ term 3\n                dval *= rstd_bt; \/\/ final scale\n                dinp_bt[i] += dval;\n            }\n        }\n    }\n}\n<\/pre><\/div>\n\n\n\n<p>\u8fd9\u4e2a\u8fc7\u7a0b\u9996\u5148\u8ba1\u7b97\u4e86\u4e24\u4e2a\u91cd\u8981\u7684\u4e2d\u95f4\u53d8\u91cf\uff1a<code>dnorm_mean<\/code>\uff08\u8f93\u51fa\u68af\u5ea6\u7684\u5e73\u5747\u503c\uff09\u548c<code>dnorm_norm_mean<\/code>\uff08\u5f52\u4e00\u5316\u503c\u4e0e\u8f93\u51fa\u68af\u5ea6\u4e58\u79ef\u7684\u5e73\u5747\u503c\uff09\u3002\u7136\u540e\uff0c\u5b83\u4f7f\u7528\u8fd9\u4e9b\u4e2d\u95f4\u53d8\u91cf\u6765\u8ba1\u7b97\u5bf9\u6743\u91cd\u3001\u504f\u7f6e\u548c\u8f93\u5165\u7684\u68af\u5ea6\u3002\u8fd9\u79cd\u65b9\u5f0f\u786e\u4fdd\u4e86\u68af\u5ea6\u7684\u8ba1\u7b97\u8003\u8651\u4e86\u5f52\u4e00\u5316\u7684\u5f71\u54cd\uff0c\u5e76\u9002\u5f53\u5730\u66f4\u65b0\u4e86\u6743\u91cd\u548c\u504f\u7f6e\uff0c\u4ee5\u6539\u5584\u6a21\u578b\u5728\u540e\u7eed\u8fed\u4ee3\u4e2d\u7684\u8868\u73b0\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.5 matmul_forward<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u662f\u77e9\u9635\u4e58\u6cd5\uff08Matrix Multiplication\uff09\u7684\u524d\u5411\u4f20\u64ad\u51fd\u6570\uff0c\u4e3b\u8981\u7528\u4e8e\u6df1\u5ea6\u5b66\u4e60\u6a21\u578b\u4e2d\u7684\u7ebf\u6027\u5c42\uff08\u6216\u5168\u8fde\u63a5\u5c42\uff09\u3002\u5b83\u5c06\u8f93\u5165\u6570\u636e\u548c\u6743\u91cd\u77e9\u9635\u76f8\u4e58\uff0c\u7136\u540e\u52a0\u4e0a\u504f\u7f6e\u9879\uff0c\u751f\u6210\u8f93\u51fa\u6570\u636e\u3002\u4ee5\u4e0b\u662f\u5bf9\u8fd9\u4e2a\u51fd\u6570\u7684\u8be6\u7ec6\u4e2d\u6587\u6ce8\u91ca\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">void matmul_forward(float* out,\n                    float* inp, float* weight, float* bias,\n                    int B, int T, int C, int OC) {\n    \/\/ most of the running time is spent here and in matmul_backward\n    \/\/ OC is short for \"output channels\"\n    \/\/ inp is (B,T,C), weight is (OC, C), bias is (OC)\n    \/\/ out will be (B,T,OC)\n    \/\/ \u6b64\u51fd\u6570\u548cmatmul_backward\u51fd\u6570\u5360\u636e\u4e86\u5927\u90e8\u5206\u8fd0\u884c\u65f6\u95f4\n    \/\/ OC\u662f\u201c\u8f93\u51fa\u901a\u9053\u6570\u201d\u7684\u7b80\u79f0\n    \/\/ inp\u7684\u7ef4\u5ea6\u662f(B,T,C)\uff0cweight\u7684\u7ef4\u5ea6\u662f(OC, C)\uff0cbias\u7684\u7ef4\u5ea6\u662f(OC)\n    \/\/ \u8f93\u51faout\u7684\u7ef4\u5ea6\u5c06\u4f1a\u662f(B,T,OC)\n    #pragma omp parallel for collapse(2)\n    \/\/ \u904d\u5386\u6279\u6b21\n    for (int b = 0; b &lt; B; b++) {\n        \/\/ \u904d\u5386\u65f6\u95f4\u6b65\u6216\u5e8f\u5217\u957f\u5ea6\n        for (int t = 0; t &lt; T; t++) {\n            \/\/ \u5b9a\u4f4d\u5230\u8f93\u51fa\u77e9\u9635\u7684\u5177\u4f53\u4f4d\u7f6e\n            float* out_bt = out + b * T * OC + t * OC;\n            \/\/ \u5b9a\u4f4d\u5230\u8f93\u5165\u77e9\u9635\u7684\u5177\u4f53\u4f4d\u7f6e\n            float* inp_bt = inp + b * T * C + t * C;\n            \/\/ \u904d\u5386\u8f93\u51fa\u901a\u9053\n            for (int o = 0; o &lt; OC; o++) {\n                \/\/ \u5982\u679c\u6709\u504f\u7f6e\u9879\uff0c\u5219\u521d\u59cb\u5316\u4e3a\u504f\u7f6e\u503c\uff0c\u5426\u5219\u4e3a0\n                float val = (bias != NULL) ? bias[o] : 0.0f;\n                \/\/ \u5b9a\u4f4d\u5230\u6743\u91cd\u77e9\u9635\u7684\u5177\u4f53\u884c\n                float* wrow = weight + o*C;\n                \/\/ \u5bf9\u5f53\u524d\u884c\u7684\u6bcf\u4e2a\u5143\u7d20\u8fdb\u884c\u7d2f\u52a0\n                for (int i = 0; i &lt; C; i++) {\n                    \/\/ \u6267\u884c\u70b9\u4e58\u64cd\u4f5c\n                    val += inp_bt[i] * wrow[i];\n                }\n                \/\/ \u5c06\u8ba1\u7b97\u7ed3\u679c\u5b58\u50a8\u5230\u8f93\u51fa\u77e9\u9635\u4e2d\n                out_bt[o] = val;\n            }\n        }\n    }\n}\n<\/pre><\/div>\n\n\n\n<p>\u6b64\u51fd\u6570\u5229\u7528OpenMP\u8fdb\u884c\u5e76\u884c\u5316\u5904\u7406\uff0c\u4ee5\u63d0\u9ad8\u8ba1\u7b97\u6548\u7387\u3002<code>#pragma omp parallel for collapse(2)<\/code>\u6307\u4ee4\u4f1a\u5e76\u884c\u5316\u5d4c\u5957\u7684\u4e24\u5c42\u5faa\u73af\uff0c\u4ece\u800c\u52a0\u5feb\u6279\u6b21\u548c\u65f6\u95f4\u6b65\u7684\u904d\u5386\u8fc7\u7a0b\u3002<\/p>\n\n\n\n<p>\u8fd9\u4e2a\u77e9\u9635\u4e58\u6cd5\u64cd\u4f5c\u662f\u6df1\u5ea6\u5b66\u4e60\u4e2d\u5e38\u89c1\u7684\u64cd\u4f5c\u4e4b\u4e00\uff0c\u5b83\u5728\u5168\u8fde\u63a5\u5c42\u3001\u5377\u79ef\u5c42\u7684\u8ba1\u7b97\u4e2d\u90fd\u6709\u5e7f\u6cdb\u5e94\u7528\u3002\u901a\u8fc7\u5c06\u8f93\u5165\u6570\u636e\u548c\u6743\u91cd\u77e9\u9635\u76f8\u4e58\uff0c\u518d\u52a0\u4e0a\u504f\u7f6e\u9879\uff0c\u53ef\u4ee5\u5b9e\u73b0\u5bf9\u6570\u636e\u7684\u7ebf\u6027\u53d8\u6362\uff0c\u8fd9\u662f\u795e\u7ecf\u7f51\u7edc\u5b66\u4e60\u590d\u6742\u8868\u793a\u7684\u57fa\u7840\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.6 matmul_backward<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u5b9e\u73b0\u4e86\u77e9\u9635\u4e58\u6cd5\u64cd\u4f5c\u7684\u53cd\u5411\u4f20\u64ad\u8fc7\u7a0b\uff0c\u4e3b\u8981\u7528\u4e8e\u66f4\u65b0\u8f93\u5165\u3001\u6743\u91cd\u548c\u504f\u7f6e\u7684\u68af\u5ea6\u3002\u5728\u6df1\u5ea6\u5b66\u4e60\u8bad\u7ec3\u4e2d\uff0c\u53cd\u5411\u4f20\u64ad\u662f\u8ba1\u7b97\u635f\u5931\u51fd\u6570\u5173\u4e8e\u7f51\u7edc\u53c2\u6570\u68af\u5ea6\u7684\u5173\u952e\u6b65\u9aa4\uff0c\u7528\u4e8e\u53c2\u6570\u7684\u68af\u5ea6\u4e0b\u964d\u66f4\u65b0\u3002\u4ee5\u4e0b\u662f\u5bf9\u8fd9\u4e2a\u51fd\u6570\u7684\u8be6\u7ec6\u4e2d\u6587\u6ce8\u91ca\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">void matmul_backward(float* dinp, float* dweight, float* dbias,\n                     float* dout, float* inp, float* weight,\n                     int B, int T, int C, int OC) {\n    \/\/ most of the running time is spent here and in matmul_forward\n    \/\/ this backward could be done in a single \"round\" of loops\n    \/\/ but that doesn't afford an efficient parallelization strategy\n    \/\/ \u5927\u90e8\u5206\u8fd0\u884c\u65f6\u95f4\u82b1\u8d39\u5728\u8fd9\u91cc\u548cmatmul_forward\u51fd\u6570\n    \/\/ \u8fd9\u4e2a\u53cd\u5411\u4f20\u64ad\u53ef\u4ee5\u5728\u4e00\u8f6e\u5faa\u73af\u4e2d\u5b8c\u6210\n    \/\/ \u4f46\u8fd9\u6837\u505a\u4e0d\u5229\u4e8e\u6709\u6548\u7684\u5e76\u884c\u5316\u7b56\u7565\n\n    \/\/ backward into inp first, parallelize over B,T\n    \/\/ \u9996\u5148\u53cd\u5411\u4f20\u64ad\u5230\u8f93\u5165dinp\uff0c\u5bf9B,T\u8fdb\u884c\u5e76\u884c\u5316\n    #pragma omp parallel for collapse(2)\n    for (int b = 0; b &lt; B; b++) {\n        for (int t = 0; t &lt; T; t++) {\n            \/\/ \u5b9a\u4f4d\u5230\u7279\u5b9a\u6279\u6b21\u548c\u65f6\u95f4\u6b65\u7684\u8f93\u51fa\u68af\u5ea6\n            float* dout_bt = dout + b * T * OC + t * OC;\n            \/\/ \u5b9a\u4f4d\u5230\u5bf9\u5e94\u7684\u8f93\u5165\u68af\u5ea6\n            float* dinp_bt = dinp + b * T * C + t * C;\n            for (int o = 0; o &lt; OC; o++) {\n                \/\/ \u5b9a\u4f4d\u5230\u6743\u91cd\u77e9\u9635\u7684\u7279\u5b9a\u884c\n                float* wrow = weight + o*C;\n                \/\/ \u83b7\u53d6\u8f93\u51fa\u68af\u5ea6\n                float d = dout_bt[o];\n                for (int i = 0; i &lt; C; i++) {\n                    \/\/ \u66f4\u65b0\u8f93\u5165\u68af\u5ea6\n                    dinp_bt[i] += wrow[i] * d;\n                }\n            }\n        }\n    }\n    \/\/ backward into weight\/bias, parallelize over output channels OC\n    \/\/ \u7136\u540e\u53cd\u5411\u4f20\u64ad\u5230\u6743\u91cddweight\u548c\u504f\u7f6edbias\uff0c\u5bf9\u8f93\u51fa\u901a\u9053OC\u8fdb\u884c\u5e76\u884c\u5316\n    #pragma omp parallel for\n    for (int o = 0; o &lt; OC; o++) {\n        for (int b = 0; b &lt; B; b++) {\n            for (int t = 0; t &lt; T; t++) {\n                \/\/ \u5b9a\u4f4d\u5230\u8f93\u51fa\u68af\u5ea6\n                float* dout_bt = dout + b * T * OC + t * OC;\n                \/\/ \u5b9a\u4f4d\u5230\u8f93\u5165\n                float* inp_bt = inp + b * T * C + t * C;\n                \/\/ \u5b9a\u4f4d\u5230\u6743\u91cd\u68af\u5ea6\u7684\u7279\u5b9a\u884c\n                float* dwrow = dweight + o*C;\n                \/\/ \u83b7\u53d6\u8f93\u51fa\u68af\u5ea6\n                float d = dout_bt[o];\n                \/\/ \u5982\u679c\u6709\u504f\u7f6e\uff0c\u66f4\u65b0\u504f\u7f6e\u68af\u5ea6\n                if (dbias != NULL) { dbias[o] += d; }\n                for (int i = 0; i &lt; C; i++) {\n                    \/\/ \u66f4\u65b0\u6743\u91cd\u68af\u5ea6\n                    dwrow[i] += inp_bt[i] * d;\n                }\n            }\n        }\n    }\n}\n<\/pre><\/div>\n\n\n\n<p>\u5728\u8fd9\u4e2a\u51fd\u6570\u4e2d\uff0c\u9996\u5148\u901a\u8fc7\u5e76\u884c\u5316\u5904\u7406\u5bf9\u6bcf\u4e2a\u8f93\u5165\u5355\u5143<code>dinp<\/code>\u8fdb\u884c\u68af\u5ea6\u66f4\u65b0\u3002\u8fd9\u662f\u901a\u8fc7\u5c06\u6743\u91cd\u77e9\u9635<code>weight<\/code>\u548c\u8f93\u51fa\u68af\u5ea6<code>dout<\/code>\u76f8\u4e58\u5b8c\u6210\u7684\u3002\u968f\u540e\uff0c\u5bf9\u4e8e\u6743\u91cd\u77e9\u9635<code>dweight<\/code>\u548c\u504f\u7f6e\u5411\u91cf<code>dbias<\/code>\u7684\u68af\u5ea6\u66f4\u65b0\uff0c\u5219\u662f\u901a\u8fc7\u5c06\u8f93\u5165<code>inp<\/code>\u548c\u8f93\u51fa\u68af\u5ea6<code>dout<\/code>\u76f8\u4e58\u6765\u5b9e\u73b0\u7684\u3002\u8fd9\u4e2a\u8fc7\u7a0b\u53cd\u6620\u4e86\u77e9\u9635\u4e58\u6cd5\u7684\u7279\u6027\uff0c\u5373\u53cd\u5411\u4f20\u64ad\u8fc7\u7a0b\u5b9e\u9645\u4e0a\u662f\u53e6\u4e00\u79cd\u5f62\u5f0f\u7684\u77e9\u9635\u4e58\u6cd5\uff0c\u4f46\u662f\u8003\u8651\u4e86\u8f93\u51fa\u68af\u5ea6\u5bf9\u6743\u91cd\u548c\u8f93\u5165\u7684\u5f71\u54cd\u3002<\/p>\n\n\n\n<p>\u8fd9\u79cd\u5904\u7406\u65b9\u5f0f\u4f7f\u5f97\u53ef\u4ee5\u6709\u6548\u5730\u901a\u8fc7\u53cd\u5411\u4f20\u64ad\u6765\u66f4\u65b0\u6a21\u578b\u4e2d\u7ebf\u6027\u5c42\u7684\u53c2\u6570\uff0c\u4ece\u800c\u5728\u8bad\u7ec3\u8fc7\u7a0b\u4e2d\u6539\u5584\u6a21\u578b\u7684\u6027\u80fd\u3002\u901a\u8fc7OpenMP\u7684\u5e76\u884c\u5316\u6307\u4ee4\uff0c\u8fd9\u4e2a\u8fc7\u7a0b\u8fd8\u5229\u7528\u4e86\u591a\u6838\u5904\u7406\u5668\u7684\u8ba1\u7b97\u80fd\u529b\uff0c\u4ee5\u52a0\u901f\u68af\u5ea6\u7684\u8ba1\u7b97\u548c\u66f4\u65b0\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.7 attention_forward<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u5b9e\u73b0\u4e86Transformer\u6a21\u578b\u4e2d\u6ce8\u610f\u529b\u673a\u5236\uff08Attention Mechanism\uff09\u7684\u524d\u5411\u4f20\u64ad\u8fc7\u7a0b\u3002\u6ce8\u610f\u529b\u673a\u5236\u662fTransformer\u6a21\u578b\u7684\u6838\u5fc3\uff0c\u5141\u8bb8\u6a21\u578b\u5728\u5904\u7406\u5e8f\u5217\u6570\u636e\u65f6\u52a8\u6001\u5730\u5173\u6ce8\uff08\u6216\u201c\u6ce8\u610f\u201d\uff09\u5e8f\u5217\u4e2d\u7684\u4e0d\u540c\u90e8\u5206\u3002\u4ee5\u4e0b\u662f\u5bf9\u8fd9\u4e2a\u51fd\u6570\u7684\u8be6\u7ec6\u4e2d\u6587\u6ce8\u91ca\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">void attention_forward(float* out, float* preatt, float* att,\n                       float* inp,\n                       int B, int T, int C, int NH) {\n    \/\/ input is (B, T, 3C) holding the query, key, value (Q, K, V) vectors\n    \/\/ preatt, att are (B, NH, T, T). NH = number of heads, T = sequence length\n    \/\/ that holds the pre-attention and post-attention scores (used in backward)\n    \/\/ output is (B, T, C)\n    \/\/ attention is the only layer that mixes information across time\n    \/\/ every other operation is applied at every (b,t) position independently\n    \/\/ (and of course, no layer mixes information across batch)\n    \/\/ \u8f93\u5165inp\u7684\u7ef4\u5ea6\u662f(B, T, 3C)\uff0c\u5305\u542b\u67e5\u8be2\uff08Query\uff09\u3001\u952e\uff08Key\uff09\u548c\u503c\uff08Value\uff09\u5411\u91cf\n    \/\/ preatt\u548catt\u7684\u7ef4\u5ea6\u662f(B, NH, T, T)\uff0cNH\u662f\u5934\u7684\u6570\u91cf\uff0cT\u662f\u5e8f\u5217\u957f\u5ea6\n    \/\/ \u5b83\u4eec\u5b58\u50a8\u4e86\u6ce8\u610f\u529b\u8ba1\u7b97\u4e4b\u524d\u548c\u4e4b\u540e\u7684\u5206\u6570\uff08\u5728\u53cd\u5411\u4f20\u64ad\u4e2d\u4f7f\u7528\uff09\n    \/\/ \u8f93\u51faout\u7684\u7ef4\u5ea6\u662f(B, T, C)\n    \/\/ \u6ce8\u610f\u529b\u5c42\u662f\u552f\u4e00\u4e00\u4e2a\u8de8\u65f6\u95f4\u6df7\u5408\u4fe1\u606f\u7684\u5c42\n    \/\/ \u5176\u4ed6\u6240\u6709\u64cd\u4f5c\u90fd\u5728\u6bcf\u4e2a(b,t)\u4f4d\u7f6e\u72ec\u7acb\u5e94\u7528\n    \/\/ \uff08\u5f53\u7136\uff0c\u6ca1\u6709\u4efb\u4f55\u5c42\u4f1a\u8de8\u6279\u6b21\u6df7\u5408\u4fe1\u606f\uff09\n    int C3 = C*3;\n    \/\/ \u6bcf\u4e2a\u5934\u7684\u5927\u5c0f\n    int hs = C \/ NH; \/\/ head size\n    \/\/ \u7528\u4e8e\u7f29\u653e\u70b9\u79ef\u7684\u56e0\u5b50\n    float scale = 1.0 \/ sqrtf(hs);\n\n    #pragma omp parallel for collapse(3)\n    for (int b = 0; b &lt; B; b++) {\n        for (int t = 0; t &lt; T; t++) {\n            for (int h = 0; h &lt; NH; h++) {\n                float* query_t = inp + b * T * C3 + t * C3 + h * hs;\n                float* preatt_bth = preatt + b*NH*T*T + h*T*T + t*T;\n                float* att_bth = att + b*NH*T*T + h*T*T + t*T;\n\n                \/\/ pass 1: calculate query dot key and maxval\n                \/\/ \u6b65\u9aa41: \u8ba1\u7b97\u67e5\u8be2\u4e0e\u952e\u7684\u70b9\u79ef\u5e76\u627e\u5230\u6700\u5927\u503c\n                \/\/ \u7528\u4e8e\u6570\u503c\u7a33\u5b9a\u6027\n                float maxval = -10000.0f; \/\/ TODO something better\n                for (int t2 = 0; t2 &lt;= t; t2++) {\n                    \/\/ \u952e\u5411\u91cf\u4f4d\u7f6e\n                    float* key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C; \/\/ +C because it's key\n\n                    \/\/ (query_t) dot (key_t2)\n                    \/\/ \u8ba1\u7b97\u70b9\u79ef\u5e76\u7f29\u653e\n                    float val = 0.0f;\n                    for (int i = 0; i &lt; hs; i++) {\n                        val += query_t[i] * key_t2[i];\n                    }\n                    val *= scale;\n                    if (val &gt; maxval) {\n                        maxval = val;\n                    }\n\n                    preatt_bth[t2] = val;\n                }\n\n                \/\/ pass 2: calculate the exp and keep track of sum\n                \/\/ maxval is being calculated and subtracted only for numerical stability\n                \/\/ \u6b65\u9aa42: \u8ba1\u7b97exp\u5e76\u4fdd\u6301\u6c42\u548c\n                float expsum = 0.0f;\n                for (int t2 = 0; t2 &lt;= t; t2++) {\n                    float expv = expf(preatt_bth[t2] - maxval);\n                    expsum += expv;\n                    att_bth[t2] = expv;\n                }\n                float expsum_inv = expsum == 0.0f ? 0.0f : 1.0f \/ expsum;\n\n                \/\/ pass 3: normalize to get the softmax\n                \/\/ \u6b65\u9aa43: \u5f52\u4e00\u5316\u4ee5\u83b7\u5f97softmax\n                for (int t2 = 0; t2 &lt; T; t2++) {\n                    if (t2 &lt;= t) {\n                        att_bth[t2] *= expsum_inv;\n                    } else {\n                        \/\/ causal attention mask. not strictly necessary to set to zero here\n                        \/\/ only doing this explicitly for debugging and checking to PyTorch\n                        \/\/ \u56e0\u679c\u6ce8\u610f\u529b\u63a9\u7801\u3002\u8fd9\u91cc\u4e0d\u4e25\u683c\u5fc5\u8981\u8bbe\u7f6e\u4e3a\u96f6\n                        att_bth[t2] = 0.0f;\n                    }\n                }\n\n                \/\/ pass 4: accumulate weighted values into the output of attention\n                \/\/ \u6b65\u9aa44: \u7d2f\u79ef\u52a0\u6743\u503c\u5230\u6ce8\u610f\u529b\u8f93\u51fa\n                float* out_bth = out + b * T * C + t * C + h * hs;\n                for (int i = 0; i &lt; hs; i++) { out_bth[i] = 0.0f; }\n                for (int t2 = 0; t2 &lt;= t; t2++) {\n                    \/\/ \u503c\u5411\u91cf\u4f4d\u7f6e\n                    float* value_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C*2; \/\/ +C*2 because it's value\n                    float att_btht2 = att_bth[t2];\n                    for (int i = 0; i &lt; hs; i++) {\n                        out_bth[i] += att_btht2 * value_t2[i];\n                    }\n                }\n            }\n        }\n    }\n}\n<\/pre><\/div>\n\n\n\n<p>\u8fd9\u4e2a\u51fd\u6570\u9996\u5148\u8ba1\u7b97\u67e5\u8be2\uff08Query\uff09\u548c\u952e\uff08Key\uff09\u7684\u7f29\u653e\u70b9\u79ef\u6ce8\u610f\u529b\uff0c\u7136\u540e\u5e94\u7528softmax\u51fd\u6570\u4ee5\u83b7\u53d6\u6ce8\u610f\u529b\u6743\u91cd\uff0c\u63a5\u7740\u6839\u636e\u8fd9\u4e9b\u6743\u91cd\u5bf9\u503c\uff08Value\uff09\u5411\u91cf\u8fdb\u884c\u52a0\u6743\u548c\uff0c\u6700\u7ec8\u5f97\u5230\u6ce8\u610f\u529b\u5c42\u7684\u8f93\u51fa\u3002\u8fd9\u79cd\u673a\u5236\u4f7f\u5f97\u6a21\u578b\u80fd\u591f\u52a8\u6001\u5730\u805a\u7126\u4e8e\u8f93\u5165\u5e8f\u5217\u7684\u4e0d\u540c\u90e8\u5206\uff0c\u4ece\u800c\u6355\u6349\u5e8f\u5217\u5185\u7684\u957f\u8ddd\u79bb\u4f9d\u8d56\u5173\u7cfb\u3002\u901a\u8fc7OpenMP\u7684\u5e76\u884c\u5316\u6307\u4ee4\uff0c\u8fd9\u4e2a\u8fc7\u7a0b\u8fd8\u5229\u7528\u4e86\u591a\u6838\u5904\u7406\u5668\u7684\u8ba1\u7b97\u80fd\u529b\uff0c\u4ee5\u52a0\u901f\u8ba1\u7b97\u8fc7\u7a0b\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.8 attention_backward<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u5b9e\u73b0\u4e86\u6ce8\u610f\u529b\u673a\u5236\u7684\u53cd\u5411\u4f20\u64ad\u8fc7\u7a0b\uff0c\u662f\u6df1\u5ea6\u5b66\u4e60\u4e2dTransformer\u6a21\u578b\u7684\u6838\u5fc3\u7ec4\u6210\u90e8\u5206\u3002\u5728\u8bad\u7ec3\u795e\u7ecf\u7f51\u7edc\u65f6\uff0c\u53cd\u5411\u4f20\u64ad\u662f\u4e00\u79cd\u8ba1\u7b97\u635f\u5931\u51fd\u6570\u5173\u4e8e\u7f51\u7edc\u53c2\u6570\u68af\u5ea6\u7684\u65b9\u6cd5\uff0c\u7528\u4e8e\u66f4\u65b0\u6a21\u578b\u53c2\u6570\u4ee5\u6539\u5584\u6027\u80fd\u3002\u4ee5\u4e0b\u662f\u5bf9\u8fd9\u4e2a\u51fd\u6570\u7684\u8be6\u7ec6\u4e2d\u6587\u6ce8\u91ca\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">void attention_backward(float* dinp, float* dpreatt, float* datt,\n                        float* dout, float* inp, float* att,\n                        int B, int T, int C, int NH) {\n    \/\/ inp\/dinp are (B, T, 3C) Q,K,V\n    \/\/ att\/datt\/dpreatt are (B, NH, T, T)\n    \/\/ dout is (B, T, C)\n    \/\/ inp\/dinp\u662f(B, T, 3C)\u7ef4\u7684\uff0c\u5305\u542b\u4e86\u67e5\u8be2\uff08Q\uff09\u3001\u952e\uff08K\uff09\u3001\u503c\uff08V\uff09\n    \/\/ att\/datt\/dpreatt\u662f(B, NH, T, T)\u7ef4\u7684\n    \/\/ dout\u662f(B, T, C)\u7ef4\u7684\u8f93\u51fa\u68af\u5ea6\n    int C3 = C*3;\n    \/\/ \u6bcf\u4e2a\u5934\u7684\u5927\u5c0f\n    int hs = C \/ NH; \/\/ head size\n    \/\/ \u7f29\u653e\u56e0\u5b50\uff0c\u7528\u4e8e\u7f29\u653e\u67e5\u8be2\u4e0e\u952e\u7684\u70b9\u79ef\n    float scale = 1.0 \/ sqrtf(hs);\n\n    for (int b = 0; b &lt; B; b++) {\n        for (int t = 0; t &lt; T; t++) {\n            for (int h = 0; h &lt; NH; h++) {\n                \/\/ \u83b7\u53d6att\u3001datt\u548cdpreatt\u7684\u6307\u9488\n                float* att_bth = att + b*NH*T*T + h*T*T + t*T;\n                float* datt_bth = datt + b*NH*T*T + h*T*T + t*T;\n                float* dpreatt_bth = dpreatt + b*NH*T*T + h*T*T + t*T;\n                \/\/ \u83b7\u53d6\u67e5\u8be2\uff08Q\uff09\u7684\u68af\u5ea6\u548c\u539f\u59cb\u503c\u7684\u6307\u9488\n                float* dquery_t = dinp + b * T * C3 + t * C3 + h * hs;\n                float* query_t = inp + b * T * C3 + t * C3 + h * hs;\n\n                \/\/ backward pass 4, through the value accumulation\n                \/\/ \u53cd\u5411\u4f20\u64ad\u6b65\u9aa44\uff0c\u901a\u8fc7\u503c\uff08V\uff09\u7684\u7d2f\u79ef\n                float* dout_bth = dout + b * T * C + t * C + h * hs;\n                for (int t2 = 0; t2 &lt;= t; t2++) {\n                    \/\/ \u83b7\u53d6\u503c\uff08V\uff09\u7684\u6307\u9488\n                    float* value_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C*2; \/\/ +C*2 because it's value\n                    \/\/ \u83b7\u53d6\u503c\uff08V\uff09\u7684\u68af\u5ea6\u6307\u9488\n                    float* dvalue_t2 = dinp + b * T * C3 + t2 * C3 + h * hs + C*2;\n                    for (int i = 0; i &lt; hs; i++) {\n                        \/\/ in the forward pass this was:\n                        \/\/ out_bth[i] += att_bth[t2] * value_t2[i];\n                        \/\/ so now we have:\n                        \/\/ \u66f4\u65b0\u6ce8\u610f\u529b\u6743\u91cd\u548c\u503c\uff08V\uff09\u7684\u68af\u5ea6\n                        datt_bth[t2] += value_t2[i] * dout_bth[i];\n                        dvalue_t2[i] += att_bth[t2] * dout_bth[i];\n                    }\n                }\n\n                \/\/ backward pass 2 &amp; 3, the softmax\n                \/\/ note that softmax (like e.g. tanh) doesn't need the input (preatt) to backward\n                \/\/ \u53cd\u5411\u4f20\u64ad\u6b65\u9aa42\u548c3\uff0csoftmax\u90e8\u5206\n                for (int t2 = 0; t2 &lt;= t; t2++) {\n                    for (int t3 = 0; t3 &lt;= t; t3++) {\n                        \/\/ \u6307\u793a\u5668\u51fd\u6570\n                        float indicator = t2 == t3 ? 1.0f : 0.0f;\n                        \/\/ \u672c\u5730\u5bfc\u6570\n                        float local_derivative = att_bth[t2] * (indicator - att_bth[t3]);\n                        dpreatt_bth[t3] += local_derivative * datt_bth[t2];\n                    }\n                }\n\n                \/\/ backward pass 1, the query @ key matmul\n                \/\/ \u53cd\u5411\u4f20\u64ad\u6b65\u9aa41\uff0c\u67e5\u8be2\uff08Q\uff09\u4e0e\u952e\uff08K\uff09\u7684\u77e9\u9635\u4e58\u6cd5\n                for (int t2 = 0; t2 &lt;= t; t2++) {\n                    \/\/ \u83b7\u53d6\u952e\uff08K\uff09\u7684\u6307\u9488\n                    float* key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C; \/\/ +C because it's key\n                    \/\/ \u83b7\u53d6\u952e\uff08K\uff09\u7684\u68af\u5ea6\u6307\u9488\n                    float* dkey_t2 = dinp + b * T * C3 + t2 * C3 + h * hs + C; \/\/ +C because it's key\n                    for (int i = 0; i &lt; hs; i++) {\n                        \/\/ in the forward pass this was:\n                        \/\/ preatt_bth[t2] += (query_t[i] * key_t2[i]) * scale;\n                        \/\/ so now we have:\n                        \/\/ \u66f4\u65b0\u67e5\u8be2\uff08Q\uff09\u548c\u952e\uff08K\uff09\u7684\u68af\u5ea6\n                        dquery_t[i] += key_t2[i] * dpreatt_bth[t2] * scale;\n                        dkey_t2[i] += query_t[i] * dpreatt_bth[t2] * scale;\n                    }\n                }\n            }\n        }\n    }\n}\n<\/pre><\/div>\n\n\n\n<p>\u5728\u8fd9\u4e2a\u53cd\u5411\u4f20\u64ad\u8fc7\u7a0b\u4e2d\uff0c\u9996\u5148\u5904\u7406\u503c\uff08Value\uff09\u7684\u7d2f\u79ef\uff0c\u7136\u540e\u5904\u7406softmax\u90e8\u5206\uff0c\u6700\u540e\u5904\u7406\u67e5\u8be2\uff08Query\uff09\u4e0e\u952e\uff08Key\uff09\u7684\u70b9\u79ef\u3002\u8fd9\u79cd\u5206\u6b65\u5904\u7406\u65b9\u5f0f\u53cd\u6620\u4e86\u6ce8\u610f\u529b\u673a\u5236\u7684\u8ba1\u7b97\u8fc7\u7a0b\uff1a\u9996\u5148\u8ba1\u7b97\u67e5\u8be2\u548c\u952e\u4e4b\u95f4\u7684\u76f8\u4f3c\u5ea6\uff0c\u7136\u540e\u901a\u8fc7softmax\u51fd\u6570\u5c06\u8fd9\u4e9b\u76f8\u4f3c\u5ea6\u8f6c\u6362\u6210\u6ce8\u610f\u529b\u6743\u91cd\uff0c\u6700\u540e\u4f7f\u7528\u8fd9\u4e9b\u6743\u91cd\u6765\u52a0\u6743\u503c\uff08Value\uff09\u5411\u91cf\uff0c\u751f\u6210\u6700\u7ec8\u7684\u8f93\u51fa\u3002\u53cd\u5411\u4f20\u64ad\u8fc7\u7a0b\u5219\u662f\u8fd9\u4e00\u8ba1\u7b97\u8fc7\u7a0b\u7684\u9006\u8fc7\u7a0b\uff0c\u6839\u636e\u8f93\u51fa\u68af\u5ea6\uff08dout\uff09\u6765\u8ba1\u7b97\u8f93\u5165\uff08inp\uff09\u3001\u6ce8\u610f\u529b\u6743\u91cd\uff08att\uff09\u53ca\u5176\u9884\u6fc0\u6d3b\u503c\uff08preatt\uff09\u7684\u68af\u5ea6\uff08dinp\u3001datt\u3001dpreatt\uff09\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.9 gelu_backward<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u5b9e\u73b0\u4e86Gaussian Error Linear Unit (GeLU) \u6fc0\u6d3b\u51fd\u6570\u7684\u524d\u5411\u4f20\u64ad\u3002GeLU\u6fc0\u6d3b\u51fd\u6570\u5728Transformer\u6a21\u578b\u4e2d\u7684\u591a\u5c42\u611f\u77e5\u673a\uff08MLP\uff09\u90e8\u5206\u88ab\u5e7f\u6cdb\u4f7f\u7528\uff0c\u5b83\u53ef\u4ee5\u63d0\u4f9b\u4e00\u79cd\u975e\u7ebf\u6027\u53d8\u6362\uff0c\u6709\u52a9\u4e8e\u6a21\u578b\u6355\u6349\u590d\u6742\u7684\u7279\u5f81\u3002\u4ee5\u4e0b\u662f\u5bf9\u8fd9\u4e2a\u51fd\u6570\u7684\u8be6\u7ec6\u4e2d\u6587\u6ce8\u91ca\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">#define GELU_SCALING_FACTOR sqrtf(2.0f \/ M_PI)\nvoid gelu_forward(float* out, float* inp, int N) {\n    \/\/ (approximate) GeLU elementwise non-linearity in the MLP block of Transformer\n    \/\/ Transformer\u4e2dMLP\u5757\u7684\uff08\u8fd1\u4f3c\u7684\uff09GeLU\u9010\u5143\u7d20\u975e\u7ebf\u6027\u6fc0\u6d3b\u51fd\u6570\n    for (int i = 0; i &lt; N; i++) {\n        \/\/ \u5f53\u524d\u5143\u7d20\u7684\u503c\n        float x = inp[i];\n        \/\/ x\u7684\u4e09\u6b21\u65b9\u4e58\u4ee50.044715\n        float cube = 0.044715f * x * x * x;\n        \/\/ \u8ba1\u7b97GeLU\u51fd\u6570\u7684\u503c\u5e76\u8d4b\u503c\u7ed9\u8f93\u51fa\n        out[i] = 0.5f * x * (1.0f + tanhf(GELU_SCALING_FACTOR * (x + cube)));\n    }\n}\n<\/pre><\/div>\n\n\n\n<p>GeLU\u51fd\u6570\u662f\u901a\u8fc7\u5c06\u8f93\u5165\u503c<code>x<\/code>\u8fdb\u884c\u9ad8\u65af\u5206\u5e03\u7684\u7d2f\u79ef\u5206\u5e03\u51fd\u6570\uff08CDF\uff09\u53d8\u6362\u5f97\u5230\u7684\u3002\u8fd9\u91cc\u4f7f\u7528\u7684\u662fGeLU\u51fd\u6570\u7684\u8fd1\u4f3c\u5f62\u5f0f\uff0c\u5176\u901a\u8fc7tanh\u51fd\u6570\u6765\u8fd1\u4f3c\u5b9e\u73b0\u3002<code>GELU_SCALING_FACTOR<\/code>\u662f\u6839\u636eGeLU\u51fd\u6570\u7684\u5b9a\u4e49\u8ba1\u7b97\u5f97\u5230\u7684\u7f29\u653e\u56e0\u5b50\uff0c\u7528\u4e8e\u8c03\u6574\u8f93\u5165\u503c\u7684\u5c3a\u5ea6\u3002<code>0.044715f * x * x * x<\/code>\u9879\u662f\u7528\u6765\u589e\u52a0\u51fd\u6570\u7684\u975e\u7ebf\u6027\u5ea6\u3002\u901a\u8fc7\u8fd9\u79cd\u65b9\u5f0f\uff0cGeLU\u6fc0\u6d3b\u51fd\u6570\u5141\u8bb8\u4e00\u4e9b\u8f93\u5165\u76f4\u63a5\u901a\u8fc7\uff08\u5bf9\u5e94\u4e8e\u7ebf\u6027\u533a\u57df\uff09\uff0c\u540c\u65f6\u5bf9\u5176\u4ed6\u8f93\u5165\u8fdb\u884c\u975e\u7ebf\u6027\u53d8\u6362\uff0c\u8fd9\u6709\u52a9\u4e8e\u6a21\u578b\u5728\u8bad\u7ec3\u8fc7\u7a0b\u4e2d\u5b66\u4e60\u5230\u66f4\u52a0\u590d\u6742\u548c\u62bd\u8c61\u7684\u8868\u793a\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.10 gelu_backward<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u5b9e\u73b0\u4e86GeLU\u6fc0\u6d3b\u51fd\u6570\u7684\u53cd\u5411\u4f20\u64ad\u8fc7\u7a0b\uff0c\u7528\u4e8e\u8ba1\u7b97GeLU\u6fc0\u6d3b\u51fd\u6570\u76f8\u5bf9\u4e8e\u5176\u8f93\u5165\u7684\u68af\u5ea6\uff0c\u5e76\u6839\u636e\u94fe\u5f0f\u6cd5\u5219\u66f4\u65b0\u524d\u4e00\u5c42\u7684\u68af\u5ea6\u3002\u5728\u6df1\u5ea6\u5b66\u4e60\u4e2d\uff0c\u6b63\u786e\u5730\u8ba1\u7b97\u68af\u5ea6\u5bf9\u4e8e\u901a\u8fc7\u68af\u5ea6\u4e0b\u964d\u7b97\u6cd5\u6709\u6548\u5730\u8bad\u7ec3\u6a21\u578b\u81f3\u5173\u91cd\u8981\u3002\u4ee5\u4e0b\u662f\u5bf9\u8fd9\u4e2a\u51fd\u6570\u7684\u8be6\u7ec6\u4e2d\u6587\u6ce8\u91ca\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">\/\/ we want to use -Ofast optimization, but sadly GeLU breaks, so disable this flag just for it (#168)\n\/\/ \u7531\u4e8e\u4f7f\u7528-Ofast\u4f18\u5316\u65f6GeLU\u51fd\u6570\u4f1a\u51fa\u73b0\u95ee\u9898\uff0c\u56e0\u6b64\u7279\u610f\u4e3aGeLU\u51fd\u6570\u7981\u7528\u8be5\u4f18\u5316\u6807\u5fd7\uff08#168\uff09\n#pragma float_control(precise, on, push) \/\/ On msvc \/fp:fast is a lot faster, but the expf inside coshf breaks the model\n__attribute__((optimize(\"no-finite-math-only\"))) \/\/ same for gcc -Ofast\nvoid gelu_backward(float* dinp, float* inp, float* dout, int N) {\n    for (int i = 0; i &lt; N; i++) {\n        \/\/ \u8f93\u5165\u503c\n        float x = inp[i];\n        \/\/ \u8ba1\u7b97x\u7684\u4e09\u6b21\u65b9\u9879\n        float cube = 0.044715f * x * x * x;\n        \/\/ \u8ba1\u7b97tanh\u51fd\u6570\u7684\u53c2\u6570\n        float tanh_arg = GELU_SCALING_FACTOR * (x + cube);\n        \/\/ \u8ba1\u7b97tanh\u7684\u8f93\u51fa\n        float tanh_out = tanhf(tanh_arg);\n        \/\/ \u8ba1\u7b97cosh\u7684\u8f93\u51fa\uff0c\u7528\u4e8e\u8ba1\u7b97sech\n        float coshf_out = coshf(tanh_arg);\n        \/\/ \u8ba1\u7b97sech\u7684\u8f93\u51fa\n        float sech_out = 1.0f \/ (coshf_out * coshf_out);\n        \/\/ \u8ba1\u7b97\u5c40\u90e8\u68af\u5ea6\n        float local_grad = 0.5f * (1.0f + tanh_out) + x * 0.5f * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x);\n        \/\/ \u66f4\u65b0\u68af\u5ea6\n        dinp[i] += local_grad * dout[i];\n    }\n}\n#pragma float_control(pop)\n<\/pre><\/div>\n\n\n\n<p>\u8fd9\u91cc\u4f7f\u7528\u7684\u7f16\u8bd1\u5668\u6307\u4ee4\uff08\u5982<code>#pragma float_control<\/code>\u548c<code>__attribute__<\/code>\uff09\u662f\u4e3a\u4e86\u786e\u4fdd\u5728\u7f16\u8bd1\u65f6\u4e0d\u4f1a\u56e0\u4e3a\u4f18\u5316\u9009\u9879\u800c\u6539\u53d8GeLU\u51fd\u6570\u7684\u6570\u5b66\u884c\u4e3a\uff0c\u4fdd\u8bc1\u6a21\u578b\u7684\u51c6\u786e\u6027\u548c\u7a33\u5b9a\u6027\u3002\u7279\u522b\u5730\uff0cGeLU\u6fc0\u6d3b\u51fd\u6570\u7684\u53cd\u5411\u4f20\u64ad\u6d89\u53ca\u5230\u4e86tanh\u548csech\u51fd\u6570\u7684\u5bfc\u6570\uff0c\u9700\u8981\u7cbe\u786e\u7684\u6d6e\u70b9\u8fd0\u7b97\u6765\u4fdd\u8bc1\u68af\u5ea6\u8ba1\u7b97\u7684\u51c6\u786e\u6027\u3002\u901a\u8fc7\u8ba1\u7b97\u6bcf\u4e2a\u8f93\u5165\u5143\u7d20\u7684\u5c40\u90e8\u68af\u5ea6\u5e76\u4e58\u4ee5\u6765\u81ea\u540e\u4e00\u5c42\u7684\u68af\u5ea6\uff08<code>dout<\/code>\uff09\uff0c\u8be5\u51fd\u6570\u80fd\u591f\u4e3a\u524d\u4e00\u5c42\uff08\u5373GeLU\u51fd\u6570\u7684\u8f93\u5165\u5c42\uff09\u6b63\u786e\u5730\u66f4\u65b0\u68af\u5ea6\uff08<code>dinp<\/code>\uff09\uff0c\u8fd9\u662f\u6a21\u578b\u8bad\u7ec3\u8fc7\u7a0b\u4e2d\u68af\u5ea6\u53cd\u5411\u4f20\u64ad\u7684\u91cd\u8981\u6b65\u9aa4\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.11 residual_forward<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u5b9e\u73b0\u4e86\u6b8b\u5dee\u8fde\u63a5\uff08Residual Connection\uff09\u7684\u524d\u5411\u4f20\u64ad\u8fc7\u7a0b\u3002\u5728\u6df1\u5ea6\u5b66\u4e60\u6a21\u578b\u4e2d\uff0c\u7279\u522b\u662f\u5728\u6df1\u5ea6\u7f51\u7edc\u5982ResNet\u548cTransformer\u4e2d\uff0c\u6b8b\u5dee\u8fde\u63a5\u5e2e\u52a9\u6a21\u578b\u6709\u6548\u5730\u8bad\u7ec3\uff0c\u901a\u8fc7\u6dfb\u52a0\u8f93\u5165\u5230\u8f93\u51fa\u6765\u9632\u6b62\u68af\u5ea6\u6d88\u5931\u95ee\u9898\u3002\u4ee5\u4e0b\u662f\u5bf9\u8fd9\u4e2a\u51fd\u6570\u7684\u8be6\u7ec6\u4e2d\u6587\u6ce8\u91ca\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">void residual_forward(float* out, float* inp1, float* inp2, int N) {\n    \/\/ \u904d\u5386\u6bcf\u4e2a\u5143\u7d20\n    for (int i = 0; i &lt; N; i++) {\n        \/\/ \u5c06\u4e24\u4e2a\u8f93\u5165\u76f8\u52a0\u5e76\u5b58\u50a8\u5230\u8f93\u51fa\u4e2d\n        out[i] = inp1[i] + inp2[i];\n    }\n}\n<\/pre><\/div>\n\n\n\n<p>\u8fd9\u91cc<code>inp1<\/code>\u548c<code>inp2<\/code>\u5206\u522b\u662f\u4e24\u4e2a\u8f93\u5165\u6570\u7ec4\uff0c\u5b83\u4eec\u53ef\u4ee5\u4ee3\u8868\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\u4e2d\u67d0\u5c42\u7684\u8f93\u5165\u548c\u8be5\u5c42\u7684\u53d8\u6362\u8f93\u51fa\u3002<code>out<\/code>\u662f\u8f93\u51fa\u6570\u7ec4\uff0c\u5176\u4e2d\u5b58\u50a8\u4e86<code>inp1<\/code>\u548c<code>inp2<\/code>\u9010\u5143\u7d20\u76f8\u52a0\u7684\u7ed3\u679c\u3002<code>N<\/code>\u662f\u6570\u7ec4\u4e2d\u5143\u7d20\u7684\u603b\u6570\u3002\u901a\u8fc7\u8fd9\u79cd\u65b9\u5f0f\uff0c\u6b8b\u5dee\u8fde\u63a5\u5141\u8bb8\u7f51\u7edc\u5b66\u4e60\u5bf9\u8f93\u5165\u7684\u6052\u7b49\u53d8\u6362\uff08identity transformation\uff09\uff0c\u5373\u76f4\u63a5\u5c06\u8f93\u5165\u4f20\u9012\u5230\u8f93\u51fa\uff0c\u8fd9\u6709\u52a9\u4e8e\u89e3\u51b3\u66f4\u6df1\u5c42\u7f51\u7edc\u5728\u8bad\u7ec3\u65f6\u53ef\u80fd\u9047\u5230\u7684\u68af\u5ea6\u6d88\u5931\u95ee\u9898\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.12 residual_backward<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u5b9e\u73b0\u4e86\u6b8b\u5dee\u8fde\u63a5\u7684\u53cd\u5411\u4f20\u64ad\u8fc7\u7a0b\u3002\u5728\u6df1\u5ea6\u5b66\u4e60\u6a21\u578b\u4e2d\uff0c\u53cd\u5411\u4f20\u64ad\u662f\u8ba1\u7b97\u68af\u5ea6\u5e76\u66f4\u65b0\u6a21\u578b\u53c2\u6570\u7684\u5173\u952e\u6b65\u9aa4\u3002\u5bf9\u4e8e\u6b8b\u5dee\u8fde\u63a5\uff0c\u53cd\u5411\u4f20\u64ad\u8fc7\u7a0b\u7b80\u5355\u76f4\u63a5\uff0c\u56e0\u4e3a\u6b8b\u5dee\u8fde\u63a5\u53ea\u6d89\u53ca\u7b80\u5355\u7684\u52a0\u6cd5\u64cd\u4f5c\u3002\u4ee5\u4e0b\u662f\u5bf9\u8fd9\u4e2a\u51fd\u6570\u7684\u8be6\u7ec6\u4e2d\u6587\u6ce8\u91ca\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">void residual_backward(float* dinp1, float* dinp2, float* dout, int N) {\n    \/\/ \u904d\u5386\u6bcf\u4e2a\u5143\u7d20\n    for (int i = 0; i &lt; N; i++) {\n        \/\/ \u5c06\u8f93\u51fa\u68af\u5ea6\u76f4\u63a5\u4f20\u9012\u7ed9\u4e24\u4e2a\u8f93\u5165\u68af\u5ea6\n        dinp1[i] += dout[i];\n        dinp2[i] += dout[i];\n    }\n}\n<\/pre><\/div>\n\n\n\n<p>\u8fd9\u91cc<code>dout<\/code>\u662f\u4ece\u540e\u7eed\u5c42\u4f20\u56de\u7684\u68af\u5ea6\u6570\u7ec4\uff0c<code>dinp1<\/code>\u548c<code>dinp2<\/code>\u5206\u522b\u662f\u9700\u8981\u66f4\u65b0\u7684\u4e24\u4e2a\u8f93\u5165\u68af\u5ea6\u6570\u7ec4\u3002<code>N<\/code>\u662f\u6570\u7ec4\u4e2d\u5143\u7d20\u7684\u603b\u6570\u3002\u7531\u4e8e\u6b8b\u5dee\u8fde\u63a5\u7684\u524d\u5411\u4f20\u64ad\u53ea\u662f\u5c06\u4e24\u4e2a\u8f93\u5165\u76f8\u52a0\uff0c\u6240\u4ee5\u5176\u53cd\u5411\u4f20\u64ad\u8fc7\u7a0b\u4e2d\u68af\u5ea6\u7684\u4f20\u9012\u975e\u5e38\u76f4\u63a5\u2014\u2014\u540e\u7eed\u5c42\u4f20\u56de\u7684\u68af\u5ea6<code>dout<\/code>\u88ab\u7b80\u5355\u5730\u7d2f\u52a0\u5230\u4e24\u4e2a\u8f93\u5165\u7684\u68af\u5ea6<code>dinp1<\/code>\u548c<code>dinp2<\/code>\u4e0a\u3002<\/p>\n\n\n\n<p>\u6b8b\u5dee\u8fde\u63a5\u901a\u8fc7\u8fd9\u79cd\u65b9\u5f0f\u7b80\u5316\u4e86\u68af\u5ea6\u7684\u6d41\u52a8\uff0c\u6709\u52a9\u4e8e\u9632\u6b62\u6df1\u5c42\u7f51\u7edc\u4e2d\u68af\u5ea6\u6d88\u5931\u6216\u68af\u5ea6\u7206\u70b8\u7684\u95ee\u9898\uff0c\u4f7f\u5f97\u8bad\u7ec3\u6df1\u5c42\u7f51\u7edc\u53d8\u5f97\u66f4\u52a0\u53ef\u884c\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.13 softmax_forward<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u5b9e\u73b0\u4e86softmax\u51fd\u6570\u7684\u524d\u5411\u4f20\u64ad\u8fc7\u7a0b\u3002Softmax\u51fd\u6570\u5e38\u7528\u4e8e\u6df1\u5ea6\u5b66\u4e60\u6a21\u578b\u4e2d\u7684\u591a\u5206\u7c7b\u4efb\u52a1\uff0c\u5b83\u53ef\u4ee5\u5c06\u4e00\u7ec4\u672a\u5f52\u4e00\u5316\u7684\u5206\u6570\uff08logits\uff09\u8f6c\u6362\u6210\u6982\u7387\u5206\u5e03\uff0c\u5176\u4e2d\u6bcf\u4e2a\u5206\u6570\u901a\u8fc7\u6307\u6570\u51fd\u6570\u8f6c\u6362\u540e\uff0c\u518d\u9664\u4ee5\u6240\u6709\u8f6c\u6362\u5206\u6570\u7684\u548c\uff0c\u4ee5\u786e\u4fdd\u6240\u6709\u8f93\u51fa\u6982\u7387\u4e4b\u548c\u4e3a1\u3002\u4ee5\u4e0b\u662f\u5bf9\u8fd9\u4e2a\u51fd\u6570\u7684\u8be6\u7ec6\u4e2d\u6587\u6ce8\u91ca\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">void softmax_forward(float* probs, float* logits, int B, int T, int V) {\n    \/\/ output: probs are (B,T,V) of the probabilities (sums to 1.0 in each b,t position)\n    \/\/ input: logits is (B,T,V) of the unnormalized log probabilities\n    \/\/ \u8f93\u51fa\uff1aprobs\u662f(B,T,V)\u7ef4\u7684\u6982\u7387\u6570\u7ec4\uff08\u5728\u6bcf\u4e2ab,t\u4f4d\u7f6e\u7684\u548c\u4e3a1.0\uff09\n    \/\/ \u8f93\u5165\uff1alogits\u662f(B,T,V)\u7ef4\u7684\u672a\u5f52\u4e00\u5316\u7684\u5bf9\u6570\u6982\u7387\n    #pragma omp parallel for collapse(2)\n    for (int b = 0; b &lt; B; b++) {\n        for (int t = 0; t &lt; T; t++) {\n            \/\/ probs &lt;- softmax(logits)\n            \/\/ \u5c06softmax\u5e94\u7528\u4e8elogits\n            \/\/ \u5b9a\u4f4d\u5230\u7279\u5b9a\u6279\u6b21\u548c\u65f6\u95f4\u6b65\u7684logits\n            float* logits_bt = logits + b * T * V + t * V;\n            \/\/ \u5b9a\u4f4d\u5230\u5bf9\u5e94\u7684\u8f93\u51fa\u6982\u7387\n            float* probs_bt = probs + b * T * V + t * V;\n\n            \/\/ maxval is only calculated and subtracted for numerical stability\n            \/\/ \u4e3a\u4e86\u6570\u503c\u7a33\u5b9a\u6027\uff0c\u5148\u8ba1\u7b97\u5e76\u51cf\u53bb\u6700\u5927\u503c\n            \/\/ \u521d\u59cb\u5316\u4e3a\u4e00\u4e2a\u5f88\u5c0f\u7684\u6570\n            float maxval = -10000.0f; \/\/ TODO something better\n            for (int i = 0; i &lt; V; i++) {\n                if (logits_bt[i] &gt; maxval) {\n                    \/\/ \u627e\u5230\u6700\u5927\u7684logit\u503c\n                    maxval = logits_bt[i];\n                }\n            }\n            float sum = 0.0f;\n            for (int i = 0; i &lt; V; i++) {\n                \/\/ \u901a\u8fc7\u6307\u6570\u51fd\u6570\u8f6c\u6362\uff0c\u5e76\u51cf\u53bb\u6700\u5927\u503c\u4ee5\u63d0\u9ad8\u6570\u503c\u7a33\u5b9a\u6027\n                probs_bt[i] = expf(logits_bt[i] - maxval);\n                \/\/ \u8ba1\u7b97\u6240\u6709\u8f6c\u6362\u540e\u5206\u6570\u7684\u548c\n                sum += probs_bt[i];\n            }\n            for (int i = 0; i &lt; V; i++) {\n                \/\/ \u5f52\u4e00\u5316\u4ee5\u786e\u4fdd\u6982\u7387\u4e4b\u548c\u4e3a1\n                probs_bt[i] \/= sum;\n            }\n        }\n    }\n}\n<\/pre><\/div>\n\n\n\n<p>\u5728\u5904\u7406\u5927\u89c4\u6a21\u6570\u636e\u65f6\uff0c\u51cf\u53bb<code>logits<\/code>\u6570\u7ec4\u4e2d\u6bcf\u4e2a\u5143\u7d20\u7684\u6700\u5927\u503c\u662f\u4e00\u79cd\u5e38\u7528\u7684\u6570\u503c\u7a33\u5b9a\u6280\u5de7\uff0c\u8fd9\u53ef\u4ee5\u9632\u6b62\u5728\u8ba1\u7b97\u6307\u6570\u51fd\u6570\u65f6\u53d1\u751f\u6570\u503c\u6ea2\u51fa\u3002\u901a\u8fc7OpenMP\u7684\u5e76\u884c\u5316\u6307\u4ee4\uff0c\u8fd9\u4e2a\u8fc7\u7a0b\u8fd8\u5229\u7528\u4e86\u591a\u6838\u5904\u7406\u5668\u7684\u8ba1\u7b97\u80fd\u529b\uff0c\u4ee5\u52a0\u901fsoftmax\u51fd\u6570\u7684\u8ba1\u7b97\u3002\u8fd9\u6837\uff0c\u6bcf\u4e2a\u8f93\u5165\u5411\u91cf\uff08\u6216\u8005\u8bf4\uff0c\u6bcf\u4e2a\u65f6\u95f4\u6b65\u7684\u6240\u6709\u7c7b\u522b\u7684\u5206\u6570\uff09\u90fd\u88ab\u8f6c\u6362\u6210\u4e00\u4e2a\u6982\u7387\u5206\u5e03\uff0c\u8fd9\u4e9b\u6982\u7387\u5206\u5e03\u53ef\u7528\u4e8e\u540e\u7eed\u7684\u8bad\u7ec3\u6216\u9884\u6d4b\u8fc7\u7a0b\u4e2d\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.14 crossentropy_forward<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u5b9e\u73b0\u4e86\u4ea4\u53c9\u71b5\u635f\u5931\u51fd\u6570\uff08Cross Entropy Loss\uff09\u7684\u524d\u5411\u4f20\u64ad\u8fc7\u7a0b\uff0c\u8fd9\u662f\u6df1\u5ea6\u5b66\u4e60\u4e2d\u7528\u4e8e\u591a\u5206\u7c7b\u95ee\u9898\u7684\u4e00\u79cd\u5e38\u7528\u635f\u5931\u51fd\u6570\u3002\u5b83\u8861\u91cf\u7684\u662f\u6a21\u578b\u8f93\u51fa\u7684\u6982\u7387\u5206\u5e03\u4e0e\u771f\u5b9e\u6807\u7b7e\u4e4b\u95f4\u7684\u5dee\u5f02\u3002\u4ee5\u4e0b\u662f\u5bf9\u8fd9\u4e2a\u51fd\u6570\u7684\u8be6\u7ec6\u4e2d\u6587\u6ce8\u91ca\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">void crossentropy_forward(float* losses,\n                          float* probs, int* targets,\n                          int B, int T, int V) {\n    \/\/ output: losses is (B,T) of the individual losses at each position\n    \/\/ input: probs are (B,T,V) of the probabilities\n    \/\/ input: targets is (B,T) of integers giving the correct index in logits\n    \/\/ \u8f93\u51fa\uff1alosses\u662f(B,T)\u7ef4\u7684\uff0c\u8868\u793a\u6bcf\u4e2a\u4f4d\u7f6e\u7684\u5355\u72ec\u635f\u5931\n    \/\/ \u8f93\u5165\uff1aprobs\u662f(B,T,V)\u7ef4\u7684\u6982\u7387\n    \/\/ \u8f93\u5165\uff1atargets\u662f(B,T)\u7ef4\u7684\u6574\u6570\uff0c\u8868\u793alogits\u4e2d\u6b63\u786e\u7c7b\u522b\u7684\u7d22\u5f15\n    for (int b = 0; b &lt; B; b++) {\n        for (int t = 0; t &lt; T; t++) {\n            \/\/ loss = -log(probs[target])\n            \/\/ \u635f\u5931\u8ba1\u7b97\u4e3a\uff1a-log(probs[\u76ee\u6807\u7c7b\u522b])\n            \/\/ \u5b9a\u4f4d\u5230\u7279\u5b9a\u6279\u6b21\u548c\u65f6\u95f4\u6b65\u7684\u6982\u7387\n            float* probs_bt = probs + b * T * V + t * V;\n            \/\/ \u83b7\u53d6\u771f\u5b9e\u6807\u7b7e\u7684\u7d22\u5f15\n            int ix = targets[b * T + t];\n            \/\/ \u8ba1\u7b97\u5e76\u5b58\u50a8\u635f\u5931\n            losses[b * T + t] = -logf(probs_bt[ix]);\n        }\n    }\n}\n<\/pre><\/div>\n\n\n\n<p>\u5728\u8fd9\u91cc\uff0c<code>probs<\/code>\u6570\u7ec4\u5305\u542b\u4e86\u6a21\u578b\u5bf9\u6bcf\u4e2a\u7c7b\u522b\u9884\u6d4b\u7684\u6982\u7387\uff0c<code>targets<\/code>\u6570\u7ec4\u5305\u542b\u4e86\u6bcf\u4e2a\u6837\u672c\u7684\u771f\u5b9e\u7c7b\u522b\u7d22\u5f15\uff0c<code>losses<\/code>\u6570\u7ec4\u7528\u4e8e\u5b58\u50a8\u6bcf\u4e2a\u6837\u672c\u7684\u635f\u5931\u3002\u5bf9\u4e8e\u6bcf\u4e2a\u6837\u672c\uff0c\u5b83\u7684\u4ea4\u53c9\u71b5\u635f\u5931\u662f\u901a\u8fc7\u53d6\u771f\u5b9e\u7c7b\u522b\u5bf9\u5e94\u6982\u7387\u7684\u8d1f\u5bf9\u6570\u6765\u8ba1\u7b97\u7684\u3002\u8fd9\u610f\u5473\u7740\u5982\u679c\u6a21\u578b\u5bf9\u771f\u5b9e\u7c7b\u522b\u7684\u9884\u6d4b\u6982\u7387\u5f88\u9ad8\uff08\u63a5\u8fd11\uff09\uff0c\u635f\u5931\u5c06\u4f1a\u5f88\u5c0f\uff1b\u5982\u679c\u6a21\u578b\u5bf9\u771f\u5b9e\u7c7b\u522b\u7684\u9884\u6d4b\u6982\u7387\u5f88\u4f4e\uff08\u63a5\u8fd10\uff09\uff0c\u635f\u5931\u5c06\u4f1a\u5f88\u5927\u3002<\/p>\n\n\n\n<p>\u4ea4\u53c9\u71b5\u635f\u5931\u51fd\u6570\u662f\u4f18\u5316\u5206\u7c7b\u6a21\u578b\u5e38\u7528\u7684\u65b9\u6cd5\u4e4b\u4e00\uff0c\u56e0\u4e3a\u5b83\u76f4\u63a5\u9488\u5bf9\u6a21\u578b\u8f93\u51fa\u7684\u6982\u7387\u5206\u5e03\uff0c\u4f7f\u5f97\u6a21\u578b\u80fd\u591f\u5728\u8bad\u7ec3\u8fc7\u7a0b\u4e2d\u9010\u6e10\u5b66\u4e60\u5230\u5c06\u6b63\u786e\u7c7b\u522b\u7684\u6982\u7387\u9884\u6d4b\u5f97\u66f4\u9ad8\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.15 crossentropy_softmax_backward<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u5b9e\u73b0\u4e86\u4ea4\u53c9\u71b5\u635f\u5931\u51fd\u6570\u548csoftmax\u6fc0\u6d3b\u51fd\u6570\u7684\u8054\u5408\u53cd\u5411\u4f20\u64ad\u8fc7\u7a0b\u3002\u8fd9\u5728\u8bad\u7ec3\u6df1\u5ea6\u5b66\u4e60\u5206\u7c7b\u6a21\u578b\u65f6\u975e\u5e38\u5e38\u89c1\uff0c\u56e0\u4e3a\u5f88\u591a\u6a21\u578b\u5728\u8f93\u51fa\u5c42\u4f7f\u7528softmax\u51fd\u6570\u5c06logits\u8f6c\u6362\u4e3a\u6982\u7387\u5206\u5e03\uff0c\u7136\u540e\u4f7f\u7528\u4ea4\u53c9\u71b5\u635f\u5931\u51fd\u6570\u6765\u8ba1\u7b97\u9884\u6d4b\u7684\u6982\u7387\u5206\u5e03\u4e0e\u771f\u5b9e\u6807\u7b7e\u4e4b\u95f4\u7684\u5dee\u5f02\u3002\u4ee5\u4e0b\u662f\u5bf9\u8fd9\u4e2a\u51fd\u6570\u7684\u8be6\u7ec6\u4e2d\u6587\u6ce8\u91ca\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">void crossentropy_softmax_backward(float* dlogits,\n                           float* dlosses, float* probs, int* targets,\n                           int B, int T, int V) {\n    \/\/ backwards through both softmax and crossentropy\n    \/\/ \u540c\u65f6\u5bf9softmax\u548c\u4ea4\u53c9\u71b5\u6267\u884c\u53cd\u5411\u4f20\u64ad\n    for (int b = 0; b &lt; B; b++) {\n        for (int t = 0; t &lt; T; t++) {\n            \/\/ \u5b9a\u4f4d\u5230\u7279\u5b9a\u6279\u6b21\u548c\u65f6\u95f4\u6b65\u7684logits\u68af\u5ea6\n            float* dlogits_bt = dlogits + b * T * V + t * V;\n            \/\/ \u5b9a\u4f4d\u5230\u7279\u5b9a\u6279\u6b21\u548c\u65f6\u95f4\u6b65\u7684\u6982\u7387\n            float* probs_bt = probs + b * T * V + t * V;\n            \/\/ \u83b7\u53d6\u8be5\u4f4d\u7f6e\u7684\u635f\u5931\u68af\u5ea6\n            float dloss = dlosses[b * T + t];\n            \/\/ \u83b7\u53d6\u771f\u5b9e\u7c7b\u522b\u7684\u7d22\u5f15\n            int ix = targets[b * T + t];\n\n            for (int i = 0; i &lt; V; i++) {\n                \/\/ \u5f53\u524d\u7c7b\u522b\u7684\u9884\u6d4b\u6982\u7387\n                float p = probs_bt[i];\n                \/\/ \u6307\u793a\u5668\u51fd\u6570\uff0c\u5982\u679ci\u662f\u771f\u5b9e\u7c7b\u522b\u5219\u4e3a1\uff0c\u5426\u5219\u4e3a0\n                float indicator = i == ix ? 1.0f : 0.0f;\n                \/\/ \u66f4\u65b0logits\u7684\u68af\u5ea6\n                dlogits_bt[i] += (p - indicator) * dloss;\n            }\n        }\n    }\n}\n<\/pre><\/div>\n\n\n\n<p>\u5728\u8fd9\u4e2a\u8fc7\u7a0b\u4e2d\uff0c<code>dlogits<\/code>\u6570\u7ec4\u5b58\u50a8\u4e86\u6bcf\u4e2alogit\u76f8\u5bf9\u4e8e\u635f\u5931\u7684\u68af\u5ea6\uff0c<code>dlosses<\/code>\u6570\u7ec4\u5305\u542b\u4e86\u6bcf\u4e2a\u6837\u672c\u635f\u5931\u76f8\u5bf9\u4e8e\u6a21\u578b\u8f93\u51fa\u7684\u68af\u5ea6\uff0c<code>probs<\/code>\u6570\u7ec4\u5305\u542b\u4e86\u6a21\u578b\u9884\u6d4b\u7684\u6982\u7387\u5206\u5e03\uff0c<code>targets<\/code>\u6570\u7ec4\u5305\u542b\u4e86\u6bcf\u4e2a\u6837\u672c\u7684\u771f\u5b9e\u7c7b\u522b\u7d22\u5f15\u3002\u53cd\u5411\u4f20\u64ad\u7684\u76ee\u7684\u662f\u8ba1\u7b97<code>dlogits<\/code>\uff0c\u5373\u6bcf\u4e2a\u8f93\u51falogit\u76f8\u5bf9\u4e8e\u635f\u5931\u51fd\u6570\u7684\u68af\u5ea6\uff0c\u8fd9\u4e9b\u68af\u5ea6\u5c06\u88ab\u7528\u6765\u66f4\u65b0\u6a21\u578b\u53c2\u6570\u3002<\/p>\n\n\n\n<p>\u5bf9\u4e8e\u6bcf\u4e2a\u7c7b\u522b<code>i<\/code>\uff0c\u5176logit\u7684\u68af\u5ea6\u662f\u7531\u9884\u6d4b\u6982\u7387<code>p<\/code>\u51cf\u53bb\u6307\u793a\u5668\u51fd\u6570<code>indicator<\/code>\uff08\u5f53\u4e14\u4ec5\u5f53<code>i<\/code>\u662f\u771f\u5b9e\u7c7b\u522b\u65f6\u4e3a1\uff0c\u5426\u5219\u4e3a0\uff09\u518d\u4e58\u4ee5\u8be5\u4f4d\u7f6e\u7684\u635f\u5931\u68af\u5ea6<code>dloss<\/code>\u8ba1\u7b97\u5f97\u6765\u3002\u8fd9\u79cd\u8ba1\u7b97\u65b9\u5f0f\u7b80\u6d01\u5730\u8868\u8fbe\u4e86softmax\u51fd\u6570\u548c\u4ea4\u53c9\u71b5\u635f\u5931\u7684\u8054\u5408\u68af\u5ea6\uff0c\u5141\u8bb8\u6a21\u578b\u901a\u8fc7\u68af\u5ea6\u4e0b\u964d\u7b97\u6cd5\u8fdb\u884c\u5b66\u4e60\u548c\u4f18\u5316\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.16 ParameterTensors<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u5b9a\u4e49\u4e86GPT-2\u6a21\u578b\u4e2d\u7528\u5230\u7684\u53c2\u6570\u7ed3\u6784\u4f53<code>ParameterTensors<\/code>\u3002GPT-2\u662f\u4e00\u4e2a\u57fa\u4e8eTransformer\u7684\u9884\u8bad\u7ec3\u8bed\u8a00\u6a21\u578b\uff0c\u5e7f\u6cdb\u7528\u4e8e\u5404\u79cd\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4efb\u52a1\u3002<code>ParameterTensors<\/code>\u7ed3\u6784\u4f53\u4e2d\u5305\u542b\u4e86\u6a21\u578b\u6240\u6709\u5fc5\u9700\u7684\u6743\u91cd\u548c\u504f\u7f6e\u53c2\u6570\u3002\u4ee5\u4e0b\u662f\u5bf9\u6bcf\u4e2a\u6210\u5458\u7684\u7b80\u8981\u8bf4\u660e\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">#define NUM_PARAMETER_TENSORS 16\ntypedef struct {\n    \/\/ wte\uff1a\u8bcd\u5d4c\u5165\u6743\u91cd\uff08Word Token Embeddings\uff09\uff0c\u7ef4\u5ea6\u662f(V, C)\uff0c\n    \/\/ \u5176\u4e2dV\u662f\u8bcd\u6c47\u8868\u5927\u5c0f\uff0cC\u662f\u5d4c\u5165\u7ef4\u5ea6\u3002\n    float* wte; \/\/ (V, C)\n    \/\/ wpe\uff1a\u4f4d\u7f6e\u5d4c\u5165\u6743\u91cd\uff08Word Position Embeddings\uff09\uff0c\u7ef4\u5ea6\u662f(maxT, C)\uff0c\n    \/\/ \u5176\u4e2dmaxT\u662f\u6a21\u578b\u53ef\u4ee5\u5904\u7406\u7684\u6700\u5927\u5e8f\u5217\u957f\u5ea6\u3002\n    float* wpe; \/\/ (maxT, C)\n    \/\/ ln1w\u548cln1b\uff1a\u7b2c\u4e00\u5c42\u5f52\u4e00\u5316\u7684\u6743\u91cd\u548c\u504f\u7f6e\uff0c\u7ef4\u5ea6\u5206\u522b\u662f(L, C)\u3002\n    \/\/ L\u662fTransformer\u5c42\u7684\u6570\u91cf\u3002\n    float* ln1w; \/\/ (L, C)\n    float* ln1b; \/\/ (L, C)\n    \/\/ qkvw\u548cqkvb\uff1a\u67e5\u8be2\uff08Query\uff09\u3001\u952e\uff08Key\uff09\u548c\u503c\uff08Value\uff09\u7684\u6743\u91cd\u548c\u504f\u7f6e\uff0c\n    \/\/ \u7528\u4e8e\u81ea\u6ce8\u610f\u529b\u673a\u5236\uff0c\u7ef4\u5ea6\u5206\u522b\u662f(L, 3C, C)\u548c(L, 3C)\u3002\n    \/\/ \u6bcf\u5c42\u6709\u4e09\u4e2aC\u7ef4\u7684\u5411\u91cf\u5206\u522b\u5bf9\u5e94\u67e5\u8be2\u3001\u952e\u548c\u503c\u3002\n    float* qkvw; \/\/ (L, 3*C, C)\n    float* qkvb; \/\/ (L, 3*C)\n    \/\/ attprojw\u548cattprojb\uff1a\u6ce8\u610f\u529b\u8f93\u51fa\u7684\u6295\u5f71\u6743\u91cd\u548c\u504f\u7f6e\uff0c\u7ef4\u5ea6\u5206\u522b\u662f(L, C, C)\u3002\n    float* attprojw; \/\/ (L, C, C)\n    float* attprojb; \/\/ (L, C)\n    \/\/ ln2w\u548cln2b\uff1a\u7b2c\u4e8c\u5c42\u5f52\u4e00\u5316\u7684\u6743\u91cd\u548c\u504f\u7f6e\uff0c\u7ef4\u5ea6\u5206\u522b\u662f(L, C)\u3002\n    float* ln2w; \/\/ (L, C)\n    float* ln2b; \/\/ (L, C)\n    \/\/ fcw\u548cfcb\uff1a\u524d\u9988\u7f51\u7edc\uff08Feedforward Network\uff09\u7684\u6743\u91cd\u548c\u504f\u7f6e\uff0c\u7ef4\u5ea6\u5206\u522b\u662f(L, 4C, C)\u548c(L, 4C)\u3002\n    \/\/ \u524d\u9988\u7f51\u7edc\u4e2d\u4f7f\u7528\u4e86\u6269\u5c55\u7684\u5185\u90e8\u7ef4\u5ea6\uff084*C\uff09\u3002\n    float* fcw; \/\/ (L, 4*C, C)\n    float* fcb; \/\/ (L, 4*C)\n    \/\/ fcprojw\u548cfcprojb\uff1a\u524d\u9988\u7f51\u7edc\u8f93\u51fa\u7684\u6295\u5f71\u6743\u91cd\u548c\u504f\u7f6e\uff0c\u7ef4\u5ea6\u5206\u522b\u662f(L, C, 4*C)\u3002\n    float* fcprojw; \/\/ (L, C, 4*C)\n    float* fcprojb; \/\/ (L, C)\n    \/\/ lnfw\u548clnfb\uff1a\u6700\u540e\u4e00\u5c42\u5f52\u4e00\u5316\u7684\u6743\u91cd\u548c\u504f\u7f6e\uff0c\u7ef4\u5ea6\u5206\u522b\u662f(C)\u3002\n    float* lnfw; \/\/ (C)\n    float* lnfb; \/\/ (C)\n} ParameterTensors;\n<\/pre><\/div>\n\n\n\n<p>\u8fd9\u4e2a\u7ed3\u6784\u4f53\u662f\u6784\u5efa\u548c\u64cd\u4f5cGPT-2\u6a21\u578b\u7684\u5173\u952e\uff0c\u5b83\u786e\u4fdd\u4e86\u6240\u6709\u5fc5\u8981\u7684\u6a21\u578b\u53c2\u6570\u90fd\u53ef\u4ee5\u88ab\u6709\u6548\u5730\u5b58\u50a8\u548c\u8bbf\u95ee\u3002\u8fd9\u4e9b\u53c2\u6570\u5728\u6a21\u578b\u7684\u8bad\u7ec3\u8fc7\u7a0b\u4e2d\u4f1a\u4e0d\u65ad\u66f4\u65b0\uff0c\u4ee5\u4f7f\u6a21\u578b\u80fd\u591f\u66f4\u597d\u5730\u5b66\u4e60\u548c\u7406\u89e3\u8bed\u8a00\u6570\u636e\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.17 malloc_and_point_parameters<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u5b9e\u73b0\u4e86\u4e3aGPT-2\u6a21\u578b\u7684\u53c2\u6570\u5206\u914d\u5185\u5b58\uff0c\u5e76\u5c06\u5404\u4e2a\u53c2\u6570\u5f20\u91cf\u6307\u5411\u6b63\u786e\u7684\u5185\u5b58\u4f4d\u7f6e\u7684\u529f\u80fd\u3002\u901a\u8fc7\u8fd9\u79cd\u65b9\u5f0f\uff0c\u6240\u6709\u6a21\u578b\u53c2\u6570\u90fd\u88ab\u8fde\u7eed\u5730\u5b58\u50a8\u5728\u4e00\u5757\u5185\u5b58\u533a\u57df\u4e2d\uff0c\u800c<code>ParameterTensors<\/code>\u7ed3\u6784\u4f53\u4e2d\u7684\u6307\u9488\u5219\u6307\u5411\u8fd9\u5757\u5185\u5b58\u4e2d\u5bf9\u5e94\u53c2\u6570\u7684\u4f4d\u7f6e\u3002\u8fd9\u79cd\u65b9\u6cd5\u6709\u52a9\u4e8e\u63d0\u9ad8\u5185\u5b58\u4f7f\u7528\u6548\u7387\u548c\u7b80\u5316\u53c2\u6570\u7ba1\u7406\u3002\u4ee5\u4e0b\u662f\u5bf9\u8fd9\u4e2a\u51fd\u6570\u7684\u8be6\u7ec6\u4e2d\u6587\u6ce8\u91ca\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">\/\/ allocate memory for the parameters and point the individual tensors to the right places\n\/\/ \u4e3a\u53c2\u6570\u5206\u914d\u5185\u5b58\uff0c\u5e76\u5c06\u5404\u4e2a\u5f20\u91cf\u6307\u5411\u6b63\u786e\u7684\u4f4d\u7f6e\u3002\nfloat* malloc_and_point_parameters(ParameterTensors* params, size_t* param_sizes) {\n    size_t num_parameters = 0;\n    \/\/ \u904d\u5386\u6240\u6709\u53c2\u6570\u5f20\u91cf\uff0c\u8ba1\u7b97\u603b\u7684\u53c2\u6570\u6570\u91cf\n    for (size_t i = 0; i &lt; NUM_PARAMETER_TENSORS; i++) {\n        num_parameters += param_sizes[i];\n    }\n    \/\/ malloc all parameters all at once\n    \/\/ \u4e00\u6b21\u6027\u4e3a\u6240\u6709\u53c2\u6570\u5206\u914d\u8db3\u591f\u7684\u8fde\u7eed\u5185\u5b58\u7a7a\u95f4\n    float* params_memory = (float*)malloc(num_parameters * sizeof(float));\n    \/\/ assign all the tensors\n     \/\/ \u5c06\u5404\u4e2a\u53c2\u6570\u5f20\u91cf\u7684\u6307\u9488\u6307\u5411\u5206\u914d\u7684\u5185\u5b58\u4e2d\u6b63\u786e\u7684\u4f4d\u7f6e\n    float** ptrs[] = {\n        &amp;params-&gt;wte, &amp;params-&gt;wpe, &amp;params-&gt;ln1w, &amp;params-&gt;ln1b, &amp;params-&gt;qkvw, &amp;params-&gt;qkvb,\n        &amp;params-&gt;attprojw, &amp;params-&gt;attprojb, &amp;params-&gt;ln2w, &amp;params-&gt;ln2b, &amp;params-&gt;fcw, &amp;params-&gt;fcb,\n        &amp;params-&gt;fcprojw, &amp;params-&gt;fcprojb, &amp;params-&gt;lnfw, &amp;params-&gt;lnfb\n    };\n    \/\/ \u4f7f\u7528\u8fed\u4ee3\u5668\u904d\u5386\u5206\u914d\u7684\u5185\u5b58\u5e76\u4e3a\u6bcf\u4e2a\u53c2\u6570\u5f20\u91cf\u8d4b\u503c\n    float* params_memory_iterator = params_memory;\n    for (size_t i = 0; i &lt; NUM_PARAMETER_TENSORS; i++) {\n        \/\/ \u8bbe\u7f6e\u6307\u9488\u6307\u5411\u5f53\u524d\u53c2\u6570\u7684\u4f4d\u7f6e\n        *(ptrs[i]) = params_memory_iterator;\n        \/\/ \u66f4\u65b0\u8fed\u4ee3\u5668\u4ee5\u6307\u5411\u4e0b\u4e00\u6bb5\u53c2\u6570\u5185\u5b58\n        params_memory_iterator += param_sizes[i];\n    }\n    \/\/ \u8fd4\u56de\u5206\u914d\u7684\u5185\u5b58\u5757\u7684\u6307\u9488\uff0c\u7528\u4e8e\u540e\u7eed\u7684\u91ca\u653e\u7b49\u64cd\u4f5c\n    return params_memory;\n}\n<\/pre><\/div>\n\n\n\n<p>\u8fd9\u4e2a\u51fd\u6570\u7684\u5173\u952e\u5728\u4e8e\u5b83\u5141\u8bb8<code>ParameterTensors<\/code>\u7ed3\u6784\u4f53\u4e2d\u7684\u6240\u6709\u53c2\u6570\u5f20\u91cf\u901a\u8fc7\u5355\u4e00\u7684\u5185\u5b58\u5206\u914d\u6765\u7ba1\u7406\uff0c\u800c\u4e0d\u662f\u4e3a\u6bcf\u4e2a\u5f20\u91cf\u5355\u72ec\u5206\u914d\u5185\u5b58\u3002\u8fd9\u6837\u4e0d\u4ec5\u51cf\u5c11\u4e86\u5185\u5b58\u788e\u7247\uff0c\u8fd8\u7b80\u5316\u4e86\u5185\u5b58\u7ba1\u7406\u3002\u901a\u8fc7<code>param_sizes<\/code>\u6570\u7ec4\u6307\u5b9a\u6bcf\u4e2a\u53c2\u6570\u5f20\u91cf\u6240\u9700\u7684\u5185\u5b58\u5927\u5c0f\uff0c\u5e76\u5229\u7528\u6307\u9488\u6570\u7ec4<code>ptrs<\/code>\u5c06\u6bcf\u4e2a\u53c2\u6570\u5f20\u91cf\u6307\u5411\u5206\u914d\u7684\u5185\u5b58\u5757\u4e2d\u7684\u6b63\u786e\u4f4d\u7f6e\u3002\u6700\u7ec8\uff0c\u51fd\u6570\u8fd4\u56de\u6307\u5411\u5206\u914d\u5185\u5b58\u5757\u7684\u6307\u9488\uff0c\u8fd9\u5141\u8bb8\u5728\u4e0d\u9700\u8981\u8fd9\u4e9b\u53c2\u6570\u65f6\u6b63\u786e\u5730\u91ca\u653e\u5185\u5b58\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.18 ActivationTensors<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u4e2a\u7ed3\u6784\u4f53<code>ActivationTensors<\/code>\u5b9a\u4e49\u4e86\u5728GPT-2\u6a21\u578b\u6216\u7c7b\u4f3c\u7684Transformer\u6a21\u578b\u4e2d\u4f7f\u7528\u7684\u6fc0\u6d3b\u5f20\u91cf\u3002\u8fd9\u4e9b\u6fc0\u6d3b\u5f20\u91cf\u5b58\u50a8\u4e86\u6a21\u578b\u7684\u4e2d\u95f4\u8f93\u51fa\uff0c\u5982\u7f16\u7801\u540e\u7684\u5d4c\u5165\u3001\u5c42\u5f52\u4e00\u5316\u7684\u7ed3\u679c\u3001\u81ea\u6ce8\u610f\u529b\u673a\u5236\u7684\u8f93\u51fa\u7b49\u3002\u4ee5\u4e0b\u662f\u5bf9\u6bcf\u4e2a\u6210\u5458\u7684\u7b80\u8981\u8bf4\u660e\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">#define NUM_ACTIVATION_TENSORS 23\ntypedef struct {\n    \/\/ encoded\uff1a\u7f16\u7801\u540e\u7684\u5d4c\u5165\uff0c\u7ef4\u5ea6\u4e3a(B, T, C)\uff0cB\u662f\u6279\u91cf\u5927\u5c0f\uff0cT\u662f\u5e8f\u5217\u957f\u5ea6\uff0cC\u662f\u9690\u85cf\u5c42\u5927\u5c0f\u3002\n    float* encoded; \/\/ (B, T, C)\n    \/\/ ln1, ln1_mean, ln1_rstd\uff1a\u7b2c\u4e00\u5c42\u5f52\u4e00\u5316\u53ca\u5176\u7edf\u8ba1\u91cf\uff0c\u5206\u522b\u5bf9\u5e94\u4e8e\u5f52\u4e00\u5316\u540e\u7684\u503c\u3001\u5747\u503c\u548c\u9006\u6807\u51c6\u5dee\u3002\n    float* ln1; \/\/ (L, B, T, C)\n    float* ln1_mean; \/\/ (L, B, T)\n    float* ln1_rstd; \/\/ (L, B, T)\n    \/\/ qkv\uff1a\u67e5\u8be2\uff08Query\uff09\u3001\u952e\uff08Key\uff09\u3001\u503c\uff08Value\uff09\u7684\u5408\u5e76\u5f20\u91cf\uff0c\u7528\u4e8e\u81ea\u6ce8\u610f\u529b\u8ba1\u7b97\u3002\n    float* qkv; \/\/ (L, B, T, 3*C)\n    \/\/ atty\uff1a\u81ea\u6ce8\u610f\u529b\u7684\u8f93\u51fa\u3002\n    float* atty; \/\/ (L, B, T, C)\n    \/\/ preatt, att\uff1a\u81ea\u6ce8\u610f\u529b\u8ba1\u7b97\u4e2d\u7684\u9884\u6ce8\u610f\u529b\u5206\u6570\u548c\u6ce8\u610f\u529b\u5206\u6570\u3002\n    float* preatt; \/\/ (L, B, NH, T, T)\n    float* att; \/\/ (L, B, NH, T, T)\n    \/\/ attproj\uff1a\u81ea\u6ce8\u610f\u529b\u8f93\u51fa\u7684\u6295\u5f71\u3002\n    float* attproj; \/\/ (L, B, T, C)\n    \/\/ residual2\uff1a\u7b2c\u4e8c\u4e2a\u6b8b\u5dee\u8fde\u63a5\u540e\u7684\u7ed3\u679c\u3002\n    float* residual2; \/\/ (L, B, T, C)\n    \/\/ ln2, ln2_mean, ln2_rstd\uff1a\u7b2c\u4e8c\u5c42\u5f52\u4e00\u5316\u53ca\u5176\u7edf\u8ba1\u91cf\u3002\n    float* ln2; \/\/ (L, B, T, C)\n    float* ln2_mean; \/\/ (L, B, T)\n    float* ln2_rstd; \/\/ (L, B, T)\n    \/\/ fch, fch_gelu\uff1a\u524d\u9988\u7f51\u7edc\u7684\u8f93\u51fa\u548c\u7ecf\u8fc7GeLU\u6fc0\u6d3b\u51fd\u6570\u540e\u7684\u7ed3\u679c\u3002\n    float* fch; \/\/ (L, B, T, 4*C)\n    float* fch_gelu; \/\/ (L, B, T, 4*C)\n    \/\/ fcproj\uff1a\u524d\u9988\u7f51\u7edc\u8f93\u51fa\u7684\u6295\u5f71\u3002\n    float* fcproj; \/\/ (L, B, T, C)\n    \/\/ residual3\uff1a\u7b2c\u4e09\u4e2a\u6b8b\u5dee\u8fde\u63a5\u540e\u7684\u7ed3\u679c\u3002\n    float* residual3; \/\/ (L, B, T, C)\n    \/\/ lnf, lnf_mean, lnf_rstd\uff1a\u6700\u7ec8\u8f93\u51fa\u524d\u7684\u5c42\u5f52\u4e00\u5316\u53ca\u5176\u7edf\u8ba1\u91cf\u3002\n    float* lnf; \/\/ (B, T, C)\n    float* lnf_mean; \/\/ (B, T)\n    float* lnf_rstd; \/\/ (B, T)\n    \/\/ logits\uff1a\u6a21\u578b\u7684\u6700\u7ec8\u8f93\u51falogits\uff0c\u5373\u672a\u5f52\u4e00\u5316\u7684\u5206\u6570\u3002\n    float* logits; \/\/ (B, T, V)\n    \/\/ probs\uff1a\u901a\u8fc7softmax\u5f52\u4e00\u5316\u540e\u7684\u6982\u7387\u5206\u5e03\u3002\n    float* probs; \/\/ (B, T, V)\n    \/\/ losses\uff1a\u6bcf\u4e2a\u65f6\u95f4\u6b65\u7684\u635f\u5931\u3002\n    float* losses; \/\/ (B, T)\n} ActivationTensors;\n<\/pre><\/div>\n\n\n\n<p>\u8fd9\u4e2a\u7ed3\u6784\u4f53\u6355\u83b7\u4e86\u6a21\u578b\u4ece\u8f93\u5165\u5230\u8f93\u51fa\u7684\u6574\u4e2a\u6d41\u7a0b\u4e2d\u7684\u5173\u952e\u4e2d\u95f4\u72b6\u6001\uff0c\u4e3a\u6a21\u578b\u7684\u524d\u5411\u4f20\u64ad\u548c\u53cd\u5411\u4f20\u64ad\u63d0\u4f9b\u4e86\u5fc5\u8981\u7684\u6570\u636e\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.19 malloc_and_point_activations<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u5b9e\u73b0\u4e86\u4e3aGPT-2\u6a21\u578b\u4e2d\u5b9a\u4e49\u7684\u6fc0\u6d3b\u5f20\u91cf\u5206\u914d\u5185\u5b58\uff0c\u5e76\u786e\u4fdd\u5404\u4e2a\u6fc0\u6d3b\u5f20\u91cf\u6307\u5411\u6b63\u786e\u7684\u5185\u5b58\u4f4d\u7f6e\u3002\u8fd9\u6837\u505a\u65e8\u5728\u7b80\u5316\u6a21\u578b\u4e2d\u6fc0\u6d3b\u5f20\u91cf\u7684\u7ba1\u7406\uff0c\u5e76\u786e\u4fdd\u6240\u6709\u6fc0\u6d3b\u6570\u636e\u90fd\u5b58\u50a8\u5728\u8fde\u7eed\u7684\u5185\u5b58\u5757\u4e2d\u3002\u4ee5\u4e0b\u662f\u5bf9\u8fd9\u4e2a\u51fd\u6570\u7684\u8be6\u7ec6\u4e2d\u6587\u6ce8\u91ca\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">float* malloc_and_point_activations(ActivationTensors* acts, size_t* act_sizes) {\n    size_t num_activations = 0;\n    \/\/ \u904d\u5386\u6240\u6709\u6fc0\u6d3b\u5f20\u91cf\uff0c\u8ba1\u7b97\u603b\u7684\u6fc0\u6d3b\u6570\u91cf\n    for (size_t i = 0; i &lt; NUM_ACTIVATION_TENSORS; i++) {\n        num_activations += act_sizes[i];\n    }\n    \/\/ \u4e00\u6b21\u6027\u4e3a\u6240\u6709\u6fc0\u6d3b\u5f20\u91cf\u5206\u914d\u8db3\u591f\u7684\u8fde\u7eed\u5185\u5b58\u7a7a\u95f4\n    float* acts_memory = (float*)malloc(num_activations * sizeof(float));\n    \/\/ \u5c06\u5404\u4e2a\u6fc0\u6d3b\u5f20\u91cf\u7684\u6307\u9488\u6307\u5411\u5206\u914d\u7684\u5185\u5b58\u4e2d\u6b63\u786e\u7684\u4f4d\u7f6e\n    float** ptrs[] = {\n        &amp;acts-&gt;encoded, &amp;acts-&gt;ln1, &amp;acts-&gt;ln1_mean, &amp;acts-&gt;ln1_rstd, &amp;acts-&gt;qkv, &amp;acts-&gt;atty,\n        &amp;acts-&gt;preatt, &amp;acts-&gt;att, &amp;acts-&gt;attproj, &amp;acts-&gt;residual2, &amp;acts-&gt;ln2, &amp;acts-&gt;ln2_mean,\n        &amp;acts-&gt;ln2_rstd, &amp;acts-&gt;fch, &amp;acts-&gt;fch_gelu, &amp;acts-&gt;fcproj, &amp;acts-&gt;residual3, &amp;acts-&gt;lnf,\n        &amp;acts-&gt;lnf_mean, &amp;acts-&gt;lnf_rstd, &amp;acts-&gt;logits, &amp;acts-&gt;probs, &amp;acts-&gt;losses\n    };\n    \/\/ \u4f7f\u7528\u8fed\u4ee3\u5668\u904d\u5386\u5206\u914d\u7684\u5185\u5b58\u5e76\u4e3a\u6bcf\u4e2a\u6fc0\u6d3b\u5f20\u91cf\u8d4b\u503c\n    float* acts_memory_iterator = acts_memory;\n    for (size_t i = 0; i &lt; NUM_ACTIVATION_TENSORS; i++) {\n        \/\/ \u8bbe\u7f6e\u6307\u9488\u6307\u5411\u5f53\u524d\u6fc0\u6d3b\u5f20\u91cf\u7684\u4f4d\u7f6e\n        *(ptrs[i]) = acts_memory_iterator;\n        \/\/ \u66f4\u65b0\u8fed\u4ee3\u5668\u4ee5\u6307\u5411\u4e0b\u4e00\u6bb5\u6fc0\u6d3b\u6570\u636e\u7684\u5185\u5b58\n        acts_memory_iterator += act_sizes[i];\n    }\n    \/\/ \u8fd4\u56de\u5206\u914d\u7684\u5185\u5b58\u5757\u7684\u6307\u9488\uff0c\u7528\u4e8e\u540e\u7eed\u7684\u91ca\u653e\u7b49\u64cd\u4f5c\n    return acts_memory;\n}\n<\/pre><\/div>\n\n\n\n<p>\u901a\u8fc7\u8fd9\u79cd\u65b9\u6cd5\uff0c<code>ActivationTensors<\/code>\u7ed3\u6784\u4f53\u4e2d\u7684\u6240\u6709\u6fc0\u6d3b\u5f20\u91cf\u901a\u8fc7\u5355\u4e00\u7684\u5185\u5b58\u5206\u914d\u6765\u7ba1\u7406\uff0c\u800c\u4e0d\u662f\u4e3a\u6bcf\u4e2a\u5f20\u91cf\u5355\u72ec\u5206\u914d\u5185\u5b58\u3002\u8fd9\u6837\u4e0d\u4ec5\u51cf\u5c11\u4e86\u5185\u5b58\u788e\u7247\uff0c\u8fd8\u7b80\u5316\u4e86\u5185\u5b58\u7ba1\u7406\u3002<code>act_sizes<\/code>\u6570\u7ec4\u6307\u5b9a\u4e86\u6bcf\u4e2a\u6fc0\u6d3b\u5f20\u91cf\u6240\u9700\u7684\u5185\u5b58\u5927\u5c0f\uff0c<code>ptrs<\/code>\u6307\u9488\u6570\u7ec4\u7528\u4e8e\u5c06\u6bcf\u4e2a\u6fc0\u6d3b\u5f20\u91cf\u6307\u5411\u5206\u914d\u7684\u5185\u5b58\u5757\u4e2d\u7684\u6b63\u786e\u4f4d\u7f6e\u3002\u6700\u7ec8\uff0c\u51fd\u6570\u8fd4\u56de\u6307\u5411\u5206\u914d\u5185\u5b58\u5757\u7684\u6307\u9488\uff0c\u5141\u8bb8\u5728\u4e0d\u9700\u8981\u8fd9\u4e9b\u6fc0\u6d3b\u6570\u636e\u65f6\u6b63\u786e\u5730\u91ca\u653e\u5185\u5b58\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.20 GPT2Config<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u4e2a\u7ed3\u6784\u4f53<code>GPT2Config<\/code>\u5b9a\u4e49\u4e86GPT-2\u6a21\u578b\u7684\u914d\u7f6e\u53c2\u6570\u3002\u8fd9\u4e9b\u53c2\u6570\u662f\u5728\u6a21\u578b\u6784\u5efa\u548c\u8bad\u7ec3\u65f6\u5fc5\u9700\u7684\u57fa\u672c\u8bbe\u7f6e\uff0c\u5b83\u4eec\u786e\u5b9a\u4e86\u6a21\u578b\u7684\u5927\u5c0f\u3001\u590d\u6742\u5ea6\u548c\u5904\u7406\u80fd\u529b\u3002\u4ee5\u4e0b\u662f\u5bf9\u6bcf\u4e2a\u6210\u5458\u7684\u7b80\u8981\u8bf4\u660e\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">typedef struct {\n    \/\/ max_seq_len\uff1a\u6700\u5927\u5e8f\u5217\u957f\u5ea6\uff0c\u5373\u6a21\u578b\u80fd\u591f\u5904\u7406\u7684\u6700\u5927\u8f93\u5165\u957f\u5ea6\u3002\n    \/\/ \u4f8b\u5982\uff0c1024\u8868\u793a\u6a21\u578b\u80fd\u591f\u5904\u7406\u7684\u6700\u5927\u5355\u8bcd\u6216token\u6570\u91cf\u4e3a1024\u3002\n    int max_seq_len; \/\/ max sequence length, e.g. 1024\n    \/\/vocab_size\uff1a\u8bcd\u6c47\u8868\u5927\u5c0f\uff0c\u5373\u6a21\u578b\u80fd\u591f\u8bc6\u522b\u7684\u552f\u4e00\u5355\u8bcd\u6216token\u7684\u6570\u91cf\u3002\n    \/\/ \u4f8b\u5982\uff0c50257\u8868\u793a\u6a21\u578b\u7684\u8bcd\u6c47\u8868\u4e2d\u670950257\u4e2a\u4e0d\u540c\u7684token\u3002\n    int vocab_size; \/\/ vocab size, e.g. 50257\n    \/\/ num_layers\uff1a\u6a21\u578b\u4e2dTransformer\u5c42\u7684\u6570\u91cf\u3002\n    \/\/ \u4f8b\u5982\uff0c12\u8868\u793a\u6a21\u578b\u753112\u4e2aTransformer\u5c42\u7ec4\u6210\u3002\n    int num_layers; \/\/ number of layers, e.g. 12\n    \/\/ num_heads\uff1a\u81ea\u6ce8\u610f\u529b\uff08Self-attention\uff09\u673a\u5236\u4e2d\u7684\u5934\u6570\u3002\n    \/\/ \u4f8b\u5982\uff0c12\u8868\u793a\u6bcf\u4e2aTransformer\u5c42\u7684\u81ea\u6ce8\u610f\u529b\u673a\u5236\u5305\u542b12\u4e2a\u5934\u3002\n    int num_heads; \/\/ number of heads in attention, e.g. 12\n    \/\/ channels\uff1a\u901a\u9053\u6570\uff0c\u4e5f\u53ef\u4ee5\u7406\u89e3\u4e3a\u9690\u85cf\u5c42\u7684\u7ef4\u5ea6\u3002\n    \/\/ \u4f8b\u5982\uff0c768\u8868\u793a\u6bcf\u4e2aTransformer\u5c42\u7684\u8f93\u51fa\u7ef4\u5ea6\u4e3a768\u3002\n    int channels; \/\/ number of channels, e.g. 768\n} GPT2Config;\n<\/pre><\/div>\n\n\n\n<p>\u8fd9\u4e9b\u914d\u7f6e\u53c2\u6570\u5171\u540c\u5b9a\u4e49\u4e86GPT-2\u6a21\u578b\u7684\u67b6\u6784\u548c\u80fd\u529b\uff0c\u5f71\u54cd\u6a21\u578b\u7684\u8868\u8fbe\u80fd\u529b\u3001\u53c2\u6570\u6570\u91cf\u548c\u8ba1\u7b97\u590d\u6742\u5ea6\u3002\u5728\u5b9e\u9645\u4f7f\u7528\u4e2d\uff0c\u53ef\u4ee5\u6839\u636e\u5177\u4f53\u4efb\u52a1\u7684\u9700\u6c42\u548c\u53ef\u7528\u8ba1\u7b97\u8d44\u6e90\u8c03\u6574\u8fd9\u4e9b\u53c2\u6570\u4ee5\u8fbe\u5230\u6700\u4f73\u6548\u679c\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.21 GPT2<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u4e2a\u7ed3\u6784\u4f53<code>GPT2<\/code>\u5b9a\u4e49\u4e86GPT-2\u6a21\u578b\u7684\u6574\u4f53\u7ed3\u6784\uff0c\u5305\u62ec\u6a21\u578b\u914d\u7f6e\u3001\u53c2\u6570\u3001\u6fc0\u6d3b\u503c\u53ca\u5176\u68af\u5ea6\u7b49\u5173\u952e\u7ec4\u6210\u90e8\u5206\u3002\u8fd9\u662f\u4e00\u4e2a\u5168\u9762\u7684\u6570\u636e\u7ed3\u6784\uff0c\u65e8\u5728\u6355\u83b7\u8bad\u7ec3\u548c\u63a8\u65ad\u8fc7\u7a0b\u4e2d\u6240\u9700\u7684\u6240\u6709\u4fe1\u606f\u3002\u4ee5\u4e0b\u662f\u5bf9\u5404\u4e2a\u6210\u5458\u7684\u8be6\u7ec6\u89e3\u91ca\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">typedef struct {\n    \/\/ config\uff1a\u6a21\u578b\u7684\u914d\u7f6e\uff0c\u5305\u542b\u4e86\u6700\u5927\u5e8f\u5217\u957f\u5ea6\u3001\u8bcd\u6c47\u8868\u5927\u5c0f\u3001\u5c42\u6570\u3001\u6ce8\u610f\u529b\u673a\u5236\u7684\u5934\u6570\u4ee5\u53ca\u901a\u9053\u6570\u7b49\u4fe1\u606f\u3002\n    GPT2Config config;\n    \/\/ the weights (parameters) of the model, and their sizes\n    \/\/ params\uff1a\u6a21\u578b\u7684\u6743\u91cd\uff08\u53c2\u6570\uff09\uff0c\u6309\u7167ParameterTensors\u7ed3\u6784\u4f53\u7ec4\u7ec7\uff0c\n    \/\/ \u5305\u62ec\u4e86\u8bcd\u5d4c\u5165\u3001\u4f4d\u7f6e\u5d4c\u5165\u3001\u81ea\u6ce8\u610f\u529b\u548c\u524d\u9988\u7f51\u7edc\u7684\u6743\u91cd\u7b49\u3002\n    ParameterTensors params;\n    \/\/ param_sizes\uff1a\u6570\u7ec4\uff0c\u8bb0\u5f55\u4e86\u6bcf\u4e2a\u53c2\u6570\u5f20\u91cf\u7684\u5927\u5c0f\uff0c\u7528\u4e8e\u5185\u5b58\u5206\u914d\u3002\n    size_t param_sizes[NUM_PARAMETER_TENSORS];\n    \/\/ params_memory\uff1a\u6307\u5411\u4e00\u5757\u8fde\u7eed\u5185\u5b58\u7684\u6307\u9488\uff0c\u8be5\u5185\u5b58\u5757\u5b58\u50a8\u4e86\u6240\u6709\u7684\u6a21\u578b\u53c2\u6570\u3002\n    float* params_memory;\n    \/\/ num_parameters\uff1a\u6a21\u578b\u53c2\u6570\u7684\u603b\u6570\u91cf\u3002\n    size_t num_parameters;\n    \/\/ gradients of the weights\n    \/\/ grads\uff1a\u6a21\u578b\u53c2\u6570\u7684\u68af\u5ea6\uff0c\u4e0eparams\u7ed3\u6784\u76f8\u540c\u3002\n    ParameterTensors grads;\n    \/\/ grads_memory\uff1a\u6307\u5411\u5b58\u50a8\u6240\u6709\u53c2\u6570\u68af\u5ea6\u7684\u8fde\u7eed\u5185\u5b58\u5757\u7684\u6307\u9488\u3002\n    float* grads_memory;\n    \/\/ buffers for the AdamW optimizer\n    \/\/ m_memory\u548cv_memory\uff1aAdamW\u4f18\u5316\u5668\u4e2d\u7528\u4e8e\u5b58\u50a8\u4e00\u9636\u548c\u4e8c\u9636\u52a8\u91cf\u7684\u7f13\u51b2\u533a\u3002\n    float* m_memory;\n    float* v_memory;\n    \/\/ the activations of the model, and their sizes\n\t\t\/\/ acts\uff1a\u6a21\u578b\u7684\u6fc0\u6d3b\u503c\uff0c\u6309\u7167ActivationTensors\u7ed3\u6784\u4f53\u7ec4\u7ec7\uff0c\u5305\u62ec\u4e86\u7f16\u7801\u540e\u7684\u5d4c\u5165\u3001\u81ea\u6ce8\u610f\u529b\u7684\u8f93\u51fa\u3001\u5c42\u5f52\u4e00\u5316\u7684\u7ed3\u679c\u7b49\u3002\n    ActivationTensors acts;\n\t\t\/\/ act_sizes\uff1a\u6570\u7ec4\uff0c\u8bb0\u5f55\u4e86\u6bcf\u4e2a\u6fc0\u6d3b\u5f20\u91cf\u7684\u5927\u5c0f\u3002\n    size_t act_sizes[NUM_ACTIVATION_TENSORS];\n\t\t\/\/ acts_memory\uff1a\u6307\u5411\u5b58\u50a8\u6240\u6709\u6fc0\u6d3b\u503c\u7684\u8fde\u7eed\u5185\u5b58\u5757\u7684\u6307\u9488\u3002\n    float* acts_memory;\n\t\t\/\/ num_activations\uff1a\u6a21\u578b\u6fc0\u6d3b\u503c\u7684\u603b\u6570\u91cf\u3002\n    size_t num_activations;\n    \/\/ gradients of the activations\n\t\t\/\/ grads_acts\uff1a\u6fc0\u6d3b\u503c\u7684\u68af\u5ea6\u3002\n    ActivationTensors grads_acts;\n\t\t\/\/ grads_acts_memory\uff1a\u6307\u5411\u5b58\u50a8\u6240\u6709\u6fc0\u6d3b\u503c\u68af\u5ea6\u7684\u8fde\u7eed\u5185\u5b58\u5757\u7684\u6307\u9488\u3002\n    float* grads_acts_memory;\n    \/\/ other run state configuration\n\t\t\/\/ batch_size\uff08B\uff09\uff1a\u5f53\u524d\u524d\u5411\u4f20\u9012\u7684\u6279\u91cf\u5927\u5c0f\u3002\n    int batch_size; \/\/ the batch size (B) of current forward pass\n\t\t\/\/ seq_len\uff08T\uff09\uff1a\u5f53\u524d\u524d\u5411\u4f20\u9012\u7684\u5e8f\u5217\u957f\u5ea6\u3002\n    int seq_len; \/\/ the sequence length (T) of current forward pass\n\t\t\/\/ inputs\uff1a\u5f53\u524d\u524d\u5411\u4f20\u9012\u7684\u8f93\u5165token\u3002\n    int* inputs; \/\/ the input tokens for the current forward pass\n\t\t\/\/ targets\uff1a\u5f53\u524d\u524d\u5411\u4f20\u9012\u7684\u76ee\u6807token\u3002\n    int* targets; \/\/ the target tokens for the current forward pass\n\t\t\/\/ mean_loss\uff1a\u5728\u8fdb\u884c\u5e26\u76ee\u6807\u7684\u524d\u5411\u4f20\u9012\u540e\uff0c\u8be5\u503c\u4f1a\u88ab\u586b\u5145\u4e3a\u5e73\u5747\u635f\u5931\u503c\u3002\n    float mean_loss; \/\/ after a forward pass with targets, will be populated with the mean loss\n} GPT2;\n<\/pre><\/div>\n\n\n\n<p>\u8fd9\u4e2a\u7ed3\u6784\u4f53\u63d0\u4f9b\u4e86\u4e00\u4e2a\u6846\u67b6\uff0c\u4ee5\u652f\u6301GPT-2\u6a21\u578b\u7684\u8bad\u7ec3\u548c\u63a8\u7406\uff0c\u4f7f\u5f97\u6a21\u578b\u7684\u53c2\u6570\u7ba1\u7406\u3001\u524d\u5411\u548c\u53cd\u5411\u4f20\u64ad\u4ee5\u53ca\u53c2\u6570\u66f4\u65b0\u53d8\u5f97\u66f4\u52a0\u7cfb\u7edf\u5316\u548c\u9ad8\u6548\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.22 gpt2_build_from_checkpoint<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u5b9e\u73b0\u4e86\u4ece\u68c0\u67e5\u70b9\u6587\u4ef6\u8bfb\u53d6GPT-2\u6a21\u578b\u7684\u529f\u80fd\u3002\u68c0\u67e5\u70b9\u6587\u4ef6\u901a\u5e38\u7528\u4e8e\u4fdd\u5b58\u8bad\u7ec3\u8fc7\u7a0b\u4e2d\u7684\u6a21\u578b\u72b6\u6001\uff0c\u5305\u62ec\u6a21\u578b\u7684\u53c2\u6570\u548c\u914d\u7f6e\uff0c\u4ee5\u4fbf\u4e8e\u540e\u7eed\u7684\u6062\u590d\u8bad\u7ec3\u6216\u63a8\u7406\u4f7f\u7528\u3002\u4ee5\u4e0b\u662f\u5bf9\u8fd9\u4e2a\u51fd\u6570\u7684\u8be6\u7ec6\u4e2d\u6587\u6ce8\u91ca\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">void gpt2_build_from_checkpoint(GPT2 *model, char* checkpoint_path) {\n\n    \/\/ read in model from a checkpoint file\n    \/\/ \u4ece\u68c0\u67e5\u70b9\u6587\u4ef6\u4e2d\u8bfb\u53d6\u6a21\u578b\n    FILE *model_file = fopen(checkpoint_path, \"rb\");\n    if (model_file == NULL) { printf(\"Error opening model file\\n\"); exit(1); }\n    int model_header[256];\n    fread(model_header, sizeof(int), 256, model_file);\n    \/\/ \u68c0\u67e5\u6587\u4ef6\u7684\u9b54\u6570\u548c\u7248\u672c\uff0c\u786e\u4fdd\u6587\u4ef6\u683c\u5f0f\u6b63\u786e\n    if (model_header[0] != 20240326) { printf(\"Bad magic model file\"); exit(1); }\n    if (model_header[1] != 1) { printf(\"Bad version in model file\"); exit(1); }\n\n    \/\/ read in hyperparameters\n   \/\/ \u8bfb\u53d6\u8d85\u53c2\u6570\n    int maxT, V, L, NH, C;\n    model-&gt;config.max_seq_len = maxT = model_header[2];\n    model-&gt;config.vocab_size = V = model_header[3];\n    model-&gt;config.num_layers = L = model_header[4];\n    model-&gt;config.num_heads = NH = model_header[5];\n    model-&gt;config.channels = C = model_header[6];\n    printf(\"[GPT-2]\\n\");\n    printf(\"max_seq_len: %d\\n\", maxT);\n    printf(\"vocab_size: %d\\n\", V);\n    printf(\"num_layers: %d\\n\", L);\n    printf(\"num_heads: %d\\n\", NH);\n    printf(\"channels: %d\\n\", C);\n\n    \/\/ allocate space for all the parameters and read them in\n    \/\/ \u4e3a\u6240\u6709\u53c2\u6570\u5206\u914d\u7a7a\u95f4\u5e76\u4ece\u6587\u4ef6\u4e2d\u8bfb\u53d6\n    model-&gt;param_sizes[0] = V * C; \/\/ wte\n    model-&gt;param_sizes[1] = maxT * C; \/\/ wpe\n    model-&gt;param_sizes[2] = L * C; \/\/ ln1w\n    model-&gt;param_sizes[3] = L * C; \/\/ ln1b\n    model-&gt;param_sizes[4] = L * (3 * C) * C; \/\/ qkvw\n    model-&gt;param_sizes[5] = L * (3 * C); \/\/ qkvb\n    model-&gt;param_sizes[6] = L * C * C; \/\/ attprojw\n    model-&gt;param_sizes[7] = L * C; \/\/ attprojb\n    model-&gt;param_sizes[8] = L * C; \/\/ ln2w\n    model-&gt;param_sizes[9] = L * C; \/\/ ln2b\n    model-&gt;param_sizes[10] = L * (4 * C) * C; \/\/ fcw\n    model-&gt;param_sizes[11] = L * (4 * C); \/\/ fcb\n    model-&gt;param_sizes[12] = L * C * (4 * C); \/\/ fcprojw\n    model-&gt;param_sizes[13] = L * C; \/\/ fcprojb\n    model-&gt;param_sizes[14] = C; \/\/ lnfw\n    model-&gt;param_sizes[15] = C; \/\/ lnfb\n\n    \/\/ count the number of parameters\n    \/\/ \u8ba1\u7b97\u53c2\u6570\u603b\u6570\n    size_t num_parameters = 0;\n    for (size_t i = 0; i &lt; NUM_PARAMETER_TENSORS; i++) {\n        num_parameters += model-&gt;param_sizes[i];\n    }\n    printf(\"num_parameters: %zu\\n\", num_parameters);\n    model-&gt;num_parameters = num_parameters;\n\n    \/\/ read in all the parameters from file\n    \/\/ \u4ece\u6587\u4ef6\u4e2d\u8bfb\u53d6\u6240\u6709\u53c2\u6570\n    model-&gt;params_memory = malloc_and_point_parameters(&amp;model-&gt;params, model-&gt;param_sizes);\n    fread(model-&gt;params_memory, sizeof(float), num_parameters, model_file);\n    fclose(model_file);\n\n    \/\/ other inits\n    \/\/ \u5176\u4ed6\u521d\u59cb\u5316\n    model-&gt;acts_memory = NULL;\n    model-&gt;grads_memory = NULL;\n    model-&gt;m_memory = NULL;\n    model-&gt;v_memory = NULL;\n    model-&gt;grads_acts_memory = NULL;\n    model-&gt;inputs = NULL;\n    model-&gt;targets = NULL;\n    model-&gt;batch_size = 0;\n    model-&gt;seq_len = 0;\n    \/\/ \u4f7f\u7528-1.0f\u6807\u8bb0\u6ca1\u6709\u635f\u5931\n    model-&gt;mean_loss = -1.0f; \/\/ -1.0f will designate no loss\n}\n<\/pre><\/div>\n\n\n\n<p>\u901a\u8fc7\u8fd9\u79cd\u65b9\u5f0f\uff0c<code>gpt2_build_from_checkpoint<\/code>\u51fd\u6570\u80fd\u591f\u4ece\u4e00\u4e2a\u9884\u5148\u4fdd\u5b58\u7684\u68c0\u67e5\u70b9\u6587\u4ef6\u4e2d\u6062\u590dGPT-2\u6a21\u578b\u7684\u72b6\u6001\uff0c\u5305\u62ec\u6a21\u578b\u7684\u7ed3\u6784\u914d\u7f6e\u548c\u53c2\u6570\u3002\u8fd9\u5bf9\u4e8e\u6a21\u578b\u7684\u7ee7\u7eed\u8bad\u7ec3\u6216\u8fdb\u884c\u63a8\u7406\u9884\u6d4b\u975e\u5e38\u6709\u7528\u3002\u5728\u6a21\u578b\u4f7f\u7528\u4e4b\u524d\uff0c\u786e\u4fdd\u6240\u6709\u76f8\u5173\u7684\u521d\u59cb\u5316\u548c\u8d44\u6e90\u5206\u914d\u90fd\u5df2\u6b63\u786e\u5b8c\u6210\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.23 gpt2_forward<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u5b9a\u4e49\u4e86GPT-2\u6a21\u578b\u7684\u524d\u5411\u4f20\u64ad\u8fc7\u7a0b\u3002\u5b83\u8d1f\u8d23\u6839\u636e\u7ed9\u5b9a\u7684\u8f93\u5165tokens\u8ba1\u7b97\u6a21\u578b\u7684\u8f93\u51fa\uff0c\u4ee5\u53ca\u53ef\u9009\u7684\uff0c\u6839\u636e\u76ee\u6807tokens\u8ba1\u7b97\u635f\u5931\u503c\u3002\u524d\u5411\u4f20\u64ad\u662f\u6df1\u5ea6\u5b66\u4e60\u4e2d\u8ba1\u7b97\u6a21\u578b\u8f93\u51fa\u548c\u635f\u5931\u7684\u57fa\u672c\u8fc7\u7a0b\u3002\u4ee5\u4e0b\u662f\u8be6\u7ec6\u7684\u4e2d\u6587\u6ce8\u91ca\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">void gpt2_forward(GPT2 *model, int* inputs, int* targets, int B, int T) {\n    \/\/ targets are optional and could be NULL\n    \/\/ \u76ee\u6807tokens\u662f\u53ef\u9009\u7684\uff0c\u53ef\u4ee5\u4e3aNULL\n\n    \/\/ ensure the model was initialized or error out\n    \/\/ \u786e\u4fdd\u6a21\u578b\u5df2\u7ecf\u6b63\u786e\u521d\u59cb\u5316\n    if (model-&gt;params_memory == NULL) {\n        printf(\"Error: model was not initialized properly.\\n\");\n        exit(1);\n    }\n\n    \/\/ convenience parameters\n    \/\/ \u65b9\u4fbf\u8d77\u89c1\uff0c\u63d0\u53d6\u6a21\u578b\u914d\u7f6e\u4e2d\u7684\u4e00\u4e9b\u53c2\u6570\n    \/\/ \u8bcd\u6c47\u8868\u5927\u5c0f\n    int V = model-&gt;config.vocab_size;\n    \/\/ \u5c42\u6570\n    int L = model-&gt;config.num_layers;\n    \/\/ \u6ce8\u610f\u529b\u5934\u6570\n    int NH = model-&gt;config.num_heads;\n    \/\/ \u901a\u9053\u6570\uff08\u9690\u85cf\u5c42\u5927\u5c0f\uff09\n    int C = model-&gt;config.channels;\n\n    \/\/ validate inputs, all indices must be in the range [0, V)\n    \/\/ \u9a8c\u8bc1\u8f93\u5165\uff0c\u6240\u6709\u7d22\u5f15\u5fc5\u987b\u5728[0, V)\u8303\u56f4\u5185\n    for(int i = 0; i &lt; B * T; i++) {\n        assert(0 &lt;= inputs[i] &amp;&amp; inputs[i] &lt; V);\n        if (targets != NULL) {\n            assert(0 &lt;= targets[i] &amp;&amp; targets[i] &lt; V);\n        }\n    }\n\n    \/\/ allocate space for all the activations if needed (done here, lazily)\n    \/\/ \u5982\u6709\u5fc5\u8981\uff0c\u61d2\u52a0\u8f7d\u65b9\u5f0f\u4e3a\u6240\u6709\u6fc0\u6d3b\u503c\u5206\u914d\u7a7a\u95f4\n    if(model-&gt;acts_memory == NULL) {\n        \/\/ record the current B,T as well\n        \/\/ \u8bb0\u5f55\u5f53\u524d\u7684B,T\n        model-&gt;batch_size = B;\n        model-&gt;seq_len = T;\n        \/\/ and now allocate the space\n        \/\/ \u73b0\u5728\u5206\u914d\u7a7a\u95f4\n        model-&gt;act_sizes[0] = B * T * C; \/\/ encoded\n        model-&gt;act_sizes[1] = L * B * T * C; \/\/ ln1\n        model-&gt;act_sizes[2] = L * B * T;  \/\/ ln1_mean\n        model-&gt;act_sizes[3] = L * B * T;  \/\/ ln1_rstd\n        model-&gt;act_sizes[4] = L * B * T * 3*C; \/\/ qkv\n        model-&gt;act_sizes[5] = L * B * T * C;  \/\/ atty\n        model-&gt;act_sizes[6] = L * B * NH * T * T;  \/\/ preatt\n        model-&gt;act_sizes[7] = L * B * NH * T * T;  \/\/ att\n        model-&gt;act_sizes[8] = L * B * T * C; \/\/ attproj\n        model-&gt;act_sizes[9] = L * B * T * C; \/\/ residual2\n        model-&gt;act_sizes[10] = L * B * T * C; \/\/ ln2\n        model-&gt;act_sizes[11] = L * B * T; \/\/ ln2_mean\n        model-&gt;act_sizes[12] = L * B * T; \/\/ ln2_rstd\n        model-&gt;act_sizes[13] = L * B * T * 4*C; \/\/ fch\n        model-&gt;act_sizes[14] = L * B * T * 4*C; \/\/ fch_gelu\n        model-&gt;act_sizes[15] = L * B * T * C; \/\/ fcproj\n        model-&gt;act_sizes[16] = L * B * T * C; \/\/ residual3\n        model-&gt;act_sizes[17] = B * T * C; \/\/ lnf\n        model-&gt;act_sizes[18] = B * T; \/\/ lnf_mean\n        model-&gt;act_sizes[19] = B * T; \/\/ lnf_rstd\n        model-&gt;act_sizes[20] = B * T * V; \/\/ logits\n        model-&gt;act_sizes[21] = B * T * V; \/\/ probs\n        model-&gt;act_sizes[22] = B * T; \/\/ losses\n        size_t num_activations = 0;\n        for (size_t i = 0; i &lt; NUM_ACTIVATION_TENSORS; i++) {\n            num_activations += model-&gt;act_sizes[i];\n        }\n        printf(\"num_activations: %zu\\n\", num_activations);\n        model-&gt;num_activations = num_activations;\n        model-&gt;acts_memory = malloc_and_point_activations(&amp;model-&gt;acts, model-&gt;act_sizes);\n        \/\/ also create memory for caching inputs and targets\n        \/\/ \u540c\u65f6\u4e3a\u8f93\u5165\u548c\u76ee\u6807\u521b\u5efa\u5185\u5b58\u7f13\u5b58\n        model-&gt;inputs = (int*)malloc(B * T * sizeof(int));\n        \/\/ \u5982\u679c\u6211\u4eec\u6ca1\u6709\u76ee\u6807\uff0c\u8fd9\u90e8\u5206\u53ef\u80fd\u4e0d\u4f1a\u7528\u5230\uff0c\u4f46\u5f00\u9500\u5f88\u5c0f\n        model-&gt;targets = (int*)malloc(B * T * sizeof(int)); \/\/ might be unused if we never have targets but it's small\n    } else {\n        \/\/ validate B,T is consistent with how we've allocated the memory before\n        \/\/ in principle we could get more clever here in the future, for now this is safest\n        \/\/ \u9a8c\u8bc1B,T\u662f\u5426\u4e0e\u4e4b\u524d\u5206\u914d\u7684\u5185\u5b58\u4e00\u81f4\n        if (B != model-&gt;batch_size || T != model-&gt;seq_len) {\n            printf(\"Model: B=%d T=%d, Desired: B=%d T=%d\\n\", model-&gt;batch_size, model-&gt;seq_len, B, T);\n            exit(EXIT_FAILURE);\n        }\n    }\n\n    \/\/ cache the inputs\/targets\n    \/\/ \u7f13\u5b58\u8f93\u5165\/\u76ee\u6807\n    memcpy(model-&gt;inputs, inputs, B * T * sizeof(int));\n    if (targets != NULL) {\n        memcpy(model-&gt;targets, targets, B * T * sizeof(int));\n    }\n\n    \/\/ forward pass\n    \/\/ \u524d\u5411\u4f20\u64ad\n    ParameterTensors params = model-&gt;params; \/\/ for brevity\n    ActivationTensors acts = model-&gt;acts;\n    float* residual;\n    \/\/ \u4f7f\u7528encoder_forward\u51fd\u6570\u5bf9\u8f93\u5165tokens\u8fdb\u884c\u7f16\u7801\u3002\n    \/\/ \u8fd9\u4e2a\u6b65\u9aa4\u5c06\u8f93\u5165tokens\u8f6c\u6362\u4e3a\u6a21\u578b\u53ef\u4ee5\u7406\u89e3\u7684\u683c\u5f0f\uff0c\n    \/\/ \u5373\u6bcf\u4e2atoken\u5bf9\u5e94\u7684\u5411\u91cf\u8868\u793a\u3002\u8fd9\u91cc\u4f7f\u7528\u7684\u662f\u8bcd\u5d4c\u5165\uff08params.wte\uff09\u548c\u4f4d\u7f6e\u5d4c\u5165\uff08params.wpe\uff09\uff0c\n    \/\/ \u5206\u522b\u6355\u83b7\u4e86\u8bcd\u6c47\u7684\u8bed\u4e49\u4fe1\u606f\u548c\u5728\u5e8f\u5217\u4e2d\u7684\u4f4d\u7f6e\u4fe1\u606f\u3002\u7f16\u7801\u540e\u7684\u7ed3\u679c\u5b58\u50a8\u5728acts.encoded\u4e2d\u3002\n    encoder_forward(acts.encoded, inputs, params.wte, params.wpe, B, T, C); \/\/ encoding goes into residual[0]\n    \/\/ \u904d\u5386\u6bcf\u4e00\u5c42\uff1a \n    \/\/ \u4ee3\u7801\u8fdb\u5165\u4e00\u4e2a\u5faa\u73af\uff0c\u5bf9\u6a21\u578b\u7684\u6bcf\u4e00\u5c42\uff08\u4ece\u7b2c0\u5c42\u5230\u7b2cL-1\u5c42\uff09\u6267\u884c\u4e00\u7cfb\u5217\u8ba1\u7b97\u3002\n    \/\/ \u5faa\u73af\u5185\u7684\u7b2c\u4e00\u6b65\u662f\u66f4\u65b0residual\u53d8\u91cf\u7684\u503c\u3002\n    \/\/ \u5bf9\u4e8e\u7b2c0\u5c42\uff0cresidual\u4fdd\u6301\u4e3a\u7f16\u7801\u540e\u7684\u8f93\u5165\uff08\u5373\uff0c\u7b2c\u4e00\u5c42\u7684\u8f93\u5165\u662f\u7f16\u7801\u540e\u7684\u5e8f\u5217\uff09\uff1b\n    \/\/ \u5bf9\u4e8e\u5176\u4ed6\u5c42\uff0cresidual\u6307\u5411\u4e0a\u4e00\u5c42\u7684\u6b8b\u5dee\u8fde\u63a5\u8f93\u51fa\uff08acts.residual3\u52a0\u4e0a\u504f\u79fb\u91cf\uff09\u3002\n    \/\/ \u8fd9\u91cc\u7684(l-1) * B * T * C\u8ba1\u7b97\u786e\u4fdd\u4e86residual\u6b63\u786e\u6307\u5411\u4e86\u524d\u4e00\u5c42\u7684\u6b8b\u5dee\u8f93\u51fa\uff0c\u4e3a\u5f53\u524d\u5c42\u7684\u8ba1\u7b97\u63d0\u4f9b\u8f93\u5165\u3002\n    for (int l = 0; l &lt; L; l++) {\n\n        \/\/ residual\u53d8\u91cf\u88ab\u521d\u59cb\u5316\u4e3a\u6307\u5411\u7f16\u7801\u540e\u7684\u8f93\u5165acts.encoded\u3002\n        \/\/ \u5728Transformer\u67b6\u6784\u4e2d\uff0c\u6bcf\u4e00\u5c42\u7684\u8f93\u51fa\u90fd\u4f1a\u548c\u8f93\u5165\u8fdb\u884c\u76f8\u52a0\uff0c\u5f62\u6210\u6240\u8c13\u7684\u6b8b\u5dee\u8fde\u63a5\u3002\n        \/\/ \u8fd9\u4e00\u673a\u5236\u6709\u52a9\u4e8e\u907f\u514d\u6df1\u5c42\u7f51\u7edc\u4e2d\u7684\u68af\u5ea6\u6d88\u5931\u95ee\u9898\uff0c\u4f7f\u5f97\u6a21\u578b\u80fd\u591f\u6709\u6548\u5730\u5b66\u4e60\u3002\n        residual = l == 0 ? acts.encoded : acts.residual3 + (l-1) * B * T * C;\n\n        \/\/ get the pointers of the weights for this layer\n        \/\/ \u83b7\u53d6\u8fd9\u4e00\u5c42\u7684\u6743\u91cd\u6307\u9488\u3002\n        float* l_ln1w = params.ln1w + l * C;\n        float* l_ln1b = params.ln1b + l * C;\n        float* l_qkvw = params.qkvw + l * 3*C * C;\n        float* l_qkvb = params.qkvb + l * 3*C;\n        float* l_attprojw = params.attprojw + l * C * C;\n        float* l_attprojb = params.attprojb + l * C;\n        float* l_ln2w = params.ln2w + l * C;\n        float* l_ln2b = params.ln2b + l * C;\n        float* l_fcw = params.fcw + l * 4*C * C;\n        float* l_fcb = params.fcb + l * 4*C;\n        float* l_fcprojw = params.fcprojw + l * C * 4*C;\n        float* l_fcprojb = params.fcprojb + l * C;\n\n        \/\/ get the pointers of the activations for this layer\n        \/\/ \u83b7\u53d6\u8fd9\u4e00\u5c42\u7684\u6fc0\u6d3b\u503c\u6307\u9488\u3002\n        float* l_ln1 = acts.ln1 + l * B * T * C;\n        float* l_ln1_mean = acts.ln1_mean + l * B * T;\n        float* l_ln1_rstd = acts.ln1_rstd + l * B * T;\n        float* l_qkv = acts.qkv + l * B * T * 3*C;\n        float* l_atty = acts.atty + l * B * T * C;\n        float* l_preatt = acts.preatt + l * B * NH * T * T;\n        float* l_att = acts.att + l * B * NH * T * T;\n        float* l_attproj = acts.attproj + l * B * T * C;\n        float* l_residual2 = acts.residual2 + l * B * T * C;\n        float* l_ln2 = acts.ln2 + l * B * T * C;\n        float* l_ln2_mean = acts.ln2_mean + l * B * T;\n        float* l_ln2_rstd = acts.ln2_rstd + l * B * T;\n        float* l_fch = acts.fch + l * B * T * 4*C;\n        float* l_fch_gelu = acts.fch_gelu + l * B * T * 4*C;\n        float* l_fcproj = acts.fcproj + l * B * T * C;\n        float* l_residual3 = acts.residual3 + l * B * T * C;\n\n        \/\/ now do the forward pass\n        \/\/ \u73b0\u5728\u8fdb\u884c\u524d\u5411\u4f20\u64ad\u3002\n        \/\/ \n        \/\/ \u5c42\u5f52\u4e00\u5316\uff08Layer Normalization\uff09\uff1a \n        \/\/ \u5bf9\u6b8b\u5dee\u8fde\u63a5\u540e\u7684\u8f93\u51faresidual\u8fdb\u884c\u5f52\u4e00\u5316\u5904\u7406\uff0c\u4f7f\u5f97\u5176\u5206\u5e03\u66f4\u52a0\u7a33\u5b9a\u3002\n        \/\/ \u8fd9\u6709\u52a9\u4e8e\u52a0\u901f\u8bad\u7ec3\u8fc7\u7a0b\u5e76\u63d0\u9ad8\u6a21\u578b\u7684\u6027\u80fd\u3002\n        layernorm_forward(l_ln1, l_ln1_mean, l_ln1_rstd, residual, l_ln1w, l_ln1b, B, T, C);\n        \/\/ \u77e9\u9635\u4e58\u6cd5\uff08Matrix Multiplication\uff09\uff1a \n        \/\/ \u5bf9\u5c42\u5f52\u4e00\u5316\u540e\u7684\u7ed3\u679cl_ln1\u548c\u6743\u91cdl_qkvw\u8fdb\u884c\u77e9\u9635\u4e58\u6cd5\uff0c\u52a0\u4e0a\u504f\u7f6el_qkvb\uff0c\n        \/\/ \u8ba1\u7b97\u67e5\u8be2\uff08Query\uff09\u3001\u952e\uff08Key\uff09\u3001\u503c\uff08Value\uff09\u7684\u5408\u5e76\u8868\u793al_qkv\u3002\n        matmul_forward(l_qkv, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C);\n        \/\/ \u81ea\u6ce8\u610f\u529b\u673a\u5236\uff08Self-Attention\uff09\uff1a \n        \/\/ \u4f7f\u7528\u81ea\u6ce8\u610f\u529b\u673a\u5236\u5904\u7406l_qkv\uff0c\u751f\u6210\u6ce8\u610f\u529b\u52a0\u6743\u7684\u8f93\u51fal_atty\u3002\n        \/\/ \u81ea\u6ce8\u610f\u529b\u673a\u5236\u5141\u8bb8\u6a21\u578b\u5728\u5904\u7406\u6bcf\u4e2a\u5355\u8bcd\u65f6\u8003\u8651\u5230\u6574\u4e2a\u5e8f\u5217\u7684\u4e0a\u4e0b\u6587\u4fe1\u606f\u3002\n        attention_forward(l_atty, l_preatt, l_att, l_qkv, B, T, C, NH);\n        \/\/ \u518d\u6b21\u8fdb\u884c\u77e9\u9635\u4e58\u6cd5\uff1a \n        \/\/ \u5c06\u81ea\u6ce8\u610f\u529b\u7684\u8f93\u51fal_atty\u901a\u8fc7\u53e6\u4e00\u4e2a\u7ebf\u6027\u5c42\uff08\u6743\u91cd\u4e3al_attprojw\uff0c\u504f\u7f6e\u4e3al_attprojb\uff09\uff0c\n        \/\/ \u5f97\u5230\u81ea\u6ce8\u610f\u529b\u5c42\u7684\u6700\u7ec8\u8f93\u51fal_attproj\u3002\n        matmul_forward(l_attproj, l_atty, l_attprojw, l_attprojb, B, T, C, C);\n        \/\/ \u6b8b\u5dee\u8fde\u63a5\uff08Residual Connection\uff09\uff1a \n        \/\/ \u5c06\u81ea\u6ce8\u610f\u529b\u5c42\u7684\u8f93\u51fal_attproj\u4e0e\u6b8b\u5dee\u8fde\u63a5\u7684\u8f93\u5165residual\u76f8\u52a0\uff0c\n        \/\/ \u5f62\u6210\u65b0\u7684\u6b8b\u5dee\u8fde\u63a5\u8f93\u51fal_residual2\u3002\n        \/\/ \u6b8b\u5dee\u8fde\u63a5\u6709\u52a9\u4e8e\u7f13\u89e3\u6df1\u5c42\u7f51\u7edc\u8bad\u7ec3\u8fc7\u7a0b\u4e2d\u7684\u68af\u5ea6\u6d88\u5931\u6216\u7206\u70b8\u95ee\u9898\u3002\n        residual_forward(l_residual2, residual, l_attproj, B*T*C);\n        \/\/ \u7b2c\u4e8c\u6b21\u5c42\u5f52\u4e00\u5316\uff1a \n        \/\/ \u5bf9\u65b0\u7684\u6b8b\u5dee\u8fde\u63a5\u8f93\u51fal_residual2\u8fdb\u884c\u5c42\u5f52\u4e00\u5316\uff0c\u5f97\u5230l_ln2\uff0c\u8fdb\u4e00\u6b65\u7a33\u5b9a\u6a21\u578b\u7684\u8bad\u7ec3\u8fc7\u7a0b\u3002\n        layernorm_forward(l_ln2, l_ln2_mean, l_ln2_rstd, l_residual2, l_ln2w, l_ln2b, B, T, C);\n        \/\/ \u524d\u9988\u7f51\u7edc\uff08Feedforward Network\uff09\uff1a \n        \/\/ \u901a\u8fc7\u4e00\u4e2a\u524d\u9988\u7f51\u7edc\uff08\u7ebf\u6027\u5c42\u52a0\u4e0a\u6fc0\u6d3b\u51fd\u6570\uff09\uff0c\u5904\u7406l_ln2\uff0c\u5176\u4e2d\u4f7f\u7528GELU\u4f5c\u4e3a\u6fc0\u6d3b\u51fd\u6570\uff0c\u5f97\u5230l_fch\u3002\n        matmul_forward(l_fch, l_ln2, l_fcw, l_fcb, B, T, C, 4*C);\n        \/\/ GELU\u6fc0\u6d3b\u51fd\u6570\uff1a \n        \/\/ \u5bf9\u524d\u9988\u7f51\u7edc\u7684\u8f93\u51fal_fch\u5e94\u7528GELU\u6fc0\u6d3b\u51fd\u6570\uff0c\u5f97\u5230l_fch_gelu\u3002\n        \/\/ GELU\u6fc0\u6d3b\u51fd\u6570\u6709\u52a9\u4e8e\u5f15\u5165\u975e\u7ebf\u6027\uff0c\u589e\u5f3a\u6a21\u578b\u7684\u8868\u8fbe\u80fd\u529b\u3002\n        gelu_forward(l_fch_gelu, l_fch, B*T*4*C);\n        \/\/ \u6700\u540e\u7684\u7ebf\u6027\u53d8\u6362\uff1a \n        \/\/ \u5c06\u6fc0\u6d3b\u540e\u7684\u7ed3\u679cl_fch_gelu\u901a\u8fc7\u6700\u540e\u4e00\u4e2a\u7ebf\u6027\u5c42\uff08\u6743\u91cd\u4e3al_fcprojw\uff0c\u504f\u7f6e\u4e3al_fcprojb\uff09\uff0c\n        \/\/ \u5f97\u5230\u6b64\u5c42\u7684\u8f93\u51fal_fcproj\u3002\n        matmul_forward(l_fcproj, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C);\n        \/\/ \u6700\u7ec8\u7684\u6b8b\u5dee\u8fde\u63a5\uff1a \n        \/\/ \u5c06\u6700\u540e\u4e00\u4e2a\u7ebf\u6027\u5c42\u7684\u8f93\u51fal_fcproj\u4e0e\u4e4b\u524d\u7684\u6b8b\u5dee\u8fde\u63a5\u8f93\u51fal_residual2\u76f8\u52a0\uff0c\n        \/\/ \u5f62\u6210\u6700\u7ec8\u7684\u6b8b\u5dee\u8fde\u63a5\u8f93\u51fal_residual3\u3002\u8fd9\u4e00\u6b65\u5b8c\u6210\u4e86\u4e00\u4e2aTransformer\u5c42\u7684\u5168\u90e8\u8ba1\u7b97\u6d41\u7a0b\u3002\n        residual_forward(l_residual3, l_residual2, l_fcproj, B*T*C);\n        \/\/ \u5728\u6a21\u578b\u7684\u6bcf\u4e2aTransformer\u5c42\u4e2d\uff0c\u4e0b\u9762\u6b65\u9aa4\u88ab\u91cd\u590d\u6267\u884c\uff0c\n        \/\/ \u6bcf\u4e00\u5c42\u7684\u8f93\u51fa\u90fd\u4f1a\u4f5c\u4e3a\u4e0b\u4e00\u5c42\u7684\u8f93\u5165\uff0c\u76f4\u5230\u6240\u6709\u7684\u5c42\u90fd\u6267\u884c\u5b8c\u6bd5\u3002\n        \/\/ \u8fd9\u4e2a\u8fc7\u7a0b\u5141\u8bb8\u6a21\u578b\u6355\u83b7\u548c\u5904\u7406\u590d\u6742\u7684\u5e8f\u5217\u4f9d\u8d56\u5173\u7cfb\uff0c\u4ece\u800c\u751f\u6210\u51c6\u786e\u7684\u9884\u6d4b\u6216\u8bed\u8a00\u6a21\u578b\u8f93\u51fa\u3002\n    }\n    \/\/ \u6700\u540e\u7684\u6b8b\u5dee\u5728residual3\u4e2d\u3002\n    \/\/ \u9009\u62e9\u6700\u540e\u4e00\u5c42\u7684\u6b8b\u5dee\uff1a \n    \/\/ \u901a\u8fc7acts.residual3 + (L-1) * B * T * C\u8868\u8fbe\u5f0f\uff0c\n    \/\/ \u6211\u4eec\u9009\u62e9\u4e86\u6700\u540e\u4e00\u4e2aTransformer\u5c42\u7684\u6b8b\u5dee\u8fde\u63a5\u8f93\u51fa\u4f5c\u4e3a\u63a5\u4e0b\u6765\u5c42\u5f52\u4e00\u5316\u7684\u8f93\u5165\u3002\n    \/\/ \u8fd9\u91cc\u7684(L-1)\u8868\u793a\u6700\u540e\u4e00\u5c42\uff0c\u56e0\u4e3a\u5c42\u7684\u7d22\u5f15\u4ece0\u5f00\u59cb\u3002\n    residual = acts.residual3 + (L-1) * B * T * C; \/\/ last residual is in residual3\n    \/\/ \u6700\u540e\u4e00\u5c42\u7684\u5c42\u5f52\u4e00\u5316\uff08Layer Normalization\uff09\uff1a \n    \/\/ \u4f7f\u7528layernorm_forward\u51fd\u6570\u5bf9\u6700\u540e\u4e00\u5c42\u7684\u6b8b\u5dee\u8fde\u63a5\u8f93\u51fa\u8fdb\u884c\u5f52\u4e00\u5316\u5904\u7406\uff0c\u5f97\u5230acts.lnf\u3002\n    \/\/ \u8fd9\u6709\u52a9\u4e8e\u6a21\u578b\u7684\u8bad\u7ec3\u7a33\u5b9a\u6027\u548c\u6027\u80fd\u3002\n    layernorm_forward(acts.lnf, acts.lnf_mean, acts.lnf_rstd, residual, params.lnfw, params.lnfb, B, T, C);\n    \/\/ \u8f93\u51fa\u5c42\u7684\u7ebf\u6027\u53d8\u6362\uff1a \n    \/\/ \u901a\u8fc7matmul_forward\u51fd\u6570\uff0c\n    \/\/ \u5c06\u5c42\u5f52\u4e00\u5316\u540e\u7684\u8f93\u51faacts.lnf\u4e0e\u8bcd\u5d4c\u5165\u6743\u91cdparams.wte\u8fdb\u884c\u77e9\u9635\u4e58\u6cd5\u8fd0\u7b97\uff08\u6b64\u5904\u6ca1\u6709\u4f7f\u7528\u504f\u7f6e\uff09\uff0c\n    \/\/ \u5f97\u5230\u6a21\u578b\u7684\u539f\u59cb\u8f93\u51faacts.logits\u3002\n    \/\/ \u8fd9\u4e00\u6b65\u9aa4\u751f\u6210\u4e86\u6bcf\u4e2a\u8bcd\u6c47\u5728\u7ed9\u5b9a\u4e0a\u4e0b\u6587\u4e2d\u7684\u5f97\u5206\uff08\u672a\u5f52\u4e00\u5316\u7684\u6982\u7387\uff09\u3002\n    matmul_forward(acts.logits, acts.lnf, params.wte, NULL, B, T, C, V);\n    \/\/ Softmax\u5f52\u4e00\u5316\uff1a \n    \/\/ \u4f7f\u7528softmax_forward\u51fd\u6570\uff0c\u5bf9acts.logits\u8fdb\u884csoftmax\u5904\u7406\uff0c\n    \/\/ \u5c06\u539f\u59cb\u8f93\u51fa\u8f6c\u6362\u4e3a\u6982\u7387\u5206\u5e03acts.probs\u3002\n    \/\/ softmax\u786e\u4fdd\u4e86\u8f93\u51fa\u6982\u7387\u7684\u603b\u548c\u4e3a1\uff0c\u4fbf\u4e8e\u89e3\u91ca\u548c\u540e\u7eed\u5904\u7406\u3002\n    softmax_forward(acts.probs, acts.logits, B, T, V);\n\n    \/\/ also forward the cross-entropy loss function if we have the targets\n    \/\/ \u5982\u679c\u6709\u76ee\u6807tokens\uff0c\u8fd8\u4f1a\u8fdb\u884c\u4ea4\u53c9\u71b5\u635f\u5931\u7684\u8ba1\u7b97\n    if (targets != NULL) {\n        \/\/ \u4f7f\u7528crossentropy_forward\u51fd\u6570\u8ba1\u7b97\u9884\u6d4b\u6982\u7387\uff08model-&gt;acts.probs\uff09\u548c\u76ee\u6807tokens\u4e4b\u95f4\u7684\u4ea4\u53c9\u71b5\u635f\u5931\u3002\n        \/\/ \u8fd9\u4e00\u6b65\u9aa4\u8bc4\u4f30\u4e86\u6a21\u578b\u9884\u6d4b\u7684\u6982\u7387\u5206\u5e03\u4e0e\u5b9e\u9645\u53d1\u751f\u7684tokens\u4e4b\u95f4\u7684\u5dee\u5f02\uff0c\u635f\u5931\u503c\u5b58\u50a8\u5728model-&gt;acts.losses\u4e2d\n        crossentropy_forward(model-&gt;acts.losses, model-&gt;acts.probs, targets, B, T, V);\n        \/\/ for convenience also evaluate the mean loss\n        \/\/ \u4e3a\u65b9\u4fbf\u8d77\u89c1\uff0c\u540c\u65f6\u8ba1\u7b97\u5e73\u5747\u635f\u5931\n        float mean_loss = 0.0f;\n        \/\/ \u4e3a\u4e86\u83b7\u5f97\u5355\u4e2a\u8bad\u7ec3\u6837\u672c\u4e0a\u7684\u5e73\u5747\u635f\u5931\uff0c\u5c06\u6240\u6709\u635f\u5931\u503c\u76f8\u52a0\u7136\u540e\u9664\u4ee5tokens\u7684\u603b\u6570\uff08\u5373B*T\uff0c\u5176\u4e2dB\u662fbatch\u5927\u5c0f\uff0cT\u662f\u5e8f\u5217\u957f\u5ea6\uff09\u3002\n        \/\/ \u8fd9\u4e2a\u5e73\u5747\u635f\u5931\u503c\u4e3amean_loss\uff0c\u53cd\u6620\u4e86\u6a21\u578b\u5728\u5f53\u524d\u6279\u6b21\u8bad\u7ec3\u6570\u636e\u4e0a\u7684\u5e73\u5747\u8868\u73b0\u3002\n        for (int i=0; i&lt;B*T; i++) { mean_loss += model-&gt;acts.losses[i]; }\n        \/\/ \u66f4\u65b0\u6a21\u578b\u7684\u5e73\u5747\u635f\u5931\u5c5e\u6027\uff1a \n        \/\/ \u5c06\u8ba1\u7b97\u51fa\u7684\u5e73\u5747\u635f\u5931\u503c\u8d4b\u7ed9\u6a21\u578b\u7684mean_loss\u5c5e\u6027\uff0c\u4ee5\u4fbf\u4e8e\u540e\u7eed\u8bad\u7ec3\u8fc7\u7a0b\u4e2d\u7684\u4f7f\u7528\u3002\n        mean_loss \/= B*T;\n        model-&gt;mean_loss = mean_loss;\n    } else {\n        \/\/ if we don't have targets, we don't have a loss\n        \/\/ \u5982\u679c\u6ca1\u6709\u76ee\u6807tokens\uff0c\u90a3\u4e48\u5c31\u6ca1\u6709\u635f\u5931\u503c\n        model-&gt;mean_loss = -1.0f;\n    }\n}\n<\/pre><\/div>\n\n\n\n<p>\u6b64\u51fd\u6570\u9996\u5148\u786e\u4fdd\u6a21\u578b\u5df2\u6b63\u786e\u521d\u59cb\u5316\uff0c\u7136\u540e\u9a8c\u8bc1\u8f93\u5165\u548c\u76ee\u6807\u7684\u6709\u6548\u6027\u3002\u5982\u679c\u6fc0\u6d3b\u503c\u5c1a\u672a\u5206\u914d\u5185\u5b58\uff0c\u5219\u8fdb\u884c\u5206\u914d\uff0c\u5e76\u6839\u636e\u5f53\u524d\u7684\u6279\u91cf\u5927\u5c0f\u548c\u5e8f\u5217\u957f\u5ea6\u8c03\u6574\u6a21\u578b\u914d\u7f6e\u3002\u63a5\u4e0b\u6765\uff0c\u51fd\u6570\u6267\u884c\u6a21\u578b\u7684\u524d\u5411\u4f20\u64ad\u8fc7\u7a0b\uff0c\u8ba1\u7b97\u7f16\u7801\u3001\u81ea\u6ce8\u610f\u529b\u7b49\u3002<\/p>\n\n\n\n<p>\u5f53\u6211\u4eec\u8c08\u5230GPT-2\u6a21\u578b\u7684\u524d\u5411\u4f20\u64ad\u65f6\uff0c\u6211\u4eec\u6307\u7684\u662f\u6a21\u578b\u6839\u636e\u8f93\u5165\u6570\u636e\uff08\u5982\u6587\u672c\u5e8f\u5217\uff09\u8ba1\u7b97\u9884\u6d4b\u8f93\u51fa\uff08\u5982\u4e0b\u4e00\u4e2a\u5355\u8bcd\u7684\u6982\u7387\u5206\u5e03\uff09\u7684\u8fc7\u7a0b\u3002\u8fd9\u4e2a\u8fc7\u7a0b\u6d89\u53ca\u5230\u6a21\u578b\u5185\u90e8\u5404\u5c42\u7684\u987a\u5e8f\u6fc0\u6d3b\u548c\u53c2\u6570\u7684\u4f7f\u7528\u3002\u4ee5\u4e0b\u662f\u5bf9\u524d\u9762\u4ee3\u7801\u4e2d\u51e0\u4e2a\u5173\u952e\u6b65\u9aa4\u7684\u518d\u89e3\u91ca\uff1a<\/p>\n\n\n\n<ol class=\"wp-block-list\">\n<li><strong>\u83b7\u53d6\u6743\u91cd\u6307\u9488\uff1a<\/strong> \u5728\u6bcf\u4e00\u5c42\u7684\u5f00\u59cb\uff0c\u6211\u4eec\u9700\u8981\u83b7\u53d6\u5f53\u524d\u5c42\u6240\u4f7f\u7528\u7684\u6240\u6709\u6743\u91cd\u548c\u504f\u7f6e\u7684\u6307\u9488\u3002\u8fd9\u5305\u62ec\u4e86\u81ea\u6ce8\u610f\u529b\u5c42\u7684\u67e5\u8be2\u3001\u952e\u3001\u503c\u6743\u91cd\uff08<code>qkvw<\/code>, <code>qkvb<\/code>\uff09\u3001\u6ce8\u610f\u529b\u8f93\u51fa\u7684\u6295\u5f71\u6743\u91cd\uff08<code>attprojw<\/code>, <code>attprojb<\/code>\uff09\u7b49\u3002\u8fd9\u6837\u53ef\u4ee5\u76f4\u63a5\u4f7f\u7528\u8fd9\u4e9b\u53c2\u6570\u8fdb\u884c\u8ba1\u7b97\uff0c\u800c\u65e0\u9700\u5728\u6bcf\u4e00\u6b65\u67e5\u627e\u5b83\u4eec\u7684\u4f4d\u7f6e\u3002<\/li>\n\n\n\n<li><strong>\u83b7\u53d6\u6fc0\u6d3b\u503c\u6307\u9488\uff1a<\/strong> \u540c\u6837\u5730\uff0c\u6211\u4eec\u4e5f\u9700\u8981\u83b7\u53d6\u4fdd\u5b58\u4e2d\u95f4\u8ba1\u7b97\u7ed3\u679c\u7684\u6fc0\u6d3b\u503c\u7684\u6307\u9488\uff0c\u4f8b\u5982\u7f16\u7801\u540e\u7684\u8f93\u5165\uff08<code>encoded<\/code>\uff09\u3001\u81ea\u6ce8\u610f\u529b\u7684\u8f93\u51fa\uff08<code>atty<\/code>\uff09\u3001\u5404\u79cd\u5c42\u5f52\u4e00\u5316\u7684\u7ed3\u679c\u7b49\u3002\u8fd9\u4e9b\u6fc0\u6d3b\u503c\u5728\u524d\u5411\u4f20\u64ad\u7684\u4e0d\u540c\u9636\u6bb5\u88ab\u8ba1\u7b97\u51fa\u6765\uff0c\u5e76\u88ab\u540e\u7eed\u7684\u5c42\u4f7f\u7528\u3002<\/li>\n\n\n\n<li><strong>\u8fdb\u884c\u524d\u5411\u4f20\u64ad\uff1a<\/strong> \u5728\u51c6\u5907\u597d\u6240\u6709\u5fc5\u9700\u7684\u6743\u91cd\u548c\u6fc0\u6d3b\u503c\u4e4b\u540e\uff0c\u6211\u4eec\u6309\u7167\u6a21\u578b\u7684\u67b6\u6784\u9010\u5c42\u8fdb\u884c\u8ba1\u7b97\u3002\u8fd9\u5305\u62ec\uff1a<\/li>\n<\/ol>\n\n\n\n<ul class=\"wp-block-list\">\n<li>\u5bf9\u8f93\u5165\u5e8f\u5217\u8fdb\u884c\u7f16\u7801\uff0c\u751f\u6210\u7f16\u7801\u540e\u7684\u5d4c\u5165\u3002<\/li>\n\n\n\n<li>\u901a\u8fc7\u81ea\u6ce8\u610f\u529b\u5c42\u5904\u7406\u7f16\u7801\u540e\u7684\u5d4c\u5165\uff0c\u5f97\u5230\u6ce8\u610f\u529b\u52a0\u6743\u7684\u8f93\u51fa\u3002<\/li>\n\n\n\n<li>\u5e94\u7528\u524d\u9988\u795e\u7ecf\u7f51\u7edc\uff08Feedforward Neural Network, FNN\uff09\u5230\u81ea\u6ce8\u610f\u529b\u5c42\u7684\u8f93\u51fa\u4e0a\u3002<\/li>\n\n\n\n<li>\u4f7f\u7528\u6b8b\u5dee\u8fde\u63a5\u548c\u5c42\u5f52\u4e00\u5316\u6765\u7a33\u5b9a\u8bad\u7ec3\u8fc7\u7a0b\u5e76\u63d0\u9ad8\u6a21\u578b\u6027\u80fd\u3002<\/li>\n<\/ul>\n\n\n\n<ol class=\"wp-block-list\">\n<li><strong>\u5904\u7406\u6700\u540e\u7684\u6b8b\u5dee\uff1a<\/strong> \u6a21\u578b\u7684\u6bcf\u4e00\u5c42\u8f93\u51fa\u90fd\u4f1a\u4e0e\u8f93\u5165\u8fdb\u884c\u6b8b\u5dee\u8fde\u63a5\uff0c\u6700\u540e\u4e00\u5c42\u7684\u6b8b\u5dee\u8fde\u63a5\u8f93\u51fa\u5b58\u50a8\u5728<code>residual3<\/code>\u4e2d\u3002\u8fd9\u4e2a\u6b8b\u5dee\u8fde\u63a5\u7684\u8f93\u51fa\u63a5\u4e0b\u6765\u4f1a\u901a\u8fc7\u6700\u540e\u4e00\u5c42\u7684\u5c42\u5f52\u4e00\u5316\u548c\u7ebf\u6027\u5c42\uff0c\u6700\u7ec8\u751f\u6210\u6a21\u578b\u7684\u8f93\u51fa<code>logits<\/code>\u3002<\/li>\n\n\n\n<li><strong>\u8ba1\u7b97\u635f\u5931\uff1a<\/strong> \u5982\u679c\u63d0\u4f9b\u4e86\u76ee\u6807\uff08\u5982\u6b63\u786e\u7684\u4e0b\u4e00\u4e2a\u5355\u8bcd\uff09\uff0c\u6a21\u578b\u4f1a\u8ba1\u7b97\u9884\u6d4b\u8f93\u51fa\u4e0e\u5b9e\u9645\u76ee\u6807\u4e4b\u95f4\u7684\u635f\u5931\uff0c\u901a\u5e38\u662f\u4f7f\u7528\u4ea4\u53c9\u71b5\u635f\u5931\u51fd\u6570\u3002\u8fd9\u4e2a\u635f\u5931\u503c\u53ef\u7528\u4e8e\u540e\u7eed\u7684\u6a21\u578b\u8bad\u7ec3\u8fc7\u7a0b\u4e2d\uff0c\u901a\u8fc7\u53cd\u5411\u4f20\u64ad\u66f4\u65b0\u6a21\u578b\u53c2\u6570\u4ee5\u63d0\u9ad8\u6a21\u578b\u7684\u9884\u6d4b\u51c6\u786e\u6027\u3002<\/li>\n<\/ol>\n\n\n\n<p>\u6574\u4e2a\u524d\u5411\u4f20\u64ad\u8fc7\u7a0b\u901a\u8fc7\u6a21\u578b\u7684\u5c42\u6b21\u7ed3\u6784\u9010\u6b65\u8fdb\u884c\uff0c\u6bcf\u4e00\u6b65\u90fd\u5efa\u7acb\u5728\u524d\u4e00\u6b65\u7684\u8f93\u51fa\u4e4b\u4e0a\uff0c\u6700\u7ec8\u4ea7\u751f\u6a21\u578b\u5bf9\u8f93\u5165\u6570\u636e\u7684\u9884\u6d4b\u8f93\u51fa\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.24 gpt2_zero_grad<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u662fGPT-2\u6a21\u578b\u8bad\u7ec3\u8fc7\u7a0b\u4e2d\u7528\u4e8e\u91cd\u7f6e\u6a21\u578b\u68af\u5ea6\u7684\u51fd\u6570\u3002\u5728\u6bcf\u6b21\u8bad\u7ec3\u8fed\u4ee3\u5f00\u59cb\u4e4b\u524d\uff0c\u9700\u8981\u5c06\u4e4b\u524d\u8ba1\u7b97\u7684\u68af\u5ea6\u6e05\u96f6\uff0c\u4ee5\u4fbf\u4e8e\u65b0\u7684\u8bad\u7ec3\u8fed\u4ee3\u4e2d\u6b63\u786e\u7d2f\u8ba1\u68af\u5ea6\u3002\u5177\u4f53\u6765\u8bf4\uff0c\u8fd9\u4e2a\u51fd\u6570\u505a\u4e86\u4ee5\u4e0b\u64cd\u4f5c\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">void gpt2_zero_grad(GPT2 *model) {\n    \/\/ \u68c0\u67e5\u662f\u5426\u5b58\u5728\u6743\u91cd\u68af\u5ea6\uff08grads_memory\uff09\uff1a \n    \/\/ \u5982\u679c\u6a21\u578b\u7684\u6743\u91cd\u68af\u5ea6\u5185\u5b58(grads_memory)\u5df2\u7ecf\u88ab\u5206\u914d\uff0c\n    \/\/ \u90a3\u4e48\u4f7f\u7528memset\u51fd\u6570\u5c06\u6240\u6709\u6743\u91cd\u68af\u5ea6\u8bbe\u4e3a0\u3002\n    \/\/ \u8fd9\u91cc\u7684model-&gt;num_parameters\u8868\u793a\u6a21\u578b\u6240\u6709\u53c2\u6570\u7684\u603b\u6570\uff0c\n    \/\/ sizeof(float)\u8868\u793a\u6bcf\u4e2a\u68af\u5ea6\u503c\u5360\u7528\u7684\u5b57\u8282\u6570\u3002\n    \/\/ \u8fd9\u4e00\u6b65\u786e\u4fdd\u4e86\u5728\u5f00\u59cb\u65b0\u7684\u8bad\u7ec3\u6b65\u9aa4\u4e4b\u524d\uff0c\u6240\u6709\u7684\u6743\u91cd\u68af\u5ea6\u90fd\u88ab\u91cd\u7f6e\u3002\n    if(model-&gt;grads_memory != NULL) { memset(model-&gt;grads_memory, 0, model-&gt;num_parameters * sizeof(float)); }\n    \/\/ \u68c0\u67e5\u662f\u5426\u5b58\u5728\u6fc0\u6d3b\u68af\u5ea6\uff08grads_acts_memory\uff09\uff1a \n    \/\/ \u7c7b\u4f3c\u5730\uff0c\u5982\u679c\u6a21\u578b\u7684\u6fc0\u6d3b\u68af\u5ea6\u5185\u5b58(grads_acts_memory)\u5df2\u7ecf\u88ab\u5206\u914d\uff0c\n    \/\/ \u90a3\u4e48\u4e5f\u5c06\u6240\u6709\u6fc0\u6d3b\u68af\u5ea6\u6e05\u96f6\u3002\n    \/\/ model-&gt;num_activations\u8868\u793a\u6240\u6709\u6fc0\u6d3b\u503c\u7684\u603b\u6570\uff0c\n    \/\/ \u6bcf\u4e2a\u6fc0\u6d3b\u503c\u7684\u68af\u5ea6\u4e5f\u88ab\u91cd\u7f6e\u4e3a0\u3002\n    if(model-&gt;grads_acts_memory != NULL) { memset(model-&gt;grads_acts_memory, 0, model-&gt;num_activations * sizeof(float)); }\n}\n<\/pre><\/div>\n\n\n\n<p>\u8fd9\u4e2a\u8fc7\u7a0b\u662f\u6df1\u5ea6\u5b66\u4e60\u8bad\u7ec3\u4e2d\u7684\u6807\u51c6\u6b65\u9aa4\uff0c\u786e\u4fdd\u6bcf\u6b21\u53cd\u5411\u4f20\u64ad\u8ba1\u7b97\u7684\u68af\u5ea6\u4e0d\u4f1a\u4e0e\u524d\u4e00\u6b21\u8fed\u4ee3\u7684\u68af\u5ea6\u6df7\u6dc6\uff0c\u4ece\u800c\u4fdd\u969c\u8bad\u7ec3\u8fc7\u7a0b\u7684\u6b63\u786e\u6027\u548c\u7a33\u5b9a\u6027\u3002<\/p>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>1.25 gpt2_backward<\/strong><\/h2>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u5c55\u793a\u4e86GPT-2\u6a21\u578b\u7684\u53cd\u5411\u4f20\u64ad\u8fc7\u7a0b\uff0c\u5173\u952e\u6b65\u9aa4\u5982\u4e0b\uff1a<\/p>\n\n\n\n<ol class=\"wp-block-list\">\n<li><strong>\u68c0\u67e5\u524d\u5411\u4f20\u64ad\u662f\u5426\u5305\u542b\u76ee\u6807\uff1a<\/strong> \u901a\u8fc7\u68c0\u9a8c<code>model-&gt;mean_loss<\/code>\u662f\u5426\u4e3a-1.0f\u6765\u786e\u8ba4\u662f\u5426\u5df2\u8fdb\u884c\u5305\u542b\u76ee\u6807tokens\u7684\u524d\u5411\u4f20\u64ad\u3002\u5982\u679c\u672a\u8fdb\u884c\uff0c\u7a0b\u5e8f\u5c06\u62a5\u9519\u5e76\u9000\u51fa\u3002<\/li>\n\n\n\n<li><strong>\u5ef6\u8fdf\u5206\u914d\u68af\u5ea6\u5185\u5b58\uff1a<\/strong> \u5982\u679c\u5c1a\u672a\u4e3a\u6743\u91cd\u548c\u6fc0\u6d3b\u7684\u68af\u5ea6\u5206\u914d\u5185\u5b58\uff08<code>model-&gt;grads_memory<\/code>\u548c<code>model-&gt;grads_acts_memory<\/code>\u4e3a\u7a7a\uff09\uff0c\u5219\u8fdb\u884c\u5206\u914d\u5e76\u901a\u8fc7<code>gpt2_zero_grad(model)<\/code>\u521d\u59cb\u5316\u4e3a\u96f6\u3002<\/li>\n\n\n\n<li><strong>\u5b9a\u4e49\u4fbf\u6377\u53d8\u91cf\uff1a<\/strong> \u5b9a\u4e49\u4e86\u51e0\u4e2a\u4fbf\u6377\u53d8\u91cf\uff0c\u5982\u6279\u91cf\u5927\u5c0f\uff08B\uff09\u3001\u5e8f\u5217\u957f\u5ea6\uff08T\uff09\u3001\u8bcd\u6c47\u91cf\uff08V\uff09\u3001\u5c42\u6570\uff08L\uff09\u3001\u5934\u6570\uff08NH\uff09\u548c\u901a\u9053\u6570\uff08C\uff09\uff0c\u7b80\u5316\u4e86\u4ee3\u7801\u7684\u9605\u8bfb\u3002<\/li>\n\n\n\n<li><strong>\u521d\u59cb\u5316\u68af\u5ea6\uff1a<\/strong> \u4ee51.0\/(B*T)\u521d\u59cb\u5316<code>grads_acts.losses<\/code>\uff0c\u542f\u52a8\u94fe\u5f0f\u6cd5\u5219\u3002<\/li>\n\n\n\n<li><strong>\u53cd\u5411\u4f20\u64ad\u4ea4\u53c9\u71b5\u548cSoftmax\uff1a<\/strong> \u9996\u5148\u53cd\u5411\u4f20\u64ad\u4ea4\u53c9\u71b5\u548cSoftmax\u5c42\uff0c\u66f4\u65b0<code>grads_acts.logits<\/code>\u3002<\/li>\n\n\n\n<li><strong>\u9010\u5c42\u53cd\u5411\u4f20\u64ad\uff1a<\/strong> \u4ece\u6700\u540e\u4e00\u5c42\u5f00\u59cb\uff0c\u9006\u5e8f\u904d\u5386\u6bcf\u4e00\u5c42\uff0c\u5bf9\u6bcf\u4e00\u5c42\u6267\u884c\u4ee5\u4e0b\u64cd\u4f5c\uff1a<\/li>\n<\/ol>\n\n\n\n<ul class=\"wp-block-list\">\n<li>\u4f7f\u7528<code>residual_backward<\/code>\u5904\u7406\u6b8b\u5dee\u8fde\u63a5\u7684\u53cd\u5411\u4f20\u64ad\u3002<\/li>\n\n\n\n<li>\u901a\u8fc7<code>matmul_backward<\/code>\u3001<code>gelu_backward<\/code>\u3001<code>layernorm_backward<\/code>\u7b49\u51fd\u6570\u53cd\u5411\u4f20\u64ad\u8be5\u5c42\u7684\u7ebf\u6027\u53d8\u6362\u3001GELU\u975e\u7ebf\u6027\u548c\u5c42\u5f52\u4e00\u5316\u64cd\u4f5c\u3002<\/li>\n\n\n\n<li>\u66f4\u65b0\u6743\u91cd\u548c\u6fc0\u6d3b\u7684\u68af\u5ea6\u3002<\/li>\n<\/ul>\n\n\n\n<ol class=\"wp-block-list\">\n<li><strong>\u6743\u91cd\u548c\u6fc0\u6d3b\u7684\u68af\u5ea6\u66f4\u65b0\uff1a<\/strong> \u5728\u53cd\u5411\u4f20\u64ad\u7684\u6bcf\u4e00\u6b65\uff0c\u66f4\u65b0\u6a21\u578b\u53c2\u6570\u548c\u6fc0\u6d3b\u51fd\u6570\u7684\u68af\u5ea6\u3002<\/li>\n\n\n\n<li><strong>\u7f16\u7801\u5668\u7684\u53cd\u5411\u4f20\u64ad\uff1a<\/strong> \u6700\u540e\uff0c\u6267\u884c\u7f16\u7801\u5668\u7684\u53cd\u5411\u4f20\u64ad\uff0c\u66f4\u65b0\u8bcd\u5d4c\u5165\u548c\u4f4d\u7f6e\u5d4c\u5165\u7684\u68af\u5ea6\u3002<\/li>\n<\/ol>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">void gpt2_backward(GPT2 *model) {\n\n    \/\/ double check we forwarded previously, with targets\n    \/\/ \u786e\u4fdd\u5df2\u7ecf\u6267\u884c\u4e86\u5e26\u6709\u76ee\u6807tokens\u7684\u524d\u5411\u4f20\u64ad\n    if (model-&gt;mean_loss == -1.0f) {\n        printf(\"Error: must forward with targets before backward\\n\");\n        exit(1);\n    }\n\n    \/\/ lazily allocate the memory for gradients of the weights and activations, if needed\n    \/\/ \u5982\u679c\u68af\u5ea6\u7684\u5185\u5b58\u8fd8\u6ca1\u6709\u88ab\u5206\u914d\uff0c\u5219\u8fdb\u884c\u61d2\u52a0\u8f7d\u5206\u914d\u5e76\u521d\u59cb\u5316\u4e3a\u96f6\n    if (model-&gt;grads_memory == NULL) {\n        model-&gt;grads_memory = malloc_and_point_parameters(&amp;model-&gt;grads, model-&gt;param_sizes);\n        model-&gt;grads_acts_memory = malloc_and_point_activations(&amp;model-&gt;grads_acts, model-&gt;act_sizes);\n        \/\/ \u5c06\u68af\u5ea6\u521d\u59cb\u5316\u4e3a\u96f6\n        gpt2_zero_grad(model);\n    }\n\n    \/\/ convenience shortcuts\n    \/\/ \u5b9a\u4e49\u4fbf\u5229\u53d8\u91cf\n    int B = model-&gt;batch_size;\n    int T = model-&gt;seq_len;\n    int V = model-&gt;config.vocab_size;\n    int L = model-&gt;config.num_layers;\n    int NH = model-&gt;config.num_heads;\n    int C = model-&gt;config.channels;\n\n    \/\/ backward pass: go in the reverse order of the forward pass, and call backward() functions\n    \/\/ \u53cd\u5411\u4f20\u64ad\uff1a\u6309\u7167\u524d\u5411\u4f20\u64ad\u7684\u9006\u5e8f\u6267\u884c\uff0c\u5e76\u8c03\u7528\u76f8\u5e94\u7684\u53cd\u5411\u4f20\u64ad\u51fd\u6570\n    ParameterTensors params = model-&gt;params; \/\/ for brevity\n    ParameterTensors grads = model-&gt;grads;\n    ActivationTensors acts = model-&gt;acts;\n    ActivationTensors grads_acts = model-&gt;grads_acts;\n\n    \/\/ we kick off the chain rule by filling in dlosses with 1.0f\/(B*T)\n    \/\/ technically this is a small, inline backward() pass of calculating\n    \/\/ total, final loss as the mean over all losses over all (B,T) positions in the batch\n    \/\/ \u6211\u4eec\u901a\u8fc7\u75281.0f\/(B*T)\u586b\u5145dlosses\u6765\u542f\u52a8\u94fe\u5f0f\u6cd5\u5219\n    \/\/ \u4ece\u6280\u672f\u4e0a\u8bb2\uff0c\u8fd9\u662f\u4e00\u4e2a\u5c0f\u578b\u7684\u5185\u8054\u53cd\u5411\u4f20\u64ad\u6b65\u9aa4\uff0c\u7528\u4e8e\u8ba1\u7b97\u6279\u6b21\u4e2d\u6240\u6709(B,T)\u4f4d\u7f6e\u4e0a\u6240\u6709\u635f\u5931\u7684\u603b\u548c\u7684\u5e73\u5747\u503c\u4f5c\u4e3a\u6700\u7ec8\u7684\u635f\u5931\n    \/\/ \u5bf9\u6240\u6709(B,T)\u4f4d\u7f6e\u7684\u635f\u5931\u6c42\u5747\u503c\n    float dloss_mean = 1.0f \/ (B*T);\n    for (int i = 0; i &lt; B*T; i++) { grads_acts.losses[i] = dloss_mean; }\n\n    \/\/ \u4ea4\u53c9\u71b5\u548cSoftmax\u5c42\u7684\u53cd\u5411\u4f20\u64ad\n    crossentropy_softmax_backward(grads_acts.logits, grads_acts.losses, acts.probs, model-&gt;targets, B, T, V);\n    matmul_backward(grads_acts.lnf, grads.wte, NULL, grads_acts.logits, acts.lnf, params.wte, B, T, C, V);\n    \/\/ \u6700\u540e\u4e00\u5c42\u7684\u6b8b\u5dee\n    float* residual = acts.residual3 + (L-1) * B * T * C; \/\/ last layer's residual\n    \/\/ \u53cd\u5411\u4f20\u64ad\u5230\u6700\u540e\u4e00\u5c42\u7684\u6b8b\u5dee\n    float* dresidual = grads_acts.residual3 + (L-1) * B * T * C; \/\/ write to last layer's residual\n\n    \/\/ \u5c42\u5f52\u4e00\u5316\u7684\u53cd\u5411\u4f20\u64ad\n    layernorm_backward(dresidual, grads.lnfw, grads.lnfb, grads_acts.lnf, residual, params.lnfw, acts.lnf_mean, acts.lnf_rstd, B, T, C);\n\n    \/\/ \u9006\u5e8f\u904d\u5386\u6bcf\u4e00\u5c42\u8fdb\u884c\u53cd\u5411\u4f20\u64ad\n    for (int l = L-1; l &gt;= 0; l--) {\n\n        residual = l == 0 ? acts.encoded : acts.residual3 + (l-1) * B * T * C;\n        dresidual = l == 0 ? grads_acts.encoded : grads_acts.residual3 + (l-1) * B * T * C;\n\n        \/\/ get the pointers of the weights for this layer\n        \/\/ \u83b7\u53d6\u8fd9\u4e00\u5c42\u7684\u6743\u91cd\u6307\u9488\u3002\n        float* l_ln1w = params.ln1w + l * C;\n        float* l_qkvw = params.qkvw + l * 3*C * C;\n        float* l_attprojw = params.attprojw + l * C * C;\n        float* l_ln2w = params.ln2w + l * C;\n        float* l_fcw = params.fcw + l * 4*C * C;\n        float* l_fcprojw = params.fcprojw + l * C * 4*C;\n        \/\/ get the pointers of the gradients of the weights for this layer\n        \/\/ \u83b7\u53d6\u8fd9\u4e00\u5c42\u6743\u91cd\u68af\u5ea6\u7684\u6307\u9488\n        float* dl_ln1w = grads.ln1w + l * C;\n        float* dl_ln1b = grads.ln1b + l * C;\n        float* dl_qkvw = grads.qkvw + l * 3*C * C;\n        float* dl_qkvb = grads.qkvb + l * 3*C;\n        float* dl_attprojw = grads.attprojw + l * C * C;\n        float* dl_attprojb = grads.attprojb + l * C;\n        float* dl_ln2w = grads.ln2w + l * C;\n        float* dl_ln2b = grads.ln2b + l * C;\n        float* dl_fcw = grads.fcw + l * 4*C * C;\n        float* dl_fcb = grads.fcb + l * 4*C;\n        float* dl_fcprojw = grads.fcprojw + l * C * 4*C;\n        float* dl_fcprojb = grads.fcprojb + l * C;\n        \/\/ get the pointers of the activations for this layer\n        \/\/ \u83b7\u53d6\u8fd9\u4e00\u5c42\u7684\u6fc0\u6d3b\u503c\u6307\u9488\u3002\n        float* l_ln1 = acts.ln1 + l * B * T * C;\n        float* l_ln1_mean = acts.ln1_mean + l * B * T;\n        float* l_ln1_rstd = acts.ln1_rstd + l * B * T;\n        float* l_qkv = acts.qkv + l * B * T * 3*C;\n        float* l_atty = acts.atty + l * B * T * C;\n        float* l_att = acts.att + l * B * NH * T * T;\n        float* l_residual2 = acts.residual2 + l * B * T * C;\n        float* l_ln2 = acts.ln2 + l * B * T * C;\n        float* l_ln2_mean = acts.ln2_mean + l * B * T;\n        float* l_ln2_rstd = acts.ln2_rstd + l * B * T;\n        float* l_fch = acts.fch + l * B * T * 4*C;\n        float* l_fch_gelu = acts.fch_gelu + l * B * T * 4*C;\n        \/\/ get the pointers of the gradients of the activations for this layer\n        \/\/ \u83b7\u53d6\u8fd9\u4e00\u5c42\u6fc0\u6d3b\u68af\u5ea6\u7684\u6307\u9488\n        float* dl_ln1 = grads_acts.ln1 + l * B * T * C;\n        float* dl_qkv = grads_acts.qkv + l * B * T * 3*C;\n        float* dl_atty = grads_acts.atty + l * B * T * C;\n        float* dl_preatt = grads_acts.preatt + l * B * NH * T * T;\n        float* dl_att = grads_acts.att + l * B * NH * T * T;\n        float* dl_attproj = grads_acts.attproj + l * B * T * C;\n        float* dl_residual2 = grads_acts.residual2 + l * B * T * C;\n        float* dl_ln2 = grads_acts.ln2 + l * B * T * C;\n        float* dl_fch = grads_acts.fch + l * B * T * 4*C;\n        float* dl_fch_gelu = grads_acts.fch_gelu + l * B * T * 4*C;\n        float* dl_fcproj = grads_acts.fcproj + l * B * T * C;\n        float* dl_residual3 = grads_acts.residual3 + l * B * T * C;\n\n        \/\/ backprop this layer\n        \/\/ \u5bf9\u6b8b\u5dee\u8fde\u63a5\u8fdb\u884c\u53cd\u5411\u4f20\u64ad\n        residual_backward(dl_residual2, dl_fcproj, dl_residual3, B*T*C);\n        \/\/ \u5bf9\u5168\u8fde\u63a5\u5c42\uff08\u4f7f\u7528GELU\u6fc0\u6d3b\u51fd\u6570\u4e4b\u540e\uff09\u8fdb\u884c\u53cd\u5411\u4f20\u64ad\n        matmul_backward(dl_fch_gelu, dl_fcprojw, dl_fcprojb, dl_fcproj, l_fch_gelu, l_fcprojw, B, T, 4*C, C);\n        \/\/ \u5bf9GELU\u6fc0\u6d3b\u51fd\u6570\u8fdb\u884c\u53cd\u5411\u4f20\u64ad\n        gelu_backward(dl_fch, l_fch, dl_fch_gelu, B*T*4*C);\n        \/\/ \u5bf9\u5168\u8fde\u63a5\u5c42\uff08\u8f93\u5165\u5230GELU\u6fc0\u6d3b\u51fd\u6570\u4e4b\u524d\uff09\u8fdb\u884c\u53cd\u5411\u4f20\u64ad\n        matmul_backward(dl_ln2, dl_fcw, dl_fcb, dl_fch, l_ln2, l_fcw, B, T, C, 4*C);\n        \/\/ \u5bf9\u5c42\u5f52\u4e00\u5316\u8fdb\u884c\u53cd\u5411\u4f20\u64ad\uff0c\u8fd9\u662f\u5e94\u7528\u5728\u52a0\u6743\u6b8b\u5dee\u8fde\u63a5\u4e4b\u540e\u7684\n        layernorm_backward(dl_residual2, dl_ln2w, dl_ln2b, dl_ln2, l_residual2, l_ln2w, l_ln2_mean, l_ln2_rstd, B, T, C);\n        \/\/ \u5bf9\u53e6\u4e00\u4e2a\u6b8b\u5dee\u8fde\u63a5\u8fdb\u884c\u53cd\u5411\u4f20\u64ad\n        residual_backward(dresidual, dl_attproj, dl_residual2, B*T*C);\n        \/\/ \u5bf9\u5e94\u7528\u5728\u6ce8\u610f\u529b\u673a\u5236\u8f93\u51fa\u4e0a\u7684\u5168\u8fde\u63a5\u5c42\u8fdb\u884c\u53cd\u5411\u4f20\u64ad\n        matmul_backward(dl_atty, dl_attprojw, dl_attprojb, dl_attproj, l_atty, l_attprojw, B, T, C, C);\n        \/\/ \u5bf9\u6ce8\u610f\u529b\u673a\u5236\u672c\u8eab\u8fdb\u884c\u53cd\u5411\u4f20\u64ad\n        attention_backward(dl_qkv, dl_preatt, dl_att, dl_atty, l_qkv, l_att, B, T, C, NH);\n        \/\/ \u5bf9\u5e94\u7528\u5728\u5c42\u5f52\u4e00\u5316\u4e4b\u524d\u7684\u5168\u8fde\u63a5\u5c42\u8fdb\u884c\u53cd\u5411\u4f20\u64ad\n        matmul_backward(dl_ln1, dl_qkvw, dl_qkvb, dl_qkv, l_ln1, l_qkvw, B, T, C, 3*C);\n        \/\/ \u5bf9\u53e6\u4e00\u4e2a\u5c42\u5f52\u4e00\u5316\u8fdb\u884c\u53cd\u5411\u4f20\u64ad\uff0c\u8fd9\u662f\u5e94\u7528\u5728\u6b8b\u5dee\u8fde\u63a5\u4e4b\u524d\u7684\n        layernorm_backward(dresidual, dl_ln1w, dl_ln1b, dl_ln1, residual, l_ln1w, l_ln1_mean, l_ln1_rstd, B, T, C);\n    }\n    \/\/ \u7f16\u7801\u5668\u6743\u91cd\u68af\u5ea6\u7684\u66f4\u65b0\n    encoder_backward(grads.wte, grads.wpe, grads_acts.encoded, model-&gt;inputs, B, T, C);\n}\n<\/pre><\/div>\n\n\n\n<p>\u8fd9\u4e2a\u8fc7\u7a0b\u662f\u6df1\u5ea6\u5b66\u4e60\u8bad\u7ec3\u7684\u6838\u5fc3\uff0c\u901a\u8fc7\u8ba1\u7b97\u635f\u5931\u51fd\u6570\u76f8\u5bf9\u4e8e\u6bcf\u4e2a\u53c2\u6570\u7684\u68af\u5ea6\uff0c\u5e76\u5229\u7528\u8fd9\u4e9b\u68af\u5ea6\u6765\u66f4\u65b0\u6a21\u578b\u53c2\u6570\uff0c\u4ece\u800c\u6700\u5c0f\u5316\u635f\u5931\u51fd\u6570\uff0c\u63d0\u5347\u6a21\u578b\u6027\u80fd\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.26 gpt2_update<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u662fGPT-2\u6a21\u578b\u7684\u53c2\u6570\u66f4\u65b0\u8fc7\u7a0b\uff0c\u4f7f\u7528AdamW\u4f18\u5316\u5668\u8fdb\u884c\u66f4\u65b0\u3002<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, int t) {\n    \/\/ reference: https:\/\/pytorch.org\/docs\/stable\/generated\/torch.optim.AdamW.html\n\n    \/\/ lazily allocate the memory for m_memory and v_memory\n    \/\/ \u5982\u679c\u8fd8\u6ca1\u6709\u4e3am_memory\u548cv_memory\u5206\u914d\u5185\u5b58\uff0c\u5219\u61d2\u52a0\u8f7d\u5206\u914d\u5185\u5b58\n    if (model-&gt;m_memory == NULL) {\n        \/\/ \u4e3a\u7b2c\u4e00\u77e9\u5206\u914d\u5185\u5b58\n        model-&gt;m_memory = (float*)calloc(model-&gt;num_parameters, sizeof(float));\n        \/\/ \u4e3a\u7b2c\u4e8c\u77e9\u5206\u914d\u5185\u5b58\n        model-&gt;v_memory = (float*)calloc(model-&gt;num_parameters, sizeof(float));\n    }\n\n    for (int i = 0; i &lt; model-&gt;num_parameters; i++) {\n        \/\/ \u5f53\u524d\u53c2\u6570\u503c\n        float param = model-&gt;params_memory[i];\n        \/\/ \u5f53\u524d\u68af\u5ea6\u503c\n        float grad = model-&gt;grads_memory[i];\n\n        \/\/ update the first moment (momentum)\n        \/\/ \u66f4\u65b0\u7b2c\u4e00\u77e9\uff08\u52a8\u91cf\uff09\n        float m = beta1 * model-&gt;m_memory[i] + (1.0f - beta1) * grad;\n        \/\/ update the second moment (RMSprop)\n        \/\/ \u66f4\u65b0\u7b2c\u4e8c\u77e9\uff08RMSprop\uff09\n        float v = beta2 * model-&gt;v_memory[i] + (1.0f - beta2) * grad * grad;\n        \/\/ bias-correct both moments\n        \/\/ \u5bf9\u4e24\u4e2a\u77e9\u8fdb\u884c\u504f\u5dee\u4fee\u6b63\n        float m_hat = m \/ (1.0f - powf(beta1, t));\n        float v_hat = v \/ (1.0f - powf(beta2, t));\n\n        \/\/ update\n        \/\/ \u6839\u636eAdamW\u4f18\u5316\u7b97\u6cd5\u66f4\u65b0\u53c2\u6570\n        \/\/ \u66f4\u65b0\u52a8\u91cf\n        model-&gt;m_memory[i] = m;\n        \/\/ \u66f4\u65b0RMSprop\n        model-&gt;v_memory[i] = v;\n        \/\/ \u66f4\u65b0\u53c2\u6570\u503c\uff0c\u5305\u62ec\u5b66\u4e60\u7387\u3001\u504f\u5dee\u4fee\u6b63\u540e\u7684\u52a8\u91cf\u548cRMSprop\u4ee5\u53ca\u6743\u91cd\u8870\u51cf\n        model-&gt;params_memory[i] -= learning_rate * (m_hat \/ (sqrtf(v_hat) + eps) + weight_decay * param);\n    }\n}\n<\/pre><\/div>\n\n\n\n<p>\u8fd9\u4e2a\u51fd\u6570\u6839\u636e\u6a21\u578b\u5f53\u524d\u7684\u68af\u5ea6\u548c\u5386\u53f2\u52a8\u91cf\u3001RMSprop\u503c\u6765\u66f4\u65b0\u6a21\u578b\u53c2\u6570\u3002\u5b83\u9996\u5148\u68c0\u67e5\u662f\u5426\u5df2\u4e3a\u52a8\u91cf(m_memory)\u548cRMSprop(v_memory)\u5206\u914d\u5185\u5b58\uff0c\u5982\u679c\u6ca1\u6709\uff0c\u5219\u8fdb\u884c\u5206\u914d\u3002\u63a5\u7740\uff0c\u5bf9\u6bcf\u4e2a\u53c2\u6570\uff0c\u8ba1\u7b97\u5176\u66f4\u65b0\u540e\u7684\u503c\uff0c\u5176\u4e2d\u5305\u62ec\u7b2c\u4e00\u77e9\u548c\u7b2c\u4e8c\u77e9\u7684\u66f4\u65b0\uff0c\u4ee5\u53ca\u5e94\u7528\u504f\u5dee\u4fee\u6b63\u548c\u6743\u91cd\u8870\u51cf\u3002\u8fd9\u6837\u53ef\u4ee5\u5728\u8bad\u7ec3\u8fc7\u7a0b\u4e2d\u9010\u6b65\u4f18\u5316\u6a21\u578b\u53c2\u6570\uff0c\u4ee5\u671f\u8fbe\u5230\u66f4\u597d\u7684\u6027\u80fd\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.27 gpt2_free<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u91ca\u653e\u4e86GPT-2\u6a21\u578b\u4e2d\u4f7f\u7528\u7684\u6240\u6709\u52a8\u6001\u5206\u914d\u7684\u5185\u5b58\u3002\u5177\u4f53\u6765\u8bf4\uff0c\u5b83\u91ca\u653e\u4e86\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">void gpt2_free(GPT2 *model) {\n    \/\/ \u91ca\u653e\u53c2\u6570\u5185\u5b58\n    free(model-&gt;params_memory);\n    \/\/ \u91ca\u653e\u68af\u5ea6\u5185\u5b58\n    free(model-&gt;grads_memory);\n    \/\/ \u91ca\u653e\u52a8\u91cf\u5185\u5b58\n    free(model-&gt;m_memory);\n    \/\/ \u91ca\u653eRMSprop\u5185\u5b58\n    free(model-&gt;v_memory);\n    \/\/ \u91ca\u653e\u6fc0\u6d3b\u5185\u5b58\n    free(model-&gt;acts_memory);\n    \/\/ \u91ca\u653e\u6fc0\u6d3b\u68af\u5ea6\u5185\u5b58\n    free(model-&gt;grads_acts_memory);\n    \/\/ \u91ca\u653e\u8f93\u5165\u5185\u5b58\n    free(model-&gt;inputs);\n    \/\/ \u91ca\u653e\u76ee\u6807token\u5185\u5b58\n    free(model-&gt;targets);\n}\n<\/pre><\/div>\n\n\n\n<p>\u6b64\u51fd\u6570\u901a\u5e38\u5728\u6a21\u578b\u8bad\u7ec3\u5b8c\u6210\u6216\u4e0d\u518d\u9700\u8981\u6a21\u578b\u65f6\u8c03\u7528\uff0c\u4ee5\u786e\u4fdd\u53ca\u65f6\u56de\u6536\u8d44\u6e90\uff0c\u907f\u514d\u5185\u5b58\u6cc4\u9732\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.28 DataLoader<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u4e2a\u7ed3\u6784\u4f53\u5b9a\u4e49\u4e86\u4e00\u4e2a\u6570\u636e\u52a0\u8f7d\u5668\uff08DataLoader\uff09\uff0c\u5b83\u8d1f\u8d23\u4ece\u6587\u4ef6\u4e2d\u52a0\u8f7d\u8bad\u7ec3\u6216\u9a8c\u8bc1\u6570\u636e\uff0c\u4ee5\u4fbf\u7528\u4e8e\u6a21\u578b\u7684\u8bad\u7ec3\u6216\u8bc4\u4f30\u3002\u5177\u4f53\u5305\u62ec\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">typedef struct {\n    \/\/ hyperparameters\n    \/\/ \u8d85\u53c2\u6570\n    \/\/ B\uff1a\u6279\u91cf\u5927\u5c0f\uff08batch size\uff09\uff0c\u5373\u6bcf\u6b21\u8bad\u7ec3\u6216\u8bc4\u4f30\u6a21\u578b\u65f6\u4e00\u6b21\u6027\u5904\u7406\u7684\u6570\u636e\u6570\u91cf\u3002\n    int B; \/\/ batch size\n    \/\/ T\uff1a\u5e8f\u5217\u957f\u5ea6\uff08sequence length\uff09\uff0c\u6307\u5355\u4e2a\u8f93\u5165\u6570\u636e\u7684\u957f\u5ea6\uff0c\u4f8b\u5982\u4e00\u4e2a\u53e5\u5b50\u4e2d\u7684\u5355\u8bcd\u6216\u5b57\u7b26\u6570\u91cf\u3002\n    int T; \/\/ sequence length\n    \/\/ input handling and its state\n    \/\/ \u6307\u5411\u5305\u542b\u8bad\u7ec3\u6216\u9a8c\u8bc1\u6570\u636e\u7684\u6587\u4ef6\u7684\u6307\u9488\u3002\n    FILE* tokens_file;\n    \/\/ \u6587\u4ef6\u7684\u5927\u5c0f\uff0c\u5355\u4f4d\u4e3a\u5b57\u8282\u3002\n    long file_size;\n    \/\/ \u5f53\u524d\u8bfb\u53d6\u4f4d\u7f6e\uff0c\u5728\u6587\u4ef6\u4e2d\u7684\u504f\u79fb\u91cf\n    long current_position;\n    \/\/ output memory\n    \/\/ \u8f93\u51fa\u5185\u5b58\n    \/\/ batch\uff1a\u6307\u5411\u5f53\u524d\u6279\u91cf\u6570\u636e\u7684\u6307\u9488\u3002\u5b83\u53ef\u80fd\u5305\u542b\u4e00\u6279\u8f93\u5165\u6570\u636e\uff08tokens\uff09\u3002\n    int* batch;\n    \/\/ inputs\uff1a\u6307\u5411\u5f53\u524d\u6279\u91cf\u7684\u8f93\u5165\u6570\u636e\u7684\u6307\u9488\u3002\n    \/\/ \u5728\u8bad\u7ec3GPT\u6a21\u578b\u65f6\uff0c\u8fd9\u5c06\u662f\u4e00\u7cfb\u5217token\u7684ID\u3002\n    int* inputs;\n    \/\/ targets\uff1a\u6307\u5411\u5f53\u524d\u6279\u91cf\u7684\u76ee\u6807\u6570\u636e\u7684\u6307\u9488\u3002\n    \/\/ \u5728\u8bad\u7ec3\u4e2d\uff0c\u76ee\u6807\u901a\u5e38\u662f\u9884\u6d4b\u4e0b\u4e00\u4e2atoken\u7684ID\u3002\n    int* targets;\n    \/\/ convenience variables\n    \/\/ \u4fbf\u5229\u53d8\u91cf\n    \/\/ num_batches\uff1a\u6839\u636e\u6587\u4ef6\u5927\u5c0f\u3001\u6279\u91cf\u5927\u5c0f\u548c\u5e8f\u5217\u957f\u5ea6\u8ba1\u7b97\u7684\u603b\u6279\u6b21\u6570\u3002\n    int num_batches;\n} DataLoader;<\/pre><\/div>\n\n\n\n<p>\u8fd9\u4e2a\u7ed3\u6784\u4f53\u662f\u5904\u7406\u548c\u51c6\u5907\u6570\u636e\u96c6\u4ee5\u4f9b\u6a21\u578b\u8bad\u7ec3\u548c\u8bc4\u4f30\u4f7f\u7528\u7684\u91cd\u8981\u7ec4\u6210\u90e8\u5206\uff0c\u901a\u8fc7\u4ece\u6587\u4ef6\u4e2d\u6309\u6279\u6b21\u52a0\u8f7d\u6570\u636e\uff0c\u80fd\u591f\u6709\u6548\u7ba1\u7406\u5185\u5b58\u4f7f\u7528\uff0c\u540c\u65f6\u4e5f\u652f\u6301\u5927\u89c4\u6a21\u6570\u636e\u96c6\u7684\u8bad\u7ec3\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.29 dataloader_init<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u521d\u59cb\u5316\u4e86\u4e00\u4e2a\u6570\u636e\u52a0\u8f7d\u5668\uff0c\u7528\u4e8e\u4ece\u6587\u4ef6\u4e2d\u52a0\u8f7d\u8bad\u7ec3\u6216\u9a8c\u8bc1\u6570\u636e\u3002<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">void dataloader_init(DataLoader *loader, char* filename, int B, int T) {\n    \/\/ \u8bbe\u7f6e\u6279\u91cf\u5927\u5c0f\uff08B\uff09\u548c\u5e8f\u5217\u957f\u5ea6\uff08T\uff09\n    loader-&gt;B = B;\n    loader-&gt;T = T;\n\n    \/\/ open the input file for reading\n    \/\/ \u6253\u5f00\u6307\u5b9a\u7684\u6587\u4ef6\u8fdb\u884c\u8bfb\u53d6\u3002\u5982\u679c\u6587\u4ef6\u65e0\u6cd5\u6253\u5f00\uff0c\u7a0b\u5e8f\u5c06\u62a5\u9519\u5e76\u9000\u51fa\u3002\n    loader-&gt;tokens_file = fopen(filename, \"rb\");\n    if (loader-&gt;tokens_file == NULL) {\n        printf(\"Error opening tokens file\\n\");\n        exit(1);\n    }\n\n    \/\/ determine the file size\n    \/\/ \u786e\u5b9a\u6587\u4ef6\u5927\u5c0f\uff0c\u4ee5\u4fbf\u77e5\u9053\u6709\u591a\u5c11\u6570\u636e\u53ef\u7528\u3002\n    fseek(loader-&gt;tokens_file, 0, SEEK_END);\n    loader-&gt;file_size = ftell(loader-&gt;tokens_file);\n    fseek(loader-&gt;tokens_file, 0, SEEK_SET);\n    \/\/ \u68c0\u67e5\u6587\u4ef6\u5927\u5c0f\u662f\u5426\u8db3\u591f\u81f3\u5c11\u5305\u542b\u4e00\u4e2a\u6279\u6b21\u7684\u6570\u636e\u3002\u5982\u679c\u4e0d\u591f\uff0c\u7a0b\u5e8f\u5c06\u62a5\u9519\u5e76\u9000\u51fa\u3002\n    if (loader-&gt;file_size &lt; (B * T + 1) * sizeof(int)) {\n        printf(\"Error: file size is too small for the batch size and sequence length\\n\");\n        exit(1);\n    }\n    \/\/ \u8bbe\u7f6e\u5f53\u524d\u8bfb\u53d6\u4f4d\u7f6e\u4e3a\u6587\u4ef6\u5f00\u5934\u3002\n    loader-&gt;current_position = 0; \/\/ start at the beginning\n\n    \/\/ allocate space for B*T + 1 integers to store the inputs and targets\n    \/\/ \u4e3a\u6574\u4e2a\u6279\u6b21\u5206\u914d\u5185\u5b58\u7a7a\u95f4\uff0c\u5305\u62ec\u8f93\u5165\u548c\u76ee\u6807\u6570\u636e\u3002\n    \/\/ \u8fd9\u91cc\u591a\u5206\u914d\u4e86\u4e00\u4e2a\u6574\u6570\u7684\u7a7a\u95f4\uff0c\u7528\u4e8e\u5904\u7406\u76ee\u6807\u6570\u636e\u65f6\u7684\u4f4d\u79fb\u3002\n    loader-&gt;batch = (int*) malloc((B * T + 1) * sizeof(int));\n    loader-&gt;inputs = loader-&gt;batch;\n    loader-&gt;targets = loader-&gt;batch + 1; \/\/ targets are shifted by one\n    \/\/ \u8ba1\u7b97\u6587\u4ef6\u4e2d\u603b\u5171\u53ef\u4ee5\u5206\u6210\u591a\u5c11\u4e2a\u6279\u6b21\uff0c\u8fd9\u53d6\u51b3\u4e8e\u6587\u4ef6\u5927\u5c0f\u3001\u6279\u91cf\u5927\u5c0f\u548c\u5e8f\u5217\u957f\u5ea6\u3002\n    loader-&gt;num_batches = loader-&gt;file_size \/ (B * T * sizeof(int));\n}\n<\/pre><\/div>\n\n\n\n<p>\u901a\u8fc7\u8fd9\u4e2a\u521d\u59cb\u5316\u8fc7\u7a0b\uff0c\u6570\u636e\u52a0\u8f7d\u5668\u51c6\u5907\u597d\u4ece\u6587\u4ef6\u4e2d\u8bfb\u53d6\u6570\u636e\uff0c\u4ee5\u4f9b\u6a21\u578b\u8bad\u7ec3\u6216\u8bc4\u4f30\u4f7f\u7528\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.30 dataloader_reset<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u5c06\u6570\u636e\u52a0\u8f7d\u5668\u7684\u5f53\u524d\u8bfb\u53d6\u4f4d\u7f6e\u91cd\u7f6e\u4e3a\u6587\u4ef6\u7684\u5f00\u5934\u3002\u8fd9\u901a\u5e38\u5728\u6bcf\u6b21\u65b0\u7684\u6570\u636e\u904d\u5386\u5f00\u59cb\u65f6\u4f7f\u7528\uff0c\u786e\u4fdd\u4ece\u6587\u4ef6\u7684\u5f00\u59cb\u5904\u91cd\u65b0\u5f00\u59cb\u8bfb\u53d6\u6570\u636e\u3002<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">void dataloader_reset(DataLoader *loader) {\n    \/\/ \u8bbe\u7f6e\u5f53\u524d\u8bfb\u53d6\u4f4d\u7f6e\u4e3a\u6587\u4ef6\u5f00\u5934\u3002\n    loader-&gt;current_position = 0;\n}\n<\/pre><\/div>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.31 dataloader_next_batch<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u7528\u4e8e\u4ece\u6570\u636e\u6587\u4ef6\u4e2d\u8bfb\u53d6\u4e0b\u4e00\u4e2a\u6279\u6b21\u7684\u6570\u636e\u3002<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">void dataloader_next_batch(DataLoader *loader) {\n    \/\/ \u6279\u91cf\u5927\u5c0f\n    int B = loader-&gt;B;\n    \/\/ \u5e8f\u5217\u957f\u5ea6\n    int T = loader-&gt;T;\n    \/\/ if we are at the end of the file, loop back to the beginning\n    \/\/ \u5982\u679c\u6587\u4ef6\u672b\u5c3e\u4e0d\u8db3\u4ee5\u63d0\u4f9b\u4e00\u4e2a\u5b8c\u6574\u7684\u6279\u6b21\uff0c\u5c06\u4f4d\u7f6e\u91cd\u7f6e\u5230\u6587\u4ef6\u5f00\u5934\n    if (loader-&gt;current_position + (B*T+1) * sizeof(int) &gt; loader-&gt;file_size) {\n        loader-&gt;current_position = 0;\n    }\n    \/\/ read the B*T+1 integers from the file into batch\n    \/\/ \u4ece\u5f53\u524d\u4f4d\u7f6e\u8bfb\u53d6B*T+1\u4e2a\u6574\u6570\u5230\u6279\u91cf\u6570\u636e\u533a\n    fseek(loader-&gt;tokens_file, loader-&gt;current_position, SEEK_SET);\n    fread(loader-&gt;batch, sizeof(int), B*T+1, loader-&gt;tokens_file);\n    \/\/ advance the current position by B*T integers\n    \/\/ \u66f4\u65b0\u5f53\u524d\u4f4d\u7f6e\uff0c\u524d\u8fdbB*T\u4e2a\u6574\u6570\u7684\u957f\u5ea6\uff0c\u4e3a\u4e0b\u4e00\u6b21\u8bfb\u53d6\u505a\u51c6\u5907\n    loader-&gt;current_position += B*T * sizeof(int);\n}\n<\/pre><\/div>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u5b9e\u73b0\u4e86\u6570\u636e\u7684\u5faa\u73af\u8bfb\u53d6\uff1a\u5f53\u6570\u636e\u8bfb\u53d6\u5230\u6587\u4ef6\u672b\u5c3e\u65f6\uff0c\u81ea\u52a8\u4ece\u6587\u4ef6\u5f00\u5934\u7ee7\u7eed\u8bfb\u53d6\uff0c\u4fdd\u8bc1\u8fde\u7eed\u7684\u6570\u636e\u6d41\u4f9b\u6a21\u578b\u8bad\u7ec3\u3002\u8fd9\u5bf9\u4e8e\u8bad\u7ec3\u5468\u671f\u591a\u6b21\u904d\u5386\u6570\u636e\u96c6\u65f6\u975e\u5e38\u6709\u7528\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.32 dataloader_free<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u7528\u4e8e\u91ca\u653e\u6570\u636e\u52a0\u8f7d\u5668\uff08<code>DataLoader<\/code>\uff09\u4f7f\u7528\u7684\u8d44\u6e90\u3002<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">void dataloader_free(DataLoader *loader) {\n    \/\/ \u5173\u95ed\u6253\u5f00\u7684\u6587\u4ef6\n    fclose(loader-&gt;tokens_file);\n    \/\/ \u91ca\u653e\u5206\u914d\u7684\u5185\u5b58\u7a7a\u95f4\uff0c\u7528\u4e8e\u5b58\u50a8\u6279\u6b21\u6570\u636e\n    free(loader-&gt;batch);\n}\n<\/pre><\/div>\n\n\n\n<p>\u901a\u8fc7\u8c03\u7528\u6b64\u51fd\u6570\uff0c\u53ef\u4ee5\u786e\u4fdd\u5728\u6570\u636e\u52a0\u8f7d\u5668\u4e0d\u518d\u9700\u8981\u65f6\uff0c\u76f8\u5173\u8d44\u6e90\u88ab\u9002\u5f53\u5730\u91ca\u653e\uff0c\u907f\u514d\u5185\u5b58\u6cc4\u6f0f\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.33 random_u32<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u5b9e\u73b0\u4e86\u4e00\u4e2a\u7b80\u5355\u7684\u968f\u673a\u6570\u751f\u6210\u5668\uff08RNG\uff09\u4f7f\u7528xorshift\u7b97\u6cd5\u3002<code>xorshift<\/code>\u662f\u4e00\u79cd\u5feb\u901f\u3001\u9ad8\u8d28\u91cf\u7684\u4f2a\u968f\u673a\u6570\u751f\u6210\u5668\uff0c\u5e7f\u6cdb\u7528\u4e8e\u5404\u79cd\u8ba1\u7b97\u573a\u666f\u3002\u8fd9\u4e2a\u7279\u5b9a\u7684\u7248\u672c\u4f7f\u7528\u4e86\u4e00\u4e2a64\u4f4d\u72b6\u6001\u53d8\u91cf\uff0c\u5e76\u901a\u8fc7\u4e00\u7cfb\u5217\u4f4d\u79fb\u548c\u5f02\u6216\u64cd\u4f5c\u6765\u751f\u6210\u65b0\u7684\u968f\u673a\u6570\u3002<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">\/\/ the GPT-2 end-of-text token id\n\/\/ GPT-2\u6587\u672c\u7ed3\u675f\u6807\u8bb0\u7684id\n#define GPT2_EOT 50256\n\n\/\/ \u5b9e\u73b0\u4e86\u4e00\u4e2a\u7b80\u5355\u7684\u968f\u673a\u6570\u751f\u6210\u5668\uff08RNG\uff09\u4f7f\u7528xorshift\u7b97\u6cd5\u3002\n\/\/ xorshift\u662f\u4e00\u79cd\u5feb\u901f\u3001\u9ad8\u8d28\u91cf\u7684\u4f2a\u968f\u673a\u6570\u751f\u6210\u5668\uff0c\u5e7f\u6cdb\u7528\u4e8e\u5404\u79cd\u8ba1\u7b97\u573a\u666f\u3002\n\/\/ \u8fd9\u4e2a\u7279\u5b9a\u7684\u7248\u672c\u4f7f\u7528\u4e86\u4e00\u4e2a64\u4f4d\u72b6\u6001\u53d8\u91cf\uff0c\u5e76\u901a\u8fc7\u4e00\u7cfb\u5217\u4f4d\u79fb\u548c\u5f02\u6216\u64cd\u4f5c\u6765\u751f\u6210\u65b0\u7684\u968f\u673a\u6570\u3002\nunsigned int random_u32(unsigned long long *state) {\n    \/\/ xorshift rng: https:\/\/en.wikipedia.org\/wiki\/Xorshift#xorshift.2A\n    \/\/ \/\/ \u5bf9\u72b6\u6001\u53d8\u91cf\u8fdb\u884c\u53f3\u79fb12\u4f4d\u5e76\u4e0e\u539f\u72b6\u6001\u8fdb\u884c\u5f02\u6216\u64cd\u4f5c\n    *state ^= *state &gt;&gt; 12;\n    \/\/ \u5bf9\u72b6\u6001\u53d8\u91cf\u8fdb\u884c\u5de6\u79fb25\u4f4d\u5e76\u4e0e\u5f53\u524d\u72b6\u6001\u8fdb\u884c\u5f02\u6216\u64cd\u4f5c\n    *state ^= *state &lt;&lt; 25;\n    \/\/ \u5bf9\u72b6\u6001\u53d8\u91cf\u8fdb\u884c\u53f3\u79fb27\u4f4d\u5e76\u4e0e\u5f53\u524d\u72b6\u6001\u8fdb\u884c\u5f02\u6216\u64cd\u4f5c\n    *state ^= *state &gt;&gt; 27;\n    \/\/ \u4f7f\u7528\u4e00\u4e2a\u9b54\u6cd5\u5e38\u6570\u4e58\u4ee5\u5f53\u524d\u72b6\u6001\uff0c\u5e76\u53f3\u79fb32\u4f4d\u6765\u751f\u6210\u6700\u7ec8\u7684\u968f\u673a\u6570\n    return (*state * 0x2545F4914F6CDD1Dull) &gt;&gt; 32;\n}\n<\/pre><\/div>\n\n\n\n<p>\u901a\u8fc7\u6539\u53d8\u72b6\u6001\u53d8\u91cf\uff0c\u8fd9\u4e2a\u51fd\u6570\u80fd\u591f\u5728\u6bcf\u6b21\u8c03\u7528\u65f6\u751f\u6210\u4e00\u4e2a\u65b0\u7684\u65e0\u7b26\u53f732\u4f4d\u6574\u6570\u4f5c\u4e3a\u968f\u673a\u6570\u3002\u8fd9\u79cd\u65b9\u6cd5\u7684\u4f18\u70b9\u662f\u901f\u5ea6\u5feb\u4e14\u5b9e\u73b0\u7b80\u5355\uff0c\u4f46\u7531\u4e8e\u5b83\u662f\u4f2a\u968f\u673a\u7684\uff0c\u751f\u6210\u7684\u968f\u673a\u5e8f\u5217\u662f\u53ef\u9884\u6d4b\u7684\uff0c\u56e0\u6b64\u4e0d\u9002\u7528\u4e8e\u9700\u8981\u9ad8\u5b89\u5168\u6027\u7684\u968f\u673a\u6570\u751f\u6210\u573a\u666f\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.34 random_f32<\/strong><\/h3>\n\n\n\n<p>\u751f\u6210\u4e00\u4e2a\u5728[0, 1)\u533a\u95f4\u5185\u7684\u968f\u673afloat32\u6570\u5b57\u3002<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">float random_f32(unsigned long long *state) { \/\/ random float32 in [0,1)\n    return (random_u32(state) &gt;&gt; 8) \/ 16777216.0f;\n}\n<\/pre><\/div>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.35 sample_mult<\/strong><\/h3>\n\n\n\n<p>\u4ece\u6982\u7387\u5206\u5e03\u4e2d\u91c7\u6837\u4e00\u4e2a\u7d22\u5f15\u3002\u8fd9\u4e9b\u6982\u7387\u503c\u7684\u603b\u548c\u5fc5\u987b\u4e3a1\uff01\u53c2\u6570<code>coin<\/code>\u662f\u4e00\u4e2a\u5728[0, 1)\u533a\u95f4\u5185\u7684\u968f\u673a\u6570\uff0c\u901a\u5e38\u7531<code>random_f32()<\/code>\u51fd\u6570\u751f\u6210\u3002\u6b64\u51fd\u6570\u8ba1\u7b97\u7d2f\u79ef\u5206\u5e03\u51fd\u6570\uff08CDF\uff09\uff0c\u5e76\u5728<code>coin<\/code>\u5c0f\u4e8eCDF\u7684\u5f53\u524d\u503c\u65f6\u8fd4\u56de\u5bf9\u5e94\u7684\u7d22\u5f15\u3002\u5982\u679c\u56e0\u4e3a\u820d\u5165\u8bef\u5dee\uff0c<code>coin<\/code>\u6ca1\u6709\u5c0f\u4e8e\u4efb\u4f55CDF\u7684\u503c\uff0c\u5219\u9ed8\u8ba4\u8fd4\u56de\u6700\u540e\u4e00\u4e2a\u7d22\u5f15\u3002<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">int sample_mult(float* probabilities, int n, float coin) {\n    \/\/ sample index from probabilities (they must sum to 1!)\n    \/\/ coin is a random number in [0, 1), usually from random_f32()\n    \/\/ \u4ece\u6982\u7387\u5206\u5e03\u4e2d\u91c7\u6837\u4e00\u4e2a\u7d22\u5f15\uff0c\u8fd9\u4e9b\u6982\u7387\u7684\u603b\u548c\u5fc5\u987b\u4e3a1\n    \/\/ coin\u662f\u4e00\u4e2a\u5728[0, 1)\u533a\u95f4\u5185\u7684\u968f\u673a\u6570\uff0c\u901a\u5e38\u7531random_f32()\u751f\u6210\n    \/\/ \u7d2f\u79ef\u5206\u5e03\u51fd\u6570\u7684\u521d\u59cb\u503c\n    float cdf = 0.0f;\n    for (int i = 0; i &lt; n; i++) {\n        \/\/ \u7d2f\u52a0\u6982\u7387\u5230CDF\n        cdf += probabilities[i];\n        if (coin &lt; cdf) {\n            \/\/ \u5982\u679c\u968f\u673a\u6570\u5c0f\u4e8e\u5f53\u524d\u7684CDF\u503c\uff0c\u5219\u8fd4\u56de\u5f53\u524d\u7684\u7d22\u5f15\n            return i;\n        }\n    }\n    \/\/ \u5982\u679c\u56e0\u4e3a\u820d\u5165\u8bef\u5dee\uff0ccoin\u6ca1\u6709\u5c0f\u4e8e\u4efb\u4f55CDF\u7684\u503c\uff0c\u5219\u8fd4\u56de\u6700\u540e\u4e00\u4e2a\u7d22\u5f15\n    return n - 1; \/\/ in case of rounding errors\n}\n<\/pre><\/div>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.36 Tokenizer<\/strong><\/h3>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">\/\/ Tokenizer (only supports decoding)\n\/\/ Tokenizer\uff08\u4ec5\u652f\u6301\u89e3\u7801\uff09\n\ntypedef struct {\n    \/\/ \u8bcd\u6c47\u8868\u5927\u5c0f\n    uint32_t vocab_size;\n    \/\/ \u4ee4\u724c\u8868\uff0c\u5b58\u50a8\u6bcf\u4e2a\u7d22\u5f15\u5bf9\u5e94\u7684\u5b57\u7b26\u4e32\n    char **token_table;\n    \/\/ \u521d\u59cb\u5316\u72b6\u6001\uff0c\u5982\u679c\u6210\u529f\u521d\u59cb\u5316\u5219\u4e3a1\uff0c\u5426\u5219\u4e3a0\n    int init_ok;\n} Tokenizer;\n<\/pre><\/div>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.37 safe_printf<\/strong><\/h3>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">void safe_printf(const char *piece) {\n    \/\/ the tokens are raw bytes, and we we only want to print the printable ones\n    \/\/ many bytes can be various control codes, backspace, etc.\n    \/\/ \u4ee4\u724c\u662f\u539f\u59cb\u5b57\u8282\uff0c\u6211\u4eec\u53ea\u60f3\u6253\u5370\u53ef\u6253\u5370\u7684\u90a3\u4e9b\u3002\n    \/\/ \u8bb8\u591a\u5b57\u8282\u53ef\u80fd\u662f\u5404\u79cd\u63a7\u5236\u7801\u3001\u9000\u683c\u7b49\u3002\n    if (piece == NULL) { return; }\n    if (piece[0] == '\\0') { return; }\n    \/\/ handle individual byte tokens\n    \/\/ every token is asserted to be at least one byte so doing piece[1] is ok\n    \/\/ \u5904\u7406\u5355\u5b57\u8282\u4ee4\u724c\n    \/\/ \u6bcf\u4e2a\u4ee4\u724c\u81f3\u5c11\u6709\u4e00\u4e2a\u5b57\u8282\uff0c\u6240\u4ee5\u76f4\u63a5\u68c0\u67e5 piece[1] \u662f\u53ef\u884c\u7684\n    if (piece[1] == '\\0') {\n        unsigned char byte_val = piece[0];\n        if (!(isprint(byte_val) || isspace(byte_val))) {\n            \/\/ \u5947\u602a\u7684\u5b57\u8282\uff0c\u4e0d\u6253\u5370\u5b83\n            return; \/\/ weird byte, don't print it\n        }\n    }\n    printf(\"%s\", piece);\n}\n<\/pre><\/div>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.38 tokenizer_init<\/strong><\/h3>\n\n\n\n<p>\u6b64\u51fd\u6570\u7684\u76ee\u7684\u662f\u521d\u59cb\u5316Tokenizer\u7ed3\u6784\u4f53\uff0c\u901a\u8fc7\u4ece\u6307\u5b9a\u6587\u4ef6\u4e2d\u52a0\u8f7d\u4ee4\u724c\u4fe1\u606f\u3002\u9996\u5148\uff0c\u5b83\u5c1d\u8bd5\u6253\u5f00\u6307\u5b9a\u7684\u6587\u4ef6\uff1b\u5982\u679c\u5931\u8d25\uff0c\u5219\u6253\u5370\u4e00\u6761\u8b66\u544a\u6d88\u606f\uff0c\u5e76\u8bbe\u7f6e<code>init_ok<\/code>\u6807\u5fd7\u4e3a0\uff0c\u8868\u793a\u521d\u59cb\u5316\u5931\u8d25\u3002\u5982\u679c\u6587\u4ef6\u6253\u5f00\u6210\u529f\uff0c\u5b83\u5c06\u8bfb\u53d6\u6587\u4ef6\u5934\u6765\u83b7\u53d6\u8bcd\u6c47\u8868\u7684\u5927\u5c0f\uff0c\u5e76\u4e3a\u6bcf\u4e2a\u4ee4\u724c\u8bfb\u53d6\u5176\u957f\u5ea6\u548c\u5185\u5bb9\uff0c\u5c06\u6bcf\u4e2a\u4ee4\u724c\u5b58\u50a8\u4e3a\u4e00\u4e2a\u4ee5\u7a7a\u5b57\u7b26\u7ec8\u6b62\u7684\u5b57\u7b26\u4e32\u3002\u6700\u540e\uff0c\u5b83\u5173\u95ed\u6587\u4ef6\u5e76\u5c06<code>init_ok<\/code>\u6807\u5fd7\u8bbe\u7f6e\u4e3a1\uff0c\u8868\u793a\u521d\u59cb\u5316\u6210\u529f\u3002<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">void tokenizer_init(Tokenizer *tokenizer, const char *filename) {\n    FILE *file = fopen(filename, \"rb\");\n    if (file == NULL) {\n        \/\/ try to be more helpful as we just added this feature, erase later\n        \/\/ \u5c1d\u8bd5\u63d0\u4f9b\u66f4\u6709\u7528\u7684\u4fe1\u606f\uff0c\u56e0\u4e3a\u8fd9\u4e2a\u7279\u6027\u521a\u521a\u6dfb\u52a0\uff0c\u7a0d\u540e\u5220\u9664\u8fd9\u90e8\u5206\u63d0\u793a\n        printf(\"---\\n\");\n        printf(\"WARNING: Failed to open the tokenizer file %s\\n\", filename);\n        printf(\"The Tokenizer is a new feature added April 14 2024.\\n\");\n        printf(\"Re-run `python train_gpt2.py` to write it\\n\");\n        printf(\"---\\n\");\n        tokenizer-&gt;init_ok = 0;\n        return;\n    }\n    \/\/ read in the header\n    \/\/ \u8bfb\u53d6\u5934\u90e8\u4fe1\u606f\n    uint32_t header[256];\n    fread(header, sizeof(uint32_t), 256, file);\n    \/\/ \u65ad\u8a00\u6587\u4ef6\u683c\u5f0f\u6b63\u786e\n    assert(header[0] == 20240328);\n    \/\/ \u65ad\u8a00\u6587\u4ef6\u7248\u672c\u4e3a1\n    assert(header[1] == 1);\n    \/\/ \u8bfb\u53d6\u8bcd\u6c47\u8868\u5927\u5c0f\n    tokenizer-&gt;vocab_size = header[2];\n    \/\/ read in all the tokens\n    \/\/ \u8bfb\u53d6\u6240\u6709\u4ee4\u724c\n    \/\/ \u4ee4\u724c\u957f\u5ea6\n    unsigned char length;\n    tokenizer-&gt;token_table = (char **)malloc(tokenizer-&gt;vocab_size * sizeof(char *));\n    for (uint32_t i = 0; i &lt; tokenizer-&gt;vocab_size; i++) {\n        fread(&amp;length, sizeof(unsigned char), 1, file);\n        \/\/ \u65ad\u8a00\u6bcf\u4e2a\u4ee4\u724c\u81f3\u5c11\u6709\u4e00\u4e2a\u5b57\u7b26\n        assert(length &gt; 0); \/\/ every token should be at least one character\n        \/\/ \u5206\u914d\u5185\u5b58\n        char *token_bytes = (char *)malloc(length + 1);\n        fread(token_bytes, sizeof(char), length, file);\n        \/\/ \u6dfb\u52a0\u7a7a\u5b57\u7b26\u7ec8\u6b62\u7b26\u4ee5\u4fbf\u6253\u5370\n        token_bytes[length] = '\\0';  \/\/ Add null terminator for printing\n        \/\/ \u5b58\u50a8\u4ee4\u724c\n        tokenizer-&gt;token_table[i] = token_bytes;\n    }\n    \/\/ cleanups\n    \/\/ \u6e05\u7406\u5de5\u4f5c\n    fclose(file);\n    \/\/ \u521d\u59cb\u5316\u6210\u529f\u6807\u5fd7\n    tokenizer-&gt;init_ok = 1;\n}\n<\/pre><\/div>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.39 tokenizer_decode<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u4e2a\u51fd\u6570\u7684\u76ee\u7684\u662f\u6839\u636e\u7ed9\u5b9a\u7684\u4ee4\u724cID\u89e3\u7801\u5e76\u8fd4\u56de\u5bf9\u5e94\u7684\u4ee4\u724c\u5b57\u7b26\u4e32\u3002\u9996\u5148\u68c0\u67e5\u4ee4\u724c\u5668\u662f\u5426\u521d\u59cb\u5316\u6210\u529f\uff0c\u5982\u679c\u6ca1\u6709\uff0c\u5219\u76f4\u63a5\u8fd4\u56de<code>NULL<\/code>\u3002\u5982\u679c\u521d\u59cb\u5316\u6210\u529f\uff0c\u5e76\u4e14\u4ee4\u724cID\u5728\u8bcd\u6c47\u8868\u7684\u8303\u56f4\u5185\uff0c\u5219\u8fd4\u56de\u5bf9\u5e94\u7684\u4ee4\u724c\u5b57\u7b26\u4e32\u3002\u5982\u679c\u4ee4\u724cID\u8d85\u51fa\u4e86\u8bcd\u6c47\u8868\u7684\u8303\u56f4\uff0c\u5219\u6253\u5370\u4e00\u6761\u9519\u8bef\u4fe1\u606f\uff0c\u5e76\u8fd4\u56de<code>NULL<\/code>\u3002<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">const char *tokenizer_decode(Tokenizer *tokenizer, uint32_t token_id) {\n    if (tokenizer-&gt;init_ok == 0) {\n        \/\/ \u5982\u679c\u4ee4\u724c\u5668\u672a\u6210\u529f\u521d\u59cb\u5316\uff0c\u5219\u8fd4\u56de\u7a7a\n        return NULL;\n    }\n    if (token_id &lt; tokenizer-&gt;vocab_size) {\n        \/\/ \u5982\u679c\u4ee4\u724cID\u5728\u8bcd\u6c47\u8868\u5927\u5c0f\u8303\u56f4\u5185\uff0c\u5219\u8fd4\u56de\u5bf9\u5e94\u7684\u4ee4\u724c\n        return tokenizer-&gt;token_table[token_id];\n    } else {\n        \/\/ \u5982\u679c\u4ee4\u724cID\u8d85\u51fa\u8bcd\u6c47\u8868\u8303\u56f4\uff0c\u5219\u6253\u5370\u9519\u8bef\u4fe1\u606f\u5e76\u8fd4\u56de\u7a7a\n        printf(\"invalid token id %d!\\n\", token_id);\n        return NULL;\n    }\n}\n<\/pre><\/div>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.40 tokenizer_free<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u4e2a\u51fd\u6570\u7528\u4e8e\u91ca\u653e\u4ee4\u724c\u5668\u5206\u914d\u7684\u8d44\u6e90\u3002\u9996\u5148\u68c0\u67e5\u4ee4\u724c\u5668\u662f\u5426\u5df2\u6210\u529f\u521d\u59cb\u5316\uff0c\u5982\u679c\u662f\uff0c\u5c31\u91ca\u653e\u6bcf\u4e2a\u4ee4\u724c\u5b57\u7b26\u4e32\u5360\u7528\u7684\u5185\u5b58\uff0c\u5e76\u6700\u7ec8\u91ca\u653e\u5b58\u50a8\u4ee4\u724c\u5b57\u7b26\u4e32\u6307\u9488\u7684\u6570\u7ec4\u5360\u7528\u7684\u5185\u5b58\u3002\u8fd9\u662f\u5728\u4ee4\u724c\u5668\u4e0d\u518d\u9700\u8981\u65f6\uff0c\u907f\u514d\u5185\u5b58\u6cc4\u6f0f\u7684\u91cd\u8981\u6b65\u9aa4\u3002<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">void tokenizer_free(Tokenizer *tokenizer) {\n    if (tokenizer-&gt;init_ok) {\n        \/\/ \u5982\u679c\u4ee4\u724c\u5668\u5df2\u6210\u529f\u521d\u59cb\u5316\n        for (uint32_t i = 0; i &lt; tokenizer-&gt;vocab_size; i++) {\n            \/\/ \u904d\u5386\u6bcf\u4e2a\u4ee4\u724c\uff0c\u5e76\u91ca\u653e\u5176\u5360\u7528\u7684\u5185\u5b58\n            free(tokenizer-&gt;token_table[i]);\n        }\n        \/\/ \u91ca\u653e\u4ee4\u724c\u8868\u6570\u7ec4\u672c\u8eab\u5360\u7528\u7684\u5185\u5b58\n        free(tokenizer-&gt;token_table);\n    }\n}\n<\/pre><\/div>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1.41 main<\/strong><\/h3>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">\/\/ main training loop\n\/\/ \u4e3b\u8bad\u7ec3\u5faa\u73af\nint main() {\n\n    \/\/ build the GPT-2 model from a checkpoint\n    \/\/ \u4ece\u68c0\u67e5\u70b9\u6784\u5efaGPT-2\u6a21\u578b\n    GPT2 model;\n    gpt2_build_from_checkpoint(&amp;model, \"gpt2_124M.bin\");\n\n    \/\/ build the DataLoaders from tokens files. for now use tiny_shakespeare if available, else tiny_stories\n    \/\/ \u4ecetokens\u6587\u4ef6\u6784\u5efaDataLoaders\u3002\u5982\u679c\u53ef\u7528\uff0c\u73b0\u5728\u4f7f\u7528tiny_shakespeare\uff0c\u5426\u5219\u4f7f\u7528tiny_stories\n    char* tiny_stories_train = \"data\/TinyStories_train.bin\";\n    char* tiny_stories_val = \"data\/TinyStories_val.bin\";\n    char* tiny_shakespeare_train = \"data\/tiny_shakespeare_train.bin\";\n    char* tiny_shakespeare_val = \"data\/tiny_shakespeare_val.bin\";\n    char* train_tokens = access(tiny_shakespeare_train, F_OK) != -1 ? tiny_shakespeare_train : tiny_stories_train;\n    char* val_tokens = access(tiny_shakespeare_val, F_OK) != -1 ? tiny_shakespeare_val : tiny_stories_val;\n    \/\/ \u6279\u91cf\u5927\u5c0f\u4e3a4\uff0c\u5373\u5c06\u5bf94\u4e2a\u72ec\u7acb\u7684token\u5e8f\u5217\u8fdb\u884c\u8bad\u7ec3\n    int B = 4; \/\/ batch size 4 (i.e. 4 independent token sequences will be trained on)\n    \/\/ \u5e8f\u5217\u957f\u5ea6\u4e3a64\uff0c\u6bcf\u4e2a\u5e8f\u5217\u662f64\u4e2atokens\u957f\uff0c\u5fc5\u987b\u5c0f\u4e8e\u7b49\u4e8eGPT-2\u7684maxT\uff0c\u53731024\n    int T = 64; \/\/ sequence length 64 (i.e. each sequence is 64 tokens long). must be &lt;= maxT, which is 1024 for GPT-2\n    DataLoader train_loader;\n    dataloader_init(&amp;train_loader, train_tokens, B, T);\n    printf(\"train dataset num_batches: %d\\n\", train_loader.num_batches);\n    DataLoader val_loader;\n    dataloader_init(&amp;val_loader, val_tokens, B, T);\n    printf(\"val dataset num_batches: %d\\n\", val_loader.num_batches);\n    int val_num_batches = 5;\n\n    \/\/ build the Tokenizer\n    \/\/ \u521d\u59cb\u5316\u5206\u8bcd\u5668\uff08Tokenizer\uff09\n    Tokenizer tokenizer;\n    tokenizer_init(&amp;tokenizer, \"gpt2_tokenizer.bin\");\n\n    \/\/ some memory for generating samples from the model\n    \/\/ \u7528\u4e8e\u4ece\u6a21\u578b\u751f\u6210\u6837\u672c\u7684\u4e00\u4e9b\u5185\u5b58\u7a7a\u95f4\n    unsigned long long rng_state = 1337;\n    int* gen_tokens = (int*)malloc(B * T * sizeof(int));\n    \/\/ \u5728\u63a8\u7406\u6b65\u9aa4\u4e2d\uff0c\u6211\u4eec\u5c06\u751f\u6210\u8fd9\u4e48\u591atokens\u7684\u5e8f\u5217\n    const int genT = 64; \/\/ number of steps of inference we will do\n\n    \/\/ train\n    \/\/ \u8bad\u7ec3\u8fc7\u7a0b\n    struct timespec start, end;\n    for (int step = 0; step &lt;= 40; step++) {\n\n        \/\/ once in a while estimate the validation loss\n        \/\/ \u5076\u5c14\u4f30\u8ba1\u9a8c\u8bc1\u635f\u5931\n        if (step % 10 == 0) {\n            \/\/ \u521d\u59cb\u5316\u9a8c\u8bc1\u635f\u5931\u4e3a0\n            float val_loss = 0.0f;\n            \/\/ \u91cd\u7f6e\u9a8c\u8bc1\u6570\u636e\u52a0\u8f7d\u5668\u7684\u5f53\u524d\u4f4d\u7f6e\uff0c\u51c6\u5907\u4ece\u5934\u5f00\u59cb\u8bfb\u53d6\u6570\u636e\n            dataloader_reset(&amp;val_loader);\n            \/\/ \u904d\u5386\u6240\u6709\u7684\u9a8c\u8bc1\u6279\u6b21\n            for (int i = 0; i &lt; val_num_batches; i++) {\n                \/\/ \u83b7\u53d6\u4e0b\u4e00\u6279\u9a8c\u8bc1\u6570\u636e\n                dataloader_next_batch(&amp;val_loader);\n                \/\/ \u5bf9\u9a8c\u8bc1\u6570\u636e\u6267\u884c\u524d\u5411\u4f20\u64ad\n                gpt2_forward(&amp;model, val_loader.inputs, val_loader.targets, B, T);\n                \/\/ \u7d2f\u52a0\u6a21\u578b\u5728\u5f53\u524d\u6279\u6b21\u7684\u5e73\u5747\u635f\u5931\n                val_loss += model.mean_loss;\n            }\n            \/\/ \u8ba1\u7b97\u6240\u6709\u9a8c\u8bc1\u6279\u6b21\u7684\u5e73\u5747\u635f\u5931\n            val_loss \/= val_num_batches;\n            \/\/ \u6253\u5370\u9a8c\u8bc1\u635f\u5931\n            printf(\"val loss %f\\n\", val_loss);\n        }\n\n        \/\/ once in a while do model inference to print generated text\n        \/\/ \u5076\u5c14\u8fdb\u884c\u6a21\u578b\u63a8\u7406\u4ee5\u6253\u5370\u751f\u6210\u7684\u6587\u672c\n        if (step &gt; 0 &amp;&amp; step % 20 == 0) {\n            \/\/ fill up gen_tokens with the GPT2_EOT, which kicks off the generation\n            \/\/ \u4f7f\u7528GPT2_EOT\u586b\u5145gen_tokens\uff0c\u8fd9\u5c06\u89e6\u53d1\u751f\u6210\u8fc7\u7a0b\n            for(int i = 0; i &lt; B * T; ++i) {\n                gen_tokens[i] = GPT2_EOT;\n            }\n            \/\/ now sample from the model autoregressively\n            \/\/ \u73b0\u5728\u4ece\u6a21\u578b\u4e2d\u81ea\u56de\u5f52\u5730\u8fdb\u884c\u62bd\u6837\n            printf(\"generating:\\n---\\n\");\n            for (int t = 1; t &lt; genT; t++) {\n                \/\/ note that inference is very wasteful here because for each token\n                \/\/ we re-calculate the forward pass for all of (B,T) positions from scratch\n                \/\/ but the inference here is just for sanity checking anyway\n                \/\/ and we can maybe optimize a bit more later, with careful tests\n                \/\/ furthermore, below we're only using b=0 (i.e. the first row) of all B rows\n                \/\/ we're in principle running B \"inference streams\" in parallel here\n                \/\/ but only using position 0\n                \/\/ get the V-dimensional vector probs[0, t-1, :]\n                 \/\/ \u6ce8\u610f\uff0c\u8fd9\u91cc\u7684\u63a8\u7406\u975e\u5e38\u6d6a\u8d39\u8d44\u6e90\uff0c\u56e0\u4e3a\u5bf9\u4e8e\u6bcf\u4e2atoken\n                \/\/ \u6211\u4eec\u4ece\u5934\u5f00\u59cb\u91cd\u65b0\u8ba1\u7b97\u6240\u6709(B,T)\u4f4d\u7f6e\u7684\u524d\u5411\u4f20\u64ad\n                \/\/ \u4f46\u8fd9\u91cc\u7684\u63a8\u7406\u4ec5\u4ec5\u662f\u4e3a\u4e86\u8fdb\u884c\u57fa\u672c\u7684\u68c0\u67e5\n                \/\/ \u6211\u4eec\u6216\u8bb8\u53ef\u4ee5\u5728\u4ee5\u540e\u901a\u8fc7\u4ed4\u7ec6\u6d4b\u8bd5\u6765\u8fdb\u884c\u4e00\u4e9b\u4f18\u5316\n                gpt2_forward(&amp;model, gen_tokens, NULL, B, T);\n                \/\/ furthermore, below we're only using b=0 (i.e. the first row) of all B rows\n                \/\/ we're in principle running B \"inference streams\" in parallel here\n                \/\/ but only using position 0\n                \/\/ get the V-dimensional vector probs[0, t-1, :]\n                \/\/ \u6b64\u5916\uff0c\u5728\u4e0b\u9762\u6211\u4eec\u4ec5\u4f7f\u7528\u4e86b=0\uff08\u5373\u6240\u6709B\u884c\u4e2d\u7684\u7b2c\u4e00\u884c\uff09\n                \/\/ \u539f\u5219\u4e0a\u6211\u4eec\u5728\u8fd9\u91cc\u5e76\u884c\u8fd0\u884cB\u4e2a\u201c\u63a8\u7406\u6d41\u201d\n                \/\/ \u4f46\u53ea\u4f7f\u7528\u4e86\u4f4d\u7f6e0\n                \/\/ \u83b7\u53d6V\u7ef4\u5411\u91cfprobs[0, t-1, :]\n                float* probs = model.acts.probs + (t-1) * model.config.vocab_size;\n                float coin = random_f32(&amp;rng_state);\n                int next_token = sample_mult(probs, model.config.vocab_size, coin);\n                gen_tokens[t] = next_token;\n                \/\/ print the generated token, either using the Tokenizer or a fallback\n                \/\/ \u6253\u5370\u751f\u6210\u7684token\uff0c\u53ef\u4ee5\u4f7f\u7528Tokenizer\u6216\u8005\u4e00\u4e2a\u5907\u7528\u65b9\u6848\n                if (tokenizer.init_ok) {\n                    const char* token_str = tokenizer_decode(&amp;tokenizer, next_token);\n                    safe_printf(token_str);\n                } else {\n                    \/\/ fall back to printing the token id\n                    \/\/ \u56de\u9000\u5230\u6253\u5370token\u7684id\n                    printf(\"%d \", next_token);\n                }\n                fflush(stdout);\n            }\n            printf(\"\\n---\\n\");\n        }\n\n        \/\/ do a training step\n        \/\/ \u6267\u884c\u4e00\u4e2a\u8bad\u7ec3\u6b65\u9aa4\n        \/\/ \u83b7\u53d6\u5f53\u524d\u65f6\u95f4\uff0c\u4f5c\u4e3a\u5f00\u59cb\u65f6\u95f4\n        clock_gettime(CLOCK_MONOTONIC, &amp;start);\n        \/\/ \u4ece\u6570\u636e\u52a0\u8f7d\u5668\u4e2d\u83b7\u53d6\u4e0b\u4e00\u6279\u8bad\u7ec3\u6570\u636e\n        dataloader_next_batch(&amp;train_loader);\n        \/\/ \u5bf9\u83b7\u53d6\u7684\u8bad\u7ec3\u6570\u636e\u6267\u884c\u524d\u5411\u4f20\u64ad\n        gpt2_forward(&amp;model, train_loader.inputs, train_loader.targets, B, T);\n        \/\/ \u6e05\u96f6\u6a21\u578b\u7684\u68af\u5ea6\n        gpt2_zero_grad(&amp;model);\n        \/\/ \u6267\u884c\u53cd\u5411\u4f20\u64ad\uff0c\u8ba1\u7b97\u68af\u5ea6\n        gpt2_backward(&amp;model);\n        \/\/ \u4f7f\u7528AdamW\u4f18\u5316\u5668\u66f4\u65b0\u6a21\u578b\u7684\u53c2\u6570\n        gpt2_update(&amp;model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.0f, step+1);\n        \/\/ \u83b7\u53d6\u5f53\u524d\u65f6\u95f4\uff0c\u4f5c\u4e3a\u7ed3\u675f\u65f6\u95f4\n        clock_gettime(CLOCK_MONOTONIC, &amp;end);\n        \/\/ \u8ba1\u7b97\u8bad\u7ec3\u8fd9\u4e00\u6b65\u6240\u6d88\u8017\u7684\u65f6\u95f4\n        double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) \/ 1e9;\n        \/\/ \u6253\u5370\u5f53\u524d\u6b65\u9aa4\u7684\u8bad\u7ec3\u635f\u5931\u548c\u65f6\u95f4\u6d88\u8017\n        printf(\"step %d: train loss %f (took %f ms)\\n\", step, model.mean_loss, time_elapsed_s * 1000);\n    }\n\n    \/\/ free\n    \/\/ \u91ca\u653e\u8d44\u6e90\n    dataloader_free(&amp;train_loader);\n    dataloader_free(&amp;val_loader);\n    tokenizer_free(&amp;tokenizer);\n    gpt2_free(&amp;model);\n    free(gen_tokens);\n    return 0;\n}\n<\/pre><\/div>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>2. test_gpt2.c<\/strong><\/h2>\n\n\n\n<p>test_gpt2.c \u662fC\u8bed\u8a00\u7248\u7684\u6a21\u578b\u51c6\u786e\u6027\u9a8c\u8bc1\uff0c\u5305\u542b\u4e86 train_gpt2.c \u7684\u4ee3\u7801<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>2.1 check_tensor<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u662f\u4e00\u4e2a\u7528\u4e8e\u9a8c\u8bc1\u4e24\u4e2a\u5f20\u91cf\u662f\u5426\u5728\u4e00\u5b9a\u5bb9\u5dee\u8303\u56f4\u5185\u5927\u81f4\u76f8\u7b49\u7684\u51fd\u6570\u3002\u8fd9\u5728\u9a8c\u8bc1\u795e\u7ecf\u7f51\u7edc\u7684\u5b9e\u73b0\u6b63\u786e\u6027\u65f6\u975e\u5e38\u6709\u7528\uff0c\u7279\u522b\u662f\u5728\u5bf9\u6bd4\u524d\u5411\u4f20\u64ad\u548c\u53cd\u5411\u4f20\u64ad\u7684\u7ed3\u679c\u65f6\u3002\u4e0b\u9762\u662f\u4ee3\u7801\u7684\u8be6\u7ec6\u6ce8\u91ca\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">\/\/ poor man's tensor checker\n\/\/ \u8fd9\u662f\u4e00\u4e2a\u7528\u4e8e\u6bd4\u8f83\u4e24\u4e2a\u5f20\u91cf\u5728\u7ed9\u5b9a\u5bb9\u5dee\u4e0b\u662f\u5426\u5927\u81f4\u76f8\u7b49\u7684\u51fd\u6570\nint check_tensor(float *a, float *b, int n, char* label) {\n    \/\/ \u5c06\u8981\u6253\u5370\u7684\u6700\u5927\u5143\u7d20\u6570\u91cf\uff0c\u7528\u4e8e\u9519\u8bef\u8c03\u8bd5\n    int print_upto = 5;\n    \/\/ \u6807\u8bb0\u6240\u6709\u5143\u7d20\u662f\u5426\u5728\u5bb9\u5dee\u5185\u76f8\u7b49\uff0c\u9ed8\u8ba4\u4e3a1\uff08\u76f8\u7b49\uff09\n    int ok = 1;\n    \/\/ \u8bb0\u5f55\u6240\u6709\u5143\u7d20\u5dee\u5f02\u7684\u6700\u5927\u503c\n    float maxdiff = 0.0f;\n    \/\/ \u5bb9\u5dee\u503c\uff0c\u4e24\u4e2a\u5143\u7d20\u7684\u5dee\u5f02\u5fc5\u987b\u5c0f\u4e8e\u6b64\u503c\u624d\u8ba4\u4e3a\u5b83\u4eec\u76f8\u7b49\n    float tol = 2e-2;\n    \/\/ \u6253\u5370\u5f53\u524d\u6b63\u5728\u68c0\u67e5\u7684\u5f20\u91cf\u6807\u7b7e\n    printf(\"%s\\n\", label);\n    \/\/ \u904d\u5386\u6bcf\u4e2a\u5143\u7d20\u8fdb\u884c\u6bd4\u8f83\n    for (int i = 0; i &lt; n; i++) {\n        \/\/ look at the diffence at position i of these two tensors\n        \/\/ \u8ba1\u7b97\u5f53\u524d\u5143\u7d20\u7684\u7edd\u5bf9\u5dee\u503c\n        float diff = fabsf(a[i] - b[i]);\n\n        \/\/ keep track of the overall error\n        \/\/ \u5982\u679c\u5f53\u524d\u5143\u7d20\u5728\u5bb9\u5dee\u5185\uff0c\u5219\u4fdd\u6301ok\u4e3a1\uff0c\u5426\u5219\u53d8\u4e3a0\n        ok = ok &amp;&amp; (diff &lt;= tol);\n        \/\/ \u66f4\u65b0\u6700\u5927\u5dee\u5f02\u503c\n        if (diff &gt; maxdiff) { maxdiff = diff; }\n\n        \/\/ for the first few elements of each tensor, pretty print\n        \/\/ the actual numbers, so we can do a visual, qualitative proof\/assessment\n        \/\/ \u5bf9\u524d\u51e0\u4e2a\u5143\u7d20\u8fdb\u884c\u8be6\u7ec6\u6253\u5370\uff0c\u4ee5\u4fbf\u8fdb\u884c\u4eba\u5de5\u68c0\u67e5\n        if (i &lt; print_upto) {\n            if (diff &lt;= tol) {\n                if (i &lt; print_upto) { printf(\"OK \"); }\n            } else {\n                if (i &lt; print_upto) { printf(\"NOT OK \"); }\n            }\n            printf(\"%f %f\\n\", a[i], b[i]);\n        }\n    }\n    \/\/ print the final result for this tensor\n    \/\/ \u6253\u5370\u6574\u4e2a\u5f20\u91cf\u7684\u68c0\u67e5\u7ed3\u679c\n    if (ok) {\n        \/\/ \u6240\u6709\u5143\u7d20\u90fd\u5728\u5bb9\u5dee\u5185\n        printf(\"TENSOR OK, maxdiff = %e\\n\", maxdiff);\n    } else {\n        \/\/ \u5b58\u5728\u5143\u7d20\u8d85\u51fa\u5bb9\u5dee\n        printf(\"TENSOR NOT OK, maxdiff = %e\\n\", maxdiff);\n    }\n    \/\/ \u8fd4\u56de\u68c0\u67e5\u7ed3\u679c\uff0c1\u8868\u793a\u6240\u6709\u5143\u7d20\u90fd\u5728\u5bb9\u5dee\u5185\u76f8\u7b49\uff0c0\u8868\u793a\u81f3\u5c11\u6709\u4e00\u4e2a\u5143\u7d20\u4e0d\u76f8\u7b49\n    return ok;\n}\n<\/pre><\/div>\n\n\n\n<p>\u8fd9\u4e2a\u51fd\u6570\u63a5\u6536\u4e24\u4e2a\u5f20\u91cf<code>a<\/code>\u548c<code>b<\/code>\uff08\u4f5c\u4e3a\u4e00\u7ef4\u6570\u7ec4\uff09\uff0c\u5b83\u4eec\u7684\u5143\u7d20\u6570\u91cf<code>n<\/code>\uff0c\u4ee5\u53ca\u4e00\u4e2a\u7528\u4e8e\u5728\u6253\u5370\u65f6\u6807\u8bc6\u5f20\u91cf\u7684<code>label<\/code>\u5b57\u7b26\u4e32\u3002\u51fd\u6570\u904d\u5386\u8fd9\u4e9b\u5f20\u91cf\u7684\u6bcf\u4e2a\u5143\u7d20\uff0c\u8ba1\u7b97\u5b83\u4eec\u7684\u5dee\u5f02\uff0c\u5e76\u68c0\u67e5\u8fd9\u4e9b\u5dee\u5f02\u662f\u5426\u90fd\u5728\u5b9a\u4e49\u7684\u5bb9\u5dee<code>tol<\/code>\u5185\u3002\u5982\u679c\u6240\u6709\u5143\u7d20\u7684\u5dee\u5f02\u90fd\u5728\u5bb9\u5dee\u5185\uff0c\u5219\u51fd\u6570\u8fd4\u56de1\uff0c\u8868\u793a\u5f20\u91cf\u5927\u81f4\u76f8\u7b49\uff1b\u5982\u679c\u81f3\u5c11\u6709\u4e00\u4e2a\u5143\u7d20\u7684\u5dee\u5f02\u8d85\u51fa\u5bb9\u5dee\uff0c\u5219\u8fd4\u56de0\uff0c\u8868\u793a\u5f20\u91cf\u4e0d\u76f8\u7b49\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>2.2 main<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u662f\u4e00\u4e2a\u6d4b\u8bd5\u811a\u672c\u7684\u4e3b\u51fd\u6570\uff0c\u7528\u4e8e\u52a0\u8f7dGPT-2\u6a21\u578b\u3001\u6267\u884c\u524d\u5411\u548c\u53cd\u5411\u4f20\u64ad\uff0c\u4ee5\u53ca\u66f4\u65b0\u6a21\u578b\u53c2\u6570\uff0c\u5e76\u8fdb\u884c\u4e00\u7cfb\u5217\u7684\u51c6\u786e\u6027\u9a8c\u8bc1\u3002\u8fd9\u901a\u5e38\u7528\u4e8e\u786e\u4fdd\u6a21\u578b\u5b9e\u73b0\u7684\u6b63\u786e\u6027\uff0c\u901a\u8fc7\u4e0e\u9884\u671f\u7684\u7ed3\u679c\u8fdb\u884c\u5bf9\u6bd4\u3002\u4e0b\u9762\u662f\u5bf9\u8fd9\u6bb5\u4ee3\u7801\u7684\u4e2d\u6587\u6ce8\u91ca\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:c decode:true \">int main(int argc, char *argv[]) {\n\n    \/\/ build the GPT-2 model from a checkpoint\n    \/\/ \u4ece\u68c0\u67e5\u70b9\u6587\u4ef6\u6784\u5efaGPT-2\u6a21\u578b\n    GPT2 model;\n    gpt2_build_from_checkpoint(&amp;model, \"gpt2_124M.bin\");\n\n    \/\/ \u4ece\u6a21\u578b\u914d\u7f6e\u4e2d\u83b7\u53d6\u4e00\u4e9b\u57fa\u672c\u4fe1\u606f\n    int C = model.config.channels;\n    int V = model.config.vocab_size;\n    int maxT = model.config.max_seq_len;\n    int L = model.config.num_layers;\n\n    \/\/ load additional information that we will use for debugging and error checking\n    \/\/ \u52a0\u8f7d\u8c03\u8bd5\u548c\u9519\u8bef\u68c0\u67e5\u7528\u7684\u9644\u52a0\u4fe1\u606f\n    FILE *state_file = fopen(\"gpt2_124M_debug_state.bin\", \"rb\");\n    if (state_file == NULL) { printf(\"Error opening state file\\n\"); return 1; }\n    int state_header[256];\n    fread(state_header, sizeof(int), 256, state_file);\n    if (state_header[0] != 20240327) { printf(\"Bad magic state file\"); return 1; }\n    if (state_header[1] != 1) { printf(\"Bad version in state file\"); return 1; }\n    \/\/ \u6279\u5927\u5c0f\n    int B = state_header[2]; \/\/ batch size, e.g. 4\n    \/\/ \u5e8f\u5217\u957f\u5ea6\n    int T = state_header[3]; \/\/ time \/ sequence length (e.g. 64, up to maxT)\n    printf(\"[State]\\n\");\n    printf(\"batch_size: %d\\n\", B);\n    printf(\"seq_len: %d\\n\", T);\n\n    \/\/ \u9884\u671f\u68af\u5ea6\n    ParameterTensors expected_grads;\n    float* expected_grads_memory = malloc_and_point_parameters(&amp;expected_grads, model.param_sizes);\n\n    \/\/ inputs and expected outputs, only used for error checking\n    \/\/ \u8f93\u5165\u548c\u9884\u671f\u8f93\u51fa\uff0c\u4ec5\u7528\u4e8e\u9519\u8bef\u68c0\u67e5\n    int* x = (int*) malloc(B * T * sizeof(int));\n    int* y = (int*) malloc(B * T * sizeof(int));\n    float* expected_logits = (float*) malloc(B * T * V * sizeof(float));\n    float* expected_loss = (float*) malloc(1 * sizeof(float));\n\n    \/\/ read reference information from Python\n    \/\/ \u4ecePython\u4e2d\u8bfb\u53d6\u53c2\u8003\u4fe1\u606f\n    fread(x, sizeof(int), B*T, state_file);\n    fread(y, sizeof(int), B*T, state_file);\n    fread(expected_logits, sizeof(float), B*T*V, state_file);\n    fread(expected_loss, sizeof(float), 1, state_file);\n    fread(expected_grads_memory, sizeof(float), model.num_parameters, state_file);\n    fclose(state_file);\n\n    \/\/ overall OK signal for the test\n    \/\/ \u6574\u4f53\u7684OK\u6807\u5fd7\n    int allok = 1;\n\n    \/\/ let's do 10 training iterations, following the pytorch code\n    \/\/ \u8fdb\u884c10\u6b21\u8bad\u7ec3\u8fed\u4ee3\uff0c\u8ddf\u968fpytorch\u4ee3\u7801\n    float expected_losses[10] = {\n        5.270007133483887,\n        4.059706687927246,\n        3.3751230239868164,\n        2.8007826805114746,\n        2.315382242202759,\n        1.8490285873413086,\n        1.3946564197540283,\n        0.9991465210914612,\n        0.6240804195404053,\n        0.37651097774505615\n    };\n    for (int step = 0; step &lt; 10; step++) {\n\n        struct timespec start, end;\n        clock_gettime(CLOCK_MONOTONIC, &amp;start);\n\n        \/\/ \u6267\u884c\u524d\u5411\u4f20\u64ad\u3001\u68af\u5ea6\u5f52\u96f6\u3001\u53cd\u5411\u4f20\u64ad\n        gpt2_forward(&amp;model, x, y, B, T);\n        gpt2_zero_grad(&amp;model);\n        gpt2_backward(&amp;model);\n\n        clock_gettime(CLOCK_MONOTONIC, &amp;end);\n        double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) \/ 1e9;\n\n        \/\/ \u7b2c0\u6b65\u8fdb\u884c\u9519\u8bef\u68c0\u67e5\n        if (step == 0) {\n            \/\/ error checking at step 0 for reference activations\/gradients\n            \/\/ \u5728\u7b2c0\u6b65\u8fdb\u884c\u53c2\u8003\u6fc0\u6d3b\/\u68af\u5ea6\u7684\u9519\u8bef\u68c0\u67e5\u3002\n\n            \/\/ at this point, target should be equal to expected_logits, let's compare\n            \/\/ \u6b64\u65f6\uff0c\u76ee\u6807\u5e94\u8be5\u7b49\u4e8e\u9884\u671f\u7684logits\uff0c\u8ba9\u6211\u4eec\u6765\u8fdb\u884c\u6bd4\u8f83\u3002\n            \/\/ \u6bd4\u8f83\u76ee\u6807\u548c\u9884\u671f\u7684logits\n            int logits_ok = 1;\n            \/\/ \u5faa\u73af\u68c0\u67e5logits\u662f\u5426\u5339\u914d\n            for (int i=0; i&lt;B*T*V; i++) {\n                if(i &lt; 3) {\n                    printf(\"%f %f\\n\", expected_logits[i], model.acts.logits[i]);\n                }\n                if (fabsf(expected_logits[i] - model.acts.logits[i]) &gt;= 1e-2) {\n                    printf(\"MISMATCH AT INDEX %d: \", i);\n                    printf(\"%f %f\\n\", expected_logits[i],model.acts.logits[i]);\n                    logits_ok = 0;\n                    break;\n                }\n            }\n            if(!logits_ok) { printf(\"NOT \"); }\n            printf(\"OK (LOGITS)\\n\");\n            allok = allok &amp;&amp; logits_ok;\n\n            \/\/ compare the achieved loss\n            \/\/ \u6bd4\u8f83\u5b9e\u9645\u635f\u5931\u548c\u9884\u671f\u635f\u5931\n            if (fabsf(model.mean_loss - *expected_loss) &gt;= 1e-2) {\n                printf(\"LOSS MISMATCH: %f %f\\n\", model.mean_loss, *expected_loss);\n                allok = 0;\n            } else {\n                printf(\"LOSS OK: %f %f\\n\", model.mean_loss, *expected_loss);\n            }\n\n            \/\/ finally check all the gradients\n            \/\/ \u68c0\u67e5\u6240\u6709\u68af\u5ea6\n            int gradoks[16];\n            \/\/ \u5bf9\u6bcf\u4e2a\u68af\u5ea6\u8fdb\u884c\u68c0\u67e5\n            ParameterTensors grads = model.grads;\n            gradoks[0] = check_tensor(grads.wte, expected_grads.wte, V*C, \"dwte\");\n            gradoks[1] = check_tensor(grads.wpe, expected_grads.wpe, maxT*C, \"dwpe\");\n            gradoks[2] = check_tensor(grads.ln1w, expected_grads.ln1w, L*C, \"dln1w\");\n            gradoks[3] = check_tensor(grads.ln1b, expected_grads.ln1b, L*C, \"dln1b\");\n            gradoks[4] = check_tensor(grads.qkvw, expected_grads.qkvw, L*3*C*C, \"dqkvw\");\n            gradoks[5] = check_tensor(grads.qkvb, expected_grads.qkvb, L*3*C, \"dqkvb\");\n            gradoks[6] = check_tensor(grads.attprojw, expected_grads.attprojw, L*C*C, \"dattprojw\");\n            gradoks[7] = check_tensor(grads.attprojb, expected_grads.attprojb, L*C, \"dattprojb\");\n            gradoks[8] = check_tensor(grads.ln2w, expected_grads.ln2w, L*C, \"dln2w\");\n            gradoks[9] = check_tensor(grads.ln2b, expected_grads.ln2b, L*C, \"dln2b\");\n            gradoks[10] = check_tensor(grads.fcw, expected_grads.fcw, L*4*C*C, \"dfcw\");\n            gradoks[11] = check_tensor(grads.fcb, expected_grads.fcb, L*4*C, \"dfcb\");\n            gradoks[12] = check_tensor(grads.fcprojw, expected_grads.fcprojw, L*C*4*C, \"dfcprojw\");\n            gradoks[13] = check_tensor(grads.fcprojb, expected_grads.fcprojb, L*C, \"dfcprojb\");\n            gradoks[14] = check_tensor(grads.lnfw, expected_grads.lnfw, C, \"dlnfw\");\n            gradoks[15] = check_tensor(grads.lnfb, expected_grads.lnfb, C, \"dlnfb\");\n            for (int i = 0; i &lt; 16; i++) {\n                allok = allok &amp;&amp; gradoks[i];\n            }\n        }\n\n        \/\/ \u66f4\u65b0\u6a21\u578b\u53c2\u6570\n        gpt2_update(&amp;model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.01f, step+1);\n\n        \/\/ compare the losses\n        \/\/ \u6bd4\u8f83\u635f\u5931\n        float expected_loss = expected_losses[step];\n        float actual_loss = model.mean_loss;\n        int step_loss_ok = fabsf(expected_loss - actual_loss) &lt; 1e-2;\n        allok = allok &amp;&amp; step_loss_ok;\n\n        \/\/ print the timing information at the end\n        \/\/ \u6253\u5370\u65f6\u95f4\u4fe1\u606f\n        printf(\"step %d: loss %f (took %f ms) OK = %d\\n\", step, model.mean_loss, time_elapsed_s * 1000, step_loss_ok);\n    }\n\n    \/\/ final judgement\n    \/\/ \u7ed9\u51fa\u6700\u7ec8\u5224\u65ad\n    printf(\"overall okay: %d\\n\", allok);\n\n    \/\/ free everything\n    \/\/ \u91ca\u653e\u6240\u6709\u8d44\u6e90\n    free(x);\n    free(y);\n    free(expected_logits);\n    free(expected_loss);\n    free(expected_grads_memory);\n    gpt2_free(&amp;model);\n    return 0;\n}\n<\/pre><\/div>\n\n\n\n<p>\u6b64\u4ee3\u7801\u6bb5\u4e3b\u8981\u6267\u884c\u4ee5\u4e0b\u64cd\u4f5c\uff1a<\/p>\n\n\n\n<ol class=\"wp-block-list\">\n<li>\u52a0\u8f7d\u6a21\u578b\u3001\u8f93\u5165\u6570\u636e\u3001\u9884\u671f\u8f93\u51fa\u548c\u9884\u671f\u68af\u5ea6\u3002<\/li>\n\n\n\n<li>\u5bf9\u6a21\u578b\u8fdb\u884c\u591a\u6b21\u8bad\u7ec3\u8fed\u4ee3\uff0c\u6bcf\u6b21\u8fed\u4ee3\u540e\u90fd\u4f1a\u66f4\u65b0\u6a21\u578b\u7684\u53c2\u6570\u3002<\/li>\n\n\n\n<li>\u5728\u7b2c\u4e00\u6b21\u8fed\u4ee3\u540e\uff0c\u901a\u8fc7\u6bd4\u8f83\u5b9e\u9645\u7684\u8f93\u51fa\u3001\u635f\u5931\u548c\u68af\u5ea6\u4e0e\u9884\u671f\u503c\u6765\u9a8c\u8bc1\u6a21\u578b\u7684\u51c6\u786e\u6027\u3002<\/li>\n\n\n\n<li>\u6839\u636e\u9a8c\u8bc1\u7ed3\u679c\uff0c\u7ed9\u51fa\u6a21\u578b\u6574\u4f53\u7684\u51c6\u786e\u6027\u8bc4\u4ef7\u3002<\/li>\n\n\n\n<li>\u6700\u540e\u91ca\u653e\u6240\u6709\u5206\u914d\u7684\u8d44\u6e90\u3002<\/li>\n<\/ol>\n","protected":false},"excerpt":{"rendered":"<p>llm.c \u7b80\u5355\u3001\u7eaf C\/CUDA \u7684 LLM \u8bad\u7ec3\u3002\u4e0d\u9700\u8981 245MB \u7684 PyTorch \u6216 107MB  [&hellip;]<\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"site-sidebar-layout":"default","site-content-layout":"","ast-site-content-layout":"default","site-content-style":"default","site-sidebar-style":"default","ast-global-header-display":"","ast-banner-title-visibility":"","ast-main-header-display":"","ast-hfb-above-header-display":"","ast-hfb-below-header-display":"","ast-hfb-mobile-header-display":"","site-post-title":"","ast-breadcrumbs-content":"","ast-featured-img":"","footer-sml-layout":"","theme-transparent-header-meta":"","adv-header-id-meta":"","stick-header-meta":"","header-above-stick-meta":"","header-main-stick-meta":"","header-below-stick-meta":"","astra-migrate-meta-layouts":"set","ast-page-background-enabled":"default","ast-page-background-meta":{"desktop":{"background-color":"var(--ast-global-color-4)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"tablet":{"background-color":"","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"mobile":{"background-color":"","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""}},"ast-content-background-meta":{"desktop":{"background-color":"var(--ast-global-color-5)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"tablet":{"background-color":"var(--ast-global-color-5)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"mobile":{"background-color":"var(--ast-global-color-5)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""}},"_jetpack_memberships_contains_paid_content":false,"footnotes":""},"categories":[443,442],"tags":[242,431,314],"class_list":["post-3229","post","type-post","status-publish","format-standard","hentry","category-llm","category-llms","tag-chatgpt","tag-llm-c","tag-openai-api"],"views":5968,"jetpack_sharing_enabled":true,"jetpack_featured_media_url":"","_links":{"self":[{"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/posts\/3229","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=%2Fwp%2Fv2%2Fcomments&post=3229"}],"version-history":[{"count":75,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/posts\/3229\/revisions"}],"predecessor-version":[{"id":3316,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/posts\/3229\/revisions\/3316"}],"wp:attachment":[{"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=%2Fwp%2Fv2%2Fmedia&parent=3229"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=%2Fwp%2Fv2%2Fcategories&post=3229"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=%2Fwp%2Fv2%2Ftags&post=3229"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}