Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Torch sum a tensor along an axis

ipdb> outputs.size() torch.Size([10, 100]) ipdb> print sum(outputs,0).size(),sum(outputs,1).size(),sum(outputs,2).size() (100L,) (100L,) (100L,) 

How do I sum over the columns instead?

like image 778
Abhishek Bhatia Avatar asked Jun 27 '17 22:06

Abhishek Bhatia


People also ask

What does torch sum do?

sum. Returns the sum of all elements in the input tensor.


1 Answers

The simplest and best solution is to use torch.sum().

To sum all elements of a tensor:

torch.sum(outputs) # gives back a scalar 

To sum over all rows (i.e. for each column):

torch.sum(outputs, dim=0) # size = [1, ncol] 

To sum over all columns (i.e. for each row):

torch.sum(outputs, dim=1) # size = [nrow, 1] 
like image 82
mbpaulus Avatar answered Oct 03 '22 03:10

mbpaulus