summaryrefslogtreecommitdiff
path: root/toy/MLIRGen.cpp
blob: e2001fb575e47de38c4b2253a36e58d13003cf3f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
//===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements a simple IR generation targeting MLIR from a Module AST
// for the Toy language.
//
//===----------------------------------------------------------------------===//

#include "toy/MLIRGen.h"
#include "toy/AST.h"
#include "toy/Dialect.h"

#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/StandardOps/Ops.h"

#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/Support/raw_ostream.h"
#include <numeric>

using namespace toy;
using llvm::cast;
using llvm::dyn_cast;
using llvm::isa;
using llvm::make_unique;
using llvm::ScopedHashTableScope;
using llvm::SmallVector;
using llvm::StringRef;
using llvm::Twine;

namespace {

/// Implementation of a simple MLIR emission from the Toy AST.
///
/// This will emit operations that are specific to the Toy language, preserving
/// the semantics of the language and (hopefully) allow to perform accurate
/// analysis and transformation based on these high level semantics.
///
/// At this point we take advantage of the "raw" MLIR APIs to create operations
/// that haven't been registered in any way with MLIR. These operations are
/// unknown to MLIR, custom passes could operate by string-matching the name of
/// these operations, but no other type checking or semantic is associated with
/// them natively by MLIR.
class MLIRGenImpl {
public:
  MLIRGenImpl(mlir::MLIRContext &context) : context(context) {}

  /// Public API: convert the AST for a Toy module (source file) to an MLIR
  /// Module.
  std::unique_ptr<mlir::Module> mlirGen(ModuleAST &moduleAST) {
    // We create an empty MLIR module and codegen functions one at a time and
    // add them to the module.
    theModule = make_unique<mlir::Module>(&context);

    for (FunctionAST &F : moduleAST) {
      auto func = mlirGen(F);
      if (!func)
        return nullptr;
      theModule->getFunctions().push_back(func.release());
    }

    // FIXME: (in the next chapter...) without registering a dialect in MLIR,
    // this won't do much, but it should at least check some structural
    // properties.
    if (failed(theModule->verify())) {
      context.emitError(mlir::UnknownLoc::get(&context),
                        "Module verification error");
      return nullptr;
    }

    return std::move(theModule);
  }

private:
  /// In MLIR (like in LLVM) a "context" object holds the memory allocation and
  /// the ownership of many internal structure of the IR and provide a level
  /// of "uniquing" across multiple modules (types for instance).
  mlir::MLIRContext &context;

  /// A "module" matches a source file: it contains a list of functions.
  std::unique_ptr<mlir::Module> theModule;

  /// The builder is a helper class to create IR inside a function. It is
  /// re-initialized every time we enter a function and kept around as a
  /// convenience for emitting individual operations.
  /// The builder is stateful, in particular it keeeps an "insertion point":
  /// this is where the next operations will be introduced.
  std::unique_ptr<mlir::FuncBuilder> builder;

  /// The symbol table maps a variable name to a value in the current scope.
  /// Entering a function creates a new scope, and the function arguments are
  /// added to the mapping. When the processing of a function is terminated, the
  /// scope is destroyed and the mappings created in this scope are dropped.
  llvm::ScopedHashTable<StringRef, mlir::Value *> symbolTable;

  /// Helper conversion for a Toy AST location to an MLIR location.
  mlir::FileLineColLoc loc(Location loc) {
    return mlir::FileLineColLoc::get(
        mlir::UniquedFilename::get(*loc.file, &context), loc.line, loc.col,
        &context);
  }

  /// Declare a variable in the current scope, return true if the variable
  /// wasn't declared yet.
  bool declare(llvm::StringRef var, mlir::Value *value) {
    if (symbolTable.count(var)) {
      return false;
    }
    symbolTable.insert(var, value);
    return true;
  }

  /// Create the prototype for an MLIR function with as many arguments as the
  /// provided Toy AST prototype.
  mlir::Function *mlirGen(PrototypeAST &proto) {
    // This is a generic function, the return type will be inferred later.
    llvm::SmallVector<mlir::Type, 4> ret_types;
    // Arguments type is uniformly a generic array.
    llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(),
                                               getType(VarType{}));
    auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context);
    auto *function = new mlir::Function(loc(proto.loc()), proto.getName(),
                                        func_type, /* attrs = */ {});

    // Mark the function as generic: it'll require type specialization for every
    // call site.
    if (function->getNumArguments())
      function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context));

    return function;
  }

  /// Emit a new function and add it to the MLIR module.
  std::unique_ptr<mlir::Function> mlirGen(FunctionAST &funcAST) {
    // Create a scope in the symbol table to hold variable declarations.
    ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable);

    // Create an MLIR function for the given prototype.
    std::unique_ptr<mlir::Function> function(mlirGen(*funcAST.getProto()));
    if (!function)
      return nullptr;

    // Let's start the body of the function now!
    // In MLIR the entry block of the function is special: it must have the same
    // argument list as the function itself.
    function->addEntryBlock();

    auto &entryBlock = function->front();
    auto &protoArgs = funcAST.getProto()->getArgs();
    // Declare all the function arguments in the symbol table.
    for (const auto &name_value :
         llvm::zip(protoArgs, entryBlock.getArguments())) {
      declare(std::get<0>(name_value)->getName(), std::get<1>(name_value));
    }

    // Create a builder for the function, it will be used throughout the codegen
    // to create operations in this function.
    builder = llvm::make_unique<mlir::FuncBuilder>(function.get());

    // Emit the body of the function.
    if (!mlirGen(*funcAST.getBody()))
      return nullptr;

    // Implicitly return void if no return statement was emited.
    // FIXME: we may fix the parser instead to always return the last expression
    // (this would possibly help the REPL case later)
    if (function->getBlocks().back().back().getName().getStringRef() !=
        "toy.return") {
      ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None);
      mlirGen(fakeRet);
    }

    return function;
  }

  /// Emit a binary operation
  mlir::Value *mlirGen(BinaryExprAST &binop) {
    // First emit the operations for each side of the operation before emitting
    // the operation itself. For example if the expression is `a + foo(a)`
    // 1) First it will visiting the LHS, which will return a reference to the
    //    value holding `a`. This value should have been emitted at declaration
    //    time and registered in the symbol table, so nothing would be
    //    codegen'd. If the value is not in the symbol table, an error has been
    //    emitted and nullptr is returned.
    // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted
    //    and the result value is returned. If an error occurs we get a nullptr
    //    and propagate.
    //
    mlir::Value *L = mlirGen(*binop.getLHS());
    if (!L)
      return nullptr;
    mlir::Value *R = mlirGen(*binop.getRHS());
    if (!R)
      return nullptr;
    auto location = loc(binop.loc());

    // Derive the operation name from the binary operator. At the moment we only
    // support '+' and '*'.
    switch (binop.getOp()) {
    case '+':
      return builder->create<AddOp>(location, L, R).getResult();
      break;
    case '*':
      return builder->create<MulOp>(location, L, R).getResult();
    default:
      context.emitError(loc(binop.loc()),
                        Twine("Error: invalid binary operator '") +
                            Twine(binop.getOp()) + "'");
      return nullptr;
    }
  }

  // This is a reference to a variable in an expression. The variable is
  // expected to have been declared and so should have a value in the symbol
  // table, otherwise emit an error and return nullptr.
  mlir::Value *mlirGen(VariableExprAST &expr) {
    if (symbolTable.count(expr.getName()))
      return symbolTable.lookup(expr.getName());
    context.emitError(loc(expr.loc()), Twine("Error: unknown variable '") +
                                           expr.getName() + "'");
    return nullptr;
  }

  // Emit a return operation, return true on success.
  bool mlirGen(ReturnExprAST &ret) {
    auto location = loc(ret.loc());
    // `return` takes an optional expression, we need to account for it here.
    if (!ret.getExpr().hasValue()) {
      builder->create<ReturnOp>(location);
      return true;
    }
    auto *expr = mlirGen(*ret.getExpr().getValue());
    if (!expr)
      return false;
    builder->create<ReturnOp>(location, expr);
    return true;
  }

  // Emit a literal/constant array. It will be emitted as a flattened array of
  // data in an Attribute attached to a `toy.constant` operation.
  // See documentation on [Attributes](LangRef.md#attributes) for more details.
  // Here is an excerpt:
  //
  //   Attributes are the mechanism for specifying constant data in MLIR in
  //   places where a variable is never allowed [...]. They consist of a name
  //   and a [concrete attribute value](#attribute-values). It is possible to
  //   attach attributes to operations, functions, and function arguments. The
  //   set of expected attributes, their structure, and their interpretation
  //   are all contextually dependent on what they are attached to.
  //
  // Example, the source level statement:
  //   var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
  // will be converted to:
  //   %0 = "toy.constant"() {value: dense<tensor<2x3xf64>,
  //     [[1.000000e+00, 2.000000e+00, 3.000000e+00],
  //      [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> memref<2x3xf64>
  //
  mlir::Value *mlirGen(LiteralExprAST &lit) {
    auto location = loc(lit.loc());
    // The attribute is a vector with an attribute per element (number) in the
    // array, see `collectData()` below for more details.
    std::vector<mlir::Attribute> data;
    data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1,
                                 std::multiplies<int>()));
    collectData(lit, data);

    // FIXME: using a tensor type is a HACK here.
    // Can we do differently without registering a dialect? Using a string blob?
    mlir::Type elementType = mlir::FloatType::getF64(&context);
    auto dataType = builder->getTensorType(lit.getDims(), elementType);

    // This is the actual attribute that actually hold the list of values for
    // this array literal.
    auto dataAttribute = builder->getDenseElementsAttr(dataType, data)
                             .cast<mlir::DenseElementsAttr>();

    // Build the MLIR op `toy.constant`, only boilerplate below.
    return builder->create<ConstantOp>(location, lit.getDims(), dataAttribute)
        .getResult();
  }

  // Recursive helper function to accumulate the data that compose an array
  // literal. It flattens the nested structure in the supplied vector. For
  // example with this array:
  //  [[1, 2], [3, 4]]
  // we will generate:
  //  [ 1, 2, 3, 4 ]
  // Individual numbers are wrapped in a light wrapper `mlir::FloatAttr`.
  // Attributes are the way MLIR attaches constant to operations and functions.
  void collectData(ExprAST &expr, std::vector<mlir::Attribute> &data) {
    if (auto *lit = dyn_cast<LiteralExprAST>(&expr)) {
      for (auto &value : lit->getValues())
        collectData(*value, data);
      return;
    }
    assert(isa<NumberExprAST>(expr) && "expected literal or number expr");
    mlir::Type elementType = mlir::FloatType::getF64(&context);
    auto attr = mlir::FloatAttr::getChecked(
        elementType, cast<NumberExprAST>(expr).getValue(), loc(expr.loc()));
    data.push_back(attr);
  }

  // Emit a call expression. It emits specific operations for the `transpose`
  // builtin. Other identifiers are assumed to be user-defined functions.
  mlir::Value *mlirGen(CallExprAST &call) {
    auto location = loc(call.loc());
    std::string callee = call.getCallee();
    if (callee == "transpose") {
      if (call.getArgs().size() != 1) {
        context.emitError(
            location, Twine("MLIR codegen encountered an error: toy.transpose "
                            "does not accept multiple arguments"));
        return nullptr;
      }
      mlir::Value *arg = mlirGen(*call.getArgs()[0]);
      return builder->create<TransposeOp>(location, arg).getResult();
    }

    // Codegen the operands first
    SmallVector<mlir::Value *, 4> operands;
    for (auto &expr : call.getArgs()) {
      auto *arg = mlirGen(*expr);
      if (!arg)
        return nullptr;
      operands.push_back(arg);
    }
    // Calls to user-defined function are mapped to a custom call that takes
    // the callee name as an attribute.
    return builder->create<GenericCallOp>(location, call.getCallee(), operands)
        .getResult();
  }

  // Emit a call expression. It emits specific operations for two builtins:
  // transpose(x) and print(x). Other identifiers are assumed to be user-defined
  // functions. Return false on failure.
  bool mlirGen(PrintExprAST &call) {
    auto *arg = mlirGen(*call.getArg());
    if (!arg)
      return false;
    auto location = loc(call.loc());
    builder->create<PrintOp>(location, arg);
    return true;
  }

  // Emit a constant for a single number (FIXME: semantic? broadcast?)
  mlir::Value *mlirGen(NumberExprAST &num) {
    auto location = loc(num.loc());
    mlir::Type elementType = mlir::FloatType::getF64(&context);
    auto attr = mlir::FloatAttr::getChecked(elementType, num.getValue(),
                                            loc(num.loc()));
    return builder->create<ConstantOp>(location, attr).getResult();
  }

  // Dispatch codegen for the right expression subclass using RTTI.
  mlir::Value *mlirGen(ExprAST &expr) {
    switch (expr.getKind()) {
    case toy::ExprAST::Expr_BinOp:
      return mlirGen(cast<BinaryExprAST>(expr));
    case toy::ExprAST::Expr_Var:
      return mlirGen(cast<VariableExprAST>(expr));
    case toy::ExprAST::Expr_Literal:
      return mlirGen(cast<LiteralExprAST>(expr));
    case toy::ExprAST::Expr_Call:
      return mlirGen(cast<CallExprAST>(expr));
    case toy::ExprAST::Expr_Num:
      return mlirGen(cast<NumberExprAST>(expr));
    default:
      context.emitError(
          loc(expr.loc()),
          Twine("MLIR codegen encountered an unhandled expr kind '") +
              Twine(expr.getKind()) + "'");
      return nullptr;
    }
  }

  // Handle a variable declaration, we'll codegen the expression that forms the
  // initializer and record the value in the symbol table before returning it.
  // Future expressions will be able to reference this variable through symbol
  // table lookup.
  mlir::Value *mlirGen(VarDeclExprAST &vardecl) {
    mlir::Value *value = nullptr;
    auto location = loc(vardecl.loc());
    if (auto init = vardecl.getInitVal()) {
      value = mlirGen(*init);
      if (!value)
        return nullptr;
      // We have the initializer value, but in case the variable was declared
      // with specific shape, we emit a "reshape" operation. It will get
      // optimized out later as needed.
      if (!vardecl.getType().shape.empty()) {
        value = builder
                    ->create<ReshapeOp>(
                        location, value,
                        getType(vardecl.getType()).cast<ToyArrayType>())
                    .getResult();
      }
    } else {
      context.emitError(loc(vardecl.loc()),
                        "Missing initializer in variable declaration");
      return nullptr;
    }
    // Register the value in the symbol table
    declare(vardecl.getName(), value);
    return value;
  }

  /// Codegen a list of expression, return false if one of them hit an error.
  bool mlirGen(ExprASTList &blockAST) {
    ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable);
    for (auto &expr : blockAST) {
      // Specific handling for variable declarations, return statement, and
      // print. These can only appear in block list and not in nested
      // expressions.
      if (auto *vardecl = dyn_cast<VarDeclExprAST>(expr.get())) {
        if (!mlirGen(*vardecl))
          return false;
        continue;
      }
      if (auto *ret = dyn_cast<ReturnExprAST>(expr.get())) {
        if (!mlirGen(*ret))
          return false;
        return true;
      }
      if (auto *print = dyn_cast<PrintExprAST>(expr.get())) {
        if (!mlirGen(*print))
          return false;
        continue;
      }
      // Generic expression dispatch codegen.
      if (!mlirGen(*expr))
        return false;
    }
    return true;
  }

  /// Build a type from a list of shape dimensions. Types are `array` followed
  /// by an optional dimension list, example: array<2, 2>
  /// They are wrapped in a `toy` dialect (see next chapter) and get printed:
  ///   !toy.array<2, 2>
  template <typename T> mlir::Type getType(T shape) {
    SmallVector<int64_t, 8> shape64(shape.begin(), shape.end());
    return ToyArrayType::get(&context, shape64);
  }

  /// Build an MLIR type from a Toy AST variable type
  /// (forward to the generic getType(T) above).
  mlir::Type getType(const VarType &type) { return getType(type.shape); }
};

} // namespace

namespace toy {

// The public API for codegen.
std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context,
                                      ModuleAST &moduleAST) {
  return MLIRGenImpl(context).mlirGen(moduleAST);
}

} // namespace toy