pythontensorflowpytorchtorchtorchserve

Torchserve custom handler - how to pass a list of tensors for batch inferencing


I am trying to create a custom handler in torchserve and want to also use torchserve's batch capability for parallelism for optimum use of resources. I am not able to find out how to write custom handler for this inference.

Problem: I got a document of 20 pages and I have OCRed each of those documents. I am using "transfo-xl-wt103" model for the inference.

Here is my code:

import json
import logging
import os
import re
import sys

sys.path.append(os.getcwd())

# Deep learning
import torch, transformers
from transformers import AutoModel, AutoTokenizer
from ts.torch_handler.base_handler import BaseHandler
import concurrent.futures
from abc import ABC

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)8.8s] %(message)s",
    handlers=[logging.StreamHandler()],
)
logger = logging.getLogger(__name__)


class transformer_embedding_handler(BaseHandler, ABC):
    """Transformers handler class for text embedding

    Args:
        BaseHandler (Class): Base default handler to load torch based models
        ABC (Class): Helper class that provides a standard way to create ABC using inheritance
    """

    def __init__(self):
        """Class constructor"""
        logger.info("I am here, __init__ method")
        logger.info(f"transformer version:{transformers.__version__}")
        logger.info(f"torch version:{torch.__version__}")
        # run the constructor of the base classes
        super(transformer_embedding_handler, self).__init__()
        # will be set to true once initialize() function is completed
        self.initialized = False
        # configurations
        self.model_name = "transfo-xl-wt103"
        self.do_lower_case = True
        self.max_length = 1024

        # want batching?
        self.batching = False
        # if batching is set to true, padding should be true
        self.padding = False
        # Num of tensors in each batch
        self.batch_size = 8
        self.torch_worker = os.getenv('torch_worker') if os.getenv('torch_worker') else 5

    def initialize(self, ctx):
        """Get the properties for the serving context

        Args:
            ctx (object): context object that contains model server system properties
        """
        logger.info("I am here, initialize method")
        # properties = ctx.system_properties

        properties = ctx["system_properties"]
        logger.info(f"Properties: {properties}")
        model_dir = properties.get("model_dir")
        logger.info("Model dir: '%s'", model_dir)

        self.device = torch.device(
            "cuda:" + str(properties.get("gpu_id"))
            if torch.cuda.is_available()
            else "cpu"
        )

        # Load the model and the tokenizer
        self.model = AutoModel.from_pretrained(model_dir)
        self.tokenizer = AutoTokenizer.from_pretrained(model_dir, do_lower_case=self.do_lower_case)
        self.model.to(self.device)
        self.model.eval()
        logger.debug(
            "Transformer model from path {0} loaded successfully".format(model_dir)
        )

        # Initialization is done
        self.initialized = True

    def preprocess(self, requests):
        """Method for preprocessing the requests received before infernce

        Args:
            requests (list): list of requests to perform inference on

        Returns:
            dict: dict containing token tensors
        """
        logger.info("I am here, preprocess method")

        # decode the request
        full_text_list = requests[0].get("data")
        logger.info(f"Text from data {full_text_list}. type: {type(full_text_list)}")

        prepare_for_transformers_config = {"max_lens": self.max_length}
        with concurrent.futures.ThreadPoolExecutor(max_workers=int(self.torch_worker)) as executor:
            futures = [executor.submit(self.tokenizer.encode_plus, full_text, pad_to_max_length=False,
                                       add_special_tokens=True, return_tensors='pt') for full_text in full_text_list]
        token_ids = [future.result().to(self.device) for future in
                     futures]  # Not using as_completed function will maintain order of the list

        logger.info(f"Token Generation completed! Number of items in array: {len(token_ids)}")

        return token_ids

    def inference(self, token_ids):
        logger.info(f"I am here, inference method, processing token array length: {len(token_ids)}")
        logger.info("Padding: '%s'", self.padding)
        logger.info("Batching: '%s'", self.batching)
        inference = []

        for token in token_ids:
            inference.append(self.model(**token))

        return inference

    def postprocess(self, outputs):
        logger.info("I am here, postprocess method")
        out_file = "C:\\Users\\pbansal2\\Documents\\PycharmProjects\\embedder_test\\output\\embeddings.pt"
        with open(out_file, 'wb') as f:
            torch.save(outputs, f)

        return json.dumps({"success": True})


if __name__ == '__main__':
    tmp = transformer_embedding_handler()
    requests = [
        {
            "data": [
                "TEXTTTTTTTTTTT111111111111111111",
                "TEXTTTTTTTTTTT222222222222222222",
                "TEXTTTTTTTTTTT333333333333333333"
            ]
        }
    ]

    tmp.initialize({"system_properties": {
        "model_dir": "C:\\Users\\pbansal2\\Documents\\PycharmProjects\\embedder_test\\model-store", "gpu_id": 0,
        "batch_size": 1, "server_name": "MMS", "server_version": "0.4.2"}})
    out = tmp.preprocess(requests)
    inputs = tmp.inference(out)
    logger.info(inputs)
    tmp.postprocess(inputs)

My problem here is this piece in the inference function -

        for token in token_ids:
            inference.append(self.model(**token))

Is there a way to tell torchserve to use batch_size and max_batch_delay here during inferencing so that it can batch the request rather than using a for loop and calculate one by one?

I already tried python's multiprocessing, that isn't helping much. Again not sure why, but when I use multiprocessing on a 8 CPU machine, (and analyzed using a top command), all the cpus seems to be in sleep state and there is practically nothing happening.

But when I do one by one, most of the CPUs are showing usage. I am not sure, but it seems like model has already implemented some kind of parallelism.

This is the model documentation - https://huggingface.co/transfo-xl-wt103#how-to-get-started-with-the-model

Any help here is appreciated! Thank you.


Solution

  • To resolve my for loop, I was able to use TensorDataset & DataLoader classes

    batch_size=1
    dataset_val = torch.utils.data.TensorDataset(token_ids)
    dataloader_val = torch.utils.data.DataLoader(dataset_val, sampler=torch.utils.data.sampler.SequentialSampler(dataset_val), batch_size=batch_size)
    

    Then I did -

    inference=[]
    for batch in dataloader_val:
        inference.append(self.model(**batch))
    

    There is one more thing I learnt during the process, if I use the above way, I am telling python/torch to keep the tokens in the memory and I am not instructing GC to cleanup. This was always causing OOM for me. So I did this -

    def do_inference(batch):
      return self.model(**batch)
    
    def inference(self, token_ids):
            logger.info(f"I am here, inference method, processing token array length: {len(token_ids)}")
            logger.info("Padding: '%s'", self.padding)
            logger.info("Batching: '%s'", self.batching)
            inference = []
    
            for batch in dataloader_val:
                inference.append(self.do_inference(batch))
    
            return inference
    

    With the function ending, the python/torch GC ensure that previous tensors are getting cleaned up from memory before executing a new batch. This helped me a lot.