Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to construct a 3D Tensor where every 2D sub tensor is a diagonal matrix in PyTorch?

Consider I have 2D Tensor, index_in_batch * diag_ele. How can I get a 3D Tensor index_in_batch * Matrix (who is a diagonal matrix, construct by drag_ele)?

The torch.diag() construct diagonal matrix only when input is 1D, and return diagonal element when input is 2D.

like image 968
Qinqing Liu Avatar asked Nov 19 '17 00:11

Qinqing Liu


People also ask

How do you make a 3d torch tensor?

Creating Tensors Tensors can be created from Python lists with the torch. tensor() function. tensor([1., 2., 3.]) tensor([[1., 2., 3.], [4., 5., 6.]])

How do you multiply tensors in PyTorch?

mul() method is used to perform element-wise multiplication on tensors in PyTorch. It multiplies the corresponding elements of the tensors. We can multiply two or more tensors. We can also multiply scalar and tensors.

How do you create an identity matrix in PyTorch?

To create an identity matrix, we use the torch. The number of columns are by default set to the number of rows. You may change the number of rows by providing it as a parameter. This method returns a 2D tensor (matrix) whose diagonals are 1's and all other elements are 0.


2 Answers

import torch

a = torch.rand(2, 3)
print(a)
b = torch.eye(a.size(1))
c = a.unsqueeze(2).expand(*a.size(), a.size(1))
d = c * b
print(d)

Output

 0.5938  0.5769  0.0555
 0.9629  0.5343  0.2576
[torch.FloatTensor of size 2x3]


(0 ,.,.) = 
  0.5938  0.0000  0.0000
  0.0000  0.5769  0.0000
  0.0000  0.0000  0.0555

(1 ,.,.) = 
  0.9629  0.0000  0.0000
  0.0000  0.5343  0.0000
  0.0000  0.0000  0.2576
[torch.FloatTensor of size 2x3x3]
like image 92
Wasi Ahmad Avatar answered Oct 13 '22 00:10

Wasi Ahmad


Use torch.diag_embed:

>>> a = torch.randn(2, 3)
>>> torch.diag_embed(a)
tensor([[[ 1.5410,  0.0000,  0.0000],
         [ 0.0000, -0.2934,  0.0000],
         [ 0.0000,  0.0000, -2.1788]],

        [[ 0.5684,  0.0000,  0.0000],
         [ 0.0000, -1.0845,  0.0000],
         [ 0.0000,  0.0000, -1.3986]]])
like image 38
Zhi Zhang Avatar answered Oct 13 '22 02:10

Zhi Zhang