0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

PyTorch教程-14.6. 对象检测数据集

jf_pJlTbmA9 来源:PyTorch 作者:PyTorch 2023-06-05 15:44 次阅读

目标检测领域没有MNIST和Fashion-MNIST这样的小型数据集。为了快速演示对象检测模型,我们收集并标记了一个小型数据集。首先,我们从办公室拍摄了免费香蕉的照片,并生成了 1000 张不同旋转和大小的香蕉图像。然后我们将每个香蕉图像放置在一些背景图像上的随机位置。最后,我们为图像上的那些香蕉标记了边界框。

14.6.1。下载数据集

带有所有图像和 csv 标签文件的香蕉检测数据集可以直接从互联网上下载。

%matplotlib inline
import os
import pandas as pd
import torch
import torchvision
from d2l import torch as d2l

#@save
d2l.DATA_HUB['banana-detection'] = (
  d2l.DATA_URL + 'banana-detection.zip',
  '5de26c8fce5ccdea9f91267273464dc968d20d72')

%matplotlib inline
import os
import pandas as pd
from mxnet import gluon, image, np, npx
from d2l import mxnet as d2l

npx.set_np()

#@save
d2l.DATA_HUB['banana-detection'] = (
  d2l.DATA_URL + 'banana-detection.zip',
  '5de26c8fce5ccdea9f91267273464dc968d20d72')

14.6.2。读取数据集

我们将在 read_data_bananas下面的函数中读取香蕉检测数据集。数据集包括一个 csv 文件,用于对象类标签和左上角和右下角的地面实况边界框坐标。

#@save
def read_data_bananas(is_train=True):
  """Read the banana detection dataset images and labels."""
  data_dir = d2l.download_extract('banana-detection')
  csv_fname = os.path.join(data_dir, 'bananas_train' if is_train
               else 'bananas_val', 'label.csv')
  csv_data = pd.read_csv(csv_fname)
  csv_data = csv_data.set_index('img_name')
  images, targets = [], []
  for img_name, target in csv_data.iterrows():
    images.append(torchvision.io.read_image(
      os.path.join(data_dir, 'bananas_train' if is_train else
             'bananas_val', 'images', f'{img_name}')))
    # Here `target` contains (class, upper-left x, upper-left y,
    # lower-right x, lower-right y), where all the images have the same
    # banana class (index 0)
    targets.append(list(target))
  return images, torch.tensor(targets).unsqueeze(1) / 256

#@save
def read_data_bananas(is_train=True):
  """Read the banana detection dataset images and labels."""
  data_dir = d2l.download_extract('banana-detection')
  csv_fname = os.path.join(data_dir, 'bananas_train' if is_train
               else 'bananas_val', 'label.csv')
  csv_data = pd.read_csv(csv_fname)
  csv_data = csv_data.set_index('img_name')
  images, targets = [], []
  for img_name, target in csv_data.iterrows():
    images.append(image.imread(
      os.path.join(data_dir, 'bananas_train' if is_train else
             'bananas_val', 'images', f'{img_name}')))
    # Here `target` contains (class, upper-left x, upper-left y,
    # lower-right x, lower-right y), where all the images have the same
    # banana class (index 0)
    targets.append(list(target))
  return images, np.expand_dims(np.array(targets), 1) / 256

通过使用read_data_bananas函数读取图像和标签,下面的BananasDataset类将允许我们创建一个自定义Dataset实例来加载香蕉检测数据集。

#@save
class BananasDataset(torch.utils.data.Dataset):
  """A customized dataset to load the banana detection dataset."""
  def __init__(self, is_train):
    self.features, self.labels = read_data_bananas(is_train)
    print('read ' + str(len(self.features)) + (f' training examples' if
       is_train else f' validation examples'))

  def __getitem__(self, idx):
    return (self.features[idx].float(), self.labels[idx])

  def __len__(self):
    return len(self.features)

#@save
class BananasDataset(gluon.data.Dataset):
  """A customized dataset to load the banana detection dataset."""
  def __init__(self, is_train):
    self.features, self.labels = read_data_bananas(is_train)
    print('read ' + str(len(self.features)) + (f' training examples' if
       is_train else f' validation examples'))

  def __getitem__(self, idx):
    return (self.features[idx].astype('float32').transpose(2, 0, 1),
        self.labels[idx])

  def __len__(self):
    return len(self.features)

最后,我们定义load_data_bananas函数为训练集和测试集返回两个数据迭代器实例。对于测试数据集,不需要随机读取。

#@save
def load_data_bananas(batch_size):
  """Load the banana detection dataset."""
  train_iter = torch.utils.data.DataLoader(BananasDataset(is_train=True),
                       batch_size, shuffle=True)
  val_iter = torch.utils.data.DataLoader(BananasDataset(is_train=False),
                      batch_size)
  return train_iter, val_iter

#@save
def load_data_bananas(batch_size):
  """Load the banana detection dataset."""
  train_iter = gluon.data.DataLoader(BananasDataset(is_train=True),
                    batch_size, shuffle=True)
  val_iter = gluon.data.DataLoader(BananasDataset(is_train=False),
                   batch_size)
  return train_iter, val_iter

让我们读取一个 minibatch 并打印这个 minibatch 中图像和标签的形状。图像小批量的形状(批量大小、通道数、高度、宽度)看起来很熟悉:它与我们之前的图像分类任务相同。label minibatch的shape是(batch size,m, 5), 其中m是任何图像在数据集中具有的最大可能数量的边界框。

虽然 minibatch 的计算效率更高,但它要求所有图像示例都包含相同数量的边界框,以通过连接形成一个 minibatch。通常,图像可能具有不同数量的边界框;因此,图像少于m 边界框将被非法边界框填充,直到 m到达了。然后每个边界框的标签用一个长度为5的数组表示,数组的第一个元素是边界框中物体的类,其中-1表示填充的非法边界框。数组的其余四个元素是 (x,y)-边界框左上角和右下角的坐标值(范围在0到1之间)。对于香蕉数据集,由于每张图像上只有一个边界框,我们有m=1.

batch_size, edge_size = 32, 256
train_iter, _ = load_data_bananas(batch_size)
batch = next(iter(train_iter))
batch[0].shape, batch[1].shape

read 1000 training examples
read 100 validation examples

(torch.Size([32, 3, 256, 256]), torch.Size([32, 1, 5]))

batch_size, edge_size = 32, 256
train_iter, _ = load_data_bananas(batch_size)
batch = next(iter(train_iter))
batch[0].shape, batch[1].shape

read 1000 training examples
read 100 validation examples

((32, 3, 256, 256), (32, 1, 5))

14.6.3。示范

让我们演示十张带有标记的真实边界框的图像。我们可以看到香蕉的旋转、大小和位置在所有这些图像中都不同。当然,这只是一个简单的人工数据集。实际上,真实世界的数据集通常要复杂得多。

imgs = (batch[0][:10].permute(0, 2, 3, 1)) / 255
axes = d2l.show_images(imgs, 2, 5, scale=2)
for ax, label in zip(axes, batch[1][:10]):
  d2l.show_bboxes(ax, [label[0][1:5] * edge_size], colors=['w'])

poYBAGR4YoqASRiBAAZcwltTfMw221.png

imgs = (batch[0][:10].transpose(0, 2, 3, 1)) / 255
axes = d2l.show_images(imgs, 2, 5, scale=2)
for ax, label in zip(axes, batch[1][:10]):
  d2l.show_bboxes(ax, [label[0][1:5] * edge_size], colors=['w'])

pYYBAGR4Yo6ACb2eAAY5zXYFqT8820.png

14.6.4。概括

我们收集的香蕉检测数据集可用于演示对象检测模型。

目标检测的数据加载类似于图像分类。然而,在目标检测中,标签还包含图像分类中缺少的真实边界框信息

14.6.5。练习

演示香蕉检测数据集中带有真实边界框的其他图像。它们在边界框和对象方面有何不同?

假设我们要将数据增强(例如随机裁剪)应用于对象检测。它与图像分类中的有何不同?提示:如果裁剪后的图像只包含物体的一小部分怎么办?

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 数据集
    +关注

    关注

    4

    文章

    1208

    浏览量

    24742
  • pytorch
    +关注

    关注

    2

    文章

    808

    浏览量

    13256
收藏 人收藏

    评论

    相关推荐

    基于数据对象平均离群因子的离群点选择算法

    基于数据对象平均离群因子的离群点选择算法_朱付保
    发表于 01-03 17:41 0次下载

    利用Python和PyTorch处理面向对象数据

    本篇是利用 Python 和 PyTorch 处理面向对象数据系列博客的第 2 篇。 如需阅读第 1 篇:原始数据
    的头像 发表于 08-25 15:30 3001次阅读

    利用 Python 和 PyTorch 处理面向对象数据(2)) :创建数据对象

    本篇是利用 Python 和 PyTorch 处理面向对象数据系列博客的第 2 篇。我们在第 1 部分中已定义 MyDataset 类,现在,让我们来例化 MyDataset
    的头像 发表于 08-02 17:35 944次阅读
    利用 Python 和 <b class='flag-5'>PyTorch</b> 处理面向<b class='flag-5'>对象</b>的<b class='flag-5'>数据</b><b class='flag-5'>集</b>(2)) :创建<b class='flag-5'>数据</b><b class='flag-5'>集</b><b class='flag-5'>对象</b>

    利用Python和PyTorch处理面向对象数据(1)

    在本文中,我们将提供一种高效方法,用于完成数据的交互、组织以及最终变换(预处理)。随后,我们将讲解如何在训练过程中正确地把数据输入给模型。PyTorch 框架将帮助我们实现此目标,我们还将从头开始编写几个类。
    的头像 发表于 08-02 08:03 692次阅读

    PyTorch教程3.2之面向对象的设计实现

    电子发烧友网站提供《PyTorch教程3.2之面向对象的设计实现.pdf》资料免费下载
    发表于 06-05 15:48 0次下载
    <b class='flag-5'>PyTorch</b>教程3.2之面向<b class='flag-5'>对象</b>的设计实现

    PyTorch教程4.2之图像分类数据

    电子发烧友网站提供《PyTorch教程4.2之图像分类数据.pdf》资料免费下载
    发表于 06-05 15:41 0次下载
    <b class='flag-5'>PyTorch</b>教程4.2之图像分类<b class='flag-5'>数据</b><b class='flag-5'>集</b>

    PyTorch教程10.5之机器翻译和数据

    电子发烧友网站提供《PyTorch教程10.5之机器翻译和数据.pdf》资料免费下载
    发表于 06-05 15:14 0次下载
    <b class='flag-5'>PyTorch</b>教程10.5之机器翻译和<b class='flag-5'>数据</b><b class='flag-5'>集</b>

    PyTorch教程14.6对象检测数据

    电子发烧友网站提供《PyTorch教程14.6对象检测数据.pdf》资料免费下载
    发表于 06-05 11:23 0次下载
    <b class='flag-5'>PyTorch</b>教程<b class='flag-5'>14.6</b>之<b class='flag-5'>对象</b><b class='flag-5'>检测</b><b class='flag-5'>数据</b><b class='flag-5'>集</b>

    PyTorch教程14.9之语义分割和数据

    电子发烧友网站提供《PyTorch教程14.9之语义分割和数据.pdf》资料免费下载
    发表于 06-05 11:10 0次下载
    <b class='flag-5'>PyTorch</b>教程14.9之语义分割和<b class='flag-5'>数据</b><b class='flag-5'>集</b>

    PyTorch教程15.9之预训练BERT的数据

    电子发烧友网站提供《PyTorch教程15.9之预训练BERT的数据.pdf》资料免费下载
    发表于 06-05 11:06 0次下载
    <b class='flag-5'>PyTorch</b>教程15.9之预训练BERT的<b class='flag-5'>数据</b><b class='flag-5'>集</b>

    PyTorch教程16.1之情绪分析和数据

    电子发烧友网站提供《PyTorch教程16.1之情绪分析和数据.pdf》资料免费下载
    发表于 06-05 10:54 0次下载
    <b class='flag-5'>PyTorch</b>教程16.1之情绪分析和<b class='flag-5'>数据</b><b class='flag-5'>集</b>

    PyTorch教程16.4之自然语言推理和数据

    电子发烧友网站提供《PyTorch教程16.4之自然语言推理和数据.pdf》资料免费下载
    发表于 06-05 10:57 0次下载
    <b class='flag-5'>PyTorch</b>教程16.4之自然语言推理和<b class='flag-5'>数据</b><b class='flag-5'>集</b>

    没有“中间商赚差价”, OpenVINO™ 直接支持 PyTorch 模型对象

    体验—— OpenVINO 的 mo 工具可以直接将 PyTorch 模型对象转化为 OpenVINO 的模型对象,开发者可以不需要将 ONNX 模型作为中间过渡。
    的头像 发表于 06-27 16:39 778次阅读
    没有“中间商赚差价”, OpenVINO™ 直接支持 <b class='flag-5'>PyTorch</b> 模型<b class='flag-5'>对象</b>

    YOLOv8实现旋转对象检测

    YOLOv8框架在在支持分类、对象检测、实例分割、姿态评估的基础上更近一步,现已经支持旋转对象检测(OBB),基于DOTA数据
    的头像 发表于 01-11 10:43 1904次阅读
    YOLOv8实现旋转<b class='flag-5'>对象</b><b class='flag-5'>检测</b>

    PyTorch如何训练自己的数据

    PyTorch是一个广泛使用的深度学习框架,它以其灵活性、易用性和强大的动态图特性而闻名。在训练深度学习模型时,数据是不可或缺的组成部分。然而,很多时候,我们可能需要使用自己的数据
    的头像 发表于 07-02 14:09 1826次阅读