Edgar Luque's Website

Intro to LLVM and MLIR with Rust and Melior

rust MLIR LLVM

If you haven't heard about MLIR yet, it is a novel project born within LLVM, also what powers MOJO, the Torch-MLIR project, the high level IR of flang, iree and more.

But in case you don't know much about LLVM yet, I'll try to explain a bit.

A primer on LLVM

LLVM, as their web says, is a collection of modular and reusable compiler and toolchain technologies. But in this post I will focus on the LLVM Core libraries, which focus on providing a source, target independent optimizer with code generation for many CPUs.

LLVM basis is the LLVM IR, a Static single-assignment form (SSA) language, that looks like pseudo assembly. The main property of SSA is that each variable is assigned exactly once and defined before it is used.

It is what LLVM works on to apply all the optimization passes, it being SSA is one of the major enablers of the optimizations it can do (and what makes it easier to implement them), to name a few (from wikipedia):

If you want to learn LLVM IR in detail, you should look at the LLVM Language Reference Manual. If you know assembly already it shouldn't be too hard, one interesting property is that LLVM has infinite registers, so you don't need to worry about register allocation. It has a simple type system, you can work with integers of any bit size (i1, i32, i1942652), although if you don't use a recent version (17+) you will find some bugs using big integers.

Probably the biggest wall you will hit is learning about GEPs (Get Element Ptr), it's often misunderstood how it works, so they even have a entire documentation page for it: https://llvm.org/docs/GetElementPtr.html. Another thing that may need attention are PHI nodes, which are how LLVM selects a value that comes from control flow branches due to the nature of SSA.

The API to build such IR have the following structure:

If you want to use LLVM with Rust in a type safe manner, I recommend the really well done inkwell crate. Check out their README to see how the previous mentioned structures are used.

MLIR

So what is MLIR? It goes a level above, in that LLVM IR itself is one of it's dialects.

MLIR is kind of a IR of IRs, and it supports many of them using "dialects". For example, you may have heard of NVVM IR (CUDA), MLIR supports modeling it through the NVVM dialect (or ROCDL for AMD), but there is also a more generic and higher level GPU dialect.

Those dialects define conversion passes between them, meaning for example, you can convert IR code using the GPU dialect to the NVVM dialect.

They also may define dialect passes, for example the -gpu-map-parallel-loops which greedily maps loops to GPU hardware dimensions.

Some notable dialects:

You can also make your own dialect, useful to make a domain specific language for example, in this dialect you can define transformations to other dialects, passes, etc.

All these dialects can exist in your MLIR code at the same time, but at the end, you want to execute your code, for this there are Targets, one is LLVM IR itself. In this case, you would need to use passes to convert all dialects to the LLVM dialect, and then you can make the translation from MLIR to LLVM IR.

The structure of MLIR is recursive as follows:

Region -> Block(s) -> Operation(s) -> Region(s)

The top level module is also a operation, which holds a single region with a single block.

A region can have 1 or more blocks, each block can have one or more operations, a operation can use 1 or more regions.

Operations

These provides the functionality, and what make up the bulk of MLIR.

A operation has the following properties:

You can read more about operations in the Operation Definition Specification.

To use MLIR with Rust, I recommend melior, here is a snippet making a function that adds 2 numbers:

use melior::{
    Context,
    dialect::{arith, DialectRegistry, func},
    ir::{*, attribute::{StringAttribute, TypeAttribute}, r#type::FunctionType},
    utility::register_all_dialects,
};

// We need a registry to hold all the dialects
let registry = DialectRegistry::new();
// Register all dialects that come with MLIR.
register_all_dialects(&registry);

// The MLIR context, like the LLVM one.
let context = Context::new();
context.append_dialect_registry(&registry);
context.load_all_available_dialects();

// A location is a debug location like in LLVM, in MLIR all
// operations need a location, even if its "unknown".
let location = Location::unknown(&context);

// A MLIR module is akin to a LLVM module.
let module = Module::new(location);

// A integer-like type with platform dependent bit width. (like size_t or usize)
// This is a type defined in the Builtin dialect.
let index_type = Type::index(&context);

// Append a `func::func` operation to the body (a block) of the module.
// This operation accepts a string attribute, which is the name.
// A type attribute, which contains a function type in this case.
// Then it accepts a single region, which is where the body
// of the function will be, this region can have
// multiple blocks, which is how you may implement
// control flow within the function.
// These blocks each can have more operations.
module.body().append_operation(func::func(
    &context,
    // accepts a StringAttribute which is the function name.
    StringAttribute::new(&context, "add"),
    // A type attribute, defining the function signature.
    TypeAttribute::new(
            FunctionType::new(&context, &[index_type, index_type], &[index_type]).into()
        ),
    {
        // The first block within the region, blocks accept arguments
        // In regions with control flow, MLIR leverages
        // this structure to implicitly represent
        // the passage of control-flow dependent values without the complex nuances
        // of PHI nodes in traditional SSA representations.
        let block = Block::new(&[(index_type, location), (index_type, location)]);

        // Use the arith dialect to add the 2 arguments.
        let sum = block.append_operation(arith::addi(
            block.argument(0).unwrap().into(),
            block.argument(1).unwrap().into(),
            location
        ));

        // Return the result using the "func" dialect return operation.
        block.append_operation(
            func::r#return( &[sum.result(0).unwrap().into()], location)
        );

        // The Func operation requires a region,
        // we add the block we created to the region and return it,
        // which is passed as an argument to the `func::func` function.
        let region = Region::new();
        region.append_block(block);
        region
    },
    &[],
    location,
));

assert!(module.as_operation().verify());

Here is a more complex function, using the SCF dialect, which allows us to use a while loop:

let context = Context::new();
load_all_dialects(&context);

let location = Location::unknown(&context);
let module = Module::new(location);
let index_type = Type::index(&context);
let float_type = Type::float64(&context);

module.body().append_operation(func::func(
    &context,
    StringAttribute::new(&context, "foo"),
    TypeAttribute::new(FunctionType::new(&context, &[], &[]).into()),
    {
        let block = Block::new(&[]);

        let initial = block.append_operation(arith::constant(
            &context,
            IntegerAttribute::new(0, index_type).into(),
            location,
        ));

        block.append_operation(scf::r#while(
            &[initial.result(0).unwrap().into()],
            &[float_type],
            {
                let block = Block::new(&[(index_type, location)]);

                let condition = block.append_operation(arith::constant(
                    &context,
                    IntegerAttribute::new(0, IntegerType::new(&context, 1).into())
                        .into(),
                    location,
                ));

                let result = block.append_operation(arith::constant(
                    &context,
                    FloatAttribute::new(&context, 42.0, float_type).into(),
                    location,
                ));

                block.append_operation(scf::condition(
                    condition.result(0).unwrap().into(),
                    &[result.result(0).unwrap().into()],
                    location,
                ));

                let region = Region::new();
                region.append_block(block);
                region
            },
            {
                let block = Block::new(&[(float_type, location)]);

                let result = block.append_operation(arith::constant(
                    &context,
                    IntegerAttribute::new(42, Type::index(&context)).into(),
                    location,
                ));

                block.append_operation(scf::r#yield(
                    &[result.result(0).unwrap().into()],
                    location,
                ));

                let region = Region::new();
                region.append_block(block);
                region
            },
            location,
        ));

        block.append_operation(func::r#return(&[], location));

        let region = Region::new();
        region.append_block(block);
        region
    },
    &[],
    location,
));

assert!(module.as_operation().verify());

This code generates the following MLIR IR:

module {
  func.func @foo() {
    %c0 = arith.constant 0 : index
    %0 = scf.while (%arg0 = %c0) : (index) -> f64 {
      %false = arith.constant false
      %cst = arith.constant 4.200000e+01 : f64
      scf.condition(%false) %cst : f64
    } do {
    ^bb0(%arg0: f64):
      %c42 = arith.constant 42 : index
      scf.yield %c42 : index
    }
    return
  }
}

There is way more to MLIR, but this is meant to be a small introduction.

Donate using Liberapay