{"id":2106,"date":"2024-02-19T14:00:23","date_gmt":"2024-02-19T06:00:23","guid":{"rendered":"https:\/\/www.aqwu.net\/wp\/?p=2106"},"modified":"2024-04-28T20:04:36","modified_gmt":"2024-04-28T12:04:36","slug":"%e4%bb%8e%e5%a4%b4%e5%bc%80%e5%a7%8b%e5%ae%9e%e7%8e%b0%e4%b8%80%e4%b8%aa%e7%ae%80%e5%8c%96%e7%9a%84%e7%89%88%e6%9c%ac%e7%9a%84gpt%e6%a8%a1%e5%9e%8b","status":"publish","type":"post","link":"https:\/\/www.aqwu.net\/wp\/?p=2106","title":{"rendered":"\u4ece\u5934\u5f00\u59cb\u5b9e\u73b0\u4e00\u4e2a\u7b80\u5316\u7684\u7248\u672c\u7684GPT\u6a21\u578b"},"content":{"rendered":"\n<p>\u5982\u679c\u4f60\u60f3\u4ece\u5934\u5f00\u59cb\u5b9e\u73b0\u4e00\u4e2a\u7b80\u5316\u7684\u7248\u672c\u7684GPT\u6a21\u578b\uff0c\u800c\u4e0d\u4f9d\u8d56\u4e8e\u73b0\u6210\u7684GPT-2\u6a21\u578b\u5e93\uff0c\u4f60\u53ef\u4ee5\u91c7\u7528PyTorch\u8fd9\u6837\u7684\u6df1\u5ea6\u5b66\u4e60\u6846\u67b6\u3002\u4e0b\u9762\u662f\u4e00\u4e2a\u975e\u5e38\u57fa\u7840\u7684\u4f8b\u5b50\uff0c\u5c55\u793a\u4e86\u5982\u4f55\u5b9e\u73b0\u4e00\u4e2a\u7b80\u5316\u7684Transformer\u6a21\u578b\u67b6\u6784\uff0c\u8fd9\u662f\u6784\u5efaGPT\u6a21\u578b\u7684\u57fa\u7840\u3002<\/p>\n\n\n\n<p>\u8fd9\u4e2a\u4f8b\u5b50\u5c06\u4e0d\u4f1a\u8986\u76d6GPT-2\u7684\u6240\u6709\u590d\u6742\u6027\u548c\u7279\u6027\uff0c\u4f46\u53ef\u4ee5\u63d0\u4f9b\u4e00\u4e2a\u8d77\u70b9\uff0c\u5e2e\u52a9\u4f60\u7406\u89e3\u5982\u4f55\u4ece\u5934\u5f00\u59cb\u6784\u5efa\u7c7b\u4f3cGPT\u7684\u6a21\u578b\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading has-large-font-size\"><strong>1. \u57fa\u7840Transformer\u5757<\/strong><\/h3>\n\n\n\n<p>\u9996\u5148\uff0c\u6211\u4eec\u5b9a\u4e49\u4e00\u4e2a\u57fa\u7840\u7684Transformer\u5757\uff0c\u5b83\u662f\u6784\u6210GPT\u6a21\u578b\u7684\u57fa\u672c\u5355\u5143\u3002\u8fd9\u4e2a\u5757\u5c06\u5305\u62ec\u81ea\u6ce8\u610f\u529b\u673a\u5236\u548c\u524d\u9988\u795e\u7ecf\u7f51\u7edc\u3002<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport re\nimport jieba\n\nfrom torch.utils.data import Dataset, DataLoader\nfrom torch.nn.utils.rnn import pad_sequence\n\n\nclass SelfAttention(nn.Module):\n    def __init__(self, embed_size, heads):\n        super(SelfAttention, self).__init__()\n        self.embed_size = embed_size\n        self.heads = heads\n        self.head_dim = embed_size \/\/ heads\n\n        assert self.head_dim * heads == embed_size, \"Embedding size needs to be divisible by heads\"\n\n        self.values = nn.Linear(embed_size, embed_size, bias=False)\n        self.keys = nn.Linear(embed_size, embed_size, bias=False)\n        self.queries = nn.Linear(embed_size, embed_size, bias=False)\n        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)\n\n    def forward(self, value, key, query, mask):\n        N = query.shape&#91;0]\n        value_len, key_len, query_len = value.shape&#91;1], key.shape&#91;1], query.shape&#91;1]\n\n        # Split the embedding into `self.heads` pieces\n        values = self.values(value).view(N, value_len, self.heads, self.head_dim)\n        keys = self.keys(key).view(N, key_len, self.heads, self.head_dim)\n        queries = self.queries(query).view(N, query_len, self.heads, self.head_dim)\n\n        # Transpose for attention dot product: from &#91;N, value_len, self.heads, self.head_dim]\n        # to &#91;N, self.heads, value_len, self.head_dim] to match the shape for `torch.einsum`\n        values = values.transpose(1, 2)\n        keys = keys.transpose(1, 2)\n        queries = queries.transpose(1, 2)\n\n        # Attention mechanism\n        energy = torch.einsum(\"nqhd,nkhd->nhqk\", &#91;queries, keys])\n\n        if mask is not None:\n            energy = energy.masked_fill(mask == 0, float(\"-1e20\"))\n\n        attention = torch.softmax(energy \/ (self.embed_size ** (1 \/ 2)), dim=3)\n\n        # \u91cd\u5851\u524d\u8fdb\u884c\u5f20\u91cf\u4e58\u6cd5\uff0c\u7136\u540e\u91cd\u5851\u56de &#91;batch_size, seq_len, heads * head_dim]\n        out = torch.einsum(\"nhql,nlhd->nqhd\", &#91;attention, values]).reshape(\n            N, query_len, self.heads * self.head_dim\n        )\n\n        out = self.fc_out(out)\n        return out\n\nclass TransformerBlock(nn.Module):\n    def __init__(self, embed_size, heads, dropout, forward_expansion):\n        super(TransformerBlock, self).__init__()\n        self.attention = SelfAttention(embed_size, heads)\n        self.norm1 = nn.LayerNorm(embed_size)\n        self.norm2 = nn.LayerNorm(embed_size)\n\n        self.feed_forward = nn.Sequential(\n            nn.Linear(embed_size, forward_expansion * embed_size),\n            nn.ReLU(),\n            nn.Linear(forward_expansion * embed_size, embed_size),\n        )\n\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, value, key, query, mask):\n        attention = self.attention(value, key, query, mask)\n\n        # Add skip connection, followed by layer normalization\n        x = self.norm1(attention + query)\n        forward = self.feed_forward(x)\n        out = self.norm2(forward + x)  # Add skip connection, followed by layer normalization\n        return out\n<\/code><\/pre>\n\n\n\n<h3 class=\"wp-block-heading has-large-font-size\"><strong>2. \u7b80\u5316\u7248\u7684GPT\u6a21\u578b<\/strong><\/h3>\n\n\n\n<p>\u63a5\u4e0b\u6765\uff0c\u6211\u4eec\u5b9a\u4e49\u4e00\u4e2a\u7b80\u5316\u7248\u7684GPT\u6a21\u578b\uff0c\u5b83\u5229\u7528\u4e0a\u9762\u5b9a\u4e49\u7684Transformer\u5757\u3002<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>class GPT(nn.Module):\n    def __init__(self, embed_size, num_layers, heads, forward_expansion, dropout, vocab_size, max_length):\n        super(GPT, self).__init__()\n        self.embed_size = embed_size\n        self.transformer_blocks = nn.ModuleList(\n            &#91;\n                TransformerBlock(\n                    embed_size,\n                    heads,\n                    dropout=dropout,\n                    forward_expansion=forward_expansion,\n                )\n                for _ in range(num_layers)\n            ]\n        )\n\n        self.word_embedding = nn.Embedding(vocab_size, embed_size)\n        self.position_embedding = nn.Embedding(max_length, embed_size)\n\n    def forward(self, x, mask):\n        N, seq_length = x.shape\n        print(f\"Input shape: {x.shape}\")  # \u6253\u5370\u8f93\u5165\u5f62\u72b6\n\n        positions = torch.arange(0, seq_length).expand(N, seq_length).to(x.device)\n        out = self.word_embedding(x) + self.position_embedding(positions)\n        print(f\"After embedding and position shape: {out.shape}\")  # \u6253\u5370\u5d4c\u5165\u548c\u4f4d\u7f6e\u7f16\u7801\u540e\u7684\u5f62\u72b6\n\n        for layer in self.transformer_blocks:\n            out = layer(out, out, out, mask)\n            print(f\"After transformer block shape: {out.shape}\")  # \u6253\u5370\u7ecf\u8fc7\u6bcf\u4e2aTransformer\u5757\u540e\u7684\u5f62\u72b6\n\n        return out\n<\/code><\/pre>\n\n\n\n<p>\u7ee7\u7eed\u524d\u9762\u7684\u7b80\u5316\u7248GPT\u6a21\u578b\u5b9e\u73b0\uff0c\u4e0b\u9762\u63d0\u4f9b\u4e00\u4e2a\u57fa\u672c\u7684\u8bad\u7ec3\u6846\u67b6\u3002\u8fd9\u4e2a\u4f8b\u5b50\u5c06\u5c55\u793a\u5982\u4f55\u51c6\u5907\u6570\u636e\u3001\u5b9a\u4e49\u635f\u5931\u51fd\u6570\u3001\u9009\u62e9\u4f18\u5316\u5668\uff0c\u5e76\u6267\u884c\u8bad\u7ec3\u5faa\u73af\u3002\u8bf7\u6ce8\u610f\uff0c\u8fd9\u662f\u4e00\u4e2a\u9ad8\u5ea6\u7b80\u5316\u7684\u4f8b\u5b50\uff0c\u65e8\u5728\u6f14\u793a\u57fa\u672c\u6982\u5ff5\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading has-large-font-size\"><strong>3. \u51c6\u5907\u6570\u636e<\/strong><\/h3>\n\n\n\n<p>\u5047\u8bbe\u4f60\u5df2\u7ecf\u6709\u4e86\u4e00\u4e2a\u6587\u672c\u6570\u636e\u96c6\uff0c\u5e76\u4e14\u4f60\u5df2\u7ecf\u8fdb\u884c\u4e86\u9884\u5904\u7406\uff08\u4f8b\u5982\uff0c\u5206\u8bcd\u548c\u8f6c\u6362\u4e3a\u8bcd\u6c47\u7d22\u5f15\uff09\u3002\u4e3a\u4e86\u7b80\u5355\u8d77\u89c1\uff0c\u8fd9\u91cc\u4e0d\u5c55\u793a\u6570\u636e\u9884\u5904\u7406\u7684\u4ee3\u7801\u3002\u6211\u4eec\u5c06\u76f4\u63a5\u4ece\u521b\u5efa\u6570\u636e\u52a0\u8f7d\u5668\u5f00\u59cb\u3002<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>def clean_text_mixed_with_symbols(text):\n    # \u4fdd\u7559\u4e2d\u6587\u3001\u82f1\u6587\u5b57\u7b26\u3001\u6570\u5b57\u548c\u5e38\u89c1\u7684\u6807\u70b9\u7b26\u53f7\n    # \u6ce8\u610f\uff1a\u6839\u636e\u9700\u8981\uff0c\u4f60\u53ef\u4ee5\u5728\u8fd9\u91cc\u6dfb\u52a0\u6216\u5220\u9664\u7279\u5b9a\u7684\u7b26\u53f7\n    text = re.sub(r'&#91;^\\u4e00-\\u9fffA-Za-z0-9\uff0c\u3002\uff01\uff1f\u3001\uff1b\uff1a\u201c\u201d\u2018\u2019\uff08\uff09\u300a\u300b\u3010\u3011\u2014\u2026]+', ' ', text)\n    return text.strip()\n\ndef preprocess_text_mixed_with_symbols(text):\n    text = clean_text_mixed_with_symbols(text)\n    tokens = &#91;]\n    for token in jieba.cut(text, cut_all=False):\n        token = token.strip()\n        if token:\n            tokens.append(token)\n    return tokens\n\ndef load_and_preprocess_data(file_paths):\n    # \u8fd9\u91cc\u7b80\u5316\u5904\u7406\uff0c\u5177\u4f53\u5b9e\u73b0\u4f9d\u636e\u4f60\u7684\u9700\u6c42\u5b9a\n    texts = &#91;]\n    for file_path in file_paths:\n        with open(file_path, 'r', encoding='utf-8') as file:\n            text = file.read()\n            # \u6dfb\u52a0\u6587\u672c\u6e05\u6d17\u548c\u9884\u5904\u7406\u903b\u8f91\n            processed_text = preprocess_text_mixed_with_symbols(text)\n            texts.append(processed_text)\n    return texts\n\n\nclass TextDataset(Dataset):\n    def __init__(self, indexed_texts, vocab_size):\n        self.texts = &#91;torch.tensor(text, dtype=torch.long) for text in indexed_texts]  # \u7d22\u5f15\u5316\u6587\u672c\u8f6c\u6362\u4e3atensor\n        self.vocab_size = vocab_size\n\n    def __len__(self):\n        return len(self.texts)\n\n    def __getitem__(self, idx):\n        return self.texts&#91;idx]\n\n    def collate_fn(self, batch):\n        input_ids = &#91;item&#91;:-1] for item in batch]\n        target_ids = &#91;item&#91;1:] for item in batch]\n        input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=0)\n        target_ids_padded = pad_sequence(target_ids, batch_first=True, padding_value=0)\n        return input_ids_padded, target_ids_padded\n\n\ndef build_vocab(texts):\n    vocab = set(token for text in texts for token in text)\n    vocab_to_index = {word: i for i, word in enumerate(vocab, start=1)}  # \u4ece1\u5f00\u59cb\u7f16\u53f7\n    return vocab_to_index\n\n\ndef index_text(text, vocab_to_index):\n    return &#91;vocab_to_index&#91;token] for token in text if token in vocab_to_index]\n<\/code><\/pre>\n\n\n\n<h2 class=\"wp-block-heading has-large-font-size\"><strong>4. \u5b9a\u4e49\u6a21\u578b\u3001\u635f\u5931\u51fd\u6570\u548c\u4f18\u5316\u5668<\/strong><\/h2>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code># \u5b9e\u4f8b\u5316\u6a21\u578b\nmodel = GPT(\n    embed_size=embed_size,\n    num_layers=num_layers,\n    heads=heads,\n    forward_expansion=forward_expansion,\n    dropout=dropout,\n    vocab_size=vocab_size,\n    max_length=max_length\n)\n\nloss_fn = nn.CrossEntropyLoss()\noptimizer = optim.Adam(model.parameters(), lr=0.0001)\n<\/code><\/pre>\n\n\n\n<h3 class=\"wp-block-heading has-large-font-size\"><strong>5. \u8bad\u7ec3\u5faa\u73af<\/strong><\/h3>\n\n\n\n<p>\u6700\u540e\uff0c\u6211\u4eec\u6267\u884c\u8bad\u7ec3\u5faa\u73af\uff0c\u6bcf\u4e2a\u6279\u6b21\u5904\u7406\u6570\u636e\uff0c\u8ba1\u7b97\u635f\u5931\uff0c\u5e76\u66f4\u65b0\u6a21\u578b\u7684\u6743\u91cd\u3002<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>def train(model, dataloader, loss_fn, optimizer, device, epochs):\n    model.train()\n    model.to(device)\n    \n    for epoch in range(epochs):\n        for batch_idx, (input_ids, target_ids) in enumerate(dataloader):\n            input_ids = input_ids.to(device)\n            target_ids = target_ids.to(device)\n            \n            # \u524d\u5411\u4f20\u64ad\n            predictions = model(input_ids, mask=None) # \u8fd9\u91cc\u7b80\u5316\u5904\u7406\uff0c\u6ca1\u6709\u4f7f\u7528mask\n            predictions = predictions.view(-1, predictions.size(-1))\n            target_ids = target_ids.view(-1)\n            \n            # \u8ba1\u7b97\u635f\u5931\n            loss = loss_fn(predictions, target_ids)\n            \n            # \u53cd\u5411\u4f20\u64ad\u548c\u4f18\u5316\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n            \n            if batch_idx % 100 == 0:\n                print(f\"Epoch {epoch} Batch {batch_idx} Loss {loss.item()}\")\n<\/code><\/pre>\n\n\n\n<p>\u8fd9\u6bb5\u4ee3\u7801\u5c55\u793a\u4e86\u5982\u4f55\u8bbe\u7f6e\u548c\u6267\u884c\u6a21\u578b\u7684\u8bad\u7ec3\u8fc7\u7a0b\u3002\u8bf7\u6ce8\u610f\uff0c\u8fd9\u53ea\u662f\u4e00\u4e2a\u8d77\u70b9\uff0c\u771f\u5b9e\u4e16\u754c\u7684\u5e94\u7528\u53ef\u80fd\u9700\u8981\u66f4\u590d\u6742\u7684\u6570\u636e\u5904\u7406\u3001\u6a21\u578b\u8c03\u53c2\u3001\u6b63\u5219\u5316\u7b56\u7565\u3001\u4ee5\u53ca\u8bad\u7ec3\u8fc7\u7a0b\u76d1\u63a7\u3002\u6b64\u5916\uff0c\u4e3a\u4e86\u5904\u7406\u5927\u89c4\u6a21\u6570\u636e\u96c6\u548c\u6a21\u578b\uff0c\u53ef\u80fd\u8fd8\u9700\u8981\u8003\u8651\u5206\u5e03\u5f0f\u8bad\u7ec3\u548c\u6a21\u578b\u5e76\u884c\u5316\u3002<\/p>\n\n\n\n<h2 class=\"wp-block-heading has-large-font-size\"><strong>6. <\/strong>\u5b8c\u6574\u7684\u8bad\u7ec3\u4ee3\u7801<\/h2>\n\n\n\n<p>\u4ee3\u7801\u5982\u4e0b\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport re\nimport jieba\n\nfrom torch.utils.data import Dataset, DataLoader\nfrom torch.nn.utils.rnn import pad_sequence\n\n\nclass SelfAttention(nn.Module):\n    def __init__(self, embed_size, heads):\n        super(SelfAttention, self).__init__()\n        self.embed_size = embed_size\n        self.heads = heads\n        self.head_dim = embed_size \/\/ heads\n\n        assert self.head_dim * heads == embed_size, \"Embedding size needs to be divisible by heads\"\n\n        self.values = nn.Linear(embed_size, embed_size, bias=False)\n        self.keys = nn.Linear(embed_size, embed_size, bias=False)\n        self.queries = nn.Linear(embed_size, embed_size, bias=False)\n        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)\n\n    def forward(self, value, key, query, mask):\n        N = query.shape&#91;0]\n        value_len, key_len, query_len = value.shape&#91;1], key.shape&#91;1], query.shape&#91;1]\n\n        # Split the embedding into `self.heads` pieces\n        values = self.values(value).view(N, value_len, self.heads, self.head_dim)\n        keys = self.keys(key).view(N, key_len, self.heads, self.head_dim)\n        queries = self.queries(query).view(N, query_len, self.heads, self.head_dim)\n\n        # Transpose for attention dot product: from &#91;N, value_len, self.heads, self.head_dim]\n        # to &#91;N, self.heads, value_len, self.head_dim] to match the shape for `torch.einsum`\n        values = values.transpose(1, 2)\n        keys = keys.transpose(1, 2)\n        queries = queries.transpose(1, 2)\n\n        # Attention mechanism\n        energy = torch.einsum(\"nqhd,nkhd-&gt;nhqk\", &#91;queries, keys])\n\n        if mask is not None:\n            energy = energy.masked_fill(mask == 0, float(\"-1e20\"))\n\n        attention = torch.softmax(energy \/ (self.embed_size ** (1 \/ 2)), dim=3)\n\n        # \u91cd\u5851\u524d\u8fdb\u884c\u5f20\u91cf\u4e58\u6cd5\uff0c\u7136\u540e\u91cd\u5851\u56de &#91;batch_size, seq_len, heads * head_dim]\n        out = torch.einsum(\"nhql,nlhd-&gt;nqhd\", &#91;attention, values]).reshape(\n            N, query_len, self.heads * self.head_dim\n        )\n\n        out = self.fc_out(out)\n        return out\n\nclass TransformerBlock(nn.Module):\n    def __init__(self, embed_size, heads, dropout, forward_expansion):\n        super(TransformerBlock, self).__init__()\n        self.attention = SelfAttention(embed_size, heads)\n        self.norm1 = nn.LayerNorm(embed_size)\n        self.norm2 = nn.LayerNorm(embed_size)\n\n        self.feed_forward = nn.Sequential(\n            nn.Linear(embed_size, forward_expansion * embed_size),\n            nn.ReLU(),\n            nn.Linear(forward_expansion * embed_size, embed_size),\n        )\n\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, value, key, query, mask):\n        attention = self.attention(value, key, query, mask)\n\n        # Add skip connection, followed by layer normalization\n        x = self.norm1(attention + query)\n        forward = self.feed_forward(x)\n        out = self.norm2(forward + x)  # Add skip connection, followed by layer normalization\n        return out\n\nclass GPT(nn.Module):\n    def __init__(self, embed_size, num_layers, heads, forward_expansion, dropout, vocab_size, max_length):\n        super(GPT, self).__init__()\n        self.embed_size = embed_size\n        self.transformer_blocks = nn.ModuleList(\n            &#91;\n                TransformerBlock(\n                    embed_size,\n                    heads,\n                    dropout=dropout,\n                    forward_expansion=forward_expansion,\n                )\n                for _ in range(num_layers)\n            ]\n        )\n\n        self.word_embedding = nn.Embedding(vocab_size, embed_size)\n        self.position_embedding = nn.Embedding(max_length, embed_size)\n\n    def forward(self, x, mask):\n        N, seq_length = x.shape\n        print(f\"Input shape: {x.shape}\")  # \u6253\u5370\u8f93\u5165\u5f62\u72b6\n\n        positions = torch.arange(0, seq_length).expand(N, seq_length).to(x.device)\n        out = self.word_embedding(x) + self.position_embedding(positions)\n        print(f\"After embedding and position shape: {out.shape}\")  # \u6253\u5370\u5d4c\u5165\u548c\u4f4d\u7f6e\u7f16\u7801\u540e\u7684\u5f62\u72b6\n\n        for layer in self.transformer_blocks:\n            out = layer(out, out, out, mask)\n            print(f\"After transformer block shape: {out.shape}\")  # \u6253\u5370\u7ecf\u8fc7\u6bcf\u4e2aTransformer\u5757\u540e\u7684\u5f62\u72b6\n\n        return out\n\ndef clean_text_mixed_with_symbols(text):\n    # \u4fdd\u7559\u4e2d\u6587\u3001\u82f1\u6587\u5b57\u7b26\u3001\u6570\u5b57\u548c\u5e38\u89c1\u7684\u6807\u70b9\u7b26\u53f7\n    # \u6ce8\u610f\uff1a\u6839\u636e\u9700\u8981\uff0c\u4f60\u53ef\u4ee5\u5728\u8fd9\u91cc\u6dfb\u52a0\u6216\u5220\u9664\u7279\u5b9a\u7684\u7b26\u53f7\n    text = re.sub(r'&#91;^\\u4e00-\\u9fffA-Za-z0-9\uff0c\u3002\uff01\uff1f\u3001\uff1b\uff1a\u201c\u201d\u2018\u2019\uff08\uff09\u300a\u300b\u3010\u3011\u2014\u2026]+', ' ', text)\n    return text.strip()\n\ndef preprocess_text_mixed_with_symbols(text):\n    text = clean_text_mixed_with_symbols(text)\n    tokens = &#91;]\n    for token in jieba.cut(text, cut_all=False):\n        token = token.strip()\n        if token:\n            tokens.append(token)\n    return tokens\n\ndef load_and_preprocess_data(file_paths):\n    # \u8fd9\u91cc\u7b80\u5316\u5904\u7406\uff0c\u5177\u4f53\u5b9e\u73b0\u4f9d\u636e\u4f60\u7684\u9700\u6c42\u5b9a\n    texts = &#91;]\n    for file_path in file_paths:\n        with open(file_path, 'r', encoding='utf-8') as file:\n            text = file.read()\n            # \u6dfb\u52a0\u6587\u672c\u6e05\u6d17\u548c\u9884\u5904\u7406\u903b\u8f91\n            processed_text = preprocess_text_mixed_with_symbols(text)\n            texts.append(processed_text)\n    return texts\n\n\nclass TextDataset(Dataset):\n    def __init__(self, indexed_texts, vocab_size):\n        self.texts = &#91;torch.tensor(text, dtype=torch.long) for text in indexed_texts]  # \u7d22\u5f15\u5316\u6587\u672c\u8f6c\u6362\u4e3atensor\n        self.vocab_size = vocab_size\n\n    def __len__(self):\n        return len(self.texts)\n\n    def __getitem__(self, idx):\n        return self.texts&#91;idx]\n\n    def collate_fn(self, batch):\n        input_ids = &#91;item&#91;:-1] for item in batch]\n        target_ids = &#91;item&#91;1:] for item in batch]\n        input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=0)\n        target_ids_padded = pad_sequence(target_ids, batch_first=True, padding_value=0)\n        return input_ids_padded, target_ids_padded\n\n\ndef build_vocab(texts):\n    vocab = set(token for text in texts for token in text)\n    vocab_to_index = {word: i for i, word in enumerate(vocab, start=1)}  # \u4ece1\u5f00\u59cb\u7f16\u53f7\n    return vocab_to_index\n\n\ndef index_text(text, vocab_to_index):\n    return &#91;vocab_to_index&#91;token] for token in text if token in vocab_to_index]\n\ndef train(model, dataloader, loss_fn, optimizer, device, epochs):\n    model.train()\n    model.to(device)\n    \n    for epoch in range(epochs):\n        for batch_idx, (input_ids, target_ids) in enumerate(dataloader):\n            input_ids = input_ids.to(device)\n            target_ids = target_ids.to(device)\n            \n            # \u524d\u5411\u4f20\u64ad\n            predictions = model(input_ids, mask=None) # \u8fd9\u91cc\u7b80\u5316\u5904\u7406\uff0c\u6ca1\u6709\u4f7f\u7528mask\n            predictions = predictions.view(-1, predictions.size(-1))\n            target_ids = target_ids.view(-1)\n            \n            # \u8ba1\u7b97\u635f\u5931\n            loss = loss_fn(predictions, target_ids)\n            \n            # \u53cd\u5411\u4f20\u64ad\u548c\u4f18\u5316\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n            \n            if batch_idx % 100 == 0:\n                print(f\"Epoch {epoch} Batch {batch_idx} Loss {loss.item()}\")\n\n\n# \u6a21\u578b\u53c2\u6570\nvocab_size = 10000  # \u5047\u8bbe\u7684\u8bcd\u6c47\u8868\u5927\u5c0f\nembed_size = 256\nmax_length = 100\nnum_layers = 6\nheads = 8\nforward_expansion = 4\ndropout = 0.1\n\n# \u5047\u8bbe\u4f60\u7684\u6587\u672c\u6587\u4ef6\u8def\u5f84\n#file_paths = &#91;'path\/to\/your\/text1.txt', 'path\/to\/your\/text2.txt']\n#texts = load_and_preprocess_data(file_paths)\n\n# \u5047\u8bbe\u6587\u672c\u5305\u542b\u4e2d\u6587\u3001\u82f1\u6587\u548c\u5e38\u7528\u7b26\u53f7\ntext = \"1977\u5e74\uff0c\u4e09\u4f4d\u6570\u5b66\u5bb6Rivest\u3001Shamir \u548c Adleman \u8bbe\u8ba1\u4e86\u4e00\u79cd\u7b97\u6cd5\uff0c\u53ef\u4ee5\u5b9e\u73b0\u975e\u5bf9\u79f0\u52a0\u5bc6\u3002\u8fd9\u79cd\u7b97\u6cd5\u7528\u4ed6\u4eec\u4e09\u4e2a\u4eba\u7684\u540d\u5b57\u547d\u540d\uff0c\u53eb\u505aRSA\u7b97\u6cd5\u3002\u4ece\u90a3\u65f6\u76f4\u5230\u73b0\u5728\uff0cRSA\u7b97\u6cd5\u4e00\u76f4\u662f\u6700\u5e7f\u4e3a\u4f7f\u7528\u7684\u201d\u975e\u5bf9\u79f0\u52a0\u5bc6\u7b97\u6cd5\u201d\u3002\u6beb\u4e0d\u5938\u5f20\u5730\u8bf4\uff0c\u53ea\u8981\u6709\u8ba1\u7b97\u673a\u7f51\u7edc\u7684\u5730\u65b9\uff0c\u5c31\u6709RSA\u7b97\u6cd5\u3002\"\n\n# \u9884\u5904\u7406\u6587\u672c\ntexts = preprocess_text_mixed_with_symbols(text)\n# \u8f93\u51fa\u5206\u8bcd\u7ed3\u679c\nprint(texts)\n\n# \u5047\u8bbe`texts`\u662f\u5206\u8bcd\u540e\u7684\u6587\u672c\u5217\u8868\nvocab_to_index = build_vocab(texts)\nindexed_texts = &#91;index_text(text, vocab_to_index) for text in texts]\n\nvocab_size=len(vocab_to_index) + 1\n# \u73b0\u5728`texts`\u5e94\u8be5\u662f\u7d22\u5f15\u5316\u540e\u7684\u6587\u672c\u5217\u8868\ndataset = TextDataset(indexed_texts, vocab_size=len(vocab_to_index) + 1)  # +1\u56e0\u4e3a\u4ece1\u5f00\u59cb\u7f16\u53f7\ndataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=dataset.collate_fn)\n\n# \u5b9e\u4f8b\u5316\u6a21\u578b\nmodel = GPT(\n    embed_size=embed_size,\n    num_layers=num_layers,\n    heads=heads,\n    forward_expansion=forward_expansion,\n    dropout=dropout,\n    vocab_size=vocab_size,\n    max_length=max_length\n)\n\nloss_fn = nn.CrossEntropyLoss()\noptimizer = optim.Adam(model.parameters(), lr=0.0001)\n\nepochs = 1\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\ntrain(model, dataloader, loss_fn, optimizer, device, epochs)\n\n# \u4fdd\u5b58\u6a21\u578b\u53c2\u6570\n# model_path = \"gpt_simple_model.pth\"\n# torch.save(model.state_dict(), model_path)\n\n# \u5982\u679c\u8981\u4fdd\u5b58\u6574\u4e2a\u6a21\u578b\uff08\u5305\u62ec\u6a21\u578b\u7ed3\u6784\uff09\uff0c\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u65b9\u5f0f\nmodel_path = \"gpt_simple_model_full.pth\"\ntorch.save(model, model_path)\n<\/code><\/pre>\n\n\n\n<p>\u4ee3\u7801\u63d0\u4f9b\u4e86\u4e00\u4e2a\u4f7f\u7528PyTorch\u5b9e\u73b0\u7c7b\u4f3cGPT\u6a21\u578b\u7684\u5168\u9762\u793a\u4f8b\uff0c\u8fd9\u4e2a\u793a\u4f8b\u6db5\u76d6\u4e86\u591a\u4e2a\u5173\u952e\u65b9\u9762\uff0c\u5305\u62ec\u81ea\u6ce8\u610f\u529b\u5c42\u7684\u5b9a\u4e49\u3001\u53d8\u538b\u5668\u5757\u3001\u6574\u4f53GPT\u6a21\u578b\u3001\u6587\u672c\u6570\u636e\u7684\u9884\u5904\u7406\uff08\u5305\u62ec\u6df7\u5408\u8bed\u8a00\u5185\u5bb9\u7684\u6587\u672c\u6e05\u7406\u548c\u4f7f\u7528Jieba\u8fdb\u884c\u5206\u8bcd\uff09\uff0c\u4ee5\u53ca\u6700\u540e\u7684\u6a21\u578b\u8bad\u7ec3\u3001\u81ea\u5b9a\u4e49\u6570\u636e\u96c6\u548c\u6570\u636e\u52a0\u8f7d\u5668\u7684\u4f7f\u7528\u3002<\/p>\n\n\n\n<p>\u4ee5\u4e0b\u662f\u4e00\u4e9b\u5efa\u8bae\u548c\u6f84\u6e05\u70b9\uff0c\u4ee5\u786e\u4fdd\u4ee3\u7801\u6309\u9884\u671f\u5de5\u4f5c\uff0c\u5e76\u9075\u5faa\u6700\u4f73\u5b9e\u8df5\uff1a<\/p>\n\n\n\n<ol class=\"wp-block-list\">\n<li><strong>\u81ea\u6ce8\u610f\u529b\u548c\u53d8\u538b\u5668\u5757\u5b9e\u73b0<\/strong>\uff1a\u60a8\u7684\u81ea\u6ce8\u610f\u529b\u548c\u53d8\u538b\u5668\u5757\u5b9e\u73b0\u770b\u8d77\u6765\u5f88\u597d\u3002\u5b83\u9075\u5faa\u4e86\u6784\u5efa\u57fa\u4e8e\u53d8\u538b\u5668\u6a21\u578b\u7684\u6807\u51c6\u65b9\u6cd5\uff0c\u5305\u62ec\u5c06\u8f93\u5165\u5206\u5272\u6210\u591a\u4e2a\u5934\u3001\u5e94\u7528\u81ea\u6ce8\u610f\u529b\uff0c\u7136\u540e\u4f7f\u7528\u524d\u9988\u7f51\u7edc\u3002<\/li>\n\n\n\n<li><strong>\u6a21\u578b\u8bad\u7ec3\u5faa\u73af<\/strong>\uff1a\u8bad\u7ec3\u5faa\u73af\u5305\u62ec\u6df1\u5ea6\u5b66\u4e60\u6a21\u578b\u5178\u578b\u8bad\u7ec3\u8fc7\u7a0b\u7684\u57fa\u672c\u6b65\u9aa4\u3002\u5b83\u901a\u8fc7\u6a21\u578b\u5904\u7406\u8f93\u5165\u3001\u8ba1\u7b97\u635f\u5931\u3001\u6267\u884c\u53cd\u5411\u4f20\u64ad\u548c\u66f4\u65b0\u6a21\u578b\u7684\u6743\u91cd\u3002\u60a8\u8fd8\u5305\u62ec\u4e86\u8bbe\u5907\u517c\u5bb9\u6027\uff0c\u4ee5\u4fbf\u5728GPU\u4e0a\u8fd0\u884c\u6a21\u578b\uff08\u5982\u679c\u53ef\u7528\uff09\uff0c\u8fd9\u5bf9\u4e8e\u8bad\u7ec3\u6548\u7387\u81f3\u5173\u91cd\u8981\u3002<\/li>\n\n\n\n<li><strong>\u6587\u672c\u9884\u5904\u7406\u548c\u5206\u8bcd<\/strong>\uff1a\u60a8\u5305\u542b\u4e86\u6e05\u7406\u6587\u672c\u548c\u5206\u8bcd\u7684\u529f\u80fd\uff0c\u8fd9\u5bf9\u4e8eNLP\u4efb\u52a1\u81f3\u5173\u91cd\u8981\u3002\u4f7f\u7528Jieba\u8fdb\u884c\u5206\u8bcd\u9002\u7528\u4e8e\u5904\u7406\u4e2d\u6587\u6587\u672c\uff0c\u60a8\u7684\u6b63\u5219\u8868\u8fbe\u5f0f\u6e05\u7406\u6df7\u5408\u8bed\u8a00\u6587\u672c\u6db5\u76d6\u4e86\u5e7f\u6cdb\u7684\u5b57\u7b26\u3002<\/li>\n\n\n\n<li><strong>\u6570\u636e\u5904\u7406\u548c\u6570\u636e\u52a0\u8f7d\u5668<\/strong>\uff1a\u60a8\u5b9a\u4e49\u4e86\u4e00\u4e2a\u81ea\u5b9a\u4e49\u7684<code>Dataset<\/code>\u7c7b\uff0c\u5e76\u4f7f\u7528PyTorch\u7684<code>DataLoader<\/code>\u8fdb\u884c\u6279\u5904\u7406\u548c\u586b\u5145\u3002\u8fd9\u662f\u5904\u7406NLP\u4efb\u52a1\u4e2d\u53ef\u53d8\u957f\u5ea6\u5e8f\u5217\u7684\u597d\u65b9\u6cd5\u3002<\/li>\n\n\n\n<li><strong>\u6f5c\u5728\u6539\u8fdb<\/strong>\uff1a\n<ul class=\"wp-block-list\">\n<li><strong>\u6570\u636e\u9884\u5904\u7406\u4e2d\u7684\u9519\u8bef\u5904\u7406<\/strong>\uff1a\u786e\u4fdd\u60a8\u7684\u6587\u4ef6\u8bfb\u53d6\u548c\u6587\u672c\u9884\u5904\u7406\u80fd\u591f\u4f18\u96c5\u5730\u5904\u7406\u9519\u8bef\uff0c\u5c24\u5176\u662f\u5bf9\u4e8e\u53ef\u80fd\u4e0d\u5b58\u5728\u6216\u6709\u7f16\u7801\u95ee\u9898\u7684\u6587\u4ef6\u3002<\/li>\n\n\n\n<li><strong>\u6a21\u578b\u4e2d\u7684\u63a9\u7801\u4f7f\u7528<\/strong>\uff1a\u60a8\u7684\u8bc4\u8bba\u63d0\u5230\u4e86\u4e3a\u4e86\u7b80\u5316\u800c\u6ca1\u6709\u4f7f\u7528\u63a9\u7801\u3002\u5b9e\u9645\u4e0a\uff0c\u7279\u522b\u662f\u5bf9\u4e8e\u957f\u5ea6\u4e0d\u540c\u7684\u5e8f\u5217\uff0c\u63a9\u7801\u5bf9\u4e8e\u901a\u77e5\u6a21\u578b\u54ea\u4e9b\u8f93\u5165\u90e8\u5206\u662f\u586b\u5145\u4e14\u4e0d\u5e94\u8be5\u88ab\u5173\u6ce8\u662f\u81f3\u5173\u91cd\u8981\u7684\u3002<\/li>\n\n\n\n<li><strong>\u8bcd\u6c47\u8868\u6784\u5efa<\/strong>\uff1a\u6784\u5efa\u8bcd\u6c47\u8868\u548c\u7d22\u5f15\u6587\u672c\u7684\u8fc7\u7a0b\u5047\u8bbe\u6240\u6709\u6587\u672c\u90fd\u88ab\u5206\u8bcd\u6210\u4e00\u4e2a\u5e73\u9762\u5217\u8868\u3002\u5b9e\u9645\u4e0a\uff0c\u60a8\u53ef\u80fd\u6709\u591a\u4e2a\u6587\u6863\u6216\u53e5\u5b50\uff0c\u60a8\u53ef\u80fd\u5e0c\u671b\u5206\u522b\u5904\u7406\u5b83\u4eec\u6216\u4fdd\u6301\u53e5\u5b50\u8fb9\u754c\u3002<\/li>\n\n\n\n<li><strong>\u4fdd\u5b58\u6a21\u578b<\/strong>\uff1a\u60a8\u5c55\u793a\u4e86\u4e24\u79cd\u4fdd\u5b58\u6a21\u578b\u7684\u65b9\u5f0f\uff1b\u4ec5\u4fdd\u5b58\u6a21\u578b\u53c2\u6570\uff08<code>state_dict<\/code>\uff09\u66f4\u8282\u7701\u7a7a\u95f4\uff0c\u662f\u5927\u591a\u6570\u7528\u4f8b\u63a8\u8350\u7684\u65b9\u6cd5\u3002\u4fdd\u5b58\u6574\u4e2a\u6a21\u578b\u867d\u7136\u65b9\u4fbf\uff0c\u4f46\u5982\u679c\u9700\u8981\u5728\u4e0d\u540c\u73af\u5883\u4e2d\u52a0\u8f7d\u6a21\u578b\uff0c\u53ef\u80fd\u4f1a\u5bfc\u81f4\u95ee\u9898\u3002<\/li>\n<\/ul>\n<\/li>\n<\/ol>\n\n\n\n<p>\u5728\u8fd0\u884c\u4ee3\u7801\u4e4b\u524d\uff0c\u8bf7\u786e\u4fdd\u8c03\u6574\u6587\u4ef6\u8def\u5f84\uff0c\u5e76\u6839\u636e\u60a8\u7684\u5177\u4f53\u9700\u6c42\u53ef\u80fd\u6269\u5c55\u9884\u5904\u7406\u548c\u6570\u636e\u96c6\u5904\u7406\u3002\u6b64\u5916\uff0c\u8003\u8651\u5c1d\u8bd5\u4e0d\u540c\u7684\u6a21\u578b\u8d85\u53c2\u6570\uff08\u5982<code>embed_size<\/code>\u3001<code>num_layers<\/code>\u3001<code>heads<\/code>\u7b49\uff09\u548c\u8bad\u7ec3\u914d\u7f6e\uff0c\u4ee5\u627e\u5230\u9002\u5408\u60a8\u4efb\u52a1\u7684\u6700\u4f73\u8bbe\u7f6e\u3002<\/p>\n\n\n\n<p><\/p>\n","protected":false},"excerpt":{"rendered":"<p>\u5982\u679c\u4f60\u60f3\u4ece\u5934\u5f00\u59cb\u5b9e\u73b0\u4e00\u4e2a\u7b80\u5316\u7684\u7248\u672c\u7684GPT\u6a21\u578b\uff0c\u800c\u4e0d\u4f9d\u8d56\u4e8e\u73b0\u6210\u7684GPT-2\u6a21\u578b\u5e93\uff0c\u4f60\u53ef\u4ee5\u91c7\u7528PyTorch\u8fd9\u6837 [&hellip;]<\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"site-sidebar-layout":"default","site-content-layout":"","ast-site-content-layout":"default","site-content-style":"default","site-sidebar-style":"default","ast-global-header-display":"","ast-banner-title-visibility":"","ast-main-header-display":"","ast-hfb-above-header-display":"","ast-hfb-below-header-display":"","ast-hfb-mobile-header-display":"","site-post-title":"","ast-breadcrumbs-content":"","ast-featured-img":"","footer-sml-layout":"","theme-transparent-header-meta":"","adv-header-id-meta":"","stick-header-meta":"","header-above-stick-meta":"","header-main-stick-meta":"","header-below-stick-meta":"","astra-migrate-meta-layouts":"set","ast-page-background-enabled":"default","ast-page-background-meta":{"desktop":{"background-color":"var(--ast-global-color-4)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"tablet":{"background-color":"","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"mobile":{"background-color":"","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""}},"ast-content-background-meta":{"desktop":{"background-color":"var(--ast-global-color-5)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"tablet":{"background-color":"var(--ast-global-color-5)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"mobile":{"background-color":"var(--ast-global-color-5)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""}},"_jetpack_memberships_contains_paid_content":false,"footnotes":""},"categories":[313,289,443,442,312],"tags":[242,314],"class_list":["post-2106","post","type-post","status-publish","format-standard","hentry","category-chatgpt","category-gpt","category-llm","category-llms","category-openai","tag-chatgpt","tag-openai-api"],"views":2208,"jetpack_sharing_enabled":true,"jetpack_featured_media_url":"","_links":{"self":[{"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/posts\/2106","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=%2Fwp%2Fv2%2Fcomments&post=2106"}],"version-history":[{"count":6,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/posts\/2106\/revisions"}],"predecessor-version":[{"id":2113,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=\/wp\/v2\/posts\/2106\/revisions\/2113"}],"wp:attachment":[{"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=%2Fwp%2Fv2%2Fmedia&parent=2106"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=%2Fwp%2Fv2%2Fcategories&post=2106"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/www.aqwu.net\/wp\/index.php?rest_route=%2Fwp%2Fv2%2Ftags&post=2106"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}