Dataset 数据集

Pytorch中表示数据集的抽象类//

任何自定义的数据集都需要继承这个类并覆写相关方法

Dataset的描述:

所有表示从键到数据样本的映射的数据集都应继承它。
所有子类都应覆写getitem,以支持获取给定键的数据样本。
子类还可以选择性地覆盖len,许多~torch.utils.data.Sampler类实现和~torch.utils.data.DataLoader类的默认选项都希望它返回数据集的大小。
子类也可以选择实现getitems,以加快成批样本的加载速度。
此方法接受批次样本的索引列表,并返回样本列表。

假设需要加载的是当前目录下images文件夹中的42张png图片

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import os
from torch.utils.data import Dataset, DataLoader
from PIL import Image

class MyDataset(Dataset):
def __init__(self, path, processor=None):
"""Dataset类的初始化方法

Args:
path: 数据路径列表
processor: 数据预处理的函数,f:数据路径->目标数据
"""
self.path = path
self.processor = processor

def __len__(self):
"""返回数据集的大小"""
return len(self.path)

def __getitem__(self, idx):
"""根据索引返回处理后的数据"""
data_path = self.path[idx]
data = self.processor(data_path)
return data

def processor(data_path):
"""预处理函数,这里"""
img = Image.open(data_path).size
return img

def get_images_path(data_dir):
"""获取data_dir目录下所有png图片的路径,不含子目录"""
images_path = [os.path.join(data_dir,image)
for image in os.listdir(data_dir)
if image.endswith('.png')]
return images_path

data_dir = 'images' # 数据目录
images_path = get_images_path(data_dir) # 获取数据路径
dataset = MyDataset(images_path, processor=processor) # 创建数据集

print(dataset[0]) # 打印第一个数据
print(len(dataset)) # 打印数据集大小
(367, 126)
42

DataLoader 数据迭代器

Pytorch加载和处理数据集的可迭代对象//

参数 Default 说明
dataset 必选 用于加载数据的数据集
必须是torch.utils.data.Dataset的子类实例
batch_size 1 每个batch的样本数
shuffle False 是否在每个epoch开始时打乱数据
sampler None 定义从数据集中提取样本的策略
如果指定这个参数,则忽略shuffle参数
batch_sampler None 与sampler类似,但返回的是一个batch的索引
不能与batch_size、shuffle、sampler同时使用
num_workers 0 用于数据加载的子进程数
collate_fn None 将多个样本组合成一个mini-batch的函数
drop_last False 如果数据集大小不能被batch_size整除,是否丢弃最后一个不完整的batch
1
2
3
4
5
6
from torch.utils.data import DataLoader
data_loader = DataLoader(dataset,batch_size=5,shuffle=True)

if __name__ == '__main__':
for data in data_loader:
print(data)
[tensor([1866,  439,  970, 1378, 1416]), tensor([660, 159, 354, 554, 498])]
[tensor([2018, 1116, 1720,  812, 1500]), tensor([ 968,  556, 1302,  465, 1168])]
[tensor([2260, 1378, 1046, 1774, 1521]), tensor([648, 612, 514, 524, 375])]
[tensor([ 956,  962,  412, 1700,  978]), tensor([530, 542, 640, 904, 465])]
[tensor([1202, 1434, 1686, 1583, 3873]), tensor([1006,  598,  540,  844, 5041])]
[tensor([1402, 1824, 1580, 1422, 1638]), tensor([1414, 1454,  630,  570, 1044])]
[tensor([1944, 1390, 1408,  367, 1320]), tensor([592, 568, 864, 126, 364])]
[tensor([1983, 1844,  912, 1400, 1564]), tensor([482, 866, 501, 602, 514])]
[tensor([1144, 1594]), tensor([718, 494])]