{"id":2123,"date":"2024-02-26T09:53:20","date_gmt":"2024-02-26T01:53:20","guid":{"rendered":"https:\/\/www.aqwu.net\/wp\/?p=2123"},"modified":"2024-04-28T20:04:15","modified_gmt":"2024-04-28T12:04:15","slug":"60-%e8%a1%8c-numpy-%e4%b8%ad%e7%9a%84-gpt","status":"publish","type":"post","link":"https:\/\/www.aqwu.net\/wp\/?p=2123","title":{"rendered":"60 \u884c NumPy \u4e2d\u7684 GPT"},"content":{"rendered":"\n<p>\u672c\u6587\u8fd8\u662f\u6765\u81ea<a href=\"https:\/\/jaykmody.com\/\" target=\"_blank\" rel=\"noreferrer noopener\">Jay Mody<\/a>\uff0c\u90a3\u7bc7\u88ab<a href=\"https:\/\/twitter.com\/karpathy\/status\/1627729834821701633\" target=\"_blank\" rel=\"noreferrer noopener\">Andrej Karpathy\u624b\u52a8\u70b9\u8d5e<\/a>\u7684<a href=\"https:\/\/jaykmody.com\/blog\/gpt-from-scratch\/\" target=\"_blank\" rel=\"noreferrer noopener\">GPT in 60 Lines of NumPy<\/a>\u3002<\/p>\n\n\n\n<p>LLM\u5927\u884c\u5176\u9053\uff0c\u7136\u800c\u5927\u591a\u6570GPT\u6a21\u578b\u90fd\u50cf\u4e2a\u9ed1\u76d2\u5b50\u4e00\u822c\u9690\u9690\u7ef0\u7ef0\uff0c\u751a\u81f3\u5f88\u591a\u4eba\u90fd\u5f00\u59cb\u795e\u79d8\u5316\u8fd9\u4e2a\u6280\u672f\u3002\u6211\u89c9\u5f97\u76f4\u63a5\u8df3\u8fdb\u6570\u5b66\u539f\u7406\u548c\u4ee3\u7801\u91cc\u770b\u770b\u771f\u5b9e\u53d1\u751f\u4e86\u4ec0\u4e48\uff0c\u624d\u662f\u6700\u6709\u6548\u7684\u7406\u89e3\u67d0\u9879\u6280\u672f\u7684\u65b9\u6cd5\u3002\u6b63\u5982DeepMind\u7684Julian Schrittwieser\u6240\u8bf4\uff1a<\/p>\n\n\n\n<blockquote class=\"wp-block-quote is-layout-flow wp-block-quote-is-layout-flow\">\n<p>\u8fd9\u4e9b\u90fd\u662f\u7535\u8111\u7a0b\u5e8f\u3002<\/p>\n<\/blockquote>\n\n\n\n<p>\u8fd9\u7bc7\u6587\u7ae0\u7ec6\u81f4\u7684\u8bb2\u89e3\u4e86GPT\u6a21\u578b\u7684\u6838\u5fc3\u7ec4\u6210\u53ca\u539f\u7406\uff0c\u5e76\u4e14\u7528Numpy\u624b\u6413\u4e86\u4e00\u4e2a\u5b8c\u6574\u7684\u5b9e\u73b0\uff08\u53ef\u4ee5\u8dd1\u7684\u90a3\u79cd\uff09\uff0c\u8bfb\u8d77\u6765\u771f\u7684\u795e\u6e05\u6c14\u723d\u3002\u9879\u76ee\u4ee3\u7801\u4e5f\u5b8c\u5168\u5f00\u6e90\uff0c\u53eb\u505a<a href=\"https:\/\/github.com\/jaymody\/picoGPT\" target=\"_blank\" rel=\"noreferrer noopener\">picoGPT<\/a>(pico\uff0c\u679c\u7136\u662f\u4e0d\u80fd\u518d\u5c0f\u7684GPT\u4e86)\u3002<\/p>\n\n\n\n<p>\u539f\u6587\u94fe\u63a5\uff1a<a href=\"https:\/\/jaykmody.com\/blog\/gpt-from-scratch\/\" target=\"_blank\" rel=\"noreferrer noopener\">GPT in 60 Lines of NumPy<\/a><\/p>\n\n\n\n<p>\u8bd1\u6587\u94fe\u63a5\uff1a<a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/\">60\u884cNumPy\u624b\u6413GPT<\/a><\/p>\n\n\n\n<p>(\u5df2\u83b7\u539f\u6587\u4f5c\u8005\u6388\u6743)<\/p>\n\n\n\n<p>\u5173\u4e8e\u8bd1\u6587\u51e0\u70b9\u8bf4\u660e\uff1a<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>\u7ffb\u8bd1\u57fa\u672c\u6309\u7167\u539f\u4f5c\u8005\u7684\u8868\u8ff0\u548c\u903b\u8f91\uff0c\u4e2a\u522b\u90e8\u5206\u8bd1\u8005\u505a\u4e86\u8865\u5145\u548c\u770b\u6cd5\uff1b<\/li>\n\n\n\n<li>\u6587\u4e2d\u7684\u4e2a\u522b\u82f1\u6587\u672f\u8bed\u5f88\u96be\u7ffb\u8bd1\uff0c\u7b97\u662f\u8be5\u9886\u57df\u7684\u4e13\u6709\u540d\u8bcd\u4e86\uff0c\u56e0\u6b64\u8fd9\u7c7b\u672f\u8bed\u5c31\u76f4\u63a5\u4fdd\u7559\u4e86\uff0c\u6bd4\u5982transformer<\/li>\n<\/ul>\n\n\n\n<hr class=\"wp-block-separator has-alpha-channel-opacity\"\/>\n\n\n\n<p>\u5728\u672c\u6587\u4e2d\uff0c\u6211\u4eec\u5c06\u4ec5\u4ec5\u4f7f\u7528<a href=\"https:\/\/github.com\/jaymody\/picoGPT\/blob\/29e78cc52b58ed2c1c483ffea2eb46ff6bdec785\/gpt2_pico.py#L3-L58\" target=\"_blank\" rel=\"noreferrer noopener\">60\u884cNumpy<\/a>\uff0c\u4ece0-1\u5b9e\u73b0\u4e00\u4e2aGPT\u3002\u7136\u540e\u6211\u4eec\u5c06OpenAI\u53d1\u5e03\u7684GPT-2\u6a21\u578b\u7684\u6743\u91cd\u52a0\u8f7d\u8fdb\u6211\u4eec\u7684\u5b9e\u73b0\u5e76\u751f\u6210\u4e00\u4e9b\u6587\u672c\u3002<\/p>\n\n\n\n<p><strong>\u6ce8\u610f\uff1a<\/strong><\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>\u672c\u6587\u5047\u5b9a\u8bfb\u8005\u719f\u6089Python\uff0cNumpy\uff0c\u8fd8\u6709\u4e00\u4e9b\u8bad\u7ec3\u795e\u7ecf\u7f51\u7edc\u7684\u57fa\u672c\u7ecf\u9a8c\u3002<\/li>\n\n\n\n<li>\u8003\u8651\u5230\u5728\u4fdd\u6301\u5b8c\u6574\u6027\u7684\u540c\u65f6\u8ba9\u5b9e\u73b0\u5c3d\u53ef\u80fd\u7684\u7b80\u5355\uff0c\u672c\u6587\u7684\u5b9e\u73b0\u6545\u610f\u4e22\u5f03\u4e86\u539f\u59cb\u6a21\u578b\u7684\u5927\u91cf\u529f\u80fd\u548c\u7279\u70b9\u3002\u76ee\u7684\u5f88\u7b80\u5355\u554a\uff0c\u5c31\u662f\u63d0\u4f9b\u4e00\u4e2a<strong>\u7b80\u5355\u4e14\u5b8c\u6574\u7684GPT\u7684\u6280\u672f\u4ecb\u7ecd\uff0c\u4f5c\u4e3a\u6559\u5b66\u7528\u9014\u4f7f\u7528\u3002<\/strong><\/li>\n\n\n\n<li>GPT\u67b6\u6784\u53ea\u662fLLM\u53d6\u5f97\u4eca\u65f6\u4eca\u65e5\u6210\u5c31\u7684\u4e00\u4e2a\u5c0f\u5c0f\u7ec4\u6210\u90e8\u5206<sup>[1]<\/sup><\/li>\n\n\n\n<li>\u672c\u6587\u4e2d\u7684\u6240\u6709\u4ee3\u7801\u90fd\u53ef\u4ee5\u5728\u8fd9\u91cc\u627e\u5230:<code>https:\/\/github.com\/jaymody\/picoGPT<\/code><\/li>\n\n\n\n<li><a href=\"https:\/\/news.ycombinator.com\/item?id=34726115\" target=\"_blank\" rel=\"noreferrer noopener\">Hacker news\u4e0a\u5173\u4e8e\u672c\u6587\u7684\u8ba8\u8bba<\/a><\/li>\n<\/ul>\n\n\n\n<p><strong>\u66f4\u65b0(2023\/2\/9)\uff1a<\/strong>\u6dfb\u52a0\u4e86\u201d\u4e0b\u4e00\u6b65\u5462\uff1f\u201d\u90e8\u5206\uff0c\u5e76\u4e14\u66f4\u65b0\u4e86\u4ecb\u7ecd\u90e8\u5206<\/p>\n\n\n\n<p><strong>\u66f4\u65b0(2023\/2\/28)\uff1a<\/strong>\u4e3a\u201c\u4e0b\u4e00\u6b65\u5462\uff1f\u201d\u90e8\u5206\u53c8\u6dfb\u52a0\u4e86\u4e00\u4e9b\u5185\u5bb9<\/p>\n\n\n\n<hr class=\"wp-block-separator has-alpha-channel-opacity\"\/>\n\n\n\n<h2 class=\"wp-block-heading\" id=\"GPT\u662f\u4ec0\u4e48\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#GPT%E6%98%AF%E4%BB%80%E4%B9%88\"><\/a>1. GPT\u662f\u4ec0\u4e48?<\/strong><\/h2>\n\n\n\n<p>GPT\u4ee3\u8868<strong>\u751f\u6210\u5f0f\u9884\u8bad\u7ec3(Generative Pre-trained Transformer)<\/strong>\u3002\u8fd9\u662f\u4e00\u7c7b\u57fa\u4e8e<a href=\"https:\/\/arxiv.org\/pdf\/1706.03762.pdf\" target=\"_blank\" rel=\"noreferrer noopener\">Transformer<\/a>\u7684\u795e\u7ecf\u7f51\u7edc\u67b6\u6784\u3002<a href=\"https:\/\/jalammar.github.io\/how-gpt3-works-visualizations-animations\/\" target=\"_blank\" rel=\"noreferrer noopener\">Jay Alammar\u7684\u201dGPT3\u662f\u5982\u4f55\u5de5\u4f5c\u7684\u201d<\/a>\u4e00\u6587\u5728\u5b8f\u89c2\u89c6\u89d2\u4e0b\u5bf9GPT\u8fdb\u884c\u4e86\u7cbe\u5f69\u7684\u4ecb\u7ecd\u3002\u4f46\u8fd9\u91cc\u7b80\u5355\u6765\u8bf4\uff1a<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong>\u751f\u6210\u5f0f(Generative)\uff1a<\/strong>GPT\u53ef\u4ee5\u751f\u6210\u6587\u672c<\/li>\n\n\n\n<li><strong>\u9884\u8bad\u7ec3(Pre-trained)\uff1a<\/strong>GPT\u57fa\u4e8e\u6765\u81ea\u4e8e\u4e66\u672c\u3001\u4e92\u8054\u7f51\u7b49\u6765\u6e90\u7684\u6d77\u91cf\u6587\u672c\u8fdb\u884c\u8bad\u7ec3<\/li>\n\n\n\n<li><strong>Transformer\uff1a<\/strong>GPT\u662f\u4e00\u4e2a<em>decoder-only\u7684transformer<\/em>\u795e\u7ecf\u7f51\u7edc\u7ed3\u6784\u8bd1\u8005\u6ce8\uff1aTransformer\u5c31\u662f\u4e00\u79cd\u7279\u5b9a\u7684\u795e\u7ecf\u7f51\u7edc\u7ed3\u6784<\/li>\n<\/ul>\n\n\n\n<p>\u7c7b\u4f3c<a href=\"https:\/\/en.wikipedia.org\/wiki\/GPT-3\" target=\"_blank\" rel=\"noreferrer noopener\">OpenAI\u7684GPT-3<\/a>,&nbsp;<a href=\"https:\/\/blog.google\/technology\/ai\/lamda\/\" target=\"_blank\" rel=\"noreferrer noopener\">\u8c37\u6b4c\u7684LaMDA<\/a>\u8fd8\u6709<a href=\"https:\/\/docs.cohere.ai\/docs\/command-beta\" target=\"_blank\" rel=\"noreferrer noopener\">Cohere\u7684Command XLarge<\/a>\u7684\u5927\u8bed\u8a00\u6a21\u578b\u7684\u5e95\u5c42\u90fd\u662fGPT\u6a21\u578b\u3002\u8ba9\u5b83\u4eec\u8fd9\u4e48\u7279\u6b8a\u7684\u539f\u56e0\u662f<strong>1\uff09<\/strong>\u5b83\u4eec\u975e\u5e38\u7684\u5927\uff08\u6210\u767e\u4e0a\u5343\u4ebf\u7684\u53c2\u6570\uff09\uff1b<strong>2\uff09<\/strong>\u5b83\u4eec\u662f\u57fa\u4e8e\u6d77\u91cf\u6570\u636e\u8fdb\u884c\u8bad\u7ec3\u7684\uff08\u6210\u767e\u4e0a\u5343\u4e2aGB\u7684\u6587\u672c\u6570\u636e\uff09<\/p>\n\n\n\n<p>\u6839\u672c\u4e0a\u6765\u770b\uff0c\u7ed9\u5b9a\u4e00\u7ec4<strong>\u63d0\u793a<\/strong>(<strong>prompt<\/strong>)\uff0cGPT\u80fd\u591f\u57fa\u4e8e\u6b64<strong>\u751f\u6210\u6587\u672c<\/strong>(<strong>generates text<\/strong>)\u3002\u5373\u4f7f\u662f\u4f7f\u7528\u5982\u6b64\u7b80\u5355\u7684API\uff08input = \u6587\u672c\uff0coutput = \u6587\u672c\uff09\uff0c\u4e00\u4e2a\u8bad\u7ec3\u597d\u7684GPT\u80fd\u591f\u5b8c\u6210\u5f88\u591a\u51fa\u8272\u7684\u4efb\u52a1\uff0c\u6bd4\u5982<a href=\"https:\/\/machinelearningknowledge.ai\/ezoimgfmt\/b2611031.smushcdn.com\/2611031\/wp-content\/uploads\/2022\/12\/ChatGPT-Demo-of-Drafting-an-Email.png?lossy=0&amp;strip=1&amp;webp=1&amp;ezimgfmt=ng:webp\/ngcb1\" target=\"_blank\" rel=\"noreferrer noopener\">\u5e2e\u4f60\u5199\u90ae\u4ef6<\/a>\uff0c<a href=\"https:\/\/machinelearningknowledge.ai\/ezoimgfmt\/b2611031.smushcdn.com\/2611031\/wp-content\/uploads\/2022\/12\/ChatGPT-Example-Book-Summarization.png?lossy=0&amp;strip=1&amp;webp=1&amp;ezimgfmt=ng:webp\/ngcb1\" target=\"_blank\" rel=\"noreferrer noopener\">\u603b\u7ed3\u4e00\u672c\u4e66<\/a>\uff0c<a href=\"https:\/\/khrisdigital.com\/wp-content\/uploads\/2022\/12\/image-1.png\" target=\"_blank\" rel=\"noreferrer noopener\">\u7ed9\u4f60\u7684instagram\u8d77\u6807\u9898<\/a>\uff0c<a href=\"https:\/\/machinelearningknowledge.ai\/ezoimgfmt\/b2611031.smushcdn.com\/2611031\/wp-content\/uploads\/2022\/12\/ChatGPT-Examples-Explaining-Black-Holes.png?lossy=0&amp;strip=1&amp;webp=1&amp;ezimgfmt=ng:webp\/ngcb1\" target=\"_blank\" rel=\"noreferrer noopener\">\u7ed95\u5c81\u7684\u5c0f\u5b69\u89e3\u91ca\u4ec0\u4e48\u662f\u9ed1\u6d1e<\/a>\uff0c<a href=\"https:\/\/machinelearningknowledge.ai\/ezoimgfmt\/b2611031.smushcdn.com\/2611031\/wp-content\/uploads\/2022\/12\/ChatGPT-Demo-of-Writing-SQL-Queries.png?lossy=0&amp;strip=1&amp;webp=1&amp;ezimgfmt=ng:webp\/ngcb1\" target=\"_blank\" rel=\"noreferrer noopener\">\u5199SQL\u4ee3\u7801<\/a>\uff0c<a href=\"https:\/\/machinelearningknowledge.ai\/ezoimgfmt\/b2611031.smushcdn.com\/2611031\/wp-content\/uploads\/2022\/12\/Chat-GPT-Example-Writing-a-Will.png?lossy=0&amp;strip=1&amp;webp=1&amp;ezimgfmt=ng:webp\/ngcb1\" target=\"_blank\" rel=\"noreferrer noopener\">\u751a\u81f3\u5e2e\u4f60\u5199\u4e0b\u4f60\u7684\u9057\u5631<\/a>\u3002<\/p>\n\n\n\n<p>\u4ee5\u4e0a\u5c31\u662f\u5b8f\u89c2\u89c6\u89d2\u4e0b\u5173\u4e8eGPT\u7684\u6982\u89c8\u4ee5\u53ca\u5b83\u80fd\u591f\u505a\u7684\u4e8b\u60c5\u3002\u73b0\u5728\u8ba9\u6211\u4eec\u6df1\u5165\u4e00\u4e9b\u7ec6\u8282\u628a\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\" id=\"\u8f93\u5165-\u8f93\u5165\"><strong>1.1 \u8f93\u5165\/\u8f93\u51fa(Input\/Output)<\/strong><\/h3>\n\n\n\n<p>\u4e00\u4e2aGPT\u7684\u51fd\u6570\u7b7e\u540d\u57fa\u672c\u4e0a\u7c7b\u4f3c\u8fd9\u6837\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \" title=\"\u51fd\u6570gpt\uff0c\u5b83\u63a5\u53d7\u4e00\u4e2a\u6574\u6570\u5217\u8868\u4f5c\u4e3a\u8f93\u5165\uff0c\u5e76\u8fd4\u56de\u4e00\u4e2a\u6d6e\u70b9\u6570\u7684\u4e8c\u7ef4\u5217\u8868\">def gpt(inputs: list[int]) -&gt; list[list[float]]:\n    # inputs has shape [n_seq]\n    # output has shape [n_seq, n_vocab]\n    # \u8f93\u5165\u53c2\u6570 'inputs' \u662f\u4e00\u4e2a\u6574\u6570\u5217\u8868\uff0c\u8868\u793a\u5e8f\u5217\uff0c\u5176\u5f62\u72b6\u4e3a [n_seq]\uff0c\n    # \u5176\u4e2d 'n_seq' \u662f\u5e8f\u5217\u7684\u957f\u5ea6\u3002\n\n    # \u51fd\u6570\u7684\u8f93\u51fa\u662f\u4e00\u4e2a\u4e8c\u7ef4\u6d6e\u70b9\u6570\u5217\u8868\uff0c\u5176\u5f62\u72b6\u4e3a [n_seq, n_vocab]\uff0c\n    # 'n_vocab' \u8868\u793a\u8bcd\u6c47\u8868\u7684\u5927\u5c0f\uff0c\u5373\u6a21\u578b\u80fd\u591f\u8f93\u51fa\u7684\u4e0d\u540c\u6807\u7b7e\u7684\u6570\u91cf\u3002\n    # \u8fd9\u610f\u5473\u7740\u5bf9\u4e8e\u8f93\u5165\u5e8f\u5217\u4e2d\u7684\u6bcf\u4e2a\u5143\u7d20\uff0c\u6a21\u578b\u90fd\u4f1a\u8f93\u51fa\u4e00\u4e2a\u4e0e\u8bcd\u6c47\u8868\u5927\u5c0f\u76f8\u7b49\u7684\u6982\u7387\u5206\u5e03\u3002\n\n    output = # beep boop \u795e\u7ecf\u7f51\u7edc\u9b54\u6cd5\n    # \u4e0a\u9762\u7684\u6ce8\u91ca \"beep boop \u795e\u7ecf\u7f51\u7edc\u9b54\u6cd5\" \u6697\u793a\u8fd9\u91cc\u662f\u51fd\u6570\u7684\u6838\u5fc3\u5b9e\u73b0\uff0c\n    # \u5177\u4f53\u53ef\u80fd\u6d89\u53ca\u5230\u5bf9\u8f93\u5165\u6570\u636e\u7684\u795e\u7ecf\u7f51\u7edc\u5904\u7406\uff0c\u4f46\u5177\u4f53\u7684\u5b9e\u73b0\u7ec6\u8282\u5728\u8fd9\u6bb5\u4ee3\u7801\u4e2d\u5e76\u672a\u7ed9\u51fa\u3002\n\n    return output\n    # \u8fd4\u56de\u5904\u7406\u540e\u7684\u8f93\u51fa\u6570\u636e\u3002\n<\/pre><\/div>\n\n\n\n<h4 class=\"wp-block-heading\" id=\"\u8f93\u5165\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#%E8%BE%93%E5%85%A5\"><\/a>1.1.1 \u8f93\u5165(Input)<\/strong><\/h4>\n\n\n\n<p>\u8f93\u5165\u662f\u4e00\u4e9b\u6587\u672c\uff0c\u8fd9\u4e9b\u6587\u672c\u88ab\u8868\u793a\u6210<strong>\u4e00\u4e32\u6574\u6570\u5e8f\u5217<\/strong>(<strong>sequence of integers<\/strong>)\uff0c\u6bcf\u4e2a\u6574\u6570\u90fd\u4e0e\u6587\u672c\u4e2d\u7684<strong>tokens<\/strong>\u5bf9\u5e94\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \"># integers represent tokens in our text, for example:\n# \u6574\u6570\u4ee3\u8868\u6211\u4eec\u6587\u672c\u4e2d\u7684\u6807\u8bb0\uff08tokens\uff09\uff0c\u4f8b\u5982\uff1a\n# text   = \"not  all    heroes wear capes\":\n#           \u5e76\u975e \u6240\u6709\u7684 \u82f1\u96c4   \u90fd\u7a7f \u62ab\u98ce\n# tokens = \"not\"  \"all\" \"heroes\" \"wear\" \"capes\"\ninputs =   [1,     0,    2,      4,     6]\n# \u8fd9\u91cc 'inputs' \u662f\u4e00\u4e2a\u6574\u6570\u5217\u8868\uff0c\u5176\u4e2d\u7684\u6bcf\u4e2a\u6574\u6570\u4ee3\u8868\u5bf9\u5e94\u6587\u672c\u4e2d\u7684\u4e00\u4e2a\u5355\u8bcd\u3002\n# \u6bd4\u5982\uff0c'1' \u53ef\u80fd\u4ee3\u8868 \"not\"\uff0c'0' \u4ee3\u8868 \"all\"\uff0c\u4ee5\u6b64\u7c7b\u63a8\u3002\n# \u8fd9\u79cd\u6620\u5c04\u5173\u7cfb\uff08\u5355\u8bcd\u5230\u6574\u6570\u7684\u6620\u5c04\uff09\u901a\u5e38\u662f\u9884\u5148\u5b9a\u4e49\u597d\u7684\uff0c\u57fa\u4e8e\u67d0\u79cd\u65b9\u5f0f\u8fdb\u884c\u5355\u8bcd\u7684\u7f16\u53f7\u3002\n# \u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u7684\u4efb\u52a1\u4e2d\uff0c\u8fd9\u6837\u7684\u8f6c\u6362\u662f\u5e38\u89c1\u7684\u7b2c\u4e00\u6b65\uff0c\n# \u5b83\u5c06\u6587\u672c\u6570\u636e\u8f6c\u5316\u4e3a\u6a21\u578b\u53ef\u4ee5\u5904\u7406\u7684\u6570\u503c\u5f62\u5f0f\u3002<\/pre><\/div>\n\n\n\n<p>token\u662f\u6587\u672c\u7684\u5c0f\u7247\u6bb5\uff0c\u5b83\u4eec\u7531\u67d0\u79cd<strong>\u5206\u8bcd\u5668\uff08tokenizer\uff09<\/strong>\u4ea7\u751f\u3002\u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u4e00\u4e2a<strong>\u8bcd\u6c47\u8868(vocabulary)<\/strong>\u5c06tokens\u6620\u5c04\u4e3a\u6574\u6570\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \"># \u8bcd\u6c47\u8868\u4e2d\u6807\u8bb0(token)\u7684\u7d22\u5f15\u4ee3\u8868\u4e86\u8be5\u6807\u8bb0(token)\u7684\u6574\u6570ID\n# \u4f8b\u5982\uff0c\"heroes\"\u7684\u6574\u6570ID\u4e3a2\uff0c\u56e0\u4e3a vocab[2] = \"heroes\"\nvocab = [\"all\", \"not\", \"heroes\", \"the\", \"wear\", \".\", \"capes\"]\n\n# \u4e00\u4e2a\u5047\u8bbe\u7684\u5206\u8bcd\u5668\uff0c\u8fd9\u4e2a\u5206\u8bcd\u5668\u57fa\u4e8e\u7a7a\u683c\u8fdb\u884c\u5206\u8bcd\ntokenizer = WhitespaceTokenizer(vocab)\n\n# encode() \u65b9\u6cd5\u5c06\u5b57\u7b26\u4e32\u8f6c\u6362\u4e3a\u6574\u6570\u5217\u8868\nids = tokenizer.encode(\"not all heroes wear\") # ids = [1, 0, 2, 4]\n# \u8fd9\u91cc\uff0c\"not all heroes wear\" \u88ab\u5206\u8bcd\u5e76\u8f6c\u6362\u4e3a\u5bf9\u5e94\u7684\u6574\u6570ID\u5217\u8868\n\n# \u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u8bcd\u6c47\u8868\u6620\u5c04\u6765\u67e5\u770b\u5b9e\u9645\u7684\u6807\u8bb0(token)\ntokens = [tokenizer.vocab[i] for i in ids] # tokens = [\"not\", \"all\", \"heroes\", \"wear\"]\n# \u901a\u8fc7\u7d22\u5f15\u8bbf\u95ee\u8bcd\u6c47\u8868\uff0c\u5c06\u6574\u6570ID\u5217\u8868\u8f6c\u6362\u56de\u5b83\u4eec\u5bf9\u5e94\u7684\u6807\u8bb0(token)\uff08\u5355\u8bcd\uff09\n\n# decode() \u65b9\u6cd5\u5c06\u6574\u6570\u5217\u8868\u8f6c\u6362\u56de\u5b57\u7b26\u4e32\ntext = tokenizer.decode(ids) # text = \"not all heroes wear\"\n# \u5c06\u6574\u6570ID\u5217\u8868\u8fd8\u539f\u4e3a\u539f\u59cb\u7684\u5b57\u7b26\u4e32\u6587\u672c<\/pre><\/div>\n\n\n\n<p>\u7b80\u5355\u8bf4\uff1a<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>\u6211\u4eec\u6709\u4e00\u4e2a\u5b57\u7b26\u4e32<\/li>\n\n\n\n<li>\u6211\u4eec\u4f7f\u7528tokenizer\u5c06\u5176\u62c6\u89e3\u4e3a\u5c0f\u7247\u6bb5-\u6211\u4eec\u79f0\u4e4b\u4e3atoken<\/li>\n\n\n\n<li>\u6211\u4eec\u4f7f\u7528\u8bcd\u6c47\u8868\u5c06\u8fd9\u4e9btoken\u6620\u5c04\u4e3a\u6574\u6570<\/li>\n<\/ul>\n\n\n\n<p>\u5728\u5b9e\u9645\u4e2d\uff0c\u6211\u4eec\u4e0d\u4ec5\u4ec5\u4f7f\u7528\u7b80\u5355\u7684\u901a\u8fc7\u7a7a\u767d\u5206\u9694\u53bb\u505a\u5206\u8bcd\uff0c\u6211\u4eec\u4f1a\u4f7f\u7528\u4e00\u4e9b\u66f4\u9ad8\u7ea7\u7684\u65b9\u6cd5\uff0c\u6bd4\u5982<a href=\"https:\/\/huggingface.co\/course\/chapter6\/5?fw=pt\" target=\"_blank\" rel=\"noreferrer noopener\">Byte-Pair Encoding<\/a>\u6216\u8005<a href=\"https:\/\/huggingface.co\/course\/chapter6\/6?fw=pt\" target=\"_blank\" rel=\"noreferrer noopener\">WordPiece<\/a>\uff0c\u4f46\u5b83\u4eec\u7684\u539f\u7406\u662f\u4e00\u6837\u7684\uff1a<\/p>\n\n\n\n<ol class=\"wp-block-list\">\n<li>\u6709\u4e00\u4e2a<code>vocab<\/code>\u5373\u8bcd\u6c47\u8868\uff0c\u53ef\u4ee5\u5c06\u5b57\u7b26\u4e32token\u6620\u5c04\u5230\u6574\u6570\u7d22\u5f15<\/li>\n\n\n\n<li>\u6709\u4e00\u4e2a<code>encode<\/code>\u65b9\u6cd5\uff0c\u5373\u7f16\u7801\u65b9\u6cd5\uff0c\u53ef\u4ee5\u5b9e\u73b0<code>str -&gt; list[int]<\/code>\u7684\u8f6c\u5316<\/li>\n\n\n\n<li>\u6709\u4e00\u4e2a<code>decode<\/code>\u65b9\u6cd5\uff0c\u5373\u89e3\u7801\u65b9\u6cd5\uff0c\u53ef\u4ee5\u5b9e\u73b0<code>list[int] -&gt; str<\/code>\u7684\u8f6c\u5316<sup><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#fn:2\">[2]<\/a><\/sup><\/li>\n<\/ol>\n\n\n\n<h4 class=\"wp-block-heading\" id=\"\u8f93\u51fa\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#%E8%BE%93%E5%87%BA\"><\/a>1.1.2 \u8f93\u51fa(Output)<\/strong><\/h4>\n\n\n\n<p>\u8f93\u51fa\u662f\u4e00\u4e2a<strong>\u4e8c\u7ef4\u6570\u7ec4<\/strong>(<strong>2D array<\/strong>)\uff0c\u5176\u4e2d<code>output[i][j]<\/code>\u8868\u793a\u6a21\u578b\u7684<strong>\u9884\u6d4b\u6982\u7387<\/strong>(<strong>predicted probability<\/strong>)\uff0c\u8fd9\u4e2a\u6982\u7387\u4ee3\u8868\u4e86\u8bcd\u6c47\u8868\u4e2d\u4f4d\u4e8e<code>vocab[j]<\/code>\u7684token\u662f\u4e0b\u4e00\u4e2atoken<code>inputs[i+1]<\/code>\u7684\u6982\u7387\u3002\u6bd4\u5982\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \"># \u5b9a\u4e49\u4e00\u4e2a\u8bcd\u6c47\u8868\uff0c\u5305\u542b\u4e86\u4e00\u7cfb\u5217\u5355\u8bcd\u3002\nvocab = [\"all\", \"not\", \"heroes\", \"the\", \"wear\", \".\", \"capes\"]\n\n# \u5b9a\u4e49\u8f93\u5165\u5e8f\u5217\uff0c\u8fd9\u91cc\u4f7f\u7528\u6574\u6570\u5217\u8868\u8868\u793a\u6587\u672c\u4e2d\u7684\u5355\u8bcd\uff0c\u6574\u6570\u662f\u8bcd\u6c47\u8868\u4e2d\u5bf9\u5e94\u5355\u8bcd\u7684\u7d22\u5f15\u3002\ninputs = [1, 0, 2, 4] # \u5bf9\u5e94\u4e8e \"not\" \"all\" \"heroes\" \"wear\"\n\n# \u8c03\u7528gpt\u51fd\u6570\uff0c\u4f20\u5165\u4e0a\u8ff0\u8f93\u5165\u5e8f\u5217\u3002\noutput = gpt(inputs)\n# \u51fd\u6570\u8fd4\u56de\u4e00\u4e2a\u8f93\u51fa\u5217\u8868\uff0c\u6bcf\u4e2a\u5143\u7d20\u662f\u4e00\u4e2a\u6982\u7387\u5206\u5e03\u5217\u8868\uff0c\u8868\u793a\u4e0b\u4e00\u4e2a\u5355\u8bcd\u662f\u8bcd\u6c47\u8868\u4e2d\u6bcf\u4e2a\u5355\u8bcd\u7684\u6982\u7387\u3002\n\n# \u4e0b\u9762\u662f\u5bf9\u8f93\u51fa\u7684\u8be6\u7ec6\u89e3\u91ca\uff1a\n\n# output[0] =  [0.75    0.1     0.0       0.15    0.0   0.0    0.0  ]\n# \u7ed9\u5b9a\u8f93\u5165\"not\"\u65f6\uff0c\u6a21\u578b\u9884\u6d4b\u4e0b\u4e00\u4e2a\u5355\u8bcd\u662f\"all\"\u7684\u6982\u7387\u6700\u9ad8\u3002\n\n# output[1] =  [0.0     0.0      0.8     0.1    0.0    0.0   0.1  ]\n# \u7ed9\u5b9a\u5e8f\u5217[\"not\", \"all\"]\u65f6\uff0c\u6a21\u578b\u9884\u6d4b\u4e0b\u4e00\u4e2a\u5355\u8bcd\u662f\"heroes\"\u7684\u6982\u7387\u6700\u9ad8\u3002\n\n# output[-1] = [0.0     0.0     0.0     0.1     0.0    0.05  0.85  ]\n# \u7ed9\u5b9a\u5b8c\u6574\u5e8f\u5217[\"not\", \"all\", \"heroes\", \"wear\"]\u65f6\uff0c\u6a21\u578b\u9884\u6d4b\u4e0b\u4e00\u4e2a\u5355\u8bcd\u662f\"capes\"\u7684\u6982\u7387\u6700\u9ad8\u3002\n# \u5728Python\u4e2d\uff0c-1\u4f5c\u4e3a\u7d22\u5f15\u8868\u793a\u5217\u8868\u4e2d\u7684\u6700\u540e\u4e00\u4e2a\u5143\u7d20\u3002\u56e0\u6b64\uff0coutput[-1]\u6307\u7684\u662foutput\u5217\u8868\u4e2d\u7684\u6700\u540e\u4e00\u4e2a\u5143\u7d20\n\n# \u8fd9\u4e9b\u8f93\u51fa\u5c55\u793a\u4e86\u6a21\u578b\u662f\u5982\u4f55\u57fa\u4e8e\u7ed9\u5b9a\u7684\u5355\u8bcd\u5e8f\u5217\u6765\u9884\u6d4b\u4e0b\u4e00\u4e2a\u6700\u53ef\u80fd\u7684\u5355\u8bcd\u3002\n# \u6bcf\u4e2a\u8f93\u51fa\u5411\u91cf\u7684\u6bcf\u4e2a\u5143\u7d20\u4ee3\u8868\u8bcd\u6c47\u8868\u4e2d\u5bf9\u5e94\u5355\u8bcd\u88ab\u9884\u6d4b\u4e3a\u4e0b\u4e00\u4e2a\u5355\u8bcd\u7684\u6982\u7387\u3002\n<\/pre><\/div>\n\n\n\n<p>\u4e3a\u4e86\u9488\u5bf9\u6574\u4e2a\u5e8f\u5217\u83b7\u5f97<strong>\u4e0b\u4e00\u4e2atoken\u9884\u6d4b<\/strong>(<strong>next token prediction<\/strong>)\uff0c \u6211\u4eec\u53ef\u4ee5\u7b80\u5355\u7684\u9009\u62e9<code>output[-1]<\/code>\u4e2d\u6982\u7387\u6700\u5927\u7684\u90a3\u4e2atoken\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \"># \u5b9a\u4e49\u4e00\u4e2a\u8bcd\u6c47\u8868\uff0c\u5305\u542b\u4e00\u7cfb\u5217\u9884\u5b9a\u4e49\u7684\u5355\u8bcd\u3002\nvocab = [\"all\", \"not\", \"heroes\", \"the\", \"wear\", \".\", \"capes\"]\n\n# \u5b9a\u4e49\u8f93\u5165\u5e8f\u5217\uff0c\u8fd9\u662f\u4e00\u4e2a\u6574\u6570\u5217\u8868\uff0c\u6bcf\u4e2a\u6574\u6570\u4ee3\u8868\u8bcd\u6c47\u8868\u4e2d\u5bf9\u5e94\u5355\u8bcd\u7684\u7d22\u5f15\u3002\n# \u8fd9\u91cc\u7684\u8f93\u5165\u5e8f\u5217\u4ee3\u8868\u7684\u6587\u672c\u662f \"not all heroes wear\"\u3002\ninputs = [1, 0, 2, 4] # \u5bf9\u5e94\u4e8e \"not\" \"all\" \"heroes\" \"wear\"\n\n# \u8c03\u7528gpt\u6a21\u578b\u51fd\u6570\uff0c\u4f20\u5165\u8f93\u5165\u5e8f\u5217\uff0c\u6a21\u578b\u5c06\u57fa\u4e8e\u8fd9\u4e9b\u8f93\u5165\u9884\u6d4b\u4e0b\u4e00\u4e2a\u5355\u8bcd\u7684\u6982\u7387\u5206\u5e03\u3002\n# 'output' \u662f\u4e00\u4e2a\u5217\u8868\uff0c\u5176\u4e2d\u6bcf\u4e2a\u5143\u7d20\u4ee3\u8868\u5bf9\u5e94\u4e8e\u8bcd\u6c47\u8868\u4e2d\u6bcf\u4e2a\u5355\u8bcd\u4f5c\u4e3a\u4e0b\u4e00\u4e2a\u5355\u8bcd\u7684\u6982\u7387\u3002\noutput = gpt(inputs)\n\n# \u4f7f\u7528np.argmax\u51fd\u6570\u4ece\u6a21\u578b\u9884\u6d4b\u7684\u6700\u540e\u4e00\u4e2a\u5355\u8bcd\u7684\u6982\u7387\u5206\u5e03\u4e2d\u627e\u51fa\u6982\u7387\u6700\u9ad8\u7684\u5355\u8bcd\u7d22\u5f15\u3002\n# \u8fd9\u91cc\u7684 '-1' \u4ee3\u8868\u5217\u8868\u7684\u6700\u540e\u4e00\u4e2a\u5143\u7d20\uff0c\u5373\u57fa\u4e8e\u6574\u4e2a\u8f93\u5165\u5e8f\u5217\u7684\u9884\u6d4b\u7ed3\u679c\u3002\nnext_token_id = np.argmax(output[-1]) # next_token_id = 6\uff0c\u8868\u793a\u6982\u7387\u6700\u9ad8\u7684\u5355\u8bcd\u662f\u8bcd\u6c47\u8868\u4e2d\u7d22\u5f15\u4e3a6\u7684\u5355\u8bcd\u3002\n\n# \u6839\u636e\u9884\u6d4b\u51fa\u7684\u4e0b\u4e00\u4e2a\u5355\u8bcd\u7d22\u5f15\uff08next_token_id\uff09\uff0c\u4ece\u8bcd\u6c47\u8868\u4e2d\u83b7\u53d6\u5b9e\u9645\u7684\u5355\u8bcd\u3002\n# \u8fd9\u91cc\u7684 \"capes\" \u662f\u6839\u636e\u6a21\u578b\u9884\u6d4b\uff0c\u7ed9\u5b9a\u8f93\u5165\u5e8f\u5217\"not all heroes wear\"\u7684\u4e0b\u4e00\u4e2a\u6700\u53ef\u80fd\u7684\u5355\u8bcd\u3002\nnext_token = vocab[next_token_id] # next_token = \"capes\"\n<\/pre><\/div>\n\n\n\n<p>\u5c06\u5177\u6709\u6700\u9ad8\u6982\u7387\u7684token\u4f5c\u4e3a\u6211\u4eec\u7684\u9884\u6d4b\uff0c\u53eb\u505a<a href=\"https:\/\/docs.cohere.ai\/docs\/controlling-generation-with-top-k-top-p#1-pick-the-top-token-greedy-decoding\" target=\"_blank\" rel=\"noreferrer noopener\">greedy decoding<\/a>\u6216\u8005<strong>greedy sampling(\u8d2a\u5a6a\u91c7\u6837)<\/strong>\u3002<\/p>\n\n\n\n<p>\u5728\u4e00\u4e2a\u5e8f\u5217\u4e2d\u9884\u6d4b\u4e0b\u4e00\u4e2a\u903b\u8f91\u8bcd(logical word)\u7684\u4efb\u52a1\u88ab\u79f0\u4e4b\u4e3a<strong>\u8bed\u8a00\u5efa\u6a21<\/strong>(<strong>language modeling<\/strong>)\u3002\u56e0\u6b64\u6211\u4eec\u53ef\u4ee5\u79f0GPT\u4e3a<strong>\u8bed\u8a00\u6a21\u578b<\/strong>(<strong>language model<\/strong>)\u3002<\/p>\n\n\n\n<p>\u751f\u6210\u4e00\u4e2a\u5355\u8bcd\u662f\u633a\u9177\u7684\uff08\u4f46\u4e5f\u5c31\u90a3\u6837\u4e86\uff09\uff0c\u4f46\u662f\u8981\u662f\u751f\u6210\u6574\u4e2a\u53e5\u5b50\u3001\u6574\u7bc7\u6587\u7ae0\u5462\uff1f<\/p>\n\n\n\n<h3 class=\"wp-block-heading\" id=\"\u751f\u6210\u6587\u672c\"><strong>1.2 \u751f\u6210\u6587\u672c(Generating Text)<\/strong><\/h3>\n\n\n\n<h4 class=\"wp-block-heading\" id=\"\u81ea\u56de\u5f52\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#%E8%87%AA%E5%9B%9E%E5%BD%92\"><\/a>1.2.1 \u81ea\u56de\u5f52(Autoregressive)<\/strong><\/h4>\n\n\n\n<p>\u6211\u4eec\u53ef\u4ee5\u8fed\u4ee3\u5730\u901a\u8fc7\u6a21\u578b\u83b7\u53d6\u4e0b\u4e00\u4e2atoken\u7684\u9884\u6d4b\uff0c\u4ece\u800c\u751f\u6210\u6574\u4e2a\u53e5\u5b50\u3002\u5728\u6bcf\u6b21\u8fed\u4ee3\u4e2d\uff0c\u6211\u4eec\u5c06\u9884\u6d4b\u7684token\u518d\u6dfb\u52a0\u56de\u8f93\u5165\u4e2d\u53bb\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \"># \u5b9a\u4e49\u4e00\u4e2a\u751f\u6210\u6587\u672c\u5e8f\u5217\u7684\u51fd\u6570\uff0c\u63a5\u6536\u521d\u59cb\u8f93\u5165\u548c\u8981\u751f\u6210\u7684\u5355\u8bcd\u6570\u91cf\u4f5c\u4e3a\u53c2\u6570\u3002\n# \u5728Python\u4e2d\uff0c\u4e0b\u5212\u7ebf _ \u5728for\u5faa\u73af\u4e2d\u4f5c\u4e3a\u4e00\u4e2a\u53d8\u91cf\u540d\u88ab\u4f7f\u7528\uff0c\n# \u5176\u76ee\u7684\u4e3b\u8981\u662f\u8868\u793a\u8fd9\u4e2a\u53d8\u91cf\u662f\u4e34\u65f6\u7684\u6216\u4e0d\u91cd\u8981\u7684\uff0c\u5373\u6211\u4eec\u4e0d\u6253\u7b97\u5728\u5faa\u73af\u4f53\u5185\u4f7f\u7528\u8fd9\u4e2a\u53d8\u91cf\u3002\n# \u8fd9\u662f\u4e00\u79cd\u60ef\u4f8b\uff0c\u7528\u6765\u544a\u8bc9\u5176\u4ed6\u9605\u8bfb\u4ee3\u7801\u7684\u4eba\uff0c\u8fd9\u4e2a\u53d8\u91cf\u5728\u5faa\u73af\u8fc7\u7a0b\u4e2d\u5c06\u4e0d\u4f1a\u88ab\u7528\u5230\ndef generate(inputs, n_tokens_to_generate):\n    # \u81ea\u56de\u5f52\u89e3\u7801\u5faa\u73af\uff0c\u6839\u636e\u8981\u751f\u6210\u7684\u5355\u8bcd\u6570\u91cf\u8fdb\u884c\u8fed\u4ee3\u3002\n    for _ in range(n_tokens_to_generate):  # \u81ea\u56de\u5f52\u89e3\u7801\u5faa\u73af\n        output = gpt(inputs)  # \u6a21\u578b\u524d\u5411\u4f20\u64ad\uff0c\u83b7\u53d6\u9884\u6d4b\u7ed3\u679c\n        next_id = np.argmax(output[-1])  # \u8d2a\u5a6a\u91c7\u6837\uff0c\u9009\u62e9\u6982\u7387\u6700\u9ad8\u7684\u5355\u8bcd\u7d22\u5f15\n        inputs.append(int(next_id))  # \u5c06\u9884\u6d4b\u7684\u5355\u8bcd\u7d22\u5f15\u6dfb\u52a0\u5230\u8f93\u5165\u5e8f\u5217\u4e2d\uff0c\u7528\u4e8e\u4e0b\u4e00\u6b21\u9884\u6d4b\n    # \u51fd\u6570\u8fd4\u56de\u65b0\u589e\u52a0\u7684\u5355\u8bcd\u7d22\u5f15\uff0c\u5373\u751f\u6210\u7684\u5355\u8bcd\u5e8f\u5217\n    return inputs[len(inputs) - n_tokens_to_generate :]  # \u53ea\u8fd4\u56de\u751f\u6210\u7684\u5355\u8bcd\u7d22\u5f15\n\n# \u5b9a\u4e49\u521d\u59cb\u8f93\u5165\u5e8f\u5217\uff0c\u4ee3\u8868\u7684\u6587\u672c\u662f \"not all\"\u3002\ninput_ids = [1, 0] # \"not\" \"all\"\n\n# \u8c03\u7528generate\u51fd\u6570\uff0c\u57fa\u4e8e\u521d\u59cb\u8f93\u5165\u751f\u62103\u4e2a\u5355\u8bcd\u3002\noutput_ids = generate(input_ids, 3) # output_ids = [2, 4, 6]\n\n# \u6839\u636e\u751f\u6210\u7684\u5355\u8bcd\u7d22\u5f15\u5217\u8868\uff0c\u901a\u8fc7\u8bcd\u6c47\u8868\u6620\u5c04\u83b7\u53d6\u5b9e\u9645\u7684\u5355\u8bcd\u3002\noutput_tokens = [vocab[i] for i in output_ids] # \u6839\u636e\u7d22\u5f15\u5f97\u5230\u7684\u5355\u8bcd\u662f \"heroes\" \"wear\" \"capes\"\n<\/pre><\/div>\n\n\n\n<p>\u8fd9\u4e2a\u8fc7\u7a0b\u662f\u5728\u9884\u6d4b\u672a\u6765\u7684\u503c\uff08\u56de\u5f52\uff09\uff0c\u5e76\u4e14\u5c06\u9884\u6d4b\u7684\u503c\u6dfb\u52a0\u56de\u8f93\u5165\u4e2d\u53bb\uff08auto\uff09\uff0c\u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48\u4f60\u4f1a\u770b\u5230GPT\u88ab\u63cf\u8ff0\u4e3a<strong>\u81ea\u56de\u5f52\u6a21\u578b<\/strong>(<strong>autoregressive<\/strong>)\u3002<\/p>\n\n\n\n<h4 class=\"wp-block-heading\" id=\"\u91c7\u6837\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#%E9%87%87%E6%A0%B7\"><\/a>1.2.2 \u91c7\u6837(Sampling)<\/strong><\/h4>\n\n\n\n<p>\u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u5bf9\u6982\u7387\u5206\u5e03\u8fdb\u884c\u91c7\u6837\u6765\u66ff\u4ee3\u8d2a\u5fc3\u91c7\u6837\uff0c\u4ece\u800c\u4e3a\u6211\u4eec\u7684\u751f\u6210\u5f15\u5165\u4e00\u4e9b<strong>\u968f\u673a\u6027\uff08stochasticity\uff09<\/strong>\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \"># \u5b9a\u4e49\u4e00\u4e2a\u8f93\u5165\u5e8f\u5217\uff0c\u8fd9\u4e9b\u6574\u6570\u4ee3\u8868\u8bcd\u6c47\u8868\u4e2d\u5bf9\u5e94\u7684\u5355\u8bcd\u7d22\u5f15\u3002\ninputs = [1, 0, 2, 4] # \u5206\u522b\u5bf9\u5e94\u4e8e \"not\" \"all\" \"heroes\" \"wear\"\n\n# \u8c03\u7528\u4e00\u4e2a\u540d\u4e3agpt\u7684\u51fd\u6570\uff08\u5047\u8bbe\u662f\u4e00\u4e2a\u9884\u8bad\u7ec3\u7684\u8bed\u8a00\u6a21\u578b\uff09\uff0c\u4f20\u5165\u8f93\u5165\u5e8f\u5217\u3002\noutput = gpt(inputs)\n\n# \u63a5\u4e0b\u6765\u7684\u51e0\u884c\u4ee3\u7801\u5c55\u793a\u4e86\u5982\u4f55\u4f7f\u7528\u6a21\u578b\u7684\u8f93\u51fa\uff08\u6700\u540e\u4e00\u4e2a\u5355\u8bcd\u7684\u6982\u7387\u5206\u5e03\uff09\u6765\u968f\u673a\u9009\u62e9\u4e0b\u4e00\u4e2a\u5355\u8bcd\u3002\n# np.random.choice \u51fd\u6570\u4ece\u7ed9\u5b9a\u7684\u8303\u56f4\u5185\u6839\u636e\u6982\u7387\u5206\u5e03\u968f\u673a\u9009\u62e9\u4e00\u4e2a\u5143\u7d20\u3002\n# np.arange(vocab_size) \u521b\u5efa\u4e00\u4e2a\u4ece0\u5230vocab_size\uff08\u8bcd\u6c47\u8868\u5927\u5c0f\uff09\u7684\u6570\u7ec4\u3002\n# p=output[-1] \u53c2\u6570\u6307\u5b9a\u4e86\u6bcf\u4e2a\u5355\u8bcd\u88ab\u9009\u4e2d\u7684\u6982\u7387\uff0c\u8fd9\u4e2a\u6982\u7387\u6765\u81ea\u4e8e\u6a21\u578b\u5bf9\u6700\u540e\u4e00\u4e2a\u5355\u8bcd\u7684\u9884\u6d4b\u8f93\u51fa\u3002\n\nnp.random.choice(np.arange(vocab_size), p=output[-1]) # \u6839\u636e\u6982\u7387\u53ef\u80fd\u9009\u62e9 \"capes\"\nnp.random.choice(np.arange(vocab_size), p=output[-1]) # \u6839\u636e\u6982\u7387\u53ef\u80fd\u9009\u62e9 \"hats\"\nnp.random.choice(np.arange(vocab_size), p=output[-1]) # \u6839\u636e\u6982\u7387\u53ef\u80fd\u518d\u6b21\u9009\u62e9 \"capes\"\nnp.random.choice(np.arange(vocab_size), p=output[-1]) # \u6839\u636e\u6982\u7387\u53ef\u80fd\u518d\u6b21\u9009\u62e9 \"capes\"\nnp.random.choice(np.arange(vocab_size), p=output[-1]) # \u6839\u636e\u6982\u7387\u53ef\u80fd\u9009\u62e9 \"pants\"\n\n# \u8bf7\u6ce8\u610f\uff0c\u867d\u7136\u8fd9\u91cc\u7ed9\u51fa\u4e86\u5177\u4f53\u7684\u5355\u8bcd\u4f5c\u4e3a\u6ce8\u91ca\uff0c\u5b9e\u9645\u4e0anp.random.choice\u7684\u7ed3\u679c\u662f\u4e0d\u786e\u5b9a\u7684\uff0c\n# \u5b83\u53d6\u51b3\u4e8e\u6bcf\u6b21\u8c03\u7528\u65f6\u7684\u6982\u7387\u5206\u5e03\u3002\u8fd9\u610f\u5473\u7740\u6bcf\u6b21\u8fd0\u884c\u8fd9\u6bb5\u4ee3\u7801\u65f6\uff0c\u9009\u62e9\u7684\u5355\u8bcd\u53ef\u80fd\u4f1a\u6709\u6240\u4e0d\u540c\u3002\n<\/pre><\/div>\n\n\n\n<p>\u8fd9\u6837\u5b50\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u57fa\u4e8e\u540c\u4e00\u4e2a\u8f93\u5165\u4ea7\u751f\u4e0d\u540c\u7684\u8f93\u51fa\u53e5\u5b50\u5566\u3002\u5f53\u6211\u4eec\u7ed3\u5408\u66f4\u591a\u7684\u6bd4\u5982<a href=\"https:\/\/docs.cohere.ai\/docs\/controlling-generation-with-top-k-top-p#2-pick-from-amongst-the-top-tokens-top-k\" target=\"_blank\" rel=\"noreferrer noopener\">top-k<\/a>\uff0c<a href=\"https:\/\/docs.cohere.ai\/docs\/controlling-generation-with-top-k-top-p#3-pick-from-amongst-the-top-tokens-whose-probabilities-add-up-to-15-top-p\" target=\"_blank\" rel=\"noreferrer noopener\">top-p<\/a>\u548c<a href=\"https:\/\/docs.cohere.ai\/docs\/temperature\"><strong>temperature<\/strong><\/a>(<a href=\"https:\/\/docs.cohere.ai\/docs\/temperature\" target=\"_blank\" rel=\"noreferrer noopener\">\u6e29\u5ea6<\/a>)\u8fd9\u6837\u7684\u6280\u5de7\u7684\u65f6\u5019\uff0c\uff08\u8fd9\u4e9b\u6280\u5de7\u80fd\u591f\u80fd\u66f4\u6539\u91c7\u6837\u7684\u5206\u5e03\uff09\uff0c\u6211\u4eec\u8f93\u51fa\u7684\u8d28\u91cf\u4e5f\u4f1a\u6709\u5f88\u5927\u7684\u63d0\u9ad8\u3002\u8fd9\u4e9b\u6280\u5de7\u4e5f\u5f15\u5165\u4e86\u4e00\u4e9b\u8d85\u53c2\u6570\uff0c\u901a\u8fc7\u8c03\u6574\u8fd9\u4e9b\u8d85\u53c2\uff0c\u6211\u4eec\u53ef\u4ee5\u83b7\u5f97\u4e0d\u540c\u7684\u751f\u6210\u8868\u73b0(behaviors)\u3002\u6bd4\u5982\u63d0\u9ad8\u6e29\u5ea6\u8d85\u53c2\uff0c\u6211\u4eec\u7684\u6a21\u578b\u5c31\u4f1a\u66f4\u52a0\u5192\u8fdb\uff0c\u4ece\u800c\u53d8\u5f97\u66f4\u6709\u201c<strong>\u521b\u9020\u529b(creative)<\/strong>\u201d\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\" id=\"\u8bad\u7ec3\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#%E8%AE%AD%E7%BB%83\"><\/a>1.3 \u8bad\u7ec3(Training)<\/strong><\/h3>\n\n\n\n<p>\u6211\u4eec\u4e0e\u8bad\u7ec3\u5176\u5b83\u795e\u7ecf\u7f51\u7edc\u4e00\u6837\uff0c\u9488\u5bf9\u7279\u5b9a\u7684<strong>\u635f\u5931\u51fd\u6570<\/strong>(<strong>loss function<\/strong>)\u4f7f\u7528<a href=\"https:\/\/arxiv.org\/pdf\/1609.04747.pdf\" target=\"_blank\" rel=\"noreferrer noopener\">\u68af\u5ea6\u4e0b\u964d<\/a>(<a href=\"https:\/\/arxiv.org\/pdf\/1609.04747.pdf\"><strong>gradient descent<\/strong><\/a>)\u8bad\u7ec3GPT\u3002\u5bf9\u4e8eGPT\uff0c\u6211\u4eec\u4f7f\u7528<strong>\u8bed\u8a00\u5efa\u6a21\u4efb\u52a1\u7684<a href=\"https:\/\/www.youtube.com\/watch?v=ErfnhcEV1O8\" target=\"_blank\" rel=\"noreferrer noopener\">\u4ea4\u53c9\u71b5\u635f\u5931<\/a><\/strong>(<strong><a href=\"https:\/\/www.youtube.com\/watch?v=ErfnhcEV1O8\">cross entropy loss<\/a> over the language modeling task<\/strong>)\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \"># \u5b9a\u4e49\u4e00\u4e2a\u8ba1\u7b97\u8bed\u8a00\u6a21\u578b\u635f\u5931\u7684\u51fd\u6570\u3002\u8fd9\u4e2a\u51fd\u6570\u63a5\u53d7\u4e00\u4e2a\u6574\u6570\u5217\u8868\uff08\u4ee3\u8868\u6587\u672c\u4e2d\u5355\u8bcd\u7684\u7d22\u5f15\uff09\u548c\u6a21\u578b\u53c2\u6570\u3002\ndef lm_loss(inputs: list[int], params) -&gt; float:\n    # \u6807\u8bb0(token)y\u662f\u8f93\u5165\u5411\u5de6\u79fb\u52a8\u4e00\u4f4d\u7684\u7ed3\u679c\u3002\u8fd9\u610f\u5473\u7740\uff0c\u5bf9\u4e8e\u8f93\u5165\u5e8f\u5217\u4e2d\u7684\u6bcf\u4e2a\u5355\u8bcd\uff0c\n    # \u6211\u4eec\u4f7f\u7528\u4e0b\u4e00\u4e2a\u5355\u8bcd\u4f5c\u4e3a\u5b83\u7684\u9884\u6d4b\u6807\u8bb0(token)\u3002\n    #\n    # inputs = [not,     all,    heros,   wear,   capes]\n    #      x = [not,     all,   heroes,   wear  ]\n    #      y = [all,  heroes,     wear,   capes ]\n    # \n    # \u4f8b\u5982\uff0c\u7ed9\u5b9a\u8f93\u5165\u5e8f\u5217 [not, all, heroes, wear, capes]\uff0c\n    # \u8f93\u5165x\uff08\u6a21\u578b\u7684\u8f93\u5165\uff09\u5c06\u4f1a\u662f [not, all, heroes, wear]\uff0c\n    # \u800c\u5bf9\u5e94\u7684\u6807\u8bb0(token)y\uff08\u6211\u4eec\u5e0c\u671b\u6a21\u578b\u9884\u6d4b\u7684\u8f93\u51fa\uff09\u5c06\u4f1a\u662f [all, heroes, wear, capes]\u3002\n    #\n    # \u56e0\u4e3a\u6700\u540e\u4e00\u4e2a\u8f93\u5165\u5143\u7d20\u6ca1\u6709\u5bf9\u5e94\u7684\u4e0b\u4e00\u4e2a\u5355\u8bcd\u4f5c\u4e3a\u6807\u8bb0(token)\uff0c\u6240\u4ee5\u6211\u4eec\u4ecex\u4e2d\u6392\u9664\u4e86inputs[-1]\u3002\n    #\n    # \u56e0\u6b64\uff0c\u5bf9\u4e8eN\u4e2a\u8f93\u5165\uff0c\u6211\u4eec\u6709N-1\u4e2a\u8bed\u8a00\u5efa\u6a21\u7684\u6837\u4f8b\u5bf9\u3002\n    # inputs[:-1]\uff1a\u8fd9\u4e2a\u8868\u8fbe\u5f0f\u8868\u793a\u4ece\u5217\u8868inputs\u4e2d\u53d6\u51fa\u4ece\u5f00\u59cb\u5230\u5012\u6570\u7b2c\u4e8c\u4e2a\u5143\u7d20\u7684\u6240\u6709\u5143\u7d20\n    # inputs[1:]\uff1a\u8fd9\u4e2a\u8868\u8fbe\u5f0f\u8868\u793a\u4ece\u5217\u8868inputs\u4e2d\u53d6\u51fa\u4ece\u7b2c\u4e8c\u4e2a\u5143\u7d20\uff08\u7d22\u5f15\u4e3a1\uff09\u5230\u6700\u540e\u4e00\u4e2a\u5143\u7d20\u7684\u6240\u6709\u5143\u7d20\n    x, y = inputs[:-1], inputs[1:]\n\n    # \u524d\u5411\u4f20\u64ad\n    # \u4f7f\u7528gpt\u51fd\u6570\u548c\u7ed9\u5b9a\u7684\u53c2\u6570\u5bf9\u8f93\u5165x\u8fdb\u884c\u5904\u7406\uff0c\u5f97\u5230\u6bcf\u4e2a\u4f4d\u7f6e\u9884\u6d4b\u7684\u4e0b\u4e00\u4e2a\u5355\u8bcd\u7684\u6982\u7387\u5206\u5e03\u3002\n    output = gpt(x, params)\n\n    # \u4ea4\u53c9\u71b5\u635f\u5931\n    # \u5bf9\u6240\u6709N-1\u4e2a\u6837\u4f8b\u8ba1\u7b97\u635f\u5931\u7684\u5e73\u5747\u503c\u3002\n    # \u8fd9\u91cc\uff0c\u6211\u4eec\u8ba1\u7b97\u7684\u662f\u5b9e\u9645\u6807\u8bb0(token)y\u5bf9\u5e94\u7684\u6982\u7387\u7684\u8d1f\u5bf9\u6570\uff0c\u8fd9\u662f\u8ba1\u7b97\u4ea4\u53c9\u71b5\u635f\u5931\u7684\u5e38\u89c1\u65b9\u6cd5\u3002\n    loss = np.mean(-np.log(output[np.arange(len(y)), y]))\n\n    return loss\n\n# \u5b9a\u4e49\u4e00\u4e2a\u8bad\u7ec3\u51fd\u6570\uff0c\u5b83\u63a5\u53d7\u4e00\u4e2a\u6587\u672c\u5217\u8868\uff08\u6bcf\u4e2a\u6587\u672c\u7531\u5355\u8bcd\u5217\u8868\u7ec4\u6210\uff09\u548c\u521d\u59cb\u6a21\u578b\u53c2\u6570\u3002\ndef train(texts: list[list[str]], params) -&gt; float:\n    # \u904d\u5386\u6587\u672c\u6570\u636e\u96c6\u4e2d\u7684\u6bcf\u4e2a\u6587\u672c\u3002\n    for text in texts:\n        # \u4f7f\u7528tokenizer\u5c06\u6587\u672c\u7f16\u7801\u4e3a\u6574\u6570\u5217\u8868\u3002\n        inputs = tokenizer.encode(text)\n        # \u8ba1\u7b97\u5f53\u524d\u6587\u672c\u7684\u635f\u5931\u3002\n        loss = lm_loss(inputs, params)\n        # \u901a\u8fc7\u53cd\u5411\u4f20\u64ad\u8ba1\u7b97\u68af\u5ea6\u3002\n        gradients = compute_gradients_via_backpropagation(loss, params)\n        # \u4f7f\u7528\u68af\u5ea6\u4e0b\u964d\u66f4\u65b0\u53c2\u6570\u3002\n        params = gradient_descent_update_step(gradients, params)\n    # \u8fd4\u56de\u66f4\u65b0\u540e\u7684\u53c2\u6570\u3002\n    return params<\/pre><\/div>\n\n\n\n<p>\u4ee5\u4e0a\u662f\u4e00\u4e2a\u6781\u5ea6\u7b80\u5316\u7684\u8bad\u7ec3\u8bbe\u7f6e\uff0c\u4f46\u662f\u5b83\u57fa\u672c\u8986\u76d6\u4e86\u91cd\u70b9\u3002\u8fd9\u91cc\u6ce8\u610f\u4e00\u4e0b\uff0c\u6211\u4eec\u7684<code>gpt<\/code>\u51fd\u6570\u7b7e\u540d\u4e2d\u52a0\u5165\u4e86<code>params<\/code>\uff08\u4e3a\u4e86\u7b80\u5316\uff0c\u6211\u4eec\u5728\u4e0a\u4e00\u8282\u662f\u628a\u5b83\u53bb\u6389\u7684\uff09\u3002\u5728\u8bad\u7ec3\u5faa\u73af\u7684\u6bcf\u6b21\u8fed\u4ee3\u4e2d\uff1a<\/p>\n\n\n\n<ol class=\"wp-block-list\">\n<li>\u6211\u4eec\u4e3a\u7ed9\u5b9a\u7684\u8f93\u5165\u6587\u672c\u793a\u4f8b\u8ba1\u7b97\u8bed\u8a00\u5efa\u6a21\u635f\u5931<\/li>\n\n\n\n<li>\u635f\u5931\u51b3\u5b9a\u4e86\u6211\u4eec\u7684\u68af\u5ea6\uff0c\u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u53cd\u5411\u4f20\u64ad\u8ba1\u7b97\u68af\u5ea6<\/li>\n\n\n\n<li>\u6211\u4eec\u4f7f\u7528\u68af\u5ea6\u6765\u66f4\u65b0\u6211\u4eec\u7684\u6a21\u578b\u53c2\u6570\uff0c\u4f7f\u5f97\u6211\u4eec\u7684\u635f\u5931\u80fd\u591f\u6700\u5c0f\u5316\uff08\u68af\u5ea6\u4e0b\u964d\uff09<\/li>\n<\/ol>\n\n\n\n<p>\u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u5728\u8fd9\u91cc\u5e76\u672a\u4f7f\u7528\u660e\u786e\u7684\u6807\u6ce8\u6570\u636e\u3002\u53d6\u800c\u4ee3\u4e4b\u7684\u662f\uff0c\u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u539f\u59cb\u6587\u672c\u81ea\u8eab\uff0c\u4ea7\u751f\u5927\u91cf\u7684\u8f93\u5165\/\u6807\u7b7e\u5bf9(input\/label pairs)\u3002\u8fd9\u5c31\u662f\u6240\u8c13\u7684<a href=\"https:\/\/en.wikipedia.org\/wiki\/Self-supervised_learning\" target=\"_blank\" rel=\"noreferrer noopener\">\u81ea\u76d1\u7763\u5b66\u4e60<\/a>(<strong><a href=\"https:\/\/en.wikipedia.org\/wiki\/Self-supervised_learning\">self-supervised learning<\/a><\/strong>)\u3002<\/p>\n\n\n\n<p>\u81ea\u76d1\u7763\u5b66\u4e60\u7684\u8303\u5f0f\uff0c\u8ba9\u6211\u4eec\u80fd\u591f\u6d77\u91cf\u6269\u5145\u8bad\u7ec3\u6570\u636e\u3002\u6211\u4eec\u53ea\u9700\u8981\u5c3d\u53ef\u80fd\u591a\u7684\u641e\u5230\u5927\u91cf\u7684\u6587\u672c\u6570\u636e\uff0c\u7136\u540e\u5c06\u5176\u4e22\u5165\u6a21\u578b\u5373\u53ef\u3002\u6bd4\u5982\uff0cGPT-3\u5c31\u662f\u57fa\u4e8e\u6765\u81ea\u4e92\u8054\u7f51\u548c\u4e66\u7c4d\u7684<strong>3000\u4ebftoken<\/strong>(<strong>300 billion tokens<\/strong>)\u8fdb\u884c\u8bad\u7ec3\u7684\uff1a<\/p>\n\n\n\n<figure class=\"wp-block-image\"><img decoding=\"async\" src=\"https:\/\/jiqihumanr.github.io\/images\/table.2.2.png\" alt=\"(table 2.2)\"\/><\/figure>\n\n\n\n<p>\u6765\u81eaGPT-3\u8bba\u6587\u7684Table 2.2<\/p>\n\n\n\n<p>\u5f53\u7136\uff0c\u8fd9\u91cc\u4f60\u5c31\u9700\u8981\u4e00\u4e2a\u8db3\u591f\u5927\u7684\u6a21\u578b\u6709\u80fd\u529b\u53bb\u4ece\u8fd9\u4e48\u5927\u91cf\u7684\u6570\u636e\u4e2d\u5b66\u5230\u5185\u5bb9\uff0c\u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48GPT-3\u6a21\u578b\u62e5\u6709<strong>1750\u4ebf\u7684\u53c2\u6570(<strong>175 billion parameters<\/strong>)<\/strong>\uff0c\u5e76\u4e14\u5927\u6982\u6d88\u8017\u4e86<a href=\"https:\/\/twitter.com\/eturner303\/status\/1266264358771757057\" target=\"_blank\" rel=\"noreferrer noopener\">100\u4e07\u20131000\u4e07\u7f8e\u5143\u7684\u8ba1\u7b97\u8d39\u7528\u8fdb\u884c\u8bad\u7ec3<\/a><sup><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#fn:3\">[3]<\/a><\/sup>\u3002<\/p>\n\n\n\n<p>\u8fd9\u4e2a\u81ea\u76d1\u7763\u8bad\u7ec3\u7684\u6b65\u9aa4\u79f0\u4e4b\u4e3a<strong>\u9884\u8bad\u7ec3<\/strong>(<strong>pre-training<\/strong>)\uff0c\u800c\u6211\u4eec\u53ef\u4ee5\u91cd\u590d\u4f7f\u7528\u9884\u8bad\u7ec3\u6a21\u578b\u6743\u91cd\u6765\u8bad\u7ec3\u4e0b\u6e38\u4efb\u52a1\u4e0a\u7684\u7279\u5b9a\u6a21\u578b\uff0c\u6bd4\u5982\u5bf9\u6587\u672c\u8fdb\u884c\u5206\u7c7b\uff08\u5206\u7c7b\u67d0\u6761\u63a8\u6587\u662f\u6709\u5bb3\u7684\u8fd8\u662f\u65e0\u5bb3\u7684\uff09\u3002\u9884\u8bad\u7ec3\u6a21\u578b\u6709\u65f6\u4e5f\u88ab\u79f0\u4e3a<strong>\u57fa\u7840\u6a21\u578b(foundation models)<\/strong>\u3002<\/p>\n\n\n\n<p>\u5728\u4e0b\u6e38\u4efb\u52a1\u4e0a\u8bad\u7ec3\u6a21\u578b\u88ab\u79f0\u4e4b\u4e3a<strong>\u5fae\u8c03<\/strong>(<strong>fine-tuning<\/strong>)\uff0c\u7531\u4e8e\u6a21\u578b\u6743\u91cd\u5df2\u7ecf\u9884\u8bad\u7ec3\u597d\u4e86\uff0c\u5df2\u7ecf\u80fd\u591f\u7406\u89e3\u8bed\u8a00\u4e86\uff0c\u90a3\u4e48\u6211\u4eec\u9700\u8981\u505a\u7684\u5c31\u662f\u9488\u5bf9\u7279\u5b9a\u7684\u4efb\u52a1\u53bb\u5fae\u8c03\u8fd9\u4e9b\u6743\u91cd\u3002<\/p>\n\n\n\n<p>\u8fd9\u4e2a\u6240\u8c13\u201c\u5728\u901a\u7528\u4efb\u52a1\u4e0a\u9884\u8bad\u7ec3 + \u7279\u5b9a\u4efb\u52a1\u4e0a\u5fae\u8c03\u201d\u7684\u7b56\u7565\u5c31\u79f0\u4e4b\u4e3a<a href=\"https:\/\/en.wikipedia.org\/wiki\/Transfer_learning\" target=\"_blank\" rel=\"noreferrer noopener\">\u8fc1\u79fb\u5b66\u4e60<\/a>(<a href=\"https:\/\/en.wikipedia.org\/wiki\/Transfer_learning\">transfer learning<\/a>)\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\" id=\"\u63d0\u793a\uff08prompting\uff09\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#%E6%8F%90%E7%A4%BA%EF%BC%88prompting%EF%BC%89\"><\/a>1.4 \u63d0\u793a\uff08prompting\uff09<\/strong><\/h3>\n\n\n\n<p>\u672c\u8d28\u4e0a\u770b\uff0c\u539f\u59cb\u7684<a href=\"https:\/\/s3-us-west-2.amazonaws.com\/openai-assets\/research-covers\/language-unsupervised\/language_understanding_paper.pdf\" target=\"_blank\" rel=\"noreferrer noopener\">GPT\u8bba\u6587<\/a>\u53ea\u662f\u63d0\u4f9b\u4e86\u7528\u6765\u8fc1\u79fb\u5b66\u4e60\u7684transformer\u6a21\u578b\u7684\u9884\u8bad\u7ec3\u3002\u6587\u7ae0\u663e\u793a\uff0c\u4e00\u4e2a117M\u7684GPT\u9884\u8bad\u7ec3\u6a21\u578b\uff0c\u5728\u9488\u5bf9\u4e0b\u6e38\u4efb\u52a1\u7684\u6807\u6ce8\u6570\u636e\u4e0a\u5fae\u8c03\u4e4b\u540e\uff0c\u5b83\u80fd\u591f\u5728\u5404\u79cd<strong>NLP(natural language processing,\u81ea\u7136\u8bed\u8a00\u5904\u7406)<\/strong>\u4efb\u52a1\u4e0a\u8fbe\u5230\u6700\u4f18\u6027\u80fd\u3002<\/p>\n\n\n\n<p>\u76f4\u5230<a href=\"https:\/\/d4mucfpksywv.cloudfront.net\/better-language-models\/language_models_are_unsupervised_multitask_learners.pdf\" target=\"_blank\" rel=\"noreferrer noopener\">GPT-2<\/a>\u548c<a href=\"https:\/\/arxiv.org\/abs\/2005.14165\" target=\"_blank\" rel=\"noreferrer noopener\">GPT-3<\/a>\u7684\u8bba\u6587\u51fa\u6765\uff0c\u6211\u4eec\u624d\u610f\u8bc6\u5230\uff0c\u4e00\u4e2aGPT\u6a21\u578b\u53ea\u8981\u5728\u8db3\u591f\u591a\u7684\u6570\u636e\u4e0a\u8bad\u7ec3\uff0c\u53ea\u8981\u6a21\u578b\u62e5\u6709\u8db3\u591f\u591a\u7684\u53c2\u6570\uff0c\u90a3\u4e48\u4e0d\u9700\u8981\u5fae\u8c03\uff0c\u6a21\u578b<strong>\u672c\u8eab<\/strong>(<strong>by itself<\/strong>)\u5c31\u6709\u80fd\u529b\u6267\u884c\u5404\u79cd\u4efb\u52a1\u3002\u53ea\u8981\u4f60\u5bf9\u6a21\u578b\u8fdb\u884c\u63d0\u793a\uff0c\u8fd0\u884c\u81ea\u56de\u5f52\u8bed\u8a00\u6a21\u578b\uff0c\u7136\u540e\u4f60\u731c\u548b\u5730\uff1f\u6a21\u578b\u5c31\u795e\u5947\u7684\u8fd4\u56de\u7ed9\u6211\u4eec\u5408\u9002\u7684\u54cd\u5e94\u4e86\u3002\u8fd9\uff0c\u5c31\u662f\u6240\u8c13\u7684<strong>\u4e0a\u4e0b\u6587\u5b66\u4e60(in-context learning<\/strong>)\uff0c \u4e5f\u5c31\u662f\u8bf4\u6a21\u578b\u4ec5\u4ec5\u6839\u636e\u63d0\u793a\u7684\u5185\u5bb9\uff0c\u5c31\u80fd\u591f\u6267\u884c\u5404\u79cd\u4efb\u52a1\u4e86\u3002<strong>\u4e0a\u4e0b\u6587\u5b66\u4e60(in-context learning<\/strong>)\u53ef\u4ee5\u662f<strong>\u96f6\u6b21(zero shot)<\/strong>, <strong>\u4e00\u6b21(one shot)<\/strong>, \u6216\u8005\u662f<strong>\u5f88\u5c11\u6b21(few shot)<\/strong>\u7684\uff1a<\/p>\n\n\n\n<blockquote class=\"wp-block-quote is-layout-flow wp-block-quote-is-layout-flow\">\n<p>\u8bd1\u8005\u6ce8\uff1a\u6211\u4eec\u53ef\u4ee5\u7b80\u5355\u7684\u8ba4\u4e3a\uff0c\u4e3a\u4e86\u6267\u884c\u6211\u4eec\u7684\u81ea\u5df1\u7684\u4efb\u52a1\uff0czero shot\u8868\u793a\u6211\u4eec\u76f4\u63a5\u62ff\u7740\u5927\u6a21\u578b\u5c31\u80fd\u7528\u4e8e\u6211\u4eec\u7684\u4efb\u52a1\u4e86\uff1bone shot\u8868\u793a\u6211\u4eec\u9700\u8981\u63d0\u4f9b\u7ed9\u5927\u6a21\u578b\u5173\u4e8e\u6211\u4eec\u7279\u5b9a\u4efb\u52a1\u7684\u4e00\u4e2a\u5217\u5b50\uff1bfew shot\u8868\u793a\u6211\u4eec\u9700\u8981\u63d0\u4f9b\u7ed9\u5927\u6a21\u578b\u5173\u4e8e\u6211\u4eec\u7279\u5b9a\u4efb\u52a1\u7684\u51e0\u4e2a\u4f8b\u5b50\uff1b<\/p>\n<\/blockquote>\n\n\n\n<figure class=\"wp-block-image\"><img decoding=\"async\" src=\"https:\/\/jiqihumanr.github.io\/images\/fig.2.1.png\" alt=\"(fig 2.1)\"\/><\/figure>\n\n\n\n<p>\u6765\u81eaGPT-3\u8bba\u6587\u7684\u56fe2.1<\/p>\n\n\n\n<p>\u57fa\u4e8e\u63d0\u793a\u5185\u5bb9\u751f\u6210\u6587\u672c\u4e5f\u88ab\u79f0\u4e4b\u4e3a<strong>\u6761\u4ef6\u751f\u6210<\/strong>(<strong>conditional generation<\/strong>)\uff0c\u56e0\u4e3a\u6211\u4eec\u7684\u6a21\u578b\u662f\u57fa\u4e8e\u7279\u5b9a\u7684\u8f93\u5165\uff08<em>\u6761\u4ef6<\/em>\uff09\u8fdb\u884c\u751f\u6210\u7684\u3002<\/p>\n\n\n\n<p>\u5f53\u7136\uff0cGPT\u4e5f\u4e0d\u4ec5\u9650\u4e8e\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4efb\u52a1(NLP)\u3002\u4f60\u53ef\u4ee5\u5c06\u6a21\u578b\u7528\u4e8e\u4efb\u4f55\u4f60\u60f3\u8981\u7684\u6761\u4ef6\u4e0b\u3002\u6bd4\u5982\u4f60\u53ef\u4ee5\u5c06GPT\u53d8\u6210\u4e00\u4e2a\u804a\u5929\u673a\u5668\u4eba(\u5373\uff1a<a href=\"https:\/\/openai.com\/blog\/chatgpt\/\" target=\"_blank\" rel=\"noreferrer noopener\">ChatGPT<\/a>)\uff0c\u8fd9\u91cc\u7684\u6761\u4ef6\u5c31\u662f\u4f60\u7684\u5bf9\u8bdd\u5386\u53f2\u3002\u4f60\u4e5f\u53ef\u4ee5\u8fdb\u4e00\u6b65\u6761\u4ef6\u5316\u4f60\u7684\u804a\u5929\u673a\u5668\u4eba\uff0c\u901a\u8fc7\u63d0\u793a\u8bcd\u8fdb\u884c\u67d0\u79cd\u63cf\u8ff0\uff0c\u9650\u5b9a\u5176\u8868\u73b0\u4e3a\u67d0\u79cd\u884c\u4e3a\uff08\u6bd4\u5982\u4f60\u53ef\u4ee5\u63d0\u793a\uff1a\u201c\u4f60\u662f\u4e2a\u804a\u5929\u673a\u5668\u4eba\uff0c\u8bf7\u793c\u8c8c\u4e00\u70b9\uff0c\u8bf7\u8bb2\u5b8c\u6574\u7684\u53e5\u5b50\uff0c\u4e0d\u8981\u8bf4\u6709\u5bb3\u7684\u4e1c\u897f\uff0c\u7b49\u7b49\u201d\uff09\u3002\u50cf\u8fd9\u6837\u6761\u4ef6\u5316\u4f60\u7684\u6a21\u578b\uff0c\u4f60\u5b8c\u5168\u53ef\u4ee5\u5f97\u5230\u4e00\u4e2a<a href=\"https:\/\/imgur.com\/a\/AbDFcgk\" target=\"_blank\" rel=\"noreferrer noopener\">\u5b9a\u5236\u5316\u79c1\u4eba\u52a9\u7406\u673a\u5668\u4eba<\/a>(<a href=\"https:\/\/imgur.com\/a\/AbDFcgk\">chatbot a persona<\/a>)\u3002\u4f46\u662f\u8fd9\u6837\u7684\u65b9\u5f0f\u4e0d\u4e00\u5b9a\u5f88\u5065\u58ee\uff0c<a href=\"https:\/\/twitter.com\/zswitten\/status\/1598380220943593472\" target=\"_blank\" rel=\"noreferrer noopener\">\u4f60\u4ecd\u7136\u53ef\u4ee5\u5bf9\u4f60\u7684\u6a21\u578b\u8fdb\u884c\u8d8a\u72f1\uff0c\u7136\u540e\u8ba9\u5b83\u8868\u73b0\u5931\u5e38<\/a>(<a href=\"https:\/\/twitter.com\/zswitten\/status\/1598380220943593472\">&#8220;jailbreak&#8221; the model and make it misbehave<\/a>)\u3002<\/p>\n\n\n\n<p>\u8bf4\u5b8c\u4e86\u8fd9\u4e9b\uff0c\u73b0\u5728\u7ec8\u4e8e\u8981\u5f00\u59cb\u5b9e\u9645\u5b9e\u73b0\u4e86\u3002<\/p>\n\n\n\n<h2 class=\"wp-block-heading\" id=\"\u51c6\u5907\u5de5\u4f5c\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#%E5%87%86%E5%A4%87%E5%B7%A5%E4%BD%9C\"><\/a>2. \u51c6\u5907\u5de5\u4f5c<\/strong><\/h2>\n\n\n\n<p>\u9996\u5148\u5c06\u8fd9\u4e2a\u6559\u7a0b\u7684\u4ed3\u5e93clone\u4e0b\u6765\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:sh decode:true \">git clone https:\/\/github.com\/jaymody\/picoGPT\ncd picoGPT\n\nls -l\ntotal 32\n-rwxrwxrwx 1 tony tony 1065 Apr 25  2023 LICENSE\n-rwxrwxrwx 1 tony tony 2395 Apr 25  2023 README.md\n-rwxrwxrwx 1 tony tony 4318 Apr 25  2023 encoder.py\n-rwxrwxrwx 1 tony tony 4246 Apr 25  2023 gpt2.py\n-rwxrwxrwx 1 tony tony 2330 Apr 25  2023 gpt2_pico.py\n-rwxrwxrwx 1 tony tony  502 Apr 25  2023 requirements.txt\n-rwxrwxrwx 1 tony tony 2745 Apr 25  2023 utils.py<\/pre><\/div>\n\n\n\n<p>\u7136\u540e\u5b89\u88c5\u4f9d\u8d56\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:sh decode:true \">pip install -r requirements.txt<\/pre><\/div>\n\n\n\n<p>\u6ce8\u610f\uff1a\u76ee\u524d\u4ee3\u7801\u5728<code>Python 3.9.10<\/code>\u4e0b\u6d4b\u8bd5\u901a\u8fc7\u3002<\/p>\n\n\n\n<p>\u7b80\u5355\u4ecb\u7ecd\u4e00\u4e0b\u6bcf\u4e2a\u6587\u4ef6\uff1a<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li><code>encoder.py<\/code>\u5305\u542b\u4e86OpenAI\u7684BPE\u5206\u8bcd\u5668\u7684\u4ee3\u7801\uff0c\u8fd9\u662f\u76f4\u63a5\u4ece<a href=\"https:\/\/github.com\/openai\/gpt-2\/blob\/master\/src\/encoder.py\" target=\"_blank\" rel=\"noreferrer noopener\">gpt-2\u4ed3\u5e93<\/a>\u62ff\u8fc7\u6765\u7684<\/li>\n\n\n\n<li><code>utils.py<\/code>\uff1a\u5305\u542b\u4e0b\u8f7d\u5e76\u52a0\u8f7dGPT-2\u6a21\u578b\u7684\u6743\u91cd\uff0c\u5206\u8bcd\u5668\u548c\u8d85\u53c2\u6570<\/li>\n\n\n\n<li><code>gpt2.py<\/code>\uff1a\u5305\u542b\u4e86\u5b9e\u9645GPT\u6a21\u578b\u4ee5\u53ca\u751f\u6210\u7684\u4ee3\u7801\uff0c\u8fd9\u4e2a\u4ee3\u7801\u53ef\u4ee5\u4f5c\u4e3apython\u811a\u672c\u76f4\u63a5\u8fd0\u884c<\/li>\n\n\n\n<li><code>gpt2_pico.py<\/code>\uff1a\u548c<code>gpt2.py<\/code>\u4e00\u6837\uff0c\u4f46\u662f\u884c\u6570\u53d8\u5c11\u4e86\u3002\u4f60\u95ee\u4e3a\u4ec0\u4e48\uff1f\u4f60\u731c<\/li>\n<\/ul>\n\n\n\n<p>\u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u5c06\u4ece0-1\u590d\u73b0<code>gpt2.py<\/code>\uff0c\u6240\u4ee5\u8bf7\u5148\u5c06\u8fd9\u4e2a\u6587\u4ef6\u5220\u6389\u5427\uff0c\u6211\u4eec\u91cd\u65b0\u5efa\u7acb\u4e00\u4e2a\u65b0\u7684<code>gpt2.py<\/code>\u6587\u4ef6\uff0c\u7136\u540e\u4ece\u5934\u5199\u8d77\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:sh decode:true \">rm gpt2.py\ntouch gpt2.py<\/pre><\/div>\n\n\n\n<p>\u9996\u5148\uff0c\u5c06\u4e0b\u9762\u7684\u4ee3\u7801\u7c98\u8d34\u5230<code>gpt2.py<\/code>\u91cc\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">import numpy as np\n# \u8fd9\u884c\u4ee3\u7801\u5bfc\u5165\u4e86NumPy\u5e93\uff0c\u5e76\u5c06\u5176\u7f29\u5199\u4e3anp\u3002NumPy\u662fPython\u7684\u4e00\u4e2a\u5f00\u6e90\u6570\u503c\u8ba1\u7b97\u6269\u5c55\u5e93\uff0c\n# \u63d0\u4f9b\u4e86\u5927\u91cf\u7684\u7ef4\u5ea6\u6570\u7ec4\u4e0e\u77e9\u9635\u8fd0\u7b97\u529f\u80fd\uff0c\u8fd8\u5305\u542b\u4e86\u8bb8\u591a\u9ad8\u7ea7\u6570\u5b66\u51fd\u6570\u64cd\u4f5c\u3002\n# NumPy\u662f\u79d1\u5b66\u8ba1\u7b97\u4e2d\u5e38\u7528\u7684\u4e00\u4e2a\u5e93\uff0c\u7279\u522b\u662f\u5728\u6570\u636e\u5206\u6790\u3001\u673a\u5668\u5b66\u4e60\u7b49\u9886\u57df\u3002\n# \u4f7f\u7528np\u4f5c\u4e3aNumPy\u7684\u7f29\u5199\u662f\u4e00\u4e2a\u5e7f\u6cdb\u91c7\u7528\u7684\u7ea6\u5b9a\uff0c\u8fd9\u6837\u53ef\u4ee5\u5728\u8c03\u7528NumPy\u51fd\u6570\u65f6\u4f7f\u4ee3\u7801\u66f4\u7b80\u6d01\u3002\n\n# \u5b9a\u4e49GPT-2\u6a21\u578b\u7684\u6846\u67b6\uff0c\u4f46\u6ca1\u6709\u5b9e\u73b0\u5177\u4f53\u7684\u5185\u90e8\u903b\u8f91\u3002\ndef gpt2(inputs, wte, wpe, blocks, ln_f, n_head):\n    pass # TODO: \u5b9e\u73b0\u8fd9\u4e2a\u51fd\u6570\n\n# \u5b9a\u4e49\u4e00\u4e2a\u751f\u6210\u6587\u672c\u7684\u51fd\u6570\uff0c\u8fd9\u4e2a\u51fd\u6570\u5229\u7528\u4e86\u81ea\u56de\u5f52\u89e3\u7801\u5faa\u73af\u6765\u751f\u6210\u6307\u5b9a\u6570\u91cf\u7684\u4ee4\u724c\u3002\ndef generate(inputs, params, n_head, n_tokens_to_generate):\n    from tqdm import tqdm\n    # \u8fd9\u884c\u4ee3\u7801\u4ecetqdm\u5e93\u4e2d\u5bfc\u5165\u4e86tqdm\u51fd\u6570\u3002tqdm\u662f\u4e00\u4e2a\u5feb\u901f\u3001\u53ef\u6269\u5c55\u7684Python\u8fdb\u5ea6\u6761\u5e93\uff0c\n    # \u53ef\u4ee5\u5728Python\u957f\u5faa\u73af\u4e2d\u6dfb\u52a0\u4e00\u4e2a\u8fdb\u5ea6\u63d0\u793a\u4fe1\u606f\uff0c\u7528\u6237\u53ea\u9700\u8981\u5c01\u88c5\u4efb\u610f\u7684\u8fed\u4ee3\u5668 tqdm(iterator)\u3002\n    # tqdm\u663e\u793a\u5f53\u524d\u5faa\u73af\u7684\u8fdb\u5ea6\uff0c\u9884\u8ba1\u5269\u4f59\u65f6\u95f4\uff0c\u4ee5\u53ca\u5728\u4e00\u884c\u5185\u52a8\u6001\u66f4\u65b0\u8fd9\u4e9b\u4fe1\u606f\u3002\n\n    # \u4f7f\u7528tqdm\u53ef\u4ee5\u5e2e\u52a9\u7528\u6237\u4e86\u89e3\u4ee3\u7801\u7684\u6267\u884c\u8fdb\u5ea6\uff0c\u7279\u522b\u662f\u5728\u5904\u7406\u9700\u8981\u8f83\u957f\u65f6\u95f4\u7684\u5faa\u73af\u64cd\u4f5c\u65f6\uff0c\n    # \u8fd9\u53ef\u4ee5\u63d0\u4f9b\u76f4\u89c2\u7684\u53cd\u9988\u4fe1\u606f\uff0c\u8ba9\u7528\u6237\u77e5\u9053\u7a0b\u5e8f\u4ecd\u5728\u6b63\u5e38\u8fd0\u884c\uff0c\u800c\u4e0d\u662f\u5361\u4f4f\u6216\u5d29\u6e83\u3002\n\n    # \u4f8b\u5982\uff0c\u5728\u6570\u636e\u5904\u7406\u3001\u6a21\u578b\u8bad\u7ec3\u6216\u6587\u4ef6\u8bfb\u5199\u7b49\u8017\u65f6\u64cd\u4f5c\u4e2d\u4f7f\u7528tqdm\uff0c\u53ef\u4ee5\u6539\u5584\u7528\u6237\u7684\u4f53\u9a8c\uff0c\n    # \u901a\u8fc7\u53ef\u89c6\u5316\u8fdb\u5ea6\u6761\u660e\u786e\u5730\u663e\u793a\u8fdb\u5ea6\u548c\u9884\u8ba1\u7684\u5b8c\u6210\u65f6\u95f4\u3002\n\n    # \u4f7f\u7528tqdm\u5e93\u663e\u793a\u8fdb\u5ea6\u6761\uff0c\u5bf9\u6bcf\u4e2a\u751f\u6210\u7684\u4ee4\u724c\u8fdb\u884c\u8fed\u4ee3\u3002\n    for _ in tqdm(range(n_tokens_to_generate), \"generating\"):  # \u81ea\u56de\u5f52\u89e3\u7801\u5faa\u73af\n        logits = gpt2(inputs, **params, n_head=n_head)  # \u6a21\u578b\u524d\u5411\u4f20\u64ad\n        next_id = np.argmax(logits[-1])  # \u8d2a\u5a6a\u91c7\u6837\uff0c\u9009\u62e9\u6982\u7387\u6700\u9ad8\u7684\u4e0b\u4e00\u4e2a\u4ee4\u724cID\n        inputs.append(int(next_id))  # \u5c06\u9884\u6d4b\u7684\u4ee4\u724cID\u6dfb\u52a0\u5230\u8f93\u5165\u4e2d\uff0c\u7528\u4e8e\u4e0b\u4e00\u6b21\u751f\u6210\n\n    # \u8fd4\u56de\u751f\u6210\u7684\u4ee4\u724cID\uff0c\u53ea\u8fd4\u56de\u65b0\u589e\u52a0\u7684\u90e8\u5206\n    return inputs[len(inputs) - n_tokens_to_generate :]  # \u53ea\u8fd4\u56de\u751f\u6210\u7684\u4ee4\u724cID\n\n# \u5b9a\u4e49\u4e3b\u51fd\u6570\uff0c\u7528\u4e8e\u5904\u7406\u6587\u672c\u751f\u6210\u7684\u6d41\u7a0b\u3002\ndef main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = \"124M\", models_dir: str = \"models\"):\n    from utils import load_encoder_hparams_and_params\n    # \u8fd9\u884c\u4ee3\u7801\u4ece\u4e00\u4e2a\u540d\u4e3autils\u7684\u6a21\u5757\u4e2d\u5bfc\u5165\u4e86load_encoder_hparams_and_params\u51fd\u6570\u3002\n    # \u8fd9\u4e2a\u51fd\u6570\u7684\u4f5c\u7528\u901a\u5e38\u662f\u52a0\u8f7d\u7f16\u7801\u5668\u3001\u8d85\u53c2\u6570\u548c\u53c2\u6570\u3002\u5728\u4e0a\u4e0b\u6587\u4e2d\uff0c\u8fd9\u53ef\u80fd\u6307\u7684\u662f\u52a0\u8f7d\n    # \u4e00\u4e2a\u9884\u8bad\u7ec3\u7684\u8bed\u8a00\u6a21\u578b\uff08\u5982GPT-2\uff09\u7684\u76f8\u5173\u914d\u7f6e\u548c\u6743\u91cd\u53c2\u6570\u3002\u8fd9\u4e9b\u53c2\u6570\u53ef\u80fd\u5305\u62ec\u6a21\u578b\u7684\u5927\u5c0f\u3001\n    # \u7528\u4e8e\u6a21\u578b\u8bad\u7ec3\u7684\u8d85\u53c2\u6570\u8bbe\u7f6e\uff0c\u4ee5\u53ca\u6a21\u578b\u6743\u91cd\u7b49\u3002\u8fd9\u6837\u7684\u8bbe\u8ba1\u6a21\u5f0f\u4f7f\u5f97\u4ee3\u7801\u7ed3\u6784\u66f4\u6e05\u6670\uff0c\n    # \u5e76\u4e14\u65b9\u4fbf\u5728\u4e0d\u540c\u7684\u5730\u65b9\u590d\u7528\u52a0\u8f7d\u6a21\u578b\u7684\u903b\u8f91\u3002\n    # \u5728\u673a\u5668\u5b66\u4e60\u548c\u6df1\u5ea6\u5b66\u4e60\u9879\u76ee\u4e2d\uff0c\u5c06\u6a21\u578b\u52a0\u8f7d\u548c\u53c2\u6570\u914d\u7f6e\u7684\u529f\u80fd\u6a21\u5757\u5316\u662f\u4e00\u79cd\u5e38\u89c1\u7684\u5b9e\u8df5\u3002\n\n    # \u4eceGPT-2\u53d1\u5e03\u7684\u6587\u4ef6\u4e2d\u52a0\u8f7d\u7f16\u7801\u5668\u3001\u8d85\u53c2\u6570\u548c\u53c2\u6570\n    encoder, hparams, params = load_encoder_hparams_and_params(model_size, models_dir)\n\n    # \u4f7f\u7528BPE\u5206\u8bcd\u5668\u5c06\u8f93\u5165\u5b57\u7b26\u4e32\u7f16\u7801\u4e3a\u4ee4\u724cID\n    input_ids = encoder.encode(prompt)\n\n    # \u786e\u4fdd\u751f\u6210\u7684\u4ee4\u724c\u6570\u91cf\u52a0\u4e0a\u8f93\u5165\u957f\u5ea6\u4e0d\u4f1a\u8d85\u8fc7\u6a21\u578b\u7684\u6700\u5927\u5e8f\u5217\u957f\u5ea6\n    # hparams[\"n_ctx\"]\u5f15\u7528\u7684\u662f\u4e00\u4e2a\u8d85\u53c2\u6570\u5b57\u5178\uff08hparams\uff09\u4e2d\u7684n_ctx\u952e\u5bf9\u5e94\u7684\u503c\u3002\n    # \u5728\u795e\u7ecf\u7f51\u7edc\u6a21\u578b\uff0c\u7279\u522b\u662f\u8bed\u8a00\u6a21\u578b\u5982GPT-2\u4e2d\uff0c\n    # n_ctx\u901a\u5e38\u6307\u7684\u662f\u6a21\u578b\u80fd\u591f\u5904\u7406\u7684\u6700\u5927\u4e0a\u4e0b\u6587\uff08\u6216\u5e8f\u5217\uff09\u957f\u5ea6\u3002\n    # \u8fd9\u4e2a\u957f\u5ea6\u5b9a\u4e49\u4e86\u6a21\u578b\u4e00\u6b21\u80fd\u591f\u63a5\u53d7\u548c\u5904\u7406\u7684\u6700\u5927\u4ee4\u724c\uff08\u5982\u5355\u8bcd\u6216\u5b57\u7b26\uff09\u6570\u91cf\n    assert len(input_ids) + n_tokens_to_generate &lt; hparams[\"n_ctx\"]\n\n    # \u751f\u6210\u8f93\u51fa\u4ee4\u724cID\n    # hparams[\"n_head\"]\uff1a\u5b9a\u4e49\u4e86\u5728\u6ce8\u610f\u529b\u673a\u5236\u4e2d\u5e76\u884c\u8fd0\u884c\u7684\u72ec\u7acb\u5934\u7684\u6570\u91cf\u3002\n    # \u589e\u52a0\u5934\u7684\u6570\u91cf\u53ef\u4ee5\u63d0\u5347\u6a21\u578b\u6355\u6349\u4e0d\u540c\u7279\u5f81\u7684\u80fd\u529b\uff0c\n    # \u4f46\u540c\u65f6\u4e5f\u4f1a\u589e\u52a0\u6a21\u578b\u7684\u590d\u6742\u5ea6\u548c\u8ba1\u7b97\u8981\u6c42\n    output_ids = generate(input_ids, params, hparams[\"n_head\"], n_tokens_to_generate)\n\n    # \u5c06\u4ee4\u724cID\u89e3\u7801\u56de\u5b57\u7b26\u4e32\n    output_text = encoder.decode(output_ids)\n\n    return output_text\n\n# \u5f53\u811a\u672c\u76f4\u63a5\u8fd0\u884c\u65f6\u6267\u884c\u4e3b\u51fd\u6570\nif __name__ == \"__main__\":\n    import fire\n    # \u8fd9\u884c\u4ee3\u7801\u5bfc\u5165\u4e86Python\u7684fire\u5e93\u3002Fire\u5e93\u662f\u4e00\u4e2a\u7531Google\u5f00\u53d1\u7684\u5f00\u6e90\u5e93\uff0c\u7528\u4e8e\u81ea\u52a8\u751f\u6210\u547d\u4ee4\u884c\u63a5\u53e3\uff08CLI\uff09\u3002\n    # \u5b83\u53ef\u4ee5\u5c06\u4efb\u4f55Python\u5bf9\u8c61\uff08\u65e0\u8bba\u662f\u51fd\u6570\u3001\u7c7b\u3001\u6a21\u5757\u3001\u751a\u81f3\u662f\u5bf9\u8c61\u5b9e\u4f8b\uff09\u8f6c\u6362\u6210CLI\u3002\n    # \u4f7f\u7528Fire\uff0c\u4f60\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u5c06\u4e00\u4e2aPython\u811a\u672c\u53d8\u4e3a\u4e00\u4e2a\u53ef\u4ee5\u63a5\u53d7\u547d\u4ee4\u884c\u53c2\u6570\u7684\u547d\u4ee4\u884c\u5de5\u5177\uff0c\n    # \u800c\u65e0\u9700\u7f16\u5199\u5927\u91cf\u7684\u89e3\u6790\u547d\u4ee4\u884c\u53c2\u6570\u7684\u4ee3\u7801\u3002Fire\u4f1a\u81ea\u52a8\u5904\u7406\u547d\u4ee4\u884c\u53c2\u6570\u7684\u89e3\u6790\uff0c\u5e76\u5c06\u547d\u4ee4\u884c\u53c2\u6570\u6620\u5c04\u5230\u51fd\u6570\u7684\u53c2\u6570\u4e0a\u3002\n    # \u8fd9\u6837\uff0c\u5f00\u53d1\u8005\u53ef\u4ee5\u66f4\u4e13\u6ce8\u4e8e\u903b\u8f91\u4ee3\u7801\u7684\u5f00\u53d1\uff0c\u800c\u4e0d\u662fCLI\u7684\u6784\u5efa\u548c\u53c2\u6570\u89e3\u6790\u3002\n    # Fire\u975e\u5e38\u9002\u5408\u5feb\u901f\u6784\u5efa\u548c\u8fed\u4ee3\u5f00\u53d1\u547d\u4ee4\u884c\u5de5\u5177\uff0c\u4f7f\u5f97Python\u811a\u672c\u7684\u547d\u4ee4\u884c\u5316\u53d8\u5f97\u5f02\u5e38\u7b80\u5355\u3002\n\n    # \u4f7f\u7528fire\u5e93\u4f7f\u5f97\u4ece\u547d\u4ee4\u884c\u8fd0\u884c\u811a\u672c\u65f6\u80fd\u591f\u63a5\u53d7\u53c2\u6570\n    fire.Fire(main)\n<\/pre><\/div>\n\n\n\n<p>\u6211\u4eec\u5c06\u5206\u4e3a\u56db\u90e8\u5206\u8fdb\u884c\u62c6\u89e3\uff1a<\/p>\n\n\n\n<ol class=\"wp-block-list\">\n<li><code>gpt2<\/code>\u51fd\u6570\u662f\u6211\u4eec\u5c06\u8981\u5b9e\u73b0\u7684\u5b9e\u9645GPT\u4ee3\u7801\u3002\u4f60\u4f1a\u6ce8\u610f\u5230\u51fd\u6570\u7b7e\u540d\u4e2d\u9664\u4e86<code>inputs<\/code>\uff0c\u8fd8\u6709\u5176\u5b83\u7684\u53c2\u6570\uff1a\n<ul class=\"wp-block-list\">\n<li><code>wte<\/code>,&nbsp;<code>wpe<\/code>,&nbsp;<code>blocks<\/code>,&nbsp;<code>ln_f<\/code>\u8fd9\u4e9b\u90fd\u662f\u6211\u4eec\u6a21\u578b\u7684\u53c2\u6570<\/li>\n\n\n\n<li><code>n_head<\/code>\u662f\u524d\u5411\u8ba1\u7b97\u8fc7\u7a0b\u4e2d\u9700\u8981\u7684\u8d85\u53c2<\/li>\n<\/ul>\n<\/li>\n\n\n\n<li><code>generate<\/code>\u51fd\u6570\u662f\u6211\u4eec\u4e4b\u524d\u770b\u5230\u7684\u81ea\u56de\u5f52\u89e3\u7801\u7b97\u6cd5\u3002\u4e3a\u4e86\u7b80\u6d01\uff0c\u6211\u4eec\u4f7f\u7528\u8d2a\u5fc3\u91c7\u6837\u7b97\u6cd5\u3002<code>tqdm<\/code>\u662f\u4e00\u4e2a\u8fdb\u5ea6\u6761\u5e93\uff0c\u5b83\u53ef\u4ee5\u5e2e\u52a9\u6211\u4eec\u968f\u7740\u6bcf\u6b21\u751f\u6210\u4e00\u4e2atoken\uff0c\u53ef\u89c6\u5316\u5730\u89c2\u5bdf\u89e3\u7801\u8fc7\u7a0b\u3002<\/li>\n\n\n\n<li><code>main<\/code>\u51fd\u6570\u4e3b\u8981\u5904\u7406\uff1a\n<ul class=\"wp-block-list\">\n<li>1.\u52a0\u8f7d\u5206\u8bcd\u5668(<code>encoder<\/code>)\uff0c \u6a21\u578b\u6743\u91cd\uff08<code>params<\/code>\uff09\uff0c \u8d85\u53c2\uff08<code>hparams<\/code>\uff09<\/li>\n\n\n\n<li>2.\u4f7f\u7528\u5206\u8bcd\u5668\u5c06\u8f93\u5165\u63d0\u793a\u8bcd\u7f16\u7801\u4e3atoken ID<\/li>\n\n\n\n<li>3.\u8c03\u7528\u751f\u6210\u51fd\u6570<\/li>\n\n\n\n<li>4.\u5c06\u8f93\u51faID\u89e3\u7801\u4e3a\u5b57\u7b26\u4e32<\/li>\n<\/ul>\n<\/li>\n\n\n\n<li><code>fire.Fire(main)<\/code>\u5c06\u6211\u4eec\u7684\u6e90\u6587\u4ef6\u8f6c\u6210\u4e00\u4e2a\u547d\u4ee4\u884c\u5e94\u7528\uff0c\u7136\u540e\u5c31\u53ef\u4ee5\u50cf\u8fd9\u6837\u8fd0\u884c\u6211\u4eec\u7684\u4ee3\u7801\u4e86\uff1a<code>python gpt2.py \"some prompt here\"<\/code><\/li>\n<\/ol>\n\n\n\n<p>\u6211\u4eec\u5148\u5728notebook\u6216\u8005python\u4ea4\u4e92\u754c\u9762\u4e0b\u770b\u770b<code>encoder<\/code>,&nbsp;<code>hparams<\/code>,&nbsp;<code>params<\/code>\uff0c\u8fd0\u884c\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">from utils import load_encoder_hparams_and_params\nencoder, hparams, params = load_encoder_hparams_and_params(\"124M\", \"models\")<\/pre><\/div>\n\n\n\n<p>\u4e0a\u8ff0\u4ee3\u7801\u5c06<a href=\"https:\/\/github.com\/jaymody\/picoGPT\/blob\/a750c145ba4d09d5764806a6c78c71ffaff88e64\/utils.py#L13-L40\" target=\"_blank\" rel=\"noreferrer noopener\">\u4e0b\u8f7d\u5fc5\u8981\u7684\u6a21\u578b\u53ca\u5206\u8bcd\u5668\u6587\u4ef6<\/a>\u81f3<code>models\/124M<\/code>\uff0c\u5e76\u4e14<a href=\"https:\/\/github.com\/jaymody\/picoGPT\/blob\/a750c145ba4d09d5764806a6c78c71ffaff88e64\/utils.py#L68-L82\" target=\"_blank\" rel=\"noreferrer noopener\">\u52a0\u8f7d<code>encoder<\/code>,<code>hparams<\/code>,<code>params<\/code><\/a>\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\" id=\"\u7f16\u7801\u5668\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#%E7%BC%96%E7%A0%81%E5%99%A8\"><\/a>2.1 \u7f16\u7801\u5668(Encoder)<\/strong><\/h3>\n\n\n\n<p>\u6211\u4eec\u7684<code>encoder<\/code>\u4f7f\u7528\u7684\u662fGPT-2\u4e2d\u4f7f\u7528\u7684BPE\u5206\u8bcd\u5668\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">&gt;&gt;&gt; ids = encoder.encode(\"Not all heroes wear capes.\")\n&gt;&gt;&gt; ids\n[3673, 477, 10281, 5806, 1451, 274, 13]\n\n&gt;&gt;&gt; encoder.decode(ids)\n\"Not all heroes wear capes.\"<\/pre><\/div>\n\n\n\n<p>\u4f7f\u7528\u5206\u8bcd\u5668\u7684\u8bcd\u6c47\u8868(\u5b58\u50a8\u4e8e<code>encoder.decoder<\/code>)\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u770b\u5b9e\u9645\u7684token\u5230\u5e95\u957f\u5565\u6837\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">&gt;&gt;&gt; [encoder.decoder[i] for i in ids]\n['Not', '\u0120all', '\u0120heroes', '\u0120wear', '\u0120cap', 'es', '.']<\/pre><\/div>\n\n\n\n<p>\u6ce8\u610f\uff0c\u6709\u7684\u65f6\u5019\u6211\u4eec\u7684token\u662f\u5355\u8bcd\uff08\u6bd4\u5982\uff1a<code>Not<\/code>\uff09\uff0c\u6709\u7684\u65f6\u5019\u867d\u7136\u4e5f\u662f\u5355\u8bcd\uff0c\u4f46\u662f\u53ef\u80fd\u4f1a\u6709\u4e00\u4e2a\u7a7a\u683c\u5728\u5b83\u524d\u9762\uff08\u6bd4\u5982<code>\u0120all<\/code>,&nbsp;<a href=\"https:\/\/github.com\/karpathy\/minGPT\/blob\/37baab71b9abea1b76ab957409a1cc2fbfba8a26\/mingpt\/bpe.py#L22-L33\" target=\"_blank\" rel=\"noreferrer noopener\"><code>\u0120<\/code>\u4ee3\u8868\u4e00\u4e2a\u7a7a\u683c<\/a>\uff09\uff0c\u6709\u65f6\u5019\u662f\u4e00\u4e2a\u5355\u8bcd\u7684\u4e00\u90e8\u5206\uff08\u6bd4\u5982\uff1acapes\u88ab\u5206\u9694\u4e3a<code>\u0120cap<\/code>\u548c<code>es<\/code>\uff09\uff0c\u8fd8\u6709\u53ef\u80fd\u5b83\u5c31\u662f\u6807\u70b9\u7b26\u53f7\uff08\u6bd4\u5982\uff1a<code>.<\/code>\uff09\u3002<\/p>\n\n\n\n<p>BPE\u7684\u4e00\u4e2a\u597d\u5904\u662f\u5b83\u53ef\u4ee5\u7f16\u7801\u4efb\u610f\u5b57\u7b26\u4e32\u3002\u5982\u679c\u9047\u5230\u4e86\u67d0\u4e9b\u6ca1\u6709\u5728\u8bcd\u6c47\u8868\u91cc\u663e\u793a\u7684\u5b57\u7b26\u4e32\uff0c\u90a3\u4e48BPE\u5c31\u4f1a\u5c06\u5176\u5206\u5272\u4e3a\u5b83\u80fd\u591f\u7406\u89e3\u7684\u5b50\u4e32\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">&gt;&gt;&gt; [encoder.decoder[i] for i in encoder.encode(\"zjqfl\")]\n['z', 'j', 'q', 'fl']<\/pre><\/div>\n\n\n\n<p>\u6211\u4eec\u8fd8\u53ef\u4ee5\u68c0\u67e5\u4e00\u4e0b\u8bcd\u6c47\u8868\u7684\u5927\u5c0f\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">&gt;&gt;&gt; len(encoder.decoder)\n50257<\/pre><\/div>\n\n\n\n<p>\u8bcd\u6c47\u8868\u4ee5\u53ca\u51b3\u5b9a\u5b57\u7b26\u4e32\u5982\u4f55\u5206\u89e3\u7684<strong>\u5b57\u8282\u5bf9\u7ec4\u5408\uff08byte-pair merges\uff09<\/strong>\uff0c\u662f\u901a\u8fc7<em>\u8bad\u7ec3\u5206\u8bcd\u5668<\/em>\u83b7\u5f97\u7684\u3002\u5f53\u6211\u4eec\u52a0\u8f7d\u5206\u8bcd\u5668\uff0c\u5c31\u4f1a\u4ece\u4e00\u4e9b\u6587\u4ef6\u52a0\u8f7d\u5df2\u7ecf\u8bad\u7ec3\u597d\u7684\u8bcd\u6c47\u8868\u548c<strong>\u5b57\u8282\u5bf9\u7ec4\u5408<\/strong>\uff0c\u8fd9\u4e9b\u6587\u4ef6\u5728\u6211\u4eec\u8fd0\u884c<code>load_encoder_hparams_and_params<\/code>\u7684\u65f6\u5019\uff0c\u968f\u7740\u6a21\u578b\u6587\u4ef6\u88ab\u4e00\u8d77\u4e0b\u8f7d\u4e86\u3002\u4f60\u53ef\u4ee5\u67e5\u770b<code>models\/124M\/encoder.json<\/code>(\u8bcd\u6c47\u8868)\u548c<code>models\/124M\/vocab.bpe<\/code>(\u5b57\u8282\u5bf9\u7ec4\u5408)\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\" id=\"\u8d85\u53c2\u6570\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#%E8%B6%85%E5%8F%82%E6%95%B0\"><\/a>2.2 \u8d85\u53c2\u6570(Hyperparameters)<\/strong><\/h3>\n\n\n\n<p><code>hparams<\/code>\u662f\u4e00\u4e2a\u5b57\u5178\uff0c\u8fd9\u4e2a\u5b57\u5178\u5305\u542b\u7740\u6211\u4eec\u6a21\u578b\u7684\u8d85\u53c2\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">&gt;&gt;&gt; hparams\n{\n  \"n_vocab\": 50257, # \u6211\u4eec\u8bcd\u6c47\u8868\u4e2d\u7684\u6807\u8bb0(token)\u6570\u91cf\n  \"n_ctx\": 1024, # \u8f93\u5165\u7684\u6700\u5927\u53ef\u80fd\u5e8f\u5217\u957f\u5ea6\n  \"n_embd\": 768, # \u5d4c\u5165\u7ef4\u5ea6\uff08\u51b3\u5b9a\u4e86\u7f51\u7edc\u7684\u201c\u5bbd\u5ea6\u201d\uff09\n  \"n_head\": 12, # \u6ce8\u610f\u529b\u5934\u7684\u6570\u91cf\uff08n_embd \u5fc5\u987b\u80fd\u88ab n_head \u6574\u9664\uff09\n  \"n_layer\": 12 # \u7f51\u7edc\u7684\u5c42\u6570\uff08\u51b3\u5b9a\u4e86\u7f51\u7edc\u7684\u201c\u6df1\u5ea6\u201d\uff09\n}\n\n<\/pre><\/div>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u5c55\u793a\u4e86\u4e00\u4e2a\u5178\u578b\u7684\u8d85\u53c2\u6570\uff08hparams\uff09\u914d\u7f6e\u5b57\u5178\uff0c\u901a\u5e38\u7528\u4e8e\u5b9a\u4e49\u548c\u521d\u59cb\u5316Transformer\u67b6\u6784\u7684\u795e\u7ecf\u7f51\u7edc\u6a21\u578b\uff08\u5982GPT\u7cfb\u5217\u6a21\u578b\uff09\uff1a<\/p>\n\n\n\n<p><strong>n_vocab<\/strong>\uff1a\u5b9a\u4e49\u4e86\u6a21\u578b\u8bcd\u6c47\u8868\u7684\u5927\u5c0f\uff0c\u5373\u6a21\u578b\u80fd\u591f\u8bc6\u522b\u7684\u552f\u4e00\u6807\u8bb0(token)\uff08\u4f8b\u5982\u5355\u8bcd\u6216\u5b57\u7b26\uff09\u7684\u6570\u91cf\u3002\u8fd9\u4e2a\u53c2\u6570\u76f4\u63a5\u5f71\u54cd\u4e86\u6a21\u578b\u8f93\u5165\u5c42\u548c\u8f93\u51fa\u5c42\u7684\u5927\u5c0f\u3002<\/p>\n\n\n\n<p><strong>n_ctx<\/strong>\uff1a\u6307\u5b9a\u4e86\u6a21\u578b\u53ef\u4ee5\u5904\u7406\u7684\u6700\u5927\u5e8f\u5217\u957f\u5ea6\u3002\u8fd9\u4e2a\u957f\u5ea6\u9650\u5236\u4e86\u6a21\u578b\u4e00\u6b21\u6027\u53ef\u4ee5\u63a5\u6536\u548c\u5904\u7406\u7684\u8f93\u5165\u6570\u636e\u7684\u5927\u5c0f\uff0c\u5bf9\u4e8e\u6587\u672c\u751f\u6210\u548c\u5176\u4ed6\u5e8f\u5217\u5904\u7406\u4efb\u52a1\u81f3\u5173\u91cd\u8981\u3002<\/p>\n\n\n\n<p><strong>n_embd<\/strong>\uff1a\u5d4c\u5165\u7ef4\u5ea6\uff0c\u8868\u793a\u6a21\u578b\u4e2d\u6bcf\u4e2a\u6807\u8bb0(token)\u7684\u5411\u91cf\u8868\u793a\u7684\u7ef4\u5ea6\u3002\u8fd9\u4e2a\u53c2\u6570\u51b3\u5b9a\u4e86\u7f51\u7edc\u201c\u5bbd\u5ea6\u201d\uff0c\u8f83\u5927\u7684\u5d4c\u5165\u7ef4\u5ea6\u53ef\u4ee5\u8ba9\u6a21\u578b\u6355\u6349\u66f4\u590d\u6742\u7684\u7279\u5f81\uff0c\u4f46\u4e5f\u4f1a\u589e\u52a0\u6a21\u578b\u7684\u53c2\u6570\u91cf\u548c\u8ba1\u7b97\u9700\u6c42\u3002<\/p>\n\n\n\n<p><strong>n_head<\/strong>\uff1a\u591a\u5934\u6ce8\u610f\u529b\u673a\u5236\u4e2d\u7684\u5934\u6570\u3002\u5728Transformer\u6a21\u578b\u4e2d\uff0c\u591a\u5934\u6ce8\u610f\u529b\u5141\u8bb8\u6a21\u578b\u5728\u4e0d\u540c\u7684\u8868\u793a\u5b50\u7a7a\u95f4\u4e2d\u5e76\u884c\u6355\u83b7\u4fe1\u606f\uff0c\u63d0\u9ad8\u4e86\u6a21\u578b\u5904\u7406\u590d\u6742\u5e8f\u5217\u6570\u636e\u7684\u80fd\u529b\u3002<\/p>\n\n\n\n<p><strong>n_layer<\/strong>\uff1a\u7f51\u7edc\u5c42\u6570\uff0c\u5373Transformer\u6a21\u578b\u4e2d\u7f16\u7801\u5668\u548c\u89e3\u7801\u5668\u5806\u53e0\u7684\u5c42\u6570\u3002\u5c42\u6570\u8d8a\u591a\uff0c\u6a21\u578b\u7684\u201c\u6df1\u5ea6\u201d\u8d8a\u5927\uff0c\u7406\u8bba\u4e0a\u6a21\u578b\u80fd\u591f\u5b66\u4e60\u66f4\u590d\u6742\u7684\u7279\u5f81\u548c\u5173\u7cfb\uff0c\u4f46\u540c\u65f6\u4e5f\u589e\u52a0\u4e86\u8bad\u7ec3\u7684\u96be\u5ea6\u548c\u8fc7\u62df\u5408\u7684\u98ce\u9669\u3002<\/p>\n\n\n\n<p>\u6211\u4eec\u5c06\u5728\u4ee3\u7801\u7684\u6ce8\u91ca\u4e2d\u4f7f\u7528\u8fd9\u4e9b\u7b26\u53f7\u6765\u8868\u793a\u5404\u79cd\u7684\u5927\u5c0f\u7ef4\u5ea6\u7b49\u7b49\u3002\u6211\u4eec\u8fd8\u4f1a\u4f7f\u7528<code>n_seq<\/code>\u6765\u8868\u793a\u8f93\u5165\u5e8f\u5217\u7684\u957f\u5ea6(\u5373\uff1a<code>n_seq = len(inputs)<\/code>)\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\" id=\"\u53c2\u6570\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#%E5%8F%82%E6%95%B0\"><\/a>2.3 \u53c2\u6570(Parameters)<\/strong><\/h3>\n\n\n\n<p><code>params<\/code>\u662f\u4e00\u4e2a\u5d4c\u5957\u7684json\u5b57\u5178\uff0c\u8be5\u5b57\u5178\u5177\u6709\u6a21\u578b\u8bad\u7ec3\u597d\u7684\u6743\u91cd\u3002json\u7684\u53f6\u5b50\u8282\u70b9\u662fNumPy\u6570\u7ec4\u3002\u5982\u679c\u6211\u4eec\u6253\u5370<code>params<\/code>\uff0c \u7528\u4ed6\u4eec\u7684\u5f62\u72b6\u53bb\u8868\u793a\u6570\u7ec4\uff0c\u6211\u4eec\u53ef\u4ee5\u5f97\u5230\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">&gt;&gt;&gt; import numpy as np\n&gt;&gt;&gt; \n&gt;&gt;&gt; def shape_tree(d):\n&gt;&gt;&gt;     # \u5982\u679cd\u662f\u4e00\u4e2aNumPy\u6570\u7ec4\uff0c\u8fd4\u56de\u5b83\u7684\u5f62\u72b6\u4f5c\u4e3a\u4e00\u4e2a\u5217\u8868\u3002\n&gt;&gt;&gt;     if isinstance(d, np.ndarray):\n&gt;&gt;&gt;         return list(d.shape)\n&gt;&gt;&gt;     # \u5982\u679cd\u662f\u4e00\u4e2a\u5217\u8868\uff0c\u9012\u5f52\u5730\u5bf9\u6bcf\u4e2a\u5143\u7d20\u8c03\u7528shape_tree\uff0c\u5e76\u8fd4\u56de\u5f62\u72b6\u4fe1\u606f\u7684\u5217\u8868\u3002\n&gt;&gt;&gt;     elif isinstance(d, list):\n&gt;&gt;&gt;         return [shape_tree(v) for v in d]\n&gt;&gt;&gt;     # \u679cd\u662f\u4e00\u4e2a\u5b57\u5178\uff0c\u9012\u5f52\u5730\u5bf9\u6bcf\u4e2a\u503c\u8c03\u7528shape_tree\uff0c\u8fd4\u56de\u4e00\u4e2a\u5305\u542b\u5f62\u72b6\u4fe1\u606f\u7684\u65b0\u5b57\u5178\u3002\n&gt;&gt;&gt;     elif isinstance(d, dict):\n&gt;&gt;&gt;         return {k: shape_tree(v) for k, v in d.items()}\n&gt;&gt;&gt;     # \u5982\u679cd\u4e0d\u662f\u4e0a\u8ff0\u4efb\u4f55\u4e00\u79cd\u7c7b\u578b\uff0c\u629b\u51faValueError\u5f02\u5e38\u3002\n&gt;&gt;&gt;     else:\n&gt;&gt;&gt;         ValueError(\"uh oh\")\n&gt;&gt;&gt; \n&gt;&gt;&gt; # \u5370\u53c2\u6570params\u7684\u5f62\u72b6\u4fe1\u606f\u3002\u8fd9\u91cc\u5047\u8bbeparams\u662f\u4e00\u4e2a\u590d\u6742\u7684\u5d4c\u5957\u6570\u636e\u7ed3\u6784\uff0c\u5305\u542b\u6a21\u578b\u7684\u53c2\u6570\u3002\n&gt;&gt;&gt; print(shape_tree(params))\n{\n    \"wpe\": [1024, 768],\n    \"wte\": [50257, 768],\n    \"ln_f\": {\"b\": [768], \"g\": [768]},\n    \"blocks\": [\n        {\n            \"attn\": {\n                \"c_attn\": {\"b\": [2304], \"w\": [768, 2304]},\n                \"c_proj\": {\"b\": [768], \"w\": [768, 768]},\n            },\n            \"ln_1\": {\"b\": [768], \"g\": [768]},\n            \"ln_2\": {\"b\": [768], \"g\": [768]},\n            \"mlp\": {\n                \"c_fc\": {\"b\": [3072], \"w\": [768, 3072]},\n                \"c_proj\": {\"b\": [768], \"w\": [3072, 768]},\n            },\n        },\n        ... # repeat for n_layers\n    ]\n}<\/pre><\/div>\n\n\n\n<p>\u8fd9\u4e9b\u662f\u4ece\u539f\u59cb\u7684OpenAI TensorFlow checkpoint\u52a0\u8f7d\u7684\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">&gt;&gt;&gt; import tensorflow as tf\n&gt;&gt;&gt; \n&gt;&gt;&gt; # \u83b7\u53d6\u6307\u5b9a\u76ee\u5f55\u4e0b\u6700\u65b0\u7684\u68c0\u67e5\u70b9\u6587\u4ef6\u8def\u5f84\n&gt;&gt;&gt; tf_ckpt_path = tf.train.latest_checkpoint(\"models\/124M\")\n&gt;&gt;&gt; \n&gt;&gt;&gt; # \u904d\u5386\u68c0\u67e5\u70b9\u6587\u4ef6\u4e2d\u7684\u6240\u6709\u53d8\u91cf\n&gt;&gt;&gt; for name, _ in tf.train.list_variables(tf_ckpt_path):\n&gt;&gt;&gt;     # \u52a0\u8f7d\u6bcf\u4e2a\u53d8\u91cf\u7684\u503c\uff0c\u5e76\u53bb\u9664\u5355\u4e00\u7ef4\u5ea6\n&gt;&gt;&gt;     arr = tf.train.load_variable(tf_ckpt_path, name).squeeze()\n&gt;&gt;&gt;     # \u6253\u5370\u53d8\u91cf\u540d\u548c\u5176\u5f62\u72b6\n&gt;&gt;&gt;     print(f\"{name}: {arr.shape}\")\nmodel\/h0\/attn\/c_attn\/b: (2304,)\nmodel\/h0\/attn\/c_attn\/w: (768, 2304)\nmodel\/h0\/attn\/c_proj\/b: (768,)\nmodel\/h0\/attn\/c_proj\/w: (768, 768)\nmodel\/h0\/ln_1\/b: (768,)\nmodel\/h0\/ln_1\/g: (768,)\nmodel\/h0\/ln_2\/b: (768,)\nmodel\/h0\/ln_2\/g: (768,)\nmodel\/h0\/mlp\/c_fc\/b: (3072,)\nmodel\/h0\/mlp\/c_fc\/w: (768, 3072)\nmodel\/h0\/mlp\/c_proj\/b: (768,)\nmodel\/h0\/mlp\/c_proj\/w: (3072, 768)\nmodel\/h1\/attn\/c_attn\/b: (2304,)\nmodel\/h1\/attn\/c_attn\/w: (768, 2304)\n...\nmodel\/h9\/mlp\/c_proj\/b: (768,)\nmodel\/h9\/mlp\/c_proj\/w: (3072, 768)\nmodel\/ln_f\/b: (768,)\nmodel\/ln_f\/g: (768,)\nmodel\/wpe: (1024, 768)\nmodel\/wte: (50257, 768)\n<\/pre><\/div>\n\n\n\n<p><a href=\"https:\/\/github.com\/jaymody\/picoGPT\/blob\/29e78cc52b58ed2c1c483ffea2eb46ff6bdec785\/utils.py#L43-L65\" target=\"_blank\" rel=\"noreferrer noopener\">\u4e0b\u8ff0\u4ee3\u7801<\/a>\u5c06\u4e0a\u9762\u7684tensorflow\u53d8\u91cf\u8f6c\u6362\u4e3a<code>params<\/code>\u5b57\u5178\u3002<\/p>\n\n\n\n<p>\u4e3a\u4e86\u5bf9\u6bd4\uff0c\u8fd9\u91cc\u663e\u793a\u4e86<code>params<\/code>\u7684\u5f62\u72b6\uff0c\u4f46\u662f\u6570\u5b57\u88ab<code>hparams<\/code>\u66ff\u4ee3\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">{\n    \"wpe\": [n_ctx, n_embd],\n    \"wte\": [n_vocab, n_embd],\n    \"ln_f\": {\"b\": [n_embd], \"g\": [n_embd]},\n    \"blocks\": [\n        {\n            \"attn\": {\n                \"c_attn\": {\"b\": [3*n_embd], \"w\": [n_embd, 3*n_embd]},\n                \"c_proj\": {\"b\": [n_embd], \"w\": [n_embd, n_embd]},\n            },\n            \"ln_1\": {\"b\": [n_embd], \"g\": [n_embd]},\n            \"ln_2\": {\"b\": [n_embd], \"g\": [n_embd]},\n            \"mlp\": {\n                \"c_fc\": {\"b\": [4*n_embd], \"w\": [n_embd, 4*n_embd]},\n                \"c_proj\": {\"b\": [n_embd], \"w\": [4*n_embd, n_embd]},\n            },\n        },\n        ... # repeat for n_layers\n    ]\n}<\/pre><\/div>\n\n\n\n<p>\u5728\u5b9e\u73b0GPT\u7684\u8fc7\u7a0b\u4e2d\uff0c\u4f60\u53ef\u80fd\u4f1a\u9700\u8981\u53c2\u8003\u8fd9\u4e2a\u5b57\u5178\u6765\u786e\u8ba4\u6743\u91cd\u7684\u5f62\u72b6\u3002\u4e3a\u4e86\u4e00\u81f4\u6027\uff0c\u6211\u4eec\u5c06\u4f1a\u4f7f\u4ee3\u7801\u4e2d\u7684\u53d8\u91cf\u540d\u548c\u5b57\u5178\u7684\u952e\u503c\u4fdd\u6301\u5bf9\u9f50\u3002<\/p>\n\n\n\n<h2 class=\"wp-block-heading\" id=\"\u57fa\u7840\u5c42\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#%E5%9F%BA%E7%A1%80%E5%B1%82\"><\/a>3. \u57fa\u7840\u5c42(Basic Layers)<\/strong><\/h2>\n\n\n\n<p>\u5728\u8fdb\u5165\u5b9e\u9645GPT\u67b6\u6784\u524d\u7684\u6700\u540e\u4e00\u4ef6\u4e8b\uff0c\u8ba9\u6211\u4eec\u6765\u624b\u6413\u51e0\u4e2a\u57fa\u7840\u7684\u795e\u7ecf\u7f51\u7edc\u5c42\u5427\uff0c\u8fd9\u4e9b\u57fa\u7840\u5c42\u53ef\u4e0d\u53ea\u662f\u9488\u5bf9GPT\u7684\uff0c\u5b83\u4eec\u5728\u5404\u79cd\u60c5\u51b5\u4e0b\u90fd\u5f88\u6709\u7528\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\" id=\"GELU\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#GELU\"><\/a>3.1 \u683c\u9c81(GELU)<\/strong><\/h3>\n\n\n\n<p>GPT-2\u7684\u975e\u7ebf\u6027\uff08<strong>\u6fc0\u6d3b\u51fd\u6570<\/strong>,<strong>activation function<\/strong>\uff09\u9009\u62e9\u662f<a href=\"https:\/\/arxiv.org\/pdf\/1606.08415.pdf\" target=\"_blank\" rel=\"noreferrer noopener\">GELU\uff08\u9ad8\u65af\u8bef\u5dee\u7ebf\u6027\u5355\u5143,\uff09<\/a>\uff0c\u8fd9\u662f\u4e00\u79cd\u7c7b\u4f3cReLU\u7684\u6fc0\u6d3b\u51fd\u6570\uff1a<img decoding=\"async\" src=\"https:\/\/jiqihumanr.github.io\/images\/GELU.png\" alt=\"\"><\/p>\n\n\n\n<p>\u6765\u81eaGELU\u8bba\u6587\u7684\u56fe1<\/p>\n\n\n\n<p>\u4e0b\u9762\u4ee3\u7801\u5b9a\u4e49\u4e86\u4e00\u4e2a\u540d\u4e3a<code>gelu<\/code>\u7684\u51fd\u6570\uff0c\u5b83\u5b9e\u73b0\u4e86Gaussian Error Linear Unit (GELU)\u6fc0\u6d3b\u51fd\u6570\u3002GELU\u662f\u4e00\u79cd\u5728\u6df1\u5ea6\u5b66\u4e60\u6a21\u578b\u4e2d\u5e38\u7528\u7684\u975e\u7ebf\u6027\u6fc0\u6d3b\u51fd\u6570\uff0c\u7279\u522b\u662f\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\uff08NLP\uff09\u9886\u57df\u7684Transformer\u6a21\u578b\u4e2d\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">def gelu(x):\n    # \u8ba1\u7b97GELU\u6fc0\u6d3b\u51fd\u6570\u7684\u503c\u3002\n    # GELU\u6fc0\u6d3b\u51fd\u6570\u63d0\u4f9b\u4e86\u4e00\u79cd\u5e73\u6ed1\u7684\u65b9\u5f0f\u6765\u6743\u8861\u7ebf\u6027\u548c\u975e\u7ebf\u6027\u53d8\u6362\u3002\n    # \u516c\u5f0f\u4e3a\uff1a0.5 * x * (1 + tanh(sqrt(2 \/ pi) * (x + 0.044715 * x^3)))\n    # \u5176\u4e2d\uff1a\n    # - 0.5 * x * (1 + ...)\uff1a\u786e\u4fdd\u5f53x\u63a5\u8fd10\u65f6\uff0c\u51fd\u6570\u63a5\u8fd1\u7ebf\u6027\u53d8\u6362\u3002\n    # - tanh(...)\uff1a\u63d0\u4f9b\u975e\u7ebf\u6027\u53d8\u6362\uff0c\u5e2e\u52a9\u6a21\u578b\u6355\u83b7\u590d\u6742\u7684\u6570\u636e\u7279\u5f81\u3002\n    # - sqrt(2 \/ pi) * (x + 0.044715 * x^3)\uff1a\u8c03\u6574x\u7684\u503c\uff0c\u589e\u5f3a\u6a21\u578b\u7684\u975e\u7ebf\u6027\u80fd\u529b\u3002\n    #\n    # \u53c2\u6570\uff1a\n    # x\uff1a\u8f93\u5165\u6570\u636e\uff0c\u53ef\u4ee5\u662f\u4e00\u4e2a\u6570\u503c\u3001\u5411\u91cf\u6216\u77e9\u9635\u3002\n    #\n    # \u8fd4\u56de\u503c\uff1a\n    # \u5e94\u7528GELU\u6fc0\u6d3b\u51fd\u6570\u540e\u7684\u7ed3\u679c\uff0c\u4e0e\u8f93\u5165x\u5177\u6709\u76f8\u540c\u7684\u5f62\u72b6\u3002\n    return 0.5 * x * (1 + np.tanh(np.sqrt(2 \/ np.pi) * (x + 0.044715 * x**3)))\n<\/pre><\/div>\n\n\n\n<p>\u548cReLU\u7c7b\u4f3c\uff0cGELU\u4e5f\u5bf9\u8f93\u5165\u8fdb\u884c\u9010\u5143\u7d20\u64cd\u4f5c\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">&gt;&gt;&gt; gelu(np.array([[1, 2], [-2, 0.5]]))\narray([[ 0.84119,  1.9546 ],\n       [-0.0454 ,  0.34571]])<\/pre><\/div>\n\n\n\n<p>\u8fd9\u610f\u5473\u7740\uff1a<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>\u5bf9\u4e8e\u8f93\u51651\uff0cGELU\u6fc0\u6d3b\u540e\u7684\u8f93\u51fa\u5927\u7ea6\u4e3a0.84119\u3002<\/li>\n\n\n\n<li>\u5bf9\u4e8e\u8f93\u51652\uff0cGELU\u6fc0\u6d3b\u540e\u7684\u8f93\u51fa\u5927\u7ea6\u4e3a1.9546\u3002<\/li>\n\n\n\n<li>\u5bf9\u4e8e\u8f93\u5165-2\uff0cGELU\u6fc0\u6d3b\u540e\u7684\u8f93\u51fa\u5927\u7ea6\u4e3a-0.0454\u3002<\/li>\n\n\n\n<li>\u5bf9\u4e8e\u8f93\u51650.5\uff0cGELU\u6fc0\u6d3b\u540e\u7684\u8f93\u51fa\u5927\u7ea6\u4e3a0.34571\u3002<\/li>\n<\/ul>\n\n\n\n<p>GELU\u6fc0\u6d3b\u51fd\u6570\u80fd\u591f\u6709\u6548\u5730\u5904\u7406\u6b63\u6570\u548c\u8d1f\u6570\u8f93\u5165\uff0c\u4e3a\u8d1f\u6570\u63d0\u4f9b\u5e73\u6ed1\u7684\u975e\u7ebf\u6027\u53d8\u6362\uff0c\u540c\u65f6\u4fdd\u6301\u6b63\u6570\u7684\u6fc0\u6d3b\u503c\u76f8\u5bf9\u4e0d\u53d8\uff0c\u8fd9\u6709\u52a9\u4e8e\u6539\u5584\u6df1\u5ea6\u5b66\u4e60\u6a21\u578b\u7684\u8bad\u7ec3\u548c\u6027\u80fd\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\" id=\"Softmax\"><strong>3.2 Softmax<\/strong><\/h3>\n\n\n\n<p>\u4e0b\u9762\u662f\u6700\u7ecf\u5178\u7684<a href=\"https:\/\/en.wikipedia.org\/wiki\/Softmax_function\" target=\"_blank\" rel=\"noreferrer noopener\">softmax<\/a>:<\/p>\n\n\n\n<figure class=\"wp-block-image size-full\"><img loading=\"lazy\" decoding=\"async\" width=\"307\" height=\"73\" src=\"https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/02\/\u56fe\u7247-3.png\" alt=\"\" class=\"wp-image-2218\" srcset=\"https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/02\/\u56fe\u7247-3.png 307w, https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/02\/\u56fe\u7247-3-300x71.png 300w\" sizes=\"auto, (max-width: 307px) 100vw, 307px\" \/><\/figure>\n\n\n\n<p>\u4e0b\u9762\u4ee3\u7801\u5b9a\u4e49\u4e86\u4e00\u4e2a\u540d\u4e3a<code>softmax<\/code>\u7684\u51fd\u6570\uff0c\u5b83\u5b9e\u73b0\u4e86softmax\u6fc0\u6d3b\u51fd\u6570\u3002Softmax\u51fd\u6570\u662f\u5728\u673a\u5668\u5b66\u4e60\u548c\u6df1\u5ea6\u5b66\u4e60\u4e2d\u5e7f\u6cdb\u4f7f\u7528\u7684\u4e00\u4e2a\u51fd\u6570\uff0c\u7279\u522b\u662f\u5728\u5206\u7c7b\u4efb\u52a1\u7684\u8f93\u51fa\u5c42\uff0c\u7528\u4e8e\u5c06\u6a21\u578b\u7684\u8f93\u51fa\u8f6c\u6362\u4e3a\u6982\u7387\u5206\u5e03\u3002<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">def softmax(x):\n    # \u9996\u5148\uff0c\u4ece\u8f93\u5165x\u4e2d\u51cf\u53bbx\u5728\u6700\u540e\u4e00\u4e2a\u8f74\uff08axis=-1\uff09\u4e0a\u7684\u6700\u5927\u503c\uff0c\u4ee5\u63d0\u9ad8\u6570\u503c\u7a33\u5b9a\u6027\u3002\n    # \u8fd9\u4e2a\u64cd\u4f5c\u6709\u52a9\u4e8e\u9632\u6b62\u5728\u8ba1\u7b97e^x\u65f6\u51fa\u73b0\u6570\u503c\u6ea2\u51fa\u7684\u95ee\u9898\u3002\n    # keepdims=True\u4fdd\u6301\u8f93\u51fa\u7684\u7ef4\u5ea6\u4e0e\u8f93\u5165\u76f8\u540c\uff0c\u4ee5\u4fbf\u540e\u7eed\u64cd\u4f5c\u3002\n    exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True))\n    \n    # \u63a5\u7740\uff0c\u8ba1\u7b97exp_x\u6cbf\u7740\u6700\u540e\u4e00\u4e2a\u8f74\u7684\u548c\uff0c\u540c\u6837\u4fdd\u6301\u7ef4\u5ea6\u4e0d\u53d8\u3002\n    # \u8fd9\u4e00\u6b65\u5f97\u5230\u4e86\u6bcf\u4e2a\u5143\u7d20\u5bf9\u5e94\u7684e^x\u7684\u548c\u3002\n    # \u6700\u540e\uff0c\u5c06exp_x\u7684\u6bcf\u4e2a\u5143\u7d20\u9664\u4ee5\u5b83\u4eec\u6240\u5728\u884c\uff08\u6216\u5217\uff09\u7684e^x\u4e4b\u548c\uff0c\n    # \u5f97\u5230\u7684\u7ed3\u679c\u662f\u4e00\u4e2a\u6982\u7387\u5206\u5e03\uff0c\u5176\u4e2d\u6bcf\u4e2a\u5143\u7d20\u7684\u503c\u90fd\u57280\u52301\u4e4b\u95f4\uff0c\n    # \u5e76\u4e14\u6bcf\u4e00\u884c\uff08\u6216\u5217\uff09\u7684\u5143\u7d20\u4e4b\u548c\u4e3a1\u3002\n    return exp_x \/ np.sum(exp_x, axis=-1, keepdims=True)\n<\/pre><\/div>\n\n\n\n<p>\u8fd9\u91cc\u6211\u4eec\u4f7f\u7528\u4e86<a href=\"https:\/\/jaykmody.com\/blog\/stable-softmax\/\" target=\"_blank\" rel=\"noreferrer noopener\"><code>max(x)<\/code>\u6280\u5de7<\/a>\u6765\u4fdd\u6301\u6570\u503c\u7a33\u5b9a\u6027\u3002<\/p>\n\n\n\n<p>softmax\u7528\u6765\u5c06\u4e00\u7ec4\u5b9e\u6570\uff08\u2212\u221e\u81f3\u221e\u4e4b\u95f4\uff09\u8f6c\u6362\u4e3a\u6982\u7387\uff080\u81f31\u4e4b\u95f4\uff0c\u5176\u6c42\u548c\u4e3a1\uff09\u3002\u6211\u4eec\u5c06<code>softmax<\/code>\u4f5c\u7528\u4e8e\u8f93\u5165\u7684\u6700\u672b\u8f74\u4e0a\u3002<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">&gt;&gt;&gt; x = softmax(np.array([[2, 100], [-5, 0]]))\n&gt;&gt;&gt; x\narray([[0.00034, 0.99966],\n       [0.26894, 0.73106]])\n&gt;&gt;&gt; x.sum(axis=-1)\narray([1., 1.])<\/pre><\/div>\n\n\n\n<h3 class=\"wp-block-heading\" id=\"\u5c42\u5f52\u4e00\u5316\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#%E5%B1%82%E5%BD%92%E4%B8%80%E5%8C%96\"><\/a>3.3 \u5c42\u5f52\u4e00\u5316(Layer Normalization)<\/strong><\/h3>\n\n\n\n<p><a href=\"https:\/\/arxiv.org\/pdf\/1607.06450.pdf\" target=\"_blank\" rel=\"noreferrer noopener\">\u5c42\u5f52\u4e00\u5316<\/a>\u5c06\u6570\u503c\u6807\u51c6\u5316\u4e3a\u5747\u503c\u4e3a0\u65b9\u5dee\u4e3a1\u7684\u503c\uff1a<\/p>\n\n\n\n<figure class=\"wp-block-image size-full\"><img loading=\"lazy\" decoding=\"async\" width=\"404\" height=\"95\" src=\"https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/02\/\u56fe\u7247-4.png\" alt=\"\" class=\"wp-image-2222\" srcset=\"https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/02\/\u56fe\u7247-4.png 404w, https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/02\/\u56fe\u7247-4-300x71.png 300w\" sizes=\"auto, (max-width: 404px) 100vw, 404px\" \/><\/figure>\n\n\n\n<figure class=\"wp-block-image size-full\"><img loading=\"lazy\" decoding=\"async\" width=\"710\" height=\"46\" src=\"https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/02\/\u56fe\u7247-6.png\" alt=\"\" class=\"wp-image-2224\" srcset=\"https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/02\/\u56fe\u7247-6.png 710w, https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/02\/\u56fe\u7247-6-300x19.png 300w\" sizes=\"auto, (max-width: 710px) 100vw, 710px\" \/><\/figure>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">def layer_norm(x, g, b, eps: float = 1e-5):\n    # \u8ba1\u7b97\u8f93\u5165x\u5728\u6700\u540e\u4e00\u4e2a\u8f74\uff08\u901a\u5e38\u662f\u7279\u5f81\u7ef4\u5ea6\uff09\u4e0a\u7684\u5747\u503c\u3002\n    # keepdims=True\u4fdd\u6301\u8f93\u51fa\u7ef4\u5ea6\u4e0e\u8f93\u5165\u76f8\u540c\uff0c\u4fbf\u4e8e\u540e\u7eed\u8ba1\u7b97\u3002\n    mean = np.mean(x, axis=-1, keepdims=True)\n\n    # \u8ba1\u7b97\u8f93\u5165x\u5728\u6700\u540e\u4e00\u4e2a\u8f74\u4e0a\u7684\u65b9\u5dee\u3002\n    variance = np.var(x, axis=-1, keepdims=True)\n\n    # \u4f7f\u7528\u8ba1\u7b97\u51fa\u7684\u5747\u503c\u548c\u65b9\u5dee\u5bf9x\u8fdb\u884c\u5f52\u4e00\u5316\u5904\u7406\uff0c\u786e\u4fdd\u6bcf\u4e2a\u7279\u5f81\u7684\u5747\u503c\u4e3a0\uff0c\u65b9\u5dee\u4e3a1\u3002\n    # eps\uff08\u4e00\u4e2a\u5f88\u5c0f\u7684\u6570\uff09\u88ab\u52a0\u5230\u65b9\u5dee\u4e2d\u4ee5\u907f\u514d\u9664\u4ee5\u96f6\u7684\u60c5\u51b5\u3002\n    x = (x - mean) \/ np.sqrt(variance + eps)\n\n    # \u5c06\u5f52\u4e00\u5316\u540e\u7684x\u901a\u8fc7\u7f29\u653e\uff08\u4e58\u4ee5g\uff09\u548c\u504f\u79fb\uff08\u52a0\u4e0ab\uff09\u6765\u8fdb\u884c\u7ebf\u6027\u53d8\u6362\u3002\n    # \u8fd9\u91cc\u7684g\u548cb\u662f\u53ef\u5b66\u4e60\u7684\u53c2\u6570\uff0c\u5206\u522b\u5bf9\u5e94\u4e8eGamma\uff08\u7f29\u653e\uff09\u548cBeta\uff08\u504f\u79fb\uff09\u3002\n    return g * x + b\n    <\/pre><\/div>\n\n\n\n<p>\u5c42\u5f52\u4e00\u5316\u786e\u4fdd\u6bcf\u5c42\u7684\u8f93\u5165\u603b\u662f\u5728\u4e00\u4e2a\u4e00\u81f4\u7684\u8303\u56f4\u91cc\uff0c\u800c\u8fd9\u5c06\u4e3a\u8bad\u7ec3\u8fc7\u7a0b\u7684\u52a0\u901f\u548c\u7a33\u5b9a\u63d0\u4f9b\u652f\u6301\u3002\u4e0e<a href=\"https:\/\/arxiv.org\/pdf\/1502.03167.pdf\" target=\"_blank\" rel=\"noreferrer noopener\">\u6279\u5f52\u4e00\u5316<\/a>\u7c7b\u4f3c\uff0c\u5f52\u4e00\u5316\u4e4b\u540e\u7684\u8f93\u51fa\u901a\u8fc7\u4e24\u4e2a\u53ef\u5b66\u4e60\u53c2\u6570gamma&nbsp;\u548c&nbsp;beta \u8fdb\u884c\u7f29\u653e\u548c\u504f\u79fb\u3002\u5206\u6bcd\u4e2d\u7684\u5c0f<code>epsilon<\/code>\u9879\u7528\u6765\u907f\u514d\u8ba1\u7b97\u4e2d\u7684\u5206\u6bcd\u4e3a\u96f6\u9519\u8bef\u3002<\/p>\n\n\n\n<p>\u6211\u4eec\u5728transformer\u4e2d\u7528\u5c42\u5f52\u4e00\u5316\u6765\u66ff\u6362\u6279\u5f52\u4e00\u5316\u7684<a href=\"https:\/\/stats.stackexchange.com\/questions\/474440\/why-do-transformers-use-layer-norm-instead-of-batch-norm\" target=\"_blank\" rel=\"noreferrer noopener\">\u539f\u56e0\u6709\u5f88\u591a<\/a>\u3002\u5404\u79cd\u4e0d\u540c\u5f52\u4e00\u5316\u6280\u5de7\u7684\u4e0d\u540c\u70b9\u5728<a href=\"https:\/\/tungmphung.com\/deep-learning-normalization-methods\/\" target=\"_blank\" rel=\"noreferrer noopener\">\u8fd9\u4e2a\u535a\u5ba2<\/a>\u4e2d\u8fdb\u884c\u4e86\u7cbe\u5f69\u7684\u603b\u7ed3\u3002<\/p>\n\n\n\n<p>\u6211\u4eec\u5bf9\u8f93\u5165\u7684\u6700\u672b\u8f74\u8fdb\u884c\u5c42\u5f52\u4e00\u5316\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">&gt;&gt;&gt; # \u5b9a\u4e49\u4e00\u4e2a\u4e8c\u7ef4\u6570\u7ec4x\uff0c\u6a21\u62df\u795e\u7ecf\u7f51\u7edc\u5c42\u7684\u8f93\u5165\u3002\n&gt;&gt;&gt; x = np.array([[2, 2, 3], [-5, 0, 1]])\n&gt;&gt;&gt; \n&gt;&gt;&gt; # \u5bf9x\u5e94\u7528\u5c42\u5f52\u4e00\u5316\u3002g\uff08Gamma\uff09\u8bbe\u7f6e\u4e3a\u4e0ex\u7684\u6700\u540e\u4e00\u4e2a\u7ef4\u5ea6\u76f8\u540c\u76841\u7684\u5411\u91cf\uff0cb\uff08Beta\uff09\u8bbe\u7f6e\u4e3a0\u7684\u5411\u91cf\u3002\n&gt;&gt;&gt; # \u8fd9\u610f\u5473\u7740\u5f52\u4e00\u5316\u540e\u7684\u8f93\u51fa\u5c06\u4e0d\u4f1a\u8fdb\u884c\u7f29\u653e\u548c\u504f\u79fb\u3002\n&gt;&gt;&gt; x = layer_norm(x, g=np.ones(x.shape[-1]), b=np.zeros(x.shape[-1]))\n&gt;&gt;&gt; \n&gt;&gt;&gt; # \u6253\u5370\u5f52\u4e00\u5316\u540e\u7684x\u3002\n&gt;&gt;&gt; # \u53ef\u4ee5\u770b\u5230\uff0cx\u4e2d\u7684\u6bcf\u4e2a\u6837\u672c\u90fd\u88ab\u8f6c\u6362\u6210\u4e86\u65b0\u7684\u503c\uff0c\u4ee5\u786e\u4fdd\u6bcf\u4e2a\u6837\u672c\u7684\u7279\u5f81\u5747\u503c\u4e3a0\uff0c\u65b9\u5dee\u4e3a1\u3002\n&gt;&gt;&gt; x\n&gt;&gt;&gt; # \u8f93\u51fa\u662f\u5f52\u4e00\u5316\u540e\u7684\u6570\u7ec4\uff1a\n&gt;&gt;&gt; # array([[-0.70709, -0.70709,  1.41418],\n&gt;&gt;&gt; #        [-1.397  ,  0.508  ,  0.889  ]])\n&gt;&gt;&gt; \n&gt;&gt;&gt; # \u8ba1\u7b97\u5f52\u4e00\u5316\u540e\u6570\u636e\u7684\u65b9\u5dee\uff0c\u4ee5\u9a8c\u8bc1\u5b83\u662f\u5426\u63a5\u8fd11\u3002\n&gt;&gt;&gt; x.var(axis=-1)\n&gt;&gt;&gt; # \u8f93\u51fa\u65b9\u5dee\u63a5\u8fd11\u7684\u7ed3\u679c\uff0c\u8bf4\u660e\u5f52\u4e00\u5316\u64cd\u4f5c\u4f7f\u6bcf\u4e2a\u6837\u672c\u7684\u7279\u5f81\u65b9\u5dee\u7edf\u4e00\u4e3a1\u3002\n&gt;&gt;&gt; # array([0.99996, 1.     ]) # \u7531\u4e8e\u6d6e\u70b9\u8fd0\u7b97\u7684\u539f\u56e0\uff0c\u7ed3\u679c\u975e\u5e38\u63a5\u8fd1\u4f46\u4e0d\u5b8c\u5168\u7b49\u4e8e1\u3002\n&gt;&gt;&gt; \n&gt;&gt;&gt; # \u8ba1\u7b97\u5f52\u4e00\u5316\u540e\u6570\u636e\u7684\u5747\u503c\uff0c\u4ee5\u9a8c\u8bc1\u5b83\u662f\u5426\u63a5\u8fd10\u3002\n&gt;&gt;&gt; x.mean(axis=-1)\n&gt;&gt;&gt; # \u8f93\u51fa\u5747\u503c\u63a5\u8fd10\u7684\u7ed3\u679c\uff0c\u8bf4\u660e\u5f52\u4e00\u5316\u64cd\u4f5c\u6210\u529f\u5730\u5c06\u6bcf\u4e2a\u6837\u672c\u7684\u7279\u5f81\u5747\u503c\u8c03\u6574\u4e3a0\u3002\n&gt;&gt;&gt; # array([-0., -0.])\narray([[-0.70709, -0.70709,  1.41418],\n       [-1.397  ,  0.508  ,  0.889  ]])\n&gt;&gt;&gt; x.var(axis=-1)\narray([0.99996, 1.     ]) # floating point shenanigans\n&gt;&gt;&gt; x.mean(axis=-1)\narray([-0., -0.])\n\n<\/pre><\/div>\n\n\n\n<h3 class=\"wp-block-heading\" id=\"\u7ebf\u6027\uff08\u53d8\u6362\uff09\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#%E7%BA%BF%E6%80%A7%EF%BC%88%E5%8F%98%E6%8D%A2%EF%BC%89\"><\/a>3.4 \u7ebf\u6027\uff08\u53d8\u6362,Linear\uff09<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u91cc\u662f\u6807\u51c6\u7684\u77e9\u9635\u4e58\u6cd5+\u504f\u7f6e\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">def linear(x, w, b):  # [m, in], [in, out], [out] -&gt; [m, out]\n    # x: \u8f93\u5165\u77e9\u9635\uff0c\u5f62\u72b6\u4e3a [m, in]\uff0c\u5176\u4e2d m \u662f\u6279\u6b21\u5927\u5c0f\uff0cin \u662f\u8f93\u5165\u7279\u5f81\u7684\u7ef4\u5ea6\u3002\n    # w: \u6743\u91cd\u77e9\u9635\uff0c\u5f62\u72b6\u4e3a [in, out]\uff0c\u5176\u4e2d in \u662f\u8f93\u5165\u7279\u5f81\u7684\u7ef4\u5ea6\uff0cout \u662f\u8f93\u51fa\u7279\u5f81\u7684\u7ef4\u5ea6\u3002\n    # b: \u504f\u7f6e\u5411\u91cf\uff0c\u5f62\u72b6\u4e3a [out]\uff0c\u5176\u4e2d out \u662f\u8f93\u51fa\u7279\u5f81\u7684\u7ef4\u5ea6\u3002\n\n    # \u8fd4\u56de\u503c: \u7ecf\u8fc7\u7ebf\u6027\u53d8\u6362\u540e\u7684\u8f93\u51fa\u77e9\u9635\uff0c\u5f62\u72b6\u4e3a [m, out]\u3002\n    # \u8fd9\u4e2a\u53d8\u6362\u7684\u8ba1\u7b97\u516c\u5f0f\u4e3a\uff1ax @ w + b\uff0c\u5176\u4e2d @ \u8868\u793a\u77e9\u9635\u4e58\u6cd5\uff0c+ \u8868\u793a\u5411\u91cf\u52a0\u6cd5\u3002\n    # \u5728\u8fd9\u91cc\uff0c@ \u7b26\u53f7\u7528\u4e8e\u8868\u793a\u77e9\u9635\u4e58\u6cd5\u64cd\u4f5c\u3002\n    return x @ w + b\n<\/pre><\/div>\n\n\n\n<p>\u7ebf\u6027\u5c42\u4e5f\u901a\u5e38\u88ab\u8ba4\u4e3a\u662f<strong>\u6295\u5f71<\/strong>\u64cd\u4f5c\uff08\u56e0\u4e3a\u5b83\u4eec\u5c06\u4e00\u4e2a\u5411\u91cf\u7a7a\u95f4\u6295\u5f71\u5230\u53e6\u4e00\u4e2a\u5411\u91cf\u7a7a\u95f4\uff09\u3002<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \"># \u751f\u6210\u4e00\u4e2a\u968f\u673a\u7684\u8f93\u5165\u77e9\u9635x\uff0c\u5f62\u72b6\u4e3a(64, 784)\u3002\n# \u8fd9\u91cc\uff0c784\u662f\u8f93\u5165\u6570\u636e\u7684\u7ef4\u5ea6(dim)\uff0c64\u662f\u6279\u6b21\u5927\u5c0f\u6216\u5e8f\u5217\u957f\u5ea6\u3002\n&gt;&gt;&gt; x = np.random.normal(size=(64, 784)) # \u8f93\u5165\u7ef4\u5ea6(dim) = 784\uff0c\u6279\u6b21(batch)\/\u5e8f\u5217\u7ef4\u5ea6(sequence dim) = 64\n\n# \u751f\u6210\u4e00\u4e2a\u968f\u673a\u7684\u6743\u91cd\u77e9\u9635w\uff0c\u5f62\u72b6\u4e3a(784, 10)\u3002\n# \u8fd9\u91cc\uff0c10\u662f\u8f93\u51fa\u6570\u636e\u7684\u7ef4\u5ea6(dim)\u3002\n&gt;&gt;&gt; w = np.random.normal(size=(784, 10)) # \u8f93\u51fa\u7ef4\u5ea6(dim) = 10\n\n# \u751f\u6210\u4e00\u4e2a\u968f\u673a\u7684\u504f\u7f6e\u5411\u91cfb\uff0c\u957f\u5ea6\u4e3a10\u3002\n&gt;&gt;&gt; b = np.random.normal(size=(10,))\n\n# \u6253\u5370\u8f93\u5165\u77e9\u9635x\u7684\u5f62\u72b6\uff0c\u786e\u8ba4\u5176\u5f62\u72b6\u4e3a(64, 784)\u3002\n# \u8fd9\u8868\u793a\u670964\u4e2a\u6570\u636e\u70b9\uff0c\u6bcf\u4e2a\u6570\u636e\u70b9\u6709784\u4e2a\u7279\u5f81\u3002\n&gt;&gt;&gt; x.shape # \u7ebf\u6027\u53d8\u6362\u524d\u7684\u5f62\u72b6\n\n# \u4f7f\u7528linear\u51fd\u6570\u5bf9\u8f93\u5165x\u8fdb\u884c\u7ebf\u6027\u53d8\u6362\uff0c\u7136\u540e\u6253\u5370\u53d8\u6362\u540e\u7684\u5f62\u72b6\u3002\n# \u7ecf\u8fc7\u53d8\u6362\u540e\uff0c\u8f93\u51fa\u6570\u636e\u7684\u5f62\u72b6\u53d8\u4e3a(64, 10)\uff0c\u8868\u793a\u670964\u4e2a\u6570\u636e\u70b9\uff0c\n# \u6bcf\u4e2a\u6570\u636e\u70b9\u7ecf\u8fc7\u7ebf\u6027\u5c42\u53d8\u6362\u540e\u8f93\u51fa\u4e3a10\u4e2a\u7279\u5f81\u3002\n&gt;&gt;&gt; linear(x, w, b).shape # \u7ebf\u6027\u53d8\u6362\u540e\u7684\u5f62\u72b6\n\n&gt;&gt;&gt; x = np.random.normal(size=(64, 784)) # input dim = 784, batch\/sequence dim = 64\n&gt;&gt;&gt; w = np.random.normal(size=(784, 10)) # output dim = 10\n&gt;&gt;&gt; b = np.random.normal(size=(10,))\n&gt;&gt;&gt; x.shape # shape before linear projection\n(64, 784)\n&gt;&gt;&gt; linear(x, w, b).shape # shape after linear projection\n(64, 10)<\/pre><\/div>\n\n\n\n<h2 class=\"wp-block-heading\" id=\"GPT\u67b6\u6784\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#GPT%E6%9E%B6%E6%9E%84\"><\/a>4. GPT\u67b6\u6784<\/strong><\/h2>\n\n\n\n<p>GPT\u7684\u67b6\u6784\u662f\u57fa\u4e8e<a href=\"https:\/\/arxiv.org\/pdf\/1706.03762.pdf\" target=\"_blank\" rel=\"noreferrer noopener\">transformer<\/a>\u7684\uff1a<img decoding=\"async\" src=\"https:\/\/jiqihumanr.github.io\/images\/trans.png\" alt=\"\">\u6765\u81eaAttention is All You Need\u8bba\u6587\u7684\u56fe1<\/p>\n\n\n\n<p>\u4f46\u5b83\u4ec5\u4ec5\u4f7f\u7528\u4e86\u89e3\u7801\u5668\u5c42\uff08\u56fe\u4e2d\u7684\u53f3\u8fb9\u90e8\u5206\uff09\uff1a<img decoding=\"async\" src=\"https:\/\/jiqihumanr.github.io\/images\/gpt.png\" alt=\"\">&nbsp;&nbsp;GPT\u67b6\u6784<\/p>\n\n\n\n<p>\u6ce8\u610f\uff0c\u56e0\u4e3a\u6211\u4eec\u5df2\u7ecf\u641e\u5b9a\u4e86\u7f16\u7801\u5668\uff0c\u6240\u4ee5\u4e2d\u95f4\u7684\u201dcross-attention\u201d\u5c42\u4e5f\u88ab\u79fb\u9664\u4e86\u3002<\/p>\n\n\n\n<p>\u4ece\u5b8f\u89c2\u7684\u89d2\u5ea6\u6765\u770b\uff0cGPT\u67b6\u6784\u6709\u4e09\u4e2a\u90e8\u5206\u7ec4\u6210\uff1a<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>\u6587\u672c + \u4f4d\u7f6e<strong>\u5d4c\u5165<\/strong>(positional&nbsp;<strong>embeddings<\/strong>)<\/li>\n\n\n\n<li>\u57fa\u4e8etransformer\u7684<strong>\u89e3\u7801\u5668\u5c42<\/strong>(<strong>decoder stack<\/strong>)<\/li>\n\n\n\n<li><strong>\u6295\u5f71\u4e3a\u8bcd\u6c47\u8868<\/strong>(<strong>projection to vocab<\/strong>)\u7684\u6b65\u9aa4<\/li>\n<\/ul>\n\n\n\n<p>\u4ee3\u7801\u5c42\u9762\u7684\u8bdd\uff0c\u5c31\u50cf\u8fd9\u6837\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">def gpt2(inputs, wte, wpe, blocks, ln_f, n_head):  # [n_seq] -&gt; [n_seq, n_vocab]\n    # inputs: \u8f93\u5165\u5e8f\u5217\u7684\u6807\u8bb0(token)\u7d22\u5f15\uff0c\u5f62\u72b6\u4e3a [n_seq]\u3002\n    # wte: \u6807\u8bb0(token)\u5d4c\u5165\u77e9\u9635\uff0c\u7528\u4e8e\u5c06\u6807\u8bb0(token)\u7d22\u5f15\u8f6c\u6362\u4e3a\u5d4c\u5165\u5411\u91cf\u3002\n    # wpe: \u4f4d\u7f6e\u5d4c\u5165\u77e9\u9635\uff0c\u7528\u4e8e\u7ed9\u8f93\u5165\u5e8f\u5217\u4e2d\u7684\u6bcf\u4e2a\u6807\u8bb0(token)\u6dfb\u52a0\u4f4d\u7f6e\u4fe1\u606f\u3002\n    # blocks: Transformer\u6a21\u578b\u7684\u5c42\uff08\u6216\u79f0\u4e3ablock\uff09\uff0c\u6bcf\u4e2ablock\u5305\u542b\u4e00\u7ec4\u7528\u4e8e\u8be5\u5c42\u7684\u53c2\u6570\u3002\n    # ln_f: \u5e94\u7528\u4e8e\u6700\u540e\u8f93\u51fa\u7684\u5c42\u6b63\u5219\u5316\u51fd\u6570\u7684\u53c2\u6570\u3002\n    # n_head: \u6ce8\u610f\u529b\u673a\u5236\u4e2d\u5934\u7684\u6570\u91cf\u3002\n\n    # token + positional embeddings\n    # \u6807\u8bb0(token)\u5d4c\u5165 + \u4f4d\u7f6e\u5d4c\u5165\n    # \u4f7f\u7528\u8f93\u5165\u7d22\u5f15\u4ecewte\u4e2d\u83b7\u53d6\u6807\u8bb0(token)\u5d4c\u5165\uff0c\u5e76\u5c06\u5176\u4e0e\u5bf9\u5e94\u7684\u4f4d\u7f6e\u5d4c\u5165\u76f8\u52a0\u3002\n    # \u8fd9\u91cc\uff0crange(len(inputs))\u751f\u6210\u4e00\u4e2a\u4e0e\u8f93\u5165\u5e8f\u5217\u957f\u5ea6\u76f8\u540c\u7684\u4f4d\u7f6e\u7d22\u5f15\u5e8f\u5217\u3002\n    x = wte[inputs] + wpe[range(len(inputs))]  # [n_seq] -&gt; [n_seq, n_embd]\n\n    # \u901a\u8fc7n\u5c42Transformer block\u7684\u524d\u5411\u4f20\u64ad\n    for block in blocks:\n        x = transformer_block(x, **block, n_head=n_head)  # [n_seq, n_embd] -&gt; [n_seq, n_embd]\n\n    # \u5bf9\u8f93\u51fa\u5e94\u7528\u5c42\u6b63\u5219\u5316\n    x = layer_norm(x, **ln_f)  # [n_seq, n_embd] -&gt; [n_seq, n_embd]\n\n    # \u6295\u5f71\u5230\u8bcd\u6c47\u7a7a\u95f4\n    # \u4f7f\u7528\u8f6c\u7f6e\u7684\u6807\u8bb0(token)\u5d4c\u5165\u77e9\u9635wte.T\u5c06\u5d4c\u5165\u5411\u91cf\u6295\u5f71\u56de\u8bcd\u6c47\u7a7a\u95f4\uff0c\u5f97\u5230\u6700\u7ec8\u7684\u8f93\u51fa\u3002\n    return x @ wte.T  # [n_seq, n_embd] -&gt; [n_seq, n_vocab]\n<\/pre><\/div>\n\n\n\n<p>\u73b0\u5728\u6211\u4eec\u5c06\u4e0a\u9762\u4e09\u4e2a\u90e8\u5206\u505a\u66f4\u7ec6\u81f4\u7684\u62c6\u89e3\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\" id=\"\u5d4c\u5165\u5c42\"><strong>4.1 \u5d4c\u5165\u5c42(Embeddings)<\/strong><\/h3>\n\n\n\n<h4 class=\"wp-block-heading\" id=\"Token-\u5d4c\u5165\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#Token-%E5%B5%8C%E5%85%A5\"><\/a>4.1.1 Token \u5d4c\u5165(Token Embeddings)<\/strong><\/h4>\n\n\n\n<p>\u5bf9\u4e8e\u795e\u7ecf\u7f51\u7edc\u800c\u8a00\uff0ctoken ID\u672c\u8eab\u5e76\u4e0d\u662f\u4e00\u4e2a\u597d\u7684\u8868\u793a\u3002\u7b2c\u4e00\uff0ctoken ID\u7684\u76f8\u5bf9\u5927\u5c0f\u4f1a\u4f20\u9012\u9519\u8bef\u7684\u4fe1\u606f\uff08\u6bd4\u5982\uff0c\u5728\u6211\u4eec\u7684\u8bcd\u6c47\u8868\u4e2d\uff0c\u5982\u679c<code>Apple = 5<\/code>\uff0c<code>Table=10<\/code>\uff0c\u90a3\u5c31\u610f\u5473\u7740<code>2 * Table = Apple<\/code>\uff1f\u663e\u7136\u4e0d\u5bf9\uff09\u3002\u5176\u4e8c\uff0c\u5355\u4e2a\u7684\u6570\u4e5f\u6ca1\u6709\u8db3\u591f\u7684<strong>\u7ef4\u5ea6<\/strong>\u5582\u7ed9\u795e\u7ecf\u7f51\u7edc\u3002<\/p>\n\n\n\n<blockquote class=\"wp-block-quote is-layout-flow wp-block-quote-is-layout-flow\">\n<p>\u8bd1\u8005\u6ce8\uff1a\u5bf9\u4e8e\u7b2c\u4e8c\u70b9\u8865\u5145\u4e00\u53e5\uff0c\u4e5f\u5c31\u662f\u8bf4\u5355\u4e2a\u7684\u6570\u5b57\u5305\u542b\u7684\u7279\u5f81\u4fe1\u606f\u4e0d\u591f\u4e30\u5bcc<\/p>\n<\/blockquote>\n\n\n\n<p>\u4e3a\u4e86\u89e3\u51b3\u8fd9\u4e9b\u9650\u5236\uff0c\u6211\u4eec\u5c06\u5229\u7528<a href=\"https:\/\/jaykmody.com\/blog\/attention-intuition\/#word-vectors-and-similarity\" target=\"_blank\" rel=\"noreferrer noopener\">\u8bcd\u5411\u91cf<\/a>(<a href=\"https:\/\/jaykmody.com\/blog\/attention-intuition\/#word-vectors-and-similarity\">word vectors<\/a>)\uff0c\u5373\u901a\u8fc7\u4e00\u4e2a\u5b66\u4e60\u5230\u7684\u5d4c\u5165\u77e9\u9635\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">wte[inputs] # [n_seq] -&gt; [n_seq, n_embd]\n# wte\u4ee3\u8868\u8bcd\u5d4c\u5165\u77e9\u9635\uff0c\n# inputs\u662f\u4e00\u4e2a\u5305\u542b\u5e8f\u5217\u4e2d\u6bcf\u4e2a\u6807\u7b7e(token)\u7d22\u5f15\u7684\u6570\u7ec4<\/pre><\/div>\n\n\n\n<p>\u8fd8\u8bb0\u5f97\u5417\uff1f<code>wte<\/code>\u662f\u4e00\u4e2a<code>[n_vocab, n_emdb]<\/code>\u7684\u77e9\u9635\u3002\u8fd9\u5c31\u50cf\u4e00\u4e2a\u67e5\u627e\u8868\uff0c\u77e9\u9635\u4e2d\u7684\u7b2c<em><em>i<\/em>th<\/em>\u884c\u5bf9\u5e94\u6211\u4eec\u7684\u8bcd\u6c47\u8868\u4e2d\u7684\u7b2c<em>i<\/em>th\u4e2atoken\u7684\u5411\u91cf\u8868\u793a\uff08\u5b66\u51fa\u6765\u7684\uff09\u3002<code>wte[inputs]<\/code>\u4f7f\u7528\u4e86\u6570\u6570\u7ec4\u7d22\u5f15(<a href=\"https:\/\/numpy.org\/doc\/stable\/user\/basics.indexing.html#integer-array-indexing\" target=\"_blank\" rel=\"noreferrer noopener\">integer array indexing<\/a>)\u6765\u68c0\u7d22\u6211\u4eec\u8f93\u5165\u4e2d\u6bcf\u4e2atoken\u6240\u5bf9\u5e94\u7684\u5411\u91cf\u3002<\/p>\n\n\n\n<p>\u5c31\u50cf\u795e\u7ecf\u7f51\u7edc\u4e2d\u7684\u5176\u4ed6\u53c2\u6570\uff0c<code>wte<\/code>\u662f\u53ef\u5b66\u4e60\u7684\u3002\u4e5f\u5c31\u662f\u8bf4\uff0c\u5728\u8bad\u7ec3\u5f00\u59cb\u7684\u65f6\u5019\u5b83\u662f\u968f\u673a\u521d\u59cb\u5316\u7684\uff0c\u7136\u540e\u968f\u7740\u8bad\u7ec3\u7684\u8fdb\u884c\uff0c\u901a\u8fc7\u68af\u5ea6\u4e0b\u964d\u4e0d\u65ad\u66f4\u65b0\u3002<\/p>\n\n\n\n<h4 class=\"wp-block-heading\" id=\"\u4f4d\u7f6e\u5d4c\u5165\uff08Positional-Embeddings\uff09\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#%E4%BD%8D%E7%BD%AE%E5%B5%8C%E5%85%A5%EF%BC%88Positional-Embeddings%EF%BC%89\"><\/a>4.1.2 \u4f4d\u7f6e\u5d4c\u5165\uff08Positional Embeddings\uff09<\/strong><\/h4>\n\n\n\n<p>\u5355\u7eaf\u7684transformer\u67b6\u6784\u7684\u4e00\u4e2a\u53e4\u602a\u5730\u65b9\u5728\u4e8e\u5b83\u5e76\u4e0d\u8003\u8651\u4f4d\u7f6e\u3002\u5f53\u6211\u4eec\u968f\u673a\u6253\u4e71\u8f93\u5165\u4f4d\u7f6e\u987a\u5e8f\u7684\u65f6\u5019\uff0c\u8f93\u51fa\u53ef\u4ee5\u4fdd\u6301\u4e0d\u53d8\uff08\u8f93\u5165\u7684\u987a\u5e8f\u5bf9\u8f93\u51fa\u5e76\u672a\u4ea7\u751f\u5f71\u54cd\uff09\u3002<\/p>\n\n\n\n<p>\u53ef\u662f\u8bcd\u7684\u987a\u5e8f\u5f53\u7136\u662f\u8bed\u8a00\u4e2d\u91cd\u8981\u7684\u90e8\u5206\u554a\uff0c\u56e0\u6b64\u6211\u4eec\u9700\u8981\u4f7f\u7528\u67d0\u4e9b\u65b9\u5f0f\u5c06\u4f4d\u7f6e\u4fe1\u606f\u7f16\u7801\u8fdb\u6211\u4eec\u7684\u8f93\u5165\u3002\u4e3a\u4e86\u8fd9\u4e2a\u76ee\u6807\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u53e6\u4e00\u4e2a\u5b66\u4e60\u5230\u7684\u5d4c\u5165\u77e9\u9635\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">wpe[range(len(inputs))] # [n_seq] -&gt; [n_seq, n_embd]\n# wpe\u4ee3\u8868\u4f4d\u7f6e\u5d4c\u5165\u77e9\u9635\uff0c\n# inputs\u662f\u8f93\u5165\u5e8f\u5217\u7684\u6807\u7b7e(token)\u7d22\u5f15\u5217\u8868\n<\/pre><\/div>\n\n\n\n<p><code>wpe<\/code>\u662f\u4e00\u4e2a<code>[n_ctx, n_emdb]<\/code>\u77e9\u9635\u3002\u77e9\u9635\u7684\u7b2c<em>i<\/em>th\u884c\u5305\u542b\u4e00\u4e2a\u7f16\u7801\u8f93\u5165\u4e2d\u7b2c<em>i<\/em>th\u4e2a\u4f4d\u7f6e\u4fe1\u606f\u7684\u5411\u91cf\u3002\u4e0e<code>wte<\/code>\u7c7b\u4f3c\uff0c\u8fd9\u4e2a\u77e9\u9635\u4e5f\u662f\u901a\u8fc7\u68af\u5ea6\u4e0b\u964d\u6765\u5b66\u4e60\u5230\u7684\u3002<\/p>\n\n\n\n<p>\u9700\u8981\u6ce8\u610f\u7684\u662f\uff0c\u8fd9\u5c06\u9650\u5236\u6a21\u578b\u7684\u6700\u5927\u5e8f\u5217\u957f\u5ea6\u4e3a<code>n_ctx<\/code><sup><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#fn:4\">[4]<\/a><\/sup>\u3002\u4e5f\u5c31\u662f\u8bf4\u5fc5\u987b\u6ee1\u8db3<code>len(inputs) &lt;= n_ctx<\/code>\u3002<\/p>\n\n\n\n<h4 class=\"wp-block-heading\" id=\"\u7ec4\u5408\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#%E7%BB%84%E5%90%88\"><\/a>4.1.3 \u7ec4\u5408(Combined&nbsp;)<\/strong><\/h4>\n\n\n\n<p>\u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u5c06token\u5d4c\u5165\u4e0e\u4f4d\u7f6e\u5d4c\u5165\u8054\u5408\u4e3a\u4e00\u4e2a\u7ec4\u5408\u5d4c\u5165\uff0c\u8fd9\u4e2a\u5d4c\u5165\u5c06token\u4fe1\u606f\u548c\u4f4d\u7f6e\u4fe1\u606f\u90fd\u7f16\u7801\u8fdb\u6765\u4e86\u3002<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \"># \u6807\u7b7e(token)\u5d4c\u5165 + \u4f4d\u7f6e\u5d4c\u5165\nx = wte[inputs] + wpe[range(len(inputs))]  # [n_seq] -&gt; [n_seq, n_embd]\n# x[i]\u4ee3\u8868\u7b2ci\u4e2a\u8bcd\u7684\u8bcd\u5d4c\u5165\u52a0\u4e0a\u7b2ci\u4e2a\u4f4d\u7f6e\u7684\u4f4d\u7f6e\u5d4c\u5165\n\n<\/pre><\/div>\n\n\n\n<h3 class=\"wp-block-heading\" id=\"\u89e3\u7801\u5c42\"><strong>4.2 \u89e3\u7801\u5c42(Decoder Stack)<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u5c31\u662f\u795e\u5947\u53d1\u751f\u7684\u5730\u65b9\u4e86\uff0c\u4e5f\u662f\u6df1\u5ea6\u5b66\u4e60\u4e2d\u201c\u6df1\u5ea6\u201c\u7684\u6765\u6e90\u3002\u6211\u4eec\u5c06\u521a\u624d\u7684\u5d4c\u5165\u901a\u8fc7\u4e00\u8fde\u4e32\u7684 <code>n_layer<\/code> transformer \u89e3\u7801\u5668\u6a21\u5757\u3002<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \"># \u901a\u8fc7n\u5c42Transformer block\u8fdb\u884c\u524d\u5411\u4f20\u64ad\nfor block in blocks:\n    # x: \u5f53\u524d\u5c42\u7684\u8f93\u5165\u6570\u636e\uff0c\u5f62\u72b6\u4e3a[n_seq, n_embd]\uff0c\u5176\u4e2dn_seq\u662f\u5e8f\u5217\u957f\u5ea6\uff0cn_embd\u662f\u5d4c\u5165\u7ef4\u5ea6\u3002\n    # block: \u5f53\u524dTransformer\u5c42\u7684\u53c2\u6570\uff0c\u5305\u62ec\u8be5\u5c42\u6240\u9700\u7684\u6240\u6709\u6743\u91cd\u548c\u504f\u7f6e\u7b49\u3002\n    # n_head: \u6ce8\u610f\u529b\u673a\u5236\u4e2d\u7684\u5934\u6570\uff0c\u8fd9\u662fTransformer\u6a21\u578b\u7684\u4e00\u4e2a\u91cd\u8981\u53c2\u6570\u3002\n    # transformer_block: \u4e00\u4e2a\u51fd\u6570\uff0c\u5b9e\u73b0\u4e86Transformer\u5c42\u7684\u8ba1\u7b97\u903b\u8f91\u3002\n    # **block: \u5c06block\u5b57\u5178\u5c55\u5f00\u4e3a\u5173\u952e\u5b57\u53c2\u6570\uff0c\u4f20\u9012\u7ed9transformer_block\u51fd\u6570\u3002\n    # [n_seq, n_embd] -&gt; [n_seq, n_embd]: \u8868\u793a\u8f93\u5165\u548c\u8f93\u51fa\u7684\u5f62\u72b6\u4e0d\u53d8\uff0c\u4f9d\u7136\u4e3a[n_seq, n_embd]\u3002\n    x = transformer_block(x, **block, n_head=n_head)\n<\/pre><\/div>\n\n\n\n<p>\u4e00\u65b9\u9762\uff0c\u5806\u53e0\u66f4\u591a\u7684\u5c42\u8ba9\u6211\u4eec\u53ef\u4ee5\u63a7\u5236\u5230\u5e95\u6211\u4eec\u7684\u7f51\u7edc\u6709\u591a<em>\u201c\u6df1\u201d<\/em>\u3002\u4ee5GPT-3\u4e3a\u4f8b\uff0c\u5176<a href=\"https:\/\/preview.redd.it\/n9fgba8b0qr01.png?auto=webp&amp;s=e86d2d3447c777d3222016e81a0adfaec1a95592\" target=\"_blank\" rel=\"noreferrer noopener\">\u9ad8\u8fbe96\u5c42<\/a>\u3002\u53e6\u4e00\u65b9\u9762\uff0c\u9009\u62e9\u4e00\u4e2a\u66f4\u5927\u7684<code>n_embd<\/code>\u503c\uff0c\u8ba9\u6211\u4eec\u53ef\u4ee5\u63a7\u5236\u7f51\u7edc\u6709\u591a<em>\u201c\u5bbd\u201d<\/em>\uff08\u8fd8\u662f\u4ee5GPT-3\u4e3a\u4f8b\uff0c\u5b83\u4f7f\u7528\u7684\u5d4c\u5165\u5927\u5c0f\u4e3a12288\uff09\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\" id=\"\u6295\u5f71\u4e3a\u8bcd\u6c47\u8868-projection-to-vocab\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#%E6%8A%95%E5%BD%B1%E4%B8%BA%E8%AF%8D%E6%B1%87%E8%A1%A8-projection-to-vocab\"><\/a>4.3 \u6295\u5f71\u4e3a\u8bcd\u6c47\u8868(Projection to vocab)<\/strong><\/h3>\n\n\n\n<p>\u5728\u6700\u540e\u7684\u6b65\u9aa4\u4e2d\uff0c\u6211\u4eec\u5c06transformer\u6700\u540e\u4e00\u4e2a\u7ed3\u6784\u5757\u7684\u8f93\u5165\u6295\u5f71\u4e3a\u5b57\u7b26\u8868\u7684\u4e00\u4e2a\u6982\u7387\u5206\u5e03\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \"># \u6295\u5f71\u5230\u8bcd\u6c47\u8868\nx = layer_norm(x, **ln_f)  # [n_seq, n_embd] -&gt; [n_seq, n_embd]\n<\/pre><\/div>\n\n\n\n<p>\u8fd9\u91cc\u6709\u4e00\u4e9b\u9700\u8981\u6ce8\u610f\u7684\u70b9\uff1a<\/p>\n\n\n\n<ol class=\"wp-block-list\">\n<li>\u5728\u8fdb\u884c\u6295\u5f71\u64cd\u4f5c\u4e4b\u524d\uff0c\u6211\u4eec\u5148\u5c06<code>x<\/code>\u901a\u8fc7<strong>\u6700\u540e\u7684\u5c42\u5f52\u4e00\u5316\u5c42<\/strong>(<strong>final layer normalization<\/strong>)\u3002\u8fd9\u662fGPT-2\u67b6\u6784\u6240\u7279\u6709\u7684\uff08\u5e76\u6ca1\u6709\u51fa\u73b0\u5728GPT\u539f\u59cb\u8bba\u6587\u548cTransformer\u8bba\u6587\u4e2d\uff09\u3002<\/li>\n\n\n\n<li>\u6211\u4eec<strong>\u590d\u7528\u4e86\u5d4c\u5165\u77e9\u9635<\/strong><code>(<strong>final layer normalization<\/strong>) wte <\/code>\u8fdb\u884c\u6295\u5f71\u64cd\u4f5c\u3002\u5176\u5b83\u7684GPT\u5b9e\u73b0\u5f53\u7136\u53ef\u4ee5\u9009\u62e9\u4f7f\u7528\u53e6\u5916\u5b66\u4e60\u5230\u7684\u6743\u91cd\u77e9\u9635\u8fdb\u884c\u6295\u5f71\uff0c\u4f46\u662f\u6743\u91cd\u77e9\u9635\u5171\u4eab\u5177\u6709\u4ee5\u4e0b\u4e00\u4e9b\u4f18\u52bf\uff1a\n<ul class=\"wp-block-list\">\n<li>\u4f60\u53ef\u4ee5\u8282\u7701\u4e00\u4e9b\u53c2\u6570\uff08\u867d\u7136\u5bf9\u4e8eGPT-3\u8fd9\u6837\u7684\u4f53\u91cf\uff0c\u8fd9\u4e2a\u8282\u7701\u57fa\u672c\u53ef\u4ee5\u5ffd\u7565\uff09<\/li>\n\n\n\n<li>\u8003\u8651\u5230\u8fd9\u4e2a\u77e9\u9635\u4f5c\u7528\u4e8e<strong>\u8f6c\u6362\u5230\u8bcd<\/strong>\u4e0e<strong>\u6765\u81ea\u4e8e\u8bcd<\/strong>\u7684\u4e24\u79cd\u8f6c\u6362\uff0c\u7406\u8bba\u4e0a\uff0c\u76f8\u5bf9\u4e8e\u5206\u522b\u4f7f\u7528\u4e24\u4e2a\u77e9\u9635\u6765\u505a\u8fd9\u4ef6\u4e8b\uff0c\u4f7f\u7528\u540c\u4e00\u4e2a\u77e9\u9635\u5c06\u5b66\u5230\u66f4\u4e3a\u4e30\u5bcc\u7684\u8868\u5f81\u3002<\/li>\n<\/ul>\n<\/li>\n\n\n\n<li>\u5728\u6700\u540e\uff0c\u6211\u4eec<strong>\u5e76\u672a\u4f7f\u7528 <code>softmax<\/code><\/strong>\uff0c\u56e0\u6b64\u6211\u4eec\u7684\u8f93\u51fa\u662f<a href=\"https:\/\/developers.google.com\/machine-learning\/glossary\/#logits\" target=\"_blank\" rel=\"noreferrer noopener\"><code>logits<\/code><\/a>\u800c\u4e0d\u662f0-1\u4e4b\u95f4\u7684\u6982\u7387\u3002\u8fd9\u6837\u505a\u7684\u7406\u7531\u662f\uff1a\n<ul class=\"wp-block-list\">\n<li><code>softmax<\/code>\u662f<a href=\"https:\/\/en.wikipedia.org\/wiki\/Monotonic_function\" target=\"_blank\" rel=\"noreferrer noopener\">\u5355\u8c03\u7684<\/a>(<a href=\"https:\/\/en.wikipedia.org\/wiki\/Monotonic_function\">monotonic<\/a>)\uff0c\u56e0\u6b64\u5bf9\u4e8e\u8d2a\u5fc3\u91c7\u6837\u800c\u8a00\uff0c<code>np.argmax(logits)<\/code>\u548c<code>np.argmax(softmax(logits))<\/code>\u662f\u7b49\u4ef7\u7684\uff0c\u56e0\u6b64\u4f7f\u7528<code>softmax<\/code>\u5c31\u53d8\u5f97\u591a\u6b64\u4e00\u4e3e\u3002<\/li>\n\n\n\n<li><code>softmax<\/code>\u662f\u4e0d\u53ef\u9006\u7684\uff0c\u8fd9\u610f\u5473\u7740\u6211\u4eec\u603b\u662f\u53ef\u4ee5\u901a\u8fc7<code>softmax<\/code>\u5c06<code>logits<\/code>\u53d8\u4e3a<code>probabilities<\/code>\uff0c\u4f46\u4e0d\u80fd\u4ece<code>probabilities<\/code>\u53d8\u4e3a<code>softmax<\/code>\uff0c\u4e3a\u4e86\u8ba9\u7075\u6d3b\u6027\u6700\u5927\uff0c\u6211\u4eec\u9009\u62e9\u76f4\u63a5\u8f93\u51fa<code>logits<\/code>\u3002<\/li>\n\n\n\n<li>\u6570\u503c\u7a33\u5b9a\u6027\u7684\u8003\u91cf\u3002\u6bd4\u5982\u8ba1\u7b97\u4ea4\u53c9\u71b5\u635f\u5931\u7684\u65f6\u5019\uff0c<a href=\"https:\/\/jaykmody.com\/blog\/stable-softmax\/#cross-entropy-and-log-softmax\" target=\"_blank\" rel=\"noreferrer noopener\">\u76f8\u5bf9\u4e8e<code>log_softmax(logits)<\/code>\uff0c<code>log(softmax(logits))<\/code>\u7684\u6570\u503c\u7a33\u5b9a\u6027\u5c31\u5dee<\/a>\u3002<\/li>\n<\/ul>\n<\/li>\n<\/ol>\n\n\n\n<p>\u6295\u5f71\u4e3a\u8bcd\u6c47\u8868\u7684\u8fc7\u7a0b\u6709\u65f6\u5019\u4e5f\u88ab\u79f0\u4e4b\u4e3a<strong>\u8bed\u8a00\u5efa\u6a21\u5934\uff08language modeling head\uff09<\/strong>\u3002\u8fd9\u91cc\u7684\u201c\u5934\u201d\u662f\u4ec0\u4e48\u610f\u601d\u5462\uff1f\u4f60\u7684GPT\u4e00\u65e6\u88ab\u9884\u8bad\u7ec3\u5b8c\u6bd5\uff0c\u90a3\u4e48\u4f60\u53ef\u4ee5\u901a\u8fc7\u66f4\u6362\u5176\u4ed6\u6295\u5f71\u64cd\u4f5c\u7684\u8bed\u8a00\u5efa\u6a21\u5934\uff0c\u6bd4\u5982\u4f60\u53ef\u4ee5\u5c06\u5176\u66f4\u6362\u4e3a<strong>\u5206\u7c7b\u5934<\/strong>(<strong>classification head<\/strong>&nbsp;)\uff0c\u4ece\u800c\u5728\u4e00\u4e9b\u5206\u7c7b\u4efb\u52a1\u4e0a\u5fae\u8c03\u4f60\u7684\u6a21\u578b\uff08\u8ba9\u5176\u5b8c\u6210\u5206\u7c7b\u4efb\u52a1\uff09\u3002\u56e0\u6b64\u4f60\u7684\u6a21\u578b\u53ef\u4ee5\u62e5\u6709\u591a\u79cd\u5934\uff0c\u611f\u89c9\u6709\u70b9\u50cf<a href=\"https:\/\/en.wikipedia.org\/wiki\/Lernaean_Hydra\" target=\"_blank\" rel=\"noreferrer noopener\">hydra<\/a>\u3002<\/p>\n\n\n\n<blockquote class=\"wp-block-quote is-layout-flow wp-block-quote-is-layout-flow\">\n<p>\u8bd1\u8005\u6ce8\uff1ahydra\u662f\u5e0c\u814a\u795e\u8bdd\u4e2d\u7684\u4e5d\u5934\u86c7\uff0c\u611f\u53d7\u4e00\u4e0b<\/p>\n<\/blockquote>\n\n\n\n<p>\u597d\u4e86\uff0c\u4ee5\u4e0a\u5c31\u662fGPT\u67b6\u6784\u7684\u5b8f\u89c2\u4ecb\u7ecd\u3002\u90a3\u4e48\u73b0\u5728\u6211\u4eec\u518d\u6765\u770b\u770b\u89e3\u7801\u5668\u6a21\u5757\u7684\u7ec6\u8282\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\" id=\"\u89e3\u7801\u5668\u6a21\u5757\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#%E8%A7%A3%E7%A0%81%E5%99%A8%E6%A8%A1%E5%9D%97\"><\/a>4.4 \u89e3\u7801\u5668\u6a21\u5757(Decoder Block)<\/strong><\/h3>\n\n\n\n<p>transformer\u89e3\u7801\u5668\u6a21\u5757\u7531\u4e24\u4e2a\u5b50\u5c42\u7ec4\u6210\uff1a<\/p>\n\n\n\n<ol class=\"wp-block-list\">\n<li>\u591a\u5934\u56e0\u679c\u81ea\u6ce8\u610f\u529b\uff08Multi-head causal self attention\uff09<\/li>\n\n\n\n<li>\u9010\u4f4d\u7f6e\u524d\u9988\u795e\u7ecf\u7f51\u7edc\uff08Position-wise feed forward neural network\uff09<\/li>\n<\/ol>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">def transformer_block(x, mlp, attn, ln_1, ln_2, n_head):  # [n_seq, n_embd] -&gt; [n_seq, n_embd]\n    # x\u662f\u8f93\u5165\u7684\u5d4c\u5165\u5411\u91cf\uff0c\u5f62\u72b6\u4e3a[n_seq, n_embd]\uff0c\u5176\u4e2dn_seq\u662f\u5e8f\u5217\u957f\u5ea6\uff0cn_embd\u662f\u5d4c\u5165\u7ef4\u5ea6\u3002\n    # mlp\u548cattn\u5206\u522b\u662f\u524d\u9988\u7f51\u7edc\u548c\u6ce8\u610f\u529b\u673a\u5236\u7684\u53c2\u6570\u3002\n    # ln_1\u548cln_2\u662f\u5e94\u7528\u4e8e\u81ea\u6ce8\u610f\u529b\u673a\u5236\u548c\u524d\u9988\u7f51\u7edc\u4e4b\u524d\u7684\u5c42\u5f52\u4e00\u5316\u53c2\u6570\u3002\n    # n_head\u662f\u591a\u5934\u6ce8\u610f\u529b\u673a\u5236\u4e2d\u5934\u7684\u6570\u91cf\u3002\n    # \u591a\u5934\u56e0\u679c\u81ea\u6ce8\u610f\u529b\u673a\u5236\n    # \u9996\u5148\uff0c\u5bf9\u8f93\u5165x\u5e94\u7528\u5c42\u5f52\u4e00\u5316\uff08layer_norm\uff09\uff0c\u7136\u540e\u4f20\u9012\u7ed9\u591a\u5934\u81ea\u6ce8\u610f\u529b\u673a\u5236\uff08mha\uff09\u3002\n    # mha\u51fd\u6570\u63a5\u53d7\u5f52\u4e00\u5316\u540e\u7684x\u3001\u6ce8\u610f\u529b\u673a\u5236\u7684\u53c2\u6570(**attn)\u548c\u5934\u7684\u6570\u91cf(n_head)\u3002\n    # \u7136\u540e\u5c06\u81ea\u6ce8\u610f\u529b\u7684\u8f93\u51fa\u4e0e\u539f\u59cb\u8f93\u5165x\u76f8\u52a0\uff0c\u5b9e\u73b0\u6b8b\u5dee\u8fde\u63a5\u3002\n    # \u8fd9\u4e00\u6b65\u4fdd\u6301\u4e86x\u7684\u5f62\u72b6\u4e0d\u53d8\uff0c\u5373[n_seq, n_embd]\u3002\n    x = x + mha(layer_norm(x, **ln_1), **attn, n_head=n_head)\n\n    # \u4f4d\u7f6e\u5168\u8fde\u63a5\u524d\u9988\u7f51\u7edc\n    # \u518d\u6b21\u5bf9x\u5e94\u7528\u5c42\u5f52\u4e00\u5316\uff08layer_norm\uff09\uff0c\u7136\u540e\u4f20\u9012\u7ed9\u524d\u9988\u7f51\u7edc\uff08ffn\uff09\u3002\n    # ffn\u51fd\u6570\u63a5\u53d7\u5f52\u4e00\u5316\u540e\u7684x\u548c\u524d\u9988\u7f51\u7edc\u7684\u53c2\u6570(**mlp)\u3002\n    # \u540c\u6837\u5730\uff0c\u5c06\u524d\u9988\u7f51\u7edc\u7684\u8f93\u51fa\u4e0e\u8f93\u5165x\u76f8\u52a0\uff0c\u5b9e\u73b0\u53e6\u4e00\u4e2a\u6b8b\u5dee\u8fde\u63a5\u3002\n    # \u8fd9\u4e2a\u6b65\u9aa4\u4e5f\u4fdd\u6301\u4e86x\u7684\u5f62\u72b6\u4e0d\u53d8\uff0c\u5373[n_seq, n_embd]\u3002\n    x = x + ffn(layer_norm(x, **ln_2), **mlp)\n\n    # \u8fd4\u56de\u7ecf\u8fc7\u4e00\u4e2aTransformer\u5c42\u5904\u7406\u540e\u7684\u8f93\u51fax\u3002\n    return x\n<\/pre><\/div>\n\n\n\n<p>\u6bcf\u4e2a\u5b50\u5c42\u90fd\u5728\u8f93\u5165\u4e0a\u4f7f\u7528\u4e86\u5c42\u5f52\u4e00\u5316\uff0c\u4e5f\u4f7f\u7528\u4e86\u6b8b\u5dee\u8fde\u63a5\uff08\u5373\u5c06\u5b50\u5c42\u7684\u8f93\u5165\u76f4\u63a5\u8fde\u63a5\u5230\u5b50\u5c42\u7684\u8f93\u51fa\uff09\u3002<\/p>\n\n\n\n<p>\u5148\u8bb2\u51e0\u6761\u6ce8\u610f\u70b9\uff1a<\/p>\n\n\n\n<ol class=\"wp-block-list\">\n<li><strong>\u591a\u5934\u56e0\u679c\u81ea\u6ce8\u610f\u529b\u673a\u5236<\/strong>(<strong>Multi-head causal self attention<\/strong>)\u4fbf\u4e8e\u8f93\u5165\u4e4b\u95f4\u7684\u901a\u4fe1\u3002\u5728\u7f51\u7edc\u7684\u5176\u5b83\u5730\u65b9\uff0c\u6a21\u578b\u662f\u4e0d\u5141\u8bb8\u8f93\u5165\u76f8\u4e92\u201c\u770b\u5230\u201d\u5f7c\u6b64\u7684\u3002\u5d4c\u5165\u5c42\u3001\u9010\u4f4d\u7f6e\u524d\u9988\u7f51\u7edc\u3001\u5c42\u5f52\u4e00\u5316\u4ee5\u53ca\u6295\u5f71\u5230\u8bcd\u6c47\u8868\u7684\u64cd\u4f5c\uff0c\u90fd\u662f\u9010\u4f4d\u7f6e\u5bf9\u6211\u4eec\u7684\u8f93\u5165\u8fdb\u884c\u7684\u3002\u5efa\u6a21\u8f93\u5165\u4e4b\u95f4\u7684\u5173\u7cfb\u5b8c\u5168\u7531\u6ce8\u610f\u529b\u673a\u5236\u6765\u5904\u7406\u3002<\/li>\n\n\n\n<li><strong>\u9010\u4f4d\u7f6e\u524d\u9988\u795e\u7ecf\u7f51\u7edc<\/strong>(<strong>Position-wise feed forward neural network<\/strong>)\u53ea\u662f\u4e00\u4e2a\u5e38\u89c4\u7684\u4e24\u5c42\u5168\u8fde\u63a5\u795e\u7ecf\u7f51\u7edc\u3002\u5b83\u53ea\u662f\u4e3a\u6211\u4eec\u7684\u6a21\u578b\u589e\u52a0\u4e00\u4e9b\u53ef\u5b66\u4e60\u7684\u53c2\u6570\uff0c\u4ee5\u4fc3\u8fdb\u5b66\u4e60\u8fc7\u7a0b\u3002<\/li>\n\n\n\n<li>\u5728\u539f\u59cb\u7684transformer\u8bba\u6587\u4e2d\uff0c\u5c42\u5f52\u4e00\u5316\u88ab\u653e\u7f6e\u5728\u8f93\u51fa\u5c42<code>layer_norm(x + sublayer(x))<\/code>\u4e0a\uff0c\u800c\u6211\u4eec\u5728\u8fd9\u91cc\u4e3a\u4e86\u5339\u914dGPT-2\uff0c\u5c06\u5c42\u5f52\u4e00\u5316\u653e\u7f6e\u5728\u8f93\u5165<code>x + sublayer(layer_norm(x))<\/code>\u4e0a\u3002\u8fd9\u88ab\u79f0\u4e3a<strong>\u9884\u5f52\u4e00\u5316<\/strong>\uff0c\u5e76\u4e14\u5df2\u88ab<a href=\"https:\/\/arxiv.org\/pdf\/2002.04745.pdf\" target=\"_blank\" rel=\"noreferrer noopener\">\u8bc1\u660e\u5728\u6539\u5584transformer\u7684\u6027\u80fd\u65b9\u9762\u975e\u5e38\u91cd\u8981<\/a>\u3002<\/li>\n\n\n\n<li><strong>\u6b8b\u5dee\u8fde\u63a5<\/strong>(<strong>Residual connections<\/strong>)\uff08\u7531\u4e8e<a href=\"https:\/\/arxiv.org\/pdf\/1512.03385.pdf\" target=\"_blank\" rel=\"noreferrer noopener\">ResNet<\/a>\u800c\u5e7f\u4e3a\u4eba\u77e5\uff09\u8fd9\u8fd9\u91cc\u6709\u51e0\u4e2a\u4e0d\u540c\u7684\u76ee\u7684\uff1a\n<ul class=\"wp-block-list\">\n<li>1.\u4f7f\u5f97\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\uff08\u5373\u5c42\u6570\u975e\u5e38\u591a\u7684\u795e\u7ecf\u7f51\u7edc\uff09\u66f4\u5bb9\u6613\u8fdb\u884c\u4f18\u5316\u3002\u5176\u601d\u60f3\u662f\u4e3a\u68af\u5ea6\u63d0\u4f9b\u201c\u6377\u5f84\u201d\uff0c\u4f7f\u5f97\u68af\u5ea6\u66f4\u5bb9\u6613\u5730\u56de\u4f20\u5230\u7f51\u7edc\u7684\u521d\u59cb\u7684\u5c42\uff0c\u4ece\u800c\u66f4\u5bb9\u6613\u8fdb\u884c\u4f18\u5316\u3002<\/li>\n\n\n\n<li>2.\u5982\u679c\u6ca1\u6709\u6b8b\u5dee\u8fde\u63a5\u7684\u8bdd\uff0c\u52a0\u6df1\u6a21\u578b\u5c42\u6570\u4f1a\u5bfc\u81f4\u6027\u80fd\u4e0b\u964d\uff08\u53ef\u80fd\u662f\u56e0\u4e3a\u68af\u5ea6\u5f88\u96be\u5728\u6ca1\u6709\u635f\u5931\u4fe1\u606f\u7684\u60c5\u51b5\u4e0b\u56de\u4f20\u5230\u6574\u4e2a\u6df1\u5c42\u7f51\u7edc\u4e2d\uff09\u3002\u6b8b\u5dee\u8fde\u63a5\u4f3c\u4e4e\u53ef\u4ee5\u4e3a\u66f4\u6df1\u5c42\u7684\u7f51\u7edc\u63d0\u4f9b\u4e00\u4e9b\u7cbe\u5ea6\u63d0\u5347\u3002<\/li>\n\n\n\n<li>3.\u53ef\u4ee5\u5e2e\u52a9\u89e3\u51b3<a href=\"https:\/\/programmathically.com\/understanding-the-exploding-and-vanishing-gradients-problem\/\" target=\"_blank\" rel=\"noreferrer noopener\">\u68af\u5ea6\u6d88\u5931\/\u7206\u70b8\u7684\u95ee\u9898<\/a>\u3002<\/li>\n<\/ul>\n<\/li>\n<\/ol>\n\n\n\n<p>\u73b0\u5728\u6211\u4eec\u518d\u6df1\u5165\u8ba8\u8bba\u4e00\u4e0b\u8fd9\u4e24\u4e2a\u5b50\u5c42\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\" id=\"\u9010\u4f4d\u7f6e\u524d\u9988\u7f51\u7edc\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#%E9%80%90%E4%BD%8D%E7%BD%AE%E5%89%8D%E9%A6%88%E7%BD%91%E7%BB%9C\"><\/a>4.5 \u9010\u4f4d\u7f6e\u524d\u9988\u7f51\u7edc(Position-wise Feed Forward Network)<\/strong><\/h3>\n\n\n\n<p>\u9010\u4f4d\u7f6e\u524d\u9988\u7f51\u7edc\uff08Position-wise Feed Forward Network\uff09\u662f\u4e00\u4e2a\u7b80\u5355\u7684 2 \u5c42\u7684\u591a\u5c42\u611f\u77e5\u5668\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">def ffn(x, c_fc, c_proj):  # [n_seq, n_embd] -&gt; [n_seq, n_embd]\n    # \u5411\u4e0a\u6295\u5f71\n    # \u9996\u5148\uff0c\u4f7f\u7528linear\u51fd\u6570\u548cGaussian Error Linear Unit (GELU)\u6fc0\u6d3b\u51fd\u6570\u5bf9\u8f93\u5165x\u8fdb\u884c\u7ebf\u6027\u53d8\u6362\u548c\u975e\u7ebf\u6027\u6fc0\u6d3b\u3002\n    # linear\u51fd\u6570\u901a\u8fc7c_fc\u53c2\u6570\uff08\u5305\u542b\u6743\u91cd\u548c\u504f\u7f6e\uff09\u5c06\u6bcf\u4e2a\u4f4d\u7f6e\u7684\u5d4c\u5165\u5411\u91cf\u4ecen_embd\u7ef4\u6269\u5c55\u52304*n_embd\u7ef4\u3002\n    # \u8fd9\u4e00\u6b65\u7684\u76ee\u7684\u662f\u589e\u52a0\u7f51\u7edc\u7684\u8868\u793a\u80fd\u529b\u3002\n    a = gelu(linear(x, **c_fc))  # [n_seq, n_embd] -&gt; [n_seq, 4*n_embd]\n\n    # \u5411\u4e0b\u6295\u5f71\n    # \u7136\u540e\uff0c\u4f7f\u7528\u53e6\u4e00\u4e2alinear\u51fd\u6570\u901a\u8fc7c_proj\u53c2\u6570\u5c06\u6269\u5c55\u540e\u7684\u5d4c\u5165\u5411\u91cf\u4ece4*n_embd\u7ef4\u538b\u7f29\u56den_embd\u7ef4\u3002\n    # \u8fd9\u6837\u505a\u65e2\u53ef\u4ee5\u589e\u52a0\u6a21\u578b\u7684\u6df1\u5ea6\u548c\u590d\u6742\u6027\uff0c\u53c8\u80fd\u4fdd\u6301\u8f93\u5165\u548c\u8f93\u51fa\u7684\u7ef4\u5ea6\u4e00\u81f4\uff0c\u4fbf\u4e8e\u5806\u53e0\u591a\u4e2aTransformer\u5c42\u3002\n    x = linear(a, **c_proj)  # [n_seq, 4*n_embd] -&gt; [n_seq, n_embd]\n\n    # \u8fd4\u56de\u7ecf\u8fc7\u524d\u9988\u7f51\u7edc\u5904\u7406\u540e\u7684\u8f93\u51fax\u3002\n    return x\n<\/pre><\/div>\n\n\n\n<p>\u8fd9\u91cc\u6ca1\u6709\u4ec0\u4e48\u7279\u522b\u7684\u6280\u5de7\uff0c\u6211\u4eec\u53ea\u662f\u5c06<code>n_embd<\/code>\u6295\u5f71\u5230\u4e00\u4e2a\u66f4\u9ad8\u7684\u7ef4\u5ea6<code>4*n_embd<\/code>\uff0c\u7136\u540e\u518d\u5c06\u5176\u6295\u5f71\u56de<code>n_embd<\/code><sup><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#fn:5\">[5]<\/a><\/sup>\u3002<\/p>\n\n\n\n<p>\u56de\u5fc6\u4e00\u4e0b\u6211\u4eec\u7684<code>params<\/code>\u5b57\u5178\uff0c\u6211\u4eec\u7684<code>mlp<\/code>\u53c2\u6570\u5982\u4e0b\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">\"mlp\": {\n    \"c_fc\": {\"b\": [4*n_embd], \"w\": [n_embd, 4*n_embd]},\n    \"c_proj\": {\"b\": [n_embd], \"w\": [4*n_embd, n_embd]},\n}<\/pre><\/div>\n\n\n\n<h3 class=\"wp-block-heading\" id=\"\u591a\u5934\u56e0\u679c\u81ea\u6ce8\u610f\u529b\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#%E5%A4%9A%E5%A4%B4%E5%9B%A0%E6%9E%9C%E8%87%AA%E6%B3%A8%E6%84%8F%E5%8A%9B\"><\/a>4.6 \u591a\u5934\u56e0\u679c\u81ea\u6ce8\u610f\u529b(Multi-Head Causal Self Attention)<\/strong><\/h3>\n\n\n\n<p>\u8fd9\u4e00\u5c42\u53ef\u80fd\u662f\u7406\u89e3transformer\u6700\u56f0\u96be\u7684\u90e8\u5206\u3002\u56e0\u6b64\u6211\u4eec\u901a\u8fc7\u5206\u522b\u89e3\u91ca\u201c\u591a\u5934\u56e0\u679c\u81ea\u6ce8\u610f\u529b\u201d\u7684\u6bcf\u4e2a\u8bcd\uff0c\u4e00\u6b65\u6b65\u7406\u89e3\u201c\u591a\u5934\u56e0\u679c\u81ea\u6ce8\u610f\u529b\u201d\uff1a<\/p>\n\n\n\n<ol class=\"wp-block-list\">\n<li>\u6ce8\u610f\u529b\uff08Attention\uff09<\/li>\n\n\n\n<li>\u81ea\u8eab\uff08Self\uff09<\/li>\n\n\n\n<li>\u56e0\u679c\uff08Causal\uff09<\/li>\n\n\n\n<li>\u591a\u5934\uff08Multi-Head\uff09<\/li>\n<\/ol>\n\n\n\n<h4 class=\"wp-block-heading\" id=\"\u6ce8\u610f\u529b\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#%E6%B3%A8%E6%84%8F%E5%8A%9B\"><\/a>4.6.1 \u6ce8\u610f\u529b(Attention)<\/strong><\/h4>\n\n\n\n<p>\u6211\u8fd8\u6709\u53e6\u4e00\u7bc7\u5173\u4e8e\u8fd9\u4e2a\u8bdd\u9898\u7684<a href=\"https:\/\/jaykmody.com\/blog\/attention-intuition\/\" target=\"_blank\" rel=\"noreferrer noopener\">\u535a\u5ba2\u6587\u7ae0<\/a>\uff0c\u90a3\u7bc7\u535a\u5ba2\u4e2d\uff0c\u6211\u4ece\u5934\u5f00\u59cb\u63a8\u5bfc\u4e86<a href=\"https:\/\/arxiv.org\/pdf\/1706.03762.pdf\" target=\"_blank\" rel=\"noreferrer noopener\">\u539f\u59cbtransformer\u8bba\u6587<\/a>\u4e2d\u63d0\u51fa\u7684\u7f29\u653e\u70b9\u79ef\u65b9\u7a0b\uff1a<\/p>\n\n\n\n<figure class=\"wp-block-image size-full\"><img loading=\"lazy\" decoding=\"async\" width=\"537\" height=\"97\" src=\"https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/02\/\u56fe\u7247-7.png\" alt=\"\" class=\"wp-image-2259\" srcset=\"https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/02\/\u56fe\u7247-7.png 537w, https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/02\/\u56fe\u7247-7-300x54.png 300w\" sizes=\"auto, (max-width: 537px) 100vw, 537px\" \/><\/figure>\n\n\n\n<p>\u56e0\u6b64\u5728\u8fd9\u7bc7\u6587\u7ae0\u4e2d\uff0c\u6211\u5c06\u8df3\u8fc7\u5173\u4e8e\u6ce8\u610f\u529b\u7684\u89e3\u91ca\u3002\u60a8\u4e5f\u53ef\u4ee5\u53c2\u8003 Lilian Weng \u7684&nbsp;<a href=\"https:\/\/lilianweng.github.io\/posts\/2018-06-24-attention\/\" target=\"_blank\" rel=\"noreferrer noopener\">Attention? Attention!<\/a>\u548cJay Alammar\u7684<a href=\"https:\/\/jalammar.github.io\/visualizing-neural-machine-translation-mechanics-of-seq2seq-models-with-attention\/\" target=\"_blank\" rel=\"noreferrer noopener\">The Illustrated Transformer<\/a>\uff0c\u8fd9\u4e24\u7bc7\u4e5f\u5bf9\u6ce8\u610f\u529b\u673a\u5236\u505a\u4e86\u6781\u597d\u7684\u89e3\u91ca\u3002<\/p>\n\n\n\n<p>\u6211\u4eec\u73b0\u5728\u53ea\u8981\u53bb\u9002\u914d\u6211\u535a\u5ba2\u6587\u7ae0\u4e2d\u7684\u6ce8\u610f\u529b\u5b9e\u73b0\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">def attention(q, k, v):  # [n_q, d_k], [n_k, d_k], [n_k, d_v] -&gt; [n_q, d_v]\n    # q: \u67e5\u8be2\u77e9\u9635\uff0c\u5f62\u72b6\u4e3a[n_q, d_k]\uff0c\u5176\u4e2dn_q\u662f\u67e5\u8be2\u7684\u6570\u91cf\uff0cd_k\u662f\u952e\/\u67e5\u8be2\u7684\u7ef4\u5ea6\u3002\n    # k: \u952e\u77e9\u9635\uff0c\u5f62\u72b6\u4e3a[n_k, d_k]\uff0c\u5176\u4e2dn_k\u662f\u952e\u7684\u6570\u91cf\uff0cd_k\u662f\u952e\/\u67e5\u8be2\u7684\u7ef4\u5ea6\u3002\n    # v: \u503c\u77e9\u9635\uff0c\u5f62\u72b6\u4e3a[n_k, d_v]\uff0c\u5176\u4e2dn_k\u662f\u503c\u7684\u6570\u91cf\uff0cd_v\u662f\u503c\u7684\u7ef4\u5ea6\u3002\n\n    # \u8ba1\u7b97\u67e5\u8be2q\u548c\u952ek\u7684\u70b9\u79ef\u6ce8\u610f\u529b\u5f97\u5206\uff0c\u7136\u540e\u9664\u4ee5d_k\u7684\u5e73\u65b9\u6839\u8fdb\u884c\u7f29\u653e\u3002\n    # \u8fd9\u4e2a\u7f29\u653e\u6709\u52a9\u4e8e\u63a7\u5236\u68af\u5ea6\u5728\u8bad\u7ec3\u521d\u671f\u7684\u7a33\u5b9a\u6027\u3002\n    # \u4f7f\u7528softmax\u51fd\u6570\u5bf9\u6ce8\u610f\u529b\u5f97\u5206\u8fdb\u884c\u5f52\u4e00\u5316\uff0c\u786e\u4fdd\u6240\u6709\u5f97\u5206\u7684\u548c\u4e3a1\u3002\n    # \u8fd9\u6837\u53ef\u4ee5\u5c06\u6ce8\u610f\u529b\u5f97\u5206\u89e3\u91ca\u4e3a\u6982\u7387\u5206\u5e03\uff0c\u8868\u793a\u6bcf\u4e2a\u952e\u5bf9\u67e5\u8be2\u7684\u91cd\u8981\u6027\u3002\n    attention_scores = softmax(q @ k.T \/ np.sqrt(q.shape[-1]))\n\n    # \u5c06\u5f52\u4e00\u5316\u7684\u6ce8\u610f\u529b\u5f97\u5206\u4e0e\u503cv\u76f8\u4e58\uff0c\u5f97\u5230\u52a0\u6743\u7684\u503c\uff0c\u7136\u540e\u6c42\u548c\u3002\n    # \u8fd9\u4e00\u6b65\u9aa4\u805a\u5408\u4e86\u5bf9\u6bcf\u4e2a\u67e5\u8be2\u6700\u91cd\u8981\u7684\u4fe1\u606f\uff0c\u751f\u6210\u8f93\u51fa\u3002\n    return attention_scores @ v\n<\/pre><\/div>\n\n\n\n<h4 class=\"wp-block-heading\" id=\"\u81ea\u8eab-Self\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#%E8%87%AA%E8%BA%AB-Self\"><\/a>4.6.2 \u81ea\u8eab(Self)<\/strong><\/h4>\n\n\n\n<p>\u5f53<code>q<\/code>,&nbsp;<code>k<\/code>\u548c<code>v<\/code>\u6765\u81ea\u540c\u4e00\u6765\u6e90\u65f6\uff0c\u6211\u4eec\u5c31\u662f\u5728\u6267\u884c<a href=\"https:\/\/lilianweng.github.io\/posts\/2018-06-24-attention\/#self-attention\" target=\"_blank\" rel=\"noreferrer noopener\">\u81ea\u6ce8\u610f\u529b<\/a>\uff08\u5373\u8ba9\u6211\u4eec\u7684\u8f93\u5165\u5e8f\u5217\u81ea\u6211\u5173\u6ce8\uff09\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">def self_attention(x): # [n_seq, n_embd] -&gt; [n_seq, n_embd]\n    # x: \u8f93\u5165\u77e9\u9635\uff0c\u5f62\u72b6\u4e3a[n_seq, n_embd]\uff0c\u5176\u4e2dn_seq\u662f\u5e8f\u5217\u957f\u5ea6\uff0cn_embd\u662f\u5d4c\u5165\u7ef4\u5ea6\u3002\n\n    # \u8c03\u7528attention\u51fd\u6570\u8fdb\u884c\u81ea\u6ce8\u610f\u529b\u8ba1\u7b97\u3002\n    # \u5728\u81ea\u6ce8\u610f\u529b\u673a\u5236\u4e2d\uff0c\u67e5\u8be2\uff08q\uff09\u3001\u952e\uff08k\uff09\u548c\u503c\uff08v\uff09\u90fd\u662f\u540c\u4e00\u4e2a\u8f93\u5165\u77e9\u9635x\u3002\n    # \u8fd9\u610f\u5473\u7740\uff0c\u5bf9\u4e8e\u7ed9\u5b9a\u7684\u8f93\u5165\u5e8f\u5217\uff0c\u6a21\u578b\u5c06\u5b66\u4e60\u5982\u4f55\u57fa\u4e8e\u5e8f\u5217\u5185\u90e8\u7684\u5176\u4ed6\u4f4d\u7f6e\u6765\u8868\u793a\u6bcf\u4e2a\u4f4d\u7f6e\u3002\n    # \u8fd9\u79cd\u673a\u5236\u4f7f\u6a21\u578b\u80fd\u591f\u6355\u6349\u5e8f\u5217\u5185\u7684\u957f\u8ddd\u79bb\u4f9d\u8d56\u5173\u7cfb\uff0c\u65e0\u9700\u4f9d\u8d56\u4e8e\u4f20\u7edf\u7684\u5faa\u73af\u7ed3\u6784\u3002\n    return attention(q=x, k=x, v=x)\n<\/pre><\/div>\n\n\n\n<p>\u4f8b\u5982\uff0c\u5982\u679c\u6211\u4eec\u7684\u8f93\u5165\u662f\u201cJay went to the store, he bought 10 apples.\u201d\uff0c\u6211\u4eec\u8ba9\u5355\u8bcd\u201che\u201d\u5173\u6ce8\u6240\u6709\u5176\u5b83\u5355\u8bcd\uff0c\u5305\u62ec\u201cJay\u201d\uff0c\u8fd9\u610f\u5473\u7740\u6a21\u578b\u53ef\u4ee5\u5b66\u4e60\u5230\u201che\u201d\u6307\u7684\u662f\u201cJay\u201d\u3002<\/p>\n\n\n\n<blockquote class=\"wp-block-quote is-layout-flow wp-block-quote-is-layout-flow\">\n<p>\u8bd1\u8005\u6ce8\uff1a\u6ce8\u610f\u8fd9\u91cc\u662f\u82f1\u6587\u7684\u6587\u672c<\/p>\n<\/blockquote>\n\n\n\n<p>\u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u4e3a<code>q<\/code>\u3001<code>k<\/code>\u3001<code>v<\/code>\u548c\u6ce8\u610f\u529b\u8f93\u51fa\u5f15\u5165\u6295\u5f71\u6765\u589e\u5f3a\u81ea\u6ce8\u610f\u529b\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">def self_attention(x, w_k, w_q, w_v, w_proj): # [n_seq, n_embd] -&gt; [n_seq, n_embd]\n    # qkv\u6295\u5f71\n    # \u4f7f\u7528\u7ebf\u6027\u53d8\u6362\u5c06\u8f93\u5165x\u6295\u5f71\u5230\u67e5\u8be2\uff08Q\uff09\u3001\u952e\uff08K\uff09\u3001\u503c\uff08V\uff09\u7a7a\u95f4\u3002\n    # w_k, w_q, w_v\u5206\u522b\u662f\u5bf9\u5e94\u4e8eK\u3001Q\u3001V\u7684\u6295\u5f71\u77e9\u9635\u3002\n    q = x @ w_k # \u4f7f\u7528w_k\u5bf9x\u8fdb\u884c\u7ebf\u6027\u53d8\u6362\u4ee5\u83b7\u5f97\u67e5\u8be2\u77e9\u9635Q\u3002\n    k = x @ w_q # \u4f7f\u7528w_q\u5bf9x\u8fdb\u884c\u7ebf\u6027\u53d8\u6362\u4ee5\u83b7\u5f97\u952e\u77e9\u9635K\u3002\n    v = x @ w_v # \u4f7f\u7528w_v\u5bf9x\u8fdb\u884c\u7ebf\u6027\u53d8\u6362\u4ee5\u83b7\u5f97\u503c\u77e9\u9635V\u3002\n\n    # \u6267\u884c\u81ea\u6ce8\u610f\u529b\u64cd\u4f5c\n    # \u8c03\u7528attention\u51fd\u6570\uff0c\u4f20\u5165Q\u3001K\u3001V\u77e9\u9635\uff0c\u8ba1\u7b97\u81ea\u6ce8\u610f\u529b\u7684\u8f93\u51fa\u3002\n    x = attention(q, k, v) # \u81ea\u6ce8\u610f\u529b\u673a\u5236\u7684\u8f93\u51fa\uff0c\u5f62\u72b6\u4fdd\u6301\u4e0d\u53d8\uff0c\u4ecd\u4e3a[n_seq, n_embd]\u3002\n\n    # \u8f93\u51fa\u6295\u5f71\n    # \u6700\u540e\uff0c\u4f7f\u7528w_proj\u77e9\u9635\u5bf9\u81ea\u6ce8\u610f\u529b\u7684\u8f93\u51fa\u8fdb\u884c\u7ebf\u6027\u53d8\u6362\uff0c\u4ee5\u83b7\u5f97\u6700\u7ec8\u7684\u8f93\u51fa\u3002\n    x = x @ w_proj # \u4f7f\u7528w_proj\u5bf9\u81ea\u6ce8\u610f\u529b\u8f93\u51fa\u8fdb\u884c\u7ebf\u6027\u53d8\u6362\u3002\n\n    # \u8fd4\u56de\u81ea\u6ce8\u610f\u529b\u5c42\u7684\u8f93\u51fa\uff0c\u5f62\u72b6\u4e3a[n_seq, n_embd]\uff0c\u4e0e\u8f93\u5165\u5f62\u72b6\u76f8\u540c\u3002\n    return x\n<\/pre><\/div>\n\n\n\n<p>\u8fd9\u4f7f\u5f97\u6211\u4eec\u7684\u6a21\u578b\u4e3a<code>q<\/code>,&nbsp;<code>k<\/code>,&nbsp;<code>v<\/code>\u5b66\u5230\u4e00\u4e2a\u6700\u597d\u7684\u6620\u5c04\uff0c\u4ee5\u5e2e\u52a9\u6ce8\u610f\u529b\u533a\u5206\u8f93\u5165\u4e4b\u95f4\u7684\u5173\u7cfb\u3002<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">def self_attention(x, w_fc, w_proj): # [n_seq, n_embd] -&gt; [n_seq, n_embd]\n    # QKV\u6295\u5f71\n    # \u4f7f\u7528\u5355\u4e2a\u7ebf\u6027\u53d8\u6362\u6743\u91cd\u77e9\u9635w_fc\u5c06\u8f93\u5165x\u6620\u5c04\u5230\u4e00\u4e2a\u5408\u5e76\u4e86\u67e5\u8be2\u3001\u952e\u3001\u503c\u7684\u7a7a\u95f4\u3002\n    # w_fc\u7684\u7ef4\u5ea6\u4e3a[n_embd, 3*n_embd]\uff0c\u610f\u5473\u7740\u8f93\u51fa\u7684\u6bcf\u4e2a\u5143\u7d20\u5c06\u88ab\u6620\u5c04\u5230\u4e00\u4e2a\u4e09\u500d\u7ef4\u5ea6\u7684\u7a7a\u95f4\uff0c\u4ee5\u4fbf\u540e\u7eed\u5206\u5272\u4e3aQ\u3001K\u3001V\u3002\n    x = x @ w_fc # [n_seq, n_embd] @ [n_embd, 3*n_embd] -&gt; [n_seq, 3*n_embd]\n\n    # \u5206\u5272\u6210QKV\n    # \u5c06\u4e0a\u4e00\u6b65\u7684\u8f93\u51fa\u6cbf\u6700\u540e\u4e00\u4e2a\u7ef4\u5ea6\u5206\u5272\u6210\u4e09\u4e2a\u90e8\u5206\uff0c\u5206\u522b\u4ee3\u8868\u67e5\u8be2\u3001\u952e\u548c\u503c\u3002\n    # np.split\u51fd\u6570\u5b9e\u73b0\u8fd9\u4e00\u5206\u5272\uff0c3\u8868\u793a\u5206\u5272\u6210\u4e09\u90e8\u5206\uff0caxis=-1\u8868\u793a\u6cbf\u6700\u540e\u4e00\u4e2a\u7ef4\u5ea6\u8fdb\u884c\u5206\u5272\u3002\n    q, k, v = np.split(x, 3, axis=-1) # [n_seq, 3*n_embd] -&gt; 3\u4e2a[n_seq, n_embd]\n\n    # \u6267\u884c\u81ea\u6ce8\u610f\u529b\n    # \u8c03\u7528\u4e4b\u524d\u5b9a\u4e49\u7684attention\u51fd\u6570\uff0c\u4f20\u5165\u5206\u5272\u540e\u5f97\u5230\u7684Q\u3001K\u3001V\uff0c\u8fdb\u884c\u81ea\u6ce8\u610f\u529b\u8ba1\u7b97\u3002\n    # \u81ea\u6ce8\u610f\u529b\u7684\u8f93\u51fa\u7ef4\u5ea6\u4ecd\u4e3a[n_seq, n_embd]\uff0c\u5f62\u72b6\u4e0d\u53d8\u3002\n    x = attention(q, k, v) # [n_seq, n_embd] -&gt; [n_seq, n_embd]\n\n    # \u8f93\u51fa\u6295\u5f71\n    # \u4f7f\u7528\u53e6\u4e00\u4e2a\u7ebf\u6027\u53d8\u6362\u6743\u91cd\u77e9\u9635w_proj\u5bf9\u81ea\u6ce8\u610f\u529b\u7684\u8f93\u51fa\u8fdb\u884c\u6295\u5f71\uff0c\u4ee5\u5f97\u5230\u6700\u7ec8\u7684\u8f93\u51fa\u3002\n    # w_proj\u7684\u7ef4\u5ea6\u4e3a[n_embd, n_embd]\uff0c\u4fdd\u8bc1\u8f93\u51fa\u7684\u5f62\u72b6\u4e0e\u8f93\u5165x\u76f8\u540c\u3002\n    x = x @ w_proj # [n_seq, n_embd] @ [n_embd, n_embd] -&gt; [n_seq, n_embd]\n\n    # \u8fd4\u56de\u81ea\u6ce8\u610f\u529b\u5c42\u7684\u8f93\u51fa\u3002\n    return x\n<\/pre><\/div>\n\n\n\n<p>\u5982\u679c\u6211\u4eec\u5c06<code>w_q<\/code>\u3001<code>w_k<\/code>\u548c<code>w_v<\/code>\u7ec4\u5408\u6210\u4e00\u4e2a\u5355\u72ec\u7684\u77e9\u9635<code>w_fc<\/code>\uff0c\u6267\u884c\u6295\u5f71\u64cd\u4f5c\uff0c\u7136\u540e\u62c6\u5206\u7ed3\u679c\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u5c06\u77e9\u9635\u4e58\u6cd5\u7684\u6570\u91cf\u4ece4\u4e2a\u51cf\u5c11\u52302\u4e2a\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">def self_attention(x, w_fc, w_proj): # [n_seq, n_embd] -&gt; [n_seq, n_embd]\n    # QKV\u6295\u5f71\n    # \u4f7f\u7528\u5355\u4e2a\u7ebf\u6027\u53d8\u6362\u6743\u91cd\u77e9\u9635w_fc\u5c06\u8f93\u5165x\u6620\u5c04\u5230\u4e00\u4e2a\u5408\u5e76\u4e86\u67e5\u8be2\u3001\u952e\u3001\u503c\u7684\u7a7a\u95f4\u3002\n    # w_fc\u7684\u7ef4\u5ea6\u4e3a[n_embd, 3*n_embd]\uff0c\u610f\u5473\u7740\u8f93\u51fa\u7684\u6bcf\u4e2a\u5143\u7d20\u5c06\u88ab\u6620\u5c04\u5230\u4e00\u4e2a\u4e09\u500d\u7ef4\u5ea6\u7684\u7a7a\u95f4\uff0c\u4ee5\u4fbf\u540e\u7eed\u5206\u5272\u4e3aQ\u3001K\u3001V\u3002\n    x = x @ w_fc # [n_seq, n_embd] @ [n_embd, 3*n_embd] -&gt; [n_seq, 3*n_embd]\n\n    # \u5206\u5272\u6210QKV\n    # \u5c06\u4e0a\u4e00\u6b65\u7684\u8f93\u51fa\u6cbf\u6700\u540e\u4e00\u4e2a\u7ef4\u5ea6\u5206\u5272\u6210\u4e09\u4e2a\u90e8\u5206\uff0c\u5206\u522b\u4ee3\u8868\u67e5\u8be2\u3001\u952e\u548c\u503c\u3002\n    # np.split\u51fd\u6570\u5b9e\u73b0\u8fd9\u4e00\u5206\u5272\uff0c3\u8868\u793a\u5206\u5272\u6210\u4e09\u90e8\u5206\uff0caxis=-1\u8868\u793a\u6cbf\u6700\u540e\u4e00\u4e2a\u7ef4\u5ea6\u8fdb\u884c\u5206\u5272\u3002\n    q, k, v = np.split(x, 3, axis=-1) # [n_seq, 3*n_embd] -&gt; 3\u4e2a[n_seq, n_embd]\n\n    # \u6267\u884c\u81ea\u6ce8\u610f\u529b\n    # \u8c03\u7528\u4e4b\u524d\u5b9a\u4e49\u7684attention\u51fd\u6570\uff0c\u4f20\u5165\u5206\u5272\u540e\u5f97\u5230\u7684Q\u3001K\u3001V\uff0c\u8fdb\u884c\u81ea\u6ce8\u610f\u529b\u8ba1\u7b97\u3002\n    # \u81ea\u6ce8\u610f\u529b\u7684\u8f93\u51fa\u7ef4\u5ea6\u4ecd\u4e3a[n_seq, n_embd]\uff0c\u5f62\u72b6\u4e0d\u53d8\u3002\n    x = attention(q, k, v) # [n_seq, n_embd] -&gt; [n_seq, n_embd]\n\n    # \u8f93\u51fa\u6295\u5f71\n    # \u4f7f\u7528\u53e6\u4e00\u4e2a\u7ebf\u6027\u53d8\u6362\u6743\u91cd\u77e9\u9635w_proj\u5bf9\u81ea\u6ce8\u610f\u529b\u7684\u8f93\u51fa\u8fdb\u884c\u6295\u5f71\uff0c\u4ee5\u5f97\u5230\u6700\u7ec8\u7684\u8f93\u51fa\u3002\n    # w_proj\u7684\u7ef4\u5ea6\u4e3a[n_embd, n_embd]\uff0c\u4fdd\u8bc1\u8f93\u51fa\u7684\u5f62\u72b6\u4e0e\u8f93\u5165x\u76f8\u540c\u3002\n    x = x @ w_proj # [n_seq, n_embd] @ [n_embd, n_embd] -&gt; [n_seq, n_embd]\n\n    # \u8fd4\u56de\u81ea\u6ce8\u610f\u529b\u5c42\u7684\u8f93\u51fa\u3002\n    return x\n<\/pre><\/div>\n\n\n\n<p>\u8fd9\u6837\u4f1a\u66f4\u52a0\u9ad8\u6548\uff0c\u56e0\u4e3a\u73b0\u4ee3\u52a0\u901f\u5668\uff08\u5982GPU\uff09\u53ef\u4ee5\u66f4\u597d\u5730\u5229\u7528\u4e00\u4e2a\u5927\u7684\u77e9\u9635\u4e58\u6cd5\uff0c\u800c\u4e0d\u662f\u987a\u5e8f\u6267\u884c3\u4e2a\u72ec\u7acb\u7684\u5c0f\u77e9\u9635\u4e58\u6cd5\u3002<\/p>\n\n\n\n<p>\u6700\u540e\uff0c\u6211\u4eec\u6dfb\u52a0\u504f\u7f6e\u5411\u91cf\u4ee5\u5339\u914dGPT-2\u7684\u5b9e\u73b0\uff0c\u7136\u540e\u4f7f\u7528\u6211\u4eec\u7684<code>linear<\/code>\u51fd\u6570\uff0c\u5e76\u5c06\u53c2\u6570\u91cd\u547d\u540d\u4ee5\u5339\u914d\u6211\u4eec\u7684<code>params<\/code>\u5b57\u5178\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">def self_attention(x, c_attn, c_proj): # [n_seq, n_embd] -&gt; [n_seq, n_embd]\n    # QKV\u6295\u5f71\n    # \u9996\u5148\uff0c\u901a\u8fc7\u4e00\u4e2a\u7ebf\u6027\u53d8\u6362\uff08\u4f7f\u7528\u7ed9\u5b9a\u7684c_attn\u53c2\u6570\uff09\u5c06\u8f93\u5165x\u6295\u5f71\u5230\u4e00\u4e2a\u5408\u5e76\u4e86\u67e5\u8be2\u3001\u952e\u3001\u503c\u7684\u7a7a\u95f4\u3002\n    # \u8fd9\u4e00\u6b65\u5c06\u8f93\u5165\u7684\u7ef4\u5ea6\u4ece[n_seq, n_embd]\u6269\u5c55\u5230[n_seq, 3*n_embd]\u3002\n    x = linear(x, **c_attn)\n\n    # \u5206\u5272\u6210QKV\n    # \u7136\u540e\uff0c\u5c06\u4e0a\u4e00\u6b65\u5f97\u5230\u7684\u6269\u5c55\u540e\u7684\u77e9\u9635\u6cbf\u6700\u540e\u4e00\u4e2a\u7ef4\u5ea6\u5206\u5272\u6210\u4e09\u4e2a\u76f8\u7b49\u7684\u90e8\u5206\uff0c\n    # \u5206\u522b\u4ee3\u8868\u67e5\u8be2\uff08Q\uff09\u3001\u952e\uff08K\uff09\u548c\u503c\uff08V\uff09\u3002\n    q, k, v = np.split(x, 3, axis=-1)\n\n    # \u6267\u884c\u81ea\u6ce8\u610f\u529b\n    # \u4f7f\u7528\u5206\u5272\u5f97\u5230\u7684Q\u3001K\u3001V\u6267\u884c\u81ea\u6ce8\u610f\u529b\u8ba1\u7b97\u3002\u8fd9\u91cc\u7684attention\u51fd\u6570\u8d1f\u8d23\u8ba1\u7b97\u81ea\u6ce8\u610f\u529b\u673a\u5236\u7684\u8f93\u51fa\u3002\n    x = attention(q, k, v)\n\n    # \u8f93\u51fa\u6295\u5f71\n    # \u6700\u540e\uff0c\u4f7f\u7528\u53e6\u4e00\u4e2a\u7ebf\u6027\u53d8\u6362\uff08\u4f7f\u7528\u7ed9\u5b9a\u7684c_proj\u53c2\u6570\uff09\u5c06\u81ea\u6ce8\u610f\u529b\u7684\u8f93\u51fa\u6295\u5f71\u56de\u539f\u59cb\u7684\u5d4c\u5165\u7ef4\u5ea6\u7a7a\u95f4[n_seq, n_embd]\u3002\n    x = linear(x, **c_proj)\n\n    # \u8fd4\u56de\u81ea\u6ce8\u610f\u529b\u5c42\u7684\u8f93\u51fa\uff0c\u5176\u5f62\u72b6\u4fdd\u6301\u4e3a[n_seq, n_embd]\u3002\n    return x\n<\/pre><\/div>\n\n\n\n<p>\u56de\u5fc6\u4e00\u4e0b\uff0c\u4ece\u6211\u4eec\u7684<code>params<\/code>\u5b57\u5178\u4e2d\u53ef\u77e5\uff0c<code>attn<\/code>\u53c2\u6570\u7c7b\u4f3c\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">\"attn\": {\n    \"c_attn\": {\"b\": [3*n_embd], \"w\": [n_embd, 3*n_embd]},\n    \"c_proj\": {\"b\": [n_embd], \"w\": [n_embd, n_embd]},\n},<\/pre><\/div>\n\n\n\n<h4 class=\"wp-block-heading\" id=\"\u56e0\u679c\"><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#%E5%9B%A0%E6%9E%9C\"><\/a><strong>4.6.3 <\/strong>\u56e0\u679c(Causal)<\/h4>\n\n\n\n<p>\u6211\u4eec\u5f53\u524d\u7684\u81ea\u6ce8\u610f\u529b\u8bbe\u7f6e\u5b58\u5728\u4e00\u4e2a\u95ee\u9898\uff0c\u5c31\u662f\u6211\u4eec\u7684\u8f93\u5165\u80fd\u591f\u201c\u770b\u5230\u201d\u672a\u6765\u7684\u4fe1\u606f\uff01\u6bd4\u5982\uff0c\u5982\u679c\u6211\u4eec\u7684\u8f93\u5165\u662f[\u201cnot\u201d, \u201call\u201d, \u201cheroes\u201d, \u201cwear\u201d, \u201ccapes\u201d]\uff0c\u5728\u81ea\u6ce8\u610f\u529b\u4e2d\uff0c\u201cwear\u201d\u53ef\u4ee5\u770b\u5230\u201ccapes\u201d\u3002\u8fd9\u610f\u5473\u7740\u201cwear\u201d\u7684\u8f93\u51fa\u6982\u7387\u5c06\u4f1a\u53d7\u5230\u504f\u5dee\uff0c\u56e0\u4e3a\u6a21\u578b\u5df2\u7ecf\u77e5\u9053\u6b63\u786e\u7684\u7b54\u6848\u662f\u201ccapes\u201d\u3002\u8fd9\u662f\u4e0d\u597d\u7684\uff0c\u56e0\u4e3a\u6211\u4eec\u7684\u6a21\u578b\u4f1a\u4ece\u4e2d\u5b66\u4e60\u5230\uff0c\u8f93\u5165 <em>i<\/em> \u7684\u6b63\u786e\u7b54\u6848\u53ef\u4ee5\u4ece\u8f93\u5165 <em>i<\/em>+1 \u4e2d\u83b7\u53d6\u3002<\/p>\n\n\n\n<p>\u4e3a\u4e86\u9632\u6b62\u8fd9\u79cd\u60c5\u51b5\u53d1\u751f\uff0c\u6211\u4eec\u9700\u8981\u4fee\u6539\u6ce8\u610f\u529b\u77e9\u9635\uff0c\u4ee5<em>\u9690\u85cf<\/em>(hide)\u6216<strong>\u5c4f\u853d<\/strong>(<strong>mask<\/strong>)\u6211\u4eec\u7684\u8f93\u5165\uff0c\u4f7f\u5176\u65e0\u6cd5\u770b\u5230\u672a\u6765\u7684\u4fe1\u606f\u3002\u4f8b\u5982\uff0c\u5047\u8bbe\u6211\u4eec\u7684\u6ce8\u610f\u529b\u77e9\u9635\u5982\u4e0b\u6240\u793a\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">       not    all    heroes wear   capes\n   not 0.116  0.159  0.055  0.226  0.443\n   all 0.180  0.397  0.142  0.106  0.175\nheroes 0.156  0.453  0.028  0.129  0.234\n  wear 0.499  0.055  0.133  0.017  0.295\n capes 0.089  0.290  0.240  0.228  0.153<\/pre><\/div>\n\n\n\n<p>\u8fd9\u91cc\u6bcf\u4e00\u884c\u5bf9\u5e94\u4e00\u4e2a\u67e5\u8be2(query)\uff0c\u6bcf\u4e00\u5217\u5bf9\u5e94\u4e00\u4e2a\u952e\u503c(key)\u3002\u5728\u8fd9\u4e2a\u4f8b\u5b50\u4e2d\uff0c\u67e5\u770b \u201cwear\u201d \u5bf9\u5e94\u7684\u884c\uff0c\u53ef\u4ee5\u770b\u5230\u5b83\u5728\u6700\u540e\u4e00\u5217\u4ee50.295\u7684\u6743\u91cd\u4e0e \u201ccapes\u201d \u76f8\u5173\u3002\u4e3a\u4e86\u9632\u6b62\u8fd9\u79cd\u60c5\u51b5\u53d1\u751f\uff0c\u6211\u4eec\u8981\u5c06\u8fd9\u9879\u8bbe\u4e3a<code>0.0<\/code>:<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">      not    all    heroes wear   capes\n   not 0.116  0.159  0.055  0.226  0.443\n   all 0.180  0.397  0.142  0.106  0.175\nheroes 0.156  0.453  0.028  0.129  0.234\n  wear 0.499  0.055  0.133  0.017  0.\n capes 0.089  0.290  0.240  0.228  0.153<\/pre><\/div>\n\n\n\n<p>\u901a\u5e38\uff0c\u4e3a\u4e86\u9632\u6b62\u8f93\u5165\u4e2d\u7684\u6240\u6709\u67e5\u8be2\u770b\u5230\u672a\u6765\u4fe1\u606f\uff0c\u6211\u4eec\u5c06\u6240\u6709\u6ee1\u8db3<em>j<\/em>&gt;<em>i<\/em>\u7684\u4f4d\u7f6e<em>i<\/em>,&nbsp;<em>j<\/em>\u90fd\u8bbe\u7f6e\u4e3a<code>0<\/code>\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">       not    all    heroes wear   capes\n   not 0.116  0.     0.     0.     0.\n   all 0.180  0.397  0.     0.     0.\nheroes 0.156  0.453  0.028  0.     0.\n  wear 0.499  0.055  0.133  0.017  0.\n capes 0.089  0.290  0.240  0.228  0.153<\/pre><\/div>\n\n\n\n<p>\u6211\u4eec\u5c06\u8fd9\u79f0\u4e3a<strong>\u63a9\u7801(masking)<\/strong>\u3002\u63a9\u7801\u65b9\u6cd5\u7684\u4e00\u4e2a\u95ee\u9898\u662f\u6211\u4eec\u7684\u884c\u4e0d\u518d\u52a0\u8d77\u6765\u4e3a1\uff08\u56e0\u4e3a\u6211\u4eec\u5728\u4f7f\u7528<code>softmax<\/code>\u540e\u624d\u5c06\u5b83\u4eec\u8bbe\u4e3a0\uff09\u3002\u4e3a\u4e86\u786e\u4fdd\u6211\u4eec\u7684\u884c\u4ecd\u7136\u52a0\u8d77\u6765\u4e3a1\uff0c\u6211\u4eec\u9700\u8981\u5728\u4f7f\u7528<code>softmax<\/code>\u4e4b\u524d\u5148\u4fee\u6539\u6ce8\u610f\u529b\u77e9\u9635\u3002<\/p>\n\n\n\n<p>\u8fd9\u53ef\u4ee5\u901a\u8fc7\u5728<code>softmax<\/code>\u4e4b\u524d\u5c06\u9700\u8981\u88ab\u63a9\u7801\u7684\u6761\u76ee\u8bbe\u7f6e\u4e3a\u2212\u221e\u6765\u5b9e\u73b0<sup><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#fn:6\">[6]<\/a><\/sup>\uff1a<\/p>\n\n\n\n<p>\u4e0b\u9762\u4ee3\u7801\u5b9a\u4e49\u4e86\u4e00\u4e2a\u540d\u4e3a<code>attention<\/code>\u7684\u51fd\u6570\uff0c\u5b83\u5b9e\u73b0\u4e86\u5e26\u6709\u63a9\u7801\uff08mask\uff09\u7684\u6ce8\u610f\u529b\u673a\u5236\u3002\u63a9\u7801\u7528\u4e8e\u5728\u6ce8\u610f\u529b\u8ba1\u7b97\u4e2d\u6392\u9664\uff08\u6216\u51cf\u5c0f\uff09\u67d0\u4e9b\u5143\u7d20\u7684\u5f71\u54cd\uff0c\u5e38\u89c1\u4e8e\u5904\u7406\u53d8\u957f\u8f93\u5165\u5e8f\u5217\u548c\u9632\u6b62\u4fe1\u606f\u6cc4\u9732\uff08\u4f8b\u5982\uff0c\u5728\u81ea\u56de\u5f52\u6a21\u578b\u4e2d\u9632\u6b62\u672a\u6765\u4fe1\u606f\u88ab\u63d0\u524d\u770b\u5230\uff09\u7684\u573a\u666f<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">def attention(q, k, v, mask):  # [n_q, d_k], [n_k, d_k], [n_k, d_v], [n_q, n_k] -&gt; [n_q, d_v]\n    # q: \u67e5\u8be2\u77e9\u9635\uff0c\u5f62\u72b6\u4e3a[n_q, d_k]\uff0c\u5176\u4e2dn_q\u662f\u67e5\u8be2\u7684\u6570\u91cf\uff0cd_k\u662f\u952e\/\u67e5\u8be2\u7684\u7ef4\u5ea6\u3002\n    # k: \u952e\u77e9\u9635\uff0c\u5f62\u72b6\u4e3a[n_k, d_k]\uff0c\u5176\u4e2dn_k\u662f\u952e\u7684\u6570\u91cf\uff0cd_k\u662f\u952e\/\u67e5\u8be2\u7684\u7ef4\u5ea6\u3002\n    # v: \u503c\u77e9\u9635\uff0c\u5f62\u72b6\u4e3a[n_k, d_v]\uff0c\u5176\u4e2dn_k\u662f\u503c\u7684\u6570\u91cf\uff0cd_v\u662f\u503c\u7684\u7ef4\u5ea6\u3002\n    # mask: \u63a9\u7801\u77e9\u9635\uff0c\u5f62\u72b6\u4e3a[n_q, n_k]\uff0c\u7528\u4e8e\u8c03\u6574\u6216\u5c4f\u853d\u67d0\u4e9b\u952e\u503c\u5bf9\u7684\u6ce8\u610f\u529b\u6743\u91cd\u3002\n\n    # \u8ba1\u7b97\u67e5\u8be2q\u548c\u952ek\u7684\u70b9\u79ef\uff0c\u7136\u540e\u9664\u4ee5d_k\u7684\u5e73\u65b9\u6839\u8fdb\u884c\u7f29\u653e\uff0c\u4ee5\u63a7\u5236\u68af\u5ea6\u7a33\u5b9a\u6027\u3002\n    # \u52a0\u4e0a\u63a9\u7801\u77e9\u9635mask\uff0c\u63a9\u7801\u901a\u5e38\u5305\u542b\u8d1f\u65e0\u7a77\uff08\u8868\u793a\u5b8c\u5168\u5c4f\u853d\uff09\u62160\uff08\u8868\u793a\u4e0d\u5c4f\u853d\uff09\u3002\n    # softmax\u51fd\u6570\u5bf9\u6bcf\u4e00\u884c\u8fdb\u884c\u5f52\u4e00\u5316\uff0c\u4f7f\u5f97\u6bcf\u884c\u7684\u5143\u7d20\u548c\u4e3a1\uff0c\u8868\u793a\u6982\u7387\u5206\u5e03\u3002\n    # \u6ce8\u610f\uff1a\u63a9\u7801\u6dfb\u52a0\u5728softmax\u5e94\u7528\u4e4b\u524d\uff0c\u4ee5\u786e\u4fdd\u88ab\u5c4f\u853d\u7684\u9879\u4e0d\u4f1a\u5bf9\u6700\u7ec8\u7ed3\u679c\u4ea7\u751f\u5f71\u54cd\u3002\n    attention_scores = softmax(q @ k.T \/ np.sqrt(q.shape[-1]) + mask)\n\n    # \u5c06\u5f52\u4e00\u5316\u7684\u6ce8\u610f\u529b\u5f97\u5206\u4e0e\u503cv\u76f8\u4e58\uff0c\u5f97\u5230\u52a0\u6743\u7684\u503c\uff0c\u7136\u540e\u6c42\u548c\u3002\n    # \u8fd9\u4e00\u6b65\u9aa4\u805a\u5408\u4e86\u5bf9\u6bcf\u4e2a\u67e5\u8be2\u6700\u91cd\u8981\u7684\u4fe1\u606f\u3002\n    # \u8fd4\u56de\u503c\u7684\u5f62\u72b6\u4e3a[n_q, d_v]\uff0c\u6bcf\u4e2a\u67e5\u8be2\u5bf9\u5e94\u4e00\u4e2a\u8f93\u51fa\u5411\u91cf\uff0c\u8be5\u5411\u91cf\u662f\u6240\u6709\u503c\u7684\u52a0\u6743\u548c\u3002\n    return attention_scores @ v\n<\/pre><\/div>\n\n\n\n<p>\u5176\u4e2d<code>mask<\/code>\u8868\u793a\u77e9\u9635\uff08<code>n_seq=5<\/code>\uff09\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">0 -1e10 -1e10 -1e10 -1e10\n0   0   -1e10 -1e10 -1e10\n0   0     0   -1e10 -1e10\n0   0     0     0   -1e10\n0   0     0     0     0<\/pre><\/div>\n\n\n\n<p>\u6211\u4eec\u7528<code>-1e10<\/code>\u66ff\u6362<code>-np.inf<\/code>\uff0c \u56e0\u4e3a<code>-np.inf<\/code>\u4f1a\u5bfc\u81f4<code>nans<\/code>\u9519\u8bef\u3002<\/p>\n\n\n\n<p>\u6dfb\u52a0<code>mask<\/code>\u5230\u6211\u4eec\u7684\u6ce8\u610f\u529b\u77e9\u9635\u4e2d\uff0c\u800c\u4e0d\u662f\u660e\u786e\u8bbe\u7f6e\u503c\u4e3a<code>-1e10<\/code>\uff0c\u662f\u56e0\u4e3a\u5728\u5b9e\u9645\u64cd\u4f5c\u4e2d\uff0c\u4efb\u4f55\u6570\u52a0\u4e0a<code>-inf<\/code>\u8fd8\u662f<code>-inf<\/code>\u3002<\/p>\n\n\n\n<p>\u6211\u4eec\u53ef\u4ee5\u5728NumPy\u4e2d\u901a\u8fc7<code>(1 - np.tri(n_seq)) * -1e10<\/code>\u6765\u8ba1\u7b97<code>mask<\/code>\u77e9\u9635\u3002<\/p>\n\n\n\n<p>\u5c06\u4ee5\u4e0a\u8fd9\u4e9b\u7ec4\u5408\u8d77\u6765\uff0c\u6211\u4eec\u5f97\u5230\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">def causal_self_attention(x, c_attn, c_proj): # [n_seq, n_embd] -&gt; [n_seq, n_embd]\n    # QKV\u6295\u5f71\n    # \u9996\u5148\uff0c\u4f7f\u7528\u7ebf\u6027\u53d8\u6362\uff08\u901a\u8fc7\u53c2\u6570c_attn\uff09\u5c06\u8f93\u5165x\u6620\u5c04\u5230\u5408\u5e76\u4e86\u67e5\u8be2Q\u3001\u952eK\u3001\u503cV\u7684\u7a7a\u95f4\u3002\n    x = linear(x, **c_attn) # [n_seq, n_embd] -&gt; [n_seq, 3*n_embd]\n\n    # \u5206\u5272\u6210QKV\n    # \u7136\u540e\uff0c\u5c06\u6620\u5c04\u540e\u7684\u7ed3\u679c\u6cbf\u6700\u540e\u4e00\u4e2a\u7ef4\u5ea6\u5e73\u5747\u5206\u5272\u6210\u4e09\u4e2a\u90e8\u5206\uff0c\u5206\u522b\u5bf9\u5e94\u67e5\u8be2Q\u3001\u952eK\u548c\u503cV\u3002\n    q, k, v = np.split(x, 3, axis=-1) # [n_seq, 3*n_embd] -&gt; 3\u4e2a[n_seq, n_embd]\n\n    # \u56e0\u679c\u63a9\u7801\n    # \u521b\u5efa\u4e00\u4e2a\u56e0\u679c\u63a9\u7801\uff0c\u7528\u4e8e\u9690\u85cf\u672a\u6765\u7684\u8f93\u5165\uff0c\u9632\u6b62\u5b83\u4eec\u88ab\u5f53\u524d\u4f4d\u7f6e\u6ce8\u610f\u5230\u3002\n    # \u4f7f\u7528numpy\u7684tri\u51fd\u6570\u521b\u5efa\u4e00\u4e2a\u4e0b\u4e09\u89d2\u77e9\u9635\uff0c\u7136\u540e\u4e58\u4ee5\u4e00\u4e2a\u975e\u5e38\u5c0f\u7684\u8d1f\u6570(-1e10)\uff0c\n    # \u4ee5\u5728softmax\u5e94\u7528\u540e\u5c06\u8fd9\u4e9b\u4f4d\u7f6e\u7684\u6743\u91cd\u63a5\u8fd1\u4e8e\u96f6\u3002\n    causal_mask = np.tri(x.shape[0], dtype=x.dtype) * -1e10  # [n_seq, n_seq]\n\n    # \u6267\u884c\u56e0\u679c\u81ea\u6ce8\u610f\u529b\n    # \u8c03\u7528\u4fee\u6539\u540e\u7684attention\u51fd\u6570\uff0c\u4f20\u5165Q\u3001K\u3001V\u548c\u56e0\u679c\u63a9\u7801\uff0c\u6267\u884c\u56e0\u679c\u81ea\u6ce8\u610f\u529b\u673a\u5236\u3002\n    x = attention(q, k, v, causal_mask) # [n_seq, n_embd] -&gt; [n_seq, n_embd]\n\n    # \u8f93\u51fa\u6295\u5f71\n    # \u6700\u540e\uff0c\u4f7f\u7528\u53e6\u4e00\u4e2a\u7ebf\u6027\u53d8\u6362\uff08\u901a\u8fc7\u53c2\u6570c_proj\uff09\u5c06\u81ea\u6ce8\u610f\u529b\u7684\u8f93\u51fa\u6620\u5c04\u56de\u539f\u59cb\u5d4c\u5165\u7ef4\u5ea6\u7a7a\u95f4\u3002\n    x = linear(x, **c_proj) # [n_seq, n_embd] @ [n_embd, n_embd] = [n_seq, n_embd]\n\n    # \u8fd4\u56de\u56e0\u679c\u81ea\u6ce8\u610f\u529b\u5c42\u7684\u8f93\u51fa\u3002\n    return x\n<\/pre><\/div>\n\n\n\n<h4 class=\"wp-block-heading\" id=\"\u591a\u5934\"><strong>4.6.4 \u591a\u5934(Multi-Head)<\/strong><\/h4>\n\n\n\n<p>\u6211\u4eec\u53ef\u4ee5\u8fdb\u4e00\u6b65\u6539\u8fdb\u6211\u4eec\u7684\u5b9e\u73b0\uff0c\u901a\u8fc7\u8fdb\u884c<code>n_head<\/code>\u4e2a\u72ec\u7acb\u7684\u6ce8\u610f\u529b\u8ba1\u7b97\uff0c\u5c06\u6211\u4eec\u7684\u67e5\u8be2\uff08queries\uff09\uff0c\u952e\uff08keys\uff09\u548c\u503c\uff08values\uff09\u62c6\u5206\u5230\u591a\u4e2a<strong>\u5934\uff08heads\uff09<\/strong>\u91cc\u53bb\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">def mha(x, c_attn, c_proj, n_head):  # [n_seq, n_embd] -&gt; [n_seq, n_embd]\n    # QKV\u6295\u5f71\n    # \u9996\u5148\uff0c\u901a\u8fc7\u4e00\u4e2a\u7ebf\u6027\u53d8\u6362\uff08\u4f7f\u7528\u7ed9\u5b9a\u7684c_attn\u53c2\u6570\uff09\u5c06\u8f93\u5165x\u6620\u5c04\u5230\u5408\u5e76\u4e86\u67e5\u8be2Q\u3001\u952eK\u3001\u503cV\u7684\u7a7a\u95f4\u3002\n    x = linear(x, **c_attn)  # [n_seq, n_embd] -&gt; [n_seq, 3*n_embd]\n\n    # \u5206\u5272\u6210QKV\n    # \u7136\u540e\uff0c\u5c06\u6620\u5c04\u540e\u7684\u7ed3\u679c\u6cbf\u6700\u540e\u4e00\u4e2a\u7ef4\u5ea6\u5206\u5272\u6210\u4e09\u4e2a\u76f8\u7b49\u7684\u90e8\u5206\uff0c\u5206\u522b\u5bf9\u5e94\u67e5\u8be2Q\u3001\u952eK\u548c\u503cV\u3002\n    qkv = np.split(x, 3, axis=-1)  # [n_seq, 3*n_embd] -&gt; 3\u4e2a[n_seq, n_embd]\n\n    # \u5206\u5272\u6210\u5934\n    # \u5bf9Q\u3001K\u3001V\u6bcf\u4e2a\u8fdb\u884c\u8fdb\u4e00\u6b65\u5206\u5272\uff0c\u6309\u5934\u7684\u6570\u91cfn_head\u5206\u5272\uff0c\u4e3a\u6bcf\u4e2a\u5934\u5206\u914d\u4e00\u90e8\u5206\u7ef4\u5ea6\u3002\n    qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), qkv))  # [3, n_seq, n_embd] -&gt; [3, n_head, n_seq, n_embd\/n_head]\n\n    # \u56e0\u679c\u63a9\u7801\n    # \u521b\u5efa\u4e00\u4e2a\u56e0\u679c\u63a9\u7801\uff0c\u7528\u4e8e\u9690\u85cf\u672a\u6765\u7684\u8f93\u5165\uff0c\u9632\u6b62\u5b83\u4eec\u88ab\u5f53\u524d\u4f4d\u7f6e\u6ce8\u610f\u5230\u3002\n    causal_mask = (1 - np.tri(x.shape[0]), dtype=x.dtype) * -1e10  # [n_seq, n_seq]\n\n    # \u5bf9\u6bcf\u4e2a\u5934\u6267\u884c\u6ce8\u610f\u529b\u8ba1\u7b97\n    # \u904d\u5386\u6bcf\u4e2a\u5934\uff0c\u5bf9Q\u3001K\u3001V\u6267\u884c\u5e26\u6709\u56e0\u679c\u63a9\u7801\u7684\u6ce8\u610f\u529b\u8ba1\u7b97\u3002\n    out_heads = [attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)]  # [n_head, n_seq, n_embd\/n_head]\n\n    # \u5408\u5e76\u5934\n    # \u5c06\u6240\u6709\u5934\u7684\u8f93\u51fa\u6cbf\u7740\u6700\u540e\u4e00\u4e2a\u7ef4\u5ea6\u5408\u5e76\u8d77\u6765\uff0c\u6062\u590d\u5230\u539f\u59cb\u7684\u5d4c\u5165\u7ef4\u5ea6\u3002\n    x = np.hstack(out_heads)  # [n_seq, n_embd]\n\n    # \u8f93\u51fa\u6295\u5f71\n    # \u5bf9\u5408\u5e76\u540e\u7684\u8f93\u51fa\u6267\u884c\u4e00\u4e2a\u7ebf\u6027\u53d8\u6362\uff08\u4f7f\u7528\u7ed9\u5b9a\u7684c_proj\u53c2\u6570\uff09\uff0c\u5f97\u5230\u6700\u7ec8\u7684\u8f93\u51fa\u3002\n    x = linear(x, **c_proj)  # [n_seq, n_embd] -&gt; [n_seq, n_embd]\n\n    # \u8fd4\u56de\u591a\u5934\u6ce8\u610f\u529b\u7684\u8f93\u51fa\u3002\n    return x\n<\/pre><\/div>\n\n\n\n<p>\u8fd9\u91cc\u6dfb\u52a0\u4e86\u4e09\u6b65:<\/p>\n\n\n\n<ol class=\"wp-block-list\">\n<li>\u62c6\u5206<code>q<\/code>\uff0c&nbsp;<code>k<\/code>\uff0c&nbsp;<code>v<\/code>\u5230<code>n_head<\/code>\u4e2a\u5934\uff1a<\/li>\n<\/ol>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \"># \u5206\u5272\u6210\u5934\n# \u5bf9\u67e5\u8be2\uff08Q\uff09\u3001\u952e\uff08K\uff09\u3001\u503c\uff08V\uff09\u7684\u6bcf\u4e00\u4e2a\u8fdb\u884c\u8fdb\u4e00\u6b65\u5206\u5272\uff0c\u4ee5\u652f\u6301\u591a\u5934\u6ce8\u610f\u529b\u673a\u5236\u3002\n# \u901a\u8fc7\u5728\u6700\u540e\u4e00\u4e2a\u7ef4\u5ea6\uff08axis=-1\uff09\u4e0a\u5206\u5272\uff0c\u4e3a\u6bcf\u4e2a\u5934\u5206\u914d\u4e00\u90e8\u5206\u7ef4\u5ea6\uff08n_embd\/n_head\uff09\u3002\n# \u8fd9\u91cc\u4f7f\u7528lambda\u51fd\u6570\u548cmap\u51fd\u6570\u6765\u5bf9qkv\u4e2d\u7684\u6bcf\u4e2a\u5143\u7d20\uff08Q\u3001K\u3001V\uff09\u8fdb\u884c\u64cd\u4f5c\u3002\n# np.split\u51fd\u6570\u5b9e\u73b0\u5b9e\u9645\u7684\u5206\u5272\u64cd\u4f5c\uff0cn_head\u6307\u5b9a\u4e86\u5206\u5272\u7684\u6bb5\u6570\u3002\n# \u7ed3\u679cqkv_heads\u662f\u4e00\u4e2a\u5217\u8868\uff0c\u5305\u542b\u4e09\u4e2a\u5143\u7d20\uff08Q\u3001K\u3001V\uff09\uff0c\u6bcf\u4e2a\u5143\u7d20\u7684\u5f62\u72b6\u4e3a[n_head, n_seq, n_embd\/n_head]\u3002\nqkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), qkv))  # [3, n_seq, n_embd] -&gt; [3, n_head, n_seq, n_embd\/n_head]\n<\/pre><\/div>\n\n\n\n<ol class=\"wp-block-list\" start=\"2\">\n<li>\u4e3a\u6bcf\u4e2a\u5934\u8ba1\u7b97\u6ce8\u610f\u529b\uff1a<\/li>\n<\/ol>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \"># \u5bf9\u6bcf\u4e2a\u5934\u6267\u884c\u6ce8\u610f\u529b\u8ba1\u7b97\n# \u4f7f\u7528\u5217\u8868\u63a8\u5bfc\u5f0f\u904d\u5386qkv_heads\u4e2d\u7684\u6bcf\u4e2a\u5934\uff08Q\u3001K\u3001V\uff09\uff0c\u5bf9\u6bcf\u4e2a\u5934\u6267\u884c\u6ce8\u610f\u529b\u51fd\u6570\u3002\n# qkv_heads\u88ab\u89e3\u538b\uff08unpack\uff09\u4e3aQ\u3001K\u3001V\u4e09\u4e2a\u7ec4\u4ef6\uff0c\u6bcf\u4e2a\u7ec4\u4ef6\u90fd\u662f\u591a\u5934\u7684\u96c6\u5408\uff0c\n# \u5176\u4e2d\u5305\u542b\u4e86\u6bcf\u4e2a\u5934\u5bf9\u5e94\u7684\u67e5\u8be2\uff08Q\uff09\u3001\u952e\uff08K\uff09\u548c\u503c\uff08V\uff09\u3002\n# \u8c03\u7528attention\u51fd\u6570\u8ba1\u7b97\u6bcf\u4e2a\u5934\u7684\u8f93\u51fa\uff0c\u8fd9\u91cc\u5047\u8bbeattention\u51fd\u6570\u5df2\u7ecf\u5b9a\u4e49\uff0c\n# \u80fd\u591f\u63a5\u53d7\u5355\u4e2a\u5934\u7684Q\u3001K\u3001V\u5e76\u8fdb\u884c\u6ce8\u610f\u529b\u8ba1\u7b97\u3002\n# \u6ce8\u610f\uff1a\u8fd9\u91cc\u7684\u8c03\u7528\u7701\u7565\u4e86\u63a9\u7801\u53c2\u6570\uff0c\u5b9e\u9645\u4f7f\u7528\u65f6\u53ef\u80fd\u9700\u8981\u8003\u8651\u662f\u5426\u4f20\u9012\u63a9\u7801\u53c2\u6570\u4ee5\u5b9e\u73b0\u7279\u5b9a\u7684\u6ce8\u610f\u529b\u673a\u5236\uff0c\u5982\u56e0\u679c\u63a9\u7801\u3002\n# \u7ed3\u679cout_heads\u662f\u4e00\u4e2a\u5217\u8868\uff0c\u5176\u4e2d\u5305\u542b\u4e86\u6bcf\u4e2a\u5934\u7ecf\u8fc7\u6ce8\u610f\u529b\u8ba1\u7b97\u540e\u7684\u8f93\u51fa\uff0c\n# \u6bcf\u4e2a\u8f93\u51fa\u7684\u5f62\u72b6\u4e3a[n_seq, n_embd\/n_head]\uff0c\u4e0e\u8f93\u5165\u5f62\u72b6\u76f8\u540c\uff0c\u4f46\u6700\u540e\u4e00\u4e2a\u7ef4\u5ea6\u88ab\u5206\u5272\u4e3a\u5934\u6570\u3002\nout_heads = [attention(q, k, v) for q, k, v in zip(*qkv_heads)]  # [3, n_head, n_seq, n_embd\/n_head] -&gt; [n_head, n_seq, n_embd\/n_head]\n<\/pre><\/div>\n\n\n\n<ol class=\"wp-block-list\" start=\"3\">\n<li>\u5408\u5e76\u6bcf\u4e2a\u5934\u7684\u8f93\u51fa\uff1a<\/li>\n<\/ol>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \"># \u5408\u5e76\u5934\n# \u5c06\u591a\u5934\u6ce8\u610f\u529b\u7684\u8f93\u51fa\u4ece\u591a\u4e2a\u5934\u7684\u8868\u793a\u5408\u5e76\u56de\u5355\u4e2a\u8868\u793a\u3002\n# \u4f7f\u7528np.concatenate\u800c\u4e0d\u662fnp.hstack\uff0c\u56e0\u4e3a\u6211\u4eec\u9700\u8981\u6cbf\u7740\u6700\u540e\u4e00\u4e2a\u7ef4\u5ea6\uff08\u5d4c\u5165\u7ef4\u5ea6\uff09\u5408\u5e76\u5934\u7684\u8f93\u51fa\uff0c\n# \u800cnp.hstack\u9ed8\u8ba4\u64cd\u4f5c\u4e8e\u7b2c\u4e00\u4e2a\u8f74\uff08\u5373\u5728\u8fd9\u4e2a\u4e0a\u4e0b\u6587\u4e2d\u4e0d\u6b63\u786e\uff09\u3002\n# \u6b63\u786e\u7684\u5408\u5e76\u64cd\u4f5c\u5e94\u8be5\u6307\u5b9aaxis=-1\u6765\u786e\u4fdd\u6cbf\u7740\u5d4c\u5165\u7ef4\u5ea6\u5408\u5e76\u3002\n# \u8fd9\u91ccout_heads\u662f\u4e00\u4e2a\u5f62\u72b6\u4e3a[n_head, n_seq, n_embd\/n_head]\u7684\u5217\u8868\uff0c\u6211\u4eec\u60f3\u8981\u7684\u7ed3\u679c\u662f\u4e00\u4e2a\n# \u5355\u4e00\u7684\u6570\u7ec4\uff0c\u5f62\u72b6\u4e3a[n_seq, n_embd]\uff0c\u5373\u5c06\u6240\u6709\u5934\u7684\u8f93\u51fa\u5408\u5e76\u5230\u5d4c\u5165\u7ef4\u5ea6\u3002\nx = np.concatenate(out_heads, axis=-1)  # [n_head, n_seq, n_embd\/n_head] -&gt; [n_seq, n_embd]\n<\/pre><\/div>\n\n\n\n<p>\u6ce8\u610f\uff0c\u8fd9\u6837\u53ef\u4ee5\u5c06\u6bcf\u4e2a\u6ce8\u610f\u529b\u8ba1\u7b97\u7684\u7ef4\u5ea6\u4ece<code>n_embd<\/code>\u51cf\u5c11\u5230<code>n_embd\/n_head<\/code>\u3002\u8fd9\u662f\u4e00\u4e2a\u6743\u8861\u3002\u5bf9\u4e8e\u7f29\u51cf\u4e86\u7684\u7ef4\u5ea6\uff0c\u6211\u4eec\u7684\u6a21\u578b\u5728\u901a\u8fc7\u6ce8\u610f\u529b\u5efa\u6a21\u5173\u7cfb\u65f6\u83b7\u5f97\u4e86\u989d\u5916\u7684<code>\u5b50\u7a7a\u95f4<\/code>\u3002\u4f8b\u5982\uff0c\u4e5f\u8bb8\u4e00\u4e2a\u6ce8\u610f\u529b\u5934\u8d1f\u8d23\u5c06\u4ee3\u8bcd\u4e0e\u4ee3\u8bcd\u6240\u6307\u7684\u4eba\u8054\u7cfb\u8d77\u6765\uff1b\u4e5f\u8bb8\u53e6\u4e00\u4e2a\u6ce8\u610f\u529b\u5934\u8d1f\u8d23\u901a\u8fc7\u53e5\u53f7\u5c06\u53e5\u5b50\u5206\u7ec4\uff1b\u53e6\u4e00\u4e2a\u5219\u53ef\u80fd\u53ea\u662f\u8bc6\u522b\u54ea\u4e9b\u5355\u8bcd\u662f\u5b9e\u4f53\uff0c\u54ea\u4e9b\u4e0d\u662f\u3002\u867d\u7136\u8fd9\u53ef\u80fd\u4e5f\u53ea\u662f\u53e6\u4e00\u4e2a\u795e\u7ecf\u7f51\u7edc\u9ed1\u76d2\u800c\u5df2\u3002<\/p>\n\n\n\n<p>\u6211\u4eec\u7f16\u5199\u7684\u4ee3\u7801\u6309\u987a\u5e8f\u5faa\u73af\u6267\u884c\u6bcf\u4e2a\u5934\u7684\u6ce8\u610f\u529b\u8ba1\u7b97\uff08\u6bcf\u6b21\u4e00\u4e2a\uff09\uff0c\u5f53\u7136\u8fd9\u5e76\u4e0d\u662f\u5f88\u9ad8\u6548\u3002\u5728\u5b9e\u8df5\u4e2d\uff0c\u4f60\u4f1a\u5e0c\u671b\u5e76\u884c\u5904\u7406\u8fd9\u4e9b\u8ba1\u7b97\u3002\u5f53\u7136\u5728\u672c\u6587\u4e2d\u8003\u8651\u5230\u7b80\u6d01\u6027\uff0c\u6211\u4eec\u5c06\u4fdd\u6301\u8fd9\u79cd\u987a\u5e8f\u6267\u884c\u3002<\/p>\n\n\n\n<p>\u597d\u5566\uff0c\u6709\u4e86\u4ee5\u4e0a\u8fd9\u4e9b\uff0c\u6211\u4eec\u7ec8\u4e8e\u5b8c\u6210\u4e86GPT\u7684\u5b9e\u73b0\uff01\u73b0\u5728\u8981\u505a\u7684\u5c31\u662f\u5c06\u5b83\u4eec\u7ec4\u5408\u8d77\u6765\u5e76\u8fd0\u884c\u4ee3\u7801\u3002<\/p>\n\n\n\n<h2 class=\"wp-block-heading\" id=\"\u5c06\u6240\u6709\u4ee3\u7801\u7ec4\u5408\u8d77\u6765\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#%E5%B0%86%E6%89%80%E6%9C%89%E4%BB%A3%E7%A0%81%E7%BB%84%E5%90%88%E8%B5%B7%E6%9D%A5\"><\/a>5. \u5c06\u6240\u6709\u4ee3\u7801\u7ec4\u5408\u8d77\u6765<\/strong><\/h2>\n\n\n\n<p>\u5c06\u6240\u6709\u4ee3\u7801\u7ec4\u5408\u8d77\u6765\uff0c\u6211\u4eec\u5c31\u5f97\u5230\u4e86<a href=\"https:\/\/github.com\/jaymody\/picoGPT\/blob\/main\/gpt2.py\" target=\"_blank\" rel=\"noreferrer noopener\"><code>gpt2.py<\/code><\/a>\uff0c\u603b\u5171\u7684\u4ee3\u7801\u53ea\u6709120\u884c\uff08<a href=\"https:\/\/github.com\/jaymody\/picoGPT\/blob\/a750c145ba4d09d5764806a6c78c71ffaff88e64\/gpt2_pico.py#L3-L58\" target=\"_blank\" rel=\"noreferrer noopener\">\u5982\u679c\u4f60\u79fb\u9664\u6ce8\u91ca\u3001\u7a7a\u683c\u4e4b\u7c7b\u7684\uff0c\u90a3\u5c31\u53ea\u670960\u884c<\/a>\uff09\u3002<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">import numpy as np\nfrom tqdm import tqdm\n\n# \u5b9a\u4e49GELU\u6fc0\u6d3b\u51fd\u6570\u3002\ndef gelu(x):\n    # GELU\u6fc0\u6d3b\u51fd\u6570\u7684\u5b9e\u73b0\u3002\n    return 0.5 * x * (1 + np.tanh(np.sqrt(2 \/ np.pi) * (x + 0.044715 * x**3)))\n\n# \u5b9a\u4e49softmax\u51fd\u6570\u3002\ndef softmax(x):\n    # \u5bf9\u8f93\u5165x\u5e94\u7528softmax\uff0c\u7528\u4e8e\u8ba1\u7b97\u6982\u7387\u5206\u5e03\u3002\n    exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True))\n    return exp_x \/ np.sum(exp_x, axis=-1, keepdims=True)\n\n# \u5b9a\u4e49\u5c42\u5f52\u4e00\u5316\u51fd\u6570\u3002\ndef layer_norm(x, g, b, eps: float = 1e-5):\n    # \u5bf9\u8f93\u5165x\u8fdb\u884c\u5c42\u5f52\u4e00\u5316\u3002\n    mean = np.mean(x, axis=-1, keepdims=True)\n    variance = np.var(x, axis=-1, keepdims=True)\n    x = (x - mean) \/ np.sqrt(variance + eps)\n    return g * x + b\n\n# \u5b9a\u4e49\u7ebf\u6027\u53d8\u6362\u51fd\u6570\u3002\ndef linear(x, w, b):\n    # \u5bf9\u8f93\u5165x\u5e94\u7528\u7ebf\u6027\u53d8\u6362\u3002\n    return x @ w + b\n\n# \u5b9a\u4e49\u524d\u9988\u7f51\u7edc\u3002\ndef ffn(x, c_fc, c_proj):\n    # \u5bf9\u8f93\u5165x\u5e94\u7528\u524d\u9988\u7f51\u7edc\u3002\n    a = gelu(linear(x, **c_fc))\n    x = linear(a, **c_proj)\n    return x\n\n# \u5b9a\u4e49\u6ce8\u610f\u529b\u673a\u5236\u3002\ndef attention(q, k, v, mask):\n    # \u5bf9\u8f93\u5165\u6267\u884c\u5e26\u63a9\u7801\u7684\u6ce8\u610f\u529b\u8ba1\u7b97\u3002\n    return softmax(q @ k.T \/ np.sqrt(q.shape[-1]) + mask) @ v\n\n# \u5b9a\u4e49\u591a\u5934\u6ce8\u610f\u529b\u673a\u5236\u3002\ndef mha(x, c_attn, c_proj, n_head):\n    # \u5bf9\u8f93\u5165x\u5e94\u7528\u591a\u5934\u6ce8\u610f\u529b\u673a\u5236\u3002\n    x = linear(x, **c_attn)\n    qkv = np.split(x, 3, axis=-1)\n    qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), qkv))\n    causal_mask = (1 - np.tri(x.shape[0], dtype=x.dtype)) * -1e10\n    out_heads = [attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)]\n    x = np.hstack(out_heads)\n    x = linear(x, **c_proj)\n    return x\n\n# \u5b9a\u4e49Transformer\u5757\u3002\ndef transformer_block(x, mlp, attn, ln_1, ln_2, n_head):\n    # \u5bf9\u8f93\u5165x\u5e94\u7528Transformer\u5757\u3002\n    x = x + mha(layer_norm(x, **ln_1), **attn, n_head=n_head)\n    x = x + ffn(layer_norm(x, **ln_2), **mlp)\n    return x\n\n# \u5b9a\u4e49\u7b80\u5316\u7248\u7684GPT-2\u6a21\u578b\u3002\ndef gpt2(inputs, wte, wpe, blocks, ln_f, n_head):\n    # \u5bf9\u8f93\u5165\u5e8f\u5217\u5e94\u7528GPT-2\u6a21\u578b\u3002\n    x = wte[inputs] + wpe[range(len(inputs))]\n    for block in blocks:\n        x = transformer_block(x, **block, n_head=n_head)\n    x = layer_norm(x, **ln_f)\n    return x @ wte.T\n\n# \u5b9a\u4e49\u6587\u672c\u751f\u6210\u51fd\u6570\u3002\ndef generate(inputs, params, n_head, n_tokens_to_generate):\n    # \u751f\u6210\u6587\u672c\u7684\u81ea\u56de\u5f52\u5faa\u73af\u3002\n    for _ in tqdm(range(n_tokens_to_generate), \"generating\"):\n        logits = gpt2(inputs, **params, n_head=n_head)\n        next_id = np.argmax(logits[-1])\n        inputs.append(int(next_id))\n    return inputs[len(inputs) - n_tokens_to_generate:]\n\n# \u5b9a\u4e49\u4e3b\u51fd\u6570\u3002\ndef main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = \"124M\", models_dir: str = \"models\"):\n    # \u4e3b\u51fd\u6570\uff0c\u52a0\u8f7d\u6a21\u578b\u53c2\u6570\uff0c\u7f16\u7801\u8f93\u5165\uff0c\u751f\u6210\u6587\u672c\uff0c\u89e3\u7801\u8f93\u51fa\u3002\n    from utils import load_encoder_hparams_and_params\n    encoder, hparams, params = load_encoder_hparams_and_params(model_size, models_dir)\n    input_ids = encoder.encode(prompt)\n    assert len(input_ids) + n_tokens_to_generate &lt; hparams[\"n_ctx\"]\n    output_ids = generate(input_ids, params, hparams[\"n_head\"], n_tokens_to_generate)\n    output_text = encoder.decode(output_ids)\n    return output_text\n\n# \u5165\u53e3\u70b9\u3002\nif __name__ == \"__main__\":\n    import fire\n    fire.Fire(main)\n<\/pre><\/div>\n\n\n\n<p>\u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u4ee5\u4e0b\u4ee3\u7801\u6d4b\u8bd5\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">python gpt2.py \\\n    \"Alan Turing theorized that computers would one day become\" \\\n    --n_tokens_to_generate 8<\/pre><\/div>\n\n\n\n<p>\u5176\u8f93\u51fa\u662f\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">the most powerful machines on the planet.<\/pre><\/div>\n\n\n\n<p>\u6210\u529f\u8fd0\u884c\uff01\uff01\uff01<\/p>\n\n\n\n<p>\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b<a href=\"https:\/\/gist.github.com\/jaymody\/9054ca64eeea7fad1b58a185696bb518\" target=\"_blank\" rel=\"noreferrer noopener\">Dockerfile<\/a>\u9a8c\u8bc1\u6211\u4eec\u7684\u5b9e\u73b0\u4e0e<a href=\"https:\/\/github.com\/openai\/gpt-2\" target=\"_blank\" rel=\"noreferrer noopener\">OpenAI\u7684\u5b98\u65b9GPT-2\u4ed3\u5e93<\/a>\u4ea7\u751f\u76f8\u540c\u7684\u7ed3\u679c\uff08\u6ce8\u610f\uff1a\u8fd9\u5728M1 Macbooks\u4e0a\u65e0\u6cd5\u8fd0\u884c\uff0c\u8fd9\u91cc\u6d89\u53ca\u5230TensorFlow\u7684\u652f\u6301\u95ee\u9898\u3002\u8fd8\u6709\u4e00\u4e2a\u8b66\u544a\u662f\uff1a\u8fd9\u4f1a\u4e0b\u8f7d\u6240\u67094\u4e2aGPT-2\u6a21\u578b\uff0c\u800c\u8fd9\u610f\u5473\u7740\u5927\u91cfGB\u89c4\u6a21\u7684\u6587\u4ef6\u9700\u8981\u88ab\u4e0b\u8f7d\uff09\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">docker build -t \"openai-gpt-2\" \"https:\/\/gist.githubusercontent.com\/jaymody\/9054ca64eeea7fad1b58a185696bb518\/raw\/Dockerfile\"\ndocker run -dt \"openai-gpt-2\" --name \"openai-gpt-2-app\"\ndocker exec -it \"openai-gpt-2-app\" \/bin\/bash -c 'python3 src\/interactive_conditional_samples.py --length 8 --model_type 124M --top_k 1'\n# paste \"Alan Turing theorized that computers would one day become\" when prompted<\/pre><\/div>\n\n\n\n<p>\u8fd9\u91cc\u5e94\u8be5\u4f1a\u7ed9\u51fa\u5b8c\u5168\u76f8\u540c\u7684\u7ed3\u679c\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">he most powerful machines on the planet.<\/pre><\/div>\n\n\n\n<h2 class=\"wp-block-heading\" id=\"\u4e0b\u4e00\u6b65\u5462\uff1f\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#%E4%B8%8B%E4%B8%80%E6%AD%A5%E5%91%A2%EF%BC%9F\"><\/a>6. \u4e0b\u4e00\u6b65\u5462\uff1f<\/strong><\/h2>\n\n\n\n<p>\u8fd9\u4e2a\u5b9e\u73b0\u867d\u7136\u4e0d\u9519\uff0c\u4f46\u8fd8\u7f3a\u5c11\u5f88\u591a\u989d\u5916\u7684\u529f\u80fd\uff1a<\/p>\n\n\n\n<h3 class=\"wp-block-heading\" id=\"GPU-TPU-\u652f\u6301\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#GPU-TPU-%E6%94%AF%E6%8C%81\"><\/a>6.1 GPU\/TPU \u652f\u6301<\/strong><\/h3>\n\n\n\n<p>\u5c06NumPy\u66ff\u6362\u4e3a<a href=\"https:\/\/github.com\/google\/jax\" target=\"_blank\" rel=\"noreferrer noopener\">JAX<\/a>\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">import jax.numpy as np<\/pre><\/div>\n\n\n\n<p>\u641e\u5b9a\uff01\u73b0\u5728\u4f60\u53ef\u4ee5\u5728GPU\u751a\u81f3\u662f<a href=\"https:\/\/cloud.google.com\/tpu\/docs\/system-architecture-tpu-vm\" target=\"_blank\" rel=\"noreferrer noopener\">TPU<\/a>\u4e0a\u4f7f\u7528\u8fd9\u4e2a\u4ee3\u7801\u4e86\uff01\u524d\u63d0\u662f\u4f60<a href=\"https:\/\/github.com\/google\/jax#installation\" target=\"_blank\" rel=\"noreferrer noopener\">\u6b63\u786e\u5730\u5b89\u88c5\u4e86JAX<\/a>\u3002<\/p>\n\n\n\n<blockquote class=\"wp-block-quote is-layout-flow wp-block-quote-is-layout-flow\">\n<p>\u8bd1\u8005\u6ce8\uff1aJAX\u662f\u4e2a\u597d\u4e1c\u897f\uff1a\uff09<\/p>\n<\/blockquote>\n\n\n\n<h3 class=\"wp-block-heading\" id=\"\u53cd\u5411\u4f20\u64ad\"><strong>6.2 \u53cd\u5411\u4f20\u64ad(Backpropagation)<\/strong><\/h3>\n\n\n\n<p>\u5982\u679c\u6211\u4eec\u7528JAX\u66ff\u6362\u6389\u4e86NumPy\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">import jax.numpy as np<\/pre><\/div>\n\n\n\n<p>\u90a3\u4e48\u8ba1\u7b97\u68af\u5ea6\u4e5f\u53d8\u5f97\u5f88\u7b80\u5355\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">import jax.numpy as np\n\n# \u5b9a\u4e49\u8bed\u8a00\u6a21\u578b\u7684\u635f\u5931\u51fd\u6570\ndef lm_loss(params, inputs, n_head) -&gt; float:\n    # \u5c06\u8f93\u5165\u5e8f\u5217\u5206\u5272\u6210\u8f93\u5165(x)\u548c\u76ee\u6807\u8f93\u51fa(y)\u3002x\u7531\u9664\u4e86\u6700\u540e\u4e00\u4e2a\u5143\u7d20\u5916\u7684\u6240\u6709\u5143\u7d20\u7ec4\u6210\uff0c\n    # y\u7531\u9664\u4e86\u7b2c\u4e00\u4e2a\u5143\u7d20\u5916\u7684\u6240\u6709\u5143\u7d20\u7ec4\u6210\uff0c\u5b9e\u73b0\u4e86\u5411\u524d\u4e00\u4f4d\u7684\u6548\u679c\uff0c\u7528\u4e8e\u9884\u6d4b\u4e0b\u4e00\u4e2a\u5355\u8bcd\u3002\n    x, y = inputs[:-1], inputs[1:]\n    \n    # \u8c03\u7528gpt2\u6a21\u578b\u51fd\u6570\uff0c\u4f20\u5165\u8f93\u5165x\uff0c\u6a21\u578b\u53c2\u6570params\uff0c\u548c\u5934\u7684\u6570\u91cfn_head\u3002\n    # \u5047\u5b9agpt2\u51fd\u6570\u8fd4\u56de\u5bf9\u6bcf\u4e2a\u53ef\u80fd\u8f93\u51fa\u7684\u6982\u7387\u5206\u5e03\u3002\n    output = gpt2(x, **params, n_head=n_head)\n    \n    # \u8ba1\u7b97\u635f\u5931\uff0c\u8fd9\u91cc\u4f7f\u7528\u4e86\u8d1f\u5bf9\u6570\u4f3c\u7136\u4f5c\u4e3a\u635f\u5931\u51fd\u6570\u3002\u8fd9\u8981\u6c42output\u662f\u4e00\u4e2a\u6982\u7387\u5206\u5e03\uff0c\n    # \u5176\u4e2doutput\u7684\u6bcf\u4e00\u884c\u5bf9\u5e94\u8f93\u5165\u5e8f\u5217\u4e2d\u6bcf\u4e2a\u4f4d\u7f6e\u7684\u4e0b\u4e00\u4e2a\u5355\u8bcd\u7684\u6982\u7387\u5206\u5e03\u3002\n    # \u6ce8\u610f\uff1a\u8fd9\u91cc\u7b80\u5355\u5730\u4f7f\u7528output[y]\u53ef\u80fd\u5e76\u4e0d\u6b63\u786e\uff0c\u9664\u975eoutput\u786e\u5b9e\u4ee5\u8fd9\u79cd\u65b9\u5f0f\u7ec4\u7ec7\u3002\n    # \u901a\u5e38\u9700\u8981\u66f4\u7cbe\u786e\u5730\u7d22\u5f15output\u6765\u5339\u914dy\u7684\u5f62\u72b6\u3002\n    loss = np.mean(-np.log(output[y]))\n    \n    # \u8fd4\u56de\u8ba1\u7b97\u51fa\u7684\u5e73\u5747\u635f\u5931\u3002\n    return loss\n\n# \u4f7f\u7528jax\u5e93\u6765\u8ba1\u7b97lm_loss\u51fd\u6570\u5173\u4e8eparams\u7684\u68af\u5ea6\u3002\n# \u8fd9\u91cc\u5047\u8bbeparams\u662fgpt2\u6a21\u578b\u7684\u53c2\u6570\uff0cinputs\u662f\u8f93\u5165\u5e8f\u5217\uff0cn_head\u662f\u591a\u5934\u6ce8\u610f\u529b\u673a\u5236\u4e2d\u5934\u7684\u6570\u91cf\u3002\ngrads = jax.grad(lm_loss)(params, inputs, n_head)\n<\/pre><\/div>\n\n\n\n<h3 class=\"wp-block-heading\" id=\"\u6279\u5904\u7406\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#%E6%89%B9%E5%A4%84%E7%90%86\"><\/a>6.3 \u6279\u5904\u7406(Batching)<\/strong><\/h3>\n\n\n\n<p>\u8fd8\u662f\u90a3\u53e5\u8bdd\uff0c\u5982\u679c\u6211\u4eec\u7528<a href=\"https:\/\/github.com\/google\/jax\" target=\"_blank\" rel=\"noreferrer noopener\">JAX<\/a>\u66ff\u6362\u6389NumPy<sup><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#fn:7\">[7]<\/a><\/sup>\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">import jax.numpy as np<\/pre><\/div>\n\n\n\n<p>\u90a3\u4e48\u8ba9<code>gpt2<\/code>\u51fd\u6570\u6279\u91cf\u5316\u5c31\u53d8\u5f97\u5f88\u7b80\u5355\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \"># \u4f7f\u7528jax.vmap\u5c06gpt2\u51fd\u6570\u6279\u91cf\u5316\u3002in_axes\u53c2\u6570\u5b9a\u4e49\u4e86\u6bcf\u4e2a\u8f93\u5165\u53c2\u6570\u7684\u6279\u5904\u7406\u8f74\u3002\n# \u5bf9\u4e8ebatched_inputs\uff0c\u6279\u5904\u7406\u8f74\u662f0\uff08\u8868\u793a\u6bcf\u4e2a\u6279\u6b21\u7684\u6570\u636e\u6cbf\u7740\u7b2c\u4e00\u4e2a\u7ef4\u5ea6\u6392\u5217\uff09\u3002\n# \u5bf9\u4e8e\u5176\u4ed6\u53c2\u6570\uff08\u6a21\u578b\u53c2\u6570\u548c\u914d\u7f6e\uff09\uff0c\u7531\u4e8e\u5b83\u4eec\u5728\u6279\u5904\u7406\u4e2d\u4e0d\u53d8\uff0c\u56e0\u6b64\u8bbe\u7f6e\u4e3aNone\uff0c\u8868\u793a\u4e0d\u8fdb\u884c\u6279\u5904\u7406\u3002\ngpt2_batched = jax.vmap(gpt2, in_axes=[0, None, None, None, None, None])\n\n# \u8c03\u7528\u6279\u91cf\u5316\u540e\u7684gpt2\u51fd\u6570\u5904\u7406\u6279\u6b21\u5316\u7684\u8f93\u5165\u3002\n# batched_inputs\u662f\u4e00\u4e2a\u5f62\u72b6\u4e3a[batch, seq_len]\u7684\u6570\u7ec4\uff0c\u5176\u4e2dbatch\u662f\u6279\u6b21\u5927\u5c0f\uff0cseq_len\u662f\u5e8f\u5217\u957f\u5ea6\u3002\n# gpt2_batched\u51fd\u6570\u7684\u8f93\u51fa\u5c06\u662f\u4e00\u4e2a\u5f62\u72b6\u4e3a[batch, seq_len, vocab]\u7684\u6570\u7ec4\uff0c\n# \u5176\u4e2dvocab\u662f\u8bcd\u6c47\u8868\u7684\u5927\u5c0f\u3002\u8fd9\u8868\u793a\u5bf9\u4e8e\u6279\u6b21\u4e2d\u7684\u6bcf\u4e2a\u8f93\u5165\u5e8f\u5217\uff0c\u6a21\u578b\u90fd\u8fd4\u56de\u4e86\u4e00\u4e2a\u5e8f\u5217\u957f\u5ea6\u76f8\u540c\u7684\uff0c\n# \u6bcf\u4e2a\u4f4d\u7f6e\u90fd\u6709\u4e00\u4e2a\u5bf9\u5e94\u8bcd\u6c47\u8868\u5927\u5c0f\u7684\u9884\u6d4b\u6982\u7387\u5206\u5e03\u7684\u8f93\u51fa\u3002\ngpt2_batched(batched_inputs) # [batch, seq_len] -&gt; [batch, seq_len, vocab]\n<\/pre><\/div>\n\n\n\n<h3 class=\"wp-block-heading\" id=\"\u63a8\u65ad\u4f18\u5316\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#%E6%8E%A8%E6%96%AD%E4%BC%98%E5%8C%96\"><\/a>6.4 \u63a8\u65ad\u4f18\u5316(Inference Optimization)<\/strong><\/h3>\n\n\n\n<p>\u6211\u4eec\u7684\u5b9e\u73b0\u76f8\u5f53\u4f4e\u6548\u3002\u9664\u4e86\u652f\u6301GPU\u548c\u6279\u5904\u7406\u4e4b\u5916\uff0c\u6700\u5feb\u4e14\u6700\u6709\u6548\u7684\u4f18\u5316\u53ef\u80fd\u662f\u5b9e\u73b0\u4e00\u4e2a<a href=\"https:\/\/kipp.ly\/blog\/transformer-inference-arithmetic\/#kv-cache\" target=\"_blank\" rel=\"noreferrer noopener\">\u952e\u503c\u7f13\u5b58<\/a>\u3002\u6b64\u5916\uff0c\u6211\u4eec\u987a\u5e8f\u5730\u5b9e\u73b0\u4e86\u6ce8\u610f\u529b\u5934\u8ba1\u7b97\uff0c\u800c\u5b9e\u9645\u4e0a\u6211\u4eec\u5e94\u8be5\u4f7f\u7528\u5e76\u884c\u8ba1\u7b97<sup><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#fn:8\">[8]<\/a><\/sup>\u3002<\/p>\n\n\n\n<p>\u5176\u5b9e\u8fd8\u6709\u5f88\u591a\u5f88\u591a\u7684\u63a8\u7406\u4f18\u5316\u53ef\u4ee5\u505a\u3002\u6211\u5efa\u8bae\u4eceLillian Weng\u7684<a href=\"https:\/\/lilianweng.github.io\/posts\/2023-01-10-inference-optimization\/\" target=\"_blank\" rel=\"noreferrer noopener\">Large Transformer Model Inference Optimization<\/a>\u548cKipply\u7684<a href=\"https:\/\/kipp.ly\/blog\/transformer-inference-arithmetic\/\" target=\"_blank\" rel=\"noreferrer noopener\">Transformer Inference Arithmetic<\/a>\u5f00\u59cb\u5b66\u4e60\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\" id=\"\u8bad\u7ec3-1\"><strong>6.5 \u8bad\u7ec3(Training)<\/strong><\/h3>\n\n\n\n<p>\u8bad\u7ec3 GPT \u5bf9\u4e8e\u795e\u7ecf\u7f51\u7edc\u6765\u8bf4\u662f\u975e\u5e38\u6807\u51c6\u7684\u884c\u4e3a\uff08\u9488\u5bf9\u635f\u5931\u51fd\u6570\u8fdb\u884c\u68af\u5ea6\u4e0b\u964d\uff09\u3002\u5f53\u7136\uff0c\u5728\u8bad\u7ec3 GPT \u65f6\u4f60\u8fd8\u9700\u8981\u4f7f\u7528\u4e00\u5806\u5e38\u89c4\u7684\u6280\u5de7\uff08\u4f7f\u7528 Adam \u4f18\u5316\u5668\uff0c\u627e\u5230\u6700\u4f73\u7684\u5b66\u4e60\u7387\uff0c\u901a\u8fc7dropout\u548c\/\u6216\u6743\u91cd\u8870\u51cf\u8fdb\u884c\u6b63\u5219\u5316\uff0c\u4f7f\u7528\u5b66\u4e60\u7387\u89c4\u5212\u5668\uff0c\u4f7f\u7528\u6b63\u786e\u7684\u6743\u91cd\u521d\u59cb\u5316\uff0c\u8fdb\u884c\u5206\u6279\u5904\u7406\u7b49\u7b49\uff09\u3002<\/p>\n\n\n\n<p>\u800c\u8bad\u7ec3\u4e00\u4e2a\u597d\u7684GPT\u6a21\u578b\u7684\u771f\u6b63\u79d8\u8bc0\u5728\u4e8e<strong>\u80fd\u591f\u6269\u5c55\u6570\u636e\u548c\u6a21\u578b<\/strong>(<strong>scale the data and the model<\/strong>)\uff0c\u8fd9\u4e5f\u662f\u771f\u6b63\u7684\u6311\u6218\u6240\u5728\u3002<\/p>\n\n\n\n<p>\u4e3a\u4e86\u6269\u5c55\u6570\u636e\u91cf\uff0c\u60a8\u9700\u8981\u62e5\u6709\u5927\u89c4\u6a21\u3001\u9ad8\u8d28\u91cf\u3001\u591a\u6837\u5316\u7684\u6587\u672c\u8bed\u6599\u5e93\u3002<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong>\u5927\u89c4\u6a21<\/strong>\u610f\u5473\u7740\u62e5\u6709\u6570\u5341\u4ebf\u7684token\uff08\u6570\u767e\u4e07GB\u7684\u6570\u636e\uff09\u3002\u4f8b\u5982\u53ef\u4ee5\u67e5\u770b<a href=\"https:\/\/pile.eleuther.ai\/\" target=\"_blank\" rel=\"noreferrer noopener\">The Pile<\/a>\uff0c\u8fd9\u662f\u4e00\u4e2a\u7528\u4e8e\u5927\u578b\u8bed\u8a00\u6a21\u578b\u7684\u5f00\u6e90\u9884\u8bad\u7ec3\u6570\u636e\u96c6\u3002<\/li>\n\n\n\n<li><strong>\u9ad8\u8d28\u91cf<\/strong>\u610f\u5473\u7740\u9700\u8981\u8fc7\u6ee4\u6389\u91cd\u590d\u7684\u793a\u4f8b\u3001\u672a\u683c\u5f0f\u5316\u7684\u6587\u672c\u3001\u4e0d\u8fde\u8d2f\u7684\u6587\u672c\u3001\u5783\u573e\u6587\u672c\u7b49\u7b49\u3002<\/li>\n\n\n\n<li><strong>\u591a\u6837\u6027<\/strong>\u610f\u5473\u7740\u5e8f\u5217\u957f\u5ea6\u53d8\u5316\u5927\uff0c\u6db5\u76d6\u4e86\u8bb8\u591a\u4e0d\u540c\u7684\u4e3b\u9898\uff0c\u6765\u81ea\u4e0d\u540c\u7684\u6765\u6e90\uff0c\u5177\u6709\u4e0d\u540c\u7684\u89c2\u70b9\u7b49\u7b49\u3002\u5f53\u7136\uff0c\u5982\u679c\u6570\u636e\u4e2d\u5b58\u5728\u4efb\u4f55\u504f\u89c1\uff0c\u5b83\u5c06\u53cd\u6620\u5728\u6a21\u578b\u4e2d\uff0c\u56e0\u6b64\u60a8\u9700\u8981\u8c28\u614e\u5904\u7406\u3002<\/li>\n<\/ul>\n\n\n\n<p>\u5c06\u6a21\u578b\u6269\u5c55\u5230\u6570\u5341\u4ebf\u4e2a\u53c2\u6570\u9700\u8981\u8d85\u7ea7\u5927\u91cf\u7684\u5de5\u7a0b\uff08\u548c\u91d1\u94b1lol\uff09\u3002\u8bad\u7ec3\u6846\u67b6\u4f1a\u53d8\u5f97<a href=\"https:\/\/github.com\/NVIDIA\/Megatron-LM\" target=\"_blank\" rel=\"noreferrer noopener\">\u975e\u5e38\u5197\u957f\u548c\u590d\u6742<\/a>\u3002\u5173\u4e8e\u8fd9\u4e2a\u4e3b\u9898\u7684\u4e00\u4e2a\u826f\u597d\u8d77\u70b9\u662fLillian Weng\u7684<a href=\"https:\/\/lilianweng.github.io\/posts\/2021-09-25-train-large\/\" target=\"_blank\" rel=\"noreferrer noopener\">How to Train Really Large Models on Many GPUs<\/a>\u3002\u5f53\u7136\uff0c\u5173\u4e8e\u8fd9\u4e2a\u8bdd\u9898\u8fd8\u6709NVIDIA\u7684<a href=\"https:\/\/arxiv.org\/pdf\/1909.08053.pdf\" target=\"_blank\" rel=\"noreferrer noopener\">Megatron Framework<\/a>,&nbsp;<a href=\"https:\/\/arxiv.org\/pdf\/2204.06514.pdf\" target=\"_blank\" rel=\"noreferrer noopener\">Cohere\u7684\u8bad\u7ec3\u6846\u67b6<\/a>, Google\u7684<a href=\"https:\/\/arxiv.org\/pdf\/2204.02311.pdf\" target=\"_blank\" rel=\"noreferrer noopener\">PALM<\/a>, \u5f00\u6e90\u7684<a href=\"https:\/\/github.com\/kingoflolz\/mesh-transformer-jax\" target=\"_blank\" rel=\"noreferrer noopener\">mesh-transformer-jax<\/a>\uff08\u7528\u4e8e\u8bad\u7ec3EleutherAI\u7684\u5f00\u6e90\u6a21\u578b\uff09\uff0c\u4ee5\u53ca<a href=\"https:\/\/arxiv.org\/pdf\/2203.15556.pdf\" target=\"_blank\" rel=\"noreferrer noopener\">\u5f88\u591a<\/a>\u3001<a href=\"https:\/\/www.microsoft.com\/en-us\/research\/blog\/turing-nlg-a-17-billion-parameter-language-model-by-microsoft\/\" target=\"_blank\" rel=\"noreferrer noopener\">\u5f88\u591a<\/a>\u3001<a href=\"https:\/\/arxiv.org\/pdf\/2005.14165.pdf\" target=\"_blank\" rel=\"noreferrer noopener\">\u5f88\u591a<\/a>\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\" id=\"\u8bc4\u4f30\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#%E8%AF%84%E4%BC%B0\"><\/a>6.6 \u8bc4\u4f30(Evaluation)<\/strong><\/h3>\n\n\n\n<p>\u54e6\u5bf9\u4e86\uff0c\u90a3\u4e48\u8981\u600e\u4e48\u8bc4\u4f30\u5927\u8bed\u8a00\u6a21\u578b\u5462\uff1f\u8001\u5b9e\u8bf4\uff0c\u8fd9\u662f\u4e00\u4e2a\u975e\u5e38\u56f0\u96be\u7684\u95ee\u9898\u3002<a href=\"https:\/\/arxiv.org\/abs\/2211.09110\" target=\"_blank\" rel=\"noreferrer noopener\">HELM<\/a>&nbsp;\u662f\u4e00\u4e2a\u76f8\u5f53\u5168\u9762\u4e14\u4e0d\u9519\u7684\u8d77\u70b9\uff0c\u4f46\u4f60\u5e94\u8be5\u59cb\u7ec8\u5bf9<a href=\"https:\/\/en.wikipedia.org\/wiki\/Goodhart%27s_law\" target=\"_blank\" rel=\"noreferrer noopener\">\u57fa\u51c6\u6d4b\u8bd5\u548c\u8bc4\u4f30\u6307\u6807<\/a>\u4fdd\u6301\u6000\u7591\u7684\u6001\u5ea6\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\" id=\"\u67b6\u6784\u6539\u8fdb\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#%E6%9E%B6%E6%9E%84%E6%94%B9%E8%BF%9B\"><\/a>6.7 \u67b6\u6784\u6539\u8fdb(Architecture Improvements)<\/strong><\/h3>\n\n\n\n<p>\u6211\u63a8\u8350\u770b\u4e00\u4e0bPhil Wang\u7684<a href=\"https:\/\/github.com\/lucidrains\/x-transformers\" target=\"_blank\" rel=\"noreferrer noopener\">X-Transformers<\/a>\u3002\u5b83\u5305\u542b\u4e86Transformer\u67b6\u6784\u7684\u6700\u65b0\u6700\u8d5e\u7684\u7814\u7a76\u3002<a href=\"https:\/\/arxiv.org\/pdf\/2102.11972.pdf\" target=\"_blank\" rel=\"noreferrer noopener\">\u8fd9\u7bc7\u8bba\u6587<\/a>\u4e5f\u662f\u4e00\u4e2a\u4e0d\u9519\u7684\u6982\u8ff0(\u89c1\u8868\u683c1)\u3002Facebook\u6700\u8fd1\u7684<a href=\"https:\/\/arxiv.org\/pdf\/2302.13971.pdf\" target=\"_blank\" rel=\"noreferrer noopener\">LLaMA\u8bba\u6587<\/a>\u4e5f\u53ef\u80fd\u662f\u6807\u51c6\u67b6\u6784\u6539\u8fdb\u7684\u4e00\u4e2a\u5f88\u597d\u7684\u53c2\u8003\uff08\u622a\u81f32023\u5e742\u6708\uff09\u3002<\/p>\n\n\n\n<blockquote class=\"wp-block-quote is-layout-flow wp-block-quote-is-layout-flow\">\n<p>\u8bd1\u8005\u6ce8\uff1a\u5b66transformer\u7684\u5c0f\u4f19\u4f34\uff0c\u770b\u5b8c<a href=\"https:\/\/github.com\/lucidrains\/x-transformers\" target=\"_blank\" rel=\"noreferrer noopener\">x-transformers<\/a>\u7edd\u5bf9\u529f\u529b\u5927\u6da8<\/p>\n<\/blockquote>\n\n\n\n<h3 class=\"wp-block-heading\" id=\"\u505c\u6b62\u751f\u6210\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#%E5%81%9C%E6%AD%A2%E7%94%9F%E6%88%90\"><\/a>6.8 \u505c\u6b62\u751f\u6210(Stopping Generation)<\/strong><\/h3>\n\n\n\n<p>\u6211\u4eec\u5f53\u524d\u7684\u5b9e\u73b0\u9700\u8981\u4e8b\u5148\u6307\u5b9a\u8981\u751f\u6210\u7684\u786e\u5207token\u6570\u91cf\u3002\u8fd9\u4e0d\u662f\u4e00\u4e2a\u5f88\u597d\u7684\u65b9\u6cd5\uff0c\u56e0\u4e3a\u6211\u4eec\u751f\u6210\u7684\u6587\u672c\u53ef\u80fd\u4f1a\u592a\u957f\u3001\u592a\u77ed\u6216\u5728\u53e5\u5b50\u4e2d\u95f4\u622a\u65ad\u3002<\/p>\n\n\n\n<p>\u4e3a\u4e86\u89e3\u51b3\u8fd9\u4e2a\u95ee\u9898\uff0c\u6211\u4eec\u53ef\u4ee5\u5f15\u5165\u4e00\u4e2a\u7279\u6b8a\u7684<strong>\u53e5\u5b50\u7ed3\u675f\uff08EOS\uff09token<\/strong>\u3002\u5728\u9884\u8bad\u7ec3\u671f\u95f4\uff0c\u6211\u4eec\u5728\u8f93\u5165\u7684\u672b\u5c3e\u9644\u52a0EOS token\uff08\u6bd4\u5982\uff0c<code>tokens = [\"not\", \"all\", \"heroes\", \"wear\", \"capes\", \".\", \"&lt;|EOS|&gt;\"]<\/code>\uff09\u3002\u5728\u751f\u6210\u8fc7\u7a0b\u4e2d\uff0c\u6211\u4eec\u53ea\u9700\u8981\u5728\u9047\u5230EOS token\u65f6\u505c\u6b62\uff08\u6216\u8005\u8fbe\u5230\u6700\u5927\u5e8f\u5217\u957f\u5ea6\uff09\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">def generate(inputs, eos_id, max_seq_len):\n    # inputs: \u521d\u59cb\u8f93\u5165\u5e8f\u5217\u7684\u4ee4\u724c\u7d22\u5f15\u3002\n    # eos_id: \u7ed3\u675f\u7b26\uff08End Of Sentence\uff09\u7684\u4ee4\u724c\u7d22\u5f15\uff0c\u7528\u4e8e\u6807\u8bc6\u751f\u6210\u8fc7\u7a0b\u4f55\u65f6\u505c\u6b62\u3002\n    # max_seq_len: \u751f\u6210\u6587\u672c\u7684\u6700\u5927\u957f\u5ea6\u9650\u5236\u3002\n\n    # \u4fdd\u5b58\u521d\u59cb\u8f93\u5165\u7684\u957f\u5ea6\uff0c\u7528\u4e8e\u6700\u540e\u8fd4\u56de\u751f\u6210\u7684\u90e8\u5206\u3002\n    prompt_len = len(inputs)\n    \n    # \u5f53\u6700\u540e\u4e00\u4e2a\u4ee4\u724c\u4e0d\u662f\u7ed3\u675f\u7b26\u4e14\u5f53\u524d\u5e8f\u5217\u957f\u5ea6\u5c0f\u4e8e\u6700\u5927\u957f\u5ea6\u65f6\uff0c\u7ee7\u7eed\u751f\u6210\u3002\n    while inputs[-1] != eos_id and len(inputs) &lt; max_seq_len:\n        # \u4f7f\u7528gpt\u6a21\u578b\u5bf9\u5f53\u524d\u8f93\u5165\u5e8f\u5217\u8fdb\u884c\u9884\u6d4b\uff0c\u5f97\u5230\u4e0b\u4e00\u4e2a\u4ee4\u724c\u7684\u6982\u7387\u5206\u5e03\u3002\n        output = gpt(inputs)\n        \n        # \u4ece\u6700\u540e\u4e00\u4e2a\u4f4d\u7f6e\u7684\u6982\u7387\u5206\u5e03\u4e2d\u9009\u62e9\u6982\u7387\u6700\u9ad8\u7684\u4ee4\u724c\u4f5c\u4e3a\u4e0b\u4e00\u4e2a\u4ee4\u724c\u3002\n        next_id = np.argmax(output[-1])\n        \n        # \u5c06\u9009\u4e2d\u7684\u4ee4\u724c\u6dfb\u52a0\u5230\u8f93\u5165\u5e8f\u5217\u4e2d\uff0c\u4e3a\u4e0b\u4e00\u8f6e\u751f\u6210\u505a\u51c6\u5907\u3002\n        inputs.append(int(next_id))\n    \n    # \u751f\u6210\u7ed3\u675f\uff0c\u8fd4\u56de\u9664\u4e86\u521d\u59cb\u8f93\u5165\u5916\u7684\u751f\u6210\u90e8\u5206\u3002\n    return inputs[prompt_len:]\n<\/pre><\/div>\n\n\n\n<p>GPT-2 \u6ca1\u6709\u4f7f\u7528 EOS token\u8fdb\u884c\u9884\u8bad\u7ec3\uff0c\u56e0\u6b64\u6211\u4eec\u65e0\u6cd5\u5728\u6211\u4eec\u7684\u4ee3\u7801\u4e2d\u4f7f\u7528\u8fd9\u79cd\u65b9\u6cd5\uff0c\u4f46\u662f\u73b0\u5728\u5927\u591a\u6570 LLMs \u90fd\u5df2\u7ecf\u4f7f\u7528 EOS token\u4e86\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\" id=\"\u65e0\u6761\u4ef6\u751f\u6210\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#%E6%97%A0%E6%9D%A1%E4%BB%B6%E7%94%9F%E6%88%90\"><\/a>6.9 \u65e0\u6761\u4ef6\u751f\u6210(Unconditional Generation)<\/strong><\/h3>\n\n\n\n<p>\u4f7f\u7528\u6211\u4eec\u7684\u6a21\u578b\u751f\u6210\u6587\u672c\u9700\u8981\u5bf9\u5176\u63d0\u4f9b\u63d0\u793a<strong>\u6761\u4ef6<\/strong>(<strong>condition<\/strong>)\u3002\u4f46\u662f\u6211\u4eec\u4e5f\u53ef\u4ee5\u8ba9\u6a21\u578b\u6267\u884c<strong>\u65e0\u6761\u4ef6\u751f\u6210<\/strong>(<strong>unconditional generation<\/strong>)\uff0c\u5373\u6a21\u578b\u5728\u6ca1\u6709\u4efb\u4f55\u8f93\u5165\u63d0\u793a\u7684\u60c5\u51b5\u4e0b\u751f\u6210\u6587\u672c\u3002<\/p>\n\n\n\n<p>\u8fd9\u662f\u901a\u8fc7\u5728\u9884\u8bad\u7ec3\u671f\u95f4\u5728\u8f93\u5165\u5f00\u5934\u52a0\u4e0a\u4e00\u4e2a\u7279\u6b8a\u7684<strong>\u53e5\u5b50\u5f00\u5934\uff08BOS\uff09token<\/strong>\u6765\u5b9e\u73b0\u7684\uff08\u4f8b\u5982&nbsp;<code>tokens = [\"&lt;|BOS|&gt;\", \"not\", \"all\", \"heroes\", \"wear\", \"capes\", \".\"]<\/code>\uff09\u3002\u8981\u8fdb\u884c\u65e0\u6761\u4ef6\u6587\u672c\u751f\u6210\u7684\u8bdd\uff0c\u6211\u4eec\u5c31\u8f93\u5165\u4e00\u4e2a\u4ec5\u5305\u542bBOS token\u7684\u5217\u8868\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">def generate_unconditioned(bos_id, n_tokens_to_generate):\n    # bos_id: \u5f00\u59cb\u7b26\u53f7\u7684\u4ee4\u724c\u7d22\u5f15\uff0c\u8868\u793a\u751f\u6210\u6587\u672c\u7684\u8d77\u59cb\u70b9\u3002\n    # n_tokens_to_generate: \u9700\u8981\u751f\u6210\u7684\u4ee4\u724c\u6570\u91cf\u3002\n\n    # \u521d\u59cb\u5316\u8f93\u5165\u5e8f\u5217\uff0c\u4ec5\u5305\u542b\u5f00\u59cb\u7b26\u53f7\u3002\n    inputs = [bos_id]\n\n    # \u5faa\u73af\u6307\u5b9a\u7684\u6b21\u6570\uff0c\u6bcf\u6b21\u751f\u6210\u4e00\u4e2a\u4ee4\u724c\u3002\n    for _ in range(n_tokens_to_generate):\n        # \u4f7f\u7528\u5f53\u524d\u7684\u8f93\u5165\u5e8f\u5217\u8c03\u7528gpt\u6a21\u578b\uff0c\u83b7\u53d6\u6a21\u578b\u7684\u8f93\u51fa\u3002\n        output = gpt(inputs)\n        \n        # \u4ece\u6a21\u578b\u8f93\u51fa\u7684\u6700\u540e\u4e00\u4e2a\u4f4d\u7f6e\uff08\u5373\u6700\u65b0\u751f\u6210\u7684\u4ee4\u724c\u7684\u6982\u7387\u5206\u5e03\uff09\u4e2d\u9009\u62e9\u6982\u7387\u6700\u9ad8\u7684\u4ee4\u724c\u4f5c\u4e3a\u4e0b\u4e00\u4e2a\u4ee4\u724c\u3002\n        next_id = np.argmax(output[-1])\n        \n        # \u5c06\u9009\u4e2d\u7684\u4ee4\u724c\u6dfb\u52a0\u5230\u8f93\u5165\u5e8f\u5217\u4e2d\uff0c\u4f9b\u4e0b\u4e00\u6b21\u8fed\u4ee3\u4f7f\u7528\u3002\n        inputs.append(int(next_id))\n    \n    # \u8fd4\u56de\u751f\u6210\u7684\u5e8f\u5217\uff0c\u4e0d\u5305\u62ec\u521d\u59cb\u7684\u5f00\u59cb\u7b26\u53f7\u3002\n    return inputs[1:]\n<\/pre><\/div>\n\n\n\n<p>\u8fd9\u4e2a\u51fd\u6570\u7684\u5de5\u4f5c\u539f\u7406\u5982\u4e0b\uff1a<\/p>\n\n\n\n<ol class=\"wp-block-list\">\n<li>\u4ece\u4e00\u4e2a\u4ec5\u5305\u542b\u5f00\u59cb\u7b26\u53f7\uff08<code>bos_id<\/code>\uff09\u7684\u5e8f\u5217\u5f00\u59cb\u3002<\/li>\n\n\n\n<li>\u5728\u6bcf\u6b21\u5faa\u73af\u4e2d\uff0c\u8c03\u7528\u6a21\u578b<code>gpt<\/code>\u9884\u6d4b\u4e0b\u4e00\u4e2a\u4ee4\u724c\uff0c\u76f4\u5230\u751f\u6210\u4e86\u6307\u5b9a\u6570\u91cf\u7684\u4ee4\u724c\u3002<\/li>\n\n\n\n<li>\u9009\u62e9\u6bcf\u6b21\u9884\u6d4b\u4e2d\u6982\u7387\u6700\u9ad8\u7684\u4ee4\u724c\u4f5c\u4e3a\u4e0b\u4e00\u4e2a\u4ee4\u724c\uff0c\u5e76\u5c06\u5176\u52a0\u5165\u5230\u8f93\u5165\u5e8f\u5217\u4e2d\u3002<\/li>\n\n\n\n<li>\u5faa\u73af\u7ed3\u675f\u540e\uff0c\u8fd4\u56de\u9664\u5f00\u59cb\u7b26\u53f7\u5916\u7684\u6240\u6709\u751f\u6210\u7684\u4ee4\u724c\u5e8f\u5217\u3002<\/li>\n<\/ol>\n\n\n\n<p>\u8fd9\u79cd\u751f\u6210\u65b9\u5f0f\u662f\u65e0\u6761\u4ef6\u7684\uff08unconditioned\uff09\uff0c\u56e0\u4e3a\u5b83\u4e0d\u4f9d\u8d56\u4e8e\u4efb\u4f55\u5148\u524d\u7684\u4e0a\u4e0b\u6587\u4fe1\u606f\uff0c\u4ec5\u4ece\u4e00\u4e2a\u56fa\u5b9a\u7684\u5f00\u59cb\u7b26\u53f7\u51fa\u53d1\u3002\u8fd9\u79cd\u65b9\u6cd5\u901a\u5e38\u7528\u4e8e\u6d4b\u8bd5\u6216\u5c55\u793a\u8bed\u8a00\u6a21\u578b\u7684\u80fd\u529b\uff0c\u6216\u5728\u9700\u8981\u751f\u6210\u5168\u65b0\u5185\u5bb9\u65f6\u4f7f\u7528\u3002\u4e0e\u6761\u4ef6\u751f\u6210\u76f8\u6bd4\uff08\u6761\u4ef6\u751f\u6210\u57fa\u4e8e\u7279\u5b9a\u7684\u4e0a\u4e0b\u6587\u6216\u63d0\u793a\u751f\u6210\u6587\u672c\uff09\uff0c\u65e0\u6761\u4ef6\u751f\u6210\u53ef\u80fd\u4f1a\u4ea7\u751f\u66f4\u968f\u673a\u548c\u4e0d\u53ef\u9884\u6d4b\u7684\u8f93\u51fa\u3002<\/p>\n\n\n\n<p>GPT-2\u7684\u9884\u8bad\u7ec3\u662f\u5e26\u6709BOS token\u7684\uff08\u4e0d\u8fc7\u5b83\u6709\u4e00\u4e2a\u4ee4\u4eba\u56f0\u60d1\u7684\u540d\u5b57<code>&lt;|endoftext|&gt;<\/code>\uff09\uff0c\u56e0\u6b64\u5728\u6211\u4eec\u7684\u5b9e\u73b0\u4e2d\u8981\u8fd0\u884c\u65e0\u6761\u4ef6\u751f\u6210\u7684\u8bdd\uff0c\u53ea\u9700\u8981\u7b80\u5355\u5730\u5c06<a href=\"https:\/\/github.com\/jaymody\/picoGPT\/blob\/dfb5df895a7a6b18705866a0bf7ec04947d8e05a\/gpt2.py#L104\" target=\"_blank\" rel=\"noreferrer noopener\">\u8fd9\u884c\u4ee3\u7801<\/a>\u66f4\u6539\u4e3a\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \"># \u5047\u5b9a`encoder`\u662f\u4e00\u4e2a\u5df2\u7ecf\u521d\u59cb\u5316\u597d\u7684\u7f16\u7801\u5668\u5b9e\u4f8b\uff0c\u80fd\u591f\u5c06\u6587\u672c\u7f16\u7801\u6210\u4ee4\u724c\u7d22\u5f15\u5217\u8868\u3002\n# `prompt`\u662f\u4e00\u4e2a\u5b57\u7b26\u4e32\uff0c\u8868\u793a\u7528\u6237\u8f93\u5165\u6216\u5176\u4ed6\u5f62\u5f0f\u7684\u63d0\u793a\u6587\u672c\u3002\n\ninput_ids = encoder.encode(prompt) if prompt else [encoder.encoder[\"\"]]\n<\/pre><\/div>\n\n\n\n<p>\u7136\u540e\u8fd0\u884c;<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">python gpt2.py \"\"<\/pre><\/div>\n\n\n\n<p>\u7136\u540e\u5373\u53ef\u751f\u6210\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">The first time I saw the new version of the game, I was so excited. I was so excited to see the new version of the game, I was so excited to see the new version\n<\/pre><\/div>\n\n\n\n<p>\u56e0\u4e3a\u6211\u4eec\u4f7f\u7528\u7684\u662f\u8d2a\u5fc3\u91c7\u6837\uff0c\u6240\u4ee5\u8f93\u51fa\u7ed3\u679c\u4e0d\u662f\u5f88\u597d\uff08\u91cd\u590d\u7684\u5185\u5bb9\u8f83\u591a\uff09\uff0c\u4e14\u6bcf\u6b21\u8fd0\u884c\u4ee3\u7801\u7684\u8f93\u51fa\u7ed3\u679c\u90fd\u662f\u786e\u5b9a\u7684\u3002\u4e3a\u4e86\u83b7\u5f97\u66f4\u9ad8\u8d28\u91cf\u7684\u3001\u4e0d\u786e\u5b9a\u6027\u66f4\u5927\u7684\u751f\u6210\u7ed3\u679c\uff0c\u6211\u4eec\u9700\u8981\u76f4\u63a5\u4ece\u6982\u7387\u5206\u5e03\u4e2d\u8fdb\u884c\u91c7\u6837\uff08\u6700\u597d\u5728\u4f7f\u7528<code>top-p<\/code>\u4e4b\u7c7b\u7684\u65b9\u6cd5\u540e\u8fdb\u884c\u91c7\u6837\uff09\u3002<\/p>\n\n\n\n<p>\u65e0\u6761\u4ef6\u751f\u6210\u4e0d\u662f\u7279\u522b\u6709\u7528\uff0c\u4f46\u5b83\u662f\u6f14\u793aGPT\u80fd\u529b\u7684\u4e00\u79cd\u6709\u8da3\u65b9\u5f0f\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\" id=\"\u5fae\u8c03\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#%E5%BE%AE%E8%B0%83\"><\/a>6.10 \u5fae\u8c03(Fine-tuning)<\/strong><\/h3>\n\n\n\n<p>\u6211\u4eec\u5728\u8bad\u7ec3\u90e8\u5206\u7b80\u8981\u4ecb\u7ecd\u4e86\u5fae\u8c03\u3002\u56de\u60f3\u4e00\u4e0b\uff0c\u5fae\u8c03\u662f\u6307\u6211\u4eec\u590d\u7528\u9884\u8bad\u7ec3\u7684\u6743\u91cd\uff0c\u5bf9\u6a21\u578b\u5728\u67d0\u4e9b\u4e0b\u6e38\u4efb\u52a1\u4e0a\u8fdb\u884c\u8bad\u7ec3\u3002\u6211\u4eec\u79f0\u8fd9\u4e2a\u8fc7\u7a0b\u4e3a\u8fc1\u79fb\u5b66\u4e60\u3002<\/p>\n\n\n\n<p>\u7406\u8bba\u4e0a\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u96f6\u6837\u672c\u6216\u5c11\u6837\u672c\u63d0\u793a\u6765\u8ba9\u6a21\u578b\u5b8c\u6210\u6211\u4eec\u7684\u4efb\u52a1\uff0c\u4f46\u662f\u5982\u679c\u60a8\u53ef\u4ee5\u8bbf\u95ee\u4e00\u4e2a\u6807\u6ce8\u7684\u6570\u636e\u96c6\uff0c\u5bf9GPT\u8fdb\u884c\u5fae\u8c03\u5c06\u4f1a\u4ea7\u751f\u66f4\u597d\u7684\u7ed3\u679c\uff08\u8fd9\u4e9b\u7ed3\u679c\u53ef\u4ee5\u5728\u83b7\u5f97\u66f4\u591a\u6570\u636e\u548c\u66f4\u9ad8\u8d28\u91cf\u7684\u6570\u636e\u65f6\u8fdb\u884c\u6269\u5c55\uff09\u3002<\/p>\n\n\n\n<p>\u597d\u7684\uff0c\u4ee5\u4e0b\u662f\u5173\u4e8e\u5fae\u8c03\u7684\u4e00\u4e9b\u76f8\u5173\u4e3b\u9898\uff1a<\/p>\n\n\n\n<h4 class=\"wp-block-heading\" id=\"\u5206\u7c7b\u5fae\u8c03\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#%E5%88%86%E7%B1%BB%E5%BE%AE%E8%B0%83\"><\/a>6.10.1 \u5206\u7c7b\u5fae\u8c03(Classification Fine-tuning)<\/strong><\/h4>\n\n\n\n<p>\u5728\u5206\u7c7b\u5fae\u8c03\u4e2d\uff0c\u6211\u4eec\u4f1a\u7ed9\u6a21\u578b\u4e00\u4e9b\u6587\u672c\uff0c\u5e76\u8981\u6c42\u5b83\u9884\u6d4b\u5b83\u5c5e\u4e8e\u54ea\u4e2a\u7c7b\u522b\u3002\u4ee5<a href=\"https:\/\/huggingface.co\/datasets\/imdb\" target=\"_blank\" rel=\"noreferrer noopener\">IMDB\u6570\u636e\u96c6<\/a>\u4e3a\u4f8b\uff0c\u5b83\u5305\u542b\u7740\u7535\u5f71\u8bc4\u8bba\uff0c\u5c06\u7535\u5f71\u8bc4\u4e3a\u597d\u6216\u574f\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python highlight:0 decode:true \">--- Example 1 ---\nText: I wouldn't rent this one even on dollar rental night.\nLabel: Bad\n--- Example 2 ---\nText: I don't know why I like this movie so well, but I never get tired of watching it.\nLabel: Good\n--- Example 3 ---\n...<\/pre><\/div>\n\n\n\n<p>\u4e3a\u4e86\u5fae\u8c03\u6211\u4eec\u7684\u6a21\u578b\uff0c\u6211\u4eec\u9700\u8981\u7528\u5206\u7c7b\u5934\u66ff\u6362\u8bed\u8a00\u5efa\u6a21\u5934\uff0c\u5c06\u5176\u5e94\u7528\u4e8e\u6700\u540e\u4e00\u4e2atoken\u7684\u8f93\u51fa\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">def gpt2(inputs, wte, wpe, blocks, ln_f, cls_head, n_head):\n    # wte: \u4ee4\u724c\u5d4c\u5165\u77e9\u9635\uff0c\u7528\u4e8e\u5c06\u8f93\u5165\u4ee4\u724c\u7f16\u7801\u6210\u5411\u91cf\u3002\n    # wpe: \u4f4d\u7f6e\u5d4c\u5165\u77e9\u9635\uff0c\u7528\u4e8e\u7ed9\u8f93\u5165\u5411\u91cf\u6dfb\u52a0\u4f4d\u7f6e\u4fe1\u606f\u3002\n    # blocks: Transformer\u6a21\u578b\u7684\u5757\uff08\u5c42\uff09\u5217\u8868\u3002\n    # ln_f: \u6700\u540e\u4e00\u5c42\u7684\u5c42\u5f52\u4e00\u5316\u53c2\u6570\u3002\n    # cls_head: \u5206\u7c7b\u5934\u7684\u6743\u91cd\u77e9\u9635\uff0c\u7528\u4e8e\u5c06\u6a21\u578b\u7684\u8f93\u51fa\u6620\u5c04\u5230\u7c7b\u522b\u4e0a\u3002\n    # n_head: \u591a\u5934\u6ce8\u610f\u529b\u4e2d\u5934\u7684\u6570\u91cf\u3002\n    \n    # \u5c06\u8f93\u5165\u4ee4\u724c\u8f6c\u6362\u4e3a\u5d4c\u5165\u5411\u91cf\uff0c\u5e76\u52a0\u4e0a\u4f4d\u7f6e\u5d4c\u5165\u3002\n    x = wte[inputs] + wpe[range(len(inputs))]\n    \n    # \u901a\u8fc7\u6240\u6709Transformer\u5757\uff08\u5c42\uff09\u8fdb\u884c\u524d\u5411\u4f20\u64ad\u3002\n    for block in blocks:\n        x = transformer_block(x, **block, n_head=n_head)\n    \n    # \u5bf9\u6700\u540e\u7684\u8f93\u51fa\u5e94\u7528\u5c42\u5f52\u4e00\u5316\u3002\n    x = layer_norm(x, **ln_f)\n\n    # \u4f7f\u7528\u5206\u7c7b\u5934\u5bf9\u6700\u540e\u4e00\u4e2a\u8f93\u51fa\u5411\u91cf\u8fdb\u884c\u6295\u5f71\uff0c\u4ee5\u5f97\u5230\u7c7b\u522b\u9884\u6d4b\u3002\n    # \u6ce8\u610f\u8fd9\u91cc\u53ea\u53d6\u5e8f\u5217\u6700\u540e\u4e00\u4e2a\u4f4d\u7f6e\u7684\u8f93\u51fa\u8fdb\u884c\u5206\u7c7b\uff0c\u8fd9\u5e38\u89c1\u4e8e\u5904\u7406\u5e8f\u5217\u5206\u7c7b\u4efb\u52a1\u65f6\uff0c\n    # \u5047\u8bbe\u5e8f\u5217\u7684\u6700\u540e\u4e00\u4e2a\u8f93\u51fa\u5305\u542b\u4e86\u6574\u4e2a\u5e8f\u5217\u7684\u4fe1\u606f\u3002\n    # [n_embd] @ [n_embd, n_classes] -&gt; [n_classes]\n    return x[-1] @ cls_head\n<\/pre><\/div>\n\n\n\n<p>\u8fd9\u91cc\u6211\u4eec\u53ea\u4f7f\u7528\u6700\u540e\u4e00\u4e2atoken\u7684\u8f93\u51fa<code>x[-1]<\/code>\uff0c\u56e0\u4e3a\u6211\u4eec\u53ea\u9700\u8981\u4e3a\u6574\u4e2a\u8f93\u5165\u4ea7\u751f\u4e00\u4e2a\u5355\u4e00\u7684\u6982\u7387\u5206\u5e03\uff0c\u800c\u4e0d\u662f\u50cf\u8bed\u8a00\u6a21\u578b\u4e00\u6837\u4ea7\u751f<code>n_seq<\/code>\u4e2a\u5206\u5e03\u3002\u6211\u4eec\u7279\u522b\u9009\u62e9\u6700\u540e\u4e00\u4e2atoken\uff08\u800c\u4e0d\u662f\u7b2c\u4e00\u4e2atoken\u6216\u6240\u6709token\u7684\u7ec4\u5408\uff09\uff0c\u56e0\u4e3a\u6700\u540e\u4e00\u4e2atoken\u662f\u552f\u4e00\u5141\u8bb8\u5173\u6ce8\u6574\u4e2a\u5e8f\u5217\u7684token\uff0c\u56e0\u6b64\u5b83\u5177\u6709\u5173\u4e8e\u6574\u4e2a\u8f93\u5165\u6587\u672c\u7684\u4fe1\u606f\u3002<\/p>\n\n\n\n<p>\u540c\u5f80\u5e38\u4e00\u6837\uff0c\u6211\u4eec\u6839\u636e\u4ea4\u53c9\u71b5\u635f\u5931\u8fdb\u884c\u4f18\u5316\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">def singe_example_loss_fn(inputs: list[int], label: int, params) -&gt; float:\n    # inputs: \u8f93\u5165\u5e8f\u5217\u7684\u4ee4\u724c\u7d22\u5f15\u5217\u8868\u3002\n    # label: \u6b63\u786e\u7c7b\u522b\u7684\u7d22\u5f15\u3002\n    # params: GPT\u6a21\u578b\u7684\u53c2\u6570\u3002\n\n    # \u8c03\u7528GPT\u6a21\u578b\uff0c\u83b7\u53d6\u5bf9\u6bcf\u4e2a\u7c7b\u522b\u7684logits\u9884\u6d4b\u3002\n    logits = gpt(inputs, **params)\n\n    # \u4f7f\u7528softmax\u51fd\u6570\u5c06logits\u8f6c\u6362\u4e3a\u6982\u7387\u5206\u5e03\u3002\n    probs = softmax(logits)\n\n    # \u8ba1\u7b97\u4ea4\u53c9\u71b5\u635f\u5931\uff1a\u53d6\u8d1f\u5bf9\u6570\u6982\u7387\u4f5c\u4e3a\u635f\u5931\u3002\n    # \u8fd9\u91cc\u76f4\u63a5\u7d22\u5f15label\u5bf9\u5e94\u7684\u6982\u7387\uff0c\u5047\u5b9alabel\u662f\u6b63\u786e\u7684\u7c7b\u522b\u7d22\u5f15\u3002\n    loss = -np.log(probs[label])\n\n    # \u8fd4\u56de\u8ba1\u7b97\u7684\u635f\u5931\u503c\u3002\n    return loss\n<\/pre><\/div>\n\n\n\n<p>\u6211\u4eec\u8fd8\u53ef\u4ee5\u6267\u884c<strong>\u591a\u6807\u7b7e\u5206\u7c7b<\/strong>\uff08\u5373\u4e00\u4e2a\u6837\u672c\u53ef\u4ee5\u5c5e\u4e8e\u591a\u4e2a\u7c7b\u522b\uff0c\u800c\u4e0d\u4ec5\u4ec5\u662f\u4e00\u4e2a\u7c7b\u522b\uff09\uff0c\u8fd9\u53ef\u4ee5\u901a\u8fc7\u4f7f\u7528<code>sigmoid<\/code>\u66ff\u4ee3<code>softmax<\/code>\u5e76\u9488\u5bf9\u6bcf\u4e2a\u7c7b\u522b\u91c7\u7528\u4e8c\u5206\u4ea4\u53c9\u71b5\u635f\u5931\uff08\u53c2\u89c1<a href=\"https:\/\/stats.stackexchange.com\/questions\/207794\/what-loss-function-for-multi-class-multi-label-classification-tasks-in-neural-n\" target=\"_blank\" rel=\"noreferrer noopener\">\u8fd9\u4e2astackexchange\u95ee\u9898<\/a>\uff09\u3002<\/p>\n\n\n\n<h4 class=\"wp-block-heading\" id=\"\u751f\u6210\u5f0f\u5fae\u8c03\"><strong>6.10.2 \u751f\u6210\u5f0f\u5fae\u8c03(Generative Fine-tuning)<\/strong><\/h4>\n\n\n\n<p>\u6709\u4e9b\u4efb\u52a1\u65e0\u6cd5\u88ab\u7b80\u5355\u5730\u8ba4\u4e3a\u662f\u5206\u7c7b\uff0c\u5982\u6458\u8981\u7684\u4efb\u52a1\u3002\u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u5bf9\u8f93\u5165\u548c\u6807\u7b7e\u62fc\u63a5\u8fdb\u884c\u8bed\u8a00\u5efa\u6a21\uff0c\u4ece\u800c\u5b9e\u73b0\u8fd9\u7c7b\u4efb\u52a1\u7684\u5fae\u8c03\u3002\u4f8b\u5982\uff0c\u4e0b\u9762\u5c31\u662f\u4e00\u4e2a\u6458\u8981\u8bad\u7ec3\u6837\u672c\u7684\u793a\u4f8b\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python highlight:0 decode:true \">--- Article ---\nThis is an article I would like to summarize.\n--- Summary ---\nThis is the summary.<\/pre><\/div>\n\n\n\n<p>\u6211\u4eec\u5c31\u50cf\u9884\u8bad\u7ec3\u65f6\u90a3\u6837\u8bad\u7ec3\u8fd9\u4e2a\u6a21\u578b\uff08\u6839\u636e\u8bed\u8a00\u5efa\u6a21\u7684\u635f\u5931\u8fdb\u884c\u4f18\u5316\uff09\u3002<\/p>\n\n\n\n<p>\u5728\u9884\u6d4b\u65f6\uff0c\u6211\u4eec\u5c06\u76f4\u5230<code>\"--- Summary ---\"<\/code>\u7684\u8f93\u5165\u5582\u7ed9\u6a21\u578b\uff0c\u7136\u540e\u6267\u884c\u81ea\u56de\u5f52\u8bed\u8a00\u5efa\u6a21\u4ee5\u751f\u6210\u6458\u8981\u3002<\/p>\n\n\n\n<p>\u5b9a\u754c\u7b26<code>\"--- Article ---\"<\/code>\u548c<code>\"--- Summary ---\"<\/code>\u7684\u9009\u62e9\u662f\u4efb\u610f\u7684\u3002\u5982\u4f55\u9009\u62e9\u6587\u672c\u683c\u5f0f\u7531\u60a8\u51b3\u5b9a\uff0c\u53ea\u8981\u5728\u8bad\u7ec3\u548c\u63a8\u65ad\u4e2d\u4fdd\u6301\u4e00\u81f4\u5373\u53ef\u3002<\/p>\n\n\n\n<p>\u8bf7\u6ce8\u610f\uff0c\u5176\u5b9e\u6211\u4eec\u4e5f\u53ef\u4ee5\u5c06\u5206\u7c7b\u4efb\u52a1\u8868\u8ff0\u4e3a\u751f\u6210\u4efb\u52a1\uff08\u4ee5IMDB\u4e3a\u4f8b\uff09\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python highlight:0 decode:true \">--- Text ---\nI wouldn't rent this one even on dollar rental night.\n--- Label ---\nBad<\/pre><\/div>\n\n\n\n<p>\u7136\u800c\uff0c\u8fd9\u79cd\u65b9\u6cd5\u7684\u8868\u73b0\u5f88\u53ef\u80fd\u4f1a\u6bd4\u76f4\u63a5\u8fdb\u884c\u5206\u7c7b\u5fae\u8c03\u8981\u5dee\uff08\u635f\u5931\u51fd\u6570\u5305\u62ec\u5bf9\u6574\u4e2a\u5e8f\u5217\u8fdb\u884c\u8bed\u8a00\u5efa\u6a21\uff0c\u800c\u4e0d\u4ec5\u4ec5\u662f\u5bf9\u6700\u7ec8\u9884\u6d4b\u7684\u8f93\u51fa\u8fdb\u884c\u5efa\u6a21\uff0c\u56e0\u6b64\u4e0e\u9884\u6d4b\u6709\u5173\u7684\u635f\u5931\u5c06\u88ab\u7a00\u91ca\uff09\u3002<\/p>\n\n\n\n<h4 class=\"wp-block-heading\" id=\"\u6307\u4ee4\u5fae\u8c03\"><strong><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#%E6%8C%87%E4%BB%A4%E5%BE%AE%E8%B0%83\"><\/a>6.10.3 \u6307\u4ee4\u5fae\u8c03(Instruction Fine-tuning)<\/strong><\/h4>\n\n\n\n<p>\u76ee\u524d\u5927\u591a\u6570\u6700\u5148\u8fdb\u7684\u5927\u578b\u8bed\u8a00\u6a21\u578b\u5728\u9884\u8bad\u7ec3\u540e\u8fd8\u9700\u8981\u7ecf\u8fc7\u4e00\u4e2a\u989d\u5916\u7684<strong>\u6307\u4ee4\u5fae\u8c03<\/strong>(<strong>instruction fine-tuning<\/strong>)\u6b65\u9aa4\u3002\u5728\u8fd9\u4e2a\u6b65\u9aa4\u4e2d\uff0c\u6a21\u578b\u5728\u6210\u5343\u4e0a\u4e07\u4e2a\u7531<strong>\u4eba\u5de5\u6807\u6ce8<\/strong>(<strong>human labeled<\/strong>)\u7684\u6307\u4ee4\u63d0\u793a+\u8865\u5168\u5bf9\u4e0a\u8fdb\u884c\u5fae\u8c03\uff08\u751f\u6210\u5f0f\uff09\u3002\u6307\u4ee4\u5fae\u8c03\u4e5f\u53ef\u4ee5\u79f0\u4e3a<strong>\u76d1\u7763\u5f0f\u5fae\u8c03<\/strong>(<strong>supervised fine-tuning<\/strong>)\uff0c\u56e0\u4e3a\u6570\u636e\u662f\u4eba\u5de5\u6807\u8bb0\u7684\uff08\u5373<strong>\u6709\u76d1\u7763\u7684<\/strong>,<strong>supervised<\/strong>\uff09\u3002<\/p>\n\n\n\n<p>\u90a3\u6307\u4ee4\u5fae\u8c03\u7684\u597d\u5904\u662f\u4ec0\u4e48\u5462\uff1f\u867d\u7136\u5728\u9884\u6d4b\u7ef4\u57fa\u767e\u79d1\u6587\u7ae0\u4e2d\u7684\u4e0b\u4e00\u4e2a\u8bcd\u65f6\uff0c\u6a21\u578b\u5728\u7eed\u5199\u53e5\u5b50\u65b9\u9762\u8868\u73b0\u5f97\u5f88\u597d\uff0c\u4f46\u5b83\u5e76\u4e0d\u64c5\u957f\u9075\u5faa\u8bf4\u660e\u3001\u8fdb\u884c\u5bf9\u8bdd\u6216\u5bf9\u6587\u4ef6\u8fdb\u884c\u6458\u8981\uff08\u8fd9\u4e9b\u662f\u6211\u4eec\u5e0c\u671bGPT\u80fd\u591f\u505a\u5230\u7684\u4e8b\u60c5\uff09\u3002\u5728\u4eba\u7c7b\u6807\u8bb0\u7684\u6307\u4ee4 + \u5b8c\u6210\u5bf9\u4e2d\u5fae\u8c03\u5b83\u4eec\u662f\u6559\u5bfc\u6a21\u578b\u5982\u4f55\u53d8\u5f97\u66f4\u6709\u7528\uff0c\u5e76\u4f7f\u5b83\u4eec\u66f4\u5bb9\u6613\u4ea4\u4e92\u7684\u4e00\u79cd\u65b9\u6cd5\u3002\u6211\u4eec\u5c06\u5176\u79f0\u4e3a<strong>AI\u5bf9\u9f50(AI alignment)<\/strong>\uff0c\u56e0\u4e3a\u6211\u4eec\u9700\u8981\u6a21\u578b\u4ee5\u6211\u4eec\u60f3\u8981\u7684\u65b9\u5f0f\u505a\u4e8b\u548c\u8868\u73b0\u3002\u5bf9\u9f50\u662f\u4e00\u4e2a\u6d3b\u8dc3\u7684\u7814\u7a76\u9886\u57df\uff0c\u5b83\u4e0d\u4ec5\u4ec5\u53ea\u5305\u62ec\u9075\u5faa\u8bf4\u660e\uff08\u8fd8\u6d89\u53ca\u504f\u89c1\u3001\u5b89\u5168\u3001\u610f\u56fe\u7b49\uff09\u7684\u95ee\u9898\u3002<\/p>\n\n\n\n<p>\u90a3\u4e48\u8fd9\u4e9b\u6307\u4ee4\u6570\u636e\u5230\u5e95\u662f\u4ec0\u4e48\u6837\u5b50\u7684\u5462\uff1fGoogle\u7684<a href=\"https:\/\/arxiv.org\/pdf\/2109.01652.pdf\" target=\"_blank\" rel=\"noreferrer noopener\">FLAN<\/a>\u6a21\u578b\u662f\u5728\u591a\u4e2a\u5b66\u672f\u7684\u81ea\u7136\u8bed\u8a00\u5904\u7406\u6570\u636e\u96c6\uff08\u8fd9\u4e9b\u6570\u636e\u96c6\u5df2\u7ecf\u88ab\u4eba\u5de5\u6807\u6ce8\uff09\u4e0a\u8fdb\u884c\u8bad\u7ec3\u7684\uff1a<\/p>\n\n\n\n<figure class=\"wp-block-image\"><img decoding=\"async\" src=\"https:\/\/jiqihumanr.github.io\/images\/flan.png\" alt=\"(fig 3)\"\/><\/figure>\n\n\n\n<p>\u6765\u81eaFLAN\u8bba\u6587\u7684\u56fe3<\/p>\n\n\n\n<p>OpenAI\u7684<a href=\"https:\/\/arxiv.org\/pdf\/2203.02155.pdf\" target=\"_blank\" rel=\"noreferrer noopener\">InstructGPT<\/a>\u5219\u4f7f\u7528\u4e86\u4ece\u5176API\u4e2d\u6536\u96c6\u7684\u63d0\u793a\u8fdb\u884c\u8bad\u7ec3\u3002\u7136\u540e\u4ed6\u4eec\u96c7\u4f63\u5de5\u4eba\u4e3a\u8fd9\u4e9b\u63d0\u793a\u7f16\u5199\u8865\u5168\u3002\u4e0b\u9762\u662f\u8fd9\u4e9b\u6570\u636e\u7684\u8be6\u7ec6\u4fe1\u606f\uff1a<\/p>\n\n\n\n<figure class=\"wp-block-image\"><img decoding=\"async\" src=\"https:\/\/jiqihumanr.github.io\/images\/igpt.png\" alt=\"(igpt)\"\/><\/figure>\n\n\n\n<p>\u6765\u81eaInstructGPT\u8bba\u6587\u7684\u88681\u4e0e\u88682<\/p>\n\n\n\n<h4 class=\"wp-block-heading\" id=\"\u53c2\u6570\u9ad8\u6548\u5fae\u8c03\uff08Parameter-Efficient-Fine-tuning\uff09\"><strong>6.10.4 \u53c2\u6570\u9ad8\u6548\u5fae\u8c03(Parameter Efficient Fine-tuning)<\/strong><\/h4>\n\n\n\n<p>\u5f53\u6211\u4eec\u5728\u4e0a\u9762\u7684\u90e8\u5206\u8ba8\u8bba\u5fae\u8c03\u65f6\uff0c\u6211\u4eec\u662f\u5728\u66f4\u65b0\u6a21\u578b\u7684\u6240\u6709\u53c2\u6570\u3002\u867d\u7136\u8fd9\u53ef\u4ee5\u83b7\u5f97\u6700\u4f73\u6027\u80fd\uff0c\u4f46\u6210\u672c\u975e\u5e38\u9ad8\uff0c\u65e0\u8bba\u662f\u5728\u8ba1\u7b97\u65b9\u9762\uff08\u9700\u8981\u7ecf\u8fc7\u6574\u4e2a\u6a21\u578b\u8fdb\u884c\u53cd\u5411\u4f20\u64ad\uff09\uff0c\u8fd8\u662f\u5728\u5b58\u50a8\u65b9\u9762\uff08\u5bf9\u4e8e\u6bcf\u4e2a\u5fae\u8c03\u7684\u6a21\u578b\uff0c\u60a8\u9700\u8981\u5b58\u50a8\u5b8c\u4e00\u4efd\u5168\u65b0\u7684\u53c2\u6570\u526f\u672c\uff09\u3002<\/p>\n\n\n\n<p>\u6700\u7b80\u5355\u7684\u89e3\u51b3\u65b9\u6cd5\u662f<strong>\u53ea\u66f4\u65b0\u6a21\u578b\u5934\u90e8<\/strong>(<strong>only update the head<\/strong>)\u5e76<strong>\u51bb\u7ed3<\/strong>(<strong>freeze<\/strong>)\uff08\u5373\u4f7f\u5176\u4e0d\u53ef\u8bad\u7ec3\uff09\u6a21\u578b\u7684\u5176\u5b83\u90e8\u5206\u3002\u867d\u7136\u8fd9\u6837\u505a\u53ef\u4ee5\u52a0\u901f\u8bad\u7ec3\u5e76\u5927\u5927\u51cf\u5c11\u65b0\u53c2\u6570\u7684\u6570\u91cf\uff0c\u4f46\u5176\u8868\u73b0\u5e76\u4e0d\u597d\uff0c\u56e0\u4e3a\u67d0\u79cd\u610f\u4e49\u4e0a\u6211\u4eec\u635f\u5931\u4e86\u6df1\u5ea6\u5b66\u4e60\u4e2d\u7684<strong>\u6df1\u5ea6<\/strong>(<em>deep<\/em>)\u3002\u76f8\u53cd\uff0c\u6211\u4eec\u53ef\u4ee5<strong>\u9009\u62e9\u6027\u5730\u51bb\u7ed3<\/strong>(<strong>selectively freeze<\/strong>)\u7279\u5b9a\u5c42\uff08\u4f8b\u5982\u51bb\u7ed3\u9664\u4e86\u6700\u540e\u56db\u5c42\u5916\u7684\u6240\u6709\u5c42\uff0c\u6216\u6bcf\u9694\u4e00\u5c42\u8fdb\u884c\u51bb\u7ed3\uff0c\u6216\u51bb\u7ed3\u9664\u591a\u5934\u6ce8\u610f\u529b\u53c2\u6570\u5916\u7684\u6240\u6709\u53c2\u6570\uff09\uff0c\u90a3\u4e48\u8fd9\u5c06\u6709\u52a9\u4e8e\u6062\u590d\u6df1\u5ea6\u3002\u8fd9\u79cd\u65b9\u6cd5\u7684\u6027\u80fd\u8981\u597d\u5f97\u591a\uff0c\u4f46\u6211\u4eec\u4e5f\u53d8\u5f97\u4e0d\u90a3\u4e48\u53c2\u6570\u9ad8\u6548(parameter efficient)\uff0c\u540c\u65f6\u4e5f\u5931\u53bb\u4e86\u4e00\u4e9b\u8bad\u7ec3\u901f\u5ea6\u7684\u4f18\u52bf\u3002<\/p>\n\n\n\n<p>\u9664\u6b64\u4e4b\u5916\uff0c\u6211\u4eec\u8fd8\u53ef\u4ee5\u5229\u7528<strong>\u53c2\u6570\u9ad8\u6548\u5fae\u8c03(Parameter Efficient Fine-tuning)<\/strong>\u65b9\u6cd5\u3002\u8fd9\u4ecd\u7136\u662f\u4e00\u4e2a\u6d3b\u8dc3\u7684\u7814\u7a76\u9886\u57df\uff0c<a href=\"https:\/\/aclanthology.org\/2021.emnlp-main.243.pdf\" target=\"_blank\" rel=\"noreferrer noopener\">\u6709\u8bb8\u591a\u4e0d\u540c\u7684\u65b9\u6cd5\u53ef\u4f9b\u9009\u62e9<\/a>\u3001<a href=\"https:\/\/arxiv.org\/pdf\/2110.07602.pdf\" target=\"_blank\" rel=\"noreferrer noopener\">\u9009\u62e9<\/a>\u3001<a href=\"https:\/\/arxiv.org\/pdf\/2101.00190.pdf\" target=\"_blank\" rel=\"noreferrer noopener\">\u9009\u62e9<\/a>\u3001<a href=\"https:\/\/arxiv.org\/pdf\/2103.10385.pdf\" target=\"_blank\" rel=\"noreferrer noopener\">\u9009\u62e9<\/a>\u3001<a href=\"https:\/\/arxiv.org\/pdf\/2106.09685.pdf\" target=\"_blank\" rel=\"noreferrer noopener\">\u9009\u62e9<\/a>\u3001<a href=\"https:\/\/arxiv.org\/pdf\/1902.00751.pdf\" target=\"_blank\" rel=\"noreferrer noopener\">\u9009\u62e9<\/a>\u3001<a href=\"https:\/\/arxiv.org\/abs\/2205.05638\" target=\"_blank\" rel=\"noreferrer noopener\">\u9009\u62e9<\/a>\u3002<\/p>\n\n\n\n<p>\u4e3e\u4e2a\u4f8b\u5b50\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u770b<a href=\"https:\/\/arxiv.org\/pdf\/1902.00751.pdf\" target=\"_blank\" rel=\"noreferrer noopener\">Adapters\u8bba\u6587<\/a>\u3002\u5728\u8fd9\u79cd\u65b9\u6cd5\u4e2d\uff0c\u6211\u4eec\u5728transformer\u6a21\u5757\u7684FFN\u548cMHA\u5c42\u540e\u6dfb\u52a0\u4e86\u4e00\u4e2a\u989d\u5916\u7684\u201cadapter\u201d\u5c42\u3002\u8fd9\u91cc\u7684adapter\u5c42\u53ea\u662f\u4e00\u4e2a\u7b80\u5355\u7684\u4e24\u5c42\u5168\u8fde\u63a5\u795e\u7ecf\u7f51\u7edc\uff0c\u5176\u4e2d\u8f93\u5165\u548c\u8f93\u51fa\u7ef4\u5ea6\u662f<code>n_embd<\/code>\uff0c\u800c\u9690\u85cf\u7ef4\u5ea6\u5c0f\u4e8e<code>n_embd<\/code>\uff1a<\/p>\n\n\n\n<figure class=\"wp-block-image\"><img decoding=\"async\" src=\"https:\/\/jiqihumanr.github.io\/images\/adapter.png\" alt=\"(adapter)\"\/><\/figure>\n\n\n\n<p>\u6765\u81eaAdapters\u8bba\u6587\u7684\u56fe2<\/p>\n\n\n\n<p>\u9002\u914d\u5668\u65b9\u6cd5\u4e2d\uff0c\u9690\u85cf\u5c42\u7684\u5927\u5c0f\u662f\u4e00\u4e2a\u6211\u4eec\u53ef\u4ee5\u8bbe\u7f6e\u7684\u8d85\u53c2\u6570\uff0c\u8fd9\u4f7f\u6211\u4eec\u80fd\u591f\u5728\u53c2\u6570\u548c\u6027\u80fd\u4e4b\u95f4\u8fdb\u884c\u6743\u8861\u3002\u8be5\u8bba\u6587\u8868\u660e\uff0c\u5bf9\u4e8eBERT\u6a21\u578b\uff0c\u4f7f\u7528\u8fd9\u79cd\u65b9\u6cd5\u53ef\u4ee5\u5c06\u8bad\u7ec3\u53c2\u6570\u6570\u91cf\u964d\u4f4e\u52302\uff05\uff0c\u800c\u4e0e\u5b8c\u5168\u5fae\u8c03\u76f8\u6bd4\u4ec5\u6709\u5c11\u91cf\u7684\u6027\u80fd\u4e0b\u964d(&lt;1%)\u3002<\/p>\n\n\n\n<hr class=\"wp-block-separator has-alpha-channel-opacity\"\/>\n\n\n\n<ol class=\"wp-block-list\">\n<li>\u5927\u89c4\u6a21\u8bad\u7ec3\u3001\u6536\u96c6\u6d77\u91cf\u6570\u636e\u3001\u63d0\u9ad8\u6a21\u578b\u901f\u5ea6\u3001\u6027\u80fd\u8bc4\u4f30\u4ee5\u53ca\u5bf9\u9f50\u6a21\u578b\u4f7f\u5176\u4e3a\u4eba\u7c7b\u670d\u52a1\uff0c\u6570\u767e\u540d\u5de5\u7a0b\u5e08\/\u7814\u7a76\u4eba\u5458\u7684\u5c06\u8fd9\u89c6\u4e3a\u7ec8\u8eab\u4e8b\u4e1a\uff0c\u8fd9\u4e9b\u4eba\u7684\u5de5\u4f5c\u9020\u5c31\u4e86\u4eca\u65f6\u4eca\u65e5\u7684\u5927\u578b\u8bed\u8a00\u6a21\u578b\uff0c\u7edd\u4e0d\u4ec5\u4ec5\u662f\u56e0\u4e3a\u6a21\u578b\u7684\u67b6\u6784\u3002GPT\u67b6\u6784\u6070\u597d\u662f\u7b2c\u4e00\u4e2a\u5177\u6709\u826f\u597d\u7684\u53ef\u6269\u5c55\u6027\u3001\u53ef\u5728GPU\u4e0a\u9ad8\u5ea6\u5e76\u884c\u5316\u4e14\u5584\u4e8e\u5e8f\u5217\u5efa\u6a21\u7684\u795e\u7ecf\u7f51\u7edc\u67b6\u6784\u3002\u771f\u6b63\u7684\u79d8\u8bc0\u6765\u81ea\u4e8e\u6269\u5c55\u7684\u6570\u636e\u548c\u6a21\u578b\u89c4\u6a21\uff08<a href=\"http:\/\/www.incompleteideas.net\/IncIdeas\/BitterLesson.html\" target=\"_blank\" rel=\"noreferrer noopener\">\u4e00\u5982\u65e2\u5f80\u7684\u91cd\u8981<\/a>\uff09\uff0cGPT\u53ea\u662f\u8ba9\u6211\u4eec\u53ef\u4ee5\u8fd9\u6837\u505a\u800c\u5df2[9]\u3002\u53ef\u80fdTransformer\u7684\u6210\u529f\u662f\u521a\u597d\u4e2d\u4e86<a href=\"https:\/\/hardwarelottery.github.io\/\" target=\"_blank\" rel=\"noreferrer noopener\">\u786c\u4ef6\u5f69\u7968<\/a>\u800c\u5df2\uff0c\u8fd8\u6709\u4e00\u4e9b\u5176\u4ed6\u7684\u67b6\u6784\u53ef\u80fd\u6b63\u5728\u7b49\u5f85\u7740\u53d6\u4ee3Transformer\u3002<a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#fnref:1\">&nbsp;<\/a><\/li>\n\n\n\n<li>\u5bf9\u4e8e\u67d0\u4e9b\u5e94\u7528\u7a0b\u5e8f\uff0c\u5206\u8bcd\u5668\u4e0d\u9700\u8981\u4e00\u4e2a<code>decoder<\/code>\u65b9\u6cd5\u3002\u4f8b\u5982\uff0c\u5982\u679c\u4f60\u60f3\u8981\u5bf9\u7535\u5f71\u8bc4\u8bba\u8fdb\u884c\u5206\u7c7b\uff0c\u5224\u65ad\u8bc4\u8bba\u662f\u8bf4\u8fd9\u90e8\u7535\u5f71\u597d\u8fd8\u662f\u4e0d\u597d\uff0c\u4f60\u53ea\u9700\u8981\u80fd\u591f\u5bf9\u6587\u672c\u8fdb\u884c<code>encode<\/code>\uff0c\u5e76\u5728\u6a21\u578b\u4e0a\u8fdb\u884c\u524d\u5411\u4f20\u9012\uff0c\u6ca1\u6709\u5fc5\u8981\u8fdb\u884c<code>decode<\/code>\u3002\u4f46\u662f\u5bf9\u4e8e\u751f\u6210\u6587\u672c\uff0c<code>decode<\/code>\u662f\u5fc5\u9700\u7684\u3002<a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#fnref:2\">&nbsp;<\/a><\/li>\n\n\n\n<li>\u867d\u7136\u6709<a href=\"https:\/\/arxiv.org\/pdf\/2210.11416.pdf\" target=\"_blank\" rel=\"noreferrer noopener\">InstructGPT<\/a>\u548c<a href=\"https:\/\/arxiv.org\/pdf\/2203.15556.pdf\" target=\"_blank\" rel=\"noreferrer noopener\">Chinchilla<\/a>\u7684\u8bba\u6587\uff0c\u6211\u4eec\u5df2\u7ecf\u610f\u8bc6\u5230\u5b9e\u9645\u4e0a\u5e76\u4e0d\u9700\u8981\u8bad\u7ec3\u90a3\u4e48\u5927\u7684\u6a21\u578b\u3002\u5728\u7ecf\u8fc7\u6700\u4f18\u8bad\u7ec3\u548c\u6307\u4ee4\u5fae\u8c03\u540e\uff0c\u53c2\u6570\u4e3a13\u4ebf\u7684GPT\u6a21\u578b\u53ef\u4ee5\u80dc\u8fc7\u53c2\u6570\u4e3a1750\u4ebf\u7684GPT-3\u3002<a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#fnref:3\">&nbsp;<\/a><\/li>\n\n\n\n<li>\u539f\u59cb\u7684transformer\u8bba\u6587\u4f7f\u7528\u4e86<a href=\"https:\/\/nlp.seas.harvard.edu\/2018\/04\/03\/attention.html#positional-encoding\" target=\"_blank\" rel=\"noreferrer noopener\">\u9884\u8ba1\u7b97\u7684\u4f4d\u7f6e\u5d4c\u5165<\/a>\uff08positional embedding\uff09\uff0c\u4ed6\u4eec\u53d1\u73b0\u8fd9\u79cd\u65b9\u6cd5\u7684\u8868\u73b0\u548c\u5b66\u4e60\u7684\u4f4d\u7f6e\u5d4c\u5165\u4e00\u6837\u597d\uff0c\u4f46\u5176\u6709\u4e00\u4e2a\u660e\u663e\u7684\u4f18\u52bf\uff0c\u5373\u4f60\u53ef\u4ee5\u8f93\u5165\u4efb\u610f\u957f\u7684\u5e8f\u5217\uff08\u4e0d\u53d7\u6700\u5927\u5e8f\u5217\u957f\u5ea6\u7684\u9650\u5236\uff09\u3002\u7136\u800c\u5728\u5b9e\u8df5\u4e2d\uff0c\u60a8\u7684\u6a21\u578b\u53ea\u80fd\u8868\u73b0\u5f97\u548c\u5b83\u6240\u8bad\u7ec3\u7684\u5e8f\u5217\u957f\u5ea6\u4e00\u6837\u597d\u3002\u60a8\u4e0d\u80fd\u53ea\u5728\u957f\u5ea6\u4e3a1024\u7684\u5e8f\u5217\u4e0a\u8bad\u7ec3GPT\uff0c\u7136\u540e\u6307\u671b\u5b83\u5728\u957f\u5ea6\u4e3a16k\u7684\u5e8f\u5217\u4e0a\u8868\u73b0\u826f\u597d\u3002\u7136\u800c\u6700\u8fd1\u51fa\u73b0\u4e86\u4e00\u4e9b\u6210\u529f\u7684\u76f8\u5bf9\u4f4d\u7f6e\u5d4c\u5165\uff08relative positional embeddings\uff09\u65b9\u6cd5\uff0c\u5982<a href=\"https:\/\/arxiv.org\/pdf\/2108.12409.pdf\" target=\"_blank\" rel=\"noreferrer noopener\">Alibi<\/a>\u548c<a href=\"https:\/\/arxiv.org\/pdf\/2104.09864v4.pdf\" target=\"_blank\" rel=\"noreferrer noopener\">RoPE<\/a>\u3002<a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#fnref:4\">&nbsp;<\/a><\/li>\n\n\n\n<li>\u4e0d\u540c\u7684GPT\u6a21\u578b\u53ef\u80fd\u9009\u62e9\u4e0d\u540c\u7684\u9690\u85cf\u5c42\u5bbd\u5ea6\uff0c\u800c\u4e0d\u5fc5\u662f<code>4*n_embd<\/code>\uff0c\u8fd9\u662fGPT\u6a21\u578b\u7684\u901a\u884c\u505a\u6cd5\u3002\u6b64\u5916\uff0c\u6211\u4eec\u5728\u63a8\u52a8Transformer\u7684\u6210\u529f\u65b9\u9762\u7ed9\u4e88\u591a\u5934\u6ce8\u610f\u529b\u5c42\u5f88\u591a<em>\u6ce8\u610f<\/em>\uff08\u53cc\u5173\u4e86\u54e6\uff5e\uff09\uff0c\u4f46\u5728GPT-3\u7684\u89c4\u6a21\u4e0b\uff0c<a href=\"https:\/\/twitter.com\/stephenroller\/status\/1579993017234382849\" target=\"_blank\" rel=\"noreferrer noopener\">80%\u7684\u6a21\u578b\u53c2\u6570\u5305\u542b\u5728\u524d\u9988\u5c42\u4e2d<\/a>\u3002\u8fd9\u662f\u503c\u5f97\u601d\u8003\u7684\u4e8b\u60c5\u3002<a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#fnref:5\">&nbsp;<\/a><\/li>\n\n\n\n<li>\u5982\u679c\u4f60\u8fd8\u6ca1\u6709\u88ab\u8bf4\u670d\uff0c\u53ef\u4ee5\u770b\u4e00\u4e0bsoftmax\u65b9\u7a0b\uff0c\u81ea\u5df1\u7422\u78e8\u4e00\u4e0b\u8fd9\u662f\u6b63\u786e\u7684\uff08\u751a\u81f3\u53ef\u4ee5\u62ff\u51fa\u7b14\u548c\u7eb8\u8fdb\u884c\u8ba1\u7b97\uff09\u3002<a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#fnref:6\">&nbsp;<\/a><\/li>\n\n\n\n<li>\u8868\u767dJAX<a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#fnref:7\">&nbsp;\u21a9<\/a><\/li>\n\n\n\n<li>\u4f7f\u7528JAX\u7684\u8bdd\uff0c\u8fd9\u5c31\u53ef\u4ee5\u7b80\u5355\u5199\u4e3a<code>heads = jax.vmap(attention, in_axes=(0, 0, 0, None))(q, k, v, causal_mask)<\/code><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#fnref:8\">&nbsp;<\/a><\/li>\n\n\n\n<li>\u5b9e\u9645\u4e0a\uff0c\u6211\u53ef\u80fd\u4f1a\u4e89\u8fa9\u4e00\u4e0b\u6ce8\u610f\u529b\u6a21\u578b\u5728\u5904\u7406\u5e8f\u5217\u65f6\u7684\u65b9\u5f0f\u4e0e\u5faa\u73af\/\u5377\u79ef\u5c42\u76f8\u6bd4\uff0c\u5177\u6709\u5185\u5728\u7684\u4f18\u8d8a\u6027\uff0c\u4f46\u73b0\u5728\u6211\u4eec\u5df2\u7ecf\u9677\u5165\u4e86\u4e00\u4e2a\u6ce8\u811a\u4e2d\u7684\u6ce8\u811a\u4e86\uff0c\u90a3\u8fd8\u662f\u5148\u6253\u4f4f\u5427\u3002<a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/13\/gpt-from-scratch\/#fnref:9\">&nbsp;<\/a><\/li>\n<\/ol>\n\n\n\n<p><a href=\"https:\/\/jiqihumanr.github.io\/2023\/04\/12\/distance-matrices-with-numpy\/\">&nbsp;<\/a><\/p>\n","protected":false},"excerpt":{"rendered":"<p>\u672c\u6587\u8fd8\u662f\u6765\u81eaJay Mody\uff0c\u90a3\u7bc7\u88abAndrej Karpathy\u624b\u52a8\u70b9\u8d5e\u7684GPT in 60 Lines o [&hellip;]<\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"site-sidebar-layout":"default","site-content-layout":"","ast-site-content-layout":"default","site-content-style":"default","site-sidebar-style":"default","ast-global-header-display":"","ast-banner-title-visibility":"","ast-main-header-display":"","ast-hfb-above-header-display":"","ast-hfb-below-header-display":"","ast-hfb-mobile-header-display":"","site-post-title":"","ast-breadcrumbs-content":"","ast-featured-img":"","footer-sml-layout":"","theme-transparent-header-meta":"","adv-header-id-meta":"","stick-header-meta":"","header-above-stick-meta":"","header-main-stick-meta":"","header-below-stick-meta":"","astra-migrate-meta-layouts":"set","ast-page-background-enabled":"default","ast-page-background-meta":{"desktop":{"background-color":"var(--ast-global-color-4)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"tablet":{"background-color":"","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"mobile":{"background-color":"","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""}},"ast-content-background-meta":{"desktop":{"background-color":"var(--ast-global-color-5)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"tablet":{"background-color":"var(--ast-global-color-5)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"mobile":{"background-color":"var(--ast-global-color-5)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""}},"_jetpack_memberships_contains_paid_content":false,"footnotes":""},"categories":[313,289,443,442,312,43],"tags":[242,314],"class_list":["post-2123","post","type-post","status-publish","format-standard","hentry","category-chatgpt","category-gpt","category-llm","category-llms","category-openai","category-infoarticle","tag-chatgpt","tag-openai-api"],"views":2805,"jetpack_sharing_enabled":true,"jetpack_featured_media_url":"","_links":{"self":[{"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/posts\/2123","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=%2Fwp%2Fv2%2Fcomments&post=2123"}],"version-history":[{"count":181,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/posts\/2123\/revisions"}],"predecessor-version":[{"id":2315,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/posts\/2123\/revisions\/2315"}],"wp:attachment":[{"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=%2Fwp%2Fv2%2Fmedia&parent=2123"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=%2Fwp%2Fv2%2Fcategories&post=2123"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=%2Fwp%2Fv2%2Ftags&post=2123"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}