0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

PyTorch教程-4.2. 图像分类数据集

jf_pJlTbmA9 来源:PyTorch 作者:PyTorch 2023-06-05 15:38 次阅读

广泛用于图像分类的数据集之一是手写数字的MNIST 数据集 (LeCun等人,1998 年) 。在 1990 年代发布时,它对大多数机器学习算法提出了巨大挑战,其中包含 60,000 张图像 28×28像素分辨率(加上 10,000 张图像的测试数据集)。客观地说,在 1995 年,配备高达 64MB RAM 和惊人的 5 MFLOPs 的 Sun SPARCStation 5 被认为是 AT&T 贝尔实验室最先进的机器学习设备。实现数字识别的高精度是一个1990 年代 USPS 自动分拣信件的关键组件。深度网络,如 LeNet-5 (LeCun等人,1995 年)、具有不变性的支持向量机 (Schölkopf等人,1996 年)和切线距离分类器 (Simard等人,1998 年)都允许达到 1% 以下的错误率。

十多年来,MNIST 一直是比较机器学习算法的参考点。虽然它作为基准数据集运行良好,但即使是按照当今标准的简单模型也能达到 95% 以上的分类准确率,这使得它不适合区分强模型和弱模型。更重要的是,数据集允许非常高的准确性,这在许多分类问题中通常是看不到的。这种算法的发展偏向于可以利用干净数据集的特定算法系列,例如活动集方法和边界搜索活动集算法。今天,MNIST 更像是一种健全性检查,而不是基准。ImageNet ( Deng et al. , 2009 )提出了一个更相关的挑战。不幸的是,对于本书中的许多示例和插图来说,ImageNet 太大了,因为训练这些示例需要很长时间才能使示例具有交互性。作为替代,我们将在接下来的部分中重点讨论定性相似但规模小得多的 Fashion-MNIST 数据集(Xiao等人,2017 年),该数据集于 2017 年发布。它包含 10 类服装的图像 28×28像素分辨率。

%matplotlib inline
import time
import torch
import torchvision
from torchvision import transforms
from d2l import torch as d2l

d2l.use_svg_display()

%matplotlib inline
import time
from mxnet import gluon, npx
from mxnet.gluon.data.vision import transforms
from d2l import mxnet as d2l

npx.set_np()

d2l.use_svg_display()

%matplotlib inline
import time
import jax
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from jax import numpy as jnp
from d2l import jax as d2l

d2l.use_svg_display()

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

%matplotlib inline
import time
import tensorflow as tf
from d2l import tensorflow as d2l

d2l.use_svg_display()

4.2.1. 加载数据集

由于它是一个经常使用的数据集,所有主要框架都提供了它的预处理版本。我们可以使用内置的框架实用程序将 Fashion-MNIST 数据集下载并读取到内存中。

class FashionMNIST(d2l.DataModule): #@save
  """The Fashion-MNIST dataset."""
  def __init__(self, batch_size=64, resize=(28, 28)):
    super().__init__()
    self.save_hyperparameters()
    trans = transforms.Compose([transforms.Resize(resize),
                  transforms.ToTensor()])
    self.train = torchvision.datasets.FashionMNIST(
      root=self.root, train=True, transform=trans, download=True)
    self.val = torchvision.datasets.FashionMNIST(
      root=self.root, train=False, transform=trans, download=True)

class FashionMNIST(d2l.DataModule): #@save
  """The Fashion-MNIST dataset."""
  def __init__(self, batch_size=64, resize=(28, 28)):
    super().__init__()
    self.save_hyperparameters()
    trans = transforms.Compose([transforms.Resize(resize),
                  transforms.ToTensor()])
    self.train = gluon.data.vision.FashionMNIST(
      train=True).transform_first(trans)
    self.val = gluon.data.vision.FashionMNIST(
      train=False).transform_first(trans)

class FashionMNIST(d2l.DataModule): #@save
  """The Fashion-MNIST dataset."""
  def __init__(self, batch_size=64, resize=(28, 28)):
    super().__init__()
    self.save_hyperparameters()
    self.train, self.val = tf.keras.datasets.fashion_mnist.load_data()

class FashionMNIST(d2l.DataModule): #@save
  """The Fashion-MNIST dataset."""
  def __init__(self, batch_size=64, resize=(28, 28)):
    super().__init__()
    self.save_hyperparameters()
    self.train, self.val = tf.keras.datasets.fashion_mnist.load_data()

Fashion-MNIST 包含来自 10 个类别的图像,每个类别在训练数据集中由 6,000 个图像表示,在测试数据集中由 1,000 个图像表示。测试 数据集用于评估模型性能(不得用于训练)。因此,训练集和测试集分别包含 60,000 和 10,000 张图像。

data = FashionMNIST(resize=(32, 32))
len(data.train), len(data.val)

(60000, 10000)

data = FashionMNIST(resize=(32, 32))
len(data.train), len(data.val)

(60000, 10000)

data = FashionMNIST(resize=(32, 32))
len(data.train[0]), len(data.val[0])

(60000, 10000)

data = FashionMNIST(resize=(32, 32))
len(data.train[0]), len(data.val[0])

(60000, 10000)

图像是灰度和放大到32×32分辨率以上的像素。这类似于由(二进制)黑白图像组成的原始 MNIST 数据集。但请注意,大多数具有 3 个通道(红色、绿色、蓝色)的现代图像数据和超过 100 个通道的高光谱图像(HyMap 传感器有 126 个通道)。按照惯例,我们将图像存储为 c×h×w张量,其中c是颜色通道数,h是高度和w是宽度。

data.train[0][0].shape

torch.Size([1, 32, 32])

data.train[0][0].shape

(1, 32, 32)

data.train[0][0].shape

(28, 28)

data.train[0][0].shape

(28, 28)

Fashion-MNIST 的类别具有人类可理解的名称。以下便捷方法在数字标签及其名称之间进行转换。

@d2l.add_to_class(FashionMNIST) #@save
def text_labels(self, indices):
  """Return text labels."""
  labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
       'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
  return [labels[int(i)] for i in indices]

4.2.2. 读取一个小批量

为了让我们在读取训练集和测试集时更轻松,我们使用内置的数据迭代器而不是从头开始创建一个。回想一下,在每次迭代中,数据迭代器读取一个大小为 的小批量数据batch_size。我们还随机打乱训练数据迭代器的示例。

@d2l.add_to_class(FashionMNIST) #@save
def get_dataloader(self, train):
  data = self.train if train else self.val
  return torch.utils.data.DataLoader(data, self.batch_size, shuffle=train,
                    num_workers=self.num_workers)

@d2l.add_to_class(FashionMNIST) #@save
def get_dataloader(self, train):
  data = self.train if train else self.val
  return gluon.data.DataLoader(data, self.batch_size, shuffle=train,
                 num_workers=self.num_workers)

@d2l.add_to_class(FashionMNIST) #@save
def get_dataloader(self, train):
  data = self.train if train else self.val
  process = lambda X, y: (tf.expand_dims(X, axis=3) / 255,
              tf.cast(y, dtype='int32'))
  resize_fn = lambda X, y: (tf.image.resize_with_pad(X, *self.resize), y)
  shuffle_buf = len(data[0]) if train else 1
  return tfds.as_numpy(
    tf.data.Dataset.from_tensor_slices(process(*data)).batch(
      self.batch_size).map(resize_fn).shuffle(shuffle_buf))

@d2l.add_to_class(FashionMNIST) #@save
def get_dataloader(self, train):
  data = self.train if train else self.val
  process = lambda X, y: (tf.expand_dims(X, axis=3) / 255,
              tf.cast(y, dtype='int32'))
  resize_fn = lambda X, y: (tf.image.resize_with_pad(X, *self.resize), y)
  shuffle_buf = len(data[0]) if train else 1
  return tf.data.Dataset.from_tensor_slices(process(*data)).batch(
    self.batch_size).map(resize_fn).shuffle(shuffle_buf)

为了了解这是如何工作的,让我们通过调用该 train_dataloader方法来加载一小批图像。它包含 64 张图像。

X, y = next(iter(data.train_dataloader()))
print(X.shape, X.dtype, y.shape, y.dtype)

torch.Size([64, 1, 32, 32]) torch.float32 torch.Size([64]) torch.int64

X, y = next(iter(data.train_dataloader()))
print(X.shape, X.dtype, y.shape, y.dtype)

(64, 1, 32, 32) float32 (64,) int32

X, y = next(iter(data.train_dataloader()))
print(X.shape, X.dtype, y.shape, y.dtype)

WARNING:tensorflow:From /home/d2l-worker/miniconda3/envs/d2l-en-release-1/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
(64, 32, 32, 1) float32 (64,) int32

X, y = next(iter(data.train_dataloader()))
print(X.shape, X.dtype, y.shape, y.dtype)

WARNING:tensorflow:From /home/d2l-worker/miniconda3/envs/d2l-en-release-1/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
(64, 32, 32, 1)  (64,) 

让我们看看读取图像所花费的时间。尽管它是一个内置的加载程序,但速度并不快。尽管如此,这已经足够了,因为使用深度网络处理图像需要更长的时间。因此,训练网络不受 IO 约束就足够了。

tic = time.time()
for X, y in data.train_dataloader():
  continue
f'{time.time() - tic:.2f} sec'

'5.06 sec'

tic = time.time()
for X, y in data.train_dataloader():
  continue
f'{time.time() - tic:.2f} sec'

'4.12 sec'

tic = time.time()
for X, y in data.train_dataloader():
  continue
f'{time.time() - tic:.2f} sec'

'0.96 sec'

tic = time.time()
for X, y in data.train_dataloader():
  continue
f'{time.time() - tic:.2f} sec'

'0.95 sec'

4.2.3. 可视化

我们将经常使用 Fashion-MNIST 数据集。一个便利的功能show_images可以用来可视化图像和相关的标签。其实施细节推迟到附录。

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save
  """Plot a list of images."""
  raise NotImplementedError

让我们好好利用它。通常,可视化和检查您正在训练的数据是个好主意。人类非常善于发现不寻常的方面,因此,可视化可以作为一种额外的保护措施,防止实验设计中的错误和错误。以下是训练数据集中前几个示例的图像及其相应标签(文本)。

@d2l.add_to_class(FashionMNIST) #@save
def visualize(self, batch, nrows=1, ncols=8, labels=[]):
  X, y = batch
  if not labels:
    labels = self.text_labels(y)
  d2l.show_images(X.squeeze(1), nrows, ncols, titles=labels)
batch = next(iter(data.val_dataloader()))
data.visualize(batch)

pYYBAGR5VLOAE8DAAAFXlI5prpg972.svg

@d2l.add_to_class(FashionMNIST) #@save
def visualize(self, batch, nrows=1, ncols=8, labels=[]):
  X, y = batch
  if not labels:
    labels = self.text_labels(y)
  d2l.show_images(X.squeeze(1), nrows, ncols, titles=labels)
batch = next(iter(data.val_dataloader()))
data.visualize(batch)

poYBAGR5VLWABCDeAAFUVW5zHbQ247.svg

@d2l.add_to_class(FashionMNIST) #@save
def visualize(self, batch, nrows=1, ncols=8, labels=[]):
  X, y = batch
  if not labels:
    labels = self.text_labels(y)
  d2l.show_images(jnp.squeeze(X), nrows, ncols, titles=labels)

batch = next(iter(data.val_dataloader()))
data.visualize(batch)

pYYBAGR5VLiAMQdTAAFW9OrJp3Q736.svg

@d2l.add_to_class(FashionMNIST) #@save
def visualize(self, batch, nrows=1, ncols=8, labels=[]):
  X, y = batch
  if not labels:
    labels = self.text_labels(y)
  d2l.show_images(tf.squeeze(X), nrows, ncols, titles=labels)
batch = next(iter(data.val_dataloader()))
data.visualize(batch)

pYYBAGR5VLiAMQdTAAFW9OrJp3Q736.svg

我们现在准备好在接下来的部分中使用 Fashion-MNIST 数据集。

4.2.4. 概括

我们现在有一个稍微更真实的数据集用于分类。Fashion-MNIST 是一个服装分类数据集,由代表 10 个类别的图像组成。我们将在后续部分和章节中使用该数据集来评估各种网络设计,从简单的线性模型到高级残差网络。正如我们通常对图像所做的那样,我们将它们读取为形状的张量(批量大小、通道数、高度、宽度)。目前,我们只有一个通道,因为图像是灰度的(上面的可视化使用假调色板来提高可见性)。

最后,数据迭代器是实现高效性能的关键组件。例如,我们可能会使用 GPU 进行高效的图像解压缩、视频转码或其他预处理。只要有可能,您就应该依靠利用高性能计算的良好实现的数据迭代器来避免减慢您的训练循环。

4.2.5. 练习

减少batch_size(例如,减少到 1)会影响阅读性能吗?

数据迭代器的性能很重要。您认为当前的实施是否足够快?探索改进它的各种选项。使用系统分析器找出瓶颈所在。

查看框架的在线 API 文档。还有哪些其他数据集可用?

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 数据集
    +关注

    关注

    4

    文章

    1208

    浏览量

    24697
  • pytorch
    +关注

    关注

    2

    文章

    808

    浏览量

    13221
收藏 人收藏

    评论

    相关推荐

    使用卷积神经网络进行图像分类的步骤

    (例如,高分辨率、不同光照条件等)。 2. 数据收集 获取数据 :收集或购买一个包含你想要分类图像
    的头像 发表于 11-15 15:01 298次阅读

    主动学习在图像分类技术中的应用:当前状态与未来展望

    基于Transformer结构提升模型预测性能,以确保模型预测结果的可靠性。 此外,本文还对各类主动学习图像分类算法下的重要学术工作进行了实验对比,并对各算法在不同规模数据上的
    的头像 发表于 11-14 10:12 292次阅读
    主动学习在<b class='flag-5'>图像</b><b class='flag-5'>分类</b>技术中的应用:当前状态与未来展望

    PyTorch 数据加载与处理方法

    PyTorch 是一个流行的开源机器学习库,它提供了强大的工具来构建和训练深度学习模型。在构建模型之前,一个重要的步骤是加载和处理数据。 1. PyTorch 数据加载基础 在
    的头像 发表于 11-05 17:37 399次阅读

    如何在 PyTorch 中训练模型

    准备好数据PyTorch 提供了 torch.utils.data.Dataset 和 torch.utils.data.DataLoader 两个类来帮助我们加载和批量处理数据
    的头像 发表于 11-05 17:36 331次阅读

    pytorch怎么在pycharm中运行

    第一部分:PyTorch和PyCharm的安装 1.1 安装PyTorch PyTorch是一个开源的机器学习库,用于构建和训练神经网络。要在PyCharm中使用PyTorch,首先需
    的头像 发表于 08-01 16:22 1412次阅读

    基于PyTorch的卷积核实例应用

    在深度学习和计算机视觉领域,卷积操作是一种至关重要的技术,尤其在图像处理和特征提取方面发挥着核心作用。PyTorch作为当前最流行的深度学习框架之一,提供了强大的张量操作功能和灵活的API,使得实现
    的头像 发表于 07-11 15:19 460次阅读

    如何在PyTorch中实现LeNet-5网络

    等人提出,主要用于手写数字识别任务(如MNIST数据)。下面,我将详细阐述如何在PyTorch中从头开始实现LeNet-5网络,包括网络架构设计、参数初始化、前向传播、损失函数选择、优化器配置以及训练流程等方面。
    的头像 发表于 07-11 10:58 789次阅读

    pytorch如何训练自己的数据

    本文将详细介绍如何使用PyTorch框架来训练自己的数据。我们将从数据准备、模型构建、训练过程、评估和测试等方面进行讲解。 环境搭建 首先,我们需要安装PyTorch。可以通过访问
    的头像 发表于 07-11 10:04 529次阅读

    pytorch中有神经网络模型吗

    处理、语音识别等领域取得了显著的成果。PyTorch是一个开源的深度学习框架,由Facebook的AI研究团队开发。它以其易用性、灵活性和高效性而受到广泛欢迎。在PyTorch中,有许多预训练的神经网络模型可供选择,这些模型可以用于各种任务,如
    的头像 发表于 07-11 09:59 700次阅读

    PyTorch的介绍与使用案例

    PyTorch是一个基于Python的开源机器学习库,它主要面向深度学习和科学计算领域。PyTorch由Meta Platforms(原Facebook)的人工智能研究团队开发,并逐渐发展成为深度
    的头像 发表于 07-10 14:19 397次阅读

    计算机视觉怎么给图像分类

    图像分类是计算机视觉领域中的一项核心任务,其目标是将输入的图像自动分配到预定义的类别集合中。这一过程涉及图像的特征提取、特征表示以及分类器的
    的头像 发表于 07-08 17:06 679次阅读

    tensorflow和pytorch哪个更简单?

    PyTorch更简单。选择TensorFlow还是PyTorch取决于您的具体需求和偏好。如果您需要一个易于使用、灵活且具有强大社区支持的框架,PyTorch可能是一个更好的选择。如果您需要一个在
    的头像 发表于 07-05 09:45 861次阅读

    PyTorch如何训练自己的数据

    PyTorch是一个广泛使用的深度学习框架,它以其灵活性、易用性和强大的动态图特性而闻名。在训练深度学习模型时,数据是不可或缺的组成部分。然而,很多时候,我们可能需要使用自己的数据
    的头像 发表于 07-02 14:09 1691次阅读

    PyTorch中激活函数的全面概览

    为了更清晰地学习Pytorch中的激活函数,并对比它们之间的不同,这里对最新版本的Pytorch中的激活函数进行了汇总,主要介绍激活函数的公式、图像以及使用方法,具体细节可查看官方文档。
    的头像 发表于 04-30 09:26 548次阅读
    <b class='flag-5'>PyTorch</b>中激活函数的全面概览

    了解如何使用PyTorch构建图神经网络

    图神经网络直接应用于图数据,您可以训练它们以预测节点、边缘和与图相关的任务。它用于图和节点分类、链路预测、图聚类和生成,以及图像和文本分类
    发表于 02-21 12:19 768次阅读
    了解如何使用<b class='flag-5'>PyTorch</b>构建图神经网络