{"id":2038,"date":"2024-02-16T23:05:37","date_gmt":"2024-02-16T15:05:37","guid":{"rendered":"https:\/\/www.aqwu.net\/wp\/?p=2038"},"modified":"2024-04-28T20:04:55","modified_gmt":"2024-04-28T12:04:55","slug":"%e5%a6%82%e4%bd%95%e5%88%9b%e5%bb%ba%e4%b8%80%e4%b8%aa-gpt-%e6%a8%a1%e5%9e%8b","status":"publish","type":"post","link":"https:\/\/www.aqwu.net\/wp\/?p=2038","title":{"rendered":"\u5982\u4f55\u521b\u5efa\u4e00\u4e2a GPT \u6a21\u578b(\u5b57\u7b26\u7ea7\u522b)"},"content":{"rendered":"\n<h2 class=\"wp-block-heading\"><strong>0.GPT \u6a21\u578b\u6982\u8ff0<\/strong><\/h2>\n\n\n\n<p>GPT \u6a21\u578b\u662f Generative Pretrained Transformer \u7684\u7f29\u5199\uff0c\u662f\u4e13\u4e3a\u751f\u6210\u7c7b\u4f3c\u4eba\u7c7b\u7684\u6587\u672c\u800c\u8bbe\u8ba1\u7684\u9ad8\u7ea7\u6df1\u5ea6\u5b66\u4e60\u6a21\u578b\u3002\u8fd9\u4e9b\u7531 OpenAI \u5f00\u53d1\u7684\u6a21\u578b\u5df2\u7ecf\u8fdb\u884c\u4e86\u591a\u6b21\u8fed\u4ee3\uff1aGPT-1\u3001GPT-2\u3001GPT-3\uff0c\u4ee5\u53ca\u6700\u8fd1\u7684 GPT-4\u3002<\/p>\n\n\n\n<p>GPT \u662f\u4e00\u79cd\u57fa\u4e8e transformer \u67b6\u6784\u7684 AI \u8bed\u8a00\u6a21\u578b\uff0c\u8be5\u6a21\u578b\u7ecf\u8fc7\u9884\u8bad\u7ec3\u3001\u751f\u6210\u3001\u65e0\u76d1\u7763\uff0c\u80fd\u591f\u5728\u96f6\/\u4e00\/\u5c11\u591a\u4efb\u52a1\u8bbe\u7f6e\u4e2d\u8868\u73b0\u826f\u597d\u3002\u5b83\u4ece NLP \u4efb\u52a1\u7684\u6807\u8bb0\u5e8f\u5217\u4e2d\u9884\u6d4b\u4e0b\u4e00\u4e2a\u6807\u8bb0\uff08\u5b57\u7b26\u5e8f\u5217\u7684\u5b9e\u4f8b\uff09\uff0c\u5b83\u5c1a\u672a\u7ecf\u8fc7\u8bad\u7ec3\u3002\u5728\u53ea\u770b\u5230\u51e0\u4e2a\u4f8b\u5b50\u4e4b\u540e\uff0c\u5b83\u53ef\u4ee5\u5728\u67d0\u4e9b\u57fa\u51c6\u6d4b\u8bd5\u4e2d\u8fbe\u5230\u9884\u671f\u7684\u7ed3\u679c\uff0c\u5305\u62ec\u673a\u5668\u7ffb\u8bd1\u3001\u95ee\u7b54\u548c\u5b8c\u5f62\u586b\u7a7a\u4efb\u52a1\u3002GPT \u6a21\u578b\u4e3b\u8981\u57fa\u4e8e\u6761\u4ef6\u6982\u7387\u8ba1\u7b97\u4e00\u4e2a\u5355\u8bcd\u51fa\u73b0\u5728\u53e6\u4e00\u4e2a\u6587\u672c\u4e2d\u7684\u53ef\u80fd\u6027\uff0c\u56e0\u4e3a\u5b83\u51fa\u73b0\u5728\u53e6\u4e00\u4e2a\u6587\u672c\u4e2d\u3002\u4f8b\u5982\uff0c\u5728\u53e5\u5b50\u4e2d\uff0c\u201c\u739b\u683c\u4e3d\u7279\u6b63\u5728\u7ec4\u7ec7\u8f66\u5e93\u9500\u552e&#8230;&#8230;\u4e5f\u8bb8\u6211\u4eec\u53ef\u4ee5\u4e70\u90a3\u4e2a\u65e7\u7684&#8230;&#8230;\u201d\u6905\u5b50\u8fd9\u4e2a\u8bcd\u53ef\u80fd\u6bd4\u201c\u5927\u8c61\u201d\u8fd9\u4e2a\u8bcd\u66f4\u5408\u9002\u3002\u6b64\u5916\uff0c\u8f6c\u6362\u5668\u6a21\u578b\u4f7f\u7528\u591a\u4e2a\u79f0\u4e3a\u6ce8\u610f\u529b\u5757\u7684\u5355\u5143\u6765\u5b66\u4e60\u8981\u5173\u6ce8\u6587\u672c\u5e8f\u5217\u7684\u54ea\u4e9b\u90e8\u5206\u3002\u4e00\u4e2a\u8f6c\u6362\u5668\u53ef\u80fd\u6709\u591a\u4e2a\u6ce8\u610f\u529b\u5757\uff0c\u6bcf\u4e2a\u6ce8\u610f\u529b\u5757\u5b66\u4e60\u4e00\u95e8\u8bed\u8a00\u7684\u4e0d\u540c\u65b9\u9762\u3002<\/p>\n\n\n\n<p>\u6b64\u5916\uff0cGPT \u6a21\u578b\u8fd8\u5177\u6709\u8bb8\u591a\u529f\u80fd\uff0c\u4f8b\u5982\u751f\u6210\u524d\u6240\u672a\u6709\u7684\u9ad8\u8d28\u91cf\u5408\u6210\u6587\u672c\u6837\u672c\u3002\u5982\u679c\u7528\u8f93\u5165\u542f\u52a8\u6a21\u578b\uff0c\u5b83\u5c06\u751f\u6210\u4e00\u4e2a\u8f83\u957f\u7684\u5ef6\u7eed\u3002GPT \u6a21\u578b\u5728\u4e0d\u4f7f\u7528\u7279\u5b9a\u9886\u57df\u8bad\u7ec3\u6570\u636e\u7684\u60c5\u51b5\u4e0b\uff0c\u4f18\u4e8e\u5728\u7ef4\u57fa\u767e\u79d1\u3001\u65b0\u95fb\u548c\u4e66\u7c4d\u7b49\u9886\u57df\u8bad\u7ec3\u7684\u5176\u4ed6\u8bed\u8a00\u6a21\u578b\u3002GPT \u4ec5\u4ece\u6587\u672c\u4e2d\u5b66\u4e60\u8bed\u8a00\u4efb\u52a1\uff0c\u4f8b\u5982\u9605\u8bfb\u7406\u89e3\u3001\u603b\u7ed3\u548c\u95ee\u7b54\uff0c\u800c\u65e0\u9700\u7279\u5b9a\u4efb\u52a1\u7684\u8bad\u7ec3\u6570\u636e\u3002\u8fd9\u4e9b\u4efb\u52a1\u7684\u5206\u6570\uff08\u201c\u5206\u6570\u201d\u662f\u6307\u6a21\u578b\u5206\u914d\u7684\u6570\u503c\uff0c\u7528\u4e8e\u8868\u793a\u7ed9\u5b9a\u8f93\u51fa\u6216\u7ed3\u679c\u7684\u53ef\u80fd\u6027\u6216\u6982\u7387\uff09\u4e0d\u662f\u6700\u597d\u7684\uff0c\u4f46\u5b83\u4eec\u8868\u660e\u5177\u6709\u8db3\u591f\u6570\u636e\u548c\u8ba1\u7b97\u7684\u65e0\u76d1\u7763\u6280\u672f\u53ef\u4ee5\u4f7f\u4efb\u52a1\u53d7\u76ca\u3002<\/p>\n\n\n\n<p>\u4ee5\u4e0b\u6d4b\u8bd5\u73af\u5883\u4e3a ubuntu 22.04.03 tls<\/p>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>1.\u6570\u636e\u96c6\u51c6\u5907<\/strong><\/h2>\n\n\n\n<h2 class=\"wp-block-heading has-medium-font-size\">1.0 \u83b7\u53d6\u6570\u636e<\/h2>\n\n\n\n<p>\u4f7f\u7528Tiny Shakespeare \u6570\u636e\u96c6<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>$ md CreateGPT \n$ cd CreateGPT\n$ mkdir data\n$ cd data\n$ wget https:\/\/raw.githubusercontent.com\/karpathy\/char-rnn\/master\/data\/tinyshakespeare\/input.txt\n\n$ cd ..<\/code><\/pre>\n\n\n\n<p class=\"has-medium-font-size\">1.1 \u8bfb\u53d6 Tiny Shakespeare \u6570\u636e\u96c6\uff0c\u5e76\u6253\u5370\u6570\u636e\u96c6\u957f\u5ea6<\/p>\n\n\n\n<p>setup1.1.py \u4ee3\u7801\u5982\u4e0b\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code># read it in to inspect it\nwith open('data\/input.txt', 'r', encoding='utf-8') as f:\n    text = f.read()\nprint(\"length of dataset in characters: \", len(text))\n<\/code><\/pre>\n\n\n\n<p>\u8fd0\u884c python setup1.1.py<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>$ python setup1.1.py\nlength of dataset in characters:  1115394<\/code><\/pre>\n\n\n\n<p class=\"has-medium-font-size\">1.2 \u7edf\u8ba1\u6570\u636e\u96c6\u4e2d\u90fd\u5305\u542b\u54ea\u4e9b\u5b57\u7b26\u79cd\u7c7b\uff1a<\/p>\n\n\n\n<p>setup1.2.py \u4ee3\u7801\u5982\u4e0b\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code># read it in to inspect it\nwith open('data\/input.txt', 'r', encoding='utf-8') as f:\n    text = f.read()\n# print(\"length of dataset in characters: \", len(text))\n\n# here are all the unique characters that occur in this text\nchars = sorted(list(set(text)))\nvocab_size = len(chars)\nprint(''.join(chars))\nprint(vocab_size)<\/code><\/pre>\n\n\n\n<p>\u8fd0\u884c python setup1.2.py<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>$ python setup1.2.py\n\n !$&amp;',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\n65\n<\/code><\/pre>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>2.Tokenize<\/strong><\/h2>\n\n\n\n<p>Tokenize \u6307\u7684\u662f\uff0c\u5c06\u539f\u59cb\u6587\u672c\u8f6c\u5316\u4e3a\u4e00\u4e9b\u6570\u503c\u7684\u5e8f\u5217\uff0c\u5373 Token \u5e8f\u5217\u3002<\/p>\n\n\n\n<p>\u5982\u4f55\u5c06\u81ea\u7136\u8bed\u8a00\u6587\u672c\u53d8\u4e3a Token \u5e8f\u5217\uff0c\u6709\u5f88\u591a\u9ad8\u7ea7\u7b97\u6cd5\uff0c\u6bd4\u5982\uff1a<a href=\"https:\/\/github.com\/google\/sentencepiece\" target=\"_blank\" rel=\"noreferrer noopener\">google\/sentencepiece<\/a>\u3001<a href=\"https:\/\/github.com\/openai\/tiktoken\" target=\"_blank\" rel=\"noreferrer noopener\">openai\/tiktoken<\/a>\u3002<\/p>\n\n\n\n<p>\u524d\u9762\u8bf4\u5230\uff0c\u672c\u6587\u4e2d\u91c7\u7528\u7684\u57fa\u4e8e\u5b57\u7b26\u7ea7\u522b\u7684\u8bed\u8a00\u6a21\u578b\uff0c\u5b83\u7684 Tokennizer \u7b97\u6cd5\u5341\u5206\u7b80\u5355\u3002\u524d\u9762\u4ee3\u7801\u4e2d\u7684&nbsp;<code>chars<\/code>&nbsp;\u5305\u542b\u4e86\u8bed\u6599\u4e2d\u6240\u6709\u5b57\u7b26\u7684\u79cd\u7c7b\uff0c\u7ed9\u51fa\u4e00\u4e2a\u5b57\u7b26\uff0c\u53ea\u8981\u770b\u8be5\u5b57\u7b26\u5728&nbsp;<code>chars<\/code>&nbsp;\u4e2d\u7684 indexOf\uff0c\u5c31\u5f97\u5230\u4e86\u4e00\u79cd\u6570\u503c\u5316\u7684\u65b9\u6cd5\u3002<\/p>\n\n\n\n<p class=\"has-medium-font-size\">2.1 \u5c06\u4e0a\u9762\u7684\u8bcd\u6c47\u8868\uff08<code>chars<\/code>\uff09\u6620\u5c04\u4e3a\u6574\u6570<\/p>\n\n\n\n<p><code>stoi<\/code>&nbsp;\u5b57\u7b26\u5230\u6574\u6570\u7684\u6620\u5c04\uff0c<code>itos<\/code>&nbsp;\u6574\u6570\u5230\u5b57\u7b26\u7684\u6620\u5c04\u3002<code>encode<\/code>&nbsp;\u548c&nbsp;<code>decode<\/code>&nbsp;\u5206\u522b\u662f\u5bf9\u5b57\u7b26\u4e32\u7684\u7f16\u89e3\u7801\u3002<\/p>\n\n\n\n<p>setup2.1.py \u4ee3\u7801\u5982\u4e0b\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code># read it in to inspect it\nwith open('data\/input.txt', 'r', encoding='utf-8') as f:\n    text = f.read()\n# print(\"length of dataset in characters: \", len(text))\n\n# here are all the unique characters that occur in this text\nchars = sorted(list(set(text)))\nvocab_size = len(chars)\n# print(''.join(chars))\n# print(vocab_size)\n\n# create a mapping from characters to integers\nstoi = { ch:i for i,ch in enumerate(chars) }\nitos = { i:ch for i,ch in enumerate(chars) }\n# encoder: take a string, output a list of integers\nencode = lambda s: &#91;stoi&#91;c] for c in s]\n# decoder: take a list of integers, output a string\ndecode = lambda l: ''.join(&#91;itos&#91;i] for i in l])\n\nprint(encode(\"hii there\"))\nprint(decode(encode(\"hii there\")))<\/code><\/pre>\n\n\n\n<p>\u8fd0\u884c python setup2.1.py<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>$ python setup2.1.py\n&#91;46, 47, 47, 1, 58, 46, 43, 56, 43]\nhii there<\/code><\/pre>\n\n\n\n<p>\u5728\u4e0a\u9762\u4ee3\u7801\u4e2d\uff0c\u5b9e\u73b0\u4e86\u4e00\u4e2a\u7f16\u89e3\u7801\u65b9\u6cd5\uff0c\u80fd\u591f\u5c06\u6587\u672c\u7f16\u7801\u4e3a &#8220;Token&#8221; \u5e8f\u5217\u3002\u4e4b\u6240\u4ee5 Token \u8981\u6253\u5f15\u53f7\uff0c\u56e0\u4e3a\u5728\u57fa\u4e8e\u5b57\u7b26\u7ea7\u522b\u7684\u7c92\u5ea6\u4e0b\uff0c\u7b97\u6cd5\u76f4\u89c2\u4f46\u662f\u8fc7\u4e8e\u7b80\u5355\u3002\u8fd8\u6709\u4e00\u70b9\u9700\u8981\u6ce8\u610f\u7684\u662f\uff0c\u63a5\u4e0b\u6765\u4f7f\u7528\u7684\u81ea\u7136\u8bed\u8a00\uff0c\u5fc5\u987b\u662f&nbsp;<code>chars<\/code>&nbsp;\u4e2d\u7684\u5b57\u7b26\uff0c\u4e0d\u80fd\u8d85\u51fa\u8fd9\u4e2a\u8303\u56f4\u3002<\/p>\n\n\n\n<p>\u4e0d\u8bba\u662f\u9ad8\u7ea7\u7b97\u6cd5\u8fd8\u662f\u672c\u6587\u4e2d\u7684\u7b80\u5316\u65b9\u6cd5\uff1a\u539f\u7406\u90fd\u662f\u4e00\u6837\u7684\uff0c<mark>\u5c06\u6587\u672c\u8f6c\u4e3a\u6570\u503c\u5e8f\u5217<\/mark>\u3002<\/p>\n\n\n\n<p class=\"has-medium-font-size\">2.2 \u4ee5 tiktoken\uff08\u6709 50257 \u79cd Tokens\uff09\u4e3a\u4f8b\uff1a<\/p>\n\n\n\n<p>setup2.2.py \u4ee3\u7801\u5982\u4e0b\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>import tiktoken\n\nenc = tiktoken.get_encoding('gpt2')\n\nprint(enc.n_vocab)\n# 50257\n\nprint(enc.encode('hii there'))\n# &#91;71, 4178, 612]\n\nprint(enc.decode(&#91;71, 4178, 612]))\n# 'hii there'\n<\/code><\/pre>\n\n\n\n<p>\u8fd0\u884c python setup2.2.py<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>$ python setup2.2.py\n50257\n&#91;71, 4178, 612]\nhii there<\/code><\/pre>\n\n\n\n<p>\u8fd9\u91cc\u53ef\u80fd\u9700\u8981\u4f60\u5b89\u88c5 tiktoken \u6a21\u5757<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>$ pip install tiktoken<\/code><\/pre>\n\n\n\n<p>\u53ef\u4ee5\u770b\u5230\uff0c\u4f7f\u7528\u8d77\u6765\u4e0e\u672c\u6587\u662f\u4e00\u6837\u7684\u3002\u4f46\u662f\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>\u901a\u8fc7\u9ad8\u7ea7 Tokenizer \u7f16\u7801\u540e\uff0c\u5e8f\u5217\u7684\u957f\u5ea6\u53d8\u77ed\u3002Tiny Shakespeare \u8981\u6bcf\u4e2a\u5b57\u7b26\u4e00\u4e2a Token\uff0c\u800c\u8bed\u8a00\u6a21\u578b\u7684\u4e0a\u4e0b\u6587\u6709\u9650\uff0c\u80fd\u591f\u652f\u6301\u7684 Token \u957f\u5ea6\u4e5f\u6709\u9650\u3002\u56e0\u6b64\u9ad8\u7ea7\u7684 Tokenizer\uff0c\u63d0\u5347\u4e86\u8868\u8fbe\u539f\u59cb\u4fe1\u606f\u7684\u5bc6\u5ea6\u548c\u6548\u7387\u3002\n\u6ce8\uff1a\u5728\u9ad8\u7ea7\u7b97\u6cd5\u4e2d\uff0c\u62c6\u89e3\u51fa\u6765\u7684\u662f subwords\uff08\u5b50\u8bcd\u5355\u5143\uff09\u3002 \u65e2\u4e0d\u662f\u5bf9\u6574\u4e2a\u5355\u8bcd\u7f16\u7801\uff0c\u4e5f\u4e0d\u662f\u5bf9\u5355\u4e2a\u5b57\u7b26\u7f16\u7801\uff0c\u800c\u662f\u6309\u7167\u7edf\u8ba1\u8bad\u7ec3\uff0c\u5bf9\u5b50\u8bcd\u7f16\u7801\u3002<\/code><\/pre>\n\n\n\n<p>2.3 \u5c06\u6574\u4e2a Tiny Shakespeare \u7f16\u7801\u540e\uff0c\u8f6c\u4e3a PyTorch \u5e8f\u5217<\/p>\n\n\n\n<p>setup2.3.py \u4ee3\u7801\u5982\u4e0b\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code># let's now encode the entire text dataset and store it into a torch.Tensor\nimport torch # we use PyTorch: https:\/\/pytorch.org\nwith open('data\/input.txt', 'r', encoding='utf-8') as f:\n    text = f.read()\ndata = torch.tensor(encode(text), dtype=torch.long)\nprint(data.shape, data.dtype)\n# the 1000 characters we looked at earier will to the GPT look like this\nprint(data&#91;:1000]) \n<\/code><\/pre>\n\n\n\n<p>\u8fd0\u884c python setup2.3.py<\/p>\n\n\n\n<figure class=\"wp-block-image size-large\"><img loading=\"lazy\" decoding=\"async\" width=\"706\" height=\"1024\" src=\"https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/02\/\u56fe\u7247-706x1024.png\" alt=\"\" class=\"wp-image-2052\" srcset=\"https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/02\/\u56fe\u7247-706x1024.png 706w, https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/02\/\u56fe\u7247-207x300.png 207w, https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/02\/\u56fe\u7247-768x1114.png 768w, https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/02\/\u56fe\u7247-1059x1536.png 1059w, https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/02\/\u56fe\u7247.png 1208w\" sizes=\"auto, (max-width: 706px) 100vw, 706px\" \/><\/figure>\n\n\n\n<p>\u8fd9\u91cc\u53ef\u80fd\u9700\u8981\u4f60\u5b89\u88c5 torch \u6a21\u5757<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>$ pip install torch<\/code><\/pre>\n\n\n\n<h2 class=\"wp-block-heading\">3. \u8bad\u7ec3\u96c6\u3001\u9a8c\u8bc1\u96c6\u3001\u6570\u636e\u5207\u5206<\/h2>\n\n\n\n<p>Tiny Shakespeare \u7684\u524d 90% \u7528\u4e8e\u8bad\u7ec3\uff0c\u540e 10% \u7528\u4e8e\u9a8c\u8bc1<\/p>\n\n\n\n<p>\u5206\u7247\uff08Chunk\uff09\u79f0\u4e4b\u4e3a&nbsp;<strong>block<\/strong>\uff0c\u5206\u7247\u7684\u5927\u5c0f\u79f0\u4e4b\u4e3a&nbsp;<strong>block_size<\/strong><\/p>\n\n\n\n<p>setup3.1.py \u4ee3\u7801\u5982\u4e0b\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code># let's now encode the entire text dataset and store it into a torch.Tensor\nimport torch # we use PyTorch: https:\/\/pytorch.org\nwith open('data\/input.txt', 'r', encoding='utf-8') as f:\n    text = f.read()\n\nchars = sorted(list(set(text)))\n\n# create a mapping from characters to integers\nstoi = { ch:i for i,ch in enumerate(chars) }\nitos = { i:ch for i,ch in enumerate(chars) }\n# encoder: take a string, output a list of integers\nencode = lambda s: &#91;stoi&#91;c] for c in s]\n\ndata = torch.tensor(encode(text), dtype=torch.long)\n\n# Let's now split up the data into train and validation sets\nn = int(0.9*len(data)) # first 90% will be train, rest val\ntrain_data = data&#91;:n]\nval_data = data&#91;n:]\n\nblock_size = 8\n\nprint(train_data&#91;:block_size+1]) <\/code><\/pre>\n\n\n\n<p>\u8fd0\u884c python setup3.1.py<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>$ python setup3.1.py\ntensor(&#91;18, 47, 56, 57, 58,  1, 15, 47, 58])<\/code><\/pre>\n\n\n\n<p>\u8fd9\u91cc\u53d6\u4e86\u8bad\u7ec3\u96c6\u4e2d\u7684\u7b2c\u4e00\u4e2a Block\u3002<strong>block_size<\/strong>&nbsp;\u5927\u5c0f\u4e3a 8\uff0c\u4e3a\u4ec0\u4e48\u6211\u4eec\u53d6\u4e86 9 \u4e2a Token \u5462\uff1f<\/p>\n\n\n\n<p>\u4e3a\u4e86\u7406\u89e3\u8fd9\u4e2a\u95ee\u9898\uff0c\u9996\u5148\u770b\u8bad\u7ec3\u65b9\u5f0f\uff0c\u5c06 Chunck \u62c6\u5206\u4e3a\u4e24\u4e2a\u5b50\u96c6 x \u548c y\uff0c\u5176\u4e2d x \u8868\u793a\u8f93\u5165 Token \u5e8f\u5217\uff0c\u5728\u4f7f\u7528\u65f6\u662f\u7d2f\u589e\u7684\uff0cy \u8868\u793a\u57fa\u4e8e\u8be5\u8f93\u5165\uff0c\u4e0e\u5176\u7684\u8f93\u51fa\u3002<\/p>\n\n\n\n<p>setup3.2.py \u4ee3\u7801\u5982\u4e0b\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code># let's now encode the entire text dataset and store it into a torch.Tensor\nimport torch # we use PyTorch: https:\/\/pytorch.org\nwith open('data\/input.txt', 'r', encoding='utf-8') as f:\n    text = f.read()\n\nchars = sorted(list(set(text)))\n\n# create a mapping from characters to integers\nstoi = { ch:i for i,ch in enumerate(chars) }\nitos = { i:ch for i,ch in enumerate(chars) }\n# encoder: take a string, output a list of integers\nencode = lambda s: &#91;stoi&#91;c] for c in s]\n\ndata = torch.tensor(encode(text), dtype=torch.long)\n\n# Let's now split up the data into train and validation sets\nn = int(0.9*len(data)) # first 90% will be train, rest val\ntrain_data = data&#91;:n]\nval_data = data&#91;n:]\n\nblock_size = 8\n\nx = train_data&#91;:block_size]\ny = train_data&#91;1:block_size+1]\nfor t in range(block_size):\n    context = x&#91;:t+1]\n    target = y&#91;t]\n    print(f\"when input is {context} the target: {target}\")<\/code><\/pre>\n\n\n\n<p>\u8fd0\u884c python setup3.2.py<\/p>\n\n\n\n<pre class=\"wp-block-code\" style=\"font-size:10px\"><code>$ python setup3.2.py\nwhen input is tensor(&#91;18]) the target: 47\nwhen input is tensor(&#91;18, 47]) the target: 56\nwhen input is tensor(&#91;18, 47, 56]) the target: 57\nwhen input is tensor(&#91;18, 47, 56, 57]) the target: 58\nwhen input is tensor(&#91;18, 47, 56, 57, 58]) the target: 1\nwhen input is tensor(&#91;18, 47, 56, 57, 58,  1]) the target: 15\nwhen input is tensor(&#91;18, 47, 56, 57, 58,  1, 15]) the target: 47\nwhen input is tensor(&#91;18, 47, 56, 57, 58,  1, 15, 47]) the target: 58<\/code><\/pre>\n\n\n\n<p>\u7ed9\u51fa\u4e00\u4e2a Block\uff0c\u5206\u4e3a\u51e0\u8f6e\u3002\u7b2c\u4e00\u8f6e\uff0c\u7528\u7b2c\u4e00\u4e2a Token \u63a8\u6d4b\u7b2c\u4e8c\u4e2a Token\u3002\u7b2c\u4e8c\u8f6e\uff0c\u7528<mark>\u524d\u4e24\u4e2a Token<\/mark>\u63a8\u6d4b\u7b2c\u4e09\u4e2a Token\u3002\u4ee5\u6b64\u7c7b\u63a8\uff0c\u5230\u4e86\u7b2c\u516b\u8f6e\uff0c\u7528\u524d\u516b\u4e2a Token \u63a8\u6d4b\u7b2c\u4e5d\u4e2a Token\u3002<\/p>\n\n\n\n<p><strong>block_size<\/strong>&nbsp;\u5927\u5c0f\u4e3a 8\uff0c\u8868\u793a\u6211\u4eec\u7684\u6700\u5927\u8bad\u7ec3\u957f\u5ea6\u4e3a 8\u3002\u6bcf\u4e00\u6279\u6570\u636e\u6709 9 \u4e2a\u5143\u7d20\uff0c\u5176\u4e2d\u7b2c\u4e5d\u4e2a\u5143\u7d20\u4e0d\u53c2\u4e0e\u8bad\u7ec3\uff0c\u53ea\u53c2\u4e0e\u9a8c\u8bc1\u3002<\/p>\n\n\n\n<p>\u6211\u4eec\u5c06 Tiny Shakespeare \u5207\u5206\u79f0\u4e00\u7cfb\u5217 Block\uff0c\u5c31\u76f8\u5f53\u4e8e\u4e00\u7cfb\u5217\u8003\u8bd5\u9898\uff0c\u6bcf\u9053\u9898\u662f\u4e00\u4e2a\u957f\u5ea6\u4e3a 9 \u7684\u8fde\u7eed Token \u5e8f\u5217\uff0c\u6309\u7167\u4e0a\u56fe\u65b9\u5f0f\uff0c\u8003\u8bed\u8a00\u6a21\u578b\u3002<\/p>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>4. Batch \u5212\u5206<\/strong><\/h2>\n\n\n\n<p>\u5c06\u8bad\u7ec3\u96c6\u8fdb\u884c Block \u5207\u5206\u540e\uff0c\u6211\u4eec\u53ef\u4ee5\u4e00\u4e2a\u4e00\u4e2a\u5411 GPU \u6295\u5582\uff08\u8bad\u7ec3\uff09\u3002\u4f46\u662f\uff0c\u6211\u4eec\u60f3\uff0cGPU \u4ec0\u4e48\u80fd\u529b\u6700\u5f3a\u5927\uff1f\u5e76\u884c\u8ba1\u7b97\u80fd\u529b\uff01\u4e00\u4e2a\u4e00\u4e2a\u5411 GPU \u6295\u5582\u5582\u4e0d\u9971\u3002\u4e3a\u4e86\u80fd\u591f\u5145\u5206\u53d1\u6325\u51fa GPU \u7684\u5e76\u884c\u8fd0\u7b97\u80fd\u529b\uff0c\u6211\u4eec\u5c06\u591a\u4e2a Block \u6253\u5305\u6210<mark>\u4e00\u6279\uff08Batch\uff09<\/mark>\uff0c\u4e00\u6279\u4e00\u6279\u5411 GPU \u6295\u5582\u3002\u603b\u4e4b\u4e00\u53e5\u8bdd\uff0c\u4e0d\u80fd\u8ba9 GPU \u95f2\u7740\uff0c\u63d0\u5347\u8bad\u7ec3\u6548\u7387\u3002<\/p>\n\n\n\n<p>\u503c\u5f97\u4e00\u63d0\u7684\u662f\uff0c\u5c3d\u7ba1\u4e00\u4e2a Batch \u5185\u7684 Blocks \u662f\u4e00\u6279\u8fdb\u5165 GPU \u7684\uff0c\u4f46\u662f\u5b83\u4eec\u4e4b\u95f4\u76f8\u4e92\u9694\u79bb\uff0c\u4e92\u76f8\u4e0d\u77e5\u9053\u5bf9\u65b9\u7684\u5b58\u5728\uff0c<strong>\u4e92\u4e0d\u5e72\u6270<\/strong>\u3002<\/p>\n\n\n\n<p>setup4.1.py \u4ee3\u7801\u5982\u4e0b\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code># let's now encode the entire text dataset and store it into a torch.Tensor\nimport torch # we use PyTorch: https:\/\/pytorch.org\nwith open('data\/input.txt', 'r', encoding='utf-8') as f:\n    text = f.read()\n\nchars = sorted(list(set(text)))\n\n# create a mapping from characters to integers\nstoi = { ch:i for i,ch in enumerate(chars) }\nitos = { i:ch for i,ch in enumerate(chars) }\n# encoder: take a string, output a list of integers\nencode = lambda s: &#91;stoi&#91;c] for c in s]\n\ndata = torch.tensor(encode(text), dtype=torch.long)\n\n# Let's now split up the data into train and validation sets\nn = int(0.9*len(data)) # first 90% will be train, rest val\ntrain_data = data&#91;:n]\nval_data = data&#91;n:]\n\ntorch.manual_seed(1337)\n# how many independent sequences will we process in parallel?\nbatch_size = 4\n# what is the maximum context length for predictions?\nblock_size = 8\n\ndef get_batch(split):\n    # generate a small batch of data of inputs x and targets y\n    data = train_data if split == 'train' else val_data\n    ix = torch.randint(len(data) - block_size, (batch_size,))\n    x = torch.stack(&#91;data&#91;i:i+block_size] for i in ix])\n    y = torch.stack(&#91;data&#91;i+1:i+block_size+1] for i in ix])\n    return x, y\n\nxb, yb = get_batch('train')\nprint('inputs:')\nprint(xb.shape)\nprint(xb)\nprint('targets:')\nprint(yb.shape)\nprint(yb)\n\nprint('----')\n\nfor b in range(batch_size): # batch dimension\n    for t in range(block_size): # time dimension\n        context = xb&#91;b, :t+1]\n        target = yb&#91;b,t]\n        print(f\"when input is {context.tolist()} the target: {target}\")<\/code><\/pre>\n\n\n\n<p>\u8fd0\u884c python setup4.1.py<\/p>\n\n\n\n<pre class=\"wp-block-code\" style=\"font-size:10px\"><code>$ python setup4.1.py\ninputs:\ntorch.Size(&#91;4, 8])\ntensor(&#91;&#91;24, 43, 58,  5, 57,  1, 46, 43],\n        &#91;44, 53, 56,  1, 58, 46, 39, 58],\n        &#91;52, 58,  1, 58, 46, 39, 58,  1],\n        &#91;25, 17, 27, 10,  0, 21,  1, 54]])\ntargets:\ntorch.Size(&#91;4, 8])\ntensor(&#91;&#91;43, 58,  5, 57,  1, 46, 43, 39],\n        &#91;53, 56,  1, 58, 46, 39, 58,  1],\n        &#91;58,  1, 58, 46, 39, 58,  1, 46],\n        &#91;17, 27, 10,  0, 21,  1, 54, 39]])\n----\nwhen input is &#91;24] the target: 43\nwhen input is &#91;24, 43] the target: 58\nwhen input is &#91;24, 43, 58] the target: 5\nwhen input is &#91;24, 43, 58, 5] the target: 57\nwhen input is &#91;24, 43, 58, 5, 57] the target: 1\nwhen input is &#91;24, 43, 58, 5, 57, 1] the target: 46\nwhen input is &#91;24, 43, 58, 5, 57, 1, 46] the target: 43\nwhen input is &#91;24, 43, 58, 5, 57, 1, 46, 43] the target: 39\nwhen input is &#91;44] the target: 53\nwhen input is &#91;44, 53] the target: 56\nwhen input is &#91;44, 53, 56] the target: 1\nwhen input is &#91;44, 53, 56, 1] the target: 58\nwhen input is &#91;44, 53, 56, 1, 58] the target: 46\nwhen input is &#91;44, 53, 56, 1, 58, 46] the target: 39\nwhen input is &#91;44, 53, 56, 1, 58, 46, 39] the target: 58\nwhen input is &#91;44, 53, 56, 1, 58, 46, 39, 58] the target: 1\nwhen input is &#91;52] the target: 58\nwhen input is &#91;52, 58] the target: 1\nwhen input is &#91;52, 58, 1] the target: 58\nwhen input is &#91;52, 58, 1, 58] the target: 46\nwhen input is &#91;52, 58, 1, 58, 46] the target: 39\nwhen input is &#91;52, 58, 1, 58, 46, 39] the target: 58\nwhen input is &#91;52, 58, 1, 58, 46, 39, 58] the target: 1\nwhen input is &#91;52, 58, 1, 58, 46, 39, 58, 1] the target: 46\nwhen input is &#91;25] the target: 17\nwhen input is &#91;25, 17] the target: 27\nwhen input is &#91;25, 17, 27] the target: 10\nwhen input is &#91;25, 17, 27, 10] the target: 0\nwhen input is &#91;25, 17, 27, 10, 0] the target: 21\nwhen input is &#91;25, 17, 27, 10, 0, 21] the target: 1\nwhen input is &#91;25, 17, 27, 10, 0, 21, 1] the target: 54\nwhen input is &#91;25, 17, 27, 10, 0, 21, 1, 54] the target: 39<\/code><\/pre>\n\n\n\n<p>\u4ece\u4ee3\u7801\u4e2d\u53ef\u4ee5\u770b\u51fa\uff1aBlock \u7684\u5927\u5c0f\u4e3a 8\u3002Batch \u7684\u5927\u5c0f\u4e3a 4\uff0c\u5373\u4e00\u4e2a Batch \u5305\u542b 4 \u4e2a Blocks\u3002\u53e6\u5916\u9501\u5b9a\u4e86\u968f\u673a\u6570\u79cd\u5b50\u4e3a&nbsp;<code>1337<\/code>\uff0c\u8fd9\u6837\u6211\u4eec\u90fd\u80fd\u590d\u73b0\u8ddf&nbsp;Andrej Karpathy&nbsp;\u4e00\u6837\u7684\u8bad\u7ec3\u6548\u679c\u3002<\/p>\n\n\n\n<p>\u4e0a\u8ff0\u4ee3\u7801\uff0c\u8fd0\u884c\u540e\u7684\u65e5\u5fd7\u8f93\u51fa\u5982\u4e0a\u3002\u53ef\u4ee5\u770b\u51fa\uff1a\u8f93\u5165\u7531 1 \u4e2a 8 \u5143\u7d20\u5411\u91cf\u53d8\u4e3a 4 \u4e2a\u3002\u9a8c\u8bc1\u5411\u91cf\u4e5f\u53d8\u4e3a 4 \u4e2a\u3002\u90fd Batch \u5316\u4e86\u3002\u5728\u540e\u7eed\u7684\u63a8\u7406\u91ca\u4e49\u4e2d\uff0c\u4e5f\u662f\u5c06 Batch \u5185\u6bcf\u4e2a Block \u7684\u63a8\u7406\u8fc7\u7a0b\u6253\u5370\u51fa\u6765\u3002<\/p>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>5. BigramLanguageModel V1<\/strong><\/h2>\n\n\n\n<p>\u4e8c\u5143\u8bed\u8a00\u6a21\u578b\uff08BigramLanguageModel\uff09\uff0c\u6982\u62ec\u8bf4\uff1a\u6839\u636e\u524d\u4e00\u4e2a\u8bcd\uff0c\u6765\u63a8\u6d4b\u4e0b\u4e00\u4e2a\u8bcd\u3002\u4e3e\u4f8b\u6765\u8bf4\uff1a\u4f8b\u5982\uff0c\u5bf9\u4e8e\u53e5\u5b50 &#8220;I love to play football&#8221;\uff0c\u4f1a\u5f97\u5230\u4ee5\u4e0b\u7684\u8bcd\u7ec4\uff1a&#8221;I love&#8221;, &#8220;love to&#8221;, &#8220;to play&#8221;, &#8220;play football&#8221;\u3002<\/p>\n\n\n\n<p>\u63a5\u4e0b\u6765\uff0c\u6211\u4eec\u6765\u5b9e\u73b0\u7b2c\u4e00\u4e2a BigramLanguageModel\uff0c\u4e0e\u89c6\u9891\u4e0d\u540c\u4e4b\u5904\u5728\u4e8e\uff0c\u6211\u79f0\u4e4b\u4e3a BigramLanguageModelV1\uff0c\u540e\u7eed\u6bcf\u8fdb\u884c\u4e00\u6b21\u66f4\u6539\uff0c\u90fd\u4f1a\u521b\u5efa\u4e00\u4e2a\u65b0\u7c7b\uff0c\u5e76\u63d0\u5347\u7248\u672c\u3002<\/p>\n\n\n\n<p>\u6a21\u578b\u7ee7\u627f\u81ea Pytorch \u7684 Module\uff0c\u5728\u6784\u9020\u65b9\u6cd5\u4e2d\u58f0\u660e\u6a21\u578b\u5185\u90e8\u5305\u542b\u7684\u5c42\u3002\u53ef\u89c1\u8be5\u6a21\u578b\u53ea\u6709\u4e00\u5c42\uff08nn.Embedding\uff09\u3002\u6a21\u578b\u8fd8\u5305\u62ec forward\uff0c\u524d\u5411\u4f20\u64ad\u8fc7\u7a0b\uff0c\u7528\u4e8e\u8bad\u7ec3\u3002\u6a21\u578b\u4e00\u65e6\u8bad\u7ec3\u597d\u540e\uff0c\u901a\u8fc7 generate \u53ef\u8fdb\u884c\u6587\u672c\u751f\u6210\u3002<\/p>\n\n\n\n<p>setup5.1.py \u4ee3\u7801\u5b9e\u73b0\u5982\u4e0b\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>import torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\ntorch.manual_seed(1337)\n\n# \u4e8c\u5143\u8bed\u8a00\u6a21\u578b\u5b9e\u73b0\nclass BigramLanguageModelV1(nn.Module):\n\n    def __init__(self, vocab_size):\n        super().__init__()\n        # \u6bcf\u4e2a\u8bcd\u76f4\u63a5\u4ece\u4e00\u4e2a\u67e5\u627e\u8868\u4e2d\u83b7\u53d6\u4e0b\u4e00\u4e2a\u8bcd\u7684logits\u503c\n        # logits\u662f\u6a21\u578b\u505a\u51fa\u9884\u6d4b\u524d\u7684\u4e00\u7ec4\u672a\u7ecf\u5f52\u4e00\u5316\u7684\u5206\u6570\uff0c\u53cd\u6620\u4e86\u4e0d\u540c\u7ed3\u679c\u7684\u76f8\u5bf9\u53ef\u80fd\u6027\n        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)\n\n\t# \u6a21\u578b\u524d\u5411\u4f20\u64ad\n\t# idx\uff1a\u5373\u524d\u9762\u7684 x\uff0c\u8868\u793a\u8f93\u5165\u6570\u636e\uff0c\u8bcd\u5728\u8bcd\u6c47\u8868\u4e2d\u7684\u7d22\u5f15\u7684\u5411\u91cf\n\t# targets\uff1a\u8bad\u7ec3\u7684\u76ee\u6807\u8f93\u51fa\uff0c\u6bd4\u5982\u6b63\u786e\u7684\u4e0b\u4e00\u4e2a\u8bcd\u7684\u7d22\u5f15\n    def forward(self, idx, targets=None):\n        # idx and targets are both (B,T) tensor of integers\n        logits = self.token_embedding_table(idx) # (B,T,C)\n\n        if targets is None:\n            loss = None\n        else:\n            B, T, C = logits.shape\n            logits = logits.view(B*T, C)\n            targets = targets.view(B*T)\n            loss = F.cross_entropy(logits, targets)\n\n        return logits, loss\n\n\t# \u5728\u6a21\u578b\u5df2\u7ecf\u8bad\u7ec3\u597d\u4e4b\u540e\uff0c\u6839\u636e\u7ed9\u5b9a\u7684\u8f93\u5165\u751f\u6210\u6587\u672c\u7684\u65b9\u6cd5\u3002\n    def generate(self, idx, max_new_tokens):\n        # idx is (B, T) array of indices in the current context\n        for _ in range(max_new_tokens):\n            # get the predictions\n            logits, loss = self(idx)\n            # focus only on the last time step\n            logits = logits&#91;:, -1, :] # becomes (B, C)\n            # apply softmax to get probabilities\n            probs = F.softmax(logits, dim=-1) # (B, C)\n            # sample from the distribution\n            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n            # append sampled index to the running sequence\n            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n        return idx\n        \n# let's now encode the entire text dataset and store it into a torch.Tensor\nwith open('data\/input.txt', 'r', encoding='utf-8') as f:\n    text = f.read()\n\nchars = sorted(list(set(text)))\nvocab_size = len(chars)\n\n# create a mapping from characters to integers\nstoi = { ch:i for i,ch in enumerate(chars) }\nitos = { i:ch for i,ch in enumerate(chars) }\n# encoder: take a string, output a list of integers\nencode = lambda s: &#91;stoi&#91;c] for c in s]\n\ndata = torch.tensor(encode(text), dtype=torch.long)\n\n# Let's now split up the data into train and validation sets\nn = int(0.9*len(data)) # first 90% will be train, rest val\ntrain_data = data&#91;:n]\nval_data = data&#91;n:]\n\ntorch.manual_seed(1337)\n# how many independent sequences will we process in parallel?\nbatch_size = 4\n# what is the maximum context length for predictions?\nblock_size = 8\n\ndef get_batch(split):\n    # generate a small batch of data of inputs x and targets y\n    data = train_data if split == 'train' else val_data\n    ix = torch.randint(len(data) - block_size, (batch_size,))\n    x = torch.stack(&#91;data&#91;i:i+block_size] for i in ix])\n    y = torch.stack(&#91;data&#91;i+1:i+block_size+1] for i in ix])\n    return x, y\n\nxb, yb = get_batch('train')\nprint('inputs:')\nprint(xb.shape)\nprint(xb)\nprint('targets:')\nprint(yb.shape)\nprint(yb)\n\n\n# get device\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\n# create model\nm = BigramLanguageModelV1(vocab_size).to(device)\n\nlogits, loss = m(xb.to(device), \n                 yb.to(device))\nprint(logits.shape)\nprint(loss)\n<\/code><\/pre>\n\n\n\n<p>\u8fd0\u884c python setup5.1.py<\/p>\n\n\n\n<pre class=\"wp-block-code\" style=\"font-size:10px\"><code>$ python setup5.1.py\ninputs:\ntorch.Size(&#91;4, 8])\ntensor(&#91;&#91;24, 43, 58,  5, 57,  1, 46, 43],\n        &#91;44, 53, 56,  1, 58, 46, 39, 58],\n        &#91;52, 58,  1, 58, 46, 39, 58,  1],\n        &#91;25, 17, 27, 10,  0, 21,  1, 54]])\ntargets:\ntorch.Size(&#91;4, 8])\ntensor(&#91;&#91;43, 58,  5, 57,  1, 46, 43, 39],\n        &#91;53, 56,  1, 58, 46, 39, 58,  1],\n        &#91;58,  1, 58, 46, 39, 58,  1, 46],\n        &#91;17, 27, 10,  0, 21,  1, 54, 39]])\ntorch.Size(&#91;32, 65])\ntensor(5.0364, device='cuda:0', grad_fn=&lt;NllLossBackward0&gt;)<\/code><\/pre>\n\n\n\n<p><mark>logits \u662f\u6a21\u578b\u505a\u51fa\u9884\u6d4b\u524d\u7684\u4e00\u7ec4\u672a\u7ecf\u5f52\u4e00\u5316\u7684\u5206\u6570\uff0c\u53cd\u6620\u4e86\u4e0d\u540c\u7ed3\u679c\u7684\u76f8\u5bf9\u53ef\u80fd\u6027<\/mark>\u3002\u5982\u4f55\u7406\u89e3 logits \u7684 shape \u5462\uff1fxb\uff084&#215;8\uff09\u7684\u6bcf\u4e2a\u5143\u7d20\uff08Token\uff0c\u5b57\u6bcd\u5728\u8bcd\u6c47\u8868\u4e2d\u7684\u6392\u5e8f\uff09\uff0c\u5728&nbsp;<code>forward<\/code>&nbsp;\u4e2d\uff0c\u90fd\u8981\u8f93\u5165&nbsp;<code>nn.Embedding<\/code>\uff0c\u5f97\u5230\u4e00\u4e2a\u5927\u5c0f\u4e3a 65\uff08\u8bcd\u6c47\u8868\u5927\u5c0f\uff09\u7684\u5411\u91cf\u3002\u8be5\u5411\u91cf\u4e2d\u7684\u6bcf\u4e2a\u5143\u7d20\uff0c\u8868\u793a\u6709\u5f53\u524d Token\uff0c\u63a8\u6d4b\u51fa\u8be5\u5411\u91cf\u8868\u793a Token \u7684\u53ef\u80fd\u6027\uff08\u672a\u5f52\u4e00\u5316\uff09\u3002<\/p>\n\n\n\n<p>\u4e0b\u9762\u4ee5\u7b2c\u4e00\u4e2a\u8bcd\u4e3a\u4f8b\uff0c\u5b83\u7684\u957f\u5ea6\u4e3a 65 \u7684\u8868\u793a 65 \u79cd\u5b57\u7b26\u53ef\u80fd\u6027\u7684\u5411\u91cf\u4e3a\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>print(logits&#91;0].shape)\nprint(logits&#91;0])<\/code><\/pre>\n\n\n\n<p>setup5.2.py \u5b8c\u6574\u7684\u4ee3\u7801\u5982\u4e0b\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>import torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\ntorch.manual_seed(1337)\n\n# \u4e8c\u5143\u8bed\u8a00\u6a21\u578b\u5b9e\u73b0\nclass BigramLanguageModelV1(nn.Module):\n\n    def __init__(self, vocab_size):\n        super().__init__()\n        # \u6bcf\u4e2a\u8bcd\u76f4\u63a5\u4ece\u4e00\u4e2a\u67e5\u627e\u8868\u4e2d\u83b7\u53d6\u4e0b\u4e00\u4e2a\u8bcd\u7684logits\u503c\n        # logits\u662f\u6a21\u578b\u505a\u51fa\u9884\u6d4b\u524d\u7684\u4e00\u7ec4\u672a\u7ecf\u5f52\u4e00\u5316\u7684\u5206\u6570\uff0c\u53cd\u6620\u4e86\u4e0d\u540c\u7ed3\u679c\u7684\u76f8\u5bf9\u53ef\u80fd\u6027\n        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)\n\n\t# \u6a21\u578b\u524d\u5411\u4f20\u64ad\n\t# idx\uff1a\u5373\u524d\u9762\u7684 x\uff0c\u8868\u793a\u8f93\u5165\u6570\u636e\uff0c\u8bcd\u5728\u8bcd\u6c47\u8868\u4e2d\u7684\u7d22\u5f15\u7684\u5411\u91cf\n\t# targets\uff1a\u8bad\u7ec3\u7684\u76ee\u6807\u8f93\u51fa\uff0c\u6bd4\u5982\u6b63\u786e\u7684\u4e0b\u4e00\u4e2a\u8bcd\u7684\u7d22\u5f15\n    def forward(self, idx, targets=None):\n        # idx and targets are both (B,T) tensor of integers\n        logits = self.token_embedding_table(idx) # (B,T,C)\n\n        if targets is None:\n            loss = None\n        else:\n            B, T, C = logits.shape\n            logits = logits.view(B*T, C)\n            targets = targets.view(B*T)\n            loss = F.cross_entropy(logits, targets)\n\n        return logits, loss\n\n\t# \u5728\u6a21\u578b\u5df2\u7ecf\u8bad\u7ec3\u597d\u4e4b\u540e\uff0c\u6839\u636e\u7ed9\u5b9a\u7684\u8f93\u5165\u751f\u6210\u6587\u672c\u7684\u65b9\u6cd5\u3002\n    def generate(self, idx, max_new_tokens):\n        # idx is (B, T) array of indices in the current context\n        for _ in range(max_new_tokens):\n            # get the predictions\n            logits, loss = self(idx)\n            # focus only on the last time step\n            logits = logits&#91;:, -1, :] # becomes (B, C)\n            # apply softmax to get probabilities\n            probs = F.softmax(logits, dim=-1) # (B, C)\n            # sample from the distribution\n            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n            # append sampled index to the running sequence\n            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n        return idx\n        \n# let's now encode the entire text dataset and store it into a torch.Tensor\nwith open('data\/input.txt', 'r', encoding='utf-8') as f:\n    text = f.read()\n\nchars = sorted(list(set(text)))\nvocab_size = len(chars)\n\n# create a mapping from characters to integers\nstoi = { ch:i for i,ch in enumerate(chars) }\nitos = { i:ch for i,ch in enumerate(chars) }\n# encoder: take a string, output a list of integers\nencode = lambda s: &#91;stoi&#91;c] for c in s]\n\ndata = torch.tensor(encode(text), dtype=torch.long)\n\n# Let's now split up the data into train and validation sets\nn = int(0.9*len(data)) # first 90% will be train, rest val\ntrain_data = data&#91;:n]\nval_data = data&#91;n:]\n\ntorch.manual_seed(1337)\n# how many independent sequences will we process in parallel?\nbatch_size = 4\n# what is the maximum context length for predictions?\nblock_size = 8\n\ndef get_batch(split):\n    # generate a small batch of data of inputs x and targets y\n    data = train_data if split == 'train' else val_data\n    ix = torch.randint(len(data) - block_size, (batch_size,))\n    x = torch.stack(&#91;data&#91;i:i+block_size] for i in ix])\n    y = torch.stack(&#91;data&#91;i+1:i+block_size+1] for i in ix])\n    return x, y\n\nxb, yb = get_batch('train')\nprint('inputs:')\nprint(xb.shape)\nprint(xb)\nprint('targets:')\nprint(yb.shape)\nprint(yb)\n\n\n# get device\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\n# create model\nm = BigramLanguageModelV1(vocab_size).to(device)\n\nlogits, loss = m(xb.to(device), \n                 yb.to(device))\n\nprint(logits&#91;0].shape)\nprint(logits&#91;0])\n<\/code><\/pre>\n\n\n\n<p>\u8fd0\u884c python setup5.2.py<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>$ python setup5.2.py\ninputs:\ntorch.Size(&#91;4, 8])\ntensor(&#91;&#91;24, 43, 58,  5, 57,  1, 46, 43],\n        &#91;44, 53, 56,  1, 58, 46, 39, 58],\n        &#91;52, 58,  1, 58, 46, 39, 58,  1],\n        &#91;25, 17, 27, 10,  0, 21,  1, 54]])\ntargets:\ntorch.Size(&#91;4, 8])\ntensor(&#91;&#91;43, 58,  5, 57,  1, 46, 43, 39],\n        &#91;53, 56,  1, 58, 46, 39, 58,  1],\n        &#91;58,  1, 58, 46, 39, 58,  1, 46],\n        &#91;17, 27, 10,  0, 21,  1, 54, 39]])\ntorch.Size(&#91;65])\ntensor(&#91; 1.6347, -0.0518,  0.4996,  0.7216,  0.5085, -0.7719,  0.2388,  0.3138,\n         0.2178,  0.0328, -0.1699,  1.0659,  0.7200, -0.6166,  0.0806,  2.5231,\n        -1.4623,  2.1707,  0.1624,  1.0296, -1.1377,  0.5856,  0.0173,  0.3136,\n         1.0124,  1.5122, -0.3359,  0.2456, -0.3773,  0.1587,  2.1503, -1.5131,\n        -0.9552, -0.8995, -0.9583, -0.5945,  0.5850,  0.5266,  0.7615,  0.5331,\n         1.1796,  1.3316, -0.2094,  0.0960, -0.6945,  0.5669, -0.5883,  1.4064,\n        -1.2537, -1.5195,  0.7446,  1.1914,  0.1801,  1.2333, -0.2299, -0.1531,\n         0.8408, -0.3993, -0.6126, -0.6597,  0.5906,  1.1219,  0.2432,  1.1519,\n         0.9950], device='cuda:0', grad_fn=&lt;SelectBackward0&gt;)<\/code><\/pre>\n\n\n\n<p>\u5176\u4e2d\uff0c\u503c\u6700\u5927\u7684\u5143\u7d20\u7684\u5e8f\u53f7\uff0c\u5c31\u662f\u6700\u53ef\u80fd\u7684\u90a3\u4e2a\u5b57\u6bcd\u3002\u6ce8\u610f\uff0c\u6a21\u578b<strong>\u8fd8\u6ca1\u6709\u8fdb\u884c\u4efb\u4f55\u8bad\u7ec3<\/strong>\uff0c\u5904\u4e8e\u795e\u7ecf\u9519\u4e71\u72b6\u6001\uff0c\u9884\u6d4b\u5730\u4e0e\u4e8b\u5b9e\u4e0d\u7b26\u662f\u6b63\u5e38\u7684\u3002<\/p>\n\n\n\n<p>\u6709\u4e86\u8fd9\u4e2a\u8bcd\u5d4c\u5165\u5411\u91cf\u540e\uff0c\u8ba1\u7b97\u51fa\u6a21\u578b\u4e0e Target\uff08\u4e8b\u5b9e\u7684\u4e0b\u4e00\u4e2a Token\uff09\u4e4b\u95f4\u7684\u8bef\u5dee\u4e86\u3002\u8fd9\u91cc\u4f7f\u7528\u4e86\u4ea4\u53c9\u71b5\u8bef\u5dee\u3002<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>\u4ea4\u53c9\u71b5\u8bef\u5dee\uff08Cross-Entropy Loss\uff0c\u7b80\u79f0CE\uff09\u662f\u4e00\u79cd\u5e38\u7528\u7684\u635f\u5931\u51fd\u6570\uff08loss function\uff09\uff0c\u5c24\u5176\u5728\u673a\u5668\u5b66\u4e60\u548c\u6df1\u5ea6\u5b66\u4e60\u4e2d\u7684\u5206\u7c7b\u95ee\u9898\u3002\u5b83\u662f\u7528\u6765\u8861\u91cf\u6a21\u578b\u9884\u6d4b\u6982\u7387\u5206\u5e03\u4e0e\u771f\u5b9e\u6982\u7387\u5206\u5e03\u4e4b\u95f4\u7684\u76f8\u4f3c\u5ea6\u3002\u4ea4\u53c9\u71b5\u8bef\u5dee\u7684\u503c\u8d8a\u5c0f\uff0c\u8868\u793a\u6a21\u578b\u9884\u6d4b\u7684\u6982\u7387\u5206\u5e03\u4e0e\u771f\u5b9e\u6982\u7387\u5206\u5e03\u8d8a\u63a5\u8fd1\uff0c\u6a21\u578b\u7684\u6027\u80fd\u8d8a\u597d\u3002\n<\/code><\/pre>\n\n\n\n<p>\u524d\u5411\u4f20\u64ad\u8fc7\u7a0b\u770b\u5b8c\u4e86\uff0c\u63a5\u4e0b\u6765\u5c1d\u8bd5\u8c03\u7528&nbsp;<code>generate<\/code>&nbsp;\u8fdb\u884c\u4e00\u6b21\u6587\u672c\u751f\u6210\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>print(\n\tdecode(\n\t\tm.generate(\n\t\t\tidx = torch.zeros((1, 1), dtype=torch.long).to(device), \n\t\t\tmax_new_tokens=12)&#91;0].tolist()))<\/code><\/pre>\n\n\n\n<p>setup5.3.py \u5b8c\u6574\u7684\u4ee3\u7801\u5982\u4e0b\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>import torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\ntorch.manual_seed(1337)\n\n# \u4e8c\u5143\u8bed\u8a00\u6a21\u578b\u5b9e\u73b0\nclass BigramLanguageModelV1(nn.Module):\n\n    def __init__(self, vocab_size):\n        super().__init__()\n        # \u6bcf\u4e2a\u8bcd\u76f4\u63a5\u4ece\u4e00\u4e2a\u67e5\u627e\u8868\u4e2d\u83b7\u53d6\u4e0b\u4e00\u4e2a\u8bcd\u7684logits\u503c\n        # logits\u662f\u6a21\u578b\u505a\u51fa\u9884\u6d4b\u524d\u7684\u4e00\u7ec4\u672a\u7ecf\u5f52\u4e00\u5316\u7684\u5206\u6570\uff0c\u53cd\u6620\u4e86\u4e0d\u540c\u7ed3\u679c\u7684\u76f8\u5bf9\u53ef\u80fd\u6027\n        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)\n\n\t# \u6a21\u578b\u524d\u5411\u4f20\u64ad\n\t# idx\uff1a\u5373\u524d\u9762\u7684 x\uff0c\u8868\u793a\u8f93\u5165\u6570\u636e\uff0c\u8bcd\u5728\u8bcd\u6c47\u8868\u4e2d\u7684\u7d22\u5f15\u7684\u5411\u91cf\n\t# targets\uff1a\u8bad\u7ec3\u7684\u76ee\u6807\u8f93\u51fa\uff0c\u6bd4\u5982\u6b63\u786e\u7684\u4e0b\u4e00\u4e2a\u8bcd\u7684\u7d22\u5f15\n    def forward(self, idx, targets=None):\n        # idx and targets are both (B,T) tensor of integers\n        logits = self.token_embedding_table(idx) # (B,T,C)\n\n        if targets is None:\n            loss = None\n        else:\n            B, T, C = logits.shape\n            logits = logits.view(B*T, C)\n            targets = targets.view(B*T)\n            loss = F.cross_entropy(logits, targets)\n\n        return logits, loss\n\n\t# \u5728\u6a21\u578b\u5df2\u7ecf\u8bad\u7ec3\u597d\u4e4b\u540e\uff0c\u6839\u636e\u7ed9\u5b9a\u7684\u8f93\u5165\u751f\u6210\u6587\u672c\u7684\u65b9\u6cd5\u3002\n    def generate(self, idx, max_new_tokens):\n        # idx is (B, T) array of indices in the current context\n        for _ in range(max_new_tokens):\n            # get the predictions\n            logits, loss = self(idx)\n            # focus only on the last time step\n            logits = logits&#91;:, -1, :] # becomes (B, C)\n            # apply softmax to get probabilities\n            probs = F.softmax(logits, dim=-1) # (B, C)\n            # sample from the distribution\n            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n            # append sampled index to the running sequence\n            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n        return idx\n        \n# let's now encode the entire text dataset and store it into a torch.Tensor\nwith open('data\/input.txt', 'r', encoding='utf-8') as f:\n    text = f.read()\n\nchars = sorted(list(set(text)))\nvocab_size = len(chars)\n\n# create a mapping from characters to integers\nstoi = { ch:i for i,ch in enumerate(chars) }\nitos = { i:ch for i,ch in enumerate(chars) }\n# encoder: take a string, output a list of integers\nencode = lambda s: &#91;stoi&#91;c] for c in s]\n# decoder: take a list of integers, output a string\ndecode = lambda l: ''.join(&#91;itos&#91;i] for i in l])\n\ndata = torch.tensor(encode(text), dtype=torch.long)\n\n# Let's now split up the data into train and validation sets\nn = int(0.9*len(data)) # first 90% will be train, rest val\ntrain_data = data&#91;:n]\nval_data = data&#91;n:]\n\ntorch.manual_seed(1337)\n# how many independent sequences will we process in parallel?\nbatch_size = 4\n# what is the maximum context length for predictions?\nblock_size = 8\n\ndef get_batch(split):\n    # generate a small batch of data of inputs x and targets y\n    data = train_data if split == 'train' else val_data\n    ix = torch.randint(len(data) - block_size, (batch_size,))\n    x = torch.stack(&#91;data&#91;i:i+block_size] for i in ix])\n    y = torch.stack(&#91;data&#91;i+1:i+block_size+1] for i in ix])\n    return x, y\n\nxb, yb = get_batch('train')\nprint('inputs:')\nprint(xb.shape)\nprint(xb)\nprint('targets:')\nprint(yb.shape)\nprint(yb)\n\n\n# get device\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\n# create model\nm = BigramLanguageModelV1(vocab_size).to(device)\n\nlogits, loss = m(xb.to(device), \n                 yb.to(device))\nprint(logits.shape)\nprint(loss)\n\n#print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)&#91;0].tolist()))\nprint(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long).to(device), max_new_tokens=12)&#91;0].tolist()))\n<\/code><\/pre>\n\n\n\n<p>\u8fd0\u884c python setup5.3.py<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>$ python setup5.3.py\ninputs:\ntorch.Size(&#91;4, 8])\ntensor(&#91;&#91;24, 43, 58,  5, 57,  1, 46, 43],\n        &#91;44, 53, 56,  1, 58, 46, 39, 58],\n        &#91;52, 58,  1, 58, 46, 39, 58,  1],\n        &#91;25, 17, 27, 10,  0, 21,  1, 54]])\ntargets:\ntorch.Size(&#91;4, 8])\ntensor(&#91;&#91;43, 58,  5, 57,  1, 46, 43, 39],\n        &#91;53, 56,  1, 58, 46, 39, 58,  1],\n        &#91;58,  1, 58, 46, 39, 58,  1, 46],\n        &#91;17, 27, 10,  0, 21,  1, 54, 39]])\ntorch.Size(&#91;32, 65])\ntensor(5.0364, device='cuda:0', grad_fn=&lt;NllLossBackward0&gt;)\n\nyq$;tfBfROkN<\/code><\/pre>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>6. BigramLanguageModel V1 \u8bad\u7ec3<\/strong><\/h2>\n\n\n\n<p>\u4f7f\u7528\u5982\u4e0b\u4ee3\u7801\u5bf9\u6a21\u578b\u8fdb\u884c\u4e00\u4e07\u6b21\u8bad\u7ec3,\u7528\u8bad\u7ec3\u540e\u6a21\u578b\uff0c\u751f\u6210\u4e00\u4e2a\u957f\u5ea6\u4e3a 100 \u7684\u5e8f\u5217\u770b\u770b<\/p>\n\n\n\n<p>setup6.1.py \u5b8c\u6574\u7684\u4ee3\u7801\u5982\u4e0b\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>import torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\ntorch.manual_seed(1337)\n\n# \u4e8c\u5143\u8bed\u8a00\u6a21\u578b\u5b9e\u73b0\nclass BigramLanguageModelV1(nn.Module):\n\n    def __init__(self, vocab_size):\n        super().__init__()\n        # \u6bcf\u4e2a\u8bcd\u76f4\u63a5\u4ece\u4e00\u4e2a\u67e5\u627e\u8868\u4e2d\u83b7\u53d6\u4e0b\u4e00\u4e2a\u8bcd\u7684logits\u503c\n        # logits\u662f\u6a21\u578b\u505a\u51fa\u9884\u6d4b\u524d\u7684\u4e00\u7ec4\u672a\u7ecf\u5f52\u4e00\u5316\u7684\u5206\u6570\uff0c\u53cd\u6620\u4e86\u4e0d\u540c\u7ed3\u679c\u7684\u76f8\u5bf9\u53ef\u80fd\u6027\n        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)\n\n\t# \u6a21\u578b\u524d\u5411\u4f20\u64ad\n\t# idx\uff1a\u5373\u524d\u9762\u7684 x\uff0c\u8868\u793a\u8f93\u5165\u6570\u636e\uff0c\u8bcd\u5728\u8bcd\u6c47\u8868\u4e2d\u7684\u7d22\u5f15\u7684\u5411\u91cf\n\t# targets\uff1a\u8bad\u7ec3\u7684\u76ee\u6807\u8f93\u51fa\uff0c\u6bd4\u5982\u6b63\u786e\u7684\u4e0b\u4e00\u4e2a\u8bcd\u7684\u7d22\u5f15\n    def forward(self, idx, targets=None):\n        # idx and targets are both (B,T) tensor of integers\n        logits = self.token_embedding_table(idx) # (B,T,C)\n\n        if targets is None:\n            loss = None\n        else:\n            B, T, C = logits.shape\n            logits = logits.view(B*T, C)\n            targets = targets.view(B*T)\n            loss = F.cross_entropy(logits, targets)\n\n        return logits, loss\n\n\t# \u5728\u6a21\u578b\u5df2\u7ecf\u8bad\u7ec3\u597d\u4e4b\u540e\uff0c\u6839\u636e\u7ed9\u5b9a\u7684\u8f93\u5165\u751f\u6210\u6587\u672c\u7684\u65b9\u6cd5\u3002\n    def generate(self, idx, max_new_tokens):\n        # idx is (B, T) array of indices in the current context\n        for _ in range(max_new_tokens):\n            # get the predictions\n            logits, loss = self(idx)\n            # focus only on the last time step\n            logits = logits&#91;:, -1, :] # becomes (B, C)\n            # apply softmax to get probabilities\n            probs = F.softmax(logits, dim=-1) # (B, C)\n            # sample from the distribution\n            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n            # append sampled index to the running sequence\n            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n        return idx\n        \n# let's now encode the entire text dataset and store it into a torch.Tensor\nwith open('data\/input.txt', 'r', encoding='utf-8') as f:\n    text = f.read()\n\nchars = sorted(list(set(text)))\nvocab_size = len(chars)\n\n# create a mapping from characters to integers\nstoi = { ch:i for i,ch in enumerate(chars) }\nitos = { i:ch for i,ch in enumerate(chars) }\n# encoder: take a string, output a list of integers\nencode = lambda s: &#91;stoi&#91;c] for c in s]\n# decoder: take a list of integers, output a string\ndecode = lambda l: ''.join(&#91;itos&#91;i] for i in l])\n\ndata = torch.tensor(encode(text), dtype=torch.long)\n\n# Let's now split up the data into train and validation sets\nn = int(0.9*len(data)) # first 90% will be train, rest val\ntrain_data = data&#91;:n]\nval_data = data&#91;n:]\n\ntorch.manual_seed(1337)\n# how many independent sequences will we process in parallel?\nbatch_size = 4\n# what is the maximum context length for predictions?\nblock_size = 8\n\ndef get_batch(split):\n    # generate a small batch of data of inputs x and targets y\n    data = train_data if split == 'train' else val_data\n    ix = torch.randint(len(data) - block_size, (batch_size,))\n    x = torch.stack(&#91;data&#91;i:i+block_size] for i in ix])\n    y = torch.stack(&#91;data&#91;i+1:i+block_size+1] for i in ix])\n    return x, y\n\n# get device\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\n# create model\nm = BigramLanguageModelV1(vocab_size).to(device)\n\n# create a PyTorch optimizer\noptimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)\n\nfrom tqdm import tqdm\nfor steps in tqdm(range(10000)): # increase number of steps for good results...\n    # sample a batch of data\n    xb, yb = get_batch('train')\n\n    # evaluate the loss\n    logits, loss = m(xb.to(device), yb.to(device))\n    optimizer.zero_grad(set_to_none=True)\n    loss.backward()\n    optimizer.step()\n\nprint(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long).to(device), max_new_tokens=100)&#91;0].tolist()))\n    <\/code><\/pre>\n\n\n\n<p>\u8fd0\u884c python setup6.1.py<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>$ python setup6.1.py\n100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 10000\/10000 &#91;00:06&lt;00:00, 1452.75it\/s]\n2.527495861053467\n\n\n\nCExfikRO:\nwcowindakOLOLETHAK\n\nHAPOFouBayou e.\nS:gO:33SA:\n\n\nLTauss:\nWanthafNusqhe, vet?ar dXlasoate<\/code><\/pre>\n\n\n\n<p>\u5c3d\u7ba1\u8fd8\u662f\u4e71\u7801\uff0c\u4f46\u662f\u6709\u70b9 Tiny Shakespeare \u5267\u672c\u5bf9\u8bdd\u7684\u610f\u601d\u4e86\u3002<\/p>\n\n\n\n<p>\u518d\u8bad\u7ec3\u4e00\u4e07\u6b21\uff08\u7d2f\u8ba1 2w \u6b21\uff09<\/p>\n\n\n\n<p>setup6.1.py \u4fee\u6539\u540e\u4ee3\u7801\u5982\u4e0b\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>for steps in tqdm(range(20000)): # increase number of steps for good results...<\/code><\/pre>\n\n\n\n<p>\u8fd0\u884c python setup6.1.py<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>$ python setup6.1.py\n100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 20000\/20000 &#91;00:13&lt;00:00, 1531.06it\/s]\n2.4393372535705566\n\n\nCExfik brid owindakis by bth\n\nHAPOFourayou e.\nS:\nO:33SA:\n\n\nLUCous:\nWanthar u qur, vet?\nF dilasoate\ntony@TONYP15GEN2:\/mnt\/d\/OpenAI\/CreateGPT$\n<\/code><\/pre>\n\n\n\n<p>\u8bad\u7ec310\u4e07\u6b21<\/p>\n\n\n\n<p>setup6.1.py \u4fee\u6539\u540e\u4ee3\u7801\u5982\u4e0b\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>for steps in tqdm(range(100000)): # increase number of steps for good results...<\/code><\/pre>\n\n\n\n<p>\u8fd0\u884c python setup6.1.py<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>$ python setup6.1.py\n100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 100000\/100000 &#91;01:06&lt;00:00, 1513.27it\/s]\n2.6740522384643555\n\n\nCExthy brid owindakis by bth\n\nHiset bube d e.\nS:\nO:\nIS:\nFalatanss:\nWanthar u qur, vet?\nF dilasoate<\/code><\/pre>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>7. \u4f7f\u7528\u77e9\u9635\u4e58\u6cd5\u5b9e\u73b0\u7d2f\u589e\u8fd0\u7b97<\/strong><\/h2>\n\n\n\n<p>\u4e00\u4e2a\u793a\u4f8b\uff0c\u5c55\u793a\u5982\u4f55\u901a\u8fc7\u77e9\u9635\u4e58\u6cd5\u8fdb\u884c\u2018\u52a0\u6743\u805a\u5408<\/p>\n\n\n\n<p>setup7.1.py \u5b8c\u6574\u7684\u4ee3\u7801\u5982\u4e0b\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code># toy example illustrating how matrix multiplication can be used for a \"weighted aggregation\"\nimport torch\n\ntorch.manual_seed(42)\na = torch.tril(torch.ones(3, 3))\na = a \/ torch.sum(a, 1, keepdim=True)\nb = torch.randint(0,10,(3,2)).float()\nc = a @ b\nprint('a=')\nprint(a)\nprint('--')\nprint('b=')\nprint(b)\nprint('--')\nprint('c=')\nprint(c)\n<\/code><\/pre>\n\n\n\n<p>\u8fd0\u884c python setup7.1.py<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>$ python setup7.1.py\na=\ntensor(&#91;&#91;1.0000, 0.0000, 0.0000],\n        &#91;0.5000, 0.5000, 0.0000],\n        &#91;0.3333, 0.3333, 0.3333]])\n--\nb=\ntensor(&#91;&#91;2., 7.],\n        &#91;6., 4.],\n        &#91;6., 5.]])\n--\nc=\ntensor(&#91;&#91;2.0000, 7.0000],\n        &#91;4.0000, 5.5000],\n        &#91;4.6667, 5.3333]])<\/code><\/pre>\n\n\n\n<p>GPT \u4e2d\u5305\u542b&nbsp;<a href=\"https:\/\/garden.maxieewong.com\/000.wiki\/Attention\/\">Attention<\/a>&nbsp;\u81ea\u6ce8\u610f\u529b\u673a\u5236\uff08<a href=\"https:\/\/garden.maxieewong.com\/404\">self-attention<\/a>\uff09\uff0c\u7b80\u5355\u6765\u8bf4\uff0c\u5bf9\u4e0a\u9762\u7684\u6bcf\u4e00\u8f6e\u8fdb\u884c\u52a0\u6743\uff1a<\/p>\n\n\n\n<figure class=\"wp-block-image size-large\"><img loading=\"lazy\" decoding=\"async\" width=\"1024\" height=\"213\" src=\"https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/02\/\u56fe\u7247-1-1024x213.png\" alt=\"\" class=\"wp-image-2069\" srcset=\"https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/02\/\u56fe\u7247-1-1024x213.png 1024w, https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/02\/\u56fe\u7247-1-300x62.png 300w, https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/02\/\u56fe\u7247-1-768x160.png 768w, https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/02\/\u56fe\u7247-1.png 1269w\" sizes=\"auto, (max-width: 1024px) 100vw, 1024px\" \/><\/figure>\n\n\n\n<p>\u5982\u4f55\u5b9e\u73b0\u4e0a\u8ff0\u7d2f\u589e\u8fd0\u7b97\u5462\uff1f\u4e00\u79cd\u76f4\u89c2\u65b9\u6cd5\u662f\u4f7f\u7528\u5faa\u73af\uff0c\u4f46\u662f\u8fd9\u6837\u6548\u7387\u4f4e\u3002The mathematical trick \u6307\u7684\u5c31\u662f\u4f7f\u7528\u4e00\u4e2a\u77e9\u9635\u8fd0\u7b97\u6765\u66ff\u4ee3\u5faa\u73af\uff0c\u77e9\u9635\u8fd0\u7b97\u6548\u7387\u66f4\u9ad8\uff0c\u556a\u7684\u4e00\u4e0b\u5c31\u5168\u7b97\u5b8c\u4e86\u3002<\/p>\n\n\n\n<p>\u8fd9\u91cc\u4f7f\u7528\u7684\u77e9\u9635\u662f\u4e09\u89d2\u9635\uff0c\u4e0b\u4e09\u89d2\u662f\u6743\u91cd\uff0c\u4e0a\u4e09\u89d2\u90fd\u662f 0\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>\u6743\u91cd  0   0   0\n\u6743\u91cd \u6743\u91cd  0   0\n\u6743\u91cd \u6743\u91cd \u6743\u91cd  0\n\u6743\u91cd \u6743\u91cd \u6743\u91cd \u6743\u91cd<\/code><\/pre>\n\n\n\n<p>\u6211\u4eec\u4ee5\u8fd9\u4e2a\u9635\u7684\u6bcf\u4e00\u884c\uff0c\u4e0e idx \u5217\u5411\u91cf\u76f8\u4e58\uff0c\u662f\u4e0d\u662f\u5c31\u628a\u8fd9\u4e00\u8f6e\u4e2d\u7684\u5934\u51e0\u4e2a\u5143\u7d20\uff0c\u4e0e\u6743\u91cd\u76f8\u4e58\u4e86\uff1f<\/p>\n\n\n\n<p>\u4e3a\u6b64\uff0c\u5f15\u5165\u4e00\u4e2a\u65b0\u7684\u8d85\u53c2\u6570 Channels\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>B,T,C = 4,8,2 # batch, time, channels\nx = torch.randn(B,T,C)\nx.shape<\/code><\/pre>\n\n\n\n<p>&#8220;Channel&#8221; \u53c2\u6570\u6307\u7684\u662f\u5728\u795e\u7ecf\u7f51\u7edc\uff0c\u5c24\u5176\u662f\u5728\u5904\u7406\u81ea\u6ce8\u610f\u529b\uff08Self-Attention\uff09\u673a\u5236\u65f6\uff0c\u6570\u636e\u7684\u4e00\u4e2a\u7ef4\u5ea6\uff0c\u5b83\u8868\u793a<mark>\u8f93\u5165\u6570\u636e\u4e2d\u7684\u7279\u5f81\u6570\u91cf<\/mark>\u3002<\/p>\n\n\n\n<p>\u4f8b\u5982\uff0c\u5728\u8ba1\u7b97\u673a\u89c6\u89c9\u4efb\u52a1\u4e2d\uff0c\u5bf9\u4e8e\u5f69\u8272\u56fe\u50cf\uff0c\u5e38\u89c1\u7684\u901a\u9053\u6570\u4e3a3\uff0c\u5206\u522b\u4ee3\u8868\u7ea2\u3001\u7eff\u3001\u84dd\uff08RGB\uff09\u989c\u8272\u901a\u9053\u3002\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\uff08NLP\uff09\u548cTransformer\u6a21\u578b\u7684\u4e0a\u4e0b\u6587\u4e2d\uff0c<mark>&#8220;Channel&#8221; \u901a\u5e38\u6307\u7684\u662f\u5d4c\u5165\u5411\u91cf\uff08embedding vector\uff09\u7684\u7ef4\u5ea6<\/mark>\uff0c\u6216\u8005\u8bf4\uff0c\u6bcf\u4e2a\u5355\u8bcd\u6216\u6807\u8bb0\uff08token\uff09\u88ab\u8868\u793a\u6210\u7684\u5411\u91cf\u7684\u5927\u5c0f\u3002\u8fd9\u4e9b\u5d4c\u5165\u5411\u91cf\u662f\u9ad8\u7ef4\u7a7a\u95f4\u4e2d\u7684\u70b9\uff0c\u6bcf\u4e00\u4e2a\u7ef4\u5ea6\uff08\u6216&#8221;channel&#8221;\uff09\u53ef\u4ee5\u88ab\u770b\u4f5c\u662f\u6355\u6349\u8f93\u5165\u6570\u636e\u4e2d\u67d0\u79cd\u7279\u5b9a\u65b9\u9762\u7684\u7279\u5f81\u3002<\/p>\n\n\n\n<p>\u8fd9\u91cc\u4ee5\u901a\u9053\u6570&nbsp;<code>C=2<\/code>&nbsp;\u4f5c\u4e3a\u793a\u610f\u3002<\/p>\n\n\n\n<p>\u4e0b\u9762\u4ecb\u7ecd\u4e24\u79cd\u77e9\u9635\u8fd0\u7b97\u65b9\u6cd5\u3002\u7b2c\u4e00\u79cd\u8fd0\u7b97\uff1a<\/p>\n\n\n\n<p>setup7.2.py \u5b8c\u6574\u7684\u4ee3\u7801\u5982\u4e0b\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>import torch\n\nB,T,C = 4,8,2 # batch, time, channels\nx = torch.randn(B,T,C)\n#x.shape\n\n# version 2: using matrix multiply for a weighted aggregation\n# \u521b\u5efa\u4e00\u4e2a 8x8 \u7684\u4e0b\u4e09\u89d2\u9635\n# \u5728\u4e0b\u4e09\u89d2\u9635\u4e2d\uff0c\u4e3b\u5bf9\u89d2\u7ebf\u4e0a\u65b9\u7684\u6240\u6709\u5143\u7d20\u90fd\u88ab\u8bbe\u7f6e\u4e3a0\nwei = torch.tril(torch.ones(T, T))\n\n# \u5c06\u6bcf\u4e00\u884c\u7684\u5143\u7d20\u9664\u4ee5\u8be5\u884c\u5143\u7d20\u7684\u548c\uff0c\u4ee5\u786e\u4fdd\u6bcf\u4e00\u884c\u7684\u5143\u7d20\u548c\u4e3a1\n# \u8fd9\u6837\u505a\u7684\u76ee\u7684\u662f\u5c06`wei`\u8f6c\u6362\u4e3a\u4e00\u4e2a\u6743\u91cd\u77e9\u9635\uff0c\u53ef\u4ee5\u7528\u4e8e\u5bf9\u8f93\u5165\u6570\u636e`x`\u8fdb\u884c\u52a0\u6743\u5e73\u5747\u3002\nwei = wei \/ wei.sum(1, keepdim=True)\n\n# \u8fd9\u884c\u4ee3\u7801\u4f7f\u7528\u77e9\u9635\u4e58\u6cd5\uff08`@`\u64cd\u4f5c\u7b26\uff09\u5c06\u6743\u91cd\u77e9\u9635`wei`\u5e94\u7528\u4e8e\u8f93\u5165\u6570\u636e`x`\n# \u8fd9\u5b9e\u9645\u4e0a\u662f\u5bf9`x`\u7684\u6bcf\u4e00\u884c\u8fdb\u884c\u52a0\u6743\u5e73\u5747\uff0c\u6743\u91cd\u7531`wei`\u7684\u5bf9\u5e94\u884c\u7ed9\u51fa\u3002\n# \u7ed3\u679c`xbow2`\u7684\u5f62\u72b6\u4e3a`(B, T, C)`\u3002\nxbow2 = wei @ x # (B, T, T) @ (B, T, C) ----&gt; (B, T, C)\nprint(xbow2)<\/code><\/pre>\n\n\n\n<p>\u8fd0\u884c python setup7.2.py<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>$ python setup7.2.py\ntensor(&#91;&#91;&#91;-2.2734,  1.9014],\n         &#91;-1.2950,  0.5992],\n         &#91;-0.3070,  0.6050],\n         &#91;-0.1546,  0.5280],\n         &#91;-0.0619,  0.6683],\n         &#91;-0.0036,  0.7324],\n         &#91; 0.2498,  0.5098],\n         &#91; 0.1783,  0.4924]],\n\n        &#91;&#91;-1.4634, -1.2901],\n         &#91;-1.4572, -0.7144],\n         &#91;-0.8114, -0.4710],\n         &#91;-0.2247, -0.5062],\n         &#91;-0.2281, -0.8140],\n         &#91;-0.2089, -0.7699],\n         &#91;-0.3723, -0.6599],\n         &#91;-0.3884, -0.5638]],\n\n        &#91;&#91;-0.4514,  0.1597],\n         &#91; 0.4560, -0.0235],\n         &#91;-0.3107,  0.9944],\n         &#91;-0.5348,  0.9982],\n         &#91;-0.5057,  0.9492],\n         &#91;-0.2476,  0.8629],\n         &#91;-0.1491,  0.8491],\n         &#91;-0.1891,  0.5784]],\n\n        &#91;&#91; 0.8197,  0.9400],\n         &#91; 0.7682,  0.0250],\n         &#91; 0.4424,  0.3952],\n         &#91; 0.1836,  0.5423],\n         &#91; 0.1700,  0.1112],\n         &#91; 0.1238, -0.2225],\n         &#91;-0.0805, -0.2310],\n         &#91;-0.1978,  0.0533]]])<\/code><\/pre>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u7684\u4e3b\u8981\u76ee\u7684\u662f\u521b\u5efa\u4e00\u4e2a\u4e0b\u4e09\u89d2\u77e9\u9635\uff0c\u5e76\u7528\u5b83\u6765\u5bf9\u8f93\u5165\u6570\u636e<code>x<\/code>\u8fdb\u884c\u52a0\u6743\u5e73\u5747\u3002<\/p>\n\n\n\n<p>\u8fd9\u6837\uff0c\u4fbf\u5b8c\u6210\u4e86\u5bf9\u8f93\u5165\u5e8f\u5217\u7684\u6bcf\u8f6e\u7d2f\u589e\u5904\u7406\uff0c\u5e76\u5728\u6bcf\u8f6e\u7d2f\u8fdb\u4e2d\u8fdb\u884c\u52a0\u6743\u3002\u4e0b\u9762\u518d\u4ecb\u7ecd\u7b2c\u4e8c\u79cd\u7b49\u6548\u8fd0\u7b97\uff1a<\/p>\n\n\n\n<p>setup7.3.py \u5b8c\u6574\u7684\u4ee3\u7801\u5982\u4e0b\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>import torch\nfrom torch.nn import functional as F\n\nB,T,C = 4,8,2 # batch, time, channels\nx = torch.randn(B,T,C)\n#x.shape\n\n# \u9996\u5148\u4f7f\u7528`torch.ones(T, T)`\u521b\u5efa\u4e00\u4e2a\u5927\u5c0f\u4e3a`T x T`\u7684\u51681\u77e9\u9635\n# \u7136\u540e\u4f7f\u7528`torch.tril`\u5c06\u8fd9\u4e2a\u77e9\u9635\u8f6c\u6362\u4e3a\u4e0b\u4e09\u89d2\u77e9\u9635\u3002\n# \u5728\u4e0b\u4e09\u89d2\u77e9\u9635\u4e2d\uff0c\u4e3b\u5bf9\u89d2\u7ebf\u4e0a\u65b9\u7684\u6240\u6709\u5143\u7d20\u90fd\u88ab\u8bbe\u7f6e\u4e3a0\u3002\ntril = torch.tril(torch.ones(T, T))\n\n# \u8fd9\u884c\u4ee3\u7801\u521b\u5efa\u4e86\u4e00\u4e2a\u5927\u5c0f\u4e3a`T x T`\u7684\u51680\u77e9\u9635\uff0c\u7528\u4e8e\u5b58\u50a8\u6743\u91cd\u3002\nwei = torch.zeros((T,T))\n\n# \u8fd9\u884c\u4ee3\u7801\u4f7f\u7528`masked_fill`\u51fd\u6570\u5c06`wei`\u4e2d\u5bf9\u5e94`tril`\u4e3a0\u7684\u4f4d\u7f6e\u586b\u5145\u4e3a\u8d1f\u65e0\u7a77\u3002\n# \u8fd9\u6837\u505a\u7684\u76ee\u7684\u662f\u5728\u63a5\u4e0b\u6765\u7684softmax\u64cd\u4f5c\u4e2d\uff0c\u8fd9\u4e9b\u4f4d\u7f6e\u7684\u6743\u91cd\u5c06\u88ab\u8bbe\u7f6e\u4e3a0\u3002\nwei = wei.masked_fill(tril == 0, float('-inf'))\n\n# \u8fd9\u884c\u4ee3\u7801\u4f7f\u7528softmax\u51fd\u6570\u5c06`wei`\u8f6c\u6362\u4e3a\u4e00\u4e2a\u6743\u91cd\u77e9\u9635\uff0c\u53ef\u4ee5\u7528\u4e8e\u5bf9\u8f93\u5165\u6570\u636e`x`\u8fdb\u884c\u52a0\u6743\u5e73\u5747\u3002\n# softmax\u51fd\u6570\u4f1a\u5c06\u6bcf\u4e00\u884c\u7684\u5143\u7d20\u8f6c\u6362\u4e3a\u6b63\u503c\uff0c\u5e76\u4e14\u786e\u4fdd\u6bcf\u4e00\u884c\u7684\u5143\u7d20\u548c\u4e3a1\u3002\nwei = F.softmax(wei, dim=-1)\n\n# \u8fd9\u884c\u4ee3\u7801\u4f7f\u7528\u77e9\u9635\u4e58\u6cd5\uff08`@`\u64cd\u4f5c\u7b26\uff09\u5c06\u6743\u91cd\u77e9\u9635`wei`\u5e94\u7528\u4e8e\u8f93\u5165\u6570\u636e`x`\u3002\n# \u8fd9\u5b9e\u9645\u4e0a\u662f\u5bf9`x`\u7684\u6bcf\u4e00\u884c\u8fdb\u884c\u52a0\u6743\u5e73\u5747\uff0c\u6743\u91cd\u7531`wei`\u7684\u5bf9\u5e94\u884c\u7ed9\u51fa\u3002\n# \u7ed3\u679c`xbow3`\u7684\u5f62\u72b6\u4e3a`(B, T, C)`\u3002\nxbow3 = wei @ x\nprint(xbow3)\n<\/code><\/pre>\n\n\n\n<p>\u8fd0\u884c python setup7.3.py<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>$ python setup7.3.py\ntensor(&#91;&#91;&#91; 0.4182,  0.7875],\n         &#91;-0.5942,  0.5962],\n         &#91;-0.9075,  0.1195],\n         &#91;-0.4145, -0.1805],\n         &#91;-0.9673,  0.2452],\n         &#91;-0.7444,  0.3172],\n         &#91;-0.6002,  0.3314],\n         &#91;-0.5139,  0.2772]],\n\n        &#91;&#91; 1.0123,  0.4578],\n         &#91; 0.2628,  0.7642],\n         &#91; 0.5665,  0.6021],\n         &#91; 0.4955,  0.6483],\n         &#91; 0.8661,  0.6308],\n         &#91; 0.7084,  0.4466],\n         &#91; 0.5995,  0.5258],\n         &#91; 0.4570,  0.5571]],\n\n        &#91;&#91; 0.6116, -0.9643],\n         &#91; 0.0972,  1.0087],\n         &#91; 0.5342,  0.8562],\n         &#91; 0.4989,  0.6730],\n         &#91; 0.4848,  0.5810],\n         &#91; 0.3251,  0.3983],\n         &#91; 0.2218,  0.3912],\n         &#91; 0.2258,  0.3659]],\n\n        &#91;&#91; 0.1742, -0.7519],\n         &#91; 0.9666, -0.1328],\n         &#91; 0.6983, -0.2136],\n         &#91; 0.4448, -0.7202],\n         &#91; 0.3096, -0.5783],\n         &#91; 0.2398, -0.6006],\n         &#91; 0.1999, -0.3986],\n         &#91; 0.2412, -0.2943]]])<\/code><\/pre>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>8. \u5b9e\u73b0 Masked Self-Attention<\/strong><\/h2>\n\n\n\n<p>\u4e0b\u9762\u6765\u5b9e\u73b0 Masked Self-Attention\u3002<\/p>\n\n\n\n<p>setup8.1.py \u5b8c\u6574\u7684\u4ee3\u7801\u5982\u4e0b\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>import torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\n\n# version 4: self-attention!\ntorch.manual_seed(1337)\n\n# \u8fd9\u4e24\u884c\u4ee3\u7801\u9996\u5148\u5b9a\u4e49\u4e86\u4e00\u4e9b\u53d8\u91cf\uff0c\u5305\u62ec\u6279\u6b21\u5927\u5c0f\uff08B\uff09\u3001\u65f6\u95f4\u6b65\u957f\uff08T\uff09\u548c\u901a\u9053\u6570\uff08C\uff09\n# \u7136\u540e\u751f\u6210\u4e86\u4e00\u4e2a\u968f\u673a\u7684\u5f20\u91cf`x`\uff0c\u5176\u5f62\u72b6\u4e3a`(B, T, C)`\u3002\u8fd9\u662f Mock \u8f93\u5165\nB,T,C = 4,8,32 # batch, time, channels\nx = torch.randn(B,T,C)\n\n# let's see a single Head perform self-attention\n# \u8fd9\u4e9b\u884c\u5b9a\u4e49\u4e86\u81ea\u6ce8\u610f\u529b\u673a\u5236\u4e2d\u7684\u5173\u952e\u90e8\u5206\uff1a\u952e\u3001\u67e5\u8be2\u548c\u503c\u3002\n# \u6bcf\u4e2a\u90e8\u5206\u90fd\u662f\u4e00\u4e2a\u7ebf\u6027\u53d8\u6362\uff0c\u5c06\u8f93\u5165\u7684\u7279\u5f81\u7ef4\u5ea6\uff08C\uff09\u8f6c\u6362\u4e3a\u5934\u5927\u5c0f\uff08head_size\uff09\u3002\nhead_size = 16\nkey = nn.Linear(C, head_size, bias=False)\nquery = nn.Linear(C, head_size, bias=False)\nvalue = nn.Linear(C, head_size, bias=False)\n\n# \u8fd9\u4e24\u884c\u4ee3\u7801\u5c06\u8f93\u5165`x`\u901a\u8fc7\u952e\u548c\u67e5\u8be2\u7684\u7ebf\u6027\u53d8\u6362\uff0c\u5f97\u5230\u65b0\u7684\u952e\u548c\u67e5\u8be2\u3002\nk = key(x)   # (B, T, 16)\nq = query(x) # (B, T, 16)\n\n# \u8fd9\u884c\u4ee3\u7801\u8ba1\u7b97\u4e86\u67e5\u8be2\u548c\u952e\u7684\u70b9\u79ef\uff0c\u5f97\u5230\u4e86\u6743\u91cd\u77e9\u9635`wei`\u3002\nwei =  q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---&gt; (B, T, T)\n\n# \u8fd9\u884c\u4ee3\u7801\u521b\u5efa\u4e86\u4e00\u4e2a\u4e0b\u4e09\u89d2\u77e9\u9635\u3002\ntril = torch.tril(torch.ones(T, T))\n#wei = torch.zeros((T,T))\n# \u8fd9\u884c\u4ee3\u7801\u5c06\u6743\u91cd\u77e9\u9635`wei`\u4e2d\u5bf9\u5e94\u4e0b\u4e09\u89d2\u77e9\u9635\u4e3a0\u7684\u4f4d\u7f6e\u586b\u5145\u4e3a\u8d1f\u65e0\u7a77\u3002\nwei = wei.masked_fill(tril == 0, float('-inf'))\n# \u8fd9\u884c\u4ee3\u7801\u5bf9\u6743\u91cd\u77e9\u9635`wei`\u8fdb\u884c\u4e86softmax\u64cd\u4f5c\uff0c\u4f7f\u5f97\u6bcf\u4e00\u884c\u7684\u548c\u4e3a1\u3002\nwei = F.softmax(wei, dim=-1)\n\n# \u8fd9\u884c\u4ee3\u7801\u5c06\u8f93\u5165`x`\u901a\u8fc7\u503c\u7684\u7ebf\u6027\u53d8\u6362\uff0c\u5f97\u5230\u65b0\u7684\u503c\u3002\nv = value(x)\n\n# \u8fd9\u884c\u4ee3\u7801\u5c06\u6743\u91cd\u77e9\u9635`wei`\u548c\u503c`v`\u8fdb\u884c\u77e9\u9635\u4e58\u6cd5\uff0c\u5f97\u5230\u8f93\u51fa`out`\u3002\nout = wei @ v\nprint(out)\n<\/code><\/pre>\n\n\n\n<p>\u8fd0\u884c python setup8.1.py<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>$ python setup8.1.py\ntensor(&#91;&#91;&#91;-1.5713e-01,  8.8009e-01,  1.6152e-01, -7.8239e-01, -1.4289e-01,\n           7.4676e-01,  1.0068e-01, -5.2395e-01, -8.8726e-01,  1.9068e-01,\n           1.7616e-01, -5.9426e-01, -4.8124e-01, -4.8599e-01,  2.8623e-01,\n           5.7099e-01],\n         &#91; 6.7643e-01, -5.4770e-01, -2.4780e-01,  3.1430e-01, -1.2799e-01,\n          -2.9521e-01, -4.2962e-01, -1.0891e-01, -4.9282e-02,  7.2679e-01,\n           7.1296e-01, -1.1639e-01,  3.2665e-01,  3.4315e-01, -7.0975e-02,\n           1.2716e+00],\n         &#91; 4.8227e-01, -1.0688e-01, -4.0555e-01,  1.7696e-01,  1.5811e-01,\n          -1.6967e-01,  1.6217e-02,  2.1509e-02, -2.4903e-01, -3.7725e-01,\n           2.7867e-01,  1.6295e-01, -2.8951e-01, -6.7610e-02, -1.4162e-01,\n           1.2194e+00],\n         &#91; 1.9708e-01,  2.8561e-01, -1.3028e-01, -2.6552e-01,  6.6781e-02,\n           1.9535e-01,  2.8073e-02, -2.4511e-01, -4.6466e-01,  6.9287e-02,\n           1.5284e-01, -2.0324e-01, -2.4789e-01, -1.6213e-01,  1.9474e-01,\n           7.6778e-01],\n         &#91; 2.5104e-01,  7.3457e-01,  5.9385e-01,  2.5159e-01,  2.6064e-01,\n           7.5820e-01,  5.5947e-01,  3.5387e-01, -5.9338e-01, -1.0807e+00,\n          -3.1110e-01, -2.7809e-01, -9.0541e-01,  1.3181e-01, -1.3818e-01,\n           6.3715e-01],\n         &#91; 3.4277e-01,  4.9605e-01,  4.7248e-01,  3.0277e-01,  1.8440e-01,\n           5.8144e-01,  3.8245e-01,  2.9521e-01, -4.8969e-01, -7.7051e-01,\n          -1.1721e-01, -2.5412e-01, -6.8921e-01,  1.9795e-01, -1.5135e-01,\n           7.6659e-01],\n         &#91; 1.8658e-01, -9.6351e-02, -1.4300e-01,  3.0587e-01,  8.3441e-02,\n          -6.8646e-03, -2.0472e-01, -1.5350e-01, -7.6250e-02,  3.2689e-01,\n           3.0896e-01,  7.6626e-02,  9.9243e-02,  1.6560e-01,  1.9745e-01,\n           7.6248e-01],\n         &#91; 1.3013e-01, -3.2832e-02, -4.9645e-01,  2.8652e-01,  2.7042e-01,\n          -2.6357e-01, -7.3756e-02,  3.7857e-01,  7.4580e-02,  3.3827e-02,\n           1.4695e-02,  3.1937e-01,  2.9926e-01, -1.6530e-01, -3.8630e-02,\n           3.3748e-01]],\n\n        &#91;&#91;-1.3254e+00,  1.1236e+00,  2.2927e-01, -2.9970e-01, -7.6267e-03,\n           7.9364e-01,  8.9581e-01,  3.9650e-01, -6.6613e-01, -2.1844e-01,\n          -1.3539e+00,  4.1245e-01,  9.6011e-01, -1.0805e+00, -3.9751e-01,\n          -4.4439e-01],\n         &#91;-3.8338e-01, -1.9659e-01,  8.8455e-02,  1.8560e-01, -8.7010e-02,\n           1.3239e-01,  3.0841e-01, -2.4350e-01, -1.9396e-01, -1.7634e-02,\n           4.8439e-01,  5.4210e-01, -2.0407e-02, -4.2467e-01, -2.3463e-01,\n          -4.6465e-01],\n         &#91;-1.1100e+00,  3.2334e-01,  4.7054e-01, -6.3595e-02,  2.5443e-01,\n           1.5352e-01,  2.5186e-01,  2.6286e-01,  2.7916e-01, -3.1662e-03,\n          -3.2880e-02,  4.8191e-01,  7.4431e-01, -1.9921e-01,  2.7134e-01,\n          -8.5871e-02],\n         &#91;-9.7190e-01,  4.6124e-01,  4.2349e-01, -1.7230e-02,  1.5847e-01,\n           4.1175e-01,  4.0764e-01,  2.4982e-01, -5.0322e-02,  4.1514e-03,\n          -3.9853e-01,  4.3551e-01,  7.0285e-01, -4.3081e-01,  2.6684e-02,\n          -2.0169e-01],\n         &#91; 3.3586e-01, -8.5915e-02,  9.3660e-01,  7.7311e-01,  1.8037e-01,\n           8.2853e-01, -6.9183e-02,  2.8814e-01,  1.1734e-01,  6.8448e-01,\n          -5.8500e-02,  1.2726e-01,  2.9780e-01,  1.9324e-01,  1.5655e-01,\n          -9.3004e-03],\n         &#91; 1.6984e-01,  3.0993e-02,  8.1557e-01,  6.1679e-01,  1.0429e-01,\n           7.4573e-01,  2.3072e-02,  3.0572e-01,  5.8163e-02,  5.7122e-01,\n          -4.5275e-02,  1.5051e-01,  3.2901e-01,  5.6984e-02,  1.0311e-01,\n          -9.9174e-02],\n         &#91; 4.6497e-02,  1.5765e-01,  3.9760e-01,  1.7619e-01, -2.1168e-01,\n           2.3365e-01, -6.2083e-02,  2.1726e-01, -7.8725e-03,  4.5389e-01,\n           3.4349e-01, -5.5631e-02,  3.3726e-01, -3.7591e-01, -1.0140e-02,\n          -4.5806e-01],\n         &#91;-5.3896e-01,  7.5555e-01,  3.3034e-01, -1.5849e-01, -2.6740e-01,\n           4.3495e-01,  3.7772e-01,  5.5794e-01, -1.8369e-01,  1.5938e-01,\n          -2.1042e-01,  5.5790e-02,  6.3184e-01, -6.4884e-01, -9.6084e-02,\n          -5.0751e-01]],\n\n        &#91;&#91; 6.8925e-02,  1.2248e+00, -4.1194e-01, -1.7046e-01, -6.9224e-01,\n          -2.9201e-01,  1.2704e+00, -6.8596e-01,  4.3798e-01, -2.6366e-01,\n           1.1528e-01,  1.1676e+00, -7.2138e-01, -1.2308e+00,  8.3821e-01,\n          -5.5987e-01],\n         &#91;-4.6375e-01,  6.3807e-01, -1.5842e-01, -1.3309e-01, -5.9402e-01,\n          -5.0374e-01,  2.3289e-01, -3.2126e-01,  4.5781e-01, -1.8590e-01,\n           1.9215e-01,  3.7566e-01, -3.5905e-01, -7.7262e-01,  3.5036e-01,\n           6.9694e-02],\n         &#91;-6.4044e-01,  1.3831e-01, -6.1007e-02, -1.1112e-01, -4.5228e-01,\n          -6.2271e-01, -1.7030e-01, -2.4949e-01,  5.0670e-01, -9.6444e-02,\n           4.8315e-01,  9.4986e-02, -2.9810e-01, -3.6538e-01,  3.9458e-01,\n           4.1512e-01],\n         &#91;-6.7193e-01,  1.2516e-01,  7.3386e-02, -1.3198e-01, -1.7880e-01,\n          -5.6740e-01, -6.8226e-01,  5.0844e-02,  3.3051e-01,  7.8242e-02,\n           6.8022e-02, -2.4041e-01, -6.6864e-02, -1.8411e-01, -5.3514e-02,\n           4.5113e-01],\n         &#91;-1.4270e-02,  1.0195e+00, -3.4792e-01, -1.6421e-01, -5.5846e-01,\n          -3.2457e-01,  9.9404e-01, -5.6891e-01,  4.0097e-01, -1.8123e-01,\n           1.1856e-01,  9.8704e-01, -6.4057e-01, -1.0320e+00,  7.3320e-01,\n          -4.3167e-01],\n         &#91;-6.3858e-01, -7.6533e-02, -3.6510e-01,  1.7782e-01, -6.5426e-02,\n          -3.5158e-01,  7.9591e-02,  1.7384e-01,  3.6676e-01, -4.2302e-02,\n           2.4923e-01,  4.8239e-01, -2.1295e-01, -2.9492e-01,  3.4749e-01,\n          -1.7111e-01],\n         &#91;-2.2366e-01, -5.5317e-02, -1.8296e-01,  2.4258e-01,  2.5357e-01,\n          -1.6154e-01, -2.3908e-01,  3.3243e-01,  1.0304e-01,  2.6067e-01,\n          -5.0670e-02,  3.6947e-01, -4.9856e-02,  1.1197e-01,  1.1752e-01,\n          -2.5078e-01],\n         &#91;-2.4821e-01,  1.4845e-01, -3.5033e-01,  1.7102e-01,  1.6613e-01,\n          -2.0643e-01,  8.6633e-02,  8.8414e-02,  2.1188e-01,  2.5805e-01,\n           5.5146e-02,  4.2668e-01, -2.0443e-01, -1.7372e-01,  3.8899e-01,\n           5.1725e-02]],\n\n        &#91;&#91; 9.7183e-02,  5.7301e-02, -1.0468e-01, -4.6654e-02, -1.4006e-01,\n          -8.4126e-01, -1.3625e-01, -6.7465e-01, -2.1541e-01,  1.0993e+00,\n           2.3427e-01,  3.2605e-02, -1.8521e-01,  1.4780e-01, -6.1045e-01,\n           1.5391e+00],\n         &#91; 1.9305e-01, -2.1031e-01, -3.4658e-01,  2.0567e-01, -1.7799e-01,\n          -7.4604e-01, -6.4427e-01, -6.9183e-01, -2.0558e-01,  7.0413e-01,\n           2.3632e-01,  9.8800e-04, -1.7015e-01,  1.1203e-01, -7.1064e-01,\n           1.2431e+00],\n         &#91; 2.9114e-01, -4.8343e-01, -5.9254e-01,  4.6477e-01, -2.1832e-01,\n          -6.4460e-01, -1.1627e+00, -7.0993e-01, -1.9703e-01,  2.9262e-01,\n           2.3669e-01, -3.1050e-02, -1.5471e-01,  7.7153e-02, -8.1137e-01,\n           9.3578e-01],\n         &#91; 1.7549e-01, -3.4260e-02, -2.0523e-01,  2.7644e-02, -2.1312e-01,\n          -5.6022e-01, -3.5273e-01, -6.2722e-01, -3.0037e-01,  4.6061e-01,\n           1.5004e-01,  1.9040e-02, -1.4646e-01,  1.7220e-01, -6.2559e-01,\n           1.0722e+00],\n         &#91; 1.7354e-01, -1.7962e-01, -2.7874e-01, -1.0590e-01, -1.2952e-01,\n          -3.5086e-01, -5.5830e-01, -3.8638e-01, -2.9719e-01,  3.3368e-02,\n           1.7392e-01,  5.5898e-02, -7.2007e-02,  1.3182e-02, -6.6710e-01,\n           5.4229e-01],\n         &#91; 2.4678e-01, -4.7274e-01, -5.2827e-01,  3.1212e-01, -1.7528e-01,\n          -4.8636e-01, -1.1223e+00, -5.4196e-01, -2.0142e-01,  4.0103e-02,\n           2.2231e-01, -2.9380e-02, -9.4353e-02,  2.6374e-02, -7.8726e-01,\n           6.2836e-01],\n         &#91;-3.9784e-01,  2.5915e-01,  5.0358e-01, -4.6864e-01, -2.2024e-02,\n          -3.2242e-01, -1.2578e-01,  1.0634e-01,  1.3618e-01,  1.7780e-01,\n           1.0391e-01, -6.2540e-01,  3.8904e-01,  3.3690e-01, -5.5140e-01,\n           5.2246e-01],\n         &#91;-3.5927e-01,  3.3935e-02, -2.9863e-02, -1.5019e-01, -6.0354e-03,\n          -6.5733e-02, -3.9659e-01, -6.0435e-02, -5.7551e-01, -2.9157e-01,\n           1.4899e-01, -7.5002e-02,  7.3228e-02, -4.7413e-02, -6.4394e-01,\n           2.8560e-01]]], grad_fn=&lt;UnsafeViewBackward0&gt;)<\/code><\/pre>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u5b9e\u73b0\u4e86\u4e00\u4e2a\u5e26\u6709\u63a9\u7801\u7684\u81ea\u6ce8\u610f\u529b\uff08Masked Self-Attention\uff09\u673a\u5236\u3002Masked Self-Attention \u5141\u8bb8\u6a21\u578b\u5728\u5904\u7406\u5e8f\u5217\u6570\u636e\u65f6\uff0c\u4ec5\u8003\u8651\u5f53\u524d\u4f4d\u7f6e\u4e4b\u524d\u7684\u4fe1\u606f\uff0c\u5e38\u7528\u4e8e\u5982\u751f\u6210\u6587\u672c\u7684\u4efb\u52a1\u4e2d\uff0c\u4ee5\u907f\u514d\u672a\u6765\u4fe1\u606f\u7684\u6cc4\u9732\u3002<\/p>\n\n\n\n<p>\u81ea\u6ce8\u610f\u529b\u673a\u5236\u7684\u4e09\u4e2a\u6838\u5fc3\u7ec4\u4ef6\uff1a\u67e5\u8be2\uff08query\uff09\u3001\u952e\uff08key\uff09\u548c\u503c\uff08value\uff09\uff0c\u5b83\u4eec\u90fd\u6765\u6e90\u4e8e\u540c\u4e00\u4e2a\u8f93\u5165\u6570\u636e&nbsp;<code>x<\/code>\u3002\u8fd9\u91cc\u4f7f\u7528&nbsp;<code>nn.Linear()<\/code>&nbsp;\u5bf9\u6bcf\u4e2a\u7ec4\u4ef6\u8fdb\u884c\u7ebf\u6027\u53d8\u6362\uff08\u6620\u5c04\uff09\uff0c\u4ee5\u751f\u6210\u4e0d\u540c\u7684\u8868\u793a\u7a7a\u95f4\u3002\u8fd9\u662f\u5b9e\u73b0\u6ce8\u610f\u529b\u673a\u5236\u7684\u6807\u51c6\u505a\u6cd5\uff0c\u901a\u8fc7\u8fd9\u79cd\u65b9\u5f0f\uff0c\u53ef\u4ee5\u8ba9\u6a21\u578b\u5b66\u4e60\u5230\u5982\u4f55\u6700\u6709\u6548\u5730\u8868\u793a\u6570\u636e\u3002<\/p>\n\n\n\n<p>\u5982\u679c\u79fb\u9664\u7528\u4e8e\u5c06\u6743\u91cd\u77e9\u9635&nbsp;<code>wei<\/code>&nbsp;\u4e2d\u7279\u5b9a\u4f4d\u7f6e\u8bbe\u7f6e\u4e3a\u8d1f\u65e0\u7a77\u7684\u4ee3\u7801\u884c\uff08<code>wei.masked_fill(tril == 0, float('-inf'))<\/code>\uff09\uff0c\u90a3\u4e48\u8be5\u5b9e\u73b0\u5c06\u4e0d\u518d\u662f\u4e00\u4e2a\u5e26\u6709\u63a9\u7801\u7684\u81ea\u6ce8\u610f\u529b\u673a\u5236\uff0c\u800c\u662f\u53d8\u56de\u4e00\u4e2a\u6807\u51c6\u7684\u81ea\u6ce8\u610f\u529b\u673a\u5236\u3002\u6807\u51c6\u7684\u81ea\u6ce8\u610f\u529b\u5141\u8bb8\u6bcf\u4e2a\u5e8f\u5217\u5143\u7d20\u201c\u6ce8\u610f\u201d\u5e8f\u5217\u4e2d\u7684\u6240\u6709\u5176\u4ed6\u5143\u7d20\uff0c\u800c\u4e0d\u662f\u4ec5\u4ec5\u662f\u4e4b\u524d\u7684\u5143\u7d20\u3002<\/p>\n\n\n\n<p>\u81ea\u6ce8\u610f\u529b\u673a\u5236\u7684\u4e00\u4e2a\u5173\u952e\u7279\u6027\uff1a\u67e5\u8be2\uff08query\uff09\u3001\u952e\uff08key\uff09\u548c\u503c\uff08value\uff09\u5411\u91cf\u90fd\u6765\u6e90\u4e8e\u540c\u4e00\u4e2a\u8f93\u5165&nbsp;<code>x<\/code>\u3002\u8fd9\u610f\u5473\u7740\u81ea\u6ce8\u610f\u529b\u673a\u5236\u80fd\u591f\u5728\u8f93\u5165\u6570\u636e\u7684\u5185\u90e8\u627e\u5230\u5143\u7d20\u4e4b\u95f4\u7684\u5173\u7cfb\u3002<\/p>\n\n\n\n<p>\u6ce8\uff1a\u5982\u679c\u5c06 query\u8f93\u5165\u4e3ax\uff0ckey,value\u8f93\u5165\u4e3a y\uff0c\u4fbf\u6210\u4e3a\u53e6\u4e00\u79cd\u6ce8\u610f\u529b\u673a\u5236\u2014\u2014\u4ea4\u53c9\u6ce8\u610f\u529b\uff08cross-attention\uff09\u3002\u5728\u4ea4\u53c9\u6ce8\u610f\u529b\u8bbe\u7f6e\u4e2d\uff0c\u67e5\u8be2\uff08query\uff09\u5411\u91cf\u6765\u81ea\u4e8e\u4e00\u4e2a\u8f93\u5165\uff08\u4f8b\u5982\u00a0<code>x<\/code>\uff09\uff0c\u800c\u952e\uff08key\uff09\u548c\u503c\uff08value\uff09\u5411\u91cf\u6765\u81ea\u4e8e\u53e6\u4e00\u4e2a\u4e0d\u540c\u7684\u8f93\u5165\uff08\u4f8b\u5982\u00a0<code>y<\/code>\uff09\u3002\u8fd9\u79cd\u673a\u5236\u5e38\u7528\u4e8e\u5904\u7406\u4e24\u79cd\u4e0d\u540c\u7684\u5e8f\u5217\uff0c\u4f8b\u5982\u5728\u673a\u5668\u7ffb\u8bd1\u4efb\u52a1\u4e2d\uff0c\u6a21\u578b\u9700\u8981\u8003\u8651\u6e90\u8bed\u8a00\u53e5\u5b50\uff08\u4f5c\u4e3a\u00a0<code>x<\/code>\uff09\u548c\u76ee\u6807\u8bed\u8a00\u53e5\u5b50\uff08\u4f5c\u4e3a\u00a0<code>y<\/code>\uff09\u4e4b\u95f4\u7684\u5173\u7cfb\u3002<\/p>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u4e2d\uff0c\u6700\u540e\u7684\u51e0\u884c\u4ee3\u7801\uff08\u4ece\u751f\u6210\u4e0b\u4e09\u89d2\u9635\u5230 softmax normalize\uff09\u6211\u4eec\u5df2\u7ecf\u6bd4\u8f83\u719f\u6089\u4e86\u3002\u65b0\u589e\u7684\u90e8\u5206\u662f\u5f15\u5165\u81ea\u6ce8\u610f\u529b\u7684 query, key, value\uff0c\u6784\u6210\u4e86\u4e00\u4e2a<mark>\u5355\u5934\u7684\u81ea\u6ce8\u610f\u529b\u673a\u5236<\/mark>\u3002\u6ce8\uff1a\u53ef\u4ee5\u770b\u5230 Channels \u53d8\u6210\u4e86 32\uff0c\u8bcd\u5411\u91cf\u591a\u5927\uff0c\u8fd9\u91cc\u7684 C \u5c31\u8ddf\u7740\u591a\u5927\u3002<\/p>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>9. Weight Normalization for Softmax<\/strong><\/h2>\n\n\n\n<p>\u5728\u539f\u7248\u8bba\u6587\u7684\u516c\u5f0f\u4e2d\uff0c\u6709\u4e00\u4e2a\u5206\u6bcd\uff1a<\/p>\n\n\n\n<figure class=\"wp-block-image size-full\"><img loading=\"lazy\" decoding=\"async\" width=\"721\" height=\"105\" src=\"https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/02\/\u56fe\u7247-2.png\" alt=\"\" class=\"wp-image-2074\" srcset=\"https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/02\/\u56fe\u7247-2.png 721w, https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/02\/\u56fe\u7247-2-300x44.png 300w\" sizes=\"auto, (max-width: 721px) 100vw, 721px\" \/><\/figure>\n\n\n\n<p><strong>Softmax\u51fd\u6570<\/strong>\uff1aSoftmax\u51fd\u6570\u662f\u4e00\u79cd\u5c06\u5b9e\u6570\u5411\u91cf\u8f6c\u6362\u4e3a\u6982\u7387\u5206\u5e03\u7684\u51fd\u6570\u3002\u5bf9\u4e8e\u4efb\u610f\u5b9e\u6570\u5411\u91cf\uff0cSoftmax\u51fd\u6570\u4f1a\u538b\u7f29\u6bcf\u4e2a\u5143\u7d20\u7684\u8303\u56f4\u5230[0, 1]\uff0c\u5e76\u4e14\u4f7f\u5f97\u6240\u6709\u5143\u7d20\u7684\u548c\u4e3a1\u3002\u8fd9\u5728\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\u4e2d\u975e\u5e38\u6709\u7528\uff0c\u7279\u522b\u662f\u5728\u6a21\u578b\u7684\u8f93\u51fa\u5c42\uff0c\u53ef\u4ee5\u7528\u6765\u4ee3\u8868\u6982\u7387\u5206\u5e03\u3002<\/p>\n\n\n\n<p><strong>\u6ce8\u610f\u529b\u673a\u5236\u4e2d\u7684Softmax<\/strong>\uff1a\u5728\u6ce8\u610f\u529b\u673a\u5236\u4e2d\uff0cSoftmax\u7528\u4e8e\u8ba1\u7b97\u6ce8\u610f\u529b\u6743\u91cd\uff0c\u5373\u786e\u5b9a\u5728\u751f\u6210\u8f93\u51fa\u65f6\u5e94\u8be5\u7ed9\u4e88\u5e8f\u5217\u4e2d\u6bcf\u4e2a\u5143\u7d20\u591a\u5c11\u201c\u6ce8\u610f\u529b\u201d\u3002\u901a\u8fc7Softmax\uff0c\u6a21\u578b\u80fd\u591f\u51b3\u5b9a\u5728\u805a\u5408\u4fe1\u606f\u65f6\u5bf9\u54ea\u4e9b\u5143\u7d20\u7ed9\u4e88\u66f4\u591a\u7684\u91cd\u89c6\u3002<\/p>\n\n\n\n<p><strong>Weight Normalization for Softmax<\/strong>\uff1a\u6743\u91cd\u6b63\u89c4\u5316\u662f\u4e00\u79cd\u6280\u672f\uff0c\u7528\u4e8e\u8c03\u6574\u6743\u91cd\u5411\u91cf\u7684\u5c3a\u5ea6\uff0c\u4f7f\u5176\u5177\u6709\u4e00\u5b9a\u7684\u7edf\u8ba1\u6027\u8d28\uff08\u4f8b\u5982\uff0c\u4f7f\u65b9\u5dee\u4e3a1\uff09\u3002\u5728\u6ce8\u610f\u529b\u673a\u5236\u7684\u4e0a\u4e0b\u6587\u4e2d\uff0c\u8fd9\u662f\u901a\u8fc7\u8c03\u6574\u67e5\u8be2\uff08query\uff09\u548c\u952e\uff08key\uff09\u7684\u70b9\u79ef\u7ed3\u679c\u6765\u5b9e\u73b0\u7684\uff0c\u4ece\u800c\u5f71\u54cdSoftmax\u51fd\u6570\u7684\u8f93\u5165\u3002<\/p>\n\n\n\n<p>\u4e3a\u4ec0\u4e48\u9700\u8981\u6743\u91cd\u6b63\u89c4\u5316\uff1f<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong>\u907f\u514dSoftmax\u9971\u548c<\/strong>\uff1a\u5982\u679c\u6ca1\u6709\u6b63\u89c4\u5316\uff0c\u5f53head_size\uff08\u5373\uff0c\u6bcf\u4e2a\u6ce8\u610f\u529b\u5934\u7684\u7ef4\u5ea6\uff09\u5f88\u5927\u65f6\uff0c\u67e5\u8be2\u548c\u952e\u7684\u70b9\u79ef\u7ed3\u679c\u53ef\u80fd\u4f1a\u975e\u5e38\u5927\uff0c\u5bfc\u81f4Softmax\u8f93\u5165\u7684\u503c\u57df\u8fc7\u5927\u3002\u8fd9\u4f1a\u4f7f\u5f97Softmax\u51fd\u6570\u7684\u8f93\u51fa\u53d8\u5f97\u6781\u7aef\uff0c\u5373\u5927\u591a\u6570\u7684\u6ce8\u610f\u529b\u6743\u91cd\u90fd\u96c6\u4e2d\u5728\u5c11\u6570\u51e0\u4e2a\u503c\u4e0a\uff0c\u800c\u5176\u4ed6\u503c\u51e0\u4e4e\u88ab\u5ffd\u7565\u3002<\/li>\n\n\n\n<li><strong>\u4fdd\u6301\u68af\u5ea6\u7a33\u5b9a<\/strong>\uff1a\u901a\u8fc7\u63a7\u5236Softmax\u8f93\u5165\u7684\u5c3a\u5ea6\uff0c\u53ef\u4ee5\u5e2e\u52a9\u4fdd\u6301\u68af\u5ea6\u7684\u7a33\u5b9a\u6027\uff0c\u4ece\u800c\u907f\u514d\u8bad\u7ec3\u8fc7\u7a0b\u4e2d\u7684\u68af\u5ea6\u7206\u70b8\u6216\u6d88\u5931\u95ee\u9898\u3002<\/li>\n<\/ul>\n\n\n\n<p>\u5982\u4f55\u5b9e\u73b0\u6743\u91cd\u6b63\u89c4\u5316\uff1f<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>\u6743\u91cd\u6b63\u89c4\u5316\u53ef\u4ee5\u901a\u8fc7\u9664\u4ee5\u00a0d<sub><em>k<\/em><\/sub>^1\/2\u6765\u5b9e\u73b0\uff0c\u5176\u4e2d\u00a0d<sub><em>k<\/em><\/sub>\u00a0\u662fhead_size\u3002\u8fd9\u4e2a\u64cd\u4f5c\u786e\u4fdd\u4e86\u5f53head_size\u5f88\u5927\u65f6\uff0c\u70b9\u79ef\u7ed3\u679c\u7684\u65b9\u5dee\u5927\u7ea6\u662f1\uff0c\u4ece\u800c\u7f29\u5c0f\u4e86Softmax\u8f93\u5165\u7684\u503c\u57df\u3002<\/li>\n<\/ul>\n\n\n\n<p>\u5bf9\u5e94\u7684\u4ee3\u7801\u5b9e\u73b0\u5982\u4e0b\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code># compute attention scores (\"affinities\")\nwei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)<\/code><\/pre>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>10. \u5355\u5934\u81ea\u6ce8\u610f\u529b\u6a21\u5757<\/strong><\/h2>\n\n\n\n<p>\u57fa\u4e8e\u524d\u9762\u7684\u77e5\u8bc6\u50a8\u5907\uff0c\u5355\u5934\u6ce8\u610f\u529b\u6a21\u5757\u5b9e\u73b0\u5982\u4e0b\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>class Head(nn.Module):\n    \"\"\" one head of self-attention \"\"\"\n\n    def __init__(self, head_size):\n        super().__init__()\n        self.key = nn.Linear(n_embd, head_size, bias=False)\n        self.query = nn.Linear(n_embd, head_size, bias=False)\n        self.value = nn.Linear(n_embd, head_size, bias=False)\n        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))\n\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, x):\n        B,T,C = x.shape\n        k = self.key(x)   # (B,T,C)\n        q = self.query(x) # (B,T,C)\n        # compute attention scores (\"affinities\")\n        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)\n        wei = wei.masked_fill(self.tril&#91;:T, :T] == 0, float('-inf')) # (B, T, T)\n        wei = F.softmax(wei, dim=-1) # (B, T, T)\n        wei = self.dropout(wei)\n        # perform the weighted aggregation of the values\n        v = self.value(x) # (B,T,C)\n        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)\n        return out<\/code><\/pre>\n\n\n\n<p>\u5176\u4e2d\uff0c\u5f15\u5165\u4e86 Dropout\uff0c\u5728\u8bad\u7ec3\u65f6\u968f\u673a\u4e22\u6389\u90e8\u5206\u6743\u91cd\uff0c\u6765\u63d0\u5347\u8bad\u7ec3\u6548\u679c\uff0c\u907f\u514d overfiting.<\/p>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>11. \u591a\u5934\u81ea\u6ce8\u610f\u529b\u6a21\u5757<\/strong><\/h2>\n\n\n\n<p>\u7ec4\u88c5\u591a\u4e2a\u5355\u5934\u81ea\u6ce8\u610f\u529b\u6a21\u5757\uff0c\u4fbf\u5f97\u5230\u4e86\u591a\u5934\u81ea\u6ce8\u610f\u529b\u6a21\u5757\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>class MultiHeadAttention(nn.Module):\n    \"\"\" multiple heads of self-attention in parallel \"\"\"\n\n    def __init__(self, num_heads, head_size):\n        super().__init__()\n        self.heads = nn.ModuleList(&#91;Head(head_size) for _ in range(num_heads)])\n        self.proj = nn.Linear(n_embd, n_embd)\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, x):\n        out = torch.cat(&#91;h(x) for h in self.heads], dim=-1)\n        out = self.dropout(self.proj(out))\n        return out<\/code><\/pre>\n\n\n\n<p>\u591a\u5934\u81ea\u6ce8\u610f\u529b\u901a\u8fc7\u5e76\u884c\u8fd0\u884c\u591a\u4e2a\u81ea\u6ce8\u610f\u529b\u673a\u5236\u6765\u589e\u52a0\u6a21\u578b\u7684\u8868\u8fbe\u80fd\u529b\u3002\u6bcf\u4e2a\u5934\u5173\u6ce8\u8f93\u5165\u6570\u636e\u7684\u4e0d\u540c\u90e8\u5206\uff0c\u4ece\u800c\u80fd\u591f\u6355\u83b7\u4e0d\u540c\u7684\u4fe1\u606f\u548c\u7279\u5f81\u3002\u8fd9\u4e9b\u4e0d\u540c\u5934\u7684\u8f93\u51fa\u4f1a\u6709\u4e0d\u540c\u7684\u8868\u793a\u7a7a\u95f4\u548c\u7ef4\u5ea6\u3002\u901a\u8fc7\u62fc\u63a5\u8fd9\u4e9b\u8f93\u51fa\uff0c\u6211\u4eec\u83b7\u5f97\u4e86\u4e00\u4e2a\u7efc\u5408\u4e86\u6240\u6709\u5934\u4fe1\u606f\u7684\u8868\u793a\uff0c\u4f46\u8fd9\u4e2a\u7efc\u5408\u540e\u7684\u8868\u793a\u7684\u7ef4\u5ea6\u4f1a\u6bd4\u539f\u59cb\u8f93\u5165\u5927\u3002<\/p>\n\n\n\n<p>\u7ebf\u6027\u53d8\u6362\uff08<code>self.proj<\/code>\uff09\u5728\u8fd9\u91cc\u7684\u4f5c\u7528\u662f\u5c06\u8fd9\u4e2a\u7ef4\u5ea6\u66f4\u5927\u7684\u8868\u793a\u538b\u7f29\u56de\u539f\u59cb\u8f93\u5165\u6570\u636e\u7684\u7ef4\u5ea6\u3002\u8fd9\u4e0d\u4ec5\u4f7f\u5f97\u591a\u5934\u81ea\u6ce8\u610f\u529b\u6a21\u5757\u7684\u8f93\u51fa\u53ef\u4ee5\u65e0\u7f1d\u5730\u878d\u5165\u540e\u7eed\u5c42\uff0c\u800c\u4e14\u8fd8\u901a\u8fc7\u8fd9\u4e2a\u8fc7\u7a0b\u6574\u5408\u4e86\u6765\u81ea\u4e0d\u540c\u5934\u7684\u4fe1\u606f\uff0c\u589e\u5f3a\u4e86\u6a21\u578b\u5bf9\u8f93\u5165\u6570\u636e\u7684\u7406\u89e3\u80fd\u529b\u3002<\/p>\n\n\n\n<p>\u6b64\u5916\uff0c\u7ebf\u6027\u53d8\u6362\u8fd8\u63d0\u4f9b\u4e86\u989d\u5916\u7684\u53c2\u6570\uff0c\u4e3a\u6a21\u578b\u7684\u5b66\u4e60\u63d0\u4f9b\u4e86\u66f4\u591a\u7684\u7075\u6d3b\u6027\u548c\u80fd\u529b\uff0c\u6709\u52a9\u4e8e\u6a21\u578b\u66f4\u597d\u5730\u62df\u5408\u548c\u7406\u89e3\u6570\u636e\u3002\u901a\u8fc7\u8bad\u7ec3\uff0c\u8fd9\u4e9b\u53c2\u6570\u53ef\u4ee5\u8c03\u6574\u4ee5\u4f18\u5316\u6a21\u578b\u7684\u6027\u80fd\uff0c\u4ece\u800c\u63d0\u9ad8\u6a21\u578b\u5bf9\u4e8e\u7279\u5b9a\u4efb\u52a1\u7684\u51c6\u786e\u6027\u548c\u6548\u7387\u3002<\/p>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>12. FeedForward Layer<\/strong><\/h2>\n\n\n\n<p>\u5bf9\u591a\u5934\u81ea\u6ce8\u610f\u529b\u6a21\u5757\u8fdb\u884c\u6574\u5408\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>class FeedFoward(nn.Module):\n    \"\"\" a simple linear layer followed by a non-linearity \"\"\"\n\n    def __init__(self, n_embd):\n        super().__init__()\n        self.net = nn.Sequential(\n            nn.Linear(n_embd, 4 * n_embd),\n            nn.ReLU(),\n            nn.Linear(4 * n_embd, n_embd),\n            nn.Dropout(dropout),\n        )\n\n    def forward(self, x):\n        return self.net(x)\n<\/code><\/pre>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>13. LayerNorm<\/strong><\/h2>\n\n\n\n<p>\u5173\u4e8e\u00a0LayerNorm\uff08\u5c42\u5f52\u4e00\u5316\uff09\u5177\u4f53\u53ef\u70b9\u51fb\u9605\u8bfb\u7b14\u8bb0\u3002<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>class LayerNorm1d: # (used to be BatchNorm1d)\n\n  def __init__(self, dim, eps=1e-5, momentum=0.1):\n    self.eps = eps\n    self.gamma = torch.ones(dim)\n    self.beta = torch.zeros(dim)\n\n  def __call__(self, x):\n    # calculate the forward pass\n    xmean = x.mean(1, keepdim=True) # batch mean\n    xvar = x.var(1, keepdim=True) # batch variance\n    xhat = (x - xmean) \/ torch.sqrt(xvar + self.eps) # normalize to unit variance\n    self.out = self.gamma * xhat + self.beta\n    return self.out\n\n  def parameters(self):\n    return &#91;self.gamma, self.beta]<\/code><\/pre>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>14. Positional encoding<\/strong><\/h2>\n\n\n\n<p>Attention \u673a\u5236\u901a\u8fc7\u6ce8\u610f\u5230\u5e8f\u5217\u4e2d\u7684\u5176\u5b83\u5143\u7d20\u5b9e\u73b0\u4e86\u80fd\u529b\u63d0\u5347\u3002\u4f46\u662f\uff0c<mark>Attention \u672c\u8eab\u662f\u4e0d\u8003\u8651\u5143\u7d20\u5728\u5e8f\u5217\u4e2d\u7684\u987a\u5e8f\u7684<\/mark>\u3002Positional Encoding\u00a0\u53ef\u89e3\u51b3\u8fd9\u4e00\u95ee\u9898\u3002<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>\u5728\u8bb8\u591a\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4efb\u52a1\u4e2d\uff0c\u8bcd\u7684\u987a\u5e8f\u548c\u4f4d\u7f6e\u5bf9\u4e8e\u8bed\u4e49\u7684\u7406\u89e3\u81f3\u5173\u91cd\u8981\u3002\u7136\u800c\uff0c\u5728\u4f7f\u7528 Transformer \u6a21\u578b\u65f6\uff0c\u7531\u4e8e\u5176\u591a\u5934\u81ea\u6ce8\u610f\u529b\u5c42\u7684\u7279\u6027\uff0c\u6a21\u578b\u5bf9\u8f93\u5165\u6570\u636e\u7684\u987a\u5e8f\u5e76\u4e0d\u654f\u611f\u3002\u4e3a\u4e86\u89e3\u51b3\u8fd9\u4e2a\u95ee\u9898\uff0c\u4f4d\u7f6e\u7f16\u7801\uff08Positional Encoding\uff09\u88ab\u5f15\u5165 Transformer \u6a21\u578b\u4e2d\uff0c\u4f7f\u5f97\u6a21\u578b\u80fd\u591f\u7406\u89e3\u8f93\u5165\u6570\u636e\u4e2d\u8bcd\u7684\u987a\u5e8f\u548c\u76f8\u5bf9\u4f4d\u7f6e\u3002<\/code><\/pre>\n\n\n\n<p>\u5728\u89c6\u9891\u7ed9\u51fa\u4e86\u4e00\u79cd\u4f4d\u7f6e\u7f16\u7801\u65b9\u6cd5\uff0c\u4f7f\u7528\u00a0<code>torch.arrage<\/code>\u00a0\u4e0e\u00a0<code>nn.Embedding<\/code>\u00a0\u751f\u6210\u4f4d\u7f6e\u5411\u91cf<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>def forward(self, idx, targets=None):\n    tok_emb = self.token_embedding_table(idx) # (B,T,C)\n    pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)\n    x = tok_emb + pos_emb # (B,T,C)\n    ......<\/code><\/pre>\n\n\n\n<p>\u5176\u4e2d\uff1a<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong>\u5355\u8bcd\u5d4c\u5165(Token Embeddings)<\/strong>:&nbsp;<code>tok_emb = self.token_embedding_table(idx)<\/code>&nbsp;\u8fd9\u4e00\u884c\u4ee3\u7801\u901a\u8fc7\u67e5\u627e\u5d4c\u5165\u8868\u5c06\u8f93\u5165\u7684\u5355\u8bcd\u7d22\u5f15<code>idx<\/code>\u8f6c\u6362\u6210\u5bf9\u5e94\u7684\u5d4c\u5165\u5411\u91cf<code>tok_emb<\/code>\u3002\u5d4c\u5165\u8868\u662f\u4e00\u4e2a\u9884\u5148\u8bad\u7ec3\u597d\u7684\uff0c\u53ef\u4ee5\u5c06\u6bcf\u4e2a\u552f\u4e00\u5355\u8bcd\u6620\u5c04\u5230\u4e00\u4e2a\u9ad8\u7ef4\u7a7a\u95f4\u4e2d\u7684\u5411\u91cf\u7684\u8868\u3002\u8fd9\u91cc\u7684<code>(B,T,C)<\/code>\u8868\u793a\u6279\u6b21\u5927\u5c0f\u4e3aB\uff0c\u5e8f\u5217\u957f\u5ea6\u4e3aT\uff0c\u5d4c\u5165\u7ef4\u5ea6\u4e3aC\u3002<\/li>\n\n\n\n<li><strong>\u4f4d\u7f6e\u5d4c\u5165(Positional Embeddings)<\/strong>:&nbsp;<code>pos_emb = self.position_embedding_table(torch.arange(T, device=device))<\/code>&nbsp;\u8fd9\u884c\u4ee3\u7801\u751f\u6210\u4e00\u4e2a\u4f4d\u7f6e\u5d4c\u5165\uff0c\u5176\u4e2d<code>torch.arange(T)<\/code>\u751f\u6210\u4e00\u4e2a\u4ece0\u5230T-1\u7684\u5e8f\u5217\uff0c\u5bf9\u5e94\u4e8e\u8f93\u5165\u5e8f\u5217\u4e2d\u6bcf\u4e2a\u4f4d\u7f6e\u7684\u7d22\u5f15\u3002<code>self.position_embedding_table<\/code>\u662f\u4e00\u4e2a\u9884\u5148\u5b9a\u4e49\u7684\u5d4c\u5165\u8868\uff0c\u5b83\u5c06\u8fd9\u4e9b\u4f4d\u7f6e\u7d22\u5f15\u6620\u5c04\u5230C\u7ef4\u7684\u5411\u91cf\u4e0a\uff0c\u8fd9\u6837\u6bcf\u4e2a\u4f4d\u7f6e\u5c31\u6709\u4e86\u81ea\u5df1\u7684\u4f4d\u7f6e\u5d4c\u5165\u3002\u8fd9\u4e2a\u5d4c\u5165\u5411\u91cf\u80fd\u591f\u4ee3\u8868\u6216\u7f16\u7801\u8be5\u4f4d\u7f6e\u5728\u5e8f\u5217\u4e2d\u7684\u76f8\u5bf9\u6216\u7edd\u5bf9\u4f4d\u7f6e\u4fe1\u606f\u3002<\/li>\n\n\n\n<li><strong>\u5408\u5e76\u5d4c\u5165<\/strong>:&nbsp;<code>x = tok_emb + pos_emb<\/code>&nbsp;\u6700\u540e\uff0c\u901a\u8fc7\u5c06\u5355\u8bcd\u5d4c\u5165\u548c\u4f4d\u7f6e\u5d4c\u5165\u76f8\u52a0\uff0c\u4e3a\u6bcf\u4e2a\u5355\u8bcd\u751f\u6210\u4e86\u4e00\u4e2a\u5305\u542b\u4e86\u4f4d\u7f6e\u4fe1\u606f\u7684\u6700\u7ec8\u5d4c\u5165\u3002\u8fd9\u4e2a\u64cd\u4f5c\u786e\u4fdd\u4e86\u6a21\u578b\u7684\u8f93\u5165\u65e2\u5305\u542b\u4e86\u5355\u8bcd\u7684\u8bed\u4e49\u4fe1\u606f\uff08\u901a\u8fc7\u5355\u8bcd\u5d4c\u5165\uff09\uff0c\u4e5f\u5305\u542b\u4e86\u5355\u8bcd\u7684\u4f4d\u7f6e\u4fe1\u606f\uff08\u901a\u8fc7\u4f4d\u7f6e\u5d4c\u5165\uff09\u3002\u8fd9\u6837\uff0c\u5373\u4f7f\u5728\u5904\u7406\u5e8f\u5217\u7684\u65f6\u5019\uff0c\u6a21\u578b\u4e5f\u80fd\u591f\u8bc6\u522b\u51fa\u5355\u8bcd\u7684\u987a\u5e8f\uff0c\u4ece\u800c\u66f4\u597d\u5730\u7406\u89e3\u8bed\u8a00\u6216\u5e8f\u5217\u6570\u636e\u7684\u7ed3\u6784\u548c\u542b\u4e49\u3002<\/li>\n<\/ul>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>15. GPT Block \u7ec4\u4ef6<\/strong><\/h2>\n\n\n\n<p>GPT \u662f\u7531\u591a\u4e2a Block \u7ec4\u4ef6\u4e32\u8d77\u6765\u7684\u3002\uff08\u6ce8\uff0c\u8fd9\u91cc\u8bf4\u7684 Block \u4e0d\u662f\u524d\u9762\u7684\u5e8f\u5217\u5207\u7247\uff0c\u8fd9\u91cc\u6307 GPT \u7684\u7ec4\u6210\u6a21\u5757\uff09\u3002\u5b83\u7684\u6784\u9020\u5982\u4e0b\uff1a<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>\u52a0\u5165\u591a\u5934\u81ea\u6ce8\u610f\u529b<\/li>\n\n\n\n<li>\u591a\u5934\u81ea\u6ce8\u610f\u529b\u540e\u9762\u52a0 Feed Forward Layer<\/li>\n\n\n\n<li>\u52a0\u5165 residual connection\uff08\u6b8b\u5dee\u8fde\u63a5\uff09<\/li>\n\n\n\n<li>\u52a0\u5165 LayerNorm\uff0cPre-LayerNorm\uff0c\u5176\u4e2d\uff0c\u540e\u8005\u662f\u5728\u8fdb\u5165\u591a\u5934\u81ea\u6ce8\u610f\u529b\u4e4b\u524d\uff0c\u5c31\u5148\u8fdb\u884c\u5c42\u5f52\u4e00\u5316<\/li>\n<\/ul>\n\n\n\n<p>\u5177\u4f53\u4ee3\u7801\u5b9e\u73b0\u5982\u4e0b\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>class Block(nn.Module):\n    \"\"\" Transformer block: communication followed by computation \"\"\"\n\n    def __init__(self, n_embd, n_head):\n        # n_embd: embedding dimension, n_head: the number of heads we'd like\n        super().__init__()\n        head_size = n_embd \/\/ n_head\n        self.sa = MultiHeadAttention(n_head, head_size)\n        self.ffwd = FeedFoward(n_embd)\n        self.ln1 = nn.LayerNorm(n_embd)\n        self.ln2 = nn.LayerNorm(n_embd)\n\n    def forward(self, x):\n        x = x + self.sa(self.ln1(x))\n        x = x + self.ffwd(self.ln2(x))\n        return x<\/code><\/pre>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>16. \u57fa\u4e8e BigramLanguageModel \u9b54\u6539 GPT<\/strong><\/h2>\n\n\n\n<p>\u63a5\u4e0b\u6765\uff0c\u6211\u4eec\u57fa\u4e8e\u5df2\u6709 BigramLanguageModel \u7684\u6846\u67b6\uff0c\u52a0\u4e0a\u524d\u9762\u51e0\u8282\u4e2d\u7684\u77e5\u8bc6\uff0c\u9b54\u6539\u51fa GPT\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code># super simple bigram model\nclass BigramLanguageModelV2(nn.Module):\n    def __init__(self):\n        super().__init__()\n        # each token directly reads off the logits for the next token from a lookup table\n        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)\n        self.position_embedding_table = nn.Embedding(block_size, n_embd)\n        self.blocks = nn.Sequential(*&#91;Block(n_embd, n_head=n_head) for _ in range(n_layer)])\n        self.ln_f = nn.LayerNorm(n_embd) # final layer norm\n        self.lm_head = nn.Linear(n_embd, vocab_size)\n\n    def forward(self, idx, targets=None):\n        B, T = idx.shape\n\n        # idx and targets are both (B,T) tensor of integers\n        tok_emb = self.token_embedding_table(idx) # (B,T,C)\n        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)\n        x = tok_emb + pos_emb # (B,T,C)\n        x = self.blocks(x) # (B,T,C)\n        x = self.ln_f(x) # (B,T,C)\n        logits = self.lm_head(x) # (B,T,vocab_size)\n\n        if targets is None:\n            loss = None\n        else:\n            B, T, C = logits.shape\n            logits = logits.view(B*T, C)\n            targets = targets.view(B*T)\n            loss = F.cross_entropy(logits, targets)\n\n        return logits, loss\n\n    def generate(self, idx, max_new_tokens):\n        # idx is (B, T) array of indices in the current context\n        for _ in range(max_new_tokens):\n            # crop idx to the last block_size tokens\n            idx_cond = idx&#91;:, -block_size:]\n            # get the predictions\n            logits, loss = self(idx_cond)\n            # focus only on the last time step\n            logits = logits&#91;:, -1, :] # becomes (B, C)\n            # apply softmax to get probabilities\n            probs = F.softmax(logits, dim=-1) # (B, C)\n            # sample from the distribution\n            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n            # append sampled index to the running sequence\n            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n        return idx<\/code><\/pre>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>17. \u8bad\u7ec3 GPT<\/strong><\/h2>\n\n\n\n<p>\u4e0b\u9762\u662f\u5b8c\u6574\u7684\u8bad\u7ec3\u4ee3\u7801<\/p>\n\n\n\n<p>\u8fd9\u4e2a\u811a\u672c\u6982\u8ff0\u4e86\u4f7f\u7528\u7b80\u5316\u7684Transformer\u67b6\u6784\u521b\u5efa\u548c\u8bad\u7ec3\u4e00\u4e2a\u5b57\u7b26\u7ea7\u522b\u7684\u8bed\u8a00\u6a21\u578b\u7684\u8fc7\u7a0b\u3002\u8be5\u6a21\u578b\u5728\u4e00\u4e2a\u6570\u636e\u96c6\u4e0a\u8fdb\u884c\u8bad\u7ec3\uff0c\u8fd9\u4e2a\u6570\u636e\u96c6\u5f88\u53ef\u80fd\u6765\u81ea\u4e8eTiny Shakespeare\u8bed\u6599\u5e93\uff0c\u4ee5\u751f\u6210\u7c7b\u4f3c\u98ce\u683c\u7684\u6587\u672c\u3002\u4ee5\u4e0b\u662f\u811a\u672c\u4e2d\u5173\u952e\u7ec4\u4ef6\u548c\u8fc7\u7a0b\u7684\u9010\u6b65\u89e3\u6790\uff1a<\/p>\n\n\n\n<ol class=\"wp-block-list\">\n<li><strong>\u8d85\u53c2\u6570\u8bbe\u7f6e<\/strong>\uff1a\u5b9a\u4e49\u4e86\u6279\u5904\u7406\u5927\u5c0f\u3001\u5757\u5927\u5c0f\u3001\u5b66\u4e60\u7387\u4ee5\u53ca\u6a21\u578b\u67b6\u6784\u7ec6\u8282\uff0c\u5982\u5d4c\u5165\u5c42\u7684\u6570\u91cf\u3001\u5934\u90e8\u6570\u3001\u5c42\u6570\u548c\u4e22\u5f03\u7387\u7b49\u8bad\u7ec3\u53c2\u6570\u3002<\/li>\n\n\n\n<li><strong>\u6570\u636e\u51c6\u5907<\/strong>\uff1a\n<ul class=\"wp-block-list\">\n<li>\u4ece\u6587\u4ef6\u4e2d\u52a0\u8f7d\u6587\u672c\u6570\u636e\u3002<\/li>\n\n\n\n<li>\u4ece\u6587\u672c\u4e2d\u521b\u5efa\u4e00\u4e2a\u72ec\u7279\u5b57\u7b26\u7684\u8bcd\u6c47\u8868\uff0c\u5e76\u5c06\u5b57\u7b26\u6620\u5c04\u4e3a\u6574\u6570\uff08\u4ee5\u53ca\u76f8\u53cd\u7684\u6620\u5c04\uff09\u4ee5\u4fbf\u5904\u7406\u3002<\/li>\n\n\n\n<li>\u5c06\u6570\u636e\u5206\u5272\u4e3a\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6\u3002<\/li>\n<\/ul>\n<\/li>\n\n\n\n<li><strong>\u6279\u5904\u7406\u51c6\u5907<\/strong>\uff1a\u5b9e\u73b0\u4e86\u4e00\u4e2a\u51fd\u6570\uff0c\u4e3a\u8bad\u7ec3\u548c\u9a8c\u8bc1\u751f\u6210\u6570\u636e\u6279\u6b21\u3002\u6bcf\u4e2a\u6279\u6b21\u7531\u8f93\u5165\u5e8f\u5217\u53ca\u5176\u5bf9\u5e94\u7684\u76ee\u6807\u5e8f\u5217\u7ec4\u6210\uff0c\u76ee\u6807\u5e8f\u5217\u672c\u8d28\u4e0a\u662f\u8f93\u5165\u5e8f\u5217\u5411\u53f3\u79fb\u52a8\u4e00\u4e2a\u5b57\u7b26\u3002<\/li>\n\n\n\n<li><strong>\u6a21\u578b\u7ec4\u4ef6<\/strong>\uff1a\u5b9a\u4e49\u4e86Transformer\u6a21\u578b\u7684\u5173\u952e\u7ec4\u6210\u90e8\u5206\uff1a\n<ul class=\"wp-block-list\">\n<li><strong>Head<\/strong>\uff1a\u5b9e\u73b0\u4e86\u5355\u4e2a\u81ea\u6ce8\u610f\u529b\u5934\u3002<\/li>\n\n\n\n<li><strong>MultiHeadAttention<\/strong>\uff1a\u6c47\u603b\u591a\u4e2a\u81ea\u6ce8\u610f\u529b\u5934\u3002<\/li>\n\n\n\n<li><strong>FeedForward<\/strong>\uff1a\u4e00\u4e2a\u7b80\u5355\u7684\u7ebf\u6027\u5c42\uff0c\u540e\u8ddfReLU\u6fc0\u6d3b\u3002<\/li>\n\n\n\n<li><strong>Block<\/strong>\uff1a\u5c06\u6ce8\u610f\u529b\u548c\u524d\u9988\u7ec4\u4ef6\u7ec4\u5408\u6210\u5355\u4e2aTransformer\u5757\u3002<\/li>\n<\/ul>\n<\/li>\n\n\n\n<li><strong>\u6a21\u578b\u67b6\u6784<\/strong>\uff1a\u4f7f\u7528\u5d4c\u5165\u5c42\u6784\u5efa\u8bed\u8a00\u6a21\u578b\uff0c\u7528\u4e8e\u4ee4\u724c\u548c\u4f4d\u7f6e\u5d4c\u5165\uff0c\u591a\u4e2aTransformer\u5757\uff0c\u6700\u540e\u7684\u5c42\u6b63\u5219\u5316\uff0c\u4ee5\u53ca\u4e00\u4e2a\u7ebf\u6027\u5c42\u6765\u9884\u6d4b\u4e0b\u4e00\u4e2a\u5b57\u7b26\u3002<\/li>\n\n\n\n<li><strong>\u635f\u5931\u4f30\u8ba1<\/strong>\uff1a\u5b9a\u4e49\u4e86\u4e00\u4e2a\u51fd\u6570\uff0c\u4ee5\u5728\u4e0d\u66f4\u65b0\u6a21\u578b\u6743\u91cd\u7684\u60c5\u51b5\u4e0b\uff0c\u8bc4\u4f30\u6a21\u578b\u5728\u8bad\u7ec3\u548c\u9a8c\u8bc1\u96c6\u4e0a\u7684\u6027\u80fd\u3002<\/li>\n\n\n\n<li><strong>\u8bad\u7ec3\u5faa\u73af<\/strong>\uff1a\u901a\u8fc7\u4ee5\u4e0b\u6b65\u9aa4\u8fed\u4ee3\u8bad\u7ec3\u6a21\u578b\uff1a\n<ul class=\"wp-block-list\">\n<li>\u62bd\u53d6\u6570\u636e\u6279\u6b21\u3002<\/li>\n\n\n\n<li>\u8ba1\u7b97\u635f\u5931\u3002<\/li>\n\n\n\n<li>\u6267\u884c\u53cd\u5411\u4f20\u64ad\u3002<\/li>\n\n\n\n<li>\u66f4\u65b0\u6a21\u578b\u7684\u6743\u91cd\u3002<\/li>\n\n\n\n<li>\u5b9a\u671f\u5728\u8bad\u7ec3\u548c\u9a8c\u8bc1\u96c6\u4e0a\u8bc4\u4f30\u6a21\u578b\uff0c\u4ee5\u76d1\u63a7\u6027\u80fd\u3002<\/li>\n<\/ul>\n<\/li>\n\n\n\n<li><strong>\u6587\u672c\u751f\u6210<\/strong>\uff1a\u5b9e\u73b0\u4e86\u4e00\u79cd\u65b9\u6cd5\uff0c\u4ece\u7ed9\u5b9a\u5f00\u59cb\u4e0a\u4e0b\u6587\u7684\u6a21\u578b\u4e2d\u751f\u6210\u6587\u672c\u3002\u5b83\u901a\u8fc7\u4ece\u6a21\u578b\u7684\u9884\u6d4b\u4e2d\u91c7\u6837\uff0c\u8fed\u4ee3\u6dfb\u52a0\u65b0\u5b57\u7b26\u6765\u6269\u5c55\u4e0a\u4e0b\u6587\uff0c\u4ee5\u751f\u6210\u6307\u5b9a\u957f\u5ea6\u7684\u5e8f\u5217\u3002<\/li>\n\n\n\n<li><strong>\u6267\u884c<\/strong>\uff1a\u6700\u540e\uff0c\u811a\u672c\u521d\u59cb\u5316\u6a21\u578b\uff0c\u5c06\u5176\u79fb\u52a8\u5230\u9002\u5f53\u7684\u8bbe\u5907\uff08\u5982\u679c\u53ef\u7528\uff0c\u5219\u4e3aGPU\uff09\uff0c\u5e76\u6253\u5370\u51fa\u6a21\u578b\u7684\u53c2\u6570\u6570\u91cf\u3002\u7136\u540e\u8fdb\u5165\u8bad\u7ec3\u5faa\u73af\uff0c\u5b9a\u671f\u62a5\u544a\u635f\u5931\uff0c\u5b8c\u6210\u540e\uff0c\u4ece\u8bad\u7ec3\u597d\u7684\u6a21\u578b\u751f\u6210\u6587\u672c\u5e8f\u5217\u3002<\/li>\n<\/ol>\n\n\n\n<p>\u8be5\u811a\u672c\u5c55\u793a\u4e86\u5982\u4f55\u5b9e\u73b0\u4e00\u4e2a\u57fa\u4e8eTransformer\u7684\u6a21\u578b\uff0c\u7528\u4e8e\u5b57\u7b26\u7ea7\u6587\u672c\u751f\u6210\u4efb\u52a1\uff0c\u51f8\u663e\u4e86Transformer\u67b6\u6784\u5bf9\u4e8e\u5e8f\u5217\u5efa\u6a21\u4efb\u52a1\u7684\u7075\u6d3b\u6027\u548c\u6709\u6548\u6027\u3002<\/p>\n\n\n\n<p>setup17.1 \u5b8c\u6574\u4ee3\u7801<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>import torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\n\n# hyperparameters\n# how many independent sequences will we process in parallel?\nbatch_size = 16\n# what is the maximum context length for predictions?\nblock_size = 32\nmax_iters = 5000\neval_interval = 100\nlearning_rate = 1e-3\ndevice = 'cuda' if torch.cuda.is_available() else 'cpu'\neval_iters = 200\nn_embd = 64\nn_head = 4\nn_layer = 4\ndropout = 0.0\n# ------------\n\ntorch.manual_seed(1337)\n\n# wget https:\/\/raw.githubusercontent.com\/karpathy\/char-rnn\/master\/data\/tinyshakespeare\/input.txt\nwith open('data\/input.txt', 'r', encoding='utf-8') as f:\n    text = f.read()\n\n# here are all the unique characters that occur in this text\nchars = sorted(list(set(text)))\nvocab_size = len(chars)\n# create a mapping from characters to integers\nstoi = { ch:i for i,ch in enumerate(chars) }\nitos = { i:ch for i,ch in enumerate(chars) }\n# encoder: take a string, output a list of integers\nencode = lambda s: &#91;stoi&#91;c] for c in s]\n# decoder: take a list of integers, output a string\ndecode = lambda l: ''.join(&#91;itos&#91;i] for i in l])\n\n# Train and test splits\ndata = torch.tensor(encode(text), dtype=torch.long)\n# first 90% will be train, rest val\nn = int(0.9*len(data))\ntrain_data = data&#91;:n]\nval_data = data&#91;n:]\n\n# data loading\ndef get_batch(split):\n    # generate a small batch of data of inputs x and targets y\n    data = train_data if split == 'train' else val_data\n    ix = torch.randint(len(data) - block_size, (batch_size,))\n    x = torch.stack(&#91;data&#91;i:i+block_size] for i in ix])\n    y = torch.stack(&#91;data&#91;i+1:i+block_size+1] for i in ix])\n    x, y = x.to(device), y.to(device)\n    return x, y\n\n@torch.no_grad()\ndef estimate_loss():\n    out = {}\n    model.eval()\n    for split in &#91;'train', 'val']:\n        losses = torch.zeros(eval_iters)\n        for k in range(eval_iters):\n            X, Y = get_batch(split)\n            logits, loss = model(X, Y)\n            losses&#91;k] = loss.item()\n        out&#91;split] = losses.mean()\n    model.train()\n    return out\n\nclass Head(nn.Module):\n    \"\"\" one head of self-attention \"\"\"\n\n    def __init__(self, head_size):\n        super().__init__()\n        self.key = nn.Linear(n_embd, head_size, bias=False)\n        self.query = nn.Linear(n_embd, head_size, bias=False)\n        self.value = nn.Linear(n_embd, head_size, bias=False)\n        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))\n\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, x):\n        B,T,C = x.shape\n        k = self.key(x)   # (B,T,C)\n        q = self.query(x) # (B,T,C)\n        # compute attention scores (\"affinities\")\n        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)\n        wei = wei.masked_fill(self.tril&#91;:T, :T] == 0, float('-inf')) # (B, T, T)\n        wei = F.softmax(wei, dim=-1) # (B, T, T)\n        wei = self.dropout(wei)\n        # perform the weighted aggregation of the values\n        v = self.value(x) # (B,T,C)\n        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)\n        return out\n\nclass MultiHeadAttention(nn.Module):\n    \"\"\" multiple heads of self-attention in parallel \"\"\"\n\n    def __init__(self, num_heads, head_size):\n        super().__init__()\n        self.heads = nn.ModuleList(&#91;Head(head_size) for _ in range(num_heads)])\n        self.proj = nn.Linear(n_embd, n_embd)\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, x):\n        out = torch.cat(&#91;h(x) for h in self.heads], dim=-1)\n        out = self.dropout(self.proj(out))\n        return out\n\nclass FeedFoward(nn.Module):\n    \"\"\" a simple linear layer followed by a non-linearity \"\"\"\n\n    def __init__(self, n_embd):\n        super().__init__()\n        self.net = nn.Sequential(\n            nn.Linear(n_embd, 4 * n_embd),\n            nn.ReLU(),\n            nn.Linear(4 * n_embd, n_embd),\n            nn.Dropout(dropout),\n        )\n\n    def forward(self, x):\n        return self.net(x)\n\nclass Block(nn.Module):\n    \"\"\" Transformer block: communication followed by computation \"\"\"\n\n    def __init__(self, n_embd, n_head):\n        # n_embd: embedding dimension, n_head: the number of heads we'd like\n        super().__init__()\n        head_size = n_embd \/\/ n_head\n        self.sa = MultiHeadAttention(n_head, head_size)\n        self.ffwd = FeedFoward(n_embd)\n        self.ln1 = nn.LayerNorm(n_embd)\n        self.ln2 = nn.LayerNorm(n_embd)\n\n    def forward(self, x):\n        x = x + self.sa(self.ln1(x))\n        x = x + self.ffwd(self.ln2(x))\n        return x\n\n# super simple bigram model\nclass BigramLanguageModel(nn.Module):\n\n    def __init__(self):\n        super().__init__()\n        # each token directly reads off the logits for the next token from a lookup table\n        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)\n        self.position_embedding_table = nn.Embedding(block_size, n_embd)\n        self.blocks = nn.Sequential(*&#91;Block(n_embd, n_head=n_head) for _ in range(n_layer)])\n        self.ln_f = nn.LayerNorm(n_embd) # final layer norm\n        self.lm_head = nn.Linear(n_embd, vocab_size)\n\n    def forward(self, idx, targets=None):\n        B, T = idx.shape\n\n        # idx and targets are both (B,T) tensor of integers\n        tok_emb = self.token_embedding_table(idx) # (B,T,C)\n        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)\n        x = tok_emb + pos_emb # (B,T,C)\n        x = self.blocks(x) # (B,T,C)\n        x = self.ln_f(x) # (B,T,C)\n        logits = self.lm_head(x) # (B,T,vocab_size)\n\n        if targets is None:\n            loss = None\n        else:\n            B, T, C = logits.shape\n            logits = logits.view(B*T, C)\n            targets = targets.view(B*T)\n            loss = F.cross_entropy(logits, targets)\n\n        return logits, loss\n\n    def generate(self, idx, max_new_tokens):\n        # idx is (B, T) array of indices in the current context\n        for _ in range(max_new_tokens):\n            # crop idx to the last block_size tokens\n            idx_cond = idx&#91;:, -block_size:]\n            # get the predictions\n            logits, loss = self(idx_cond)\n            # focus only on the last time step\n            logits = logits&#91;:, -1, :] # becomes (B, C)\n            # apply softmax to get probabilities\n            probs = F.softmax(logits, dim=-1) # (B, C)\n            # sample from the distribution\n            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n            # append sampled index to the running sequence\n            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n        return idx\n\nmodel = BigramLanguageModel()\nm = model.to(device)\n# print the number of parameters in the model\nprint(sum(p.numel() for p in m.parameters())\/1e6, 'M parameters')\n\n# create a PyTorch optimizer\noptimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n\nfor iter in range(max_iters):\n\n    # every once in a while evaluate the loss on train and val sets\n    if iter % eval_interval == 0 or iter == max_iters - 1:\n        losses = estimate_loss()\n        print(f\"step {iter}: train loss {losses&#91;'train']:.4f}, val loss {losses&#91;'val']:.4f}\")\n\n    # sample a batch of data\n    xb, yb = get_batch('train')\n\n    # evaluate the loss\n    logits, loss = model(xb, yb)\n    optimizer.zero_grad(set_to_none=True)\n    loss.backward()\n    optimizer.step()\n\n# generate from the model\ncontext = torch.zeros((1, 1), dtype=torch.long, device=device)\nprint(decode(m.generate(context, max_new_tokens=2000)&#91;0].tolist()))\n<\/code><\/pre>\n\n\n\n<p>\u8fd0\u884c setup17.1<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>\ntony@TONYP15GEN2:\/mnt\/d\/OpenAI\/CreateGPT$ python setup17.1.py\n0.209729 M parameters\nstep 0: train loss 4.4116, val loss 4.4022\nstep 100: train loss 2.6568, val loss 2.6670\nstep 200: train loss 2.5091, val loss 2.5058\n....\nstep 4900: train loss 1.6678, val loss 1.8338\nstep 4999: train loss 1.6627, val loss 1.8233\n\nAnd they bride will to lovest made\nTo bube toest the dest day, and bartht he us his vetward that a enswing my feanst,\nAn yentreath, Lot fortth bettly would but\nWith entends will is that Glost and the now our wabs!\nAll in you his husberd with at princess,\nwhy holvings nor\nTo this destrittle, demath kneoul---on her pribest, and doth will now;\nBut poor of his butt known, rupt for to to his shall do allood,\nThat Prive my of.\n\nHENRY BOLINGS:\nYou ardsables!\nEghts, hois courtear tear rests I\ncommand.\nO, no to Pome, griving and your mast a cempres-ennom betwer'd madant thou such\nBut not usinne, will confessy.\nWhich migh.\n......\n<\/code><\/pre>\n\n\n\n<pre id=\"block-d3a23fee-6e2a-40fc-84d6-29f6e8ff155a\" class=\"wp-block-code\"><code>\u589e\u52a0 max_iters \u6b21\u6570\u5230 50000\u548c\u5176\u4ed6\u53c2\u6570<\/code><\/pre>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code># how many independent sequences will we process in parallel?\nbatch_size = 64\n# what is the maximum context length for predictions?\nblock_size = 64\nmax_iters = 50000\neval_interval = 1000\n<\/code><\/pre>\n\n\n\n<p>\u8fd0\u884c setup17.1<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>$ python setup17.1.py\n0.211777 M parameters\nstep 0: train loss 4.3393, val loss 4.3480\nstep 1000: train loss 1.9105, val loss 1.9978\nstep 2000: train loss 1.6546, val loss 1.8139\nstep 3000: train loss 1.5592, val loss 1.7467\n......\nstep 49999: train loss 1.2681, val loss 1.5861\n\nWhat thy bridal?\n\nSTANLEY:\nHe madest my best eyes.\nOne stride, and be that fought thee by their\ncaues with my face, and zoke he on\nthe office commandion will beg it ended mine\nStirs in oversumed; the next is waked. Anly die;\nFor humble comforwater and plaw you:\nThe sensemoumes are we'll like to thee\nTo Wicknes do evimes to them, and Tieuted Keep ContemenDusgued.\n......<\/code><\/pre>\n\n\n\n<p>18.\u4fdd\u5b58 model<\/p>\n\n\n\n<p>setup18.1 \u5b8c\u6574\u4ee3\u7801<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>import os\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\n\n# hyperparameters\n# how many independent sequences will we process in parallel?\nbatch_size = 64\n# what is the maximum context length for predictions?\nblock_size = 64\nmax_iters = 500000 #50000\neval_interval = 1000\nlearning_rate = 1e-3\ndevice = 'cuda' if torch.cuda.is_available() else 'cpu'\neval_iters = 200\nn_embd = 64\nn_head = 4\nn_layer = 4\ndropout = 0.0\n# ------------\n\ntorch.manual_seed(1337)\n\n# wget https:\/\/raw.githubusercontent.com\/karpathy\/char-rnn\/master\/data\/tinyshakespeare\/input.txt\nwith open('data\/input.txt', 'r', encoding='utf-8') as f:\n    text = f.read()\n\n# here are all the unique characters that occur in this text\nchars = sorted(list(set(text)))\nvocab_size = len(chars)\n# create a mapping from characters to integers\nstoi = { ch:i for i,ch in enumerate(chars) }\nitos = { i:ch for i,ch in enumerate(chars) }\n# encoder: take a string, output a list of integers\nencode = lambda s: &#91;stoi&#91;c] for c in s]\n# decoder: take a list of integers, output a string\ndecode = lambda l: ''.join(&#91;itos&#91;i] for i in l])\n\n# Train and test splits\ndata = torch.tensor(encode(text), dtype=torch.long)\n# first 90% will be train, rest val\nn = int(0.9*len(data))\ntrain_data = data&#91;:n]\nval_data = data&#91;n:]\n\n# data loading\ndef get_batch(split):\n    # generate a small batch of data of inputs x and targets y\n    data = train_data if split == 'train' else val_data\n    ix = torch.randint(len(data) - block_size, (batch_size,))\n    x = torch.stack(&#91;data&#91;i:i+block_size] for i in ix])\n    y = torch.stack(&#91;data&#91;i+1:i+block_size+1] for i in ix])\n    x, y = x.to(device), y.to(device)\n    return x, y\n\n@torch.no_grad()\ndef estimate_loss():\n    out = {}\n    model.eval()\n    for split in &#91;'train', 'val']:\n        losses = torch.zeros(eval_iters)\n        for k in range(eval_iters):\n            X, Y = get_batch(split)\n            logits, loss = model(X, Y)\n            losses&#91;k] = loss.item()\n        out&#91;split] = losses.mean()\n    model.train()\n    return out\n\nclass Head(nn.Module):\n    \"\"\" one head of self-attention \"\"\"\n\n    def __init__(self, head_size):\n        super().__init__()\n        self.key = nn.Linear(n_embd, head_size, bias=False)\n        self.query = nn.Linear(n_embd, head_size, bias=False)\n        self.value = nn.Linear(n_embd, head_size, bias=False)\n        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))\n\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, x):\n        B,T,C = x.shape\n        k = self.key(x)   # (B,T,C)\n        q = self.query(x) # (B,T,C)\n        # compute attention scores (\"affinities\")\n        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)\n        wei = wei.masked_fill(self.tril&#91;:T, :T] == 0, float('-inf')) # (B, T, T)\n        wei = F.softmax(wei, dim=-1) # (B, T, T)\n        wei = self.dropout(wei)\n        # perform the weighted aggregation of the values\n        v = self.value(x) # (B,T,C)\n        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)\n        return out\n\nclass MultiHeadAttention(nn.Module):\n    \"\"\" multiple heads of self-attention in parallel \"\"\"\n\n    def __init__(self, num_heads, head_size):\n        super().__init__()\n        self.heads = nn.ModuleList(&#91;Head(head_size) for _ in range(num_heads)])\n        self.proj = nn.Linear(n_embd, n_embd)\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, x):\n        out = torch.cat(&#91;h(x) for h in self.heads], dim=-1)\n        out = self.dropout(self.proj(out))\n        return out\n\nclass FeedFoward(nn.Module):\n    \"\"\" a simple linear layer followed by a non-linearity \"\"\"\n\n    def __init__(self, n_embd):\n        super().__init__()\n        self.net = nn.Sequential(\n            nn.Linear(n_embd, 4 * n_embd),\n            nn.ReLU(),\n            nn.Linear(4 * n_embd, n_embd),\n            nn.Dropout(dropout),\n        )\n\n    def forward(self, x):\n        return self.net(x)\n\nclass Block(nn.Module):\n    \"\"\" Transformer block: communication followed by computation \"\"\"\n\n    def __init__(self, n_embd, n_head):\n        # n_embd: embedding dimension, n_head: the number of heads we'd like\n        super().__init__()\n        head_size = n_embd \/\/ n_head\n        self.sa = MultiHeadAttention(n_head, head_size)\n        self.ffwd = FeedFoward(n_embd)\n        self.ln1 = nn.LayerNorm(n_embd)\n        self.ln2 = nn.LayerNorm(n_embd)\n\n    def forward(self, x):\n        x = x + self.sa(self.ln1(x))\n        x = x + self.ffwd(self.ln2(x))\n        return x\n\n# super simple bigram model\nclass BigramLanguageModel(nn.Module):\n\n    def __init__(self):\n        super().__init__()\n        # each token directly reads off the logits for the next token from a lookup table\n        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)\n        self.position_embedding_table = nn.Embedding(block_size, n_embd)\n        self.blocks = nn.Sequential(*&#91;Block(n_embd, n_head=n_head) for _ in range(n_layer)])\n        self.ln_f = nn.LayerNorm(n_embd) # final layer norm\n        self.lm_head = nn.Linear(n_embd, vocab_size)\n\n    def forward(self, idx, targets=None):\n        B, T = idx.shape\n\n        # idx and targets are both (B,T) tensor of integers\n        tok_emb = self.token_embedding_table(idx) # (B,T,C)\n        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)\n        x = tok_emb + pos_emb # (B,T,C)\n        x = self.blocks(x) # (B,T,C)\n        x = self.ln_f(x) # (B,T,C)\n        logits = self.lm_head(x) # (B,T,vocab_size)\n\n        if targets is None:\n            loss = None\n        else:\n            B, T, C = logits.shape\n            logits = logits.view(B*T, C)\n            targets = targets.view(B*T)\n            loss = F.cross_entropy(logits, targets)\n\n        return logits, loss\n\n    def generate(self, idx, max_new_tokens):\n        # idx is (B, T) array of indices in the current context\n        for _ in range(max_new_tokens):\n            # crop idx to the last block_size tokens\n            idx_cond = idx&#91;:, -block_size:]\n            # get the predictions\n            logits, loss = self(idx_cond)\n            # focus only on the last time step\n            logits = logits&#91;:, -1, :] # becomes (B, C)\n            # apply softmax to get probabilities\n            probs = F.softmax(logits, dim=-1) # (B, C)\n            # sample from the distribution\n            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n            # append sampled index to the running sequence\n            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n        return idx\n\nmodel = BigramLanguageModel()\nm = model.to(device)\n# print the number of parameters in the model\nprint(sum(p.numel() for p in m.parameters())\/1e6, 'M parameters')\n\n# create a PyTorch optimizer\noptimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n\nbest_val_loss = float('inf')  # \u521d\u59cb\u5316\u6700\u4f73\u9a8c\u8bc1\u635f\u5931\u4e3a\u65e0\u7a77\u5927\n\nfor iter in range(max_iters):\n\n    # every once in a while evaluate the loss on train and val sets\n    if iter % eval_interval == 0 or iter == max_iters - 1:\n        losses = estimate_loss()\n        print(f\"step {iter}: train loss {losses&#91;'train']:.4f}, val loss {losses&#91;'val']:.4f}\")\n        val_loss = losses&#91;'train']\n        if val_loss &lt; best_val_loss:\n            best_val_loss = val_loss\n        \n    # sample a batch of data\n    xb, yb = get_batch('train')\n\n    # evaluate the loss\n    logits, loss = model(xb, yb)\n    optimizer.zero_grad(set_to_none=True)\n    loss.backward()\n    optimizer.step()\n\n# \u6307\u5b9a\u6a21\u578b\u4fdd\u5b58\u7684\u76ee\u5f55\u548c\u6587\u4ef6\u540d\nmodel_dir = 'model'\nmodel_filename = 'model.pth'\nmodel_path = os.path.join(model_dir, model_filename)\n\n# \u786e\u4fdd\u76ee\u5f55\u5b58\u5728\uff0c\u5982\u679c\u4e0d\u5b58\u5728\uff0c\u5219\u521b\u5efa\nif not os.path.exists(model_dir):\n    os.makedirs(model_dir)\n    \n# \u4fdd\u5b58\u6a21\u578b\u72b6\u6001\u548c\u5176\u5b83\u4fe1\u606f\ntorch.save({\n    'epoch': max_iters,\n    'model_state_dict': model.state_dict(),\n    'optimizer_state_dict': optimizer.state_dict(),\n    #'scheduler_state_dict': scheduler.state_dict() if scheduler else None,\n    'best_val_loss': best_val_loss,\n    'hyperparameters': {\n        'learning_rate': learning_rate,\n        'batch_size': batch_size,\n        'n_layer': n_layer,\n        'n_head': n_head,\n        'dropout': dropout,\n    },\n    #'loss_history': {\n    #    'train': train_loss_history,\n    #    'val': val_loss_history,\n    #},\n    # \u5176\u5b83\u9700\u8981\u4fdd\u5b58\u7684\u4fe1\u606f\n}, model_path)\n\nprint(f\"Model saved to {model_path}\")\n\n# generate from the model\ncontext = torch.zeros((1, 1), dtype=torch.long, device=device)\nprint(decode(m.generate(context, max_new_tokens=2000)&#91;0].tolist()))\n<\/code><\/pre>\n\n\n\n<p>\u8fd0\u884c setup18.1.py<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>$ python setup18.1.py\nstep 0: train loss 4.3393, val loss 4.3480\nstep 1000: train loss 1.9105, val loss 1.9978\nstep 2000: train loss 1.6546, val loss 1.8139\nstep 3000: train loss 1.5592, val loss 1.7467\n......\n<\/code><\/pre>\n\n\n\n<p class=\"has-small-font-size\">step 49000: train loss 1.2699, val loss 1.5956<br>step 49999: train loss 1.2681, val loss 1.5861<\/p>\n\n\n\n<p>What thy bridal?<\/p>\n\n\n\n<p>STANLEY:<br>He madest my best eyes.<br>One stride, and be that fought thee by their<br>caues with my face, and zoke he on<br>the office commandion will beg it ended mine<br>Stirs in oversumed; the next is waked. Anly die;<br>For humble comforwater and plaw you:<br>The sensemoumes are we&#8217;ll like to thee<br>To Wicknes do evimes to them, and Tieuted Keep ContemenDusgued.<\/p>\n\n\n\n<p><\/p>\n\n\n\n<p>\u53c2\u8003\u94fe\u63a5\uff1a<\/p>\n\n\n\n<p>https:\/\/medium.com\/@fareedkhandev\/understanding-transformers-a-step-by-step-math-example-part-1-a7809015150a<br>https:\/\/medium.com\/@fareedkhandev\/create-gpt-from-scratch-using-python-part-1-bd89ccf6206a<br>https:\/\/www.leewayhertz.com\/build-a-gpt-model\/<br>https:\/\/garden.maxieewong.com\/087.%E8%A7%86%E9%A2%91%E5%BA%93\/YouTube\/Andrej%20Karpathy\/Let&#8217;s%20build%20GPT%EF%BC%9Afrom%20scratch,%20in%20code,%20spelled%20out.\/<br>https:\/\/colab.research.google.com\/drive\/1JMLa53HDuA-i7ZBmqV7ZnA3c_fvtXnx-?usp=sharing#scrollTo=nql_1ER53oCf<\/p>\n","protected":false},"excerpt":{"rendered":"<p>0.GPT \u6a21\u578b\u6982\u8ff0 GPT \u6a21\u578b\u662f Generative Pretrained Transformer \u7684\u7f29 [&hellip;]<\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"site-sidebar-layout":"default","site-content-layout":"","ast-site-content-layout":"default","site-content-style":"default","site-sidebar-style":"default","ast-global-header-display":"","ast-banner-title-visibility":"","ast-main-header-display":"","ast-hfb-above-header-display":"","ast-hfb-below-header-display":"","ast-hfb-mobile-header-display":"","site-post-title":"","ast-breadcrumbs-content":"","ast-featured-img":"","footer-sml-layout":"","theme-transparent-header-meta":"","adv-header-id-meta":"","stick-header-meta":"","header-above-stick-meta":"","header-main-stick-meta":"","header-below-stick-meta":"","astra-migrate-meta-layouts":"set","ast-page-background-enabled":"default","ast-page-background-meta":{"desktop":{"background-color":"var(--ast-global-color-4)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"tablet":{"background-color":"","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"mobile":{"background-color":"","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""}},"ast-content-background-meta":{"desktop":{"background-color":"var(--ast-global-color-5)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"tablet":{"background-color":"var(--ast-global-color-5)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"mobile":{"background-color":"var(--ast-global-color-5)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""}},"_jetpack_memberships_contains_paid_content":false,"footnotes":""},"categories":[313,289,443,442,312],"tags":[242,314],"class_list":["post-2038","post","type-post","status-publish","format-standard","hentry","category-chatgpt","category-gpt","category-llm","category-llms","category-openai","tag-chatgpt","tag-openai-api"],"views":2529,"jetpack_sharing_enabled":true,"jetpack_featured_media_url":"","_links":{"self":[{"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/posts\/2038","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=%2Fwp%2Fv2%2Fcomments&post=2038"}],"version-history":[{"count":28,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/posts\/2038\/revisions"}],"predecessor-version":[{"id":2079,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/posts\/2038\/revisions\/2079"}],"wp:attachment":[{"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=%2Fwp%2Fv2%2Fmedia&parent=2038"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=%2Fwp%2Fv2%2Fcategories&post=2038"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=%2Fwp%2Fv2%2Ftags&post=2038"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}