torch.library¶
torch.library is a collection of APIs for extending PyTorch’s core library of operators. It contains utilities for creating new custom operators as well as extending operators defined with PyTorch’s C++ operator registration APIs (e.g. aten operators).
For a detailed guide on effectively using these APIs, please see this gdoc
Use torch.library.define()
to define new custom operators. Use the
impl methods, such as torch.library.impl()
and
func:torch.library.impl_abstract, to add implementations
for any operators (they may have been created using torch.library.define()
or
via PyTorch’s C++ operator registration APIs).
- torch.library.define(qualname, schema, *, lib=None, tags=())[source]¶
- torch.library.define(lib, schema, alias_analysis='')
Defines a new operator.
In PyTorch, defining an op (short for “operator”) is a two step-process: - we need to define the op (by providing an operator name and schema) - we need to implement behavior for how the operator interacts with various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc.
This entrypoint defines the custom operator (the first step) you must then perform the second step by calling various
impl_*
APIs, liketorch.library.impl()
ortorch.library.impl_abstract()
.- Parameters
qualname (str) – The qualified name for the operator. Should be a string that looks like “namespace::name”, e.g. “aten::sin”. Operators in PyTorch need a namespace to avoid name collisions; a given operator may only be created once. If you are writing a Python library, we recommend the namespace to be the name of your top-level module.
schema (str) – The schema of the operator. E.g. “(Tensor x) -> Tensor” for an op that accepts one Tensor and returns one Tensor. It does not contain the operator name (that is passed in
qualname
).lib (Optional[Library]) – If provided, the lifetime of this operator will be tied to the lifetime of the Library object.
tags (Tag | Sequence[Tag]) – one or more torch.Tag to apply to this operator. Tagging an operator changes the operator’s behavior under various PyTorch subsystems; please read the docs for the torch.Tag carefully before applying it.
- Example::
>>> import torch >>> import numpy as np >>> >>> # Define the operator >>> torch.library.define("mylib::sin", "(Tensor x) -> Tensor") >>> >>> # Add implementations for the operator >>> @torch.library.impl("mylibrary::sin", "cpu") >>> def f(x): >>> return torch.from_numpy(np.sin(x.numpy())) >>> >>> # Call the new operator from torch.ops. >>> x = torch.randn(3) >>> y = torch.ops.mylib.sin(x) >>> assert torch.allclose(y, x)
- torch.library.impl(qualname, types, func=None, *, lib=None)[source]¶
- torch.library.impl(lib, name, dispatch_key='')
Register an implementation for a device type for this operator.
You may pass “default” for
types
to register this implementation as the default implementation for ALL device types. Please only use this if the implementation truly supports all device types; for example, this is true if it is a composition of built-in PyTorch operators.Some valid types are: “cpu”, “cuda”, “xla”, “mps”, “ipu”, “xpu”.
- Parameters
Examples
>>> import torch >>> import numpy as np >>> >>> # Define the operator >>> torch.library.define("mylibrary::sin", "(Tensor x) -> Tensor") >>> >>> # Add implementations for the cpu device >>> @torch.library.impl("mylibrary::sin", "cpu") >>> def f(x): >>> return torch.from_numpy(np.sin(x.numpy())) >>> >>> x = torch.randn(3) >>> y = torch.ops.mylibrary.sin(x) >>> assert torch.allclose(y, x.sin())
- torch.library.impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1)[source]¶
Register an abstract implementation for this operator.
An “abstract implementation” specifies the behavior of this operator on Tensors that carry no data. Given some input Tensors with certain properties (sizes/strides/storage_offset/device), it specifies what the properties of the output Tensors are.
The abstract implementation has the same signature as the operator. It is run for both FakeTensors and meta tensors. To write an abstract implementation, assume that all Tensor inputs to the operator are regular CPU/CUDA/Meta tensors, but they do not have storage, and you are trying to return regular CPU/CUDA/Meta tensor(s) as output. The abstract implementation must consist of only PyTorch operations (and may not directly access the storage or data of any input or intermediate Tensors).
This API may be used as a decorator (see examples).
For a detailed guide on custom ops, please see https://docs.google.com/document/d/1W–T6wz8IY8fOI0Vm8BF44PdBgs283QvpelJZWieQWQ/edit
Examples
>>> import torch >>> import numpy as np >>> from torch import Tensor >>> >>> # Example 1: an operator without data-dependent output shape >>> torch.library.define( >>> "mylib::custom_linear", >>> "(Tensor x, Tensor weight, Tensor bias) -> Tensor") >>> >>> @torch.library.impl_abstract("mylib::custom_linear") >>> def custom_linear_abstract(x, weight): >>> assert x.dim() == 2 >>> assert weight.dim() == 2 >>> assert bias.dim() == 1 >>> assert x.shape[1] == weight.shape[1] >>> assert weight.shape[0] == bias.shape[0] >>> assert x.device == weight.device >>> >>> return (x @ weight.t()) + bias >>> >>> # Example 2: an operator with data-dependent output shape >>> torch.library.define("mylib::custom_nonzero", "(Tensor x) -> Tensor") >>> >>> @torch.library.impl_abstract("mylib::custom_nonzero") >>> def custom_nonzero_abstract(x): >>> # Number of nonzero-elements is data-dependent. >>> # Since we cannot peek at the data in an abstract impl, >>> # we use the ctx object to construct a new symint that >>> # represents the data-dependent size. >>> ctx = torch.library.get_ctx() >>> nnz = ctx.new_dynamic_size() >>> shape = [nnz, x.dim()] >>> result = x.new_empty(shape, dtype=torch.int64) >>> return result >>> >>> @torch.library.impl("mylib::custom_nonzero", "cpu") >>> def custom_nonzero_cpu(x): >>> x_np = x.numpy() >>> res = np.stack(np.nonzero(x_np), axis=1) >>> return torch.tensor(res, device=x.device)
- torch.library.get_ctx()[source]¶
get_ctx() returns the current AbstractImplCtx object.
Calling
get_ctx()
is only valid inside of an abstract impl (seetorch.library.impl_abstract()
for more usage details.- Return type
AbstractImplCtx
Low-level APIs¶
The following APIs are direct bindings to PyTorch’s C++ low-level operator registration APIs.
Warning
The low-level operator registration APIs and the PyTorch Dispatcher are a complicated PyTorch concept. We recommend you use the higher level APIs above (that do not require a torch.library.Library object) when possible. This blog post <http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/>`_ is a good starting point to learn about the PyTorch Dispatcher.
A tutorial that walks you through some examples on how to use this API is available on Google Colab.
- class torch.library.Library(ns, kind, dispatch_key='')[source]¶
A class to create libraries that can be used to register new operators or override operators in existing libraries from Python. A user can optionally pass in a dispatch keyname if they only want to register kernels corresponding to only one specific dispatch key.
To create a library to override operators in an existing library (with name ns), set the kind to “IMPL”. To create a new library (with name ns) to register new operators, set the kind to “DEF”. To create a fragment of a possibly existing library to register operators (and bypass the limitation that there is only one library for a given namespace), set the kind to “FRAGMENT”.
- Parameters
ns – library name
kind – “DEF”, “IMPL” (default: “IMPL”), “FRAGMENT”
dispatch_key – PyTorch dispatch key (default: “”)
- define(schema, alias_analysis='', *, tags=())[source]¶
Defines a new operator and its semantics in the ns namespace.
- Parameters
schema – function schema to define a new operator.
alias_analysis (optional) – Indicates if the aliasing properties of the operator arguments can be inferred from the schema (default behavior) or not (“CONSERVATIVE”).
tags (Tag | Sequence[Tag]) – one or more torch.Tag to apply to this operator. Tagging an operator changes the operator’s behavior under various PyTorch subsystems; please read the docs for the torch.Tag carefully before applying it.
- Returns
name of the operator as inferred from the schema.
- Example::
>>> my_lib = Library("foo", "DEF") >>> my_lib.define("sum(Tensor self) -> Tensor")
- impl(op_name, fn, dispatch_key='')[source]¶
Registers the function implementation for an operator defined in the library.
- Parameters
op_name – operator name (along with the overload) or OpOverload object.
fn – function that’s the operator implementation for the input dispatch key or
fallthrough_kernel()
to register a fallthrough.dispatch_key – dispatch key that the input function should be registered for. By default, it uses the dispatch key that the library was created with.
- Example::
>>> my_lib = Library("aten", "IMPL") >>> def div_cpu(self, other): >>> return self * (1 / other) >>> my_lib.impl("div.Tensor", div_cpu, "CPU")