Test Time Augmentation (TTA) in FastAI should be easily applied with learn.tta
, yet has led to numerous issues in my Cloud Run deployment. I have a working Cloud Run deployment that does base learner and metalearner scoring as a prediction endpoint using load_learner
from FastAI.
I want to switch learn.predict
to learn.tta
but issues keep arising. FastAI requires a slightly different input shape for tta
and has different shape of returned values. I wanted to make it more of a direct drop-in replacement for learn.predict
. This function worked to accomplish that in a minimalistic test notebook on Colab:
import random
from fastai.vision.all import *
# Function to perform TTA and format the output to match predict
def tta_predict(learner, img):
# Create a DataLoader for the single image using the test DataLoader
test_dl = learner.dls.test_dl([img])
# Perform TTA on the single image using the test DataLoader
preds, _ = learner.tta(dl=test_dl)
# Get the average probabilities
avg_probs = preds.mean(dim=0)
# Get the predicted class index
pred_idx = avg_probs.argmax().item()
# Get the class label
class_label = learner.dls.vocab[pred_idx]
# Format the output to match the structure of the predict method
return (class_label, pred_idx, avg_probs)
# Use the tta_predict function
prediction = tta_predict(learn, grayscale_img)
# Print the results
print(type(prediction)) # Print the type of the prediction object
print(prediction) # Print the prediction itself (class label, index, probabilities)
print(prediction[0]) # Print the predicted class label
print(prediction[2]) # Print the average probabilities
Although it seemed to work fine in the notebook, when I add that to the top of my production script and switch learn.predict
to tta_predict(learn, img)
for my base learners, the entire image starts to fail to build with Python 3.9:
Traceback (most recent call last): File "/app/main.py", line 11, in <module>
from fastai.vision.all import PILImage, BCEWithLogitsLossFlat, load_learner
File "/usr/local/lib/python3.9/site-packages/fastai/vision/all.py", line 4,
in <module> from .augment import * File "/usr/local/lib/python3.9/
site-packages/fastai/vision/augment.py", line 8, in <module> from .core import * File "/usr/local/lib/python3.9/site-packages/fastai/vision/core.py", line 259, in <module> class PointScaler(Transform): File "/usr/local/lib/python3.9/site-packages/fasttransform/transform.py", line 75, in __new__ if funcs: setattr(new_cls, nm, _merge_funcs(*funcs)) File "/usr/local/lib/python3.9/site-packages/fasttransform/transform.py", line 42, in _merge_funcs res = Function(fs[-1].methods[0].implementation) File "/usr/local/lib/python3.9/site-packages/plum/function.py", line 181, in methods self._resolve_pending_registrations() File "/usr/local/lib/python3.9/site-packages/plum/function.py", line 280, in _resolve_pending_registrations signature = Signature.from_callable(f, precedence=precedence) File "/usr/local/lib/python3.9/site-packages/plum/signature.py", line 88, in from_callable types, varargs = _extract_signature(f) File "/usr/local/lib/python3.9/site-packages/plum/signature.py", line 346, in _extract_signature resolve_pep563(f) File "/usr/local/lib/python3.9/site-packages/plum/signature.py", line 329, in resolve_pep563 beartype_resolve_pep563(f) # This mutates `f`. File "/usr/local/lib/python3.9/site-packages/beartype/peps/_pep563.py", line 263, in resolve_pep563 arg_name_to_hint[arg_name] = resolve_hint( File "/usr/local/lib/python3.9/site-packages/beartype/_check/forward/fwdmain.py", line 308, in resolve_hint return _resolve_func_scope_forward_hint( File "/usr/local/lib/python3.9/site-packages/beartype/_check/forward/fwdmain.py", line 855, in _resolve_func_scope_forward_hint raise exception_cls(exception_message) from exception beartype.roar.BeartypeDecorHintPep604Exception: Stringified PEP 604 type hint 'PILBase | TensorImageBase' syntactically invalid under Python < 3.10 (i.e., TypeError("unsupported operand type(s) for |: 'BypassNewMeta' and 'torch._C._TensorMeta'")). Consider either:
* Requiring Python >= 3.10. Abandon Python < 3.10 all ye who code here.
* Refactoring PEP 604 type hints into equivalent PEP 484 type hints: e.g.,
# Instead of this...
from __future__ import annotations
def bad_func() -> int | str: ...
# Do this. Ugly, yet it works. Worky >>>> pretty.
from typing import Union
I don't see anything in my code that could've caused that, yet there it is. I noticed somewhere in those messages it mentions "augment", which I take as confirmation that TTA is at fault (it was also the only thing that changed). So, I tried switching the Python version to 3.10. Now it builds but it's clearly broken:
ERROR loading model.pkl: Could not import 'Pipeline' from fastcore.transform - this module has been moved to the fasttransform package.
To migrate your code, please see the migration guide at: https://answerdotai.github.io/fasttransform/fastcore_migration_guide.html
The migration guide it mentions says to change
from fastcore.transform import Transform, Pipeline
to
from fasttransform import Transform, Pipeline
but my code never directly imports Pipeline
or Transform
, nor does it directly import fastcore
.
These two bugs stem from a recent change in fastcore
1.8 and the corresponding upgrade of fastai
to version 2.8.0 from 2.7.29 affecting load_learner
for vision models. The contributors from Answer.ai moved Pipeline
to their fasttransform
package but fastai
still does an import *
from .core
causing an error to be raised, since the *
imports the placeholder transform.py
that currently only exists to raise the error in question.
For now, the issue is solved by downgrading fastai
to v2.7.x. An issue has been opened on Github and is being addressed by the Answer.ai team.
If using an environment like Colab, you'll need to restart the session after downgrading fastai
.