{"id":4854,"date":"2024-12-02T17:03:26","date_gmt":"2024-12-02T09:03:26","guid":{"rendered":"https:\/\/www.aqwu.net\/wp\/?p=4854"},"modified":"2024-12-02T17:51:47","modified_gmt":"2024-12-02T09:51:47","slug":"%e5%b0%86%e4%b8%a4%e4%b8%aa%e6%95%99%e5%b8%88%e6%a8%a1%e5%9e%8b%e7%9a%84%e7%9f%a5%e8%af%86%e8%92%b8%e9%a6%8f%e5%88%b0%e4%b8%80%e4%b8%aa%e5%ad%a6%e7%94%9f%e6%a8%a1%e5%9e%8b%e4%b8%ad","status":"publish","type":"post","link":"https:\/\/www.aqwu.net\/wp\/?p=4854","title":{"rendered":"\u5c06\u4e24\u4e2a\u6559\u5e08\u6a21\u578b\u7684\u77e5\u8bc6\u84b8\u998f\u5230\u4e00\u4e2a\u5b66\u751f\u6a21\u578b\u4e2d"},"content":{"rendered":"\n<h1 class=\"wp-block-heading\">\u4e00\u3001<strong>\u63cf\u8ff0<\/strong>\uff1a<\/h1>\n\n\n\n<ul class=\"wp-block-list\">\n<li>\u5c06\u4e24\u4e2a\u6a21\u578b\u7684\u77e5\u8bc6\u63d0\u70bc\u5230\u4e00\u4e2a\u65b0\u6a21\u578b\uff08\u5b66\u751f\u6a21\u578b\uff09\u4e2d\uff0c\u4f7f\u5176\u517c\u5177\u4e24\u8005\u7684\u4f18\u70b9\u3002<\/li>\n<\/ul>\n\n\n\n<p>\u4f7f\u7528 <code>AutoModelForCausalLM<\/code> \u6765\u5b9e\u73b0\u4ece\u4e24\u4e2a\u6559\u5e08\u6a21\u578b\uff08<code>teacher1<\/code> \u548c <code>teacher2<\/code>\uff09\u5230\u5b66\u751f\u6a21\u578b\u7684\u77e5\u8bc6\u84b8\u998f\uff08Knowledge Distillation\uff09\u3002\u5728\u56e0\u679c\u8bed\u8a00\u5efa\u6a21\uff08Causal Language Modeling\uff09\u4e2d\uff0c\u6a21\u578b\u7684\u4efb\u52a1\u662f\u9884\u6d4b\u5e8f\u5217\u4e2d\u7684\u4e0b\u4e00\u4e2a\u8bcd\uff0c\u56e0\u6b64\u8bad\u7ec3\u548c\u635f\u5931\u51fd\u6570\u4e0e\u5e8f\u5217\u5206\u7c7b\u4efb\u52a1\u6709\u6240\u4e0d\u540c\u3002<\/p>\n\n\n\n<p>\u4e0b\u9762\uff0c\u6211\u5c06\u63d0\u4f9b\u5b8c\u6574\u7684\u4ee3\u7801\u5b9e\u73b0\uff0c\u5e76\u5bf9\u6bcf\u4e2a\u6b65\u9aa4\u8fdb\u884c\u8be6\u7ec6\u89e3\u91ca\u3002<\/p>\n\n\n\n<hr class=\"wp-block-separator has-alpha-channel-opacity\"\/>\n\n\n\n<h1 class=\"wp-block-heading\">\u4e8c\u3001<strong>\u6b65\u9aa4\u6982\u8ff0<\/strong><\/h1>\n\n\n\n<ol class=\"wp-block-list\">\n<li><strong>\u73af\u5883\u51c6\u5907<\/strong>\uff1a\u5b89\u88c5\u5fc5\u8981\u7684\u5e93\u3002<\/li>\n\n\n\n<li><strong>\u52a0\u8f7d\u6559\u5e08\u6a21\u578b\u548c\u6570\u636e\u96c6<\/strong>\uff1a\u4ece Hugging Face Hub \u52a0\u8f7d\u6559\u5e08\u6a21\u578b\u548c\u6570\u636e\u96c6\u3002<\/li>\n\n\n\n<li><strong>\u521d\u59cb\u5316\u5b66\u751f\u6a21\u578b<\/strong>\uff1a\u57fa\u4e8e\u76f8\u540c\u7684\u57fa\u7840\u6a21\u578b\uff0c\u521b\u5efa\u4e00\u4e2a\u8f83\u5c0f\u7684\u5b66\u751f\u6a21\u578b\u3002<\/li>\n\n\n\n<li><strong>\u5b9a\u4e49\u84b8\u998f\u635f\u5931\u51fd\u6570<\/strong>\uff1a\u7ed3\u5408\u6559\u5e08\u6a21\u578b\u8f93\u51fa\u548c\u771f\u5b9e\u6807\u7b7e\uff0c\u5b9a\u4e49\u603b\u635f\u5931\u51fd\u6570\u3002<\/li>\n\n\n\n<li><strong>\u6570\u636e\u9884\u5904\u7406\u548c\u52a0\u8f7d<\/strong>\uff1a\u5bf9\u6570\u636e\u8fdb\u884c\u9884\u5904\u7406\uff0c\u5e76\u521b\u5efa\u6570\u636e\u52a0\u8f7d\u5668\u3002<\/li>\n\n\n\n<li><strong>\u8bbe\u7f6e\u8bad\u7ec3\u53c2\u6570<\/strong>\uff1a\u5b9a\u4e49\u8bad\u7ec3\u8d85\u53c2\u6570\uff0c\u5982\u5b66\u4e60\u7387\u3001\u6279\u91cf\u5927\u5c0f\u7b49\u3002<\/li>\n\n\n\n<li><strong>\u521b\u5efa\u8bad\u7ec3\u5faa\u73af<\/strong>\uff1a\u5b9e\u73b0\u8bad\u7ec3\u8fc7\u7a0b\uff0c\u5305\u62ec\u524d\u5411\u4f20\u64ad\u3001\u8ba1\u7b97\u635f\u5931\u3001\u53cd\u5411\u4f20\u64ad\u548c\u53c2\u6570\u66f4\u65b0\u3002<\/li>\n\n\n\n<li><strong>\u4fdd\u5b58\u5b66\u751f\u6a21\u578b<\/strong>\uff1a\u8bad\u7ec3\u5b8c\u6210\u540e\uff0c\u4fdd\u5b58\u6a21\u578b\u3002<\/li>\n\n\n\n<li><strong>\u5b8c\u6574\u4ee3\u7801\u6c47\u603b<\/strong>\uff1a\u63d0\u4f9b\u5b8c\u6574\u7684\u4ee3\u7801\u3002<\/li>\n<\/ol>\n\n\n\n<hr class=\"wp-block-separator has-alpha-channel-opacity\"\/>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>1. \u73af\u5883\u51c6\u5907<\/strong><\/h2>\n\n\n\n<p>\u9996\u5148\uff0c\u786e\u4fdd\u60a8\u5df2\u7ecf\u5b89\u88c5\u4e86\u5fc5\u8981\u7684\u5e93\uff1a<\/p>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:sh decode:true \">pip install transformers datasets torch\n<\/pre><\/div>\n\n\n\n<hr class=\"wp-block-separator has-alpha-channel-opacity\"\/>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>2. \u52a0\u8f7d\u6559\u5e08\u6a21\u578b\u548c\u6570\u636e\u96c6<\/strong><\/h2>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">from transformers import AutoModelForCausalLM, AutoTokenizer\nfrom datasets import load_dataset\nimport torch\nfrom torch.utils.data import DataLoader\nfrom tqdm.auto import tqdm\n\n# \u68c0\u67e5\u8bbe\u5907\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\n# \u5b9a\u4e49\u6a21\u578b\u540d\u79f0\u548c\u6570\u636e\u96c6\u540d\u79f0\nteacher1_name = 'your-teacher1-model-name'  # \u66ff\u6362\u4e3a\u5b9e\u9645\u7684\u6a21\u578b\u540d\u79f0\nteacher2_name = 'your-teacher2-model-name'\nstudent_name = 'your-student-model-name'    # \u53ef\u4ee5\u4f7f\u7528\u76f8\u540c\u7684\u57fa\u7840\u6a21\u578b\ndataset_name = 'your-dataset-name'          # \u66ff\u6362\u4e3a\u5b9e\u9645\u7684\u6570\u636e\u96c6\u540d\u79f0\n\n# \u52a0\u8f7d\u5206\u8bcd\u5668\ntokenizer = AutoTokenizer.from_pretrained(teacher1_name)\n\n# \u52a0\u8f7d\u6559\u5e08\u6a21\u578b\nteacher1 = AutoModelForCausalLM.from_pretrained(teacher1_name).to(device)\nteacher2 = AutoModelForCausalLM.from_pretrained(teacher2_name).to(device)\n\n# \u52a0\u8f7d\u6570\u636e\u96c6\ndataset = load_dataset(dataset_name)\n<\/pre><\/div>\n\n\n\n<p><strong>\u6ce8\u610f<\/strong>\uff1a\u8bf7\u5c06 <code>'your-teacher1-model-name'<\/code>\u3001<code>'your-teacher2-model-name'<\/code> \u548c <code>'your-dataset-name'<\/code> \u66ff\u6362\u4e3a\u5b9e\u9645\u7684\u6a21\u578b\u548c\u6570\u636e\u96c6\u540d\u79f0\u3002<\/p>\n\n\n\n<hr class=\"wp-block-separator has-alpha-channel-opacity\"\/>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>3. \u521d\u59cb\u5316\u5b66\u751f\u6a21\u578b<\/strong><\/h2>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \" >from transformers import AutoConfig\n\n# \u4ece\u6559\u5e08\u6a21\u578b\u7684\u914d\u7f6e\u52a0\u8f7d\uff0c\u5e76\u4fee\u6539\u5c42\u6570\nstudent_config = AutoConfig.from_pretrained(teacher1_name)\nstudent_config.num_hidden_layers = 6  # \u4f8b\u5982\uff0c\u5c06\u5c42\u6570\u51cf\u534a\n\n# \u4ece\u5934\u521d\u59cb\u5316\u5b66\u751f\u6a21\u578b\uff08\u4e0d\u52a0\u8f7d\u9884\u8bad\u7ec3\u6743\u91cd\uff09\nstudent = AutoModelForCausalLM(config=student_config).to(device)\n<\/pre><\/div>\n\n\n\n<hr class=\"wp-block-separator has-alpha-channel-opacity\"\/>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>4. \u5b9a\u4e49\u84b8\u998f\u635f\u5931\u51fd\u6570<\/strong><\/h2>\n\n\n\n<p>\u5bf9\u4e8e\u56e0\u679c\u8bed\u8a00\u6a21\u578b\u7684\u77e5\u8bc6\u84b8\u998f\uff0c\u6211\u4eec\u5c06\u4f7f\u7528\u4ee5\u4e0b\u635f\u5931\u51fd\u6570\uff1a<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong>\u4ea4\u53c9\u71b5\u635f\u5931\uff08Cross-Entropy Loss\uff09<\/strong>\uff1a\u7528\u4e8e\u5b66\u751f\u6a21\u578b\u7684\u8f93\u51fa\u4e0e\u771f\u5b9e\u6807\u7b7e\u4e4b\u95f4\u7684\u635f\u5931\u3002<\/li>\n\n\n\n<li><strong>KL \u6563\u5ea6\u635f\u5931\uff08Kullback-Leibler Divergence\uff09<\/strong>\uff1a\u7528\u4e8e\u5b66\u751f\u6a21\u578b\u4e0e\u6559\u5e08\u6a21\u578b\u8f93\u51fa\u5206\u5e03\u4e4b\u95f4\u7684\u635f\u5931\u3002<\/li>\n<\/ul>\n\n\n\n<hr class=\"wp-block-separator has-alpha-channel-opacity\"\/>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>5. \u6570\u636e\u9884\u5904\u7406\u548c\u52a0\u8f7d<\/strong><\/h2>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \"># \u5b9a\u4e49\u6570\u636e\u9884\u5904\u7406\u51fd\u6570\ndef preprocess_function(examples):\n    # \u5bf9\u6587\u672c\u8fdb\u884c\u62fc\u63a5\u548c\u5206\u8bcd\n    inputs = tokenizer(examples['text'], truncation=True, padding='max_length', max_length=128)\n    inputs['labels'] = inputs['input_ids'].copy()\n    return inputs\n\n# \u5bf9\u6570\u636e\u96c6\u8fdb\u884c\u9884\u5904\u7406\ntokenized_dataset = dataset.map(preprocess_function, batched=True, remove_columns=dataset['train'].column_names)\n\n# \u521b\u5efa\u6570\u636e\u52a0\u8f7d\u5668\ntrain_loader = DataLoader(tokenized_dataset['train'], batch_size=8, shuffle=True)\neval_loader = DataLoader(tokenized_dataset['validation'], batch_size=8)\n<\/pre><\/div>\n\n\n\n<p><strong>\u8bf4\u660e<\/strong>\uff1a<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>\u6211\u4eec\u5c06 <code>labels<\/code> \u8bbe\u7f6e\u4e3a <code>input_ids<\/code> \u7684\u526f\u672c\uff0c\u4ee5\u4fbf\u6a21\u578b\u5b66\u4e60\u9884\u6d4b\u4e0b\u4e00\u4e2a\u8bcd\u3002<\/li>\n\n\n\n<li>\u6839\u636e\u60a8\u7684\u8d44\u6e90\uff0c\u8c03\u6574 <code>max_length<\/code> \u548c <code>batch_size<\/code>\u3002<\/li>\n<\/ul>\n\n\n\n<hr class=\"wp-block-separator has-alpha-channel-opacity\"\/>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>6. \u8bbe\u7f6e\u8bad\u7ec3\u53c2\u6570<\/strong><\/h2>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">from transformers import AdamW\n\n# \u5b9a\u4e49\u4f18\u5316\u5668\noptimizer = AdamW(student.parameters(), lr=5e-5)\n<\/pre><\/div>\n\n\n\n<hr class=\"wp-block-separator has-alpha-channel-opacity\"\/>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>7. \u521b\u5efa\u8bad\u7ec3\u5faa\u73af<\/strong><\/h2>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \">import torch.nn.functional as F\n\n# \u8d85\u53c2\u6570\nnum_epochs = 3\ntemperature = 2.0\nalpha = 0.5  # \u84b8\u998f\u635f\u5931\u4e0e\u771f\u5b9e\u6807\u7b7e\u635f\u5931\u7684\u6743\u91cd\u5e73\u8861\n\nfor epoch in range(num_epochs):\n    student.train()\n    progress_bar = tqdm(train_loader, desc=f\"Epoch {epoch+1}\/{num_epochs}\")\n    \n    for batch in progress_bar:\n        optimizer.zero_grad()\n        \n        # \u5c06\u6570\u636e\u79fb\u52a8\u5230\u8bbe\u5907\n        input_ids = torch.stack(batch['input_ids']).to(device)\n        attention_mask = torch.stack(batch['attention_mask']).to(device)\n        labels = torch.stack(batch['labels']).to(device)\n        \n        # \u83b7\u53d6\u6559\u5e08\u6a21\u578b\u7684\u8f93\u51fa\n        with torch.no_grad():\n            outputs_t1 = teacher1(input_ids=input_ids, attention_mask=attention_mask)\n            outputs_t2 = teacher2(input_ids=input_ids, attention_mask=attention_mask)\n            # \u5e73\u5747\u4e24\u4e2a\u6559\u5e08\u6a21\u578b\u7684 logits\n            logits_teacher = (outputs_t1.logits + outputs_t2.logits) \/ 2\n        \n        # \u5b66\u751f\u6a21\u578b\u7684\u8f93\u51fa\n        outputs_student = student(input_ids=input_ids, attention_mask=attention_mask)\n        logits_student = outputs_student.logits\n        \n        # \u8ba1\u7b97\u84b8\u998f\u635f\u5931\uff08KL \u6563\u5ea6\uff09\n        loss_kd = F.kl_div(\n            input=F.log_softmax(logits_student \/ temperature, dim=-1),\n            target=F.softmax(logits_teacher \/ temperature, dim=-1),\n            reduction='batchmean'\n        ) * (temperature ** 2)\n        \n        # \u8ba1\u7b97\u771f\u5b9e\u6807\u7b7e\u7684\u4ea4\u53c9\u71b5\u635f\u5931\n        loss_ce = F.cross_entropy(logits_student.view(-1, logits_student.size(-1)), labels.view(-1), ignore_index=tokenizer.pad_token_id)\n        \n        # \u603b\u635f\u5931\n        loss = alpha * loss_ce + (1 - alpha) * loss_kd\n        \n        # \u53cd\u5411\u4f20\u64ad\u548c\u4f18\u5316\n        loss.backward()\n        optimizer.step()\n        \n        # \u66f4\u65b0\u8fdb\u5ea6\u6761\n        progress_bar.set_postfix({'loss': loss.item()})\n    \n    # \u6bcf\u4e2a epoch \u7ed3\u675f\u540e\u8fdb\u884c\u8bc4\u4f30\n    student.eval()\n    total_loss = 0\n    with torch.no_grad():\n        for batch in eval_loader:\n            input_ids = torch.stack(batch['input_ids']).to(device)\n            attention_mask = torch.stack(batch['attention_mask']).to(device)\n            labels = torch.stack(batch['labels']).to(device)\n            \n            outputs = student(input_ids=input_ids, attention_mask=attention_mask, labels=labels)\n            loss = outputs.loss\n            total_loss += loss.item()\n    \n    avg_loss = total_loss \/ len(eval_loader)\n    print(f\"Validation Loss after epoch {epoch+1}: {avg_loss:.4f}\")\n<\/pre><\/div>\n\n\n\n<p><strong>\u89e3\u91ca<\/strong>\uff1a<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong>\u6559\u5e08\u6a21\u578b\u8f93\u51fa<\/strong>\uff1a\u4f7f\u7528 <code>torch.no_grad()<\/code>\uff0c\u907f\u514d\u8ba1\u7b97\u68af\u5ea6\u3002<\/li>\n\n\n\n<li><strong>\u84b8\u998f\u635f\u5931\uff08KL \u6563\u5ea6\uff09<\/strong>\uff1a\u8ba1\u7b97\u5b66\u751f\u6a21\u578b\u548c\u6559\u5e08\u6a21\u578b\u8f93\u51fa\u5206\u5e03\u4e4b\u95f4\u7684\u5dee\u5f02\u3002<\/li>\n\n\n\n<li><strong>\u771f\u5b9e\u6807\u7b7e\u635f\u5931\uff08\u4ea4\u53c9\u71b5\uff09<\/strong>\uff1a\u8ba1\u7b97\u5b66\u751f\u6a21\u578b\u8f93\u51fa\u4e0e\u771f\u5b9e\u6807\u7b7e\u4e4b\u95f4\u7684\u635f\u5931\u3002<\/li>\n\n\n\n<li><strong>\u603b\u635f\u5931<\/strong>\uff1a\u84b8\u998f\u635f\u5931\u548c\u771f\u5b9e\u6807\u7b7e\u635f\u5931\u7684\u52a0\u6743\u548c\u3002<\/li>\n\n\n\n<li><strong>\u5ffd\u7565\u586b\u5145<\/strong>\uff1a\u5728\u8ba1\u7b97\u4ea4\u53c9\u71b5\u635f\u5931\u65f6\uff0c\u4f7f\u7528 <code>ignore_index=tokenizer.pad_token_id<\/code>\uff0c\u5ffd\u7565\u586b\u5145\u6807\u8bb0\u3002<\/li>\n<\/ul>\n\n\n\n<hr class=\"wp-block-separator has-alpha-channel-opacity\"\/>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>8. \u4fdd\u5b58\u5b66\u751f\u6a21\u578b<\/strong><\/h2>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \"># \u4fdd\u5b58\u8bad\u7ec3\u597d\u7684\u5b66\u751f\u6a21\u578b\nstudent.save_pretrained('path_to_save_student_model')\ntokenizer.save_pretrained('path_to_save_student_model')\n<\/pre><\/div>\n\n\n\n<hr class=\"wp-block-separator has-alpha-channel-opacity\"\/>\n\n\n\n<h2 class=\"wp-block-heading\"><strong>9. \u5b8c\u6574\u4ee3\u7801\u6c47\u603b<\/strong><\/h2>\n\n\n\n<div class=\"wp-block-urvanov-syntax-highlighter-code-block\"><pre class=\"lang:python decode:true \" >from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AdamW\nfrom datasets import load_dataset\nimport torch\nimport torch.nn.functional as F\nfrom torch.utils.data import DataLoader\nfrom tqdm.auto import tqdm\n\n# \u68c0\u67e5\u8bbe\u5907\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\n# \u5b9a\u4e49\u6a21\u578b\u540d\u79f0\u548c\u6570\u636e\u96c6\u540d\u79f0\nteacher1_name = 'your-teacher1-model-name'  # \u66ff\u6362\u4e3a\u5b9e\u9645\u7684\u6a21\u578b\u540d\u79f0\nteacher2_name = 'your-teacher2-model-name'\nstudent_name = 'your-student-model-name'    # \u4f7f\u7528\u6559\u5e08\u6a21\u578b\u7684\u540d\u79f0\u6216\u5176\u4ed6\ndataset_name = 'your-dataset-name'          # \u66ff\u6362\u4e3a\u5b9e\u9645\u7684\u6570\u636e\u96c6\u540d\u79f0\n\n# \u52a0\u8f7d\u5206\u8bcd\u5668\ntokenizer = AutoTokenizer.from_pretrained(teacher1_name)\n\n# \u52a0\u8f7d\u6559\u5e08\u6a21\u578b\nteacher1 = AutoModelForCausalLM.from_pretrained(teacher1_name).to(device)\nteacher2 = AutoModelForCausalLM.from_pretrained(teacher2_name).to(device)\n\n# \u52a0\u8f7d\u6570\u636e\u96c6\ndataset = load_dataset(dataset_name)\n\n# \u5b9a\u4e49\u6570\u636e\u9884\u5904\u7406\u51fd\u6570\ndef preprocess_function(examples):\n    inputs = tokenizer(examples['text'], truncation=True, padding='max_length', max_length=128)\n    inputs['labels'] = inputs['input_ids'].copy()\n    return inputs\n\n# \u5bf9\u6570\u636e\u96c6\u8fdb\u884c\u9884\u5904\u7406\ntokenized_dataset = dataset.map(preprocess_function, batched=True, remove_columns=dataset['train'].column_names)\n\n# \u521b\u5efa\u6570\u636e\u52a0\u8f7d\u5668\ntrain_loader = DataLoader(tokenized_dataset['train'], batch_size=8, shuffle=True)\neval_loader = DataLoader(tokenized_dataset['validation'], batch_size=8)\n\n# \u4ece\u6559\u5e08\u6a21\u578b\u7684\u914d\u7f6e\u52a0\u8f7d\uff0c\u5e76\u4fee\u6539\u5c42\u6570\nstudent_config = AutoConfig.from_pretrained(teacher1_name)\nstudent_config.num_hidden_layers = 6  # \u4f8b\u5982\uff0c\u5c06\u5c42\u6570\u51cf\u534a\n\n# \u4ece\u5934\u521d\u59cb\u5316\u5b66\u751f\u6a21\u578b\uff08\u4e0d\u52a0\u8f7d\u9884\u8bad\u7ec3\u6743\u91cd\uff09\nstudent = AutoModelForCausalLM(config=student_config).to(device)\n\n# \u5b9a\u4e49\u4f18\u5316\u5668\noptimizer = AdamW(student.parameters(), lr=5e-5)\n\n# \u8d85\u53c2\u6570\nnum_epochs = 3\ntemperature = 2.0\nalpha = 0.5  # \u84b8\u998f\u635f\u5931\u4e0e\u771f\u5b9e\u6807\u7b7e\u635f\u5931\u7684\u6743\u91cd\u5e73\u8861\n\nfor epoch in range(num_epochs):\n    student.train()\n    progress_bar = tqdm(train_loader, desc=f\"Epoch {epoch+1}\/{num_epochs}\")\n    \n    for batch in progress_bar:\n        optimizer.zero_grad()\n        \n        # \u5c06\u6570\u636e\u79fb\u52a8\u5230\u8bbe\u5907\n        input_ids = batch['input_ids'].to(device)\n        attention_mask = batch['attention_mask'].to(device)\n        labels = batch['labels'].to(device)\n        \n        # \u83b7\u53d6\u6559\u5e08\u6a21\u578b\u7684\u8f93\u51fa\n        with torch.no_grad():\n            outputs_t1 = teacher1(input_ids=input_ids, attention_mask=attention_mask)\n            outputs_t2 = teacher2(input_ids=input_ids, attention_mask=attention_mask)\n            logits_teacher = (outputs_t1.logits + outputs_t2.logits) \/ 2\n        \n        # \u5b66\u751f\u6a21\u578b\u7684\u8f93\u51fa\n        outputs_student = student(input_ids=input_ids, attention_mask=attention_mask)\n        logits_student = outputs_student.logits\n        \n        # \u8ba1\u7b97\u84b8\u998f\u635f\u5931\uff08KL \u6563\u5ea6\uff09\n        loss_kd = F.kl_div(\n            input=F.log_softmax(logits_student \/ temperature, dim=-1),\n            target=F.softmax(logits_teacher \/ temperature, dim=-1),\n            reduction='batchmean'\n        ) * (temperature ** 2)\n        \n        # \u8ba1\u7b97\u771f\u5b9e\u6807\u7b7e\u7684\u4ea4\u53c9\u71b5\u635f\u5931\n        loss_ce = F.cross_entropy(logits_student.view(-1, logits_student.size(-1)), labels.view(-1), ignore_index=tokenizer.pad_token_id)\n        \n        # \u603b\u635f\u5931\n        loss = alpha * loss_ce + (1 - alpha) * loss_kd\n        \n        # \u53cd\u5411\u4f20\u64ad\u548c\u4f18\u5316\n        loss.backward()\n        optimizer.step()\n        \n        # \u66f4\u65b0\u8fdb\u5ea6\u6761\n        progress_bar.set_postfix({'loss': loss.item()})\n    \n    # \u6bcf\u4e2a epoch \u7ed3\u675f\u540e\u8fdb\u884c\u8bc4\u4f30\n    student.eval()\n    total_loss = 0\n    with torch.no_grad():\n        for batch in eval_loader:\n            input_ids = batch['input_ids'].to(device)\n            attention_mask = batch['attention_mask'].to(device)\n            labels = batch['labels'].to(device)\n            \n            outputs = student(input_ids=input_ids, attention_mask=attention_mask, labels=labels)\n            loss = outputs.loss\n            total_loss += loss.item()\n    \n    avg_loss = total_loss \/ len(eval_loader)\n    print(f\"Validation Loss after epoch {epoch+1}: {avg_loss:.4f}\")\n\n# \u4fdd\u5b58\u8bad\u7ec3\u597d\u7684\u5b66\u751f\u6a21\u578b\nstudent.save_pretrained('path_to_save_student_model')\ntokenizer.save_pretrained('path_to_save_student_model')\n<\/pre><\/div>\n\n\n\n<hr class=\"wp-block-separator has-alpha-channel-opacity\"\/>\n\n\n\n<h1 class=\"wp-block-heading\">\u4e09\u3001<strong>\u6ce8\u610f\u4e8b\u9879\u548c\u89e3\u91ca<\/strong><\/h1>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>1. \u6570\u636e\u52a0\u8f7d\u548c\u9884\u5904\u7406<\/strong><\/h3>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong>\u62fc\u63a5\u548c\u5206\u8bcd<\/strong>\uff1a\u5728\u9884\u5904\u7406\u51fd\u6570\u4e2d\uff0c\u6211\u4eec\u5bf9\u6587\u672c\u8fdb\u884c\u5206\u8bcd\uff0c\u5e76\u8bbe\u7f6e <code>labels<\/code> \u7b49\u4e8e <code>input_ids<\/code> \u7684\u526f\u672c\u3002<\/li>\n\n\n\n<li><strong>\u5220\u9664\u539f\u59cb\u5217<\/strong>\uff1a\u4f7f\u7528 <code>remove_columns<\/code> \u5220\u9664\u539f\u59cb\u6570\u636e\u96c6\u7684\u5217\uff0c\u907f\u514d\u4e0d\u5fc5\u8981\u7684\u6570\u636e\u5197\u4f59\u3002<\/li>\n<\/ul>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>2. \u5904\u7406\u6279\u6b21\u6570\u636e<\/strong><\/h3>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong><code>torch.stack<\/code><\/strong>\uff1a\u5728\u8bad\u7ec3\u5faa\u73af\u4e2d\uff0c\u6211\u4eec\u4f7f\u7528 <code>torch.stack<\/code> \u5c06\u6279\u6b21\u4e2d\u7684\u5f20\u91cf\u7ec4\u5408\u8d77\u6765\u3002<\/li>\n\n\n\n<li><strong>\u5f20\u91cf\u5f62\u72b6<\/strong>\uff1a\u786e\u4fdd\u8f93\u5165\u7684\u5f20\u91cf\u5f62\u72b6\u6b63\u786e\uff0c\u5339\u914d\u6a21\u578b\u7684\u9884\u671f\u8f93\u5165\u3002<\/li>\n<\/ul>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>3. \u635f\u5931\u51fd\u6570<\/strong><\/h3>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong>KL \u6563\u5ea6\u635f\u5931<\/strong>\uff1a\u7528\u4e8e\u5ea6\u91cf\u5b66\u751f\u6a21\u578b\u8f93\u51fa\u5206\u5e03\u4e0e\u6559\u5e08\u6a21\u578b\u8f93\u51fa\u5206\u5e03\u4e4b\u95f4\u7684\u5dee\u5f02\u3002<\/li>\n\n\n\n<li><strong>\u4ea4\u53c9\u71b5\u635f\u5931<\/strong>\uff1a\u7528\u4e8e\u5ea6\u91cf\u5b66\u751f\u6a21\u578b\u8f93\u51fa\u4e0e\u771f\u5b9e\u6807\u7b7e\u4e4b\u95f4\u7684\u5dee\u5f02\u3002<\/li>\n\n\n\n<li><strong>\u6e29\u5ea6\u53c2\u6570<\/strong>\uff1a\u901a\u8fc7\u6e29\u5ea6\u53c2\u6570\u8f6f\u5316\u6982\u7387\u5206\u5e03\uff0c\u4ee5\u66f4\u597d\u5730\u5b66\u4e60\u6559\u5e08\u6a21\u578b\u7684\u77e5\u8bc6\u3002<\/li>\n<\/ul>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>4. \u8d85\u53c2\u6570\u8c03\u6574<\/strong><\/h3>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong><code>num_epochs<\/code><\/strong>\uff1a\u6839\u636e\u6570\u636e\u96c6\u5927\u5c0f\u548c\u6a21\u578b\u6536\u655b\u60c5\u51b5\u8c03\u6574\u8bad\u7ec3\u8f6e\u6570\u3002<\/li>\n\n\n\n<li><strong><code>learning_rate<\/code><\/strong>\uff1a\u5b66\u4e60\u7387\u5bf9\u8bad\u7ec3\u7a33\u5b9a\u6027\u548c\u6536\u655b\u901f\u5ea6\u6709\u91cd\u8981\u5f71\u54cd\uff0c\u53ef\u6839\u636e\u9700\u8981\u8c03\u6574\u3002<\/li>\n\n\n\n<li><strong><code>temperature<\/code><\/strong>\uff1a\u5e38\u7528\u503c\u4e3a 1 \u5230 5\uff0c\u9700\u6839\u636e\u5b9e\u9a8c\u6548\u679c\u8c03\u6574\u3002<\/li>\n\n\n\n<li><strong><code>alpha<\/code><\/strong>\uff1a\u7528\u4e8e\u5e73\u8861\u84b8\u998f\u635f\u5931\u548c\u771f\u5b9e\u6807\u7b7e\u635f\u5931\uff0c\u53d6\u503c\u8303\u56f4\u5728 0 \u5230 1 \u4e4b\u95f4\u3002<\/li>\n<\/ul>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>5. \u8d44\u6e90\u8981\u6c42<\/strong><\/h3>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong>\u663e\u5b58\u5360\u7528<\/strong>\uff1a\u56e0\u679c\u8bed\u8a00\u6a21\u578b\u7684\u8f93\u51fa\u7ef4\u5ea6\u4e3a\u8bcd\u6c47\u8868\u5927\u5c0f\uff0c\u53ef\u80fd\u5bfc\u81f4\u663e\u5b58\u5360\u7528\u8f83\u9ad8\u3002\u53ef\u901a\u8fc7\u51cf\u5c0f <code>batch_size<\/code> \u6216\u4f7f\u7528\u68af\u5ea6\u7d2f\u79ef\u6765\u7f13\u89e3\u3002<\/li>\n\n\n\n<li><strong>\u8ba1\u7b97\u65f6\u95f4<\/strong>\uff1a\u8bad\u7ec3\u65f6\u95f4\u53ef\u80fd\u8f83\u957f\uff0c\u5efa\u8bae\u4f7f\u7528 GPU \u52a0\u901f\u3002<\/li>\n<\/ul>\n\n\n\n<h3 class=\"wp-block-heading\"><strong>6. \u6cd5\u5f8b\u548c\u7248\u6743<\/strong><\/h3>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong>\u6a21\u578b\u548c\u6570\u636e\u96c6\u8bb8\u53ef<\/strong>\uff1a\u5728\u4f7f\u7528\u9884\u8bad\u7ec3\u6a21\u578b\u548c\u6570\u636e\u96c6\u65f6\uff0c\u52a1\u5fc5\u9075\u5b88\u5176\u8bb8\u53ef\u534f\u8bae\u548c\u4f7f\u7528\u6761\u6b3e\u3002<\/li>\n<\/ul>\n\n\n\n<hr class=\"wp-block-separator has-alpha-channel-opacity\"\/>\n\n\n\n<h1 class=\"wp-block-heading\">\u56db\u3001<strong>\u603b\u7ed3<\/strong><\/h1>\n\n\n\n<p>\u901a\u8fc7\u4e0a\u8ff0\u4ee3\u7801\uff0c\u6211\u4eec\u5b9e\u73b0\u4e86\u4f7f\u7528 <code>AutoModelForCausalLM<\/code> \u7684\u77e5\u8bc6\u84b8\u998f\u8fc7\u7a0b\uff0c\u5c06\u4e24\u4e2a\u6559\u5e08\u6a21\u578b\u7684\u77e5\u8bc6\u84b8\u998f\u5230\u4e00\u4e2a\u8f83\u5c0f\u7684\u5b66\u751f\u6a21\u578b\u4e2d\u3002\u8be5\u5b66\u751f\u6a21\u578b\u80fd\u591f\u5728\u4fdd\u6301\u6027\u80fd\u7684\u540c\u65f6\uff0c\u51cf\u5c0f\u6a21\u578b\u5927\u5c0f\uff0c\u63d0\u9ad8\u63a8\u7406\u901f\u5ea6\u3002<\/p>\n\n\n\n<p><strong>\u5982\u679c\u60a8\u6709\u4efb\u4f55\u7591\u95ee\u6216\u9700\u8981\u8fdb\u4e00\u6b65\u7684\u5e2e\u52a9\uff0c\u8bf7\u968f\u65f6\u544a\u8bc9\u6211\uff01<\/strong><\/p>\n","protected":false},"excerpt":{"rendered":"<p>\u4e00\u3001\u63cf\u8ff0\uff1a \u4f7f\u7528 AutoModelForCausalLM \u6765\u5b9e\u73b0\u4ece\u4e24\u4e2a\u6559\u5e08\u6a21\u578b\uff08teacher1 \u548c tea [&hellip;]<\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"site-sidebar-layout":"default","site-content-layout":"","ast-site-content-layout":"default","site-content-style":"default","site-sidebar-style":"default","ast-global-header-display":"","ast-banner-title-visibility":"","ast-main-header-display":"","ast-hfb-above-header-display":"","ast-hfb-below-header-display":"","ast-hfb-mobile-header-display":"","site-post-title":"","ast-breadcrumbs-content":"","ast-featured-img":"","footer-sml-layout":"","theme-transparent-header-meta":"","adv-header-id-meta":"","stick-header-meta":"","header-above-stick-meta":"","header-main-stick-meta":"","header-below-stick-meta":"","astra-migrate-meta-layouts":"set","ast-page-background-enabled":"default","ast-page-background-meta":{"desktop":{"background-color":"var(--ast-global-color-4)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"tablet":{"background-color":"","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"mobile":{"background-color":"","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""}},"ast-content-background-meta":{"desktop":{"background-color":"var(--ast-global-color-5)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"tablet":{"background-color":"var(--ast-global-color-5)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"mobile":{"background-color":"var(--ast-global-color-5)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""}},"_jetpack_memberships_contains_paid_content":false,"footnotes":""},"categories":[444,443,442],"tags":[296,569,568],"class_list":["post-4854","post","type-post","status-publish","format-standard","hentry","category-ai","category-llm","category-llms","tag-ai","tag-knowledge-distillation","tag-568"],"views":3270,"jetpack_sharing_enabled":true,"jetpack_featured_media_url":"","_links":{"self":[{"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/posts\/4854","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=%2Fwp%2Fv2%2Fcomments&post=4854"}],"version-history":[{"count":4,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/posts\/4854\/revisions"}],"predecessor-version":[{"id":4860,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/posts\/4854\/revisions\/4860"}],"wp:attachment":[{"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=%2Fwp%2Fv2%2Fmedia&parent=4854"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=%2Fwp%2Fv2%2Fcategories&post=4854"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=%2Fwp%2Fv2%2Ftags&post=4854"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}