I'm trying to deploy torchscripted model in Python and Flask. As I realized (at least as mentioned here ) that scripted models need to be "warmed up" before using, so first run of such models takes much longer than subsequent ones. My question is: is there any way to load torchscripted models in Flask route and predict without loss of "worm-up" time? Can I store somewhere "warm-uped" model to avoid warming-up in every request? I wrote simple code that reproduce the "warm-up" pass:
import torchvision, torch, time
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
model = torch.jit.script(model)
model.eval()
x = [torch.randn((3,224,224))]
for i in range(3):
start = time.time()
model(x)
print(‘Time elapsed: {}’.format(time.time()-start))
Output:
Time elapsed: 38.29<br>
Time elapsed: 6.65<br>
Time elapsed: 6.65<br>
And Flask code:
import torch, torchvision, os, time
from flask import Flask
app = Flask(__name__)
@app.route('/')
def test_scripted_model(path='/tmp/scripted_model.pth'):
if os.path.exists(path):
model = torch.jit.load(path, map_location='cpu')
else:
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
model = torch.jit.script(model)
torch.jit.save(model, path)
model.eval()
x = [torch.randn((3, 224, 224))]
out = ''
for i in range(3):
start = time.time()
model(x)
out += 'Run {} time: {};\t'.format(i+1, round((time.time() - start), 2))
return out
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, debug=False)
Output:
Run 1 time: 46.01; Run 2 time: 8.76; Run 3 time: 8.55;
OS: Ubuntu 18.04 & Windows10
Python version: 3.6.9
Flask: 1.1.1
Torch: 1.4.0
Torchvision: 0.5.0
Update:
Solved "warm-up" problem as:
with torch.jit.optimized_execution(False):
model(x)
Update2: Solved Flask problem (as mentioned below) with creating global python model object before server starts and warming up it there. Then in each request the model is ready to use.
model = torch.jit.load(path, map_location='cpu').eval()
model(x)
app = Flask(__name__)
and then in @app.route:
@app.route('/')
def test_scripted_model():
global model
...
...
Can I store somewhere "warm-uped" model to avoid warming-up in every request?
Yes, just instantiate your model outside of the test_scripted_model
function and refer to it from within the function.