Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Convert pytorch tensor to opencv mat and vice versa in C++

I want to convert pytorch tensors to opencv mat and vice versa in C++. I have these two functions:

cv::Mat TensorToCVMat(torch::Tensor tensor)
{
    std::cout << "converting tensor to cvmat\n";
    tensor = tensor.squeeze().detach().permute({1, 2, 0});
    tensor = tensor.mul(255).clamp(0, 255).to(torch::kU8);
    tensor = tensor.to(torch::kCPU);
    int64_t height = tensor.size(0);
    int64_t width = tensor.size(1);
    cv::Mat mat(width, height, CV_8UC3);
    std::memcpy((void *)mat.data, tensor.data_ptr(), sizeof(torch::kU8) * tensor.numel());
    return mat.clone();
}

torch::Tensor CVMatToTensor(cv::Mat mat)
{
    std::cout << "converting cvmat to tensor\n";
    cv::cvtColor(mat, mat, cv::COLOR_BGR2RGB);
    cv::Mat matFloat;
    mat.convertTo(matFloat, CV_32F, 1.0 / 255);
    auto size = matFloat.size();
    auto nChannels = matFloat.channels();
    auto tensor = torch::from_blob(matFloat.data, {1, size.height, size.width, nChannels});
    return tensor.permute({0, 3, 1, 2});
}

In my code I load two images (image1 and image2) and I want to convert them to pytorch tensors and then back to opencv mat to check if it works. The problem is that I get an memory access error on the first call of TensorToCVMat and I cant figure out whats wrong as I do not have much experience with C++ programming.

cv::Mat image1;
image1 = cv::imread(argv[1]);
if (!image1.data)
{
    std::cout << "no image data\n";
    return -1;
}
cv::Mat image2;
image2 = cv::imread(argv[2]);
if (!image2.data)
{
    std::cout << "no image data\n";
    return -1;
}

torch::Tensor tensor1 = CVMatToTensor(image1);
cv::Mat new_image1 = TensorToCVMat(tensor1); // <<< this is where the memory access error is thrown
torch::Tensor tensor2 = CVMatToTensor(image2);
cv::Mat new_image2 = TensorToCVMat(tensor2);

It would be great if you could give me hints or an explanation to solve this problem.

like image 299
bastisstackoverflow Avatar asked Dec 28 '19 15:12

bastisstackoverflow


3 Answers

Not sure if the error is happening at the memcpy step. But you can use the void* data variant of the Mat constructor

Mat (int rows, int cols, int type, void *data, size_t step=AUTO_STEP)

and you can skip the memcpy step

tensor = uint8_tensor //shape: (h, w, 3)
cv::Mat mat = cv::Mat(height, width, CV_8UC3, tensor.data_ptr());
return mat;
like image 193
Saravanabalagi Ramachandran Avatar answered Sep 21 '22 19:09

Saravanabalagi Ramachandran


I am using torch>=1.7.0.

For a tensor of dtype=float and size [1, 3, height, width] this is what worked for me

cv::Mat torchTensortoCVMat(torch::Tensor& tensor)
    {
        tensor = tensor.squeeze().detach();
        tensor = tensor.permute({1, 2, 0}).contiguous();
        tensor = tensor.mul(255).clamp(0, 255).to(torch::kU8);
        tensor = tensor.to(torch::kCPU);
        int64_t height = tensor.size(0);
        int64_t width = tensor.size(1);
        cv::Mat mat = cv::Mat(cv::Size(width, height), CV_8UC3, tensor.data_ptr<uchar>());
        return mat.clone();
    }
like image 20
Operator77 Avatar answered Sep 18 '22 19:09

Operator77


My tensor shape was 500x500x3, I have to add tensor.reshape({width * height * 3}) to get the actual image

cv::Mat TensorToCVMat(torch::Tensor tensor)
{
    // torch.squeeze(input, dim=None, *, out=None) → Tensor
    // Returns a tensor with all the dimensions of input of size 1 removed.
    // tensor.detach
    // Returns a new Tensor, detached from the current graph.
    // permute dimension, 3x700x700 => 700x700x3
    tensor = tensor.detach().permute({1, 2, 0});
    // float to 255 range
    tensor = tensor.mul(255).clamp(0, 255).to(torch::kU8);
    // GPU to CPU?, may not needed
    tensor = tensor.to(torch::kCPU);
    // shape of tensor
    int64_t height = tensor.size(0);
    int64_t width = tensor.size(1);

    // Mat takes data form like {0,0,255,0,0,255,...} ({B,G,R,B,G,R,...})
    // so we must reshape tensor, otherwise we get a 3x3 grid
    tensor = tensor.reshape({width * height * 3});
    // CV_8UC3 is an 8-bit unsigned integer matrix/image with 3 channels
    cv::Mat imgbin(cv::Size(width, height), CV_8UC3, tensor.data_ptr());

    return imgbin;
}
like image 36
YugoAmaryl Avatar answered Sep 18 '22 19:09

YugoAmaryl