上一篇分析了SummaryWriter主要函数,本篇借助TensorboardX官方demo,解释add_graph用法
本文选用demo中代表性的条目,完整demo参见:https://github.com/lanpa/tensorboardX/blob/master/examples/demo_graph.py
import torch import torch.nn as nn import torch.nn.functional as F import torchvision from torch.autograd import Variable from tensorboardX import SummaryWriter dummy_input = torch.ones(1,3) class LinearInLinear(nn.Module): def __init__(self): super(LinearInLinear,self).__init__() self.l = nn.Linear(3,5) def forward(self,x): return self.l(x) with SummaryWriter(comment='LinearInLinear2') as w: w.add_graph(LinearInLinear(),dummy_input,False) class MutipleInput(nn.Module): def __init__(self): super(MutipleInput,self).__init__() self.Linear_1 = nn.Linear(3,5) def forward(self, x,y): return self.Linear_1(x+y) model_m = MutipleInput() with SummaryWriter(comment='MutipleInput') as w: w.add_graph(model_m,(torch.ones(1,3),torch.zeros(1,3)),True) def conv3x3(in_channels,out_channels,stride=1): return nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=stride,padding=1,bias=False) class BasicBlock(nn.Module): expansion = 1 def __init__(self,inplanes,planes,stride=1,downsample=None): super(BasicBlock,self).__init__() self.conv1 = conv3x3(inplanes,planes,stride) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = conv3x3(planes,planes) self.bn2 = nn.BatchNorm2d(planes) self.stride = stride def forward(self,x): residual = x out = self.conv1(x) out = self.bn1(out) out = F.relu(out) out = self.conv2(out) out = self.bn2(out) out += residual out = F.relu(out) return out dummy_input = torch.rand(1, 3, 224, 224) with SummaryWriter(comment='basicblock') as w: model = BasicBlock(3, 3) w.add_graph(model, (dummy_input, ), verbose=True)截取一个输出图(basicblock)如下:
官方demo中如下的代码片段:
dummy_input = torch.Tensor(1,3,224,224) with SummaryWriter(comment='vgg19') as w: model = torchvision.models.vgg19() w.add_graph(model, (dummy_input,)) with SummaryWriter(comment='resnet18') as w: model = torchvision.models.resnet18() w.add_graph(model, (dummy_input,))无法运行,报错信息如下:
assert output_size == [1, 1], "Only output_size=[1, 1] is supported" AssertionError: Only output_size=[1, 1] is supported仅仅是增加一个graph,和output有什么关系?
也许是版本不一致。
仍将持续更新,感谢关注!
