Shortcuts

Source code for torch.ao.quantization.pt2e.export_utils

import types

import torch
import torch.nn.functional as F


__all__ = [
    "model_is_exported",
    "_WrapperModule",
]


class _WrapperModule(torch.nn.Module):
    """Class to wrap a callable in an :class:`torch.nn.Module`. Use this if you
    are trying to export a callable.
    """

    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, *args, **kwargs):
        """Simple forward that just calls the ``fn`` provided to :meth:`WrapperModule.__init__`."""
        return self.fn(*args, **kwargs)


[docs]def model_is_exported(m: torch.nn.Module) -> bool: """ Return True if the `torch.nn.Module` was exported, False otherwise (e.g. if the model was FX symbolically traced or not traced at all). """ return isinstance(m, torch.fx.GraphModule) and any( "val" in n.meta for n in m.graph.nodes )
def _replace_dropout(m: torch.fx.GraphModule, train_to_eval: bool): """ Switch dropout patterns in the model between train and eval modes. Dropout has different behavior in train vs eval mode. For exported models, however, calling `model.train()` or `model.eval()` does not automatically switch the dropout behavior between the two modes, so here we need to rewrite the aten dropout patterns manually to achieve the same effect. See https://github.com/pytorch/pytorch/issues/103681. """ # Avoid circular dependencies from .utils import get_aten_graph_module # Needed to ensure subgraph matches are self-contained m.graph.eliminate_dead_code() m.recompile() for inplace in [False, True]: def dropout_train(x): return F.dropout(x, p=0.5, training=True, inplace=inplace) def dropout_eval(x): return F.dropout(x, p=0.5, training=False, inplace=inplace) example_inputs = (torch.randn(1),) if train_to_eval: match_pattern = get_aten_graph_module( _WrapperModule(dropout_train), example_inputs ) replacement_pattern = get_aten_graph_module( _WrapperModule(dropout_eval), example_inputs ) else: match_pattern = get_aten_graph_module( _WrapperModule(dropout_eval), example_inputs ) replacement_pattern = get_aten_graph_module( _WrapperModule(dropout_train), example_inputs ) from torch.fx.subgraph_rewriter import replace_pattern_with_filters replace_pattern_with_filters( m, match_pattern, replacement_pattern, match_filters=[], ignore_literals=True, ) m.recompile() def _replace_batchnorm(m: torch.fx.GraphModule, train_to_eval: bool): """ Switch batchnorm patterns in the model between train and eval modes. Batchnorm has different behavior in train vs eval mode. For exported models, however, calling `model.train()` or `model.eval()` does not automatically switch the batchnorm behavior between the two modes, so here we need to rewrite the aten batchnorm patterns manually to achieve the same effect. """ # TODO(Leslie): This function still fails to support custom momentum and eps value. # Enable this support in future updates. # Avoid circular dependencies from .utils import get_aten_graph_module # Needed to ensure subgraph matches are self-contained m.graph.eliminate_dead_code() m.recompile() def bn_train( x: torch.Tensor, bn_weight: torch.Tensor, bn_bias: torch.Tensor, bn_running_mean: torch.Tensor, bn_running_var: torch.Tensor, ): return F.batch_norm( x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True ) def bn_eval( x: torch.Tensor, bn_weight: torch.Tensor, bn_bias: torch.Tensor, bn_running_mean: torch.Tensor, bn_running_var: torch.Tensor, ): return F.batch_norm( x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=False ) example_inputs = ( torch.randn(1, 1, 3, 3), # x torch.randn(1), # bn_weight torch.randn(1), # bn_bias torch.randn(1), # bn_running_mean torch.randn(1), # bn_running_var ) if train_to_eval: match_pattern = get_aten_graph_module(_WrapperModule(bn_train), example_inputs) replacement_pattern = get_aten_graph_module( _WrapperModule(bn_eval), example_inputs ) else: match_pattern = get_aten_graph_module(_WrapperModule(bn_eval), example_inputs) replacement_pattern = get_aten_graph_module( _WrapperModule(bn_train), example_inputs ) from torch.fx.subgraph_rewriter import replace_pattern_with_filters replace_pattern_with_filters( m, match_pattern, replacement_pattern, match_filters=[], ignore_literals=True, ) m.recompile() # TODO: expose these under this namespace? def _move_exported_model_to_eval(model: torch.fx.GraphModule): """ Move an exported GraphModule to eval mode. This is equivalent to model.eval() but only for certain special ops like dropout, batchnorm. QAT users should call this before performing inference on the model. """ _replace_dropout(model, train_to_eval=True) _replace_batchnorm(model, train_to_eval=True) return model def _move_exported_model_to_train(model: torch.fx.GraphModule): """ Move an exported GraphModule to train mode. This is equivalent to model.train() but only for certain special ops like dropout, batchnorm. QAT users should call this before performing training on the model. """ _replace_dropout(model, train_to_eval=False) _replace_batchnorm(model, train_to_eval=False) return model def _allow_exported_model_train_eval(model: torch.fx.GraphModule): """ Allow users to call `model.train()` and `model.eval()` on an exported model, but with the effect of changing behavior between the two modes limited to special ops only, which are currently dropout and batchnorm. Note: This does not achieve the same effect as what `model.train()` and `model.eval()` does in eager models, but only provides an approximation. In particular, user code branching on `training` flag will not function correctly in general because the branch is already specialized at export time. Additionally, other ops beyond dropout and batchnorm that have different train/eval behavior will also not be converted properly. """ def _train(self, mode: bool = True): if mode: _move_exported_model_to_train(self) else: _move_exported_model_to_eval(self) def _eval(self): _move_exported_model_to_eval(self) model.train = types.MethodType(_train, model) # type: ignore[method-assign] model.eval = types.MethodType(_eval, model) # type: ignore[method-assign] return model

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources