pythonpytorchpredictionfast-ai

Do Python version issues with TTA lead to fasttransform vs. fastcore bugs in Python >= 3.10?


Test Time Augmentation (TTA) in FastAI should be easily applied with learn.tta, yet has led to numerous issues in my Cloud Run deployment. I have a working Cloud Run deployment that does base learner and metalearner scoring as a prediction endpoint using load_learner from FastAI.

I want to switch learn.predict to learn.tta but issues keep arising. FastAI requires a slightly different input shape for tta and has different shape of returned values. I wanted to make it more of a direct drop-in replacement for learn.predict. This function worked to accomplish that in a minimalistic test notebook on Colab:

import random
from fastai.vision.all import *

# Function to perform TTA and format the output to match predict
def tta_predict(learner, img):
    # Create a DataLoader for the single image using the test DataLoader
    test_dl = learner.dls.test_dl([img])
    
    # Perform TTA on the single image using the test DataLoader
    preds, _ = learner.tta(dl=test_dl)
    
    # Get the average probabilities
    avg_probs = preds.mean(dim=0)
    
    # Get the predicted class index
    pred_idx = avg_probs.argmax().item()
    
    # Get the class label
    class_label = learner.dls.vocab[pred_idx]
    
    # Format the output to match the structure of the predict method
    return (class_label, pred_idx, avg_probs)

# Use the tta_predict function
prediction = tta_predict(learn, grayscale_img)

# Print the results
print(type(prediction))  # Print the type of the prediction object
print(prediction)  # Print the prediction itself (class label, index, probabilities)
print(prediction[0])  # Print the predicted class label
print(prediction[2])  # Print the average probabilities

Although it seemed to work fine in the notebook, when I add that to the top of my production script and switch learn.predict to tta_predict(learn, img) for my base learners, the entire image starts to fail to build with Python 3.9:

Traceback (most recent call last): File "/app/main.py", line 11, in <module> 
from fastai.vision.all import PILImage, BCEWithLogitsLossFlat, load_learner 
    File "/usr/local/lib/python3.9/site-packages/fastai/vision/all.py", line 4, 
in <module> from .augment import * File "/usr/local/lib/python3.9/
site-packages/fastai/vision/augment.py", line 8, in <module> from .core import * File "/usr/local/lib/python3.9/site-packages/fastai/vision/core.py", line 259, in <module> class PointScaler(Transform): File "/usr/local/lib/python3.9/site-packages/fasttransform/transform.py", line 75, in __new__ if funcs: setattr(new_cls, nm, _merge_funcs(*funcs)) File "/usr/local/lib/python3.9/site-packages/fasttransform/transform.py", line 42, in _merge_funcs res = Function(fs[-1].methods[0].implementation) File "/usr/local/lib/python3.9/site-packages/plum/function.py", line 181, in methods self._resolve_pending_registrations() File "/usr/local/lib/python3.9/site-packages/plum/function.py", line 280, in _resolve_pending_registrations signature = Signature.from_callable(f, precedence=precedence) File "/usr/local/lib/python3.9/site-packages/plum/signature.py", line 88, in from_callable types, varargs = _extract_signature(f) File "/usr/local/lib/python3.9/site-packages/plum/signature.py", line 346, in _extract_signature resolve_pep563(f) File "/usr/local/lib/python3.9/site-packages/plum/signature.py", line 329, in resolve_pep563 beartype_resolve_pep563(f) # This mutates `f`. File "/usr/local/lib/python3.9/site-packages/beartype/peps/_pep563.py", line 263, in resolve_pep563 arg_name_to_hint[arg_name] = resolve_hint( File "/usr/local/lib/python3.9/site-packages/beartype/_check/forward/fwdmain.py", line 308, in resolve_hint return _resolve_func_scope_forward_hint( File "/usr/local/lib/python3.9/site-packages/beartype/_check/forward/fwdmain.py", line 855, in _resolve_func_scope_forward_hint raise exception_cls(exception_message) from exception beartype.roar.BeartypeDecorHintPep604Exception: Stringified PEP 604 type hint 'PILBase | TensorImageBase' syntactically invalid under Python < 3.10 (i.e., TypeError("unsupported operand type(s) for |: 'BypassNewMeta' and 'torch._C._TensorMeta'")). Consider either:
        * Requiring Python >= 3.10. Abandon Python < 3.10 all ye who code here.
        * Refactoring PEP 604 type hints into equivalent PEP 484 type hints: e.g.,
        # Instead of this...
        from __future__ import annotations
        def bad_func() -> int | str: ...
        # Do this. Ugly, yet it works. Worky >>>> pretty.
        from typing import Union

I don't see anything in my code that could've caused that, yet there it is. I noticed somewhere in those messages it mentions "augment", which I take as confirmation that TTA is at fault (it was also the only thing that changed). So, I tried switching the Python version to 3.10. Now it builds but it's clearly broken:

ERROR loading model.pkl: Could not import 'Pipeline' from fastcore.transform - this module has been moved to the fasttransform package.
To migrate your code, please see the migration guide at: https://answerdotai.github.io/fasttransform/fastcore_migration_guide.html

The migration guide it mentions says to change

from fastcore.transform import Transform, Pipeline to

from fasttransform import Transform, Pipeline

but my code never directly imports Pipeline or Transform, nor does it directly import fastcore.


Solution

  • These two bugs stem from a recent change in fastcore 1.8 and the corresponding upgrade of fastai to version 2.8.0 from 2.7.29 affecting load_learner for vision models. The contributors from Answer.ai moved Pipeline to their fasttransform package but fastai still does an import * from .core causing an error to be raised, since the * imports the placeholder transform.py that currently only exists to raise the error in question.

    For now, the issue is solved by downgrading fastai to v2.7.x. An issue has been opened on Github and is being addressed by the Answer.ai team.

    If using an environment like Colab, you'll need to restart the session after downgrading fastai.