torch.autograd.Function.vmap¶
- static Function.vmap(info, in_dims, *args)[source]¶
Define the behavior for this autograd.Function underneath
torch.vmap()
.For a
torch.autograd.Function()
to supporttorch.vmap()
, you must either override this static method, or setgenerate_vmap_rule
toTrue
(you may not do both).If you choose to override this staticmethod: it must accept
an
info
object as the first argument.info.batch_size
specifies the size of the dimension being vmapped over, whileinfo.randomness
is the randomness option passed totorch.vmap()
.an
in_dims
tuple as the second argument. For each arg inargs
,in_dims
has a correspondingOptional[int]
. It isNone
if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer specifying what dimension of the Tensor is being vmapped over.*args
, which is the same as the args toforward()
.
The return of the vmap staticmethod is a tuple of
(output, out_dims)
. Similar toin_dims
,out_dims
should be of the same structure asoutput
and contain oneout_dim
per output that specifies if the output has the vmapped dimension and what index it is in.Please see Extending torch.func with autograd.Function for more details.