Shortcuts

Source code for torch.nn.qat.modules.conv

import torch
import torch.nn as nn
from torch.nn.intrinsic import ConvReLU2d, ConvReLU3d


[docs]class Conv2d(nn.Conv2d): r""" A Conv2d module attached with FakeQuantize modules for weight, used for quantization aware training. We adopt the same interface as `torch.nn.Conv2d`, please see https://pytorch.org/docs/stable/nn.html?highlight=conv2d#torch.nn.Conv2d for documentation. Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to default. Attributes: weight_fake_quant: fake quant module for weight """ _FLOAT_MODULE = nn.Conv2d def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', qconfig=None, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode, **factory_kwargs) assert qconfig, 'qconfig must be provided for QAT module' self.qconfig = qconfig self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs) def forward(self, input): return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
[docs] @classmethod def from_float(cls, mod): r"""Create a qat module from a float module or qparams_dict Args: `mod` a float module, either produced by torch.ao.quantization utilities or directly from user """ assert type(mod) == cls._FLOAT_MODULE, 'qat.' + cls.__name__ + '.from_float only works for ' + \ cls._FLOAT_MODULE.__name__ assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' assert mod.qconfig, 'Input float module must have a valid qconfig' if type(mod) == ConvReLU2d: mod = mod[0] qconfig = mod.qconfig qat_conv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, stride=mod.stride, padding=mod.padding, dilation=mod.dilation, groups=mod.groups, bias=mod.bias is not None, padding_mode=mod.padding_mode, qconfig=qconfig) qat_conv.weight = mod.weight qat_conv.bias = mod.bias return qat_conv
def to_float(self): conv = torch.nn.Conv2d( self.in_channels, self.out_channels, self.kernel_size, # type: ignore[arg-type] self.stride, # type: ignore[arg-type] self.padding, # type: ignore[arg-type] self.dilation, # type: ignore[arg-type] self.groups, self.bias is not None, self.padding_mode) conv.weight = torch.nn.Parameter(self.weight.detach()) if self.bias is not None: conv.bias = torch.nn.Parameter(self.bias.detach()) return conv
[docs]class Conv3d(nn.Conv3d): r""" A Conv3d module attached with FakeQuantize modules for weight, used for quantization aware training. We adopt the same interface as `torch.nn.Conv3d`, please see https://pytorch.org/docs/stable/nn.html?highlight=conv3d#torch.nn.Conv3d for documentation. Similar to `torch.nn.Conv3d`, with FakeQuantize modules initialized to default. Attributes: weight_fake_quant: fake quant module for weight """ _FLOAT_MODULE = nn.Conv3d def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode="zeros", qconfig=None, device=None, dtype=None ) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode, **factory_kwargs ) assert qconfig, "qconfig must be provided for QAT module" self.qconfig = qconfig self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs) def forward(self, input): return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
[docs] @classmethod def from_float(cls, mod): r"""Create a qat module from a float module or qparams_dict Args: `mod` a float module, either produced by torch.ao.quantization utilities or directly from user """ assert type(mod) == cls._FLOAT_MODULE, ( "qat." + cls.__name__ + ".from_float only works for " + cls._FLOAT_MODULE.__name__ ) assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" assert mod.qconfig, "Input float module must have a valid qconfig" if type(mod) == ConvReLU3d: mod = mod[0] qconfig = mod.qconfig qat_conv = cls( mod.in_channels, mod.out_channels, mod.kernel_size, stride=mod.stride, padding=mod.padding, dilation=mod.dilation, groups=mod.groups, bias=mod.bias is not None, padding_mode=mod.padding_mode, qconfig=qconfig, ) qat_conv.weight = mod.weight qat_conv.bias = mod.bias return qat_conv
def to_float(self): conv = torch.nn.Conv3d( self.in_channels, self.out_channels, self.kernel_size, # type: ignore[arg-type] self.stride, # type: ignore[arg-type] self.padding, # type: ignore[arg-type] self.dilation, # type: ignore[arg-type] self.groups, self.bias is not None, self.padding_mode) conv.weight = torch.nn.Parameter(self.weight.detach()) if self.bias is not None: conv.bias = torch.nn.Parameter(self.bias.detach()) return conv

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources