diff options
author | Tuowen Zhao <ztuowen@gmail.com> | 2019-04-27 19:05:25 -0600 |
---|---|---|
committer | Tuowen Zhao <ztuowen@gmail.com> | 2019-04-27 19:05:25 -0600 |
commit | 0781257b2a8d544abdcce38824a9b8288a04800d (patch) | |
tree | 365cea96de343e354913f90b35fc944e4459b2e9 /mlir/MLIRGen.cpp | |
parent | 4127831a28e31ac53ffdb1d7e7a88dd7d6317c6e (diff) | |
download | mlir-toy-0781257b2a8d544abdcce38824a9b8288a04800d.tar.gz mlir-toy-0781257b2a8d544abdcce38824a9b8288a04800d.tar.bz2 mlir-toy-0781257b2a8d544abdcce38824a9b8288a04800d.zip |
Split toy dialect using static registration
Diffstat (limited to 'mlir/MLIRGen.cpp')
-rw-r--r-- | mlir/MLIRGen.cpp | 480 |
1 files changed, 0 insertions, 480 deletions
diff --git a/mlir/MLIRGen.cpp b/mlir/MLIRGen.cpp deleted file mode 100644 index e2001fb..0000000 --- a/mlir/MLIRGen.cpp +++ /dev/null @@ -1,480 +0,0 @@ -//===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file implements a simple IR generation targeting MLIR from a Module AST -// for the Toy language. -// -//===----------------------------------------------------------------------===// - -#include "toy/MLIRGen.h" -#include "toy/AST.h" -#include "toy/Dialect.h" - -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Module.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/IR/Types.h" -#include "mlir/StandardOps/Ops.h" - -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/ScopedHashTable.h" -#include "llvm/Support/raw_ostream.h" -#include <numeric> - -using namespace toy; -using llvm::cast; -using llvm::dyn_cast; -using llvm::isa; -using llvm::make_unique; -using llvm::ScopedHashTableScope; -using llvm::SmallVector; -using llvm::StringRef; -using llvm::Twine; - -namespace { - -/// Implementation of a simple MLIR emission from the Toy AST. -/// -/// This will emit operations that are specific to the Toy language, preserving -/// the semantics of the language and (hopefully) allow to perform accurate -/// analysis and transformation based on these high level semantics. -/// -/// At this point we take advantage of the "raw" MLIR APIs to create operations -/// that haven't been registered in any way with MLIR. These operations are -/// unknown to MLIR, custom passes could operate by string-matching the name of -/// these operations, but no other type checking or semantic is associated with -/// them natively by MLIR. -class MLIRGenImpl { -public: - MLIRGenImpl(mlir::MLIRContext &context) : context(context) {} - - /// Public API: convert the AST for a Toy module (source file) to an MLIR - /// Module. - std::unique_ptr<mlir::Module> mlirGen(ModuleAST &moduleAST) { - // We create an empty MLIR module and codegen functions one at a time and - // add them to the module. - theModule = make_unique<mlir::Module>(&context); - - for (FunctionAST &F : moduleAST) { - auto func = mlirGen(F); - if (!func) - return nullptr; - theModule->getFunctions().push_back(func.release()); - } - - // FIXME: (in the next chapter...) without registering a dialect in MLIR, - // this won't do much, but it should at least check some structural - // properties. - if (failed(theModule->verify())) { - context.emitError(mlir::UnknownLoc::get(&context), - "Module verification error"); - return nullptr; - } - - return std::move(theModule); - } - -private: - /// In MLIR (like in LLVM) a "context" object holds the memory allocation and - /// the ownership of many internal structure of the IR and provide a level - /// of "uniquing" across multiple modules (types for instance). - mlir::MLIRContext &context; - - /// A "module" matches a source file: it contains a list of functions. - std::unique_ptr<mlir::Module> theModule; - - /// The builder is a helper class to create IR inside a function. It is - /// re-initialized every time we enter a function and kept around as a - /// convenience for emitting individual operations. - /// The builder is stateful, in particular it keeeps an "insertion point": - /// this is where the next operations will be introduced. - std::unique_ptr<mlir::FuncBuilder> builder; - - /// The symbol table maps a variable name to a value in the current scope. - /// Entering a function creates a new scope, and the function arguments are - /// added to the mapping. When the processing of a function is terminated, the - /// scope is destroyed and the mappings created in this scope are dropped. - llvm::ScopedHashTable<StringRef, mlir::Value *> symbolTable; - - /// Helper conversion for a Toy AST location to an MLIR location. - mlir::FileLineColLoc loc(Location loc) { - return mlir::FileLineColLoc::get( - mlir::UniquedFilename::get(*loc.file, &context), loc.line, loc.col, - &context); - } - - /// Declare a variable in the current scope, return true if the variable - /// wasn't declared yet. - bool declare(llvm::StringRef var, mlir::Value *value) { - if (symbolTable.count(var)) { - return false; - } - symbolTable.insert(var, value); - return true; - } - - /// Create the prototype for an MLIR function with as many arguments as the - /// provided Toy AST prototype. - mlir::Function *mlirGen(PrototypeAST &proto) { - // This is a generic function, the return type will be inferred later. - llvm::SmallVector<mlir::Type, 4> ret_types; - // Arguments type is uniformly a generic array. - llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(), - getType(VarType{})); - auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context); - auto *function = new mlir::Function(loc(proto.loc()), proto.getName(), - func_type, /* attrs = */ {}); - - // Mark the function as generic: it'll require type specialization for every - // call site. - if (function->getNumArguments()) - function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context)); - - return function; - } - - /// Emit a new function and add it to the MLIR module. - std::unique_ptr<mlir::Function> mlirGen(FunctionAST &funcAST) { - // Create a scope in the symbol table to hold variable declarations. - ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable); - - // Create an MLIR function for the given prototype. - std::unique_ptr<mlir::Function> function(mlirGen(*funcAST.getProto())); - if (!function) - return nullptr; - - // Let's start the body of the function now! - // In MLIR the entry block of the function is special: it must have the same - // argument list as the function itself. - function->addEntryBlock(); - - auto &entryBlock = function->front(); - auto &protoArgs = funcAST.getProto()->getArgs(); - // Declare all the function arguments in the symbol table. - for (const auto &name_value : - llvm::zip(protoArgs, entryBlock.getArguments())) { - declare(std::get<0>(name_value)->getName(), std::get<1>(name_value)); - } - - // Create a builder for the function, it will be used throughout the codegen - // to create operations in this function. - builder = llvm::make_unique<mlir::FuncBuilder>(function.get()); - - // Emit the body of the function. - if (!mlirGen(*funcAST.getBody())) - return nullptr; - - // Implicitly return void if no return statement was emited. - // FIXME: we may fix the parser instead to always return the last expression - // (this would possibly help the REPL case later) - if (function->getBlocks().back().back().getName().getStringRef() != - "toy.return") { - ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None); - mlirGen(fakeRet); - } - - return function; - } - - /// Emit a binary operation - mlir::Value *mlirGen(BinaryExprAST &binop) { - // First emit the operations for each side of the operation before emitting - // the operation itself. For example if the expression is `a + foo(a)` - // 1) First it will visiting the LHS, which will return a reference to the - // value holding `a`. This value should have been emitted at declaration - // time and registered in the symbol table, so nothing would be - // codegen'd. If the value is not in the symbol table, an error has been - // emitted and nullptr is returned. - // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted - // and the result value is returned. If an error occurs we get a nullptr - // and propagate. - // - mlir::Value *L = mlirGen(*binop.getLHS()); - if (!L) - return nullptr; - mlir::Value *R = mlirGen(*binop.getRHS()); - if (!R) - return nullptr; - auto location = loc(binop.loc()); - - // Derive the operation name from the binary operator. At the moment we only - // support '+' and '*'. - switch (binop.getOp()) { - case '+': - return builder->create<AddOp>(location, L, R).getResult(); - break; - case '*': - return builder->create<MulOp>(location, L, R).getResult(); - default: - context.emitError(loc(binop.loc()), - Twine("Error: invalid binary operator '") + - Twine(binop.getOp()) + "'"); - return nullptr; - } - } - - // This is a reference to a variable in an expression. The variable is - // expected to have been declared and so should have a value in the symbol - // table, otherwise emit an error and return nullptr. - mlir::Value *mlirGen(VariableExprAST &expr) { - if (symbolTable.count(expr.getName())) - return symbolTable.lookup(expr.getName()); - context.emitError(loc(expr.loc()), Twine("Error: unknown variable '") + - expr.getName() + "'"); - return nullptr; - } - - // Emit a return operation, return true on success. - bool mlirGen(ReturnExprAST &ret) { - auto location = loc(ret.loc()); - // `return` takes an optional expression, we need to account for it here. - if (!ret.getExpr().hasValue()) { - builder->create<ReturnOp>(location); - return true; - } - auto *expr = mlirGen(*ret.getExpr().getValue()); - if (!expr) - return false; - builder->create<ReturnOp>(location, expr); - return true; - } - - // Emit a literal/constant array. It will be emitted as a flattened array of - // data in an Attribute attached to a `toy.constant` operation. - // See documentation on [Attributes](LangRef.md#attributes) for more details. - // Here is an excerpt: - // - // Attributes are the mechanism for specifying constant data in MLIR in - // places where a variable is never allowed [...]. They consist of a name - // and a [concrete attribute value](#attribute-values). It is possible to - // attach attributes to operations, functions, and function arguments. The - // set of expected attributes, their structure, and their interpretation - // are all contextually dependent on what they are attached to. - // - // Example, the source level statement: - // var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; - // will be converted to: - // %0 = "toy.constant"() {value: dense<tensor<2x3xf64>, - // [[1.000000e+00, 2.000000e+00, 3.000000e+00], - // [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> memref<2x3xf64> - // - mlir::Value *mlirGen(LiteralExprAST &lit) { - auto location = loc(lit.loc()); - // The attribute is a vector with an attribute per element (number) in the - // array, see `collectData()` below for more details. - std::vector<mlir::Attribute> data; - data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, - std::multiplies<int>())); - collectData(lit, data); - - // FIXME: using a tensor type is a HACK here. - // Can we do differently without registering a dialect? Using a string blob? - mlir::Type elementType = mlir::FloatType::getF64(&context); - auto dataType = builder->getTensorType(lit.getDims(), elementType); - - // This is the actual attribute that actually hold the list of values for - // this array literal. - auto dataAttribute = builder->getDenseElementsAttr(dataType, data) - .cast<mlir::DenseElementsAttr>(); - - // Build the MLIR op `toy.constant`, only boilerplate below. - return builder->create<ConstantOp>(location, lit.getDims(), dataAttribute) - .getResult(); - } - - // Recursive helper function to accumulate the data that compose an array - // literal. It flattens the nested structure in the supplied vector. For - // example with this array: - // [[1, 2], [3, 4]] - // we will generate: - // [ 1, 2, 3, 4 ] - // Individual numbers are wrapped in a light wrapper `mlir::FloatAttr`. - // Attributes are the way MLIR attaches constant to operations and functions. - void collectData(ExprAST &expr, std::vector<mlir::Attribute> &data) { - if (auto *lit = dyn_cast<LiteralExprAST>(&expr)) { - for (auto &value : lit->getValues()) - collectData(*value, data); - return; - } - assert(isa<NumberExprAST>(expr) && "expected literal or number expr"); - mlir::Type elementType = mlir::FloatType::getF64(&context); - auto attr = mlir::FloatAttr::getChecked( - elementType, cast<NumberExprAST>(expr).getValue(), loc(expr.loc())); - data.push_back(attr); - } - - // Emit a call expression. It emits specific operations for the `transpose` - // builtin. Other identifiers are assumed to be user-defined functions. - mlir::Value *mlirGen(CallExprAST &call) { - auto location = loc(call.loc()); - std::string callee = call.getCallee(); - if (callee == "transpose") { - if (call.getArgs().size() != 1) { - context.emitError( - location, Twine("MLIR codegen encountered an error: toy.transpose " - "does not accept multiple arguments")); - return nullptr; - } - mlir::Value *arg = mlirGen(*call.getArgs()[0]); - return builder->create<TransposeOp>(location, arg).getResult(); - } - - // Codegen the operands first - SmallVector<mlir::Value *, 4> operands; - for (auto &expr : call.getArgs()) { - auto *arg = mlirGen(*expr); - if (!arg) - return nullptr; - operands.push_back(arg); - } - // Calls to user-defined function are mapped to a custom call that takes - // the callee name as an attribute. - return builder->create<GenericCallOp>(location, call.getCallee(), operands) - .getResult(); - } - - // Emit a call expression. It emits specific operations for two builtins: - // transpose(x) and print(x). Other identifiers are assumed to be user-defined - // functions. Return false on failure. - bool mlirGen(PrintExprAST &call) { - auto *arg = mlirGen(*call.getArg()); - if (!arg) - return false; - auto location = loc(call.loc()); - builder->create<PrintOp>(location, arg); - return true; - } - - // Emit a constant for a single number (FIXME: semantic? broadcast?) - mlir::Value *mlirGen(NumberExprAST &num) { - auto location = loc(num.loc()); - mlir::Type elementType = mlir::FloatType::getF64(&context); - auto attr = mlir::FloatAttr::getChecked(elementType, num.getValue(), - loc(num.loc())); - return builder->create<ConstantOp>(location, attr).getResult(); - } - - // Dispatch codegen for the right expression subclass using RTTI. - mlir::Value *mlirGen(ExprAST &expr) { - switch (expr.getKind()) { - case toy::ExprAST::Expr_BinOp: - return mlirGen(cast<BinaryExprAST>(expr)); - case toy::ExprAST::Expr_Var: - return mlirGen(cast<VariableExprAST>(expr)); - case toy::ExprAST::Expr_Literal: - return mlirGen(cast<LiteralExprAST>(expr)); - case toy::ExprAST::Expr_Call: - return mlirGen(cast<CallExprAST>(expr)); - case toy::ExprAST::Expr_Num: - return mlirGen(cast<NumberExprAST>(expr)); - default: - context.emitError( - loc(expr.loc()), - Twine("MLIR codegen encountered an unhandled expr kind '") + - Twine(expr.getKind()) + "'"); - return nullptr; - } - } - - // Handle a variable declaration, we'll codegen the expression that forms the - // initializer and record the value in the symbol table before returning it. - // Future expressions will be able to reference this variable through symbol - // table lookup. - mlir::Value *mlirGen(VarDeclExprAST &vardecl) { - mlir::Value *value = nullptr; - auto location = loc(vardecl.loc()); - if (auto init = vardecl.getInitVal()) { - value = mlirGen(*init); - if (!value) - return nullptr; - // We have the initializer value, but in case the variable was declared - // with specific shape, we emit a "reshape" operation. It will get - // optimized out later as needed. - if (!vardecl.getType().shape.empty()) { - value = builder - ->create<ReshapeOp>( - location, value, - getType(vardecl.getType()).cast<ToyArrayType>()) - .getResult(); - } - } else { - context.emitError(loc(vardecl.loc()), - "Missing initializer in variable declaration"); - return nullptr; - } - // Register the value in the symbol table - declare(vardecl.getName(), value); - return value; - } - - /// Codegen a list of expression, return false if one of them hit an error. - bool mlirGen(ExprASTList &blockAST) { - ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable); - for (auto &expr : blockAST) { - // Specific handling for variable declarations, return statement, and - // print. These can only appear in block list and not in nested - // expressions. - if (auto *vardecl = dyn_cast<VarDeclExprAST>(expr.get())) { - if (!mlirGen(*vardecl)) - return false; - continue; - } - if (auto *ret = dyn_cast<ReturnExprAST>(expr.get())) { - if (!mlirGen(*ret)) - return false; - return true; - } - if (auto *print = dyn_cast<PrintExprAST>(expr.get())) { - if (!mlirGen(*print)) - return false; - continue; - } - // Generic expression dispatch codegen. - if (!mlirGen(*expr)) - return false; - } - return true; - } - - /// Build a type from a list of shape dimensions. Types are `array` followed - /// by an optional dimension list, example: array<2, 2> - /// They are wrapped in a `toy` dialect (see next chapter) and get printed: - /// !toy.array<2, 2> - template <typename T> mlir::Type getType(T shape) { - SmallVector<int64_t, 8> shape64(shape.begin(), shape.end()); - return ToyArrayType::get(&context, shape64); - } - - /// Build an MLIR type from a Toy AST variable type - /// (forward to the generic getType(T) above). - mlir::Type getType(const VarType &type) { return getType(type.shape); } -}; - -} // namespace - -namespace toy { - -// The public API for codegen. -std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context, - ModuleAST &moduleAST) { - return MLIRGenImpl(context).mlirGen(moduleAST); -} - -} // namespace toy |