Skip to content

There are issues with support for ConvTranspose2d #946

@mortal-Zero

Description

@mortal-Zero

Hello, and thank you for your outstanding project.
I encountered an error when converting a structure containing ConvTranspose2d using torch2trt. Here is the code and the error.

import torch
import torch.nn as nn
from torch2trt import torch2trt

model = nn.Sequential(
    nn.ConvTranspose2d(in_channels=32, out_channels=64,
                       kernel_size=4, stride=2,
                       padding=1, bias=True),
    nn.BatchNorm2d(num_features=64),
    nn.LeakyReLU()
)
model.to("cuda:0").eval()
x = torch.zeros([1, 32, 16, 16]).to("cuda:0")
y = model(x)
print("=====>> input: {} || output: {}".format(x.shape, y.shape))
model_trt = torch2trt(model, [x])
=====>> input: torch.Size([1, 32, 16, 16]) || output: torch.Size([1, 64, 32, 32])
[09/05/2024-11:09:34] [TRT] [E] 3: 0:0:DECONVOLUTION:GPU:kernel weights has count 32768 but 16384 was expected
[09/05/2024-11:09:34] [TRT] [E] 4: 0:0:DECONVOLUTION:GPU: count of 32768 weights in kernel, but kernel dimensions (4,4) with 32 input channels, 32 output channels and 1 groups were specified. Expected Weights count is 32 * 4*4 * 32 / 1 = 16384
[09/05/2024-11:09:34] [TRT] [E] 4: [graphShapeAnalyzer.cpp::needTypeAndDimensions::2212] Error Code 4: Internal Error (0:0:DECONVOLUTION:GPU: output shape can not be computed)
[09/05/2024-11:09:34] [TRT] [E] 3: [network.cpp::addScaleNd::1162] Error Code 3: API Usage Error (Parameter check failed at: optimizer/api/network.cpp::addScaleNd::1162, condition: qdqScale || basicScale
)
Traceback (most recent call last):
  File "/workspace/baiyixuan/test_cvcuda/digitalhuman_service/debug_codes/test.py", line 30, in <module>
    model_trt = torch2trt(model, [x])
  File "/root/miniconda3/envs/cvcuda/lib/python3.10/site-packages/torch2trt-0.5.0-py3.10-linux-x86_64.egg/torch2trt/torch2trt.py", line 643, in torch2trt
    outputs = module(*inputs)
  File "/root/miniconda3/envs/cvcuda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniconda3/envs/cvcuda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/root/miniconda3/envs/cvcuda/lib/python3.10/site-packages/torch/nn/modules/container.py", line 215, in forward
    input = module(input)
  File "/root/miniconda3/envs/cvcuda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniconda3/envs/cvcuda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/root/miniconda3/envs/cvcuda/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py", line 171, in forward
    return F.batch_norm(
  File "/root/miniconda3/envs/cvcuda/lib/python3.10/site-packages/torch2trt-0.5.0-py3.10-linux-x86_64.egg/torch2trt/torch2trt.py", line 262, in wrapper
    converter["converter"](ctx)
  File "/root/miniconda3/envs/cvcuda/lib/python3.10/site-packages/torch2trt-0.5.0-py3.10-linux-x86_64.egg/torch2trt/converters/native_converters.py", line 183, in convert_batch_norm
    output._trt = layer.get_output(0)
AttributeError: 'NoneType' object has no attribute 'get_output'

Looking forward to your reply.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions