Torch.cat

    xiaoxiao2022-07-07  187

    Torch.cat

    Torch.cat的用法实例

    Torch.cat的用法

    torch.cat是将两个张量(tensor)拼接在一起,cat是concatnate的意思,即拼接,联系在一起。

    实例

    // An highlighted block import torch a = torch.randn(1,3,5,5) b = torch.randn(1,3,5,5) print(a) print(b) >>> tensor([[[[ 0.4894, 0.9118, -0.7975, 0.7769, -1.0983], [-0.0617, 0.3230, 0.0853, 0.1426, 0.4373], [-0.7775, -0.4893, 0.3031, -0.5224, 0.7206], [ 0.0899, -1.2982, 0.3694, -0.6010, 1.0882], [ 0.7994, -0.0182, -0.2830, -0.1175, 2.3031]], [[-0.6270, 0.9806, -0.3543, -0.6706, 0.8451], [ 0.8559, 1.1715, -2.7926, 0.8195, 1.2003], [-0.9363, 0.6287, 0.8031, -1.1601, 0.5090], [-1.1433, -0.5224, 0.4913, -1.2035, 1.1474], [ 0.6201, -0.1981, -1.1308, -1.9613, -0.5917]], [[-0.0586, -1.1031, -0.0804, -0.2093, -2.4915], [ 1.1992, -1.9065, -0.9396, 0.3971, -0.1479], [-0.2771, -1.3371, -0.1468, -0.0249, 0.0760], [-0.9427, -1.0914, -0.0847, -1.0619, 0.8419], [ 0.8154, -0.1618, -0.0244, 0.3523, -1.3139]]]]) >>>tensor([[[[ 0.1765, -0.2371, 0.3850, 0.3014, 1.3498], [-0.5725, 1.1764, 0.7769, 0.7970, -0.5984], [-0.8498, 0.3575, 0.8842, 1.8408, -0.7673], [-2.0848, -2.4115, -0.1191, -1.3151, -0.2261], [ 0.8543, 0.0785, -0.4349, -1.3560, 0.0721]], [[-0.8831, 0.2914, -0.0772, -0.1918, -0.9889], [ 2.0799, 0.3074, -0.7013, -1.5068, 1.2838], [-1.1274, 0.2503, 0.9909, -1.0574, 0.1395], [-1.2156, -1.3117, 0.5919, 2.5695, -1.5748], [-0.4077, 0.8041, -1.5757, -0.0711, -0.6129]], [[-1.6921, 0.0097, -0.3866, 0.5965, -1.3929], [ 0.2597, -0.6740, 0.3119, -1.9251, -1.6731], [ 0.0244, 0.7889, -0.1629, -0.9620, -0.2372], [-1.5149, 0.4383, -1.5867, -1.0003, 0.0335], [ 0.1328, -1.6683, -1.3638, 0.0362, -0.4178]]]]) c = torch.cat([a,b], dim=1) #concatnate on dim=1 print(c) >>>tensor([[[[ 0.4894, 0.9118, -0.7975, 0.7769, -1.0983], [-0.0617, 0.3230, 0.0853, 0.1426, 0.4373], [-0.7775, -0.4893, 0.3031, -0.5224, 0.7206], [ 0.0899, -1.2982, 0.3694, -0.6010, 1.0882], [ 0.7994, -0.0182, -0.2830, -0.1175, 2.3031]], [[-0.6270, 0.9806, -0.3543, -0.6706, 0.8451], [ 0.8559, 1.1715, -2.7926, 0.8195, 1.2003], [-0.9363, 0.6287, 0.8031, -1.1601, 0.5090], [-1.1433, -0.5224, 0.4913, -1.2035, 1.1474], [ 0.6201, -0.1981, -1.1308, -1.9613, -0.5917]], [[-0.0586, -1.1031, -0.0804, -0.2093, -2.4915], [ 1.1992, -1.9065, -0.9396, 0.3971, -0.1479], [-0.2771, -1.3371, -0.1468, -0.0249, 0.0760], [-0.9427, -1.0914, -0.0847, -1.0619, 0.8419], [ 0.8154, -0.1618, -0.0244, 0.3523, -1.3139]], [[ 0.1765, -0.2371, 0.3850, 0.3014, 1.3498], [-0.5725, 1.1764, 0.7769, 0.7970, -0.5984], [-0.8498, 0.3575, 0.8842, 1.8408, -0.7673], [-2.0848, -2.4115, -0.1191, -1.3151, -0.2261], [ 0.8543, 0.0785, -0.4349, -1.3560, 0.0721]], [[-0.8831, 0.2914, -0.0772, -0.1918, -0.9889], [ 2.0799, 0.3074, -0.7013, -1.5068, 1.2838], [-1.1274, 0.2503, 0.9909, -1.0574, 0.1395], [-1.2156, -1.3117, 0.5919, 2.5695, -1.5748], [-0.4077, 0.8041, -1.5757, -0.0711, -0.6129]], [[-1.6921, 0.0097, -0.3866, 0.5965, -1.3929], [ 0.2597, -0.6740, 0.3119, -1.9251, -1.6731], [ 0.0244, 0.7889, -0.1629, -0.9620, -0.2372], [-1.5149, 0.4383, -1.5867, -1.0003, 0.0335], [ 0.1328, -1.6683, -1.3638, 0.0362, -0.4178]]]]) c.shape Out[8]: torch.Size([1, 6, 5, 5]) c = torch.cat([a,b], dim=0) # concatnate on dim=0 c.shape Out[10]: torch.Size([2, 3, 5, 5])

    根据实验可以看出,torch.cat([a,b],dim=1) 是在channel维度上对a和b进行了concatnate。同理,dim=0就是对batch维度上进行拼接,此时的shape就会变成torch.Size([2,3, 5, 5]).

    最新回复(0)