torch.utils.data.dataloader参数collate

    xiaoxiao2022-07-02  122

    torch.utils.data.DataLoader是pytorch提供的数据加载类,初始化函数如下,

    torch.utils.data.DataLoader(dataset,batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)

    dataset,batch_size等参数重要且容易理解,而collate_fn参数就不太直白,官方解释为:

    collate_fn (callable, optional) – merges a list of samples to form a mini-batch

    不明不白。

    其实,collate_fn可理解为函数句柄、指针...或者其他可调用类(实现__call__函数)。 函数输入为list,list中的元素为欲取出的一系列样本。具体如下

    indices = next(self.sample_iter) batch = self.collate_fn([dataset[i] for i in indices])

    其中self.sampler_iter即采样器,返回下一个batch中样本的序号,indices。

    通过collate_fn函数可以对这些样本做进一步的处理(任何你想要的处理),原则上返回值应当是一个有结构的batch。而DataLoader每次迭代的返回值就是collate_fn的返回值。

     

    最新回复(0)