Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

how to keep pytorch model in redis cache to access model faster for video streaming?

I have this code belonging to feature_extractor.py which is a part of this folder in here:

import torch
import torchvision.transforms as transforms
import numpy as np
import cv2
from .model import Net

class Extractor(object):
    def __init__(self, model_path, use_cuda=True):
        self.net = Net(reid=True)
        self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
        state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)['net_dict']
        self.net.load_state_dict(state_dict)
        print("Loading weights from {}... Done!".format(model_path))
        self.net.to(self.device)
        self.size = (64, 128)
        self.norm = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])

    def _preprocess(self, im_crops):
        def _resize(im, size):
            return cv2.resize(im.astype(np.float32) / 255., size)

        im_batch = torch.cat([self.norm(_resize(im, self.size)).unsqueeze(0) for im in im_crops], dim=0).float()
        return im_batch

    def __call__(self, im_crops):
        im_batch = self._preprocess(im_crops)
        with torch.no_grad():
            im_batch = im_batch.to(self.device)
            features = self.net(im_batch)
        return features.cpu().numpy()


if __name__ == '__main__':
    img = cv2.imread("demo.jpg")[:, :, (2, 1, 0)]
    extr = Extractor("checkpoint/ckpt.t7")
    feature = extr(img)
    print(feature.shape)

Now Imagine 200 requests are in row to proceed. The process of loading model for each request makes the code run slowly.

So I thought it might be a good idea to keep the pytorch model in cache. I modified it like this:

from redis import Redis
import msgpack as msg

r = Redis('111.222.333.444')

class Extractor(object):
    def __init__(self, model_path, use_cuda=True):
        try:
            self.net = msg.unpackb(r.get('REID_CKPT'))
        finally:
            self.net = Net(reid=True)
            self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
            state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)['net_dict']
            self.net.load_state_dict(state_dict)
            print("Loading weights from {}... Done!".format(model_path))
            self.net.to(self.device)
            packed_net = msg.packb(self.net)
            r.set('REID_CKPT', packed_net)

        self.size = (64, 128)
        self.norm = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])

Unfortunately this error comes up:

 File "msgpack/_packer.pyx", line 286, in msgpack._cmsgpack.Packer.pack
 File "msgpack/_packer.pyx", line 292, in msgpack._cmsgpack.Packer.pack
 File "msgpack/_packer.pyx", line 289, in msgpack._cmsgpack.Packer.pack
 File "msgpack/_packer.pyx", line 283, in msgpack._cmsgpack.Packer._pack
 TypeError: can not serialize 'Net' object

The reason obviously is because that it cannot convert Net object (pytorch nn.Module class) to bytes.

How can I efficiently save pytorch model in cache (or somehow keep it in RAM) and call for it for each request?

Thanks everyone.

like image 615
Masoud Masoumi Moghadam Avatar asked Oct 26 '22 22:10

Masoud Masoumi Moghadam


1 Answers

If you only need to keep model state on RAM, Redis is not necessary. You could instead mount RAM as a virtual disk and store model state there. Check out tmpfs.

like image 135
roman Avatar answered Oct 29 '22 13:10

roman