【语义分割系列:三】FCN 论文阅读翻译笔记+资源整理+pytorch实现

    xiaoxiao2024-11-08  70

    2015 CVPR best paper

    key word: pixel level, fully supervised, CNN semantic segmentation milestone

    Fully Convolutional Networks for Semantic Segmentation

    官方 github:FCN

    wkentaro Pytorch FCN

    Pytorch FCN bag | 自己 2019/4/22 跑过

    一篇很好的博文:全卷积网络 FCN 详解

    Introduce

    FCN作为图像语义分割的先河,实现像素级别的分类(即end to end,pixel-wise)

    特点:

    Extend deep classification architecture(AlexNet, VGG, GoogLeNet)fine-tune, pixel-to-pixelend-to-endinput of any size, output classification maps(heatmap)deconvolution(upsample), can be learnedFor pixelwise prediction, connect coarse output back to pixels

    特点:

    做pixel-wise的prediction,使用ground-truth作为监督信息,预测label map不含全连接层(fc)的全卷积(fully conv)网络 (分类网络输出一个热度图类型图像,而不像CNN在卷积之后会接上若干个全连接层,得到输入图像的分类的概率) 对全卷积网络末,使用deconv(反卷积)方法做unsampling(上采样),输出和原图大小一样。结合不同深度层结果的skip结构。同时确保鲁棒性和精确性。

    上图: 猫,输入AlexNet, 得到一个长为1000的输出向量, 表示输入图像属于每一类的概率, 其中在“tabby cat”这一类统计概率最高。但是因为丢失了一些物体的细节,不能很好地给出物体的具体轮廓、指出每个像素具体属于哪个物体。 猫,输入FCN,得到一个热度图类型图像。

    重点:

    损失函数是在最后一层的 spatial map上的 pixel 的 loss 和,在每一个 pixel 使用 softmax loss

    使用 skip 结构融合多层(3层)输出,底层网络应该可以预测更多的位置信息,因为他的感受野小可以看到小的 pixels

    上采样 lower-resolution layers 时,如果采样后的图因为 padding 等原因和前面的图大小不同,使用 crop ,当裁剪成大小相同的,spatially aligned ,使用 concat 操作融合两个层

    缺点:

    得到的结果还是不够精细。进行8倍上采样虽然比32倍的效果好了很多,但是上采样的结果还是比较模糊和平滑,对图像中的细节不敏感。

    对各个像素进行分类,没有充分考虑像素与像素之间的关系。忽略了在通常的基于像素分类的分割方法中使用的空间规整(spatial regularization)步骤,缺乏空间一致性。

    Architecture

    FCN结构

    卷积化(Convolutional) VGG16,ResNet50/101等网络丢弃全连接层上采样(Upsample) 反卷积(Deconvolution)跳层连接(Skip Layer) 是为了优化结果

    输入可为任意尺寸图像彩色图像;输出与输入尺寸相同,深度为:20类目标+背景=21。 (在PASCAL数据集上进行的,PASCAL一共20类)

    蓝:卷积 绿:max pooling

    全卷积-提取特征

    图虚线以上部分

    对于不同尺寸的输入图像,各层数据的尺寸(height,width)相应变化,深度(channel)不变。

    AlexNet最后两个全连接层(fc)改成了卷积层。 (论文中最高精度的分类网络是VGG16,此处使用AlexNet举例)

    逐像素预测

    虚线下半部分中,分别从卷积网络的不同阶段,以卷积层(蓝色×3)预测深度为21的分类结果。

    采用反卷积层对最后一个卷积层的feature map进行上采样, 使它恢复到输入图像相同的尺寸,从而可以对每个像素都产生了一个预测, 同时保留了原始输入图像中的空间信息, 最后在上采样的特征图上进行逐像素分类。

    最后的输出是21张heatmap经过upsampling变为原图大小的图片,为了对每个像素进行分类预测label成最后已经进行语义分割的图像,这里有一个小trick,就是最后通过逐个像素地求其在21张图像该像素位置的最大数值描述(概率)作为该像素的分类。因此产生了一张已经分类好的图片,如下图右侧有狗狗和猫猫的图。

    反卷积-上采样

    FCN是用transposed convolution策略,即将feature 反卷积后得到upsampling,这一过程需要学习

    先进行上采样,即扩大像素;再进行卷积,通过学习获得权值。

    反卷积层(橙色×3)可以把输入数据尺寸放大。

    FCN作者使用的方法是这里所说反卷积的一种变体,这样就可以获得相应的像素值,图像可以实现end to end。

    (feature map值与权重不同,生成的上采样的二值区域也是不一样的。)

    skip结构

    Loss

    fcn网络的输入batchsize是1,因为分割loss的计算在每一个像素点都一个真值(标签),相当于每一个像素点的都是一个分类任务,一个图像就有对应像素点个样本。所以分割任务的batch是一个图片,将一个图片最后在所有像素点上的分类loss加起来计算一次梯度的更新。

    Train

    以经典的分类网络为初始化。最后两级是全连接(红色),参数弃去不用。

    从特征小图(16164096)预测分割小图(161621),之后直接升采样为大图。 反卷积(橙色)的步长为32,这个网络称为FCN-32s。

    升采样分为两次完成(橙色×2)。 在第二次升采样前,把第4个pooling层(绿色)的预测结果(蓝色)融合进来。使用skip结构提升精确性。 第二次反卷积步长为16,这个网络称为FCN-16s。

    升采样分为三次完成(橙色×3)。 进一步融合了第3个pooling层的预测结果。 第三次反卷积步长为8,记为FCN-8s。 这一阶段使用单GPU训练约需1天。

    现在我们有1/32尺寸的heatMap,1/16尺寸的featureMap和1/8尺寸的featureMap,1/32尺寸的heatMap进行upsampling操作之后,因为这样的操作还原的图片仅仅是conv5中的卷积核中的特征,限于精度问题不能够很好地还原图像当中的特征,因此在这里向前迭代。把conv4中的卷积核对上一次upsampling之后的图进行反卷积补充细节(相当于一个差值过程,简单的双线性插值),最后把conv3中的卷积核对刚才upsampling之后的图像进行再次反卷积补充细节,最后就完成了整个图像的还原。

    较浅层的预测结果包含了更多细节信息。skip结构利用浅层信息辅助逐步升采样,结果更精细。

    Pytorch Demo

    Pytorch FCN bag | 自己 2019/4/22 跑过

    question1:bag_msk in BagDataset

    首先,把imgB resize(16,16)方便输出查看。

    imgB (16,16)

    imgB中255像素点为白色,0像素点为黑色,少部分的像素。

    [[255 255 255 255 255 255 255 255 255 255 255 255 255 255 255 255] [255 255 255 255 255 255 255 255 255 255 255 255 255 255 255 255] [255 255 255 255 255 255 0 255 252 255 254 255 255 255 255 255] [255 255 255 255 255 0 0 0 0 0 0 255 255 255 255 255] [255 255 255 255 0 0 0 0 0 0 0 0 255 255 255 255] [255 255 255 255 0 0 0 0 0 0 0 0 255 255 255 255] [255 255 255 0 0 0 0 0 0 0 0 0 1 255 255 255] [255 255 255 0 0 0 0 0 0 0 0 0 13 255 255 255] [255 255 255 0 0 0 0 0 0 0 0 0 0 255 255 255] [255 255 255 255 255 255 255 255 255 255 255 255 255 255 255 255] [255 255 255 255 255 255 255 255 255 255 255 255 255 255 255 255] [255 255 255 255 255 255 255 255 255 255 255 255 255 255 255 255] [255 255 255 255 255 255 255 0 251 255 255 1 255 255 255 255] [255 255 255 255 14 0 0 0 0 0 0 255 255 255 255 255] [255 255 255 174 0 0 0 0 0 0 0 253 6 255 255 255] [255 255 255 0 0 0 0 0 0 0 0 0 0 255 255 255]] imgB/255 (16,16)

    为了转换成0,1

    [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. ] [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. ] [1. 1. 1. 1. 1. 1. 0. 1. 0.98823529 1. 0.99607843 1. 1. 1. 1. 1. ] [1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. ] [1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. ] [1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. ] [1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.00392157 1. 1. 1. ] [1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.05098039 1. 1. 1. ] [1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. ] [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. ] [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. ] [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. ] [1. 1. 1. 1. 1. 1. 1. 0. 0.98431373 1. 1. 0.00392157 1. 1. 1. 1. ] [1. 1. 1. 1. 0.05490196 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. ] [1. 1. 1. 0.68235294 0. 0. 0. 0. 0. 0. 0. 0.99215686 0.02352941 1. 1. 1. ] [1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. ]] imgB.astype(‘uint8’) (16,16)

    作用是转换数据类型。uint8 无符号整数(0 到 255) 使用astype 0.9、0.8、0.2都是0。截断处理是0,符合我们常识,因为只有白色255才是背景图才是1。

    [[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1] [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1] [1 1 1 1 1 1 0 1 0 1 0 1 1 1 1 1] [1 1 1 1 1 0 0 0 0 0 0 1 1 1 1 1] [1 1 1 1 0 0 0 0 0 0 0 0 1 1 1 1] [1 1 1 1 0 0 0 0 0 0 0 0 1 1 1 1] [1 1 1 0 0 0 0 0 0 0 0 0 0 1 1 1] [1 1 1 0 0 0 0 0 0 0 0 0 0 1 1 1] [1 1 1 0 0 0 0 0 0 0 0 0 0 1 1 1] [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1] [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1] [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1] [1 1 1 1 1 1 1 0 0 1 1 0 1 1 1 1] [1 1 1 1 0 0 0 0 0 0 0 1 1 1 1 1] [1 1 1 0 0 0 0 0 0 0 0 0 0 1 1 1] [1 1 1 0 0 0 0 0 0 0 0 0 0 1 1 1]] onehot(imgB, 2) (16,16,2)

    每个像素点的值转化成 one-hot 形式

    [[[1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.]] [[1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.]] [[1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 1.] [0. 0.] [1. 1.] [0. 0.] [1. 1.] [0. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.]] [[1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.]] [[1. 0.] [1. 0.] [1. 0.] [1. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.]] [[1. 0.] [1. 0.] [1. 0.] [1. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.]] [[1. 0.] [1. 0.] [1. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 0.] [1. 0.] [1. 0.] [1. 0.]] [[1. 0.] [1. 0.] [1. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 0.] [1. 0.] [1. 0.] [1. 0.]] [[1. 0.] [1. 0.] [1. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 0.] [1. 0.] [1. 0.] [1. 0.]] [[1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.]] [[1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.]] [[1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.]] [[1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 1.] [0. 1.] [0. 0.] [1. 0.] [1. 1.] [0. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.]] [[1. 0.] [1. 0.] [1. 0.] [1. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.] [1. 0.]] [[1. 0.] [1. 0.] [1. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 0.] [1. 0.] [1. 0.] [1. 0.]] [[1. 0.] [1. 0.] [1. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 0.] [1. 0.] [1. 0.] [1. 0.]]] imgB.transpose(2,0,1) (2,16,16)

    把维度(16,16,2)变为(2,16,16)

    [[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.] [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.] [1. 1. 1. 1. 1. 1. 0. 1. 0. 1. 0. 1. 1. 1. 1. 1.] [1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1.] [1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1.] [1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1.] [1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1.] [1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1.] [1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1.] [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.] [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.] [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.] [1. 1. 1. 1. 1. 1. 1. 0. 0. 1. 1. 0. 1. 1. 1. 1.] [1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1.] [1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1.] [1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1.]] [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 1. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0.] [0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0.] [0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0.] [0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0.] [0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 1. 0. 0. 0. 0. 0.] [0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0.] [0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0.] [0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0.]]]

    question2:one hot

    import numpy as np #每个像素点的值代表其属于的类别,要把每个像素点的值转化成 one-hot 形式 def onehot(data, n): buf = np.zeros(data.shape + (n, )) nmsk = np.arange(data.size)*n + data.ravel() buf.ravel()[nmsk-1] = 1 return buf

    写例子:

    import torch import numpy as np data=np.array([[1.0,0.0,1.0,0.0],[0.0,0.0,0.0,0.0],[1.0,1.0,1.0,1.0],[0.0,0.0,0.0,0.0]]) data=data.astype('uint8') # 因为第九行 arrays used as indices must be of integer (or boolean) type n=2 buf=np.zeros(data.shape + (n, )) nmsk = np.arange(data.size)*n+ data.ravel() buf.ravel()[nmsk-1] = 1 buf=buf.transpose(2,0,1) buf=torch.FloatTensor(buf) ''' Numpy中扁平化函数ravel()和flatten() buf不会变成一维但是数值会改变原因: flatten()分配了新的内存,修改不会影响原来的。 ravel()返回的是一个数组的视图.修改视图时影响原本的数组. '''

    结果:

    data.size=16 data= [[1 0 1 0] [0 0 0 0] [1 1 1 1] [0 0 0 0]] buf zero= [[[0. 0.] [0. 0.] [0. 0.] [0. 0.]] [[0. 0.] [0. 0.] [0. 0.] [0. 0.]] [[0. 0.] [0. 0.] [0. 0.] [0. 0.]] [[0. 0.] [0. 0.] [0. 0.] [0. 0.]]] nmsk= [ 1 2 5 6 8 10 12 14 17 19 21 23 24 26 28 30] buf = [[[1. 1.] [0. 0.] [1. 1.] [0. 1.]] [[0. 1.] [0. 1.] [0. 1.] [0. 0.]] [[1. 0.] [1. 0.] [1. 0.] [1. 1.]] [[0. 1.] [0. 1.] [0. 1.] [0. 0.]]] buf transpose= [[[1. 0. 1. 0.] [0. 0. 0. 0.] [1. 1. 1. 1.] [0. 0. 0. 0.]] [[1. 0. 1. 1.] [1. 1. 1. 0.] [0. 0. 0. 1.] [1. 1. 1. 0.]]] buf FloatTensor= tensor([[[1., 0., 1., 0.], [0., 0., 0., 0.], [1., 1., 1., 1.], [0., 0., 0., 0.]], [[1., 0., 1., 1.], [1., 1., 1., 0.], [0., 0., 0., 1.], [1., 1., 1., 0.]]])

    question3:output and bag_mask

    把图resize为 (16,16)

    bag= tensor([[[[ 0.1597, 0.2111, 0.1939, 0.1768, 0.1768, 0.1939, 0.2624, 0.3481, 0.1083, 0.1254, 0.2282, 0.1254, 0.0569, 0.1083, 0.4508, 0.0398], [ 0.5707, 0.4337, 1.0844, 0.7077, 0.5364, 0.5193, 0.5536, 0.6392, 0.6563, -0.6965, 0.5364, 0.3823, 0.2796, 0.2796, 0.7419, 0.4851], [ 0.6563, 0.6221, 0.5878, 1.1872, 0.7591, 0.6734, 0.1083, 0.9303, 1.3927, 1.3755, 0.8104, 0.3823, 0.3309, 0.2453, 0.9132, 0.5022], [-1.3473, 0.3481, 0.6049, 0.6906, 1.3927, 1.0159, -0.8678, 1.3413, 1.1700, 1.1872, 0.9817, 0.4166, 0.4679, 0.4166, 1.1187, 0.0227], [-0.8507, -0.7993, -1.3130, 0.6906, 0.6906, 1.5982, -1.0562, 1.0502, 0.7933, 0.8618, 1.3413, 1.2899, 1.4098, 1.0673, 0.6563, -0.4226], [-0.8849, -0.9534, -0.8678, -0.6623, 0.0056, 0.5536, 0.3823, 0.9303, -0.3198, -0.5596, -0.1314, -0.0458, -0.2684, -0.3027, -0.4739, -0.6965], [-0.6965, -0.9534, 0.2796, 0.3652, -0.7137, 0.6906, -0.5938, 0.0569, 1.2043, 1.3242, -0.4739, 1.3927, 1.4612, -0.0629, -0.5938, -0.7308], [-0.5253, -0.4226, -1.0562, 1.7352, 0.7419, -0.5938, 1.6667, 1.5810, 1.4269, 1.5639, 1.3070, 1.4440, 1.7009, -0.7308, -1.2103, -1.6042], [-0.3198, -0.3712, -1.5870, 1.6838, 1.3413, 1.6838, 1.3755, 1.3242, 1.5810, 1.4954, 0.6563, 1.3755, 0.9132, -1.2959, -0.2342, -0.5253], [ 2.2489, -0.4397, -0.0629, 2.0092, 1.7694, 1.5125, 1.2728, 0.8789, 1.3755, 1.4269, -0.5767, 1.4612, 1.6667, -1.4158, 0.0741, -0.3198], [-1.9295, 0.7077, -1.7754, 1.7352, 1.5810, 1.3584, 1.2557, 1.3413, 1.1358, 1.1529, 0.0056, 1.1872, 1.2557, -1.5185, -0.2684, -0.9705], [-1.2103, -0.8164, -1.7240, 1.6667, 1.7009, 1.4954, 1.2728, 1.3927, 1.0331, 1.1529, 1.3927, 1.4440, 1.6495, -1.5185, -0.4397, -0.5596], [-1.5014, -1.4843, 0.9132, 1.6838, 1.6667, 1.6153, 1.4783, 1.4098, 1.4783, 1.1015, 1.4098, 1.5468, 1.6324, -0.8335, -0.9192, -1.6042], [-1.9295, -1.0562, -0.0801, 1.6838, 1.6667, 1.6153, 1.6324, 1.6153, 1.5125, 1.2214, 0.9988, 1.5297, 1.5639, -0.8507, -0.7993, -1.2103], [-0.5424, -0.1828, -1.9295, 1.6838, 1.7009, 1.6838, 1.7180, 1.7352, 1.6495, 1.6324, 1.7523, 1.6153, 1.6153, -2.1008, -0.4911, -0.6794], [ 1.0502, 0.0227, -0.5082, -0.4226, -0.3541, -0.3027, -0.1999, -0.0629, -0.2856, -0.4911, -0.6109, -1.6555, -0.7137, -0.7479, -0.6281, -0.7137]], [[ 0.3452, 0.3277, 0.2927, 0.3277, 0.3277, 0.4678, 0.4853, 0.5553, 0.3452, 0.3102, 0.3102, 0.3277, 0.2927, 0.3627, 0.8880, 0.3803], [ 0.6078, 0.6254, 1.6408, 1.0630, 0.8354, 0.7829, 0.7829, 0.8004, 0.8354, -0.5826, 0.6954, 0.6254, 0.5203, 0.5203, 1.2381, 0.6779], [ 0.8354, 0.8004, 0.6954, 1.7808, 1.1506, 0.9755, 0.2227, 1.1331, 1.3606, 1.3431, 0.8179, 0.5378, 0.5728, 0.5203, 1.3957, 0.8004], [-1.0553, 0.6078, 0.7829, 0.8354, 1.8859, 1.3256, -0.8452, 1.4307, 1.3256, 1.3782, 1.1506, 0.6954, 0.7304, 0.7829, 1.6583, 0.3452], [-0.4426, -0.5651, -1.0728, 1.0805, 0.9405, 1.9734, -0.9328, 1.2206, 1.1681, 1.2731, 1.6232, 1.6933, 1.7808, 1.3957, 0.8704, 0.1001], [-0.4776, -0.5301, -0.4251, -0.0749, 0.3452, 0.8704, 0.5378, 1.1506, 0.2752, 0.0826, -0.0224, 0.3803, 0.2577, 0.2577, 0.1176, -0.0749], [-0.2850, -0.4601, 0.7479, 0.8704, -0.4426, 0.8529, -0.3901, 0.5553, 1.7108, 1.7808, -0.3550, 1.8333, 1.9034, 0.5553, 0.2052, 0.2052], [-0.1275, 0.0301, 0.1527, 1.9734, 0.8004, -0.3025, 1.7983, 1.6583, 1.5007, 1.6933, 1.4132, 1.5707, 1.8158, -0.1275, -0.6176, -1.1253], [ 0.1001, -0.0224, -0.4776, 1.7633, 1.4132, 1.7808, 1.4482, 1.3957, 1.6583, 1.6057, 0.8004, 1.4482, 1.0280, -0.9853, 0.3978, 0.1527], [ 2.3060, -0.0924, 0.4328, 2.0784, 1.9034, 1.5882, 1.3782, 0.9930, 1.4482, 1.5532, -0.5651, 1.5882, 1.7983, -1.2129, 0.5903, 0.4678], [-1.8782, 0.7129, -1.6681, 1.8683, 1.6583, 1.4307, 1.3606, 1.4657, 1.2031, 1.2731, 0.0651, 1.3081, 1.3782, -1.3354, 0.3627, -0.5301], [-1.0903, -0.6176, -1.5455, 1.7983, 1.7808, 1.6232, 1.3431, 1.4657, 1.1331, 1.2206, 1.4657, 1.5707, 1.7808, -1.4230, 0.2402, 0.0126], [-1.3880, -1.3880, 0.8880, 1.8158, 1.7458, 1.6933, 1.5532, 1.4832, 1.5532, 1.2206, 1.4832, 1.6232, 1.7108, -0.0749, -0.3200, -1.4055], [-1.7906, -0.8627, 0.2227, 1.7983, 1.7458, 1.6933, 1.7108, 1.6933, 1.5882, 1.3256, 1.0630, 1.6057, 1.6408, -0.3550, -0.2500, -0.8978], [ 0.0826, 0.4678, -1.5280, 1.7633, 1.7808, 1.7633, 1.7983, 1.7808, 1.7283, 1.7108, 1.8333, 1.6933, 1.7108, -1.7906, 0.0826, -0.1800], [ 1.5532, 0.4678, 0.1702, 0.3978, 0.3803, 0.4328, 0.4328, 0.6604, 0.4153, 0.1877, -0.0049, -1.2829, -0.4601, -0.3901, -0.2325, -0.2850]], [[ 0.4091, 0.4091, 0.3742, 0.3916, 0.4265, 0.6531, 0.6705, 0.6705, 0.4788, 0.4265, 0.4439, 0.4614, 0.4265, 0.4962, 1.1411, 0.5834], [ 0.7054, 0.7402, 1.9080, 1.2805, 1.0017, 0.9494, 0.9494, 0.8622, 0.7054, -0.3578, 0.6356, 0.6531, 0.6531, 0.6531, 1.4548, 0.8622], [ 0.9494, 0.8971, 0.7402, 1.9777, 1.2980, 1.1411, 0.5136, 1.0888, 1.2108, 1.3154, 0.7402, 0.6008, 0.7054, 0.7054, 1.6640, 0.9842], [-0.4973, 0.8099, 1.0017, 0.8622, 2.0474, 1.4897, -0.4624, 1.4548, 1.3328, 1.3851, 1.1934, 0.8971, 0.8971, 0.9842, 1.9777, 0.6356], [ 0.0082, 0.1999, -0.6367, 1.2457, 1.0714, 2.1520, -0.6715, 1.2457, 1.2631, 1.4025, 1.8557, 1.8731, 1.9603, 1.5420, 1.0365, 0.4614], [-0.0267, -0.0092, 0.1128, 0.7228, 0.7402, 1.1759, 0.8274, 1.2457, 0.2696, 0.3568, 0.3393, 0.7054, 0.6182, 0.6008, 0.5136, 0.3393], [ 0.1651, -0.0441, 1.1585, 1.2457, -0.0267, 1.0365, -0.1138, 0.7402, 1.8731, 1.9777, 0.0256, 1.9951, 2.0997, 0.8099, 0.5485, 0.5659], [ 0.3742, 0.4962, 1.1934, 2.2391, 1.0539, -0.0092, 2.0125, 1.8905, 1.7337, 1.9080, 1.6640, 1.7860, 2.0474, 0.2871, -0.2010, -0.9330], [ 0.5485, 0.3742, 0.5311, 2.0823, 1.6640, 2.0300, 1.6814, 1.6291, 1.8905, 1.8557, 1.0191, 1.6814, 1.2457, -0.3404, 0.7054, 0.4962], [ 2.2217, 0.3568, 1.0017, 2.3611, 2.1171, 1.8208, 1.6291, 1.2108, 1.6814, 1.7685, -0.2532, 1.8034, 2.0125, -0.6541, 0.8971, 0.9494], [-1.7522, 0.6705, -1.5953, 2.0823, 1.8905, 1.6640, 1.6117, 1.6814, 1.4374, 1.4897, 0.3742, 1.5245, 1.5942, -0.8807, 0.6182, 0.7576], [-0.9330, -0.4101, -1.3339, 2.0125, 2.0125, 1.8383, 1.5768, 1.6988, 1.3851, 1.4548, 1.6988, 1.7860, 1.9951, -0.9853, 0.5485, 0.4439], [-1.2293, -1.3339, 0.5834, 2.0300, 1.9777, 1.9254, 1.7860, 1.7163, 1.7860, 1.4374, 1.7163, 1.8557, 1.9428, 0.7054, 0.3568, -1.1770], [-1.6127, -0.7587, 0.0256, 2.0300, 1.9777, 1.9254, 1.9254, 1.9254, 1.8208, 1.5420, 1.2980, 1.8383, 1.8731, 0.1302, 0.2696, -0.7238], [ 0.3916, 0.8448, -1.1073, 1.9951, 1.9777, 1.9777, 1.9603, 1.9603, 1.8905, 1.8731, 2.0300, 1.8905, 1.9080, -1.3687, 0.5311, 0.1476], [ 1.8208, 0.8099, 0.5834, 0.7054, 0.7576, 0.8099, 0.8274, 0.9842, 0.7925, 0.5834, 0.3916, -0.9330, -0.1661, -0.0267, 0.2173, 0.1128]]]], device='cuda:0') bag_msk=tensor([[[[1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1., 0., 0., 1., 0., 1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 1., 1.], [1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1.], [1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1.], [1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1.], [1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1.], [1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1.], [1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1.], [1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1.], [1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1.], [1., 1., 1., 1., 1., 0., 1., 0., 0., 1., 1., 0., 1., 1., 1., 1.]], [[0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 0., 0., 0.], [0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0.], [0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0.], [0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0.], [0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0.], [0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0.], [0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0.], [0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0.], [0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0.], [0., 0., 0., 0., 1., 0., 1., 1., 0., 0., 1., 0., 0., 0., 0., 0.]]]], device='cuda:0')

    question4:plt

    发现testLoader(batch=4)plt了15张图,但是测试集有60张。

    135 540÷4 train_dataset total / batch size 15 60÷4 test_dataset total / batch size

    发现每个testLoader只输出了一张图

    plt.subplot(1, 2, 1) plt.imshow(np.squeeze(bag_msk_np[0, ...]), 'gray') plt.subplot(1, 2, 2) plt.imshow(np.squeeze(output_np[0, ...]), 'gray') plt.pause(0.5) # 暂停半秒钟

    原因在于:

    np.squeeze(output_np[0, …])中output_np[0, …]取第一张图片的数据,0代表axis=0 中的index0

    squeeze在这里没有什么作用。

    squeeze:主要是降维,所降维的axis shape>1会报错

    x = np.array([[[0], [1], [2]]]) >>> x.shape (1, 3, 1) >>> np.squeeze(x).shape (3,) >>> np.squeeze(x, axis=0).shape (3, 1) >>> np.squeeze(x, axis=1).shape Traceback (most recent call last): ... ValueError: cannot select an axis to squeeze out which has size not equal to one >>> np.squeeze(x, axis=2).shape (1, 3)

    References

    https://blog.csdn.net/qq_36269513/article/details/80420363

    http://www.cnblogs.com/gujianhan/p/6030639.html

    https://blog.csdn.net/gavin__zhou/article/details/52130677

    最新回复(0)