Pytorch|Dataset&DataLoader
Dataset 数据集
Pytorch中表示数据集的抽象类//
任何自定义的数据集都需要继承这个类并覆写相关方法
Dataset的描述:
所有表示从键到数据样本的映射的数据集都应继承它。
所有子类都应覆写getitem,以支持获取给定键的数据样本。
子类还可以选择性地覆盖len,许多~torch.utils.data.Sampler类实现和~torch.utils.data.DataLoader类的默认选项都希望它返回数据集的大小。
子类也可以选择实现getitems,以加快成批样本的加载速度。
此方法接受批次样本的索引列表,并返回样本列表。
假设需要加载的是当前目录下images
文件夹中的42张png图片
1 | import os |
(367, 126)
42
DataLoader 数据迭代器
Pytorch加载和处理数据集的可迭代对象//
参数 | Default | 说明 |
---|---|---|
dataset | 必选 | 用于加载数据的数据集 必须是torch.utils.data.Dataset的子类实例 |
batch_size | 1 | 每个batch的样本数 |
shuffle | False | 是否在每个epoch开始时打乱数据 |
sampler | None | 定义从数据集中提取样本的策略 如果指定这个参数,则忽略shuffle参数 |
batch_sampler | None | 与sampler类似,但返回的是一个batch的索引 不能与batch_size、shuffle、sampler同时使用 |
num_workers | 0 | 用于数据加载的子进程数 |
collate_fn | None | 将多个样本组合成一个mini-batch的函数 |
drop_last | False | 如果数据集大小不能被batch_size整除,是否丢弃最后一个不完整的batch |
1 | from torch.utils.data import DataLoader |
[tensor([1866, 439, 970, 1378, 1416]), tensor([660, 159, 354, 554, 498])]
[tensor([2018, 1116, 1720, 812, 1500]), tensor([ 968, 556, 1302, 465, 1168])]
[tensor([2260, 1378, 1046, 1774, 1521]), tensor([648, 612, 514, 524, 375])]
[tensor([ 956, 962, 412, 1700, 978]), tensor([530, 542, 640, 904, 465])]
[tensor([1202, 1434, 1686, 1583, 3873]), tensor([1006, 598, 540, 844, 5041])]
[tensor([1402, 1824, 1580, 1422, 1638]), tensor([1414, 1454, 630, 570, 1044])]
[tensor([1944, 1390, 1408, 367, 1320]), tensor([592, 568, 864, 126, 364])]
[tensor([1983, 1844, 912, 1400, 1564]), tensor([482, 866, 501, 602, 514])]
[tensor([1144, 1594]), tensor([718, 494])]
评论