卓尔高考网

PyTorch打印模型结构输出维度和参数信息(torchsummary)

篇首语:本文由小编为大家整理,主要介绍了PyTorch打印模型结构输出维度和参数信息(torchsummary)相关的知识,希望对你有一定的参考价值。

NS.jsZHUOEr.CoM

引言

使用 PyTorch 深度学习搭建模型后,如果想查看模型结构,可以直接使用 print(model) 函数打印。但该输出结果不是特别直观,今天给大家推荐一个类似 keras 风格 model.summary() 的模型可视化工具。

安装

pip install torchsummary

用法

  • 示例
NS.jsZHUOEr.CoM
from torchvision import modelsfrom torchsummary import summaryresnet18 = models.resnet18().cuda() # 不加.cuda()会报错summary(resnet18, (3, 224, 224))
  • 输出
----------------------------------------------------------------        Layer (type)               Output Shape         Param # ================================================================            Conv2d-1         [-1, 64, 112, 112]           9,408        BatchNorm2d-2         [-1, 64, 112, 112]             128               ReLU-3         [-1, 64, 112, 112]               0          MaxPool2d-4           [-1, 64, 56, 56]               0             Conv2d-5           [-1, 64, 56, 56]          36,864        BatchNorm2d-6           [-1, 64, 56, 56]             128               ReLU-7           [-1, 64, 56, 56]               0             Conv2d-8           [-1, 64, 56, 56]          36,864       BatchNorm2d-9           [-1, 64, 56, 56]             128             ReLU-10           [-1, 64, 56, 56]               0       BasicBlock-11           [-1, 64, 56, 56]               0           Conv2d-12           [-1, 64, 56, 56]          36,864      BatchNorm2d-13           [-1, 64, 56, 56]             128             ReLU-14           [-1, 64, 56, 56]               0           Conv2d-15           [-1, 64, 56, 56]          36,864      BatchNorm2d-16           [-1, 64, 56, 56]             128             ReLU-17           [-1, 64, 56, 56]               0       BasicBlock-18           [-1, 64, 56, 56]               0           Conv2d-19          [-1, 128, 28, 28]          73,728      BatchNorm2d-20          [-1, 128, 28, 28]             256             ReLU-21          [-1, 128, 28, 28]               0           Conv2d-22          [-1, 128, 28, 28]         147,456      BatchNorm2d-23          [-1, 128, 28, 28]             256           Conv2d-24          [-1, 128, 28, 28]           8,192      BatchNorm2d-25          [-1, 128, 28, 28]             256             ReLU-26          [-1, 128, 28, 28]               0       BasicBlock-27          [-1, 128, 28, 28]               0           Conv2d-28          [-1, 128, 28, 28]         147,456      BatchNorm2d-29          [-1, 128, 28, 28]             256             ReLU-30          [-1, 128, 28, 28]               0           Conv2d-31          [-1, 128, 28, 28]         147,456      BatchNorm2d-32          [-1, 128, 28, 28]             256             ReLU-33          [-1, 128, 28, 28]               0       BasicBlock-34          [-1, 128, 28, 28]               0           Conv2d-35          [-1, 256, 14, 14]         294,912      BatchNorm2d-36          [-1, 256, 14, 14]             512             ReLU-37          [-1, 256, 14, 14]               0           Conv2d-38          [-1, 256, 14, 14]         589,824      BatchNorm2d-39          [-1, 256, 14, 14]             512           Conv2d-40          [-1, 256, 14, 14]          32,768      BatchNorm2d-41          [-1, 256, 14, 14]             512             ReLU-42          [-1, 256, 14, 14]               0       BasicBlock-43          [-1, 256, 14, 14]               0           Conv2d-44          [-1, 256, 14, 14]         589,824      BatchNorm2d-45          [-1, 256, 14, 14]             512             ReLU-46          [-1, 256, 14, 14]               0           Conv2d-47          [-1, 256, 14, 14]         589,824      BatchNorm2d-48          [-1, 256, 14, 14]             512             ReLU-49          [-1, 256, 14, 14]               0       BasicBlock-50          [-1, 256, 14, 14]               0           Conv2d-51            [-1, 512, 7, 7]       1,179,648      BatchNorm2d-52            [-1, 512, 7, 7]           1,024             ReLU-53            [-1, 512, 7, 7]               0           Conv2d-54            [-1, 512, 7, 7]       2,359,296      BatchNorm2d-55            [-1, 512, 7, 7]           1,024           Conv2d-56            [-1, 512, 7, 7]         131,072      BatchNorm2d-57            [-1, 512, 7, 7]           1,024             ReLU-58            [-1, 512, 7, 7]               0       BasicBlock-59            [-1, 512, 7, 7]               0           Conv2d-60            [-1, 512, 7, 7]       2,359,296      BatchNorm2d-61            [-1, 512, 7, 7]           1,024             ReLU-62            [-1, 512, 7, 7]               0           Conv2d-63            [-1, 512, 7, 7]       2,359,296      BatchNorm2d-64            [-1, 512, 7, 7]           1,024             ReLU-65            [-1, 512, 7, 7]               0       BasicBlock-66            [-1, 512, 7, 7]               0AdaptiveAvgPool2d-67            [-1, 512, 1, 1]               0           Linear-68                 [-1, 1000]         513,000================================================================Total params: 11,689,512Trainable params: 11,689,512Non-trainable params: 0----------------------------------------------------------------Input size (MB): 0.57Forward/backward pass size (MB): 62.79Params size (MB): 44.59Estimated Total Size (MB): 107.96----------------------------------------------------------------
  • 存疑

Resnet18 明明是 17 个卷积层加 1 个全连接层,为什么输出中打印出了 20 个卷积层??

  1. 也就是说多了 3 个卷积层,我们先打印模型结构 print(resnet18) 看一下
ResNet(  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)         (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)        (relu): ReLU(inplace=True)  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)         (layer1): Sequential(    (0): BasicBlock(      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)        (relu): ReLU(inplace=True)      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      )    (1): BasicBlock(      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)        (relu): ReLU(inplace=True)      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      )  )  (layer2): Sequential(    (0): BasicBlock(      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)       (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)       (relu): ReLU(inplace=True)      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)       (downsample): Sequential(        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      )    )    (1): BasicBlock(      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (relu): ReLU(inplace=True)      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    )  )  (layer3): Sequential(    (0): BasicBlock(      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (relu): ReLU(inplace=True)      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (downsample): Sequential(        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      )    )    (1): BasicBlock(      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (relu): ReLU(inplace=True)      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    )  )  (layer4): Sequential(    (0): BasicBlock(      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (relu): ReLU(inplace=True)      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (downsample): Sequential(        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      )    )    (1): BasicBlock(      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (relu): ReLU(inplace=True)      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn2): BatchNorm2d(512, e

以上是关于PyTorch打印模型结构输出维度和参数信息(torchsummary)的主要内容,如果未能解决你的问题,请参考以下文章

您可能还会对下面的文章感兴趣:

本文内容由互联网用户自发贡献,该文观点仅代表作者本人。 如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 793096019@qq.com 举报,一经查实,本站将立刻删除。