torch.Tensor.scatter_¶
- Tensor.scatter_(dim, index, src, reduce=None) Tensor ¶
Writes all values from the tensor
src
intoself
at the indices specified in theindex
tensor. For each value insrc
, its output index is specified by its index insrc
fordimension != dim
and by the corresponding value inindex
fordimension = dim
.For a 3-D tensor,
self
is updated as:self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
This is the reverse operation of the manner described in
gather()
.self
,index
andsrc
(if it is a Tensor) should all have the same number of dimensions. It is also required thatindex.size(d) <= src.size(d)
for all dimensionsd
, and thatindex.size(d) <= self.size(d)
for all dimensionsd != dim
. Note thatindex
andsrc
do not broadcast.Moreover, as for
gather()
, the values ofindex
must be between0
andself.size(dim) - 1
inclusive.Warning
When indices are not unique, the behavior is non-deterministic (one of the values from
src
will be picked arbitrarily) and the gradient will be incorrect (it will be propagated to all locations in the source that correspond to the same index)!Note
The backward pass is implemented only for
src.shape == index.shape
.Additionally accepts an optional
reduce
argument that allows specification of an optional reduction operation, which is applied to all values in the tensorsrc
intoself
at the indices specified in theindex
. For each value insrc
, the reduction operation is applied to an index inself
which is specified by its index insrc
fordimension != dim
and by the corresponding value inindex
fordimension = dim
.Given a 3-D tensor and reduction using the multiplication operation,
self
is updated as:self[index[i][j][k]][j][k] *= src[i][j][k] # if dim == 0 self[i][index[i][j][k]][k] *= src[i][j][k] # if dim == 1 self[i][j][index[i][j][k]] *= src[i][j][k] # if dim == 2
Reducing with the addition operation is the same as using
scatter_add_()
.Warning
The reduce argument with Tensor
src
is deprecated and will be removed in a future PyTorch release. Please usescatter_reduce_()
instead for more reduction options.- Parameters
Example:
>>> src = torch.arange(1, 11).reshape((2, 5)) >>> src tensor([[ 1, 2, 3, 4, 5], [ 6, 7, 8, 9, 10]]) >>> index = torch.tensor([[0, 1, 2, 0]]) >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src) tensor([[1, 0, 0, 4, 0], [0, 2, 0, 0, 0], [0, 0, 3, 0, 0]]) >>> index = torch.tensor([[0, 1, 2], [0, 1, 4]]) >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src) tensor([[1, 2, 3, 0, 0], [6, 7, 0, 0, 8], [0, 0, 0, 0, 0]]) >>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), ... 1.23, reduce='multiply') tensor([[2.0000, 2.0000, 2.4600, 2.0000], [2.0000, 2.0000, 2.0000, 2.4600]]) >>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), ... 1.23, reduce='add') tensor([[2.0000, 2.0000, 3.2300, 2.0000], [2.0000, 2.0000, 2.0000, 3.2300]])