PyTorch之数据处理与加载

    xiaoxiao2022-07-13  165

    1.Dataset和Dataloader

    from torch.utils.data import Dataset, Dataloader, TensorDataset

    1.Dataset是一个包装类(TensorDataset是其子类),用来将数据包装为Dataset类,然后传入DataLoader中。

    TensorDataset源代码:

    class TensorDataset(Dataset): """Dataset wrapping tensors. Each sample will be retrieved by indexing tensors along the first dimension. Arguments: *tensors (Tensor): tensors that have the same size of the first dimension. """ # *tensor为动态参数,可以传入多组tensor,一般是数据和标签 def __init__(self, *tensors): # 判断数据和标签的个数是否相等 assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors) self.tensors = tensors def __getitem__(self, index): return tuple(tensor[index] for tensor in self.tensors) def __len__(self): return self.tensors[0].size(0)

    2.DataLoader可以更加快捷地对数据进行操作。它提供的常用操作有:batch_size(每个batch的大小), shuffle(是否进行shuffle操作), num_workers(加载数据的时候使用几个子进程)

    2.示例

    train_dataset = TensorDataset(train_data, train_labels) train_loader = Dataloader(dataset=train_dataset, batch_size=128, shuffle=True) # 训练过程 for epoch in range(epochs): for index, (data, lables) in enumerate(train_loader): data, labels = torch.autograd.Variable(data.cuda()), torch.autograd.Variable(data.cuda()) train...

    参考:

    https://blog.csdn.net/zw__chen/article/details/82806900

    最新回复(0)