I have the following function:
from lxml import etree
from typing import Union
def _get_inner_xml(element: Union[etree._Element, None]) -> Union[str, None]:
if element is None:
return None
# See https://stackoverflow.com/a/51124963
return (str(element.text or "") + "".join(etree.tostring(child, encoding="unicode") for child in element)).strip()
root = etree.fromstring('<html><body>TEXT<br/>TAIL</body></html>')
innerXML = _get_inner_xml(root)
print(innerXML)
My understanding of it is that if I pass None
as an argument, I always get None
as a return value. On the other hand, an etree._Element
as argument will always result in a str
return.
If I write the following in vscode using pylance (it uses pyright under the hood):
def test(element: etree._Element):
variable = _get_inner_xml(element)
In this case I get the type hint (variable) variable: str | None
. I would expect pylance to know that variable should be of the type str
. Am I overseeing something? Is this maybe a bug?
If this works as intended: Is there a possibility to manually tell pylance that "whenever this function gets a etree._Element
it will return a str
and whenever I pass None
it returns None
"?
The answer here is to use typing.overload
(documentation here), which allows you to register multiple different signatures for one function. Function definitions decorated with @overload
are ignored at runtime — they are just for the type-checker — so the body of the function can be filled with a literal ellipsis ...
, pass
, or just a docstring. You also need to make sure you provide a "concrete" implementation of the function that doesn't use @overload
.
from lxml import etree
from typing import Union, overload
@overload
def _get_inner_xml(element: etree._Element) -> str:
"""Signature when `element` is of type `etree._Element`"""
@overload
def _get_inner_xml(element: None) -> None: ...
"""Signature when `element` is of type `None`"""
def _get_inner_xml(element: Union[etree._Element, None]) -> Union[str, None]:
if element is None:
return None
# See https://stackoverflow.com/a/51124963
return (str(element.text or "") + "".join(etree.tostring(child, encoding="unicode") for child in element)).strip()
root = etree.fromstring('<html><body>TEXT<br/>TAIL</body></html>')
innerXML = _get_inner_xml(root)
print(innerXML)