电子发烧友App

硬声App

0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示
电子发烧友网>电子资料下载>电子资料>PyTorch教程16.7之自然语言推理:微调BERT

PyTorch教程16.7之自然语言推理:微调BERT

2023-06-05 | pdf | 0.22 MB | 次下载 | 免费

资料介绍

在本章前面的部分中,我们为 SNLI 数据集上的自然语言推理任务(如第 16.4 节所述)设计了一个基于注意力的架构(第16.5节)。现在我们通过微调 BERT 重新审视这个任务。正如16.6 节所讨论的 ,自然语言推理是一个序列级文本对分类问题,微调 BERT 只需要一个额外的基于 MLP 的架构,如图 16.7.1所示。

https://file.elecfans.com/web2/M00/A9/CD/poYBAGR9POGANyPIAAKGzmOF458734.svg

图 16.7.1本节将预训练的 BERT 提供给基于 MLP 的自然语言推理架构。

在本节中,我们将下载预训练的小型 BERT 版本,然后对其进行微调以在 SNLI 数据集上进行自然语言推理。

import json
import multiprocessing
import os
import torch
from torch import nn
from d2l import torch as d2l
import json
import multiprocessing
import os
from mxnet import gluon, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l

npx.set_np()

16.7.1。加载预训练的 BERT

我们已经在第 15.9 节第 15.10 节中解释了如何在 WikiText-2 数据集上预训练 BERT (请注意,原始 BERT 模型是在更大的语料库上预训练的)。如15.10 节所述,原始 BERT 模型有数亿个参数。在下文中,我们提供了两个版本的预训练 BERT:“bert.base”与需要大量计算资源进行微调的原始 BERT 基础模型差不多大,而“bert.small”是一个小版本方便演示。

d2l.DATA_HUB['bert.base'] = (d2l.DATA_URL + 'bert.base.torch.zip',
               '225d66f04cae318b841a13d32af3acc165f253ac')
d2l.DATA_HUB['bert.small'] = (d2l.DATA_URL + 'bert.small.torch.zip',
               'c72329e68a732bef0452e4b96a1c341c8910f81f')
d2l.DATA_HUB['bert.base'] = (d2l.DATA_URL + 'bert.base.zip',
               '7b3820b35da691042e5d34c0971ac3edbd80d3f4')
d2l.DATA_HUB['bert.small'] = (d2l.DATA_URL + 'bert.small.zip',
               'a4e718a47137ccd1809c9107ab4f5edd317bae2c')

预训练的 BERT 模型都包含一个定义词汇集的“vocab.json”文件和一个预训练参数的“pretrained.params”文件。我们实现以下load_pretrained_model 函数来加载预训练的 BERT 参数。

def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens,
             num_heads, num_blks, dropout, max_len, devices):
  data_dir = d2l.download_extract(pretrained_model)
  # Define an empty vocabulary to load the predefined vocabulary
  vocab = d2l.Vocab()
  vocab.idx_to_token = json.load(open(os.path.join(data_dir, 'vocab.json')))
  vocab.token_to_idx = {token: idx for idx, token in enumerate(
    vocab.idx_to_token)}
  bert = d2l.BERTModel(
    len(vocab), num_hiddens, ffn_num_hiddens=ffn_num_hiddens, num_heads=4,
    num_blks=2, dropout=0.2, max_len=max_len)
  # Load pretrained BERT parameters
  bert.load_state_dict(torch.load(os.path.join(data_dir,
                         'pretrained.params')))
  return bert, vocab
def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens,
             num_heads, num_blks, dropout, max_len, devices):
  data_dir = d2l.download_extract(pretrained_model)
  # Define an empty vocabulary to load the predefined vocabulary
  vocab = d2l.Vocab()
  vocab.idx_to_token = json.load(open(os.path.join(data_dir, 'vocab.json')))
  vocab.token_to_idx = {token: idx for idx, token in enumerate(
    vocab.idx_to_token)}
  bert = d2l.BERTModel(len(vocab), num_hiddens, ffn_num_hiddens, num_heads,
             num_blks, dropout, max_len)
  # Load pretrained BERT parameters
  bert.load_parameters(os.path.join(data_dir, 'pretrained.params'),
             ctx=devices)
  return bert, vocab

为了便于在大多数机器上进行演示,我们将在本节中加载和微调预训练 BERT 的小型版本(“bert.small”)。在练习中,我们将展示如何微调更大的“bert.base”以显着提高测试准确性。

devices = d2l.try_all_gpus()
bert, vocab = load_pretrained_model(
  'bert.small', num_hiddens=256, ffn_num_hiddens=512, num_heads=4,
  num_blks=2, dropout=0.1, max_len=512, devices=devices)
Downloading ../data/bert.small.torch.zip from http://d2l-data.s3-accelerate.amazonaws.com/bert.small.torch.zip...
devices = d2l.try_all_gpus()
bert, vocab = load_pretrained_model(
  'bert.small', num_hiddens=256, ffn_num_hiddens=512, num_heads=4,
  num_blks=2, dropout=0.1, max_len=512, devices=devices)
Downloading ../data/bert.small.zip from http://d2l-data.s3-accelerate.amazonaws.com/bert.small.zip...

16.7.2。微调 BERT 的数据集

对于 SNLI 数据集上的下游任务自然语言推理,我们定义了一个自定义的数据集类SNLIBERTDataset在每个示例中,前提和假设形成一对文本序列,并被打包到一个 BERT 输入序列中,如图 16.6.2所示。回想第 15.8.4 节 ,段 ID 用于区分 BERT 输入序列中的前提和假设。对于 BERT 输入序列 ( max_len) 的预定义最大长度,输入文本对中较长者的最后一个标记会不断被删除,直到max_len满足为止。为了加速生成用于微调 BERT 的 SNLI 数据集,我们使用 4 个工作进程并行生成训练或测试示例。

class SNLIBERTDataset(torch.utils.data.Dataset):
  def __init__(self, dataset, max_len, vocab=None):
    all_premise_hypothesis_tokens = [[
      p_tokens, h_tokens] for p_tokens, h_tokens in zip(
      *[d2l.tokenize([s.lower() for s in sentences])
       for sentences in dataset[:2]])]

    self.labels = torch.tensor(dataset[2])
    self.vocab = vocab
    self.max_len = max_len
    (self.all_token_ids, self.all_segments,
     self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)
    print('read ' + str(len(self.all_token_ids)) + ' examples')

  def _preprocess(self, all_premise_hypothesis_tokens):
    pool = multiprocessing.Pool(4) # Use 4 worker processes
    out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)
    all_token_ids = [
      token_ids for token_ids, segments, valid_len in out]
    all_segments = [segments for token_ids, segments, valid_len in out]
    valid_lens = [valid_len for token_ids, segments, valid_len in out]
    return (torch.tensor(all_token_ids, dtype=torch.long),
        torch.tensor(all_segments, dtype=torch.long),
        torch.tensor(valid_lens))

  def _mp_worker(self, premise_hypothesis_tokens):
    p_tokens, h_tokens = premise_hypothesis_tokens
    self._truncate_pair_of_tokens(p_tokens, h_tokens)
    tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens)
    token_ids = self.vocab[tokens] + [self.vocab['']] \
               * (self.max_len - len(tokens))
    segments = segments + [0] * (self.max_len - len(segments))
    valid_len = len(tokens)
    return token_ids, segments, valid_len

  def _truncate_pair_of_tokens(self, p_tokens, h_tokens):
    # Reserve slots for '', '', and '' tokens for the BERT
    # input
    while len(p_tokens) + len(h_tokens) > self.max_len - 3:
      if len(p_tokens) > len(h_tokens):
        p_tokens.pop()
      else:
        h_tokens.pop()

  def __getitem__(self, idx):
    return (self.all_token_ids[idx], self.all_segments[idx],
        self.valid_lens[idx]), self.labels[idx]

  def __len__(self):
    return len(self.all_token_ids)
class SNLIBERTDataset(gluon.data.Dataset):
  def __init__(self, dataset, max_len, vocab=None):
    all_premise_hypothesis_tokens = [[
      p_tokens, h_tokens
下载该资料的人也在下载 下载该资料的人还在阅读
更多 >

评论

查看更多

下载排行

本周

  1. 1A7159和A7139射频芯片的资料免费下载
  2. 0.20 MB   |  55次下载  |  5 积分
  3. 2PIC12F629/675 数据手册免费下载
  4. 2.38 MB   |  36次下载  |  5 积分
  5. 3PIC16F716 数据手册免费下载
  6. 2.35 MB   |  18次下载  |  5 积分
  7. 4dsPIC33EDV64MC205电机控制开发板用户指南
  8. 5.78MB   |  8次下载  |  免费
  9. 5STC15系列常用寄存器汇总免费下载
  10. 1.60 MB   |  7次下载  |  5 积分
  11. 6模拟电路仿真实现
  12. 2.94MB   |  4次下载  |  免费
  13. 7PCB图绘制实例操作
  14. 2.92MB   |  2次下载  |  免费
  15. 8零死角玩转STM32F103—指南者
  16. 26.78 MB   |  1次下载  |  1 积分

本月

  1. 1ADI高性能电源管理解决方案
  2. 2.43 MB   |  452次下载  |  免费
  3. 2免费开源CC3D飞控资料(电路图&PCB源文件、BOM、
  4. 5.67 MB   |  141次下载  |  1 积分
  5. 3基于STM32单片机智能手环心率计步器体温显示设计
  6. 0.10 MB   |  137次下载  |  免费
  7. 4A7159和A7139射频芯片的资料免费下载
  8. 0.20 MB   |  55次下载  |  5 积分
  9. 5PIC12F629/675 数据手册免费下载
  10. 2.38 MB   |  36次下载  |  5 积分
  11. 6如何正确测试电源的纹波
  12. 0.36 MB   |  19次下载  |  免费
  13. 7PIC16F716 数据手册免费下载
  14. 2.35 MB   |  18次下载  |  5 积分
  15. 8Q/SQR E8-4-2024乘用车电子电器零部件及子系统EMC试验方法及要求
  16. 1.97 MB   |  8次下载  |  10 积分

总榜

  1. 1matlab软件下载入口
  2. 未知  |  935121次下载  |  10 积分
  3. 2开源硬件-PMP21529.1-4 开关降压/升压双向直流/直流转换器 PCB layout 设计
  4. 1.48MB  |  420062次下载  |  10 积分
  5. 3Altium DXP2002下载入口
  6. 未知  |  233088次下载  |  10 积分
  7. 4电路仿真软件multisim 10.0免费下载
  8. 340992  |  191367次下载  |  10 积分
  9. 5十天学会AVR单片机与C语言视频教程 下载
  10. 158M  |  183335次下载  |  10 积分
  11. 6labview8.5下载
  12. 未知  |  81581次下载  |  10 积分
  13. 7Keil工具MDK-Arm免费下载
  14. 0.02 MB  |  73810次下载  |  10 积分
  15. 8LabVIEW 8.6下载
  16. 未知  |  65988次下载  |  10 积分