Skip to content Skip to sidebar Skip to footer

Zero Diagonal Of A PyTorch Tensor?

Is there a simple way to zero the diagonal of a PyTorch tensor? For example I have: tensor([[2.7183, 0.4005, 2.7183, 0.5236], [0.4005, 2.7183, 0.4004, 1.3469], [2.7

Solution 1:

I believe the simplest would be to use torch.diagonal:

z = torch.randn(4,4)
torch.diagonal(z, 0).zero_()
print(z)
>>> tensor([[ 0.0000, -0.6211,  0.1120,  0.8362],
            [-0.1043,  0.0000,  0.1770,  0.4197],
            [ 0.7211,  0.1138,  0.0000, -0.7486], 
            [-0.5434, -0.8265, -0.2436,  0.0000]])

This way, the code is perfectly explicit, and you delegate the performance to pytorch's built in functions.


Solution 2:

Yes, there are a couple ways to do that, simplest one would be to go directly:

import torch

tensor = torch.rand(4, 4)
tensor[torch.arange(tensor.shape[0]), torch.arange(tensor.shape[1])] = 0

This one broadcasts 0 value across all pairs, e.g. (0, 0), (1, 1), ..., (n, n)

Another way would be (readability is debatable) to use the inverse of torch.eye like this:

tensor = torch.rand(4, 4)
tensor *= ~(torch.eye(*tensor.shape).bool())

This one creates additional matrix and does way more operations, hence I'd stick with the first version.


Solution 3:

As an alternative to indexing with two tensors separately, you could achieve this using a combination of torch.repeat, and torch.split, taking advantage of the fact the latter returns a tuple:

>>> x[torch.arange(len(x)).repeat(2).split(len(x))] = 0
>>> x
tensor([[0.0000, 0.4005, 2.7183, 0.5236],
        [0.4005, 0.0000, 0.4004, 1.3469],
        [2.7183, 0.4004, 0.0000, 0.5239],
        [0.5236, 1.3469, 0.5239, 0.0000]])

Solution 4:

This starts to look like a coding contest :)

Here's my way:

x.flatten()[::(x.shape[-1]+1)] = 0

Solution 5:

You can simply use:

x.fill_diagonal_(0)

Post a Comment for "Zero Diagonal Of A PyTorch Tensor?"