Source code for torch.fx.subgraph_rewriter
from .graph_module import GraphModule
from .graph import Graph
from .node import Node
from ._symbolic_trace import symbolic_trace
from ._compatibility import compatibility
import copy
from typing import Callable, Dict, List, NamedTuple, Optional, Set
import torch
@compatibility(is_backward_compatible=True)
class Match(NamedTuple):
# Node from which the match was found
anchor: Node
# Maps nodes in the pattern subgraph to nodes in the larger graph
nodes_map: Dict[Node, Node]
class _SubgraphMatcher:
def __init__(self, pattern: Graph) -> None:
self.pattern = pattern
if len(pattern.nodes) == 0:
raise ValueError("_SubgraphMatcher cannot be initialized with an "
"empty pattern")
# `self.pattern_anchor` is the output Node in `pattern`
self.pattern_anchor = next(iter(reversed(pattern.nodes)))
# Ensure that there is only a single output value in the pattern
# since we don't support multiple outputs
assert len(self.pattern_anchor.all_input_nodes) == 1, \
"Pattern matching on multiple outputs is not supported"
# Maps nodes in the pattern subgraph to nodes in the larger graph
self.nodes_map: Dict[Node, Node] = {}
def matches_subgraph_from_anchor(self, anchor: Node) -> bool:
"""
Checks if the whole pattern can be matched starting from
``anchor`` in the larger graph.
Pattern matching is done by recursively comparing the pattern
node's use-def relationships against the graph node's.
"""
self.nodes_map = {}
return self._match_nodes(self.pattern_anchor, anchor)
# Compare the pattern node `pn` against the graph node `gn`
def _match_nodes(self, pn: Node, gn: Node) -> bool:
# Check if we've already matched these nodes in the current
# traversal
if pn in self.nodes_map:
return self.nodes_map[pn] == gn
def attributes_are_equal(pn: Node, gn: Node) -> bool:
# Use placeholder and output nodes as wildcards. The
# only exception is that an output node can't match
# a placeholder
if (pn.op == "placeholder"
or (pn.op == "output" and gn.op != "placeholder")):
return True
return pn.op == gn.op and pn.target == gn.target
# Terminate early if the node attributes are not equal
if not attributes_are_equal(pn, gn):
return False
# Optimistically mark `pn` as a match for `gn`
self.nodes_map[pn] = gn
# Traverse the use-def relationships to ensure that `pn` is a true
# match for `gn`
if pn.op == "placeholder":
return True
if (pn.op != "output"
and len(pn.all_input_nodes) != len(gn.all_input_nodes)):
return False
if pn.op == "output":
match_found = any(self._match_nodes(pn.all_input_nodes[0], gn_)
for gn_ in gn.all_input_nodes)
else:
match_found = (len(pn.all_input_nodes) == len(gn.all_input_nodes)
and all(self._match_nodes(pn_, gn_) for pn_, gn_
in zip(pn.all_input_nodes, gn.all_input_nodes)))
if not match_found:
self.nodes_map.pop(pn)
return False
return True
def _replace_submodules(gm: GraphModule, replacement: torch.nn.Module) -> None:
gm.delete_all_unused_submodules()
if isinstance(replacement, GraphModule):
replacement.graph.lint()
def try_get_submodule(mod: torch.nn.Module, target: str) -> Optional[torch.nn.Module]:
try:
mod_match = mod.get_submodule(target)
return mod_match
except AttributeError:
return None
for node in gm.graph.nodes:
if node.op == "call_module" or node.op == "get_attr":
gm_submod = try_get_submodule(gm, node.target)
replacement_submod = try_get_submodule(replacement, node.target)
# CASE 1: This target already exists as a submodule in our
# result GraphModule. Whether or not it exists in
# `replacement`, the existing submodule takes precedence.
if gm_submod is not None:
continue
# CASE 2: The target exists as a submodule in `replacement`
# only, so we need to copy it over.
elif replacement_submod is not None:
new_submod = copy.deepcopy(getattr(replacement, node.target))
gm.add_submodule(node.target, new_submod)
# CASE 3: The target doesn't exist as a submodule in `gm`
# or `replacement`
else:
raise RuntimeError("Attempted to create a \"", node.op,
"\" node during subgraph rewriting "
f"with target {node.target}, but "
"the referenced submodule does not "
"exist in either the original "
"GraphModule `gm` or the replacement"
" GraphModule `replacement`")
gm.graph.lint()
[docs]@compatibility(is_backward_compatible=True)
def replace_pattern(gm: GraphModule, pattern: Callable, replacement: Callable) -> List[Match]:
"""
Matches all possible non-overlapping sets of operators and their
data dependencies (``pattern``) in the Graph of a GraphModule
(``gm``), then replaces each of these matched subgraphs with another
subgraph (``replacement``).
Args:
``gm``: The GraphModule that wraps the Graph to operate on
``pattern``: The subgraph to match in ``gm`` for replacement
``replacement``: The subgraph to replace ``pattern`` with
Returns:
List[Match]: A list of ``Match`` objects representing the places
in the original graph that ``pattern`` was matched to. The list
is empty if there are no matches. ``Match`` is defined as:
.. code-block:: python
class Match(NamedTuple):
# Node from which the match was found
anchor: Node
# Maps nodes in the pattern subgraph to nodes in the larger graph
nodes_map: Dict[Node, Node]
Examples:
.. code-block:: python
import torch
from torch.fx import symbolic_trace, subgraph_rewriter
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, w1, w2):
m1 = torch.cat([w1, w2]).sum()
m2 = torch.cat([w1, w2]).sum()
return x + torch.max(m1) + torch.max(m2)
def pattern(w1, w2):
return torch.cat([w1, w2]).sum()
def replacement(w1, w2):
return torch.stack([w1, w2])
traced_module = symbolic_trace(M())
subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
The above code will first match ``pattern`` in the ``forward``
method of ``traced_module``. Pattern-matching is done based on
use-def relationships, not node names. For example, if you had
``p = torch.cat([a, b])`` in ``pattern``, you could match
``m = torch.cat([a, b])`` in the original ``forward`` function,
despite the variable names being different (``p`` vs ``m``).
The ``return`` statement in ``pattern`` is matched based on its
value only; it may or may not match to the ``return`` statement in
the larger graph. In other words, the pattern doesn't have to extend
to the end of the larger graph.
When the pattern is matched, it will be removed from the larger
function and replaced by ``replacement``. If there are multiple
matches for ``pattern`` in the larger function, each non-overlapping
match will be replaced. In the case of a match overlap, the first
found match in the set of overlapping matches will be replaced.
("First" here being defined as the first in a topological ordering
of the Nodes' use-def relationships. In most cases, the first Node
is the parameter that appears directly after ``self``, while the
last Node is whatever the function returns.)
One important thing to note is that the parameters of the
``pattern`` Callable must be used in the Callable itself,
and the parameters of the ``replacement`` Callable must match
the pattern. The first rule is why, in the above code block, the
``forward`` function has parameters ``x, w1, w2``, but the
``pattern`` function only has parameters ``w1, w2``. ``pattern``
doesn't use ``x``, so it shouldn't specify ``x`` as a parameter.
As an example of the second rule, consider replacing
.. code-block:: python
def pattern(x, y):
return torch.neg(x) + torch.relu(y)
with
.. code-block:: python
def replacement(x, y):
return torch.relu(x)
In this case, ``replacement`` needs the same number of parameters
as ``pattern`` (both ``x`` and ``y``), even though the parameter
``y`` isn't used in ``replacement``.
After calling ``subgraph_rewriter.replace_pattern``, the generated
Python code looks like this:
.. code-block:: python
def forward(self, x, w1, w2):
stack_1 = torch.stack([w1, w2])
sum_1 = stack_1.sum()
stack_2 = torch.stack([w1, w2])
sum_2 = stack_2.sum()
max_1 = torch.max(sum_1)
add_1 = x + max_1
max_2 = torch.max(sum_2)
add_2 = add_1 + max_2
return add_2
"""
# Get the graphs for `gm`, `pattern`, `replacement`
original_graph = gm.graph
pattern_graph = symbolic_trace(pattern).graph
replacement_graph = symbolic_trace(replacement).graph
# Find all possible pattern matches in original_graph. Note that
# pattern matches may overlap with each other.
matcher = _SubgraphMatcher(pattern_graph)
matches: List[Match] = []
# Consider each node as an "anchor" (deepest matching graph node)
for anchor in original_graph.nodes:
if matcher.matches_subgraph_from_anchor(anchor):
def pattern_is_contained(nodes_map: Dict[Node, Node]) -> bool:
# `lookup` represents all the nodes in `original_graph`
# that are part of `pattern`
lookup: Dict[Node, Node] = {v: k for k, v in nodes_map.items()}
for n in lookup.keys():
# Nodes that can "leak"...
# Placeholders (by definition)
if n.op == "placeholder":
continue
# Pattern output (acts as a container)
if lookup[n].op == "output":
continue
# Result contained by pattern output (what we'll
# hook in to the new Graph, thus what we'll
# potentially use in other areas of the Graph as
# an input Node)
if (len(lookup[n].users) == 1
and list(lookup[n].users.keys())[0].op == "output"):
continue
for user in n.users:
# If this node has users that were not in
# `lookup`, then it must leak out of the
# pattern subgraph
if user not in lookup:
return False
return True
# It's not a match if the pattern leaks out into the rest
# of the graph
if pattern_is_contained(matcher.nodes_map):
# Shallow copy nodes_map
matches.append(Match(anchor=anchor,
nodes_map=copy.copy({
key: value
for key, value in matcher.nodes_map.items()
})))
# The set of all nodes in `original_graph` that we've seen thus far
# as part of a pattern match
replaced_nodes: Set[Node] = set()
# As we progressively replace nodes, we'll need to keep track of how the match results should change
match_changed_node: Dict[Node, Node] = dict()
# Return True if one of the nodes in the current match has already
# been used as part of another match
def overlaps_with_prev_match(match: Match) -> bool:
for pn, gn in match.nodes_map.items():
if pn.op in ["placeholder", "output"]:
continue
if gn in replaced_nodes and gn.op != "placeholder":
return True
return False
for match in matches:
# Skip overlapping matches
if overlaps_with_prev_match(match):
continue
# Map replacement graph nodes to their copy in `original_graph`
val_map: Dict[Node, Node] = {}
pattern_placeholders = [n for n in pattern_graph.nodes
if n.op == "placeholder"]
assert len(pattern_placeholders) > 0
replacement_placeholders = [n for n in replacement_graph.nodes
if n.op == "placeholder"]
assert len(pattern_placeholders) == len(replacement_placeholders)
placeholder_map = {r: p for r, p
in zip(replacement_placeholders, pattern_placeholders)}
# node from `original_graph` that matched with the output node
# in `pattern`
subgraph_output: Node = match.anchor
def mark_node_as_replaced(n: Node) -> None:
if n not in match.nodes_map.values():
return
for n_ in n.all_input_nodes:
mark_node_as_replaced(n_)
replaced_nodes.add(n)
for input_node in subgraph_output.all_input_nodes:
mark_node_as_replaced(input_node)
# Initialize `val_map` with mappings from placeholder nodes in
# `replacement` to their corresponding node in `original_graph`
for replacement_node in replacement_placeholders:
# Get the `original_graph` placeholder node
# corresponding to the current `replacement_node`
pattern_node = placeholder_map[replacement_node]
original_graph_node = match_changed_node.get(match.nodes_map[pattern_node], match.nodes_map[pattern_node])
# Populate `val_map`
val_map[replacement_node] = original_graph_node
# Copy the replacement graph over
with original_graph.inserting_before(subgraph_output):
copied_output = original_graph.graph_copy(replacement_graph,
val_map)
# Hook the output Node of the replacement subgraph in to the
# original Graph at the correct location
# CASE 1: We need to hook the replacement subgraph in somewhere
# in the middle of the graph. We replace the Node in the
# original graph that corresponds to the end of the pattern
# subgraph
if subgraph_output.op != "output":
pattern_outputs = [n for n in pattern_graph.nodes
if n.op == "output"]
assert len(pattern_outputs) > 0
replacement_outputs = [n for n in replacement_graph.nodes
if n.op == "output"]
assert len(replacement_outputs) == len(pattern_outputs)
outputs_map = {p: r for r, p
in zip(replacement_outputs, pattern_outputs)}
for pn, gn in match.nodes_map.items():
if gn.op == "placeholder":
continue
# Search for the node corresponding to the output of the pattern
if pn.op != "output":
continue
assert subgraph_output == gn
# Update all anchor inputs to the new nodes
rn = outputs_map[pn]
for pn_input, rn_input in zip(pn.all_input_nodes, rn.all_input_nodes):
gn_input = match.nodes_map[pn_input]
rn_input_in_original_graph = val_map[rn_input]
gn_input.replace_all_uses_with(rn_input_in_original_graph)
# We store the updated node point in case other nodes want to use it
match_changed_node[gn_input] = rn_input_in_original_graph
assert subgraph_output.op != "output"
# CASE 2: The pattern subgraph match extends to the end of the
# original graph, so we need to change the current graph's
# output Node to reflect the insertion of the replacement graph.
# We'll keep the current output Node, but update its args and
# `_input_nodes` as necessary
else:
subgraph_output.args = ((copied_output,))
if isinstance(copied_output, Node):
subgraph_output._input_nodes = {copied_output: None}
assert isinstance(copied_output, Node)
# Erase the `pattern` nodes
for node in reversed(original_graph.nodes):
if len(node.users) == 0 and node.op != "output":
original_graph.erase_node(node)
# Update the passed-in GraphModule to reflect the new state of
# `original_graph`
gm.recompile()
# If `replacement` was an nn.Module, we'll need to make sure that
# all the submodules have been copied over correctly
if isinstance(replacement, torch.nn.Module):
_replace_submodules(gm, replacement)
return matches