Shortcuts

Source code for torch.backends.opt_einsum

from typing import Any
import warnings
import sys
from functools import lru_cache as _lru_cache
from contextlib import contextmanager
from torch.backends import ContextProp, PropModule, __allow_nonbracketed_mutation

try:
    import opt_einsum as _opt_einsum  # type: ignore[import]
except ImportError:
    _opt_einsum = None


[docs]@_lru_cache() def is_available() -> bool: r"""Returns a bool indicating if opt_einsum is currently available.""" return _opt_einsum is not None
[docs]def get_opt_einsum() -> Any: r"""Returns the opt_einsum package if opt_einsum is currently available, else None.""" return _opt_einsum
def _set_enabled(_enabled: bool) -> None: if not is_available() and _enabled: raise ValueError(f'opt_einsum is not available, so setting `enabled` to {_enabled} will not reap ' 'the benefits of calculating an optimal path for einsum. torch.einsum will ' 'fall back to contracting from left to right. To enable this optimal path ' 'calculation, please install opt-einsum.') global enabled enabled = _enabled def _get_enabled() -> bool: return enabled def _set_strategy(_strategy: str) -> None: if not is_available(): raise ValueError(f'opt_einsum is not available, so setting `strategy` to {_strategy} will not be meaningful. ' 'torch.einsum will bypass path calculation and simply contract from left to right. ' 'Please install opt_einsum or unset `strategy`.') if not enabled: raise ValueError(f'opt_einsum is not enabled, so setting a `strategy` to {_strategy} will not be meaningful. ' 'torch.einsum will bypass path calculation and simply contract from left to right. ' 'Please set `enabled` to `True` as well or unset `strategy`.') if _strategy not in ['auto', 'greedy', 'optimal']: raise ValueError(f'`strategy` must be one of the following: [auto, greedy, optimal] but is {_strategy}') global strategy strategy = _strategy def _get_strategy() -> str: return strategy def set_flags(_enabled=None, _strategy=None): orig_flags = (enabled, None if not is_available() else strategy) if _enabled is not None: _set_enabled(_enabled) if _strategy is not None: _set_strategy(_strategy) return orig_flags @contextmanager def flags(enabled=None, strategy=None): with __allow_nonbracketed_mutation(): orig_flags = set_flags(enabled, strategy) try: yield finally: # recover the previous values with __allow_nonbracketed_mutation(): set_flags(*orig_flags) # The magic here is to allow us to intercept code like this: # # torch.backends.opt_einsum.enabled = True class OptEinsumModule(PropModule): def __init__(self, m, name): super().__init__(m, name) global enabled enabled = ContextProp(_get_enabled, _set_enabled) global strategy strategy = None if is_available(): strategy = ContextProp(_get_strategy, _set_strategy) # This is the sys.modules replacement trick, see # https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273 sys.modules[__name__] = OptEinsumModule(sys.modules[__name__], __name__) enabled = True if is_available() else False strategy = 'auto' if is_available() else None

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources