I am trying to create a custom handler in torchserve and want to also use torchserve's batch capability for parallelism for optimum use of resources. I am not able to find out how to write custom handler for this inference.
Problem: I got a document of 20 pages and I have OCRed each of those documents. I am using "transfo-xl-wt103" model for the inference.
Here is my code:
import json
import logging
import os
import re
import sys
sys.path.append(os.getcwd())
# Deep learning
import torch, transformers
from transformers import AutoModel, AutoTokenizer
from ts.torch_handler.base_handler import BaseHandler
import concurrent.futures
from abc import ABC
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)8.8s] %(message)s",
handlers=[logging.StreamHandler()],
)
logger = logging.getLogger(__name__)
class transformer_embedding_handler(BaseHandler, ABC):
"""Transformers handler class for text embedding
Args:
BaseHandler (Class): Base default handler to load torch based models
ABC (Class): Helper class that provides a standard way to create ABC using inheritance
"""
def __init__(self):
"""Class constructor"""
logger.info("I am here, __init__ method")
logger.info(f"transformer version:{transformers.__version__}")
logger.info(f"torch version:{torch.__version__}")
# run the constructor of the base classes
super(transformer_embedding_handler, self).__init__()
# will be set to true once initialize() function is completed
self.initialized = False
# configurations
self.model_name = "transfo-xl-wt103"
self.do_lower_case = True
self.max_length = 1024
# want batching?
self.batching = False
# if batching is set to true, padding should be true
self.padding = False
# Num of tensors in each batch
self.batch_size = 8
self.torch_worker = os.getenv('torch_worker') if os.getenv('torch_worker') else 5
def initialize(self, ctx):
"""Get the properties for the serving context
Args:
ctx (object): context object that contains model server system properties
"""
logger.info("I am here, initialize method")
# properties = ctx.system_properties
properties = ctx["system_properties"]
logger.info(f"Properties: {properties}")
model_dir = properties.get("model_dir")
logger.info("Model dir: '%s'", model_dir)
self.device = torch.device(
"cuda:" + str(properties.get("gpu_id"))
if torch.cuda.is_available()
else "cpu"
)
# Load the model and the tokenizer
self.model = AutoModel.from_pretrained(model_dir)
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, do_lower_case=self.do_lower_case)
self.model.to(self.device)
self.model.eval()
logger.debug(
"Transformer model from path {0} loaded successfully".format(model_dir)
)
# Initialization is done
self.initialized = True
def preprocess(self, requests):
"""Method for preprocessing the requests received before infernce
Args:
requests (list): list of requests to perform inference on
Returns:
dict: dict containing token tensors
"""
logger.info("I am here, preprocess method")
# decode the request
full_text_list = requests[0].get("data")
logger.info(f"Text from data {full_text_list}. type: {type(full_text_list)}")
prepare_for_transformers_config = {"max_lens": self.max_length}
with concurrent.futures.ThreadPoolExecutor(max_workers=int(self.torch_worker)) as executor:
futures = [executor.submit(self.tokenizer.encode_plus, full_text, pad_to_max_length=False,
add_special_tokens=True, return_tensors='pt') for full_text in full_text_list]
token_ids = [future.result().to(self.device) for future in
futures] # Not using as_completed function will maintain order of the list
logger.info(f"Token Generation completed! Number of items in array: {len(token_ids)}")
return token_ids
def inference(self, token_ids):
logger.info(f"I am here, inference method, processing token array length: {len(token_ids)}")
logger.info("Padding: '%s'", self.padding)
logger.info("Batching: '%s'", self.batching)
inference = []
for token in token_ids:
inference.append(self.model(**token))
return inference
def postprocess(self, outputs):
logger.info("I am here, postprocess method")
out_file = "C:\\Users\\pbansal2\\Documents\\PycharmProjects\\embedder_test\\output\\embeddings.pt"
with open(out_file, 'wb') as f:
torch.save(outputs, f)
return json.dumps({"success": True})
if __name__ == '__main__':
tmp = transformer_embedding_handler()
requests = [
{
"data": [
"TEXTTTTTTTTTTT111111111111111111",
"TEXTTTTTTTTTTT222222222222222222",
"TEXTTTTTTTTTTT333333333333333333"
]
}
]
tmp.initialize({"system_properties": {
"model_dir": "C:\\Users\\pbansal2\\Documents\\PycharmProjects\\embedder_test\\model-store", "gpu_id": 0,
"batch_size": 1, "server_name": "MMS", "server_version": "0.4.2"}})
out = tmp.preprocess(requests)
inputs = tmp.inference(out)
logger.info(inputs)
tmp.postprocess(inputs)
My problem here is this piece in the inference function -
for token in token_ids:
inference.append(self.model(**token))
Is there a way to tell torchserve to use batch_size and max_batch_delay here during inferencing so that it can batch the request rather than using a for loop and calculate one by one?
I already tried python's multiprocessing, that isn't helping much. Again not sure why, but when I use multiprocessing on a 8 CPU machine, (and analyzed using a top command), all the cpus seems to be in sleep state and there is practically nothing happening.
But when I do one by one, most of the CPUs are showing usage. I am not sure, but it seems like model has already implemented some kind of parallelism.
This is the model documentation - https://huggingface.co/transfo-xl-wt103#how-to-get-started-with-the-model
Any help here is appreciated! Thank you.
To resolve my for loop, I was able to use TensorDataset & DataLoader classes
batch_size=1
dataset_val = torch.utils.data.TensorDataset(token_ids)
dataloader_val = torch.utils.data.DataLoader(dataset_val, sampler=torch.utils.data.sampler.SequentialSampler(dataset_val), batch_size=batch_size)
Then I did -
inference=[]
for batch in dataloader_val:
inference.append(self.model(**batch))
There is one more thing I learnt during the process, if I use the above way, I am telling python/torch to keep the tokens in the memory and I am not instructing GC to cleanup. This was always causing OOM for me. So I did this -
def do_inference(batch):
return self.model(**batch)
def inference(self, token_ids):
logger.info(f"I am here, inference method, processing token array length: {len(token_ids)}")
logger.info("Padding: '%s'", self.padding)
logger.info("Batching: '%s'", self.batching)
inference = []
for batch in dataloader_val:
inference.append(self.do_inference(batch))
return inference
With the function ending, the python/torch GC ensure that previous tensors are getting cleaned up from memory before executing a new batch. This helped me a lot.