Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeformConv2d Cannot be Quantized #2794

Open
1 task
anzr299 opened this issue Jul 8, 2024 · 1 comment
Open
1 task

DeformConv2d Cannot be Quantized #2794

anzr299 opened this issue Jul 8, 2024 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@anzr299
Copy link
Contributor

anzr299 commented Jul 8, 2024

🐛 Describe the bug

Operator Metatype DeformConv2dOp is mentioned in nncf/nncf/torch/graph/operator_metatypes.py which uses the namespace torch.nn.functional whereas the function deform_conv2d belongs to torchvision.ops.deform_conv2d. As seen in the code output attached below, this deformable convolution was not quantized.

Environment

about-time==4.2.1
absl-py==2.1.0
accelerate==0.28.0
accuracy_checker @ git+https://github.com/openvinotoolkit/open_model_zoo.git@37f60eb7fe1dcdedc552b2fb184d646723ed5e80#subdirectory=tools/accuracy_checker
addict==2.4.0
aiohttp==3.9.5
aiosignal==1.3.1
alive-progress==3.1.5
async-timeout==4.0.3
attrs==23.2.0
autograd==1.6.2
certifi==2024.7.4
cfgv==3.4.0
charset-normalizer==3.3.2
cma==3.2.2
coloredlogs==15.0.1
contourpy==1.2.1
coverage==7.5.4
cycler==0.12.1
datasets==2.14.7
defusedxml==0.7.1
Deprecated==1.2.14
dill==0.3.7
distlib==0.3.8
efficientnet-pytorch==0.7.1
evaluate==0.3.0
exceptiongroup==1.2.1
execnet==2.1.1
fastcore==1.5.48
fastdownload==0.0.7
fastprogress==1.0.3
filelock==3.15.4
flatbuffers==24.3.25
fonttools==4.53.0
frozenlist==1.4.1
fsspec==2023.10.0
future==1.0.0
grapheme==0.6.0
grpcio==1.64.1
huggingface-hub==0.23.4
humanfriendly==10.0
identify==2.6.0
idna==3.7
iniconfig==2.0.0
Jinja2==3.1.4
joblib==1.4.2
jsonschema==4.22.0
jsonschema-specifications==2023.12.1
jstyleson==0.0.2
kiwisolver==1.4.5
lightning-utilities==0.11.3.post0
Markdown==3.6
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.9.1
mdurl==0.1.2
mpmath==1.3.0
multidict==6.0.5
multiprocess==0.70.15
natsort==8.4.0
networkx==3.3
ninja==1.11.1.1
-e git+https://github.com/anzr299/nncf.git@bfc94b9d1078024b04246f8fc41106582c227f7b#egg=nncf
nodeenv==1.9.1
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.5.82
nvidia-nvtx-cu12==12.1.105
onnx==1.16.0
onnxruntime==1.17.1
opencv-python==4.10.0.84
openvino==2024.2.0
openvino-telemetry==2024.1.0
packaging==24.1
pandas==2.2.2
pillow==10.4.0
platformdirs==4.2.2
pluggy==1.5.0
pre-commit==3.2.2
protobuf==4.25.3
psutil==6.0.0
pyarrow==16.1.0
pyarrow-hotfix==0.6
pycocotools==2.0.7
pydot==2.0.0
Pygments==2.18.0
pymoo==0.6.1.1
pyparsing==3.1.2
pytest==8.0.2
pytest-cov==4.1.0
pytest-dependency==0.6.0
pytest-mock==3.12.0
pytest-xdist==3.5.0
python-dateutil==2.9.0.post0
pytz==2024.1
PyYAML==6.0.1
referencing==0.35.1
regex==2024.5.15
requests==2.32.3
responses==0.18.0
rich==13.7.1
rpds-py==0.18.1
safetensors==0.4.3
scikit-learn==1.5.1
scipy==1.14.0
six==1.16.0
sympy==1.12.1
tabulate==0.9.0
tensorboard==2.17.0
tensorboard-data-server==0.7.2
threadpoolctl==3.5.0
timm==0.9.2
tokenizers==0.15.2
tomli==2.0.1
torch==2.3.0
torchmetrics==1.0.1
torchvision==0.18.0
tqdm==4.66.4
transformers==4.38.2
triton==2.3.0
typing_extensions==4.12.2
tzdata==2024.1
urllib3==2.2.2
virtualenv==20.26.3
Werkzeug==3.0.3
wrapt==1.16.0
xxhash==3.4.1
yarl==1.9.4

Minimal Reproducible Example


#Sample Deformable Convolution Network Model Definition which utilizes torchvision.ops.DeformConv2d
class DCNV2(nn.Module):
    def __init__(self):
        super(DCNV2, self).__init__()
        self.deform_conv1 = DeformConv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.offset_conv1 = nn.Conv2d(3, 18, kernel_size=3, stride=1, padding=1)
        
        self.deform_conv2 = DeformConv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.offset_conv2 = nn.Conv2d(32, 18, kernel_size=3, stride=1, padding=1)
        
        self.fc = nn.Linear(64 * 8 * 8, 10)
        
    def forward(self, x):
        offset1 = self.offset_conv1(x)
        x = self.deform_conv1(x, offset1)
        x = nn.ReLU()(x)
        x = nn.MaxPool2d(kernel_size=2, stride=2)(x)
        
        offset2 = self.offset_conv2(x)
        x = self.deform_conv2(x, offset2)
        x = nn.ReLU()(x)
        x = nn.MaxPool2d(kernel_size=2, stride=2)(x)
        
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

model = DCNV2()

#Dummy data for our model
class RandomDataset(torch.utils.data.Dataset):
    def __getitem__(self, index):
        return torch.randn(3, 32, 32), torch.tensor(0)
    
    def __len__(self):
        return 1000

data_loader = torch.utils.data.DataLoader(RandomDataset(), batch_size=32)

#transform function for the calibration dataset
def transform_fn(data_item):
    images, _ = data_item
    return images

calibration_dataset = nncf.Dataset(data_loader, transform_fn)
model.eval()
quantized_model = nncf.quantize(model, calibration_dataset)
print(quantized_model)
'''
OUTPUT: 

DCNV2(
  (deform_conv1): DeformConv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (offset_conv1): Conv2d(3, 18, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (deform_conv2): DeformConv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (offset_conv2): Conv2d(32, 18, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc): Linear(in_features=4096, out_features=10, bias=True)
  (_nncf): NNCFNetworkInterface(
    (external_quantizers): ModuleDict(
      (/nncf_model_input_0|OUTPUT): SymmetricQuantizer(bit=8, ch=False)
      (DCNV2/Conv2d[offset_conv1]/conv2d_0|INPUT1): SymmetricQuantizer(bit=8, ch=True)
      (DCNV2/Conv2d[offset_conv2]/conv2d_0|INPUT1): SymmetricQuantizer(bit=8, ch=True)
      (DCNV2/Linear[fc]/linear_0|INPUT1): SymmetricQuantizer(bit=8, ch=True)
    )
  )
)
'''


Are you going to submit a PR?

  • Yes I'd like to help by submitting a PR!
@anzr299 anzr299 added the bug Something isn't working label Jul 8, 2024
@anzr299
Copy link
Contributor Author

anzr299 commented Jul 8, 2024

@AlexanderDokuchaev

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants