I have this code belonging to feature_extractor.py
which is a part of this folder in here:
import torch
import torchvision.transforms as transforms
import numpy as np
import cv2
from .model import Net
class Extractor(object):
def __init__(self, model_path, use_cuda=True):
self.net = Net(reid=True)
self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)['net_dict']
self.net.load_state_dict(state_dict)
print("Loading weights from {}... Done!".format(model_path))
self.net.to(self.device)
self.size = (64, 128)
self.norm = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def _preprocess(self, im_crops):
def _resize(im, size):
return cv2.resize(im.astype(np.float32) / 255., size)
im_batch = torch.cat([self.norm(_resize(im, self.size)).unsqueeze(0) for im in im_crops], dim=0).float()
return im_batch
def __call__(self, im_crops):
im_batch = self._preprocess(im_crops)
with torch.no_grad():
im_batch = im_batch.to(self.device)
features = self.net(im_batch)
return features.cpu().numpy()
if __name__ == '__main__':
img = cv2.imread("demo.jpg")[:, :, (2, 1, 0)]
extr = Extractor("checkpoint/ckpt.t7")
feature = extr(img)
print(feature.shape)
Now Imagine 200 requests are in row to proceed. The process of loading model for each request makes the code run slowly.
So I thought it might be a good idea to keep the pytorch model in cache. I modified it like this:
from redis import Redis
import msgpack as msg
r = Redis('111.222.333.444')
class Extractor(object):
def __init__(self, model_path, use_cuda=True):
try:
self.net = msg.unpackb(r.get('REID_CKPT'))
finally:
self.net = Net(reid=True)
self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)['net_dict']
self.net.load_state_dict(state_dict)
print("Loading weights from {}... Done!".format(model_path))
self.net.to(self.device)
packed_net = msg.packb(self.net)
r.set('REID_CKPT', packed_net)
self.size = (64, 128)
self.norm = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
Unfortunately this error comes up:
File "msgpack/_packer.pyx", line 286, in msgpack._cmsgpack.Packer.pack
File "msgpack/_packer.pyx", line 292, in msgpack._cmsgpack.Packer.pack
File "msgpack/_packer.pyx", line 289, in msgpack._cmsgpack.Packer.pack
File "msgpack/_packer.pyx", line 283, in msgpack._cmsgpack.Packer._pack
TypeError: can not serialize 'Net' object
The reason obviously is because that it cannot convert Net object (pytorch nn.Module
class) to bytes.
How can I efficiently save pytorch model in cache (or somehow keep it in RAM) and call for it for each request?
Thanks everyone.
If you only need to keep model state on RAM, Redis is not necessary. You could instead mount RAM as a virtual disk and store model state there. Check out tmpfs
.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With