1.Dataset是一个包装类(TensorDataset是其子类),用来将数据包装为Dataset类,然后传入DataLoader中。
TensorDataset源代码:
class TensorDataset(Dataset): """Dataset wrapping tensors. Each sample will be retrieved by indexing tensors along the first dimension. Arguments: *tensors (Tensor): tensors that have the same size of the first dimension. """ # *tensor为动态参数,可以传入多组tensor,一般是数据和标签 def __init__(self, *tensors): # 判断数据和标签的个数是否相等 assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors) self.tensors = tensors def __getitem__(self, index): return tuple(tensor[index] for tensor in self.tensors) def __len__(self): return self.tensors[0].size(0)2.DataLoader可以更加快捷地对数据进行操作。它提供的常用操作有:batch_size(每个batch的大小), shuffle(是否进行shuffle操作), num_workers(加载数据的时候使用几个子进程)
https://blog.csdn.net/zw__chen/article/details/82806900