//===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= // // This file implements a simple IR generation targeting MLIR from a Module AST // for the Toy language. // //===----------------------------------------------------------------------===// #include "toy/MLIRGen.h" #include "toy/AST.h" #include "toy/Dialect.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/StandardTypes.h" #include "mlir/IR/Types.h" #include "mlir/StandardOps/Ops.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopedHashTable.h" #include "llvm/Support/raw_ostream.h" #include using namespace toy; using llvm::cast; using llvm::dyn_cast; using llvm::isa; using llvm::make_unique; using llvm::ScopedHashTableScope; using llvm::SmallVector; using llvm::StringRef; using llvm::Twine; namespace { /// Implementation of a simple MLIR emission from the Toy AST. /// /// This will emit operations that are specific to the Toy language, preserving /// the semantics of the language and (hopefully) allow to perform accurate /// analysis and transformation based on these high level semantics. /// /// At this point we take advantage of the "raw" MLIR APIs to create operations /// that haven't been registered in any way with MLIR. These operations are /// unknown to MLIR, custom passes could operate by string-matching the name of /// these operations, but no other type checking or semantic is associated with /// them natively by MLIR. class MLIRGenImpl { public: MLIRGenImpl(mlir::MLIRContext &context) : context(context) {} /// Public API: convert the AST for a Toy module (source file) to an MLIR /// Module. std::unique_ptr mlirGen(ModuleAST &moduleAST) { // We create an empty MLIR module and codegen functions one at a time and // add them to the module. theModule = make_unique(&context); for (FunctionAST &F : moduleAST) { auto func = mlirGen(F); if (!func) return nullptr; theModule->getFunctions().push_back(func.release()); } // FIXME: (in the next chapter...) without registering a dialect in MLIR, // this won't do much, but it should at least check some structural // properties. if (failed(theModule->verify())) { context.emitError(mlir::UnknownLoc::get(&context), "Module verification error"); return nullptr; } return std::move(theModule); } private: /// In MLIR (like in LLVM) a "context" object holds the memory allocation and /// the ownership of many internal structure of the IR and provide a level /// of "uniquing" across multiple modules (types for instance). mlir::MLIRContext &context; /// A "module" matches a source file: it contains a list of functions. std::unique_ptr theModule; /// The builder is a helper class to create IR inside a function. It is /// re-initialized every time we enter a function and kept around as a /// convenience for emitting individual operations. /// The builder is stateful, in particular it keeeps an "insertion point": /// this is where the next operations will be introduced. std::unique_ptr builder; /// The symbol table maps a variable name to a value in the current scope. /// Entering a function creates a new scope, and the function arguments are /// added to the mapping. When the processing of a function is terminated, the /// scope is destroyed and the mappings created in this scope are dropped. llvm::ScopedHashTable symbolTable; /// Helper conversion for a Toy AST location to an MLIR location. mlir::FileLineColLoc loc(Location loc) { return mlir::FileLineColLoc::get( mlir::UniquedFilename::get(*loc.file, &context), loc.line, loc.col, &context); } /// Declare a variable in the current scope, return true if the variable /// wasn't declared yet. bool declare(llvm::StringRef var, mlir::Value *value) { if (symbolTable.count(var)) { return false; } symbolTable.insert(var, value); return true; } /// Create the prototype for an MLIR function with as many arguments as the /// provided Toy AST prototype. mlir::Function *mlirGen(PrototypeAST &proto) { // This is a generic function, the return type will be inferred later. llvm::SmallVector ret_types; // Arguments type is uniformly a generic array. llvm::SmallVector arg_types(proto.getArgs().size(), getType(VarType{})); auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context); auto *function = new mlir::Function(loc(proto.loc()), proto.getName(), func_type, /* attrs = */ {}); // Mark the function as generic: it'll require type specialization for every // call site. if (function->getNumArguments()) function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context)); return function; } /// Emit a new function and add it to the MLIR module. std::unique_ptr mlirGen(FunctionAST &funcAST) { // Create a scope in the symbol table to hold variable declarations. ScopedHashTableScope var_scope(symbolTable); // Create an MLIR function for the given prototype. std::unique_ptr function(mlirGen(*funcAST.getProto())); if (!function) return nullptr; // Let's start the body of the function now! // In MLIR the entry block of the function is special: it must have the same // argument list as the function itself. function->addEntryBlock(); auto &entryBlock = function->front(); auto &protoArgs = funcAST.getProto()->getArgs(); // Declare all the function arguments in the symbol table. for (const auto &name_value : llvm::zip(protoArgs, entryBlock.getArguments())) { declare(std::get<0>(name_value)->getName(), std::get<1>(name_value)); } // Create a builder for the function, it will be used throughout the codegen // to create operations in this function. builder = llvm::make_unique(function.get()); // Emit the body of the function. if (!mlirGen(*funcAST.getBody())) return nullptr; // Implicitly return void if no return statement was emited. // FIXME: we may fix the parser instead to always return the last expression // (this would possibly help the REPL case later) if (function->getBlocks().back().back().getName().getStringRef() != "toy.return") { ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None); mlirGen(fakeRet); } return function; } /// Emit a binary operation mlir::Value *mlirGen(BinaryExprAST &binop) { // First emit the operations for each side of the operation before emitting // the operation itself. For example if the expression is `a + foo(a)` // 1) First it will visiting the LHS, which will return a reference to the // value holding `a`. This value should have been emitted at declaration // time and registered in the symbol table, so nothing would be // codegen'd. If the value is not in the symbol table, an error has been // emitted and nullptr is returned. // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted // and the result value is returned. If an error occurs we get a nullptr // and propagate. // mlir::Value *L = mlirGen(*binop.getLHS()); if (!L) return nullptr; mlir::Value *R = mlirGen(*binop.getRHS()); if (!R) return nullptr; auto location = loc(binop.loc()); // Derive the operation name from the binary operator. At the moment we only // support '+' and '*'. switch (binop.getOp()) { case '+': return builder->create(location, L, R).getResult(); break; case '*': return builder->create(location, L, R).getResult(); default: context.emitError(loc(binop.loc()), Twine("Error: invalid binary operator '") + Twine(binop.getOp()) + "'"); return nullptr; } } // This is a reference to a variable in an expression. The variable is // expected to have been declared and so should have a value in the symbol // table, otherwise emit an error and return nullptr. mlir::Value *mlirGen(VariableExprAST &expr) { if (symbolTable.count(expr.getName())) return symbolTable.lookup(expr.getName()); context.emitError(loc(expr.loc()), Twine("Error: unknown variable '") + expr.getName() + "'"); return nullptr; } // Emit a return operation, return true on success. bool mlirGen(ReturnExprAST &ret) { auto location = loc(ret.loc()); // `return` takes an optional expression, we need to account for it here. if (!ret.getExpr().hasValue()) { builder->create(location); return true; } auto *expr = mlirGen(*ret.getExpr().getValue()); if (!expr) return false; builder->create(location, expr); return true; } // Emit a literal/constant array. It will be emitted as a flattened array of // data in an Attribute attached to a `toy.constant` operation. // See documentation on [Attributes](LangRef.md#attributes) for more details. // Here is an excerpt: // // Attributes are the mechanism for specifying constant data in MLIR in // places where a variable is never allowed [...]. They consist of a name // and a [concrete attribute value](#attribute-values). It is possible to // attach attributes to operations, functions, and function arguments. The // set of expected attributes, their structure, and their interpretation // are all contextually dependent on what they are attached to. // // Example, the source level statement: // var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; // will be converted to: // %0 = "toy.constant"() {value: dense, // [[1.000000e+00, 2.000000e+00, 3.000000e+00], // [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> memref<2x3xf64> // mlir::Value *mlirGen(LiteralExprAST &lit) { auto location = loc(lit.loc()); // The attribute is a vector with an attribute per element (number) in the // array, see `collectData()` below for more details. std::vector data; data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, std::multiplies())); collectData(lit, data); // FIXME: using a tensor type is a HACK here. // Can we do differently without registering a dialect? Using a string blob? mlir::Type elementType = mlir::FloatType::getF64(&context); auto dataType = builder->getTensorType(lit.getDims(), elementType); // This is the actual attribute that actually hold the list of values for // this array literal. auto dataAttribute = builder->getDenseElementsAttr(dataType, data) .cast(); // Build the MLIR op `toy.constant`, only boilerplate below. return builder->create(location, lit.getDims(), dataAttribute) .getResult(); } // Recursive helper function to accumulate the data that compose an array // literal. It flattens the nested structure in the supplied vector. For // example with this array: // [[1, 2], [3, 4]] // we will generate: // [ 1, 2, 3, 4 ] // Individual numbers are wrapped in a light wrapper `mlir::FloatAttr`. // Attributes are the way MLIR attaches constant to operations and functions. void collectData(ExprAST &expr, std::vector &data) { if (auto *lit = dyn_cast(&expr)) { for (auto &value : lit->getValues()) collectData(*value, data); return; } assert(isa(expr) && "expected literal or number expr"); mlir::Type elementType = mlir::FloatType::getF64(&context); auto attr = mlir::FloatAttr::getChecked( elementType, cast(expr).getValue(), loc(expr.loc())); data.push_back(attr); } // Emit a call expression. It emits specific operations for the `transpose` // builtin. Other identifiers are assumed to be user-defined functions. mlir::Value *mlirGen(CallExprAST &call) { auto location = loc(call.loc()); std::string callee = call.getCallee(); if (callee == "transpose") { if (call.getArgs().size() != 1) { context.emitError( location, Twine("MLIR codegen encountered an error: toy.transpose " "does not accept multiple arguments")); return nullptr; } mlir::Value *arg = mlirGen(*call.getArgs()[0]); return builder->create(location, arg).getResult(); } // Codegen the operands first SmallVector operands; for (auto &expr : call.getArgs()) { auto *arg = mlirGen(*expr); if (!arg) return nullptr; operands.push_back(arg); } // Calls to user-defined function are mapped to a custom call that takes // the callee name as an attribute. return builder->create(location, call.getCallee(), operands) .getResult(); } // Emit a call expression. It emits specific operations for two builtins: // transpose(x) and print(x). Other identifiers are assumed to be user-defined // functions. Return false on failure. bool mlirGen(PrintExprAST &call) { auto *arg = mlirGen(*call.getArg()); if (!arg) return false; auto location = loc(call.loc()); builder->create(location, arg); return true; } // Emit a constant for a single number (FIXME: semantic? broadcast?) mlir::Value *mlirGen(NumberExprAST &num) { auto location = loc(num.loc()); mlir::Type elementType = mlir::FloatType::getF64(&context); auto attr = mlir::FloatAttr::getChecked(elementType, num.getValue(), loc(num.loc())); return builder->create(location, attr).getResult(); } // Dispatch codegen for the right expression subclass using RTTI. mlir::Value *mlirGen(ExprAST &expr) { switch (expr.getKind()) { case toy::ExprAST::Expr_BinOp: return mlirGen(cast(expr)); case toy::ExprAST::Expr_Var: return mlirGen(cast(expr)); case toy::ExprAST::Expr_Literal: return mlirGen(cast(expr)); case toy::ExprAST::Expr_Call: return mlirGen(cast(expr)); case toy::ExprAST::Expr_Num: return mlirGen(cast(expr)); default: context.emitError( loc(expr.loc()), Twine("MLIR codegen encountered an unhandled expr kind '") + Twine(expr.getKind()) + "'"); return nullptr; } } // Handle a variable declaration, we'll codegen the expression that forms the // initializer and record the value in the symbol table before returning it. // Future expressions will be able to reference this variable through symbol // table lookup. mlir::Value *mlirGen(VarDeclExprAST &vardecl) { mlir::Value *value = nullptr; auto location = loc(vardecl.loc()); if (auto init = vardecl.getInitVal()) { value = mlirGen(*init); if (!value) return nullptr; // We have the initializer value, but in case the variable was declared // with specific shape, we emit a "reshape" operation. It will get // optimized out later as needed. if (!vardecl.getType().shape.empty()) { value = builder ->create( location, value, getType(vardecl.getType()).cast()) .getResult(); } } else { context.emitError(loc(vardecl.loc()), "Missing initializer in variable declaration"); return nullptr; } // Register the value in the symbol table declare(vardecl.getName(), value); return value; } /// Codegen a list of expression, return false if one of them hit an error. bool mlirGen(ExprASTList &blockAST) { ScopedHashTableScope var_scope(symbolTable); for (auto &expr : blockAST) { // Specific handling for variable declarations, return statement, and // print. These can only appear in block list and not in nested // expressions. if (auto *vardecl = dyn_cast(expr.get())) { if (!mlirGen(*vardecl)) return false; continue; } if (auto *ret = dyn_cast(expr.get())) { if (!mlirGen(*ret)) return false; return true; } if (auto *print = dyn_cast(expr.get())) { if (!mlirGen(*print)) return false; continue; } // Generic expression dispatch codegen. if (!mlirGen(*expr)) return false; } return true; } /// Build a type from a list of shape dimensions. Types are `array` followed /// by an optional dimension list, example: array<2, 2> /// They are wrapped in a `toy` dialect (see next chapter) and get printed: /// !toy.array<2, 2> template mlir::Type getType(T shape) { SmallVector shape64(shape.begin(), shape.end()); return ToyArrayType::get(&context, shape64); } /// Build an MLIR type from a Toy AST variable type /// (forward to the generic getType(T) above). mlir::Type getType(const VarType &type) { return getType(type.shape); } }; } // namespace namespace toy { // The public API for codegen. std::unique_ptr mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST) { return MLIRGenImpl(context).mlirGen(moduleAST); } } // namespace toy