{"id":2567,"date":"2024-03-25T02:16:38","date_gmt":"2024-03-24T18:16:38","guid":{"rendered":"https:\/\/www.aqwu.net\/wp\/?p=2567"},"modified":"2024-04-28T20:01:42","modified_gmt":"2024-04-28T12:01:42","slug":"%e4%bb%8e%e5%a4%b4%e5%bc%80%e5%a7%8b%e7%bc%96%e5%86%99-lora-%e4%bb%a3%e7%a0%81","status":"publish","type":"post","link":"https:\/\/www.aqwu.net\/wp\/?p=2567","title":{"rendered":"\u4ece\u5934\u5f00\u59cb\u7f16\u5199 LoRA \u4ee3\u7801"},"content":{"rendered":"\n<p>\u5728 PyTorch \u4e2d\u5b9e\u73b0 LLMs \u7684\u4f4e\u9636\u9002\u5e94<a href=\"https:\/\/lightning.ai\/lightning-ai\/studios\/code-lora-from-scratch?view=public&amp;section=all#lora-from-scratch-implement-low-rank-adaptation-for-llms-in-pytorch\"><\/a><\/p>\n\n\n\n<p>LoRA \u4ee3\u8868\u4f4e\u9636\u9002\u5e94\uff0c\u662f\u4e00\u79cd\u66f4\u6709\u6548\u5730\u5fae\u8c03 LLMs \u7684\u6d41\u884c\u6280\u672f\u3002 LoRA \u4e0d\u8c03\u6574\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\u7684\u6240\u6709\u53c2\u6570\uff0c\u800c\u662f\u4e13\u6ce8\u4e8e\u4ec5\u66f4\u65b0\u4e00\u5c0f\u90e8\u5206\u4f4e\u79e9\u77e9\u9635.<\/p>\n\n\n\n<p>\u4ece\u5934\u5f00\u59cb\u7f16\u7801\u6765\u89e3\u91ca LoRA \u7684\u5de5\u4f5c\u539f\u7406\uff0c\u8fd9\u662f\u4e86\u89e3\u7b97\u6cd5\u5e95\u5c42\u7684\u7edd\u4f73\u7ec3\u4e60.<\/p>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>1. \u4e86\u89e3 LoRA<\/strong><\/h2>\n\n\n\n<p>\u9884\u8bad\u7ec3\u7684 LLMs \u901a\u5e38\u88ab\u79f0\u4e3a\u57fa\u7840\u6a21\u578b\uff0c\u56e0\u4e3a\u5b83\u4eec\u5728\u5404\u79cd\u4efb\u52a1\u4e2d\u5177\u6709\u591a\u529f\u80fd\u6027\u3002\u7136\u800c\uff0c\u9488\u5bf9\u7279\u5b9a\u6570\u636e\u96c6\u6216\u4efb\u52a1\u8c03\u6574\u9884\u8bad\u7ec3\u7684 LLM \u901a\u5e38\u5f88\u6709\u7528\uff0c\u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u5fae\u8c03\u6765\u5b8c\u6210\u3002<\/p>\n\n\n\n<p>\u5fae\u8c03\u5141\u8bb8\u6a21\u578b\u9002\u5e94\u7279\u5b9a\u9886\u57df\uff0c\u800c\u65e0\u9700\u8fdb\u884c\u6602\u8d35\u7684\u9884\u8bad\u7ec3\uff0c\u4f46\u66f4\u65b0\u6240\u6709\u5c42\u7684\u8ba1\u7b97\u6210\u672c\u4ecd\u7136\u5f88\u9ad8\uff0c\u5c24\u5176\u662f\u5bf9\u4e8e\u8f83\u5927\u7684\u6a21\u578b\u3002<\/p>\n\n\n\n<p>LoRA \u63d0\u4f9b\u4e86\u6bd4\u5e38\u89c4\u5fae\u8c03\u66f4\u6709\u6548\u7684\u66ff\u4ee3\u65b9\u6848\u3002\u6b63\u5982\u300aLoRA\uff1a\u5927\u578b\u8bed\u8a00\u6a21\u578b\u7684\u4f4e\u79e9\u9002\u5e94\u300b\u4e00\u6587\u4e2d\u66f4\u8be6\u7ec6\u8ba8\u8bba\u7684\u90a3\u6837\uff0cLoRA \u4ee5\u4f4e\u79e9\u683c\u5f0f\u8fd1\u4f3c\u8bad\u7ec3\u671f\u95f4\u5c42\u7684\u6743\u91cd\u53d8\u5316 <em>\u0394W<\/em>\u3002<\/p>\n\n\n\n<p>\u4f8b\u5982\uff0c\u5728\u5e38\u89c4\u5fae\u8c03\u4e2d\uff0c\u6211\u4eec\u5c06\u6743\u91cd\u77e9\u9635 W \u7684\u6743\u91cd\u66f4\u65b0\u8ba1\u7b97\u4e3a <em>\u0394W<\/em>\uff0c\u800c\u5728 LoRA \u4e2d\uff0c\u6211\u4eec\u901a\u8fc7\u4e24\u4e2a\u8f83\u5c0f\u77e9\u9635 AB \u7684\u77e9\u9635\u4e58\u6cd5\u6765\u8fd1\u4f3c <em>\u0394W<\/em>\uff0c\u5982\u4e0b\u56fe\u6240\u793a\u3002 \uff08\u5982\u679c\u60a8\u719f\u6089 PCA \u6216 SVD\uff0c\u8bf7\u5c06\u5176\u89c6\u4e3a\u5c06 <em>\u0394W<\/em> \u5206\u89e3\u4e3a A \u548c B\u3002\uff09<\/p>\n\n\n\n<figure class=\"wp-block-image size-full\"><img loading=\"lazy\" decoding=\"async\" width=\"917\" height=\"450\" src=\"https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/03\/\u56fe\u7247-1.png\" alt=\"\" class=\"wp-image-2568\" srcset=\"https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/03\/\u56fe\u7247-1.png 917w, https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/03\/\u56fe\u7247-1-300x147.png 300w, https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/03\/\u56fe\u7247-1-768x377.png 768w\" sizes=\"auto, (max-width: 917px) 100vw, 917px\" \/><\/figure>\n\n\n\n<p class=\"has-text-align-center\"><em>\u5e38\u89c4\u5fae\u8c03\uff08\u5de6\uff09\u548c LoRA\uff08\u53f3\uff09\u4e2d\u524d\u5411\u4f20\u9012\u8fc7\u7a0b\u4e2d\u6743\u91cd\u66f4\u65b0\u4e4b\u95f4\u7684\u6bd4\u8f83\u3002<\/em><\/p>\n\n\n\n<p>\u8bf7\u6ce8\u610f\uff0c\u4e0a\u56fe\u4e2d\u7684 r \u662f\u4e00\u4e2a\u8d85\u53c2\u6570\uff0c\u6211\u4eec\u53ef\u4ee5\u7528\u5b83\u6765\u6307\u5b9a\u7528\u4e8e\u9002\u5e94\u7684\u4f4e\u79e9\u77e9\u9635\u7684\u79e9\u3002\u8f83\u5c0f\u7684 r \u4f1a\u5bfc\u81f4\u66f4\u7b80\u5355\u7684\u4f4e\u79e9\u77e9\u9635\uff0c\u4ece\u800c\u5bfc\u81f4\u5728\u9002\u5e94\u8fc7\u7a0b\u4e2d\u9700\u8981\u5b66\u4e60\u7684\u53c2\u6570\u66f4\u5c11\u3002\u8fd9\u53ef\u4ee5\u5e26\u6765\u66f4\u5feb\u7684\u8bad\u7ec3\u5e76\u53ef\u80fd\u51cf\u5c11\u8ba1\u7b97\u9700\u6c42\u3002\u7136\u800c\uff0c\u968f\u7740 r \u7684\u51cf\u5c0f\uff0c\u4f4e\u79e9\u77e9\u9635\u6355\u83b7\u7279\u5b9a\u4efb\u52a1\u4fe1\u606f\u7684\u80fd\u529b\u4f1a\u964d\u4f4e\u3002<\/p>\n\n\n\n<p>\u4e3e\u4e2a\u5177\u4f53\u7684\u4f8b\u5b50\uff0c\u5047\u8bbe\u7ed9\u5b9a\u5c42\u7684\u6743\u91cd\u77e9\u9635W\u7684\u5927\u5c0f\u4e3a5,000&#215;10,000\uff08\u603b\u517150M\u4e2a\u53c2\u6570\uff09\u3002\u5982\u679c\u6211\u4eec\u9009\u62e9\u79e9r=8\uff0c\u6211\u4eec\u521d\u59cb\u5316\u4e24\u4e2a\u77e9\u9635\uff1a5,000&#215;8\u7ef4\u77e9\u9635B\u548c8&#215;10,000\u7ef4\u77e9\u9635A\u3002\u52a0\u5728\u4e00\u8d77\uff0cA\u548cB\u53ea\u670980,000 + 40,000 = 120,000\u4e2a\u53c2\u6570\uff0c\u5373400\u500d\u5c0f\u4e8e\u901a\u8fc7\u0394W\u8fdb\u884c\u5e38\u89c4\u5fae\u8c03\u768450M\u53c2\u6570\u3002<\/p>\n\n\n\n<p>\u5728\u5b9e\u8df5\u4e2d\uff0c\u5c1d\u8bd5\u4e0d\u540c\u7684 r \u503c\u4ee5\u627e\u5230\u9002\u5f53\u7684\u5e73\u8861\u4ee5\u5728\u65b0\u4efb\u52a1\u4e2d\u5b9e\u73b0\u6240\u9700\u7684\u6027\u80fd\u975e\u5e38\u91cd\u8981\u3002<\/p>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>2. \u4ece\u5934\u5f00\u59cb\u7f16\u5199 LoRA \u4ee3\u7801<\/strong><\/h2>\n\n\n\n<p>\u7531\u4e8e\u6982\u5ff5\u89e3\u91ca\u6709\u65f6\u53ef\u80fd\u5f88\u62bd\u8c61\uff0c\u56e0\u6b64\u73b0\u5728\u8ba9\u6211\u4eec\u81ea\u5df1\u5b9e\u73b0 LoRA\uff0c\u4ee5\u66f4\u597d\u5730\u4e86\u89e3\u5b83\u7684\u5de5\u4f5c\u539f\u7406\u3002\u5728\u4ee3\u7801\u4e2d\uff0c\u6211\u4eec\u53ef\u4ee5\u6309\u5982\u4e0b\u65b9\u5f0f\u5b9e\u73b0 LoRA \u5c42\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">class LoRALayer(torch.nn.Module):\n    def __init__(self, in_dim, out_dim, rank, alpha):\n        super().__init__()\n        std_dev = 1 \/ torch.sqrt(torch.tensor(rank).float())\n        self.A = torch.nn.Parameter(torch.randn(in_dim, rank) * std_dev)\n        self.B = torch.nn.Parameter(torch.zeros(rank, out_dim))\n        self.alpha = alpha\n\n    def forward(self, x):\n        x = self.alpha * (x @ self.A @ self.B)\n        return x<\/pre><\/div>\n\n\n\n<p>\u5728\u4e0a\u9762\u7684\u4ee3\u7801\u4e2d\uff0c&nbsp;<code><code><strong>in_dim<\/strong><\/code><\/code>&nbsp;\u662f\u6211\u4eec\u8981\u4f7f\u7528LoRA\u4fee\u6539\u7684\u5c42\u7684\u8f93\u5165\u7ef4\u5ea6\uff0c&nbsp;<code><code><strong>out_dim<\/strong><\/code><\/code>&nbsp;\u662f\u8be5\u5c42\u5404\u81ea\u7684\u8f93\u51fa\u7ef4\u5ea6\u3002<\/p>\n\n\n\n<p>\u6211\u4eec\u4e4b\u524d\u8ba8\u8bba\u8fc7\uff0c\u77e9\u9635&nbsp;<code><strong><em>A<\/em><\/strong><\/code>&nbsp;\u548c&nbsp;<code><strong><em>B<\/em><\/strong><\/code>&nbsp;\u7684&nbsp;<code>rank<\/code>&nbsp;\u662f\u4e00\u4e2a\u8d85\u53c2\u6570\uff0c\u7528\u4e8e\u63a7\u5236LoRA\u5f15\u5165\u7684\u9644\u52a0\u53c2\u6570\u7684\u590d\u6742\u6027\u548c\u6570\u91cf\u3002<\/p>\n\n\n\n<p>\u7136\u800c\uff0c\u770b\u770b\u4e0a\u9762\u7684\u4ee3\u7801\uff0c\u6211\u4eec\u6dfb\u52a0\u4e86\u53e6\u4e00\u4e2a\u8d85\u53c2\u6570\uff0c\u5373\u7f29\u653e\u56e0\u5b50&nbsp;<code><code><strong>alpha<\/strong><\/code><\/code>&nbsp;\u3002\u8be5\u56e0\u7d20\u51b3\u5b9a\u4e86 LoRA \u5c42\u5bf9\u6a21\u578b\u73b0\u6709\u6743\u91cd\u5f15\u5165\u7684\u53d8\u5316\u5e45\u5ea6\uff1a&nbsp;<code><code>alpha * (x @ A @ B)<\/code><\/code>&nbsp;\u3002&nbsp;<code>alpha<\/code>&nbsp;\u503c\u8d8a\u9ad8\uff0c\u610f\u5473\u7740\u5bf9\u6a21\u578b\u884c\u4e3a\u7684\u8c03\u6574\u8d8a\u5927\uff0c\u800c\u503c\u8d8a\u4f4e\uff0c\u53d8\u5316\u8d8a\u7ec6\u5fae\u3002<\/p>\n\n\n\n<p>\u53e6\u4e00\u4ef6\u9700\u8981\u6ce8\u610f\u7684\u4e8b\u60c5\u662f\uff0c\u6211\u4eec\u4f7f\u7528\u968f\u673a\u5206\u5e03\u4e2d\u7684\u5c0f\u503c\u521d\u59cb\u5316\u4e86&nbsp;<code><code>A<\/code><\/code>&nbsp;\u3002\u8fd9\u91cc\uff0c\u8fd9\u4e2a\u5206\u5e03\u7684\u6807\u51c6\u5dee\u662f\u7531\u79e9\u7684\u5e73\u65b9\u6839\u51b3\u5b9a\u7684\uff08\u8fd9\u4e2a\u9009\u62e9\u4fdd\u8bc1\u4e86&nbsp;<code><code>A<\/code><\/code>&nbsp;\u4e2d\u7684\u521d\u59cb\u503c\u4e0d\u4f1a\u592a\u5927\u3002\uff09\u4f46\u662f\uff0c\u6211\u4eec\u521d\u59cb\u5316\u4e86&nbsp;<code><code>B<\/code><\/code>&nbsp;\u5e26\u96f6\u3002\u8fd9\u91cc\u7684\u57fa\u672c\u539f\u7406\u662f\uff0c\u5728\u8bad\u7ec3\u5f00\u59cb\u65f6\uff0c\u5728\u901a\u8fc7\u53cd\u5411\u4f20\u64ad\u66f4\u65b0&nbsp;<code><code>A<\/code><\/code>&nbsp;\u548c&nbsp;<code><code>B<\/code><\/code>&nbsp;\u4e4b\u524d\uff0c&nbsp;<code>LoRALayer<\/code>&nbsp;\u4e0d\u4f1a\u5f71\u54cd\u539f\u59cb\u6743\u91cd\uff0c\u56e0\u4e3a&nbsp;<code><em><strong>AB=0<\/strong><\/em><\/code>&nbsp;if&nbsp;<code><strong><em>B=0<\/em><\/strong><\/code>&nbsp;\u3002<\/p>\n\n\n\n<p>\u8bf7\u6ce8\u610f\uff0cLoRA \u901a\u5e38\u5e94\u7528\u4e8e\u795e\u7ecf\u7f51\u7edc\u7684\u7ebf\u6027\uff08\u524d\u9988\uff09\u5c42\u3002\u4f8b\u5982\uff0c\u5047\u8bbe\u6211\u4eec\u6709\u4e00\u4e2a\u7b80\u5355\u7684 PyTorch \u6a21\u578b\u6216\u5177\u6709\u4e24\u4e2a\u7ebf\u6027\u5c42\u7684\u6a21\u5757\uff08\u4f8b\u5982\uff0c\u8fd9\u53ef\u80fd\u662f\u53d8\u538b\u5668\u5757\u7684\u524d\u9988\u6a21\u5757\uff09\u3002\u5047\u8bbe\u8be5\u6a21\u5757\u7684forward\u65b9\u6cd5\u5982\u4e0b\u6240\u793a\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">def forward(self, x):\n    x = self.linear_1(x)\n    x = F.relu(x)\n    x = self.linear_2(x)\n    return x<\/pre><\/div>\n\n\n\n<p>\u5982\u679c\u6211\u4eec\u4f7f\u7528 LoRA\uff0c\u6211\u4eec\u4f1a\u5c06 LoRA \u66f4\u65b0\u6dfb\u52a0\u5230\u8fd9\u4e9b\u7ebf\u6027\u5c42\u8f93\u51fa\u4e2d\uff0c\u5982\u4e0b\u6240\u793a\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">def forward(self, x):\n    x = self.linear_1(x) + self.lora_1(x)\n    x = F.relu(x)\n    x = self.linear_2(x) + self.lora_2(x)\n    return logits<\/pre><\/div>\n\n\n\n<p>\u5728\u4ee3\u7801\u4e2d\uff0c\u5f53\u901a\u8fc7\u4fee\u6539\u73b0\u6709 PyTorch \u6a21\u578b\u6765\u5b9e\u73b0 LoRA \u65f6\uff0c\u5b9e\u73b0\u7ebf\u6027\u5c42\u7684\u8fd9\u79cd LoRA \u4fee\u6539\u7684\u4e00\u79cd\u7b80\u5355\u65b9\u6cd5\u662f\u5c06\u6bcf\u4e2a\u7ebf\u6027\u5c42\u66ff\u6362\u4e3a&nbsp;<code><code>LinearWithLoRA<\/code><\/code>&nbsp;\u5c42\uff0c\u8be5\u5c42\u5c06&nbsp;<code><code>Linear<\/code><\/code>&nbsp;\u5c42\u4e0e\u6211\u4eec\u4e4b\u524d\u7684&nbsp;<code><code>LoRALayer<\/code><\/code>&nbsp;\u5b9e\u73b0\u76f8\u7ed3\u5408\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">class LinearWithLoRA(torch.nn.Module):\n    def __init__(self, linear, rank, alpha):\n        super().__init__()\n        self.linear = linear\n        self.lora = LoRALayer(\n            linear.in_features, linear.out_features, rank, alpha\n        )\n\n    def forward(self, x):\n        return self.linear(x) + self.lora(x)<\/pre><\/div>\n\n\n\n<p>\u4e0a\u8ff0\u8fd9\u4e9b\u6982\u5ff5\u603b\u7ed3\u5982\u4e0b\u56fe<\/p>\n\n\n\n<figure class=\"wp-block-image size-large\"><img loading=\"lazy\" decoding=\"async\" width=\"1024\" height=\"391\" src=\"https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/03\/\u56fe\u7247-2-1024x391.png\" alt=\"\" class=\"wp-image-2577\" srcset=\"https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/03\/\u56fe\u7247-2-1024x391.png 1024w, https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/03\/\u56fe\u7247-2-300x115.png 300w, https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/03\/\u56fe\u7247-2-768x293.png 768w, https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/03\/\u56fe\u7247-2-1536x586.png 1536w, https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/03\/\u56fe\u7247-2-2048x782.png 2048w\" sizes=\"auto, (max-width: 1024px) 100vw, 1024px\" \/><\/figure>\n\n\n\n<p>\u5728\u5b9e\u8df5\u4e2d\uff0c\u8981\u4f7f\u7528 LoRA \u88c5\u5907\u548c\u5fae\u8c03\u6a21\u578b\uff0c\u6211\u4eec\u6240\u8981\u505a\u7684\u5c31\u662f\u7528\u65b0\u7684&nbsp;<code>LinearWithLoRA<\/code>&nbsp;\u5c42\u66ff\u6362\u5176\u9884\u8bad\u7ec3\u7684&nbsp;<code>Linear<\/code>&nbsp;\u5c42\u3002\u6211\u4eec\u5c06\u5728\u4e0b\u9762\u7684\u5b9e\u8df5\u90e8\u5206\u4e2d\u901a\u8fc7\u5c06&nbsp;<code>LinearWithLoRA<\/code>&nbsp;\u5c42\u5e94\u7528\u4e8e\u9884\u8bad\u7ec3\u7684\u8bed\u8a00\u6a21\u578b\u6765\u4e86\u89e3\u5176\u5de5\u4f5c\u539f\u7406\u3002<\/p>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>3. \u4f7f\u7528 LoRA \u8fdb\u884c\u5fae\u8c03\u2014\u2014\u4e00\u4e2a\u5b9e\u8df5\u793a\u4f8b<\/strong><\/h2>\n\n\n\n<p>LoRA \u662f\u4e00\u79cd\u53ef\u4ee5\u5e94\u7528\u4e8e\u5404\u79cd\u7c7b\u578b\u795e\u7ecf\u7f51\u7edc\u7684\u65b9\u6cd5\uff0c\u800c\u4e0d\u4ec5\u4ec5\u662f GPT \u6216\u56fe\u50cf\u751f\u6210\u6a21\u578b\u7b49\u751f\u6210\u6a21\u578b\u3002\u5bf9\u4e8e\u8fd9\u4e2a\u5b9e\u8df5\u793a\u4f8b\uff0c\u6211\u4eec\u5c06\u8bad\u7ec3\u4e00\u4e2a\u7528\u4e8e\u6587\u672c\u5206\u7c7b\u7684\u5c0f\u578b BERT \u6a21\u578b\uff0c\u56e0\u4e3a\u5206\u7c7b\u51c6\u786e\u6027\u6bd4\u751f\u6210\u7684\u6587\u672c\u66f4\u5bb9\u6613\u8bc4\u4f30\u3002 <\/p>\n\n\n\n<p>\u7279\u522b\u662f\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 Transformer \u5e93\u4e2d\u7684\u9884\u8bad\u7ec3 DistilBERT\uff08BERT \u7684\u8f83\u5c0f\u7248\u672c\uff09\u6a21\u578b<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">from transformers import AutoModelForSequenceClassification\n\nmodel = AutoModelForSequenceClassification.from_pretrained(\n    \"distilbert-base-uncased\", num_labels=2)<\/pre><\/div>\n\n\n\n<p>\u7531\u4e8e\u6211\u4eec\u53ea\u60f3\u8bad\u7ec3\u65b0\u7684 LoRA \u6743\u91cd\uff0c\u56e0\u6b64\u6211\u4eec\u901a\u8fc7\u5c06\u6240\u6709\u53ef\u8bad\u7ec3\u7684&nbsp;<code><code>requires_grad<\/code><\/code>&nbsp;\u8bbe\u7f6e\u4e3a&nbsp;<code><code>False<\/code><\/code>&nbsp;\u6765\u51bb\u7ed3\u6240\u6709\u6a21\u578b\u53c2\u6570\u53c2\u6570\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">for param in model.parameters():\n    param.requires_grad = False\nprint(model)<\/pre><\/div>\n\n\n\n<p>\u63a5\u4e0b\u6765\uff0c\u8ba9\u6211\u4eec\u4f7f\u7528&nbsp;<code>print(model)<\/code>&nbsp;\u7b80\u8981\u68c0\u67e5\u6a21\u578b\u7684\u7ed3\u6784\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">DistilBertForSequenceClassification(\n  (distilbert): DistilBertModel(\n    (embeddings): Embeddings(\n      (word_embeddings): Embedding(30522, 768, padding_idx=0)\n      (position_embeddings): Embedding(512, 768)\n      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n      (dropout): Dropout(p=0.1, inplace=False)\n    )\n    (transformer): Transformer(\n      (layer): ModuleList(\n        (0-5): 6 x TransformerBlock(\n          (attention): MultiHeadSelfAttention(\n            (dropout): Dropout(p=0.1, inplace=False)\n            (q_lin): Linear(in_features=768, out_features=768, bias=True)\n            (k_lin): Linear(in_features=768, out_features=768, bias=True)\n            (v_lin): Linear(in_features=768, out_features=768, bias=True)\n            (out_lin): Linear(in_features=768, out_features=768, bias=True)\n          )\n          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n          (ffn): FFN(\n            (dropout): Dropout(p=0.1, inplace=False)\n            (lin1): Linear(in_features=768, out_features=3072, bias=True)\n            (lin2): Linear(in_features=3072, out_features=768, bias=True)\n            (activation): GELUActivation()\n          )\n          (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n        )\n      )\n    )\n  )\n  (pre_classifier): Linear(in_features=768, out_features=768, bias=True)\n  (classifier): Linear(in_features=768, out_features=2, bias=True)\n  (dropout): Dropout(p=0.2, inplace=False)\n)<\/pre><\/div>\n\n\n\n<p>\u6839\u636e\u4e0b\u9762\u7684\u8f93\u51fa\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u5230\u8be5\u6a21\u578b\u7531 6 \u4e2a\u5305\u542b\u7ebf\u6027\u5c42\u7684 Transformer \u5c42\u7ec4\u6210\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">(0-5): 6 x TransformerBlock(\n          (attention): MultiHeadSelfAttention(\n            ...\n            (q_lin): Linear(in_features=768, out_features=768, bias=True)\n\t\t\t\t\t\t...\n          )\n          ...<\/pre><\/div>\n\n\n\n<p>\u6b64\u5916\uff0c\u8be5\u6a21\u578b\u8fd8\u6709\u4e24\u4e2a&nbsp;<code>Linear<\/code>&nbsp;\u8f93\u51fa\u5c42\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">...\n(pre_classifier): Linear(in_features=768, out_features=768, bias=True)\n(classifier): Linear(in_features=768, out_features=2, bias=True)\n<\/pre><\/div>\n\n\n\n<p>\u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u5b9a\u4e49\u4ee5\u4e0b\u5206\u914d\u51fd\u6570\u548c\u5faa\u73af\u6765\u9009\u62e9\u6027\u5730\u4e3a\u8fd9\u4e9b&nbsp;<code>Linear<\/code>&nbsp;\u5c42\u542f\u7528 LoRA\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">from transformers import AutoModelForSequenceClassification\nimport torch\n\nclass LoRALayer(torch.nn.Module):\n    def __init__(self, in_dim, out_dim, rank, alpha):\n        super().__init__()\n        std_dev = 1 \/ torch.sqrt(torch.tensor(rank).float())\n        self.A = torch.nn.Parameter(torch.randn(in_dim, rank) * std_dev)\n        self.B = torch.nn.Parameter(torch.zeros(rank, out_dim))\n        self.alpha = alpha\n\n    def forward(self, x):\n        x = self.alpha * (x @ self.A @ self.B)\n        return x\n\nclass LinearWithLoRA(torch.nn.Module):\n    def __init__(self, linear, rank, alpha):\n        super().__init__()\n        self.linear = linear\n        self.lora = LoRALayer(\n            linear.in_features, linear.out_features, rank, alpha\n        )\n\n    def forward(self, x):\n        return self.linear(x) + self.lora(x)\n        \nmodel = AutoModelForSequenceClassification.from_pretrained(\n    \"distilbert-base-uncased\", num_labels=2)\n\nfor param in model.parameters():\n    param.requires_grad = False\n\nfrom functools import partial\n\n# default hyperparameter choices\nlora_r = 8\nlora_alpha = 16\nlora_dropout = 0.05\nlora_query = True\nlora_key = False\nlora_value = True\nlora_projection = False\nlora_mlp = False\nlora_head = False\n\nlayers = []\n\nassign_lora = partial(LinearWithLoRA, rank=lora_r, alpha=lora_alpha)\n\nfor layer in model.distilbert.transformer.layer:\n    if lora_query:\n        layer.attention.q_lin = assign_lora(layer.attention.q_lin)\n    if lora_key:\n        layer.attention.k_lin = assign_lora(layer.attention.k_lin)\n    if lora_value:\n        layer.attention.v_lin = assign_lora(layer.attention.v_lin)\n    if lora_projection:\n        layer.attention.out_lin = assign_lora(layer.attention.out_lin)\n    if lora_mlp:\n        layer.ffn.lin1 = assign_lora(layer.ffn.lin1)\n        layer.ffn.lin2 = assign_lora(layer.ffn.lin2)\nif lora_head:\n    model.pre_classifier = assign_lora(model.pre_classifier)\n    model.classifier = assign_lora(model.classifier)\n\nprint(model)\n    <\/pre><\/div>\n\n\n\n<p>\u6211\u4eec\u73b0\u5728\u53ef\u4ee5\u4f7f\u7528&nbsp;<code>print(model)<\/code>&nbsp;\u518d\u6b21\u68c0\u67e5\u6a21\u578b\u4ee5\u68c0\u67e5\u5176\u66f4\u65b0\u540e\u7684\u7ed3\u6784\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">DistilBertForSequenceClassification(\n  (distilbert): DistilBertModel(\n    (embeddings): Embeddings(\n      (word_embeddings): Embedding(30522, 768, padding_idx=0)\n      (position_embeddings): Embedding(512, 768)\n      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n      (dropout): Dropout(p=0.1, inplace=False)\n    )\n    (transformer): Transformer(\n      (layer): ModuleList(\n        (0-5): 6 x TransformerBlock(\n          (attention): MultiHeadSelfAttention(\n            (dropout): Dropout(p=0.1, inplace=False)\n            (q_lin): LinearWithLoRA(\n              (linear): Linear(in_features=768, out_features=768, bias=True)\n              (lora): LoRALayer()\n            )\n            (k_lin): Linear(in_features=768, out_features=768, bias=True)\n            (v_lin): LinearWithLoRA(\n              (linear): Linear(in_features=768, out_features=768, bias=True)\n              (lora): LoRALayer()\n            )\n            (out_lin): Linear(in_features=768, out_features=768, bias=True)\n          )\n          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n          (ffn): FFN(\n            (dropout): Dropout(p=0.1, inplace=False)\n            (lin1): Linear(in_features=768, out_features=3072, bias=True)\n            (lin2): Linear(in_features=3072, out_features=768, bias=True)\n            (activation): GELUActivation()\n          )\n          (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n        )\n      )\n    )\n  )\n  (pre_classifier): Linear(in_features=768, out_features=768, bias=True)\n  (classifier): Linear(in_features=768, out_features=2, bias=True)\n  (dropout): Dropout(p=0.2, inplace=False)\n)<\/pre><\/div>\n\n\n\n<p>\u5982\u4e0a\u6240\u793a\uff0c&nbsp;<code>Linear<\/code>&nbsp;\u5c42\u5df2\u6210\u529f\u66ff\u6362\u4e3a&nbsp;<code>LinearWithLoRA<\/code>&nbsp;\u5c42<\/p>\n\n\n\n<p>\u5982\u679c\u6211\u4eec\u4f7f\u7528\u4e0a\u9762\u663e\u793a\u7684\u9ed8\u8ba4\u8d85\u53c2\u6570\u9009\u62e9\u6765\u8bad\u7ec3\u6a21\u578b\uff0c\u5219\u4f1a\u5728 IMDb \u7535\u5f71\u8bc4\u8bba\u5206\u7c7b\u6570\u636e\u96c6\u4e0a\u4ea7\u751f\u4ee5\u4e0b\u6027\u80fd\uff1a<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>Train acc: 92.15%<\/li>\n\n\n\n<li>Val acc: 89.98%<\/li>\n\n\n\n<li>Test acc: 89.44%<\/li>\n<\/ul>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>4. \u4e0e\u4f20\u7edf\u5fae\u8c03\u7684\u6bd4\u8f83<a href=\"https:\/\/lightning.ai\/lightning-ai\/studios\/code-lora-from-scratch?tab=overview#comparison-to-traditional-finetuning\"><\/a><\/strong><\/h2>\n\n\n\n<p>\u5728\u4e0a\u4e00\u8282\u4e2d\uff0c\u6211\u4eec\u5728 LoRA \u9ed8\u8ba4\u8bbe\u7f6e\u4e0b\u83b7\u5f97\u4e86 89.44% \u7684\u6d4b\u8bd5\u51c6\u786e\u7387\u3002\u8fd9\u4e0e\u4f20\u7edf\u7684\u5fae\u8c03\u76f8\u6bd4\u5982\u4f55\uff1f<\/p>\n\n\n\n<p>\u8ba9\u6211\u4eec\u4ece\u8bad\u7ec3 DistilBERT \u6a21\u578b\u5f00\u59cb\uff0c\u4f46\u5728\u8bad\u7ec3\u671f\u95f4\u4ec5\u66f4\u65b0\u6700\u540e 2 \u5c42\u3002\u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u9996\u5148\u51bb\u7ed3\u6240\u6709\u6a21\u578b\u6743\u91cd\uff0c\u7136\u540e\u89e3\u51bb\u4e24\u4e2a\u7ebf\u6027\u8f93\u51fa\u5c42\u6765\u5b9e\u73b0\u8fd9\u4e00\u70b9\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \"># freeze all\nfor param in model.parameters():\n    param.requires_grad = False\n\n# unfreeze output layers\nfor param in model.pre_classifier.parameters():\n    param.requires_grad = True\n\nfor param in model.classifier.parameters():\n    param.requires_grad = True<\/pre><\/div>\n\n\n\n<p>\u4ec5\u8bad\u7ec3\u6700\u540e\u4e24\u5c42\u540e\uff0c\u6240\u5f97\u5206\u7c7b\u6027\u80fd\u5982\u4e0b\uff1a<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>Train acc: 86.68%<\/li>\n\n\n\n<li>Val acc: 87.26%<\/li>\n\n\n\n<li>Test acc: 86.22%<\/li>\n<\/ul>\n\n\n\n<p>\u6b63\u5982\u6211\u4eec\u6240\u770b\u5230\u7684\uff0cLoRA \u7684\u6d4b\u8bd5\u51c6\u786e\u7387\u8fbe\u5230 89.44%\uff0c\u4f18\u4e8e\u4ec5\u8bad\u7ec3\u4e0a\u9762\u6700\u540e\u4e24\u5c42\u7684\u60c5\u51b5\u3002\u6b64\u5916\uff0c\u4e0a\u9762\u7684LoRA\u914d\u7f6e\u4e5f\u66f4\u8f7b\uff0c\u56e0\u4e3a\u5b83\u53ea\u9700\u8981\u8bad\u7ec3147,456\u4e2a\u53c2\u6570\u3002\u76f8\u6bd4\u4e4b\u4e0b\uff0c\u5bf9\u6700\u540e\u4e24\u5c42\u8fdb\u884c\u5fae\u8c03\u9700\u8981\u66f4\u65b0 592,130 \u4e2a\u53c2\u6570\uff0c\u8fd9\u4e24\u4e2a\u5c42\u5f88\u5927\u3002<\/p>\n\n\n\n<p>\u73b0\u5728\uff0c\u5fae\u8c03\u6240\u6709\u56fe\u5c42\u600e\u4e48\u6837\uff1f<\/p>\n\n\n\n<p>\u5982\u679c\u6211\u4eec\u4ee5\u4f20\u7edf\u65b9\u5f0f\u5fae\u8c03 DistilBERT \u6a21\u578b\u7684\u6240\u6709\u5c42\uff0c\u6211\u4eec\u5c06\u83b7\u5f97\u4ee5\u4e0b\u7ed3\u679c\uff1a<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>Train acc: 96.41%<\/li>\n\n\n\n<li>Val acc: 92.80%<\/li>\n\n\n\n<li>Test acc: 92.31%<\/li>\n<\/ul>\n\n\n\n<p>\u56e0\u6b64\uff0c\u6b63\u5982\u6211\u4eec\u6240\u770b\u5230\u7684\uff0c\u5fae\u8c03\u6240\u6709\u5c42\uff08\u6d89\u53ca\u8bad\u7ec3 66,955,010 \u4e2a\u53c2\u6570\uff09\u6bd4\u4ec5\u5fae\u8c03\u6700\u540e\u4e24\u5c42\uff08592,130 \u4e2a\u53c2\u6570\uff09\u548c LoRA \u9ed8\u8ba4\u5c42\uff08147,456 \u4e2a\u53c2\u6570\uff09\u8868\u73b0\u66f4\u597d\u3002<\/p>\n\n\n\n<p>\u4f5c\u4e3a\u4e00\u4e2a\u8981\u70b9\uff0c\u5230\u76ee\u524d\u4e3a\u6b62\uff0cLoRA \u7684\u6027\u80fd\u6bd4\u6700\u540e\u4e24\u5c42\u7684\u4f20\u7edf\u5fae\u8c03\u8981\u597d\uff0c\u5c3d\u7ba1\u5b83\u4f7f\u7528\u7684\u53c2\u6570\u5c11\u4e86 4 \u500d\u3002\u5fae\u8c03\u6240\u6709\u5c42\u9700\u8981\u66f4\u65b0\u7684\u53c2\u6570\u6bd4 LoRA \u8bbe\u7f6e\u591a 450 \u500d\uff0c\u4f46\u6d4b\u8bd5\u7cbe\u5ea6\u4e5f\u63d0\u9ad8\u4e86 2%\u3002<\/p>\n\n\n\n<p>\u7136\u800c\uff0c\u9700\u8981\u8003\u8651\u7684\u4e00\u4e2a\u65b9\u9762\u662f\uff0c\u5230\u76ee\u524d\u4e3a\u6b62\u6211\u4eec\u4ec5\u4f7f\u7528 LoRA \u9ed8\u8ba4\u8bbe\u7f6e\u3002\u4e5f\u8bb8\u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u4e0d\u540c\u7684 LoRA \u8d85\u53c2\u6570\u914d\u7f6e\u6765\u5f25\u8865\u5b8c\u5168\u5fae\u8c03\u548c LoRA \u5fae\u8c03\u4e4b\u95f4\u7684\u5dee\u8ddd\uff1f\u6211\u4eec\u5c06\u5728\u4e0b\u4e00\u8282\u4e2d\u56de\u7b54\u8fd9\u4e2a\u95ee\u9898\u3002<\/p>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>5. \u4f18\u5316LoRA\u914d\u7f6e<\/strong><\/h2>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">lora_r = 8\nlora_alpha = 16\nlora_dropout = 0.05\nlora_query = True\nlora_key = False\nlora_value = True\nlora_projection = False\nlora_mlp = False\nlora_head = False<\/pre><\/div>\n\n\n\n<p>\u8bf7\u6ce8\u610f\uff0c\u8fd9\u4ec5\u6d89\u53ca\u5c06 LoRA \u5e94\u7528\u4e8e\u5173\u6ce8\u5c42\u7684\u67e5\u8be2\u548c\u503c\u6743\u91cd\u77e9\u9635\u3002\u6216\u8005\uff0c\u6211\u4eec\u4e5f\u53ef\u4ee5\u4e3a\u5176\u4ed6\u5c42\u542f\u7528 LoRA\u3002\u6b64\u5916\uff0c\u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u4fee\u6539\u7b49\u7ea7\uff08&nbsp;<code>lora_r<\/code>&nbsp;\uff09\u6765\u63a7\u5236\u6bcf\u4e2a LoRA \u5c42\u4e2d\u53ef\u8bad\u7ec3\u53c2\u6570\u7684\u6570\u91cf\u3002<\/p>\n\n\n\n<p>\u8981\u5c1d\u8bd5\u4e0d\u540c\u7684\u8d85\u53c2\u6570\u914d\u7f6e\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528\u6211\u4eec\u7684\u7d27\u51d1\u578b&nbsp;<code>03_finetune-lora.py script<\/code>&nbsp;\uff0c\u5b83\u63a5\u53d7\u8d85\u53c2\u6570\u9009\u62e9\u4f5c\u4e3a\u547d\u4ee4\u884c\u53c2\u6570\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">import argparse\nimport os\nimport shutil\nimport time\nfrom functools import partial\n\nimport lightning as L\nfrom lightning.pytorch.loggers import CSVLogger\nfrom lightning.pytorch.callbacks import ModelCheckpoint\n\nfrom transformers import AutoModelForSequenceClassification\nimport torch\n\nfrom local_dataset_utilities import tokenization, setup_dataloaders, get_dataset\nfrom local_model_utilities import CustomLightningModule\n\n\ndef str2bool(v):\n    if isinstance(v, bool):\n       return v\n    if v.lower() in ('yes', 'true'):\n        return True\n    elif v.lower() in ('no', 'false'):\n        return False\n    else:\n        raise argparse.ArgumentTypeError('Boolean value expected.')\n\n\nclass LoRALayer(torch.nn.Module):\n    def __init__(self, in_dim, out_dim, rank, alpha):\n        super().__init__()\n        std_dev = 1 \/ torch.sqrt(torch.tensor(rank).float())\n        self.W_a = torch.nn.Parameter(torch.randn(in_dim, rank) * std_dev)\n        self.W_b = torch.nn.Parameter(torch.zeros(rank, out_dim))\n        self.alpha = alpha\n\n    def forward(self, x):\n        x = self.alpha * (x @ self.W_a @ self.W_b)\n        return x\n\n\nclass LinearWithLoRA(torch.nn.Module):\n    def __init__(self, linear, rank, alpha):\n        super().__init__()\n        self.linear = linear\n        self.lora = LoRALayer(\n            linear.in_features, linear.out_features, rank, alpha\n        )\n\n    def forward(self, x):\n        return self.linear(x) + self.lora(x)\n\n\ndef count_parameters(model):\n    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n\n\nif __name__ == \"__main__\":\n\n    parser = argparse.ArgumentParser(description='LoRA parameters configuration')\n    parser.add_argument('--lora_r', type=int, default=8, help='Rank for LoRA layers')\n    parser.add_argument('--lora_alpha', type=int, default=16, help='Alpha for LoRA layers')\n    parser.add_argument('--lora_query', type=str2bool, default=True, help='Apply LoRA to query')\n    parser.add_argument('--lora_key', type=str2bool, default=False, help='Apply LoRA to key')\n    parser.add_argument('--lora_value', type=str2bool, default=True, help='Apply LoRA to value')\n    parser.add_argument('--lora_projection', type=str2bool, default=False, help='Apply LoRA to projection')\n    parser.add_argument('--lora_mlp', type=str2bool, default=False, help='Apply LoRA to MLP')\n    parser.add_argument('--lora_head', type=str2bool, default=False, help='Apply LoRA to head')\n    parser.add_argument('--device', type=int, default=0, help='Specify GPU device index')\n    parser.add_argument('--verbose', type=str2bool, default=True, help='Enable\/disable progress bars')\n    args = parser.parse_args()\n\n    if not torch.cuda.is_available():\n        print(\"Please switch to a GPU machine before running this code.\")\n        quit()\n\n    df_train, df_val, df_test = get_dataset()\n    imdb_tokenized = tokenization()\n    train_loader, val_loader, test_loader = setup_dataloaders(imdb_tokenized)\n\n    model = AutoModelForSequenceClassification.from_pretrained(\n        \"distilbert-base-uncased\", num_labels=2\n    )\n\n    # Freeze all layers\n    for param in model.parameters():\n        param.requires_grad = False\n\n    assign_lora = partial(LinearWithLoRA, rank=args.lora_r, alpha=args.lora_alpha)\n\n    for layer in model.distilbert.transformer.layer:\n        if args.lora_query:\n            layer.attention.q_lin = assign_lora(layer.attention.q_lin)\n        if args.lora_key:\n            layer.attention.k_lin = assign_lora(layer.attention.k_lin)\n        if args.lora_value:\n            layer.attention.v_lin = assign_lora(layer.attention.v_lin)\n        if args.lora_projection:\n            layer.attention.out_lin = assign_lora(layer.attention.out_lin)\n        if args.lora_mlp:\n            layer.ffn.lin1 = assign_lora(layer.ffn.lin1)\n            layer.ffn.lin2 = assign_lora(layer.ffn.lin2)\n    if args.lora_head:\n        model.pre_classifier = assign_lora(model.pre_classifier)\n        model.classifier = assign_lora(model.classifier)\n\n    print(\"Total number of trainable parameters:\", count_parameters(model))\n\n    lightning_model = CustomLightningModule(model)\n    callbacks = [\n        ModelCheckpoint(\n            save_top_k=1, mode=\"max\", monitor=\"val_acc\"\n        )  # save top 1 model\n    ]\n    logger = CSVLogger(save_dir=\"logs\/\", name=f\"my-model-{args.device}\")\n\n    trainer = L.Trainer(\n        max_epochs=3,\n        callbacks=callbacks,\n        accelerator=\"gpu\",\n        precision=\"16-mixed\",\n        devices=[int(args.device)],\n        logger=logger,\n        log_every_n_steps=10,\n        enable_progress_bar=args.verbose\n    )\n\n    start = time.time()\n\n    trainer.fit(\n        model=lightning_model,\n        train_dataloaders=train_loader,\n        val_dataloaders=val_loader,\n    )\n\n    end = time.time()\n    elapsed = end - start\n    print(f\"Time elapsed {elapsed\/60:.2f} min\")\n\n    train_acc = trainer.test(lightning_model, dataloaders=train_loader, ckpt_path=\"best\", verbose=False)\n    val_acc = trainer.test(lightning_model, dataloaders=val_loader, ckpt_path=\"best\", verbose=False)\n    test_acc = trainer.test(lightning_model, dataloaders=test_loader, ckpt_path=\"best\", verbose=False)\n\n    # Print all argparse settings\n    print(\"------------------------------------------------\")\n    for arg in vars(args):\n        print(f'{arg}: {getattr(args, arg)}')\n\n    train_acc = trainer.test(lightning_model, dataloaders=train_loader, ckpt_path=\"best\", verbose=False)\n    val_acc = trainer.test(lightning_model, dataloaders=val_loader, ckpt_path=\"best\", verbose=False)\n    test_acc = trainer.test(lightning_model, dataloaders=test_loader, ckpt_path=\"best\", verbose=False)\n\n    # Print settings and results\n    with open(\"results.txt\", \"a\") as f:\n        s = \"------------------------------------------------\"\n        print(s), f.write(s+\"\\n\")        \n        for arg in vars(args):\n            s = f'{arg}: {getattr(args, arg)}'\n            print(s), f.write(s+\"\\n\")\n\n        s = f\"Train acc: {train_acc[0]['accuracy']*100:2.2f}%\"\n        print(s), f.write(s+\"\\n\")\n        s = f\"Val acc:   {val_acc[0]['accuracy']*100:2.2f}%\"\n        print(s), f.write(s+\"\\n\")\n        s = f\"Test acc:  {test_acc[0]['accuracy']*100:2.2f}%\"\n        print(s), f.write(s+\"\\n\")\n        s = \"------------------------------------------------\"\n        print(s), f.write(s+\"\\n\")    \n\n    # Cleanup\n    log_dir = f\"logs\/my-model-{args.device}\"\n    if os.path.exists(log_dir):\n        shutil.rmtree(log_dir)<\/pre><\/div>\n\n\n\n<p>\u8fd0\u884c\u4ee3\u7801<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:sh decode:true \">python 03_finetune-lora.py --lora_alpha 32 --lora_r 16<\/pre><\/div>\n\n\n\n<p>\u6b64\u5916\uff0c\u60a8\u8fd8\u53ef\u4ee5\u5207\u6362\u5176\u4ed6\u8d85\u53c2\u6570\u8bbe\u7f6e\uff0c\u4f8b\u5982\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:sh decode:true \">python 03_finetune-lora.py \\\n--lora_alpha 32 \\\n--lora_r 16 \\\n--lora_query True \\\n--lora_key True \\\n--lora_value True \\\n--lora_projection True \\\n--lora_mlp True \\\n--lora_head True<\/pre><\/div>\n\n\n\n<p>\u63d0\u9ad8 LoRA \u6027\u80fd\u7684\u4e00\u79cd\u65b9\u6cd5\u662f\u624b\u52a8\u8c03\u6574\u8fd9\u4e9b\u8d85\u53c2\u6570\u9009\u62e9\u3002\u4f46\u662f\uff0c\u4e3a\u4e86\u4f7f\u8d85\u53c2\u6570\u8c03\u6574\u66f4\u52a0\u65b9\u4fbf\uff0c\u60a8\u8fd8\u53ef\u4ee5\u4f7f\u7528&nbsp;<code>03_gridsearch.py script<\/code>&nbsp;\uff0c\u5b83\u5728\u6240\u6709\u53ef\u7528\u7684 GPU \u4e0a\u8fd0\u884c\u4ee5\u4e0b\u8d85\u53c2\u6570\u7f51\u683c\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">alpha_values = [1, 4, 8, 16, 32, 64]\nrank_values = [1, 2, 4, 8, 16, 32]\ndevices = range(torch.cuda.device_count())\nlora_query = [\"True\"]\nlora_key = [\"False\", \"True\"]\nlora_value = [\"True\"]\nlora_projection = [\"False\", \"True\"]\nlora_mlp = [\"False\", \"True\"]\nlora_head = [\"False\", \"True\"]<\/pre><\/div>\n\n\n\n<p><code><code>03_finetune-lora-script.py<\/code><\/code>&nbsp;\u811a\u672c\u7684\u8bbe\u7f6e\u65b9\u5f0f\u662f\u5c06\u7ed3\u679c\u4fdd\u5b58\u5230&nbsp;<code><code>results.txt<\/code><\/code>&nbsp;<code>file<\/code>&nbsp;\u4e2d\u3002\u8fd0\u884c\u5b8c\u6210\u540e\u68c0\u67e5&nbsp;<code><code>results.txt<\/code><\/code>&nbsp;\u6587\u4ef6\uff0c\u6700\u4f73\u8d85\u53c2\u6570\u914d\u7f6e\u5982\u4e0b\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">lora_r: 8\nlora_alpha: 1\nlora_query: True\nlora_key: False\nlora_value: True\nlora_projection: False\nlora_mlp: True\nlora_head: False<\/pre><\/div>\n\n\n\n<p>\u5bfc\u81f4\uff1a<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>Val acc: 92.96%<\/li>\n\n\n\n<li>Test acc: 92.39%<\/li>\n<\/ul>\n\n\n\n<p>\u8bf7\u6ce8\u610f\uff0c\u5373\u4f7f LoRA \u8bbe\u7f6e\u53ea\u6709\u4e00\u5c0f\u90e8\u5206\u53ef\u8bad\u7ec3\u53c2\u6570\uff08500k\uff09\u4e0e\uff0866M\uff09\uff0c\u8fd9\u4e9b\u7cbe\u5ea6\u751a\u81f3\u7565\u9ad8\u4e8e\u901a\u8fc7\u5b8c\u5168\u5fae\u8c03\u83b7\u5f97\u7684\u7cbe\u5ea6\u3002<\/p>\n\n\n\n<figure class=\"wp-block-image size-large\"><img loading=\"lazy\" decoding=\"async\" width=\"1024\" height=\"385\" src=\"https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/03\/\u56fe\u7247-3-1024x385.png\" alt=\"\" class=\"wp-image-2602\" srcset=\"https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/03\/\u56fe\u7247-3-1024x385.png 1024w, https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/03\/\u56fe\u7247-3-300x113.png 300w, https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/03\/\u56fe\u7247-3-768x289.png 768w, https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/03\/\u56fe\u7247-3-1536x577.png 1536w, https:\/\/www.aqwu.net\/wp\/wp-content\/uploads\/2024\/03\/\u56fe\u7247-3-2048x769.png 2048w\" sizes=\"auto, (max-width: 1024px) 100vw, 1024px\" \/><\/figure>\n\n\n\n<p><em>\u4e3a\u4e86\u5e94\u7528 LoRA\uff0c\u6211\u4eec\u5c06\u795e\u7ecf\u7f51\u7edc\u4e2d\u73b0\u6709\u7684\u7ebf\u6027\u5c42\u66ff\u6362\u4e3a\u7ed3\u5408\u4e86\u539f\u59cb\u7ebf\u6027\u5c42\u548c LoRALayer \u7684 LinearWithLoRA \u5c42\u3002<\/em><\/p>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>6. \u7ed3\u8bba<\/strong><\/h2>\n\n\n\n<p>\u8fc7\u4ece\u5934\u5f00\u59cb\u7f16\u7801\u6765\u4e86\u89e3\u4f4e\u79e9\u9002\u5e94 (LoRA)\u3002\u901a\u8fc7\u5fae\u8c03 DistilBERT \u6a21\u578b\u8fdb\u884c\u5206\u7c7b\uff0c\u6211\u4eec\u53d1\u73b0 LoRA \u6bd4\u4ec5\u5fae\u8c03\u6a21\u578b\u7684\u6700\u540e\u4e00\u5c42\u8981\u6709\u5229\uff0892.39% \u7684\u6d4b\u8bd5\u51c6\u786e\u7387 vs 86.22% \u7684\u6d4b\u8bd5\u51c6\u786e\u7387\uff09\u3002<\/p>\n\n\n\n<p>\u539f\u6587\u94fe\u63a5\uff1a<a href=\"https:\/\/lightning.ai\/lightning-ai\/studios\/code-lora-from-scratch?tab=overview\">\u4ece\u5934\u5f00\u59cb\u7f16\u5199 LoRA \u4ee3\u7801 &#8212; Code LoRA from Scratch (lightning.ai)<\/a><\/p>\n","protected":false},"excerpt":{"rendered":"<p>\u5728 PyTorch \u4e2d\u5b9e\u73b0 LLMs \u7684\u4f4e\u9636\u9002\u5e94 LoRA \u4ee3\u8868\u4f4e\u9636\u9002\u5e94\uff0c\u662f\u4e00\u79cd\u66f4\u6709\u6548\u5730\u5fae\u8c03 LLMs \u7684\u6d41\u884c [&hellip;]<\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"site-sidebar-layout":"default","site-content-layout":"","ast-site-content-layout":"default","site-content-style":"default","site-sidebar-style":"default","ast-global-header-display":"","ast-banner-title-visibility":"","ast-main-header-display":"","ast-hfb-above-header-display":"","ast-hfb-below-header-display":"","ast-hfb-mobile-header-display":"","site-post-title":"","ast-breadcrumbs-content":"","ast-featured-img":"","footer-sml-layout":"","theme-transparent-header-meta":"","adv-header-id-meta":"","stick-header-meta":"","header-above-stick-meta":"","header-main-stick-meta":"","header-below-stick-meta":"","astra-migrate-meta-layouts":"set","ast-page-background-enabled":"default","ast-page-background-meta":{"desktop":{"background-color":"var(--ast-global-color-4)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"tablet":{"background-color":"","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"mobile":{"background-color":"","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""}},"ast-content-background-meta":{"desktop":{"background-color":"var(--ast-global-color-5)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"tablet":{"background-color":"var(--ast-global-color-5)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"mobile":{"background-color":"var(--ast-global-color-5)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""}},"_jetpack_memberships_contains_paid_content":false,"footnotes":""},"categories":[313,289,443,442,312],"tags":[242,397,395,314,396],"class_list":["post-2567","post","type-post","status-publish","format-standard","hentry","category-chatgpt","category-gpt","category-llm","category-llms","category-openai","tag-chatgpt","tag-llms","tag-lora","tag-openai-api","tag-396"],"views":2997,"jetpack_sharing_enabled":true,"jetpack_featured_media_url":"","_links":{"self":[{"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/posts\/2567","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=%2Fwp%2Fv2%2Fcomments&post=2567"}],"version-history":[{"count":28,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/posts\/2567\/revisions"}],"predecessor-version":[{"id":2605,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/posts\/2567\/revisions\/2605"}],"wp:attachment":[{"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=%2Fwp%2Fv2%2Fmedia&parent=2567"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=%2Fwp%2Fv2%2Fcategories&post=2567"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=%2Fwp%2Fv2%2Ftags&post=2567"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}