0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

PyTorch教程-16.7。自然语言推理:微调 BERT

jf_pJlTbmA9 来源:PyTorch 作者:PyTorch 2023-06-05 15:44 次阅读

在本章前面的部分中,我们为 SNLI 数据集上的自然语言推理任务(如第 16.4 节所述)设计了一个基于注意力的架构(第16.5节)。现在我们通过微调 BERT 重新审视这个任务。正如16.6 节所讨论的 ,自然语言推理是一个序列级文本对分类问题,微调 BERT 只需要一个额外的基于 MLP 的架构,如图 16.7.1所示。

poYBAGR9POGANyPIAAKGzmOF458734.svg

图 16.7.1本节将预训练的 BERT 提供给基于 MLP 的自然语言推理架构。

在本节中,我们将下载预训练的小型 BERT 版本,然后对其进行微调以在 SNLI 数据集上进行自然语言推理。

import json
import multiprocessing
import os
import torch
from torch import nn
from d2l import torch as d2l

import json
import multiprocessing
import os
from mxnet import gluon, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l

npx.set_np()

16.7.1。加载预训练的 BERT

我们已经在第 15.9 节和第 15.10 节中解释了如何在 WikiText-2 数据集上预训练 BERT (请注意,原始 BERT 模型是在更大的语料库上预训练的)。如15.10 节所述,原始 BERT 模型有数亿个参数。在下文中,我们提供了两个版本的预训练 BERT:“bert.base”与需要大量计算资源进行微调的原始 BERT 基础模型差不多大,而“bert.small”是一个小版本方便演示。

d2l.DATA_HUB['bert.base'] = (d2l.DATA_URL + 'bert.base.torch.zip',
               '225d66f04cae318b841a13d32af3acc165f253ac')
d2l.DATA_HUB['bert.small'] = (d2l.DATA_URL + 'bert.small.torch.zip',
               'c72329e68a732bef0452e4b96a1c341c8910f81f')

d2l.DATA_HUB['bert.base'] = (d2l.DATA_URL + 'bert.base.zip',
               '7b3820b35da691042e5d34c0971ac3edbd80d3f4')
d2l.DATA_HUB['bert.small'] = (d2l.DATA_URL + 'bert.small.zip',
               'a4e718a47137ccd1809c9107ab4f5edd317bae2c')

预训练的 BERT 模型都包含一个定义词汇集的“vocab.json”文件和一个预训练参数的“pretrained.params”文件。我们实现以下load_pretrained_model 函数来加载预训练的 BERT 参数。

def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens,
             num_heads, num_blks, dropout, max_len, devices):
  data_dir = d2l.download_extract(pretrained_model)
  # Define an empty vocabulary to load the predefined vocabulary
  vocab = d2l.Vocab()
  vocab.idx_to_token = json.load(open(os.path.join(data_dir, 'vocab.json')))
  vocab.token_to_idx = {token: idx for idx, token in enumerate(
    vocab.idx_to_token)}
  bert = d2l.BERTModel(
    len(vocab), num_hiddens, ffn_num_hiddens=ffn_num_hiddens, num_heads=4,
    num_blks=2, dropout=0.2, max_len=max_len)
  # Load pretrained BERT parameters
  bert.load_state_dict(torch.load(os.path.join(data_dir,
                         'pretrained.params')))
  return bert, vocab

def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens,
             num_heads, num_blks, dropout, max_len, devices):
  data_dir = d2l.download_extract(pretrained_model)
  # Define an empty vocabulary to load the predefined vocabulary
  vocab = d2l.Vocab()
  vocab.idx_to_token = json.load(open(os.path.join(data_dir, 'vocab.json')))
  vocab.token_to_idx = {token: idx for idx, token in enumerate(
    vocab.idx_to_token)}
  bert = d2l.BERTModel(len(vocab), num_hiddens, ffn_num_hiddens, num_heads,
             num_blks, dropout, max_len)
  # Load pretrained BERT parameters
  bert.load_parameters(os.path.join(data_dir, 'pretrained.params'),
             ctx=devices)
  return bert, vocab

为了便于在大多数机器上进行演示,我们将在本节中加载和微调预训练 BERT 的小型版本(“bert.small”)。在练习中,我们将展示如何微调更大的“bert.base”以显着提高测试准确性。

devices = d2l.try_all_gpus()
bert, vocab = load_pretrained_model(
  'bert.small', num_hiddens=256, ffn_num_hiddens=512, num_heads=4,
  num_blks=2, dropout=0.1, max_len=512, devices=devices)

Downloading ../data/bert.small.torch.zip from http://d2l-data.s3-accelerate.amazonaws.com/bert.small.torch.zip...

devices = d2l.try_all_gpus()
bert, vocab = load_pretrained_model(
  'bert.small', num_hiddens=256, ffn_num_hiddens=512, num_heads=4,
  num_blks=2, dropout=0.1, max_len=512, devices=devices)

Downloading ../data/bert.small.zip from http://d2l-data.s3-accelerate.amazonaws.com/bert.small.zip...

16.7.2。微调 BERT 的数据集

对于 SNLI 数据集上的下游任务自然语言推理,我们定义了一个自定义的数据集类SNLIBERTDataset。在每个示例中,前提和假设形成一对文本序列,并被打包到一个 BERT 输入序列中,如图 16.6.2所示。回想第 15.8.4 节 ,段 ID 用于区分 BERT 输入序列中的前提和假设。对于 BERT 输入序列 ( max_len) 的预定义最大长度,输入文本对中较长者的最后一个标记会不断被删除,直到max_len满足为止。为了加速生成用于微调 BERT 的 SNLI 数据集,我们使用 4 个工作进程并行生成训练或测试示例。

class SNLIBERTDataset(torch.utils.data.Dataset):
  def __init__(self, dataset, max_len, vocab=None):
    all_premise_hypothesis_tokens = [[
      p_tokens, h_tokens] for p_tokens, h_tokens in zip(
      *[d2l.tokenize([s.lower() for s in sentences])
       for sentences in dataset[:2]])]

    self.labels = torch.tensor(dataset[2])
    self.vocab = vocab
    self.max_len = max_len
    (self.all_token_ids, self.all_segments,
     self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)
    print('read ' + str(len(self.all_token_ids)) + ' examples')

  def _preprocess(self, all_premise_hypothesis_tokens):
    pool = multiprocessing.Pool(4) # Use 4 worker processes
    out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)
    all_token_ids = [
      token_ids for token_ids, segments, valid_len in out]
    all_segments = [segments for token_ids, segments, valid_len in out]
    valid_lens = [valid_len for token_ids, segments, valid_len in out]
    return (torch.tensor(all_token_ids, dtype=torch.long),
        torch.tensor(all_segments, dtype=torch.long),
        torch.tensor(valid_lens))

  def _mp_worker(self, premise_hypothesis_tokens):
    p_tokens, h_tokens = premise_hypothesis_tokens
    self._truncate_pair_of_tokens(p_tokens, h_tokens)
    tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens)
    token_ids = self.vocab[tokens] + [self.vocab['']] 
               * (self.max_len - len(tokens))
    segments = segments + [0] * (self.max_len - len(segments))
    valid_len = len(tokens)
    return token_ids, segments, valid_len

  def _truncate_pair_of_tokens(self, p_tokens, h_tokens):
    # Reserve slots for '', '', and '' tokens for the BERT
    # input
    while len(p_tokens) + len(h_tokens) > self.max_len - 3:
      if len(p_tokens) > len(h_tokens):
        p_tokens.pop()
      else:
        h_tokens.pop()

  def __getitem__(self, idx):
    return (self.all_token_ids[idx], self.all_segments[idx],
        self.valid_lens[idx]), self.labels[idx]

  def __len__(self):
    return len(self.all_token_ids)

class SNLIBERTDataset(gluon.data.Dataset):
  def __init__(self, dataset, max_len, vocab=None):
    all_premise_hypothesis_tokens = [[
      p_tokens, h_tokens] for p_tokens, h_tokens in zip(
      *[d2l.tokenize([s.lower() for s in sentences])
       for sentences in dataset[:2]])]

    self.labels = np.array(dataset[2])
    self.vocab = vocab
    self.max_len = max_len
    (self.all_token_ids, self.all_segments,
     self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)
    print('read ' + str(len(self.all_token_ids)) + ' examples')

  def _preprocess(self, all_premise_hypothesis_tokens):
    pool = multiprocessing.Pool(4) # Use 4 worker processes
    out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)
    all_token_ids = [
      token_ids for token_ids, segments, valid_len in out]
    all_segments = [segments for token_ids, segments, valid_len in out]
    valid_lens = [valid_len for token_ids, segments, valid_len in out]
    return (np.array(all_token_ids, dtype='int32'),
        np.array(all_segments, dtype='int32'),
        np.array(valid_lens))

  def _mp_worker(self, premise_hypothesis_tokens):
    p_tokens, h_tokens = premise_hypothesis_tokens
    self._truncate_pair_of_tokens(p_tokens, h_tokens)
    tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens)
    token_ids = self.vocab[tokens] + [self.vocab['']] 
               * (self.max_len - len(tokens))
    segments = segments + [0] * (self.max_len - len(segments))
    valid_len = len(tokens)
    return token_ids, segments, valid_len

  def _truncate_pair_of_tokens(self, p_tokens, h_tokens):
    # Reserve slots for '', '', and '' tokens for the BERT
    # input
    while len(p_tokens) + len(h_tokens) > self.max_len - 3:
      if len(p_tokens) > len(h_tokens):
        p_tokens.pop()
      else:
        h_tokens.pop()

  def __getitem__(self, idx):
    return (self.all_token_ids[idx], self.all_segments[idx],
        self.valid_lens[idx]), self.labels[idx]

  def __len__(self):
    return len(self.all_token_ids)

下载 SNLI 数据集后,我们通过实例化SNLIBERTDataset类来生成训练和测试示例。此类示例将在自然语言推理的训练和测试期间以小批量读取。

# Reduce `batch_size` if there is an out of memory error. In the original BERT
# model, `max_len` = 512
batch_size, max_len, num_workers = 512, 128, d2l.get_dataloader_workers()
data_dir = d2l.download_extract('SNLI')
train_set = SNLIBERTDataset(d2l.read_snli(data_dir, True), max_len, vocab)
test_set = SNLIBERTDataset(d2l.read_snli(data_dir, False), max_len, vocab)
train_iter = torch.utils.data.DataLoader(train_set, batch_size, shuffle=True,
                  num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(test_set, batch_size,
                 num_workers=num_workers)

read 549367 examples
read 9824 examples

# Reduce `batch_size` if there is an out of memory error. In the original BERT
# model, `max_len` = 512
batch_size, max_len, num_workers = 512, 128, d2l.get_dataloader_workers()
data_dir = d2l.download_extract('SNLI')
train_set = SNLIBERTDataset(d2l.read_snli(data_dir, True), max_len, vocab)
test_set = SNLIBERTDataset(d2l.read_snli(data_dir, False), max_len, vocab)
train_iter = gluon.data.DataLoader(train_set, batch_size, shuffle=True,
                  num_workers=num_workers)
test_iter = gluon.data.DataLoader(test_set, batch_size,
                 num_workers=num_workers)

read 549367 examples
read 9824 examples

16.7.3。微调 BERT

如图16.6.2所示,为自然语言推理微调 BERT 只需要一个额外的 MLP,该 MLP 由两个完全连接的层组成(参见下一类中的self.hidden和)。该 MLP 将特殊“”标记的 BERT 表示形式(对前提和假设的信息进行编码)转换为自然语言推理的三个输出:蕴含、矛盾和中性。self.outputBERTClassifier

class BERTClassifier(nn.Module):
  def __init__(self, bert):
    super(BERTClassifier, self).__init__()
    self.encoder = bert.encoder
    self.hidden = bert.hidden
    self.output = nn.LazyLinear(3)

  def forward(self, inputs):
    tokens_X, segments_X, valid_lens_x = inputs
    encoded_X = self.encoder(tokens_X, segments_X, valid_lens_x)
    return self.output(self.hidden(encoded_X[:, 0, :]))

class BERTClassifier(nn.Block):
  def __init__(self, bert):
    super(BERTClassifier, self).__init__()
    self.encoder = bert.encoder
    self.hidden = bert.hidden
    self.output = nn.Dense(3)

  def forward(self, inputs):
    tokens_X, segments_X, valid_lens_x = inputs
    encoded_X = self.encoder(tokens_X, segments_X, valid_lens_x)
    return self.output(self.hidden(encoded_X[:, 0, :]))

接下来,预训练的 BERT 模型bert被输入到 下游应用程序的BERTClassifier实例中。net在 BERT 微调的常见实现中,只会net.output从头学习附加 MLP ( ) 输出层的参数。net.encoder预训练的 BERT 编码器 ( ) 和附加 MLP 的隐藏层 ( )的所有参数都net.hidden将被微调。

net = BERTClassifier(bert)

net = BERTClassifier(bert)
net.output.initialize(ctx=devices)

回想一下15.8 节中类MaskLM和 NextSentencePred类在它们使用的 MLP 中都有参数。这些参数是预训练 BERT 模型中参数bert的一部分,因此也是net. 然而,这些参数仅用于计算预训练期间的掩码语言建模损失和下一句预测损失。MaskLM这两个损失函数与微调下游应用程序无关,因此在微调 BERT 时,在和中使用的 MLP 的参数NextSentencePred不会更新(失效)。

为了允许具有陈旧梯度的参数,在的函数 ignore_stale_grad=True中设置了标志 。我们使用此函数使用SNLI 的训练集 ( ) 和测试集 ( )来训练和评估模型。由于计算资源有限,训练和测试的准确性可以进一步提高:我们将其讨论留在练习中。stepd2l.train_batch_ch13nettrain_itertest_iter

lr, num_epochs = 1e-4, 5
trainer = torch.optim.Adam(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss(reduction='none')
net(next(iter(train_iter))[0])
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)

loss 0.519, train acc 0.791, test acc 0.782
9226.8 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]

pYYBAGR9POSAGNU7AAECAEIwxyI914.svg

lr, num_epochs = 1e-4, 5
trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': lr})
loss = gluon.loss.SoftmaxCrossEntropyLoss()
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices,
        d2l.split_batch_multi_inputs)

loss 0.477, train acc 0.810, test acc 0.785
4626.9 examples/sec on [gpu(0), gpu(1)]

pYYBAGR9POaAD1LSAAECF6LTtKs956.svg

16.7.4。概括

我们可以为下游应用微调预训练的 BERT 模型,例如 SNLI 数据集上的自然语言推理。

在微调期间,BERT 模型成为下游应用模型的一部分。仅与预训练损失相关的参数在微调期间不会更新。

16.7.5。练习

如果您的计算资源允许,微调一个更大的预训练 BERT 模型,该模型与原始 BERT 基础模型差不多大。将函数中的参数设置load_pretrained_model为:将“bert.small”替换为“bert.base”,将 、 、 和 的值分别增加到 num_hiddens=256768、3072、12ffn_num_hiddens=512和num_heads=412 num_blks=2。通过增加微调周期(并可能调整其他超参数),您能否获得高于 0.86 的测试精度?

如何根据长度比截断一对序列?比较这对截断方法和类中使用的方法 SNLIBERTDataset。他们的优缺点是什么?

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 自然语言
    +关注

    关注

    1

    文章

    284

    浏览量

    13315
  • pytorch
    +关注

    关注

    2

    文章

    798

    浏览量

    13099
收藏 人收藏

    评论

    相关推荐

    如何开始使用PyTorch进行自然语言处理

    随着人工智能和深度学习程序在未来几年的蓬勃发展,自然语言处理(NLP)将日益普及,而且必要性也与日俱增。PyTorch 自然语言处理是实现这些程序的不错选择。
    的头像 发表于 07-07 10:01 2459次阅读

    python自然语言

    最近,python自然语言是越来越火了,那么什么是自然语言自然语言(Natural Language )广纳了众多技术,对自然或人类语言
    发表于 05-02 13:50

    自然语言处理的语言模型

    自然语言处理——53 语言模型(数据平滑)
    发表于 04-16 11:11

    什么是自然语言处理

    什么是自然语言处理?自然语言处理任务有哪些?自然语言处理的方法是什么?
    发表于 09-08 06:51

    自然语言处理怎么最快入门_自然语言处理知识了解

    自然语言处理就是实现人机间自然语言通信,实现自然语言理解和自然语言生成是十分困难的,造成困难的根本原因是自然语言文本和对话的各个层次上广泛存
    发表于 12-28 17:10 5272次阅读

    自然语言入门之ESIM

    ESIM是ACL2017的一篇论文,在当时成为各个NLP比赛的杀器,直到现在仍是入门自然语言推理值得一读的文章。 本文根据ESIM原文以及pytorch代码实现对ESIM模型进行总结
    的头像 发表于 02-22 11:34 951次阅读
    <b class='flag-5'>自然语言</b>入门之ESIM

    PyTorch教程16.4之自然语言推理和数据集

    电子发烧友网站提供《PyTorch教程16.4之自然语言推理和数据集.pdf》资料免费下载
    发表于 06-05 10:57 0次下载
    <b class='flag-5'>PyTorch</b>教程16.4之<b class='flag-5'>自然语言</b><b class='flag-5'>推理</b>和数据集

    PyTorch教程16.5之自然语言推理:使用注意力

    电子发烧友网站提供《PyTorch教程16.5之自然语言推理:使用注意力.pdf》资料免费下载
    发表于 06-05 10:49 0次下载
    <b class='flag-5'>PyTorch</b>教程16.5之<b class='flag-5'>自然语言</b><b class='flag-5'>推理</b>:使用注意力

    PyTorch教程16.6之针对序列级和令牌级应用程序微调BERT

    电子发烧友网站提供《PyTorch教程16.6之针对序列级和令牌级应用程序微调BERT.pdf》资料免费下载
    发表于 06-05 10:51 0次下载
    <b class='flag-5'>PyTorch</b>教程16.6之针对序列级和令牌级应用程序<b class='flag-5'>微调</b><b class='flag-5'>BERT</b>

    PyTorch教程16.7自然语言推理微调BERT

    电子发烧友网站提供《PyTorch教程16.7自然语言推理微调BERT.pdf》资料免费下载
    发表于 06-05 10:52 0次下载
    <b class='flag-5'>PyTorch</b>教程<b class='flag-5'>16.7</b>之<b class='flag-5'>自然语言</b><b class='flag-5'>推理</b>:<b class='flag-5'>微调</b><b class='flag-5'>BERT</b>

    PyTorch教程-16.4。自然语言推理和数据集

    16.4。自然语言推理和数据集¶ Colab [火炬]在 Colab 中打开笔记本 Colab [mxnet] Open the notebook in Colab Colab [jax
    的头像 发表于 06-05 15:44 487次阅读

    PyTorch教程-16.5。自然语言推理:使用注意力

    16.5。自然语言推理:使用注意力¶ Colab [火炬]在 Colab 中打开笔记本 Colab [mxnet] Open the notebook in Colab Colab
    的头像 发表于 06-05 15:44 520次阅读
    <b class='flag-5'>PyTorch</b>教程-16.5。<b class='flag-5'>自然语言</b><b class='flag-5'>推理</b>:使用注意力

    PyTorch教程-16.6. 针对序列级和令牌级应用程序微调 BERT

    和 MLPs。当存在空间或时间限制时,这些模型很有用,但是,为每个自然语言处理任务制作一个特定模型实际上是不可行的。在 15.8 节中,我们介绍了一种预训练模型 BERT,它需要对各种自然语言处理任务进行
    的头像 发表于 06-05 15:44 392次阅读
    <b class='flag-5'>PyTorch</b>教程-16.6. 针对序列级和令牌级应用程序<b class='flag-5'>微调</b> <b class='flag-5'>BERT</b>

    自然语言处理的概念和应用 自然语言处理属于人工智能吗

      自然语言处理(Natural Language Processing)是一种人工智能技术,它是研究自然语言与计算机之间的交互和通信的一门学科。自然语言处理旨在研究机器如何理解人类语言
    发表于 08-23 17:31 1244次阅读

    ChatGPT是一个好的因果推理器吗?

    因果推理能力对于许多自然语言处理(NLP)应用至关重要。最近的因果推理系统主要基于经过微调的预训练语言模型(PLMs),如
    的头像 发表于 01-03 09:55 787次阅读
    ChatGPT是一个好的因果<b class='flag-5'>推理</b>器吗?