0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

PyTorch教程-4.3. 基本分类模型

jf_pJlTbmA9 来源:PyTorch 作者:PyTorch 2023-06-05 15:43 次阅读

您可能已经注意到,在回归的情况下,从头开始的实现和使用框架功能的简洁实现非常相似。分类也是如此。由于本书中的许多模型都处理分类,因此值得添加专门支持此设置的功能。本节为分类模型提供了一个基类,以简化以后的代码。

import torch
from d2l import torch as d2l

from mxnet import autograd, gluon, np, npx
from d2l import mxnet as d2l

npx.set_np()

from functools import partial
import jax
import optax
from jax import numpy as jnp
from d2l import jax as d2l

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

import tensorflow as tf
from d2l import tensorflow as d2l

4.3.1. 类Classifier_

我们在下面定义Classifier类。在中,validation_step我们报告了验证批次的损失值和分类准确度。我们为每个批次绘制一个更新num_val_batches 。这有利于在整个验证数据上生成平均损失和准确性。如果最后一批包含的示例较少,则这些平均数并不完全正确,但我们忽略了这一微小差异以保持代码简单。

class Classifier(d2l.Module): #@save
  """The base class of classification models."""
  def validation_step(self, batch):
    Y_hat = self(*batch[:-1])
    self.plot('loss', self.loss(Y_hat, batch[-1]), train=False)
    self.plot('acc', self.accuracy(Y_hat, batch[-1]), train=False)

We define the Classifier class below. In the validation_step we report both the loss value and the classification accuracy on a validation batch. We draw an update for every num_val_batches batches. This has the benefit of generating the averaged loss and accuracy on the whole validation data. These average numbers are not exactly correct if the last batch contains fewer examples, but we ignore this minor difference to keep the code simple.

class Classifier(d2l.Module): #@save
  """The base class of classification models."""
  def validation_step(self, batch):
    Y_hat = self(*batch[:-1])
    self.plot('loss', self.loss(Y_hat, batch[-1]), train=False)
    self.plot('acc', self.accuracy(Y_hat, batch[-1]), train=False)

We define the Classifier class below. In the validation_step we report both the loss value and the classification accuracy on a validation batch. We draw an update for every num_val_batches batches. This has the benefit of generating the averaged loss and accuracy on the whole validation data. These average numbers are not exactly correct if the last batch contains fewer examples, but we ignore this minor difference to keep the code simple.

We also redefine the training_step method for JAX since all models that will subclass Classifier later will have a loss that returns auxiliary data. This auxiliary data can be used for models with batch normalization (to be explained in Section 8.5), while in all other cases we will make the loss also return a placeholder (empty dictionary) to represent the auxiliary data.

class Classifier(d2l.Module): #@save
  """The base class of classification models."""
  def training_step(self, params, batch, state):
    # Here value is a tuple since models with BatchNorm layers require
    # the loss to return auxiliary data
    value, grads = jax.value_and_grad(
      self.loss, has_aux=True)(params, batch[:-1], batch[-1], state)
    l, _ = value
    self.plot("loss", l, train=True)
    return value, grads

  def validation_step(self, params, batch, state):
    # Discard the second returned value. It is used for training models
    # with BatchNorm layers since loss also returns auxiliary data
    l, _ = self.loss(params, batch[:-1], batch[-1], state)
    self.plot('loss', l, train=False)
    self.plot('acc', self.accuracy(params, batch[:-1], batch[-1], state),
         train=False)

We define the Classifier class below. In the validation_step we report both the loss value and the classification accuracy on a validation batch. We draw an update for every num_val_batches batches. This has the benefit of generating the averaged loss and accuracy on the whole validation data. These average numbers are not exactly correct if the last batch contains fewer examples, but we ignore this minor difference to keep the code simple.

class Classifier(d2l.Module): #@save
  """The base class of classification models."""
  def validation_step(self, batch):
    Y_hat = self(*batch[:-1])
    self.plot('loss', self.loss(Y_hat, batch[-1]), train=False)
    self.plot('acc', self.accuracy(Y_hat, batch[-1]), train=False)

默认情况下,我们使用随机梯度下降优化器,在小批量上运行,就像我们在线性回归的上下文中所做的那样。

@d2l.add_to_class(d2l.Module) #@save
def configure_optimizers(self):
  return torch.optim.SGD(self.parameters(), lr=self.lr)

@d2l.add_to_class(d2l.Module) #@save
def configure_optimizers(self):
  params = self.parameters()
  if isinstance(params, list):
    return d2l.SGD(params, self.lr)
  return gluon.Trainer(params, 'sgd', {'learning_rate': self.lr})

@d2l.add_to_class(d2l.Module) #@save
def configure_optimizers(self):
  return optax.sgd(self.lr)

@d2l.add_to_class(d2l.Module) #@save
def configure_optimizers(self):
  return tf.keras.optimizers.SGD(self.lr)

4.3.2. 准确性

给定预测概率分布y_hat,每当我们必须输出硬预测时,我们通常会选择预测概率最高的类别。事实上,许多应用程序需要我们做出选择。例如,Gmail 必须将电子邮件分类为“主要”、“社交”、“更新”、“论坛”或“垃圾邮件”。它可能会在内部估计概率,但最终它必须在类别中选择一个。

当预测与标签 class 一致时y,它们是正确的。分类准确度是所有正确预测的分数。尽管直接优化精度可能很困难(不可微分),但它通常是我们最关心的性能指标。它通常是基准测试中的相关数量。因此,我们几乎总是在训练分类器时报告它。

准确度计算如下。首先,如果y_hat是一个矩阵,我们假设第二个维度存储每个类别的预测分数。我们使用argmax每行中最大条目的索引来获取预测类。然后我们将预测的类别与真实的元素进行比较y。由于相等运算符== 对数据类型敏感,因此我们转换 的y_hat数据类型以匹配 的数据类型y。结果是一个包含条目 0(假)和 1(真)的张量。求和得出正确预测的数量。

@d2l.add_to_class(Classifier) #@save
def accuracy(self, Y_hat, Y, averaged=True):
  """Compute the number of correct predictions."""
  Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1]))
  preds = Y_hat.argmax(axis=1).type(Y.dtype)
  compare = (preds == Y.reshape(-1)).type(torch.float32)
  return compare.mean() if averaged else compare

@d2l.add_to_class(Classifier) #@save
def accuracy(self, Y_hat, Y, averaged=True):
  """Compute the number of correct predictions."""
  Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1]))
  preds = Y_hat.argmax(axis=1).astype(Y.dtype)
  compare = (preds == Y.reshape(-1)).astype(np.float32)
  return compare.mean() if averaged else compare

@d2l.add_to_class(d2l.Module) #@save
def get_scratch_params(self):
  params = []
  for attr in dir(self):
    a = getattr(self, attr)
    if isinstance(a, np.ndarray):
      params.append(a)
    if isinstance(a, d2l.Module):
      params.extend(a.get_scratch_params())
  return params

@d2l.add_to_class(d2l.Module) #@save
def parameters(self):
  params = self.collect_params()
  return params if isinstance(params, gluon.parameter.ParameterDict) and len(
    params.keys()) else self.get_scratch_params()

@d2l.add_to_class(Classifier) #@save
@partial(jax.jit, static_argnums=(0, 5))
def accuracy(self, params, X, Y, state, averaged=True):
  """Compute the number of correct predictions."""
  Y_hat = state.apply_fn({'params': params,
              'batch_stats': state.batch_stats}, # BatchNorm Only
              *X)
  Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1]))
  preds = Y_hat.argmax(axis=1).astype(Y.dtype)
  compare = (preds == Y.reshape(-1)).astype(jnp.float32)
  return compare.mean() if averaged else compare

@d2l.add_to_class(Classifier) #@save
def accuracy(self, Y_hat, Y, averaged=True):
  """Compute the number of correct predictions."""
  Y_hat = tf.reshape(Y_hat, (-1, Y_hat.shape[-1]))
  preds = tf.cast(tf.argmax(Y_hat, axis=1), Y.dtype)
  compare = tf.cast(preds == tf.reshape(Y, -1), tf.float32)
  return tf.reduce_mean(compare) if averaged else compare

4.3.3. 概括

分类是一个足够普遍的问题,它保证了它自己的便利功能。分类中最重要的是 分类器的准确性。请注意,虽然我们通常主要关心准确性,但出于统计和计算原因,我们训练分类器以优化各种其他目标。然而,无论在训练过程中哪个损失函数被最小化,有一个方便的方法来根据经验评估我们的分类器的准确性是有用的。

4.3.4. 练习

表示为Lv验证损失,让Lvq是通过本节中的损失函数平均计算的快速而肮脏的估计。最后,表示为lvb最后一个小批量的损失。表达Lv按照Lvq, lvb,以及样本和小批量大小。

表明快速而肮脏的估计Lvq是公正的。也就是说,表明E[Lv]=E[Lvq]. 为什么你还想使用Lv反而?

给定多类分类损失,表示为l(y,y′) 估计的惩罚y′当我们看到y并给出一个概率p(y∣x), 制定最佳选择规则y′. 提示:表达预期损失,使用 l和p(y∣x).

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • pytorch
    +关注

    关注

    2

    文章

    808

    浏览量

    13235
收藏 人收藏

    评论

    相关推荐

    Pytorch模型训练实用PDF教程【中文】

    模型部分?还是优化器?只有这样不断的通过可视化诊断你的模型,不断的对症下药,才能训练出一个较满意的模型。本教程内容及结构:本教程内容主要为在 PyTorch 中训练一个
    发表于 12-21 09:18

    pyhanlp文本分类与情感分析

    关系如下:训练训练指的是,利用给定训练集寻找一个能描述这种语言现象的模型的过程。开发者只需调用train接口即可,但在实现中,有许多细节。分词目前,本系统中的分词器接口一共有两种实现: 但文本分类是否
    发表于 02-20 15:37

    pytorch模型转化为onxx模型的步骤有哪些

    首先pytorch模型要先转化为onxx模型,然后从onxx模型转化为rknn模型直接转化会出现如下问题,环境都是正确的,论坛询问后也没给出
    发表于 05-09 16:36

    通过Cortex来非常方便的部署PyTorch模型

    ?你可以部署一个 AlexNet 模型,使用 PyTorch 和 Cortex 来标记图像。那语言分类器呢,比如 Chrome 用来检测页面不是用默认语言写的那个?fastText 是这个任务的完美
    发表于 11-01 15:25

    Pytorch模型转换为DeepViewRT模型时出错怎么解决?

    我正在寻求您的帮助以解决以下问题.. 我在 Windows 10 上安装了 eIQ Toolkit 1.7.3,我想将我的 Pytorch 模型转换为 DeepViewRT (.rtm) 模型,这样
    发表于 06-09 06:42

    pytorch模型转换需要注意的事项有哪些?

    什么是JIT(torch.jit)? 答:JIT(Just-In-Time)是一组编译工具,用于弥合PyTorch研究与生产之间的差距。它允许创建可以在不依赖Python解释器的情况下运行的模型
    发表于 09-18 08:05

    基于PLSA主题模型的多标记文本分类_蒋铭初

    基于PLSA主题模型的多标记文本分类_蒋铭初
    发表于 01-08 10:40 0次下载

    textCNN论文与原理——短文本分类

    前言 之前书写了使用pytorch进行短文本分类,其中的数据处理方式比较简单粗暴。自然语言处理领域包含很多任务,很多的数据向之前那样处理的话未免有点繁琐和耗时。在pytorch中众所周知的数据处理包
    的头像 发表于 12-31 10:08 2535次阅读
    textCNN论文与原理——短文<b class='flag-5'>本分类</b>

    结合BERT模型的中文文本分类算法

    针对现有中文短文夲分类算法通常存在特征稀疏、用词不规范和数据海量等问题,提出一种基于Transformer的双向编码器表示(BERT)的中文短文本分类算法,使用BERT预训练语言模型对短文本进行句子
    发表于 03-11 16:10 6次下载
    结合BERT<b class='flag-5'>模型</b>的中文文<b class='flag-5'>本分类</b>算法

    融合文本分类和摘要的多任务学习摘要模型

    文本摘要应包含源文本中所有重要信息,传统基于编码器-解码器架构的摘要模型生成的摘要准确性较低。根据文本分类和文本摘要的相关性,提出一种多任务学习摘要模型。从文本分类辅助任务中学习抽象信
    发表于 04-27 16:18 11次下载
    融合文<b class='flag-5'>本分类</b>和摘要的多任务学习摘要<b class='flag-5'>模型</b>

    基于不同神经网络的文本分类方法研究对比

    神经网络、时间递归神经网络、结构递归神经网络和预训练模型等主流方法在文本分类中应用的发展历程比较不同模型基于常用数据集的分类效果,表明利用人工神经网络伂构自动获取文本特征,可避免繁杂的
    发表于 05-13 16:34 49次下载

    基于LSTM的表示学习-文本分类模型

    的关键。为了获得妤的文本表示,提高文本分类性能,构建了基于LSTM的表示学习-文本分类模型,其中表示学习模型利用语言模型为文
    发表于 06-15 16:17 18次下载

    基于注意力机制的新闻文本分类模型

    基于注意力机制的新闻文本分类模型
    发表于 06-27 15:32 30次下载

    PyTorch本分类任务的基本流程

    本分类是NLP领域的较为容易的入门问题,本文记录文本分类任务的基本流程,大部分操作使用了**torch**和**torchtext**两个库。 ## 1. 文本数据预处理
    的头像 发表于 02-22 14:23 1112次阅读

    PyTorch教程4.3之基本分类模型

    电子发烧友网站提供《PyTorch教程4.3之基本分类模型.pdf》资料免费下载
    发表于 06-05 15:43 0次下载
    <b class='flag-5'>PyTorch</b>教程<b class='flag-5'>4.3</b>之基<b class='flag-5'>本分类</b><b class='flag-5'>模型</b>