目标检测领域没有MNIST和Fashion-MNIST这样的小型数据集。为了快速演示对象检测模型,我们收集并标记了一个小型数据集。首先,我们从办公室拍摄了免费香蕉的照片,并生成了 1000 张不同旋转和大小的香蕉图像。然后我们将每个香蕉图像放置在一些背景图像上的随机位置。最后,我们为图像上的那些香蕉标记了边界框。
14.6.1。下载数据集
带有所有图像和 csv 标签文件的香蕉检测数据集可以直接从互联网上下载。
%matplotlib inline import os import pandas as pd import torch import torchvision from d2l import torch as d2l #@save d2l.DATA_HUB['banana-detection'] = ( d2l.DATA_URL + 'banana-detection.zip', '5de26c8fce5ccdea9f91267273464dc968d20d72')
%matplotlib inline import os import pandas as pd from mxnet import gluon, image, np, npx from d2l import mxnet as d2l npx.set_np() #@save d2l.DATA_HUB['banana-detection'] = ( d2l.DATA_URL + 'banana-detection.zip', '5de26c8fce5ccdea9f91267273464dc968d20d72')
14.6.2。读取数据集
我们将在 read_data_bananas下面的函数中读取香蕉检测数据集。数据集包括一个 csv 文件,用于对象类标签和左上角和右下角的地面实况边界框坐标。
#@save def read_data_bananas(is_train=True): """Read the banana detection dataset images and labels.""" data_dir = d2l.download_extract('banana-detection') csv_fname = os.path.join(data_dir, 'bananas_train' if is_train else 'bananas_val', 'label.csv') csv_data = pd.read_csv(csv_fname) csv_data = csv_data.set_index('img_name') images, targets = [], [] for img_name, target in csv_data.iterrows(): images.append(torchvision.io.read_image( os.path.join(data_dir, 'bananas_train' if is_train else 'bananas_val', 'images', f'{img_name}'))) # Here `target` contains (class, upper-left x, upper-left y, # lower-right x, lower-right y), where all the images have the same # banana class (index 0) targets.append(list(target)) return images, torch.tensor(targets).unsqueeze(1) / 256
#@save def read_data_bananas(is_train=True): """Read the banana detection dataset images and labels.""" data_dir = d2l.download_extract('banana-detection') csv_fname = os.path.join(data_dir, 'bananas_train' if is_train else 'bananas_val', 'label.csv') csv_data = pd.read_csv(csv_fname) csv_data = csv_data.set_index('img_name') images, targets = [], [] for img_name, target in csv_data.iterrows(): images.append(image.imread( os.path.join(data_dir, 'bananas_train' if is_train else 'bananas_val', 'images', f'{img_name}'))) # Here `target` contains (class, upper-left x, upper-left y, # lower-right x, lower-right y), where all the images have the same # banana class (index 0) targets.append(list(target)) return images, np.expand_dims(np.array(targets), 1) / 256
通过使用read_data_bananas函数读取图像和标签,下面的BananasDataset类将允许我们创建一个自定义Dataset实例来加载香蕉检测数据集。
#@save class BananasDataset(torch.utils.data.Dataset): """A customized dataset to load the banana detection dataset.""" def __init__(self, is_train): self.features, self.labels = read_data_bananas(is_train) print('read ' + str(len(self.features)) + (f' training examples' if is_train else f' validation examples')) def __getitem__(self, idx): return (self.features[idx].float(), self.labels[idx]) def __len__(self): return len(self.features)
#@save class BananasDataset(gluon.data.Dataset): """A customized dataset to load the banana detection dataset.""" def __init__(self, is_train): self.features, self.labels = read_data_bananas(is_train) print('read ' + str(len(self.features)) + (f' training examples' if is_train else f' validation examples')) def __getitem__(self, idx): return (self.features[idx].astype('float32').transpose(2, 0, 1), self.labels[idx]) def __len__(self): return len(self.features)
最后,我们定义load_data_bananas函数为训练集和测试集返回两个数据迭代器实例。对于测试数据集,不需要随机读取。
#@save def load_data_bananas(batch_size): """Load the banana detection dataset.""" train_iter = torch.utils.data.DataLoader(BananasDataset(is_train=True), batch_size, shuffle=True) val_iter = torch.utils.data.DataLoader(BananasDataset(is_train=False), batch_size) return train_iter, val_iter
#@save def load_data_bananas(batch_size): """Load the banana detection dataset.""" train_iter = gluon.data.DataLoader(BananasDataset(is_train=True), batch_size, shuffle=True) val_iter = gluon.data.DataLoader(BananasDataset(is_train=False), batch_size) return train_iter, val_iter
让我们读取一个 minibatch 并打印这个 minibatch 中图像和标签的形状。图像小批量的形状(批量大小、通道数、高度、宽度)看起来很熟悉:它与我们之前的图像分类任务相同。label minibatch的shape是(batch size,m, 5), 其中m是任何图像在数据集中具有的最大可能数量的边界框。
虽然 minibatch 的计算效率更高,但它要求所有图像示例都包含相同数量的边界框,以通过连接形成一个 minibatch。通常,图像可能具有不同数量的边界框;因此,图像少于m 边界框将被非法边界框填充,直到 m到达了。然后每个边界框的标签用一个长度为5的数组表示,数组的第一个元素是边界框中物体的类,其中-1表示填充的非法边界框。数组的其余四个元素是 (x,y)-边界框左上角和右下角的坐标值(范围在0到1之间)。对于香蕉数据集,由于每张图像上只有一个边界框,我们有m=1.
batch_size, edge_size = 32, 256 train_iter, _ = load_data_bananas(batch_size) batch = next(iter(train_iter)) batch[0].shape, batch[1].shape
read 1000 training examples read 100 validation examples
(torch.Size([32, 3, 256, 256]), torch.Size([32, 1, 5]))
batch_size, edge_size = 32, 256 train_iter, _ = load_data_bananas(batch_size) batch = next(iter(train_iter)) batch[0].shape, batch[1].shape
read 1000 training examples read 100 validation examples
((32, 3, 256, 256), (32, 1, 5))
14.6.3。示范
让我们演示十张带有标记的真实边界框的图像。我们可以看到香蕉的旋转、大小和位置在所有这些图像中都不同。当然,这只是一个简单的人工数据集。实际上,真实世界的数据集通常要复杂得多。
imgs = (batch[0][:10].permute(0, 2, 3, 1)) / 255 axes = d2l.show_images(imgs, 2, 5, scale=2) for ax, label in zip(axes, batch[1][:10]): d2l.show_bboxes(ax, [label[0][1:5] * edge_size], colors=['w'])
imgs = (batch[0][:10].transpose(0, 2, 3, 1)) / 255 axes = d2l.show_images(imgs, 2, 5, scale=2) for ax, label in zip(axes, batch[1][:10]): d2l.show_bboxes(ax, [label[0][1:5] * edge_size], colors=['w'])
14.6.4。概括
我们收集的香蕉检测数据集可用于演示对象检测模型。
目标检测的数据加载类似于图像分类。然而,在目标检测中,标签还包含图像分类中缺少的真实边界框信息。
14.6.5。练习
演示香蕉检测数据集中带有真实边界框的其他图像。它们在边界框和对象方面有何不同?
假设我们要将数据增强(例如随机裁剪)应用于对象检测。它与图像分类中的有何不同?提示:如果裁剪后的图像只包含物体的一小部分怎么办?
-
数据集
+关注
关注
4文章
1208浏览量
24742 -
pytorch
+关注
关注
2文章
808浏览量
13256
发布评论请先 登录
相关推荐
评论