summaryrefslogtreecommitdiff
path: root/toy/MLIRGen.cpp
diff options
context:
space:
mode:
authorTuowen Zhao <ztuowen@gmail.com>2019-04-27 19:05:25 -0600
committerTuowen Zhao <ztuowen@gmail.com>2019-04-27 19:05:25 -0600
commit0781257b2a8d544abdcce38824a9b8288a04800d (patch)
tree365cea96de343e354913f90b35fc944e4459b2e9 /toy/MLIRGen.cpp
parent4127831a28e31ac53ffdb1d7e7a88dd7d6317c6e (diff)
downloadmlir-toy-0781257b2a8d544abdcce38824a9b8288a04800d.tar.gz
mlir-toy-0781257b2a8d544abdcce38824a9b8288a04800d.tar.bz2
mlir-toy-0781257b2a8d544abdcce38824a9b8288a04800d.zip
Split toy dialect using static registration
Diffstat (limited to 'toy/MLIRGen.cpp')
-rw-r--r--toy/MLIRGen.cpp480
1 files changed, 480 insertions, 0 deletions
diff --git a/toy/MLIRGen.cpp b/toy/MLIRGen.cpp
new file mode 100644
index 0000000..e2001fb
--- /dev/null
+++ b/toy/MLIRGen.cpp
@@ -0,0 +1,480 @@
+//===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a simple IR generation targeting MLIR from a Module AST
+// for the Toy language.
+//
+//===----------------------------------------------------------------------===//
+
+#include "toy/MLIRGen.h"
+#include "toy/AST.h"
+#include "toy/Dialect.h"
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/Types.h"
+#include "mlir/StandardOps/Ops.h"
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopedHashTable.h"
+#include "llvm/Support/raw_ostream.h"
+#include <numeric>
+
+using namespace toy;
+using llvm::cast;
+using llvm::dyn_cast;
+using llvm::isa;
+using llvm::make_unique;
+using llvm::ScopedHashTableScope;
+using llvm::SmallVector;
+using llvm::StringRef;
+using llvm::Twine;
+
+namespace {
+
+/// Implementation of a simple MLIR emission from the Toy AST.
+///
+/// This will emit operations that are specific to the Toy language, preserving
+/// the semantics of the language and (hopefully) allow to perform accurate
+/// analysis and transformation based on these high level semantics.
+///
+/// At this point we take advantage of the "raw" MLIR APIs to create operations
+/// that haven't been registered in any way with MLIR. These operations are
+/// unknown to MLIR, custom passes could operate by string-matching the name of
+/// these operations, but no other type checking or semantic is associated with
+/// them natively by MLIR.
+class MLIRGenImpl {
+public:
+ MLIRGenImpl(mlir::MLIRContext &context) : context(context) {}
+
+ /// Public API: convert the AST for a Toy module (source file) to an MLIR
+ /// Module.
+ std::unique_ptr<mlir::Module> mlirGen(ModuleAST &moduleAST) {
+ // We create an empty MLIR module and codegen functions one at a time and
+ // add them to the module.
+ theModule = make_unique<mlir::Module>(&context);
+
+ for (FunctionAST &F : moduleAST) {
+ auto func = mlirGen(F);
+ if (!func)
+ return nullptr;
+ theModule->getFunctions().push_back(func.release());
+ }
+
+ // FIXME: (in the next chapter...) without registering a dialect in MLIR,
+ // this won't do much, but it should at least check some structural
+ // properties.
+ if (failed(theModule->verify())) {
+ context.emitError(mlir::UnknownLoc::get(&context),
+ "Module verification error");
+ return nullptr;
+ }
+
+ return std::move(theModule);
+ }
+
+private:
+ /// In MLIR (like in LLVM) a "context" object holds the memory allocation and
+ /// the ownership of many internal structure of the IR and provide a level
+ /// of "uniquing" across multiple modules (types for instance).
+ mlir::MLIRContext &context;
+
+ /// A "module" matches a source file: it contains a list of functions.
+ std::unique_ptr<mlir::Module> theModule;
+
+ /// The builder is a helper class to create IR inside a function. It is
+ /// re-initialized every time we enter a function and kept around as a
+ /// convenience for emitting individual operations.
+ /// The builder is stateful, in particular it keeeps an "insertion point":
+ /// this is where the next operations will be introduced.
+ std::unique_ptr<mlir::FuncBuilder> builder;
+
+ /// The symbol table maps a variable name to a value in the current scope.
+ /// Entering a function creates a new scope, and the function arguments are
+ /// added to the mapping. When the processing of a function is terminated, the
+ /// scope is destroyed and the mappings created in this scope are dropped.
+ llvm::ScopedHashTable<StringRef, mlir::Value *> symbolTable;
+
+ /// Helper conversion for a Toy AST location to an MLIR location.
+ mlir::FileLineColLoc loc(Location loc) {
+ return mlir::FileLineColLoc::get(
+ mlir::UniquedFilename::get(*loc.file, &context), loc.line, loc.col,
+ &context);
+ }
+
+ /// Declare a variable in the current scope, return true if the variable
+ /// wasn't declared yet.
+ bool declare(llvm::StringRef var, mlir::Value *value) {
+ if (symbolTable.count(var)) {
+ return false;
+ }
+ symbolTable.insert(var, value);
+ return true;
+ }
+
+ /// Create the prototype for an MLIR function with as many arguments as the
+ /// provided Toy AST prototype.
+ mlir::Function *mlirGen(PrototypeAST &proto) {
+ // This is a generic function, the return type will be inferred later.
+ llvm::SmallVector<mlir::Type, 4> ret_types;
+ // Arguments type is uniformly a generic array.
+ llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(),
+ getType(VarType{}));
+ auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context);
+ auto *function = new mlir::Function(loc(proto.loc()), proto.getName(),
+ func_type, /* attrs = */ {});
+
+ // Mark the function as generic: it'll require type specialization for every
+ // call site.
+ if (function->getNumArguments())
+ function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
+
+ return function;
+ }
+
+ /// Emit a new function and add it to the MLIR module.
+ std::unique_ptr<mlir::Function> mlirGen(FunctionAST &funcAST) {
+ // Create a scope in the symbol table to hold variable declarations.
+ ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable);
+
+ // Create an MLIR function for the given prototype.
+ std::unique_ptr<mlir::Function> function(mlirGen(*funcAST.getProto()));
+ if (!function)
+ return nullptr;
+
+ // Let's start the body of the function now!
+ // In MLIR the entry block of the function is special: it must have the same
+ // argument list as the function itself.
+ function->addEntryBlock();
+
+ auto &entryBlock = function->front();
+ auto &protoArgs = funcAST.getProto()->getArgs();
+ // Declare all the function arguments in the symbol table.
+ for (const auto &name_value :
+ llvm::zip(protoArgs, entryBlock.getArguments())) {
+ declare(std::get<0>(name_value)->getName(), std::get<1>(name_value));
+ }
+
+ // Create a builder for the function, it will be used throughout the codegen
+ // to create operations in this function.
+ builder = llvm::make_unique<mlir::FuncBuilder>(function.get());
+
+ // Emit the body of the function.
+ if (!mlirGen(*funcAST.getBody()))
+ return nullptr;
+
+ // Implicitly return void if no return statement was emited.
+ // FIXME: we may fix the parser instead to always return the last expression
+ // (this would possibly help the REPL case later)
+ if (function->getBlocks().back().back().getName().getStringRef() !=
+ "toy.return") {
+ ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None);
+ mlirGen(fakeRet);
+ }
+
+ return function;
+ }
+
+ /// Emit a binary operation
+ mlir::Value *mlirGen(BinaryExprAST &binop) {
+ // First emit the operations for each side of the operation before emitting
+ // the operation itself. For example if the expression is `a + foo(a)`
+ // 1) First it will visiting the LHS, which will return a reference to the
+ // value holding `a`. This value should have been emitted at declaration
+ // time and registered in the symbol table, so nothing would be
+ // codegen'd. If the value is not in the symbol table, an error has been
+ // emitted and nullptr is returned.
+ // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted
+ // and the result value is returned. If an error occurs we get a nullptr
+ // and propagate.
+ //
+ mlir::Value *L = mlirGen(*binop.getLHS());
+ if (!L)
+ return nullptr;
+ mlir::Value *R = mlirGen(*binop.getRHS());
+ if (!R)
+ return nullptr;
+ auto location = loc(binop.loc());
+
+ // Derive the operation name from the binary operator. At the moment we only
+ // support '+' and '*'.
+ switch (binop.getOp()) {
+ case '+':
+ return builder->create<AddOp>(location, L, R).getResult();
+ break;
+ case '*':
+ return builder->create<MulOp>(location, L, R).getResult();
+ default:
+ context.emitError(loc(binop.loc()),
+ Twine("Error: invalid binary operator '") +
+ Twine(binop.getOp()) + "'");
+ return nullptr;
+ }
+ }
+
+ // This is a reference to a variable in an expression. The variable is
+ // expected to have been declared and so should have a value in the symbol
+ // table, otherwise emit an error and return nullptr.
+ mlir::Value *mlirGen(VariableExprAST &expr) {
+ if (symbolTable.count(expr.getName()))
+ return symbolTable.lookup(expr.getName());
+ context.emitError(loc(expr.loc()), Twine("Error: unknown variable '") +
+ expr.getName() + "'");
+ return nullptr;
+ }
+
+ // Emit a return operation, return true on success.
+ bool mlirGen(ReturnExprAST &ret) {
+ auto location = loc(ret.loc());
+ // `return` takes an optional expression, we need to account for it here.
+ if (!ret.getExpr().hasValue()) {
+ builder->create<ReturnOp>(location);
+ return true;
+ }
+ auto *expr = mlirGen(*ret.getExpr().getValue());
+ if (!expr)
+ return false;
+ builder->create<ReturnOp>(location, expr);
+ return true;
+ }
+
+ // Emit a literal/constant array. It will be emitted as a flattened array of
+ // data in an Attribute attached to a `toy.constant` operation.
+ // See documentation on [Attributes](LangRef.md#attributes) for more details.
+ // Here is an excerpt:
+ //
+ // Attributes are the mechanism for specifying constant data in MLIR in
+ // places where a variable is never allowed [...]. They consist of a name
+ // and a [concrete attribute value](#attribute-values). It is possible to
+ // attach attributes to operations, functions, and function arguments. The
+ // set of expected attributes, their structure, and their interpretation
+ // are all contextually dependent on what they are attached to.
+ //
+ // Example, the source level statement:
+ // var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
+ // will be converted to:
+ // %0 = "toy.constant"() {value: dense<tensor<2x3xf64>,
+ // [[1.000000e+00, 2.000000e+00, 3.000000e+00],
+ // [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> memref<2x3xf64>
+ //
+ mlir::Value *mlirGen(LiteralExprAST &lit) {
+ auto location = loc(lit.loc());
+ // The attribute is a vector with an attribute per element (number) in the
+ // array, see `collectData()` below for more details.
+ std::vector<mlir::Attribute> data;
+ data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1,
+ std::multiplies<int>()));
+ collectData(lit, data);
+
+ // FIXME: using a tensor type is a HACK here.
+ // Can we do differently without registering a dialect? Using a string blob?
+ mlir::Type elementType = mlir::FloatType::getF64(&context);
+ auto dataType = builder->getTensorType(lit.getDims(), elementType);
+
+ // This is the actual attribute that actually hold the list of values for
+ // this array literal.
+ auto dataAttribute = builder->getDenseElementsAttr(dataType, data)
+ .cast<mlir::DenseElementsAttr>();
+
+ // Build the MLIR op `toy.constant`, only boilerplate below.
+ return builder->create<ConstantOp>(location, lit.getDims(), dataAttribute)
+ .getResult();
+ }
+
+ // Recursive helper function to accumulate the data that compose an array
+ // literal. It flattens the nested structure in the supplied vector. For
+ // example with this array:
+ // [[1, 2], [3, 4]]
+ // we will generate:
+ // [ 1, 2, 3, 4 ]
+ // Individual numbers are wrapped in a light wrapper `mlir::FloatAttr`.
+ // Attributes are the way MLIR attaches constant to operations and functions.
+ void collectData(ExprAST &expr, std::vector<mlir::Attribute> &data) {
+ if (auto *lit = dyn_cast<LiteralExprAST>(&expr)) {
+ for (auto &value : lit->getValues())
+ collectData(*value, data);
+ return;
+ }
+ assert(isa<NumberExprAST>(expr) && "expected literal or number expr");
+ mlir::Type elementType = mlir::FloatType::getF64(&context);
+ auto attr = mlir::FloatAttr::getChecked(
+ elementType, cast<NumberExprAST>(expr).getValue(), loc(expr.loc()));
+ data.push_back(attr);
+ }
+
+ // Emit a call expression. It emits specific operations for the `transpose`
+ // builtin. Other identifiers are assumed to be user-defined functions.
+ mlir::Value *mlirGen(CallExprAST &call) {
+ auto location = loc(call.loc());
+ std::string callee = call.getCallee();
+ if (callee == "transpose") {
+ if (call.getArgs().size() != 1) {
+ context.emitError(
+ location, Twine("MLIR codegen encountered an error: toy.transpose "
+ "does not accept multiple arguments"));
+ return nullptr;
+ }
+ mlir::Value *arg = mlirGen(*call.getArgs()[0]);
+ return builder->create<TransposeOp>(location, arg).getResult();
+ }
+
+ // Codegen the operands first
+ SmallVector<mlir::Value *, 4> operands;
+ for (auto &expr : call.getArgs()) {
+ auto *arg = mlirGen(*expr);
+ if (!arg)
+ return nullptr;
+ operands.push_back(arg);
+ }
+ // Calls to user-defined function are mapped to a custom call that takes
+ // the callee name as an attribute.
+ return builder->create<GenericCallOp>(location, call.getCallee(), operands)
+ .getResult();
+ }
+
+ // Emit a call expression. It emits specific operations for two builtins:
+ // transpose(x) and print(x). Other identifiers are assumed to be user-defined
+ // functions. Return false on failure.
+ bool mlirGen(PrintExprAST &call) {
+ auto *arg = mlirGen(*call.getArg());
+ if (!arg)
+ return false;
+ auto location = loc(call.loc());
+ builder->create<PrintOp>(location, arg);
+ return true;
+ }
+
+ // Emit a constant for a single number (FIXME: semantic? broadcast?)
+ mlir::Value *mlirGen(NumberExprAST &num) {
+ auto location = loc(num.loc());
+ mlir::Type elementType = mlir::FloatType::getF64(&context);
+ auto attr = mlir::FloatAttr::getChecked(elementType, num.getValue(),
+ loc(num.loc()));
+ return builder->create<ConstantOp>(location, attr).getResult();
+ }
+
+ // Dispatch codegen for the right expression subclass using RTTI.
+ mlir::Value *mlirGen(ExprAST &expr) {
+ switch (expr.getKind()) {
+ case toy::ExprAST::Expr_BinOp:
+ return mlirGen(cast<BinaryExprAST>(expr));
+ case toy::ExprAST::Expr_Var:
+ return mlirGen(cast<VariableExprAST>(expr));
+ case toy::ExprAST::Expr_Literal:
+ return mlirGen(cast<LiteralExprAST>(expr));
+ case toy::ExprAST::Expr_Call:
+ return mlirGen(cast<CallExprAST>(expr));
+ case toy::ExprAST::Expr_Num:
+ return mlirGen(cast<NumberExprAST>(expr));
+ default:
+ context.emitError(
+ loc(expr.loc()),
+ Twine("MLIR codegen encountered an unhandled expr kind '") +
+ Twine(expr.getKind()) + "'");
+ return nullptr;
+ }
+ }
+
+ // Handle a variable declaration, we'll codegen the expression that forms the
+ // initializer and record the value in the symbol table before returning it.
+ // Future expressions will be able to reference this variable through symbol
+ // table lookup.
+ mlir::Value *mlirGen(VarDeclExprAST &vardecl) {
+ mlir::Value *value = nullptr;
+ auto location = loc(vardecl.loc());
+ if (auto init = vardecl.getInitVal()) {
+ value = mlirGen(*init);
+ if (!value)
+ return nullptr;
+ // We have the initializer value, but in case the variable was declared
+ // with specific shape, we emit a "reshape" operation. It will get
+ // optimized out later as needed.
+ if (!vardecl.getType().shape.empty()) {
+ value = builder
+ ->create<ReshapeOp>(
+ location, value,
+ getType(vardecl.getType()).cast<ToyArrayType>())
+ .getResult();
+ }
+ } else {
+ context.emitError(loc(vardecl.loc()),
+ "Missing initializer in variable declaration");
+ return nullptr;
+ }
+ // Register the value in the symbol table
+ declare(vardecl.getName(), value);
+ return value;
+ }
+
+ /// Codegen a list of expression, return false if one of them hit an error.
+ bool mlirGen(ExprASTList &blockAST) {
+ ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable);
+ for (auto &expr : blockAST) {
+ // Specific handling for variable declarations, return statement, and
+ // print. These can only appear in block list and not in nested
+ // expressions.
+ if (auto *vardecl = dyn_cast<VarDeclExprAST>(expr.get())) {
+ if (!mlirGen(*vardecl))
+ return false;
+ continue;
+ }
+ if (auto *ret = dyn_cast<ReturnExprAST>(expr.get())) {
+ if (!mlirGen(*ret))
+ return false;
+ return true;
+ }
+ if (auto *print = dyn_cast<PrintExprAST>(expr.get())) {
+ if (!mlirGen(*print))
+ return false;
+ continue;
+ }
+ // Generic expression dispatch codegen.
+ if (!mlirGen(*expr))
+ return false;
+ }
+ return true;
+ }
+
+ /// Build a type from a list of shape dimensions. Types are `array` followed
+ /// by an optional dimension list, example: array<2, 2>
+ /// They are wrapped in a `toy` dialect (see next chapter) and get printed:
+ /// !toy.array<2, 2>
+ template <typename T> mlir::Type getType(T shape) {
+ SmallVector<int64_t, 8> shape64(shape.begin(), shape.end());
+ return ToyArrayType::get(&context, shape64);
+ }
+
+ /// Build an MLIR type from a Toy AST variable type
+ /// (forward to the generic getType(T) above).
+ mlir::Type getType(const VarType &type) { return getType(type.shape); }
+};
+
+} // namespace
+
+namespace toy {
+
+// The public API for codegen.
+std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context,
+ ModuleAST &moduleAST) {
+ return MLIRGenImpl(context).mlirGen(moduleAST);
+}
+
+} // namespace toy