Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Torch C++: Getting the value of a int tensor by using *.data<int>()

In the C++ version of Libtorch, I found that I can get the value of a float tensor by *tensor_name[0].data<float>(), in which instead of 0 I can use any other valid index. But, when I have defined an int tensor by adding option at::kInt into the tensor creation, I cannot use this structure to get the value of the tensor, i.e. something like *tensor_name[0].data<at::kInt>() or *tensor_name[0].data<int>() does not work and the debugger keeps saying that Couldn't find method at::Tensor::data<at::kInt> or Couldn't find method at::Tensor::data<int>. I can get the values by auto value_array = tensor_name=accessor<int,1>(), but it was easier to use *tensor_name[0].data<int>(). Can you please let me know how I can use data<>() to get the value of an int tensor?

I also have a same problem with bool type.

like image 688
Afshin Oroojlooy Avatar asked Jan 15 '19 14:01

Afshin Oroojlooy


People also ask

How do you find the value of tensor PyTorch?

We can access the value of a tensor by using indexing and slicing. Indexing is used to access a single value in the tensor. slicing is used to access the sequence of values in a tensor. we can modify a tensor by using the assignment operator.

What is torch tensor ()?

A torch.Tensor is a multi-dimensional matrix containing elements of a single data type.

How do I find a tensor number?

You can use numel() function (number of elements) of the torch to find the number of elements in a given tensor.


1 Answers

Use item<dtype>() to get a scalar out of a Tensor.

int main() {
  torch::Tensor tensor = torch::randint(20, {2, 3});
  std::cout << tensor << std::endl;
  int a = tensor[0][0].item<int>();
  std::cout << a << std::endl;
  return 0;
}

~/l/build ❯❯❯ ./example-app
  3  10   3
  2   5   8
[ Variable[CPUFloatType]{2,3} ]
3

The following code prints 0 (tested on Linux with the stable libtorch):

#include <torch/script.h>
#include <iostream>                                     

int main(int argc, const char* argv[])                  
{
    auto indx = torch::zeros({20},at::dtype(at::kLong));
    std::cout << indx[0].item<long>() << std::endl;

    return 0;
}
like image 110
Fábio Perez Avatar answered Oct 04 '22 18:10

Fábio Perez