c++llvmllvm-ir

How does `mlir::cast` transform objects to derived classes?


The LLVM Programmers Manual gives some brief idea of how llvm::cast works:

cast<>:

The cast<> operator is a “checked cast” operation. It converts a pointer or reference from a base class to a derived class, causing an assertion failure if it is not really an instance of the right type. This should be used in cases where you have some information that makes you believe that something is of the right type. An example of the isa<> and cast<> template is:

static bool isLoopInvariant(const Value *V, const Loop *L) {
  if (isa<Constant>(V) || isa<Argument>(V) || isa<GlobalValue>(V))
    return true;

  // Otherwise, it must be an instruction...
  return !L->contains(cast<Instruction>(V)->getParent());
}

It, however, doesn't specify any other uses beyond references or pointers. Moreover, there is a separate mlir::cast which, intuitively, should follow the same convention, but doesn't have any documentation provided. While developing an MLIR Pass, I came across the following challenge:

// op is mlir::tosa::TransposeOp
auto val = op.getOperand(0); // returns mlir::Value
if (val.getType().getShape().size() > 4) // compilation failure
  ...

The code fails to compile because mlir::Value::getType returns mlir::Type a base class, which doesn't provide any information on the tensor dimension. However, by definition of tosa.transpose I know the operand must be tensor. At the same time mlir::tosa::TransposeOp::getOperand(int) returns an mlir::Value (not a reference or a pointer), so it doesn't seem I can simply "cast" it to the derived type, however if I do as follows:

auto val = mlir::cast<mlir::TypedValue<mlir::RankedTensorType>>(op.getOperand(0));
if (val.getType().getShape().size() > 4)
  ...

Everything seems to work (it even returns the correct shape).

I struggle to comprehend how this casting could possibly work correctly. Does mlir::Value somehow magically provides required information to construct an mlir::TypedValue or i'm observing some kind of UB, which gives an expected output by chance?


Solution

  • First of all,

    auto val = op.getOperand(0); // returns mlir::Value
    

    if op is indeed tosa::TransposeOp, then consider using generated helpers (e.g. op.getInput1()) that would return typed operand values. In general try no to use lower-level getOperand(0), etc. unless you're writing something that is generic across different ops.

    There are couple of things that are important to understand how LLVM RTTI works and there is no magic here:

    1. First of all mlir::Value, mlir::Type, etc. are just thin wrappers around corresponding pointers. Therefore casts never involves object construction etc. After all, Value is something that is produced by operation and Type belongs to some context, you cannot materialize them out of thin air. mlir::TypedValue is essentially mlir::Value with statically known type coming from a template argument. Nothing fancy :)

    2. The way how LLVM-style RTTI works is quite simple: it involves type checking (is-a predicate) implemented by the derived class and then just pointer conversion. Some implementation details are outlined in https://llvm.org/docs/HowToSetUpLLVMStyleRTTI.html#advanced-use-cases There few implementation details when multiple inheritance is involved as one needs to be able to cast from multiple bases, though derived class already knows its layout and therefore can perform necessary this-adjustment. And given that cast<T>, isa<T> are just normal function calls they just take either point arguments or constant references (and under the hood they normalize references to pointers, do the cast, and then dereference the result).

    The implementation itself is quite clear, see e.g. https://github.com/llvm/llvm-project/blob/main/llvm/include/llvm/Support/Casting.h#L566 and further down (cast_convert_val , etc).