summaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorTuowen Zhao <ztuowen@gmail.com>2019-04-27 19:05:25 -0600
committerTuowen Zhao <ztuowen@gmail.com>2019-04-27 19:05:25 -0600
commit0781257b2a8d544abdcce38824a9b8288a04800d (patch)
tree365cea96de343e354913f90b35fc944e4459b2e9 /mlir
parent4127831a28e31ac53ffdb1d7e7a88dd7d6317c6e (diff)
downloadmlir-toy-0781257b2a8d544abdcce38824a9b8288a04800d.tar.gz
mlir-toy-0781257b2a8d544abdcce38824a9b8288a04800d.tar.bz2
mlir-toy-0781257b2a8d544abdcce38824a9b8288a04800d.zip
Split toy dialect using static registration
Diffstat (limited to 'mlir')
-rw-r--r--mlir/EarlyLowering.cpp158
-rw-r--r--mlir/LateLowering.cpp452
-rw-r--r--mlir/MLIRGen.cpp480
-rw-r--r--mlir/ShapeInferencePass.cpp387
-rw-r--r--mlir/ToyCombine.cpp209
-rw-r--r--mlir/ToyDialect.cpp405
6 files changed, 0 insertions, 2091 deletions
diff --git a/mlir/EarlyLowering.cpp b/mlir/EarlyLowering.cpp
deleted file mode 100644
index 634c72e..0000000
--- a/mlir/EarlyLowering.cpp
+++ /dev/null
@@ -1,158 +0,0 @@
-//=======- EarlyLowering.cpp - Toy Lowering to Linear Algebra Dialect -=======//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// This file implements early lowering of Toy IR to Linalg Dialect: we only
-// lower the computationally intensive part of the program (matmul...) to a
-// dialect specialized for optimizations.
-//
-// This is intended to showcase how multiple dialects can cohabit in the same
-// function. After this lowering, you would still have toy.print in the IR for
-// example.
-//
-//===----------------------------------------------------------------------===//
-
-#include "toy/Dialect.h"
-
-#include "linalg3/Intrinsics.h"
-#include "linalg1/ViewOp.h"
-#include "linalg3/TensorOps.h"
-#include "mlir/EDSC/Builders.h"
-#include "mlir/EDSC/Helpers.h"
-#include "mlir/EDSC/Intrinsics.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/OperationSupport.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/LLVMIR/LLVMDialect.h"
-#include "mlir/Parser.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "llvm/ADT/DenseSet.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/IR/DerivedTypes.h"
-#include "llvm/IR/IRBuilder.h"
-#include "llvm/IR/Type.h"
-
-#include <algorithm>
-
-using namespace mlir;
-
-namespace {
-/// Utility function for type casting: this is making the type checker happy,
-/// while delaying the actual work involved to convert the type. Most of the
-/// time both side of the cast (producer and consumer) will be lowered to a
-/// dialect like LLVM and end up with the same LLVM representation, at which
-/// point this becomes a no-op and is eliminated.
-Value *typeCast(FuncBuilder &builder, Value *val, Type destTy) {
- if (val->getType() == destTy)
- return val;
- return builder.create<toy::TypeCastOp>(val->getLoc(), val, destTy)
- .getResult();
-}
-
-/// Create a type cast to turn a toy.array into a memref. The Toy Array will be
-/// lowered to a memref during buffer allocation, at which point the type cast
-/// becomes useless.
-Value *memRefTypeCast(FuncBuilder &builder, Value *val) {
- if (val->getType().isa<MemRefType>())
- return val;
- auto toyArrayTy = val->getType().dyn_cast<toy::ToyArrayType>();
- if (!toyArrayTy)
- return val;
- return typeCast(builder, val, toyArrayTy.toMemref());
-}
-
-/// Lower toy.mul to Linalg `matmul`.
-///
-/// This class inherit from `DialectOpConversion` and override `rewrite`,
-/// similarly to the PatternRewriter introduced in the previous chapter.
-/// It will be called by the DialectConversion framework (see `LateLowering`
-/// class below).
-class MulOpConversion : public DialectOpConversion {
-public:
- explicit MulOpConversion(MLIRContext *context)
- : DialectOpConversion(toy::MulOp::getOperationName(), 1, context) {}
-
- SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
- FuncBuilder &rewriter) const override {
- using namespace edsc;
- using intrinsics::constant_index;
- using linalg::intrinsics::range;
- using linalg::intrinsics::view;
- toy::MulOp mul = op->cast<toy::MulOp>();
- auto loc = mul.getLoc();
- Value *result = memRefTypeCast(
- rewriter, rewriter.create<toy::AllocOp>(loc, mul.getResult()->getType())
- .getResult());
- Value *lhs = memRefTypeCast(rewriter, operands[0]);
- auto memrefLHSTy = lhs->getType().cast<MemRefType>();
- Value *rhs = memRefTypeCast(rewriter, operands[1]);
- auto memrefRHSTy = rhs->getType().cast<MemRefType>();
- mlir::edsc::ScopedContext scope(rewriter, loc);
- edsc::ValueHandle r0 =
- range(constant_index(0), constant_index(memrefLHSTy.getDimSize(0)),
- constant_index(1));
- edsc::ValueHandle r1 =
- range(constant_index(0), constant_index(memrefLHSTy.getDimSize(1)),
- constant_index(1));
- edsc::ValueHandle r2 =
- range(constant_index(0), constant_index(memrefRHSTy.getDimSize(1)),
- constant_index(1));
- auto lhsView = view(lhs, {r0, r1});
- auto rhsView = view(rhs, {r1, r2});
- auto resultView = view(result, {r0, r2});
- rewriter.create<linalg::MatmulOp>(loc, lhsView, rhsView, resultView);
- return {typeCast(rewriter, result, mul.getType())};
- }
-};
-
-// The conversion class from Toy IR Dialect to a mix of Linalg and LLVM.
-class EarlyLowering : public DialectConversion {
-protected:
- // Initialize the list of converters.
- llvm::DenseSet<DialectOpConversion *>
- initConverters(MLIRContext *context) override {
- return ConversionListBuilder<MulOpConversion>::build(&allocator, context);
- }
-
-private:
- llvm::BumpPtrAllocator allocator;
-};
-
-/// This is lowering to Linalg the parts that are computationally intensive
-/// (like matmul for example...) while keeping the rest of the code in the Toy
-/// dialect.
-struct EarlyLoweringPass : public ModulePass<EarlyLoweringPass> {
-
- void runOnModule() override {
- if (failed(EarlyLowering().convert(&getModule()))) {
- getModule().getContext()->emitError(
- mlir::UnknownLoc::get(getModule().getContext()),
- "Error lowering Toy\n");
- signalPassFailure();
- }
- }
-};
-} // end anonymous namespace
-
-namespace toy {
-Pass *createEarlyLoweringPass() { return new EarlyLoweringPass(); }
-
-std::unique_ptr<mlir::DialectConversion> makeToyEarlyLowering() {
- return llvm::make_unique<EarlyLowering>();
-}
-
-} // namespace toy
diff --git a/mlir/LateLowering.cpp b/mlir/LateLowering.cpp
deleted file mode 100644
index eeae6ee..0000000
--- a/mlir/LateLowering.cpp
+++ /dev/null
@@ -1,452 +0,0 @@
-//====- LateLowering.cpp - Lowering from Toy+Linalg to LLVM -===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// This file implements late lowering of IR mixing Toy and Linalg to LLVM.
-// It involves intemerdiate steps:
-// -
-// - a mix of affine and standard dialect.
-//
-//===----------------------------------------------------------------------===//
-
-#include "toy/Dialect.h"
-
-#include "linalg3/Intrinsics.h"
-#include "linalg1/ViewOp.h"
-#include "linalg3/ConvertToLLVMDialect.h"
-#include "linalg3/TensorOps.h"
-#include "linalg3/Transforms.h"
-#include "mlir/EDSC/Builders.h"
-#include "mlir/EDSC/Helpers.h"
-#include "mlir/EDSC/Intrinsics.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/OperationSupport.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/LLVMIR/LLVMDialect.h"
-#include "mlir/Parser.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-#include "llvm/ADT/DenseSet.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/IR/DerivedTypes.h"
-#include "llvm/IR/IRBuilder.h"
-#include "llvm/IR/Type.h"
-
-#include <algorithm>
-
-using namespace mlir;
-
-namespace {
-/// Utility function for type casting: this is making the type checker happy,
-/// while delaying the actual work involved to convert the type. Most of the
-/// time both side of the cast (producer and consumer) will be lowered to a
-/// dialect like LLVM and end up with the same LLVM representation, at which
-/// point this becomes a no-op and is eliminated.
-Value *typeCast(FuncBuilder &builder, Value *val, Type destTy) {
- if (val->getType() == destTy)
- return val;
- return builder.create<toy::TypeCastOp>(val->getLoc(), val, destTy)
- .getResult();
-}
-
-/// Create a type cast to turn a toy.array into a memref. The Toy Array will be
-/// lowered to a memref during buffer allocation, at which point the type cast
-/// becomes useless.
-Value *memRefTypeCast(FuncBuilder &builder, Value *val) {
- if (val->getType().isa<MemRefType>())
- return val;
- auto toyArrayTy = val->getType().dyn_cast<toy::ToyArrayType>();
- if (!toyArrayTy)
- return val;
- return typeCast(builder, val, toyArrayTy.toMemref());
-}
-
-/// Lower a toy.add to an affine loop nest.
-///
-/// This class inherit from `DialectOpConversion` and override `rewrite`,
-/// similarly to the PatternRewriter introduced in the previous chapter.
-/// It will be called by the DialectConversion framework (see `LateLowering`
-/// class below).
-class AddOpConversion : public DialectOpConversion {
-public:
- explicit AddOpConversion(MLIRContext *context)
- : DialectOpConversion(toy::AddOp::getOperationName(), 1, context) {}
-
- /// Lower the `op` by generating IR using the `rewriter` builder. The builder
- /// is setup with a new function, the `operands` array has been populated with
- /// the rewritten operands for `op` in the new function.
- /// The results created by the new IR with the builder are returned, and their
- /// number must match the number of result of `op`.
- SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
- FuncBuilder &rewriter) const override {
- auto add = op->cast<toy::AddOp>();
- auto loc = add.getLoc();
- // Create a `toy.alloc` operation to allocate the output buffer for this op.
- Value *result = memRefTypeCast(
- rewriter, rewriter.create<toy::AllocOp>(loc, add.getResult()->getType())
- .getResult());
- Value *lhs = memRefTypeCast(rewriter, operands[0]);
- Value *rhs = memRefTypeCast(rewriter, operands[1]);
-
- using namespace edsc;
- ScopedContext scope(rewriter, loc);
- ValueHandle zero = intrinsics::constant_index(0);
- MemRefView vRes(result), vLHS(lhs), vRHS(rhs);
- IndexedValue iRes(result), iLHS(lhs), iRHS(rhs);
- IndexHandle i, j, M(vRes.ub(0));
- if (vRes.rank() == 1) {
- LoopNestBuilder({&i}, {zero}, {M}, {1})({iRes(i) = iLHS(i) + iRHS(i)});
- } else {
- assert(vRes.rank() == 2 && "only rank 1 and 2 are supported right now");
- IndexHandle N(vRes.ub(1));
- LoopNestBuilder({&i, &j}, {zero, zero}, {M, N},
- {1, 1})({iRes(i, j) = iLHS(i, j) + iRHS(i, j)});
- }
-
- // Return the newly allocated buffer, with a type.cast to preserve the
- // consumers.
- return {typeCast(rewriter, result, add.getType())};
- }
-};
-
-/// Lowers `toy.print` to a loop nest calling `printf` on every individual
-/// elements of the array.
-class PrintOpConversion : public DialectOpConversion {
-public:
- explicit PrintOpConversion(MLIRContext *context)
- : DialectOpConversion(toy::PrintOp::getOperationName(), 1, context) {}
-
- SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
- FuncBuilder &rewriter) const override {
- // Get or create the declaration of the printf function in the module.
- Function *printfFunc = getPrintf(*op->getFunction()->getModule());
-
- auto print = op->cast<toy::PrintOp>();
- auto loc = print.getLoc();
- // We will operate on a MemRef abstraction, we use a type.cast to get one
- // if our operand is still a Toy array.
- Value *operand = memRefTypeCast(rewriter, operands[0]);
- Type retTy = printfFunc->getType().getResult(0);
-
- // Create our loop nest now
- using namespace edsc;
- using llvmCall = intrinsics::ValueBuilder<LLVM::CallOp>;
- ScopedContext scope(rewriter, loc);
- ValueHandle zero = intrinsics::constant_index(0);
- ValueHandle fmtCst(getConstantCharBuffer(rewriter, loc, "%f "));
- MemRefView vOp(operand);
- IndexedValue iOp(operand);
- IndexHandle i, j, M(vOp.ub(0));
-
- ValueHandle fmtEol(getConstantCharBuffer(rewriter, loc, "\n"));
- if (vOp.rank() == 1) {
- // clang-format off
- LoopBuilder(&i, zero, M, 1)({
- llvmCall(retTy,
- rewriter.getFunctionAttr(printfFunc),
- {fmtCst, iOp(i)})
- });
- llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol});
- // clang-format on
- } else {
- IndexHandle N(vOp.ub(1));
- // clang-format off
- LoopBuilder(&i, zero, M, 1)({
- LoopBuilder(&j, zero, N, 1)({
- llvmCall(retTy,
- rewriter.getFunctionAttr(printfFunc),
- {fmtCst, iOp(i, j)})
- }),
- llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol})
- });
- // clang-format on
- }
- return {};
- }
-
-private:
- // Turn a string into a toy.alloc (malloc/free abstraction) and a sequence
- // of stores into the buffer, and return a MemRef into the buffer.
- Value *getConstantCharBuffer(FuncBuilder &builder, Location loc,
- StringRef data) const {
- auto retTy =
- builder.getMemRefType(data.size() + 1, builder.getIntegerType(8));
- Value *result = builder.create<toy::AllocOp>(loc, retTy).getResult();
- using namespace edsc;
- using intrinsics::constant_index;
- using intrinsics::constant_int;
- ScopedContext scope(builder, loc);
- MemRefView vOp(result);
- IndexedValue iOp(result);
- for (uint64_t i = 0; i < data.size(); ++i) {
- iOp(constant_index(i)) = constant_int(data[i], 8);
- }
- iOp(constant_index(data.size())) = constant_int(0, 8);
- return result;
- }
-
- /// Return the prototype declaration for printf in the module, create it if
- /// necessary.
- Function *getPrintf(Module &module) const {
- auto *printfFunc = module.getNamedFunction("printf");
- if (printfFunc)
- return printfFunc;
-
- // Create a function declaration for printf, signature is `i32 (i8*, ...)`
- Builder builder(&module);
- MLIRContext *context = module.getContext();
- LLVM::LLVMDialect *llvmDialect = static_cast<LLVM::LLVMDialect *>(
- module.getContext()->getRegisteredDialect("llvm"));
- auto &llvmModule = llvmDialect->getLLVMModule();
- llvm::IRBuilder<> llvmBuilder(llvmModule.getContext());
-
- auto llvmI32Ty = LLVM::LLVMType::get(context, llvmBuilder.getIntNTy(32));
- auto llvmI8PtrTy =
- LLVM::LLVMType::get(context, llvmBuilder.getIntNTy(8)->getPointerTo());
- auto printfTy = builder.getFunctionType({llvmI8PtrTy}, {llvmI32Ty});
- printfFunc = new Function(builder.getUnknownLoc(), "printf", printfTy);
- // It should be variadic, but we don't support it fully just yet.
- printfFunc->setAttr("std.varargs", builder.getBoolAttr(true));
- module.getFunctions().push_back(printfFunc);
- return printfFunc;
- }
-};
-
-/// Lowers constant to a sequence of store in a buffer.
-class ConstantOpConversion : public DialectOpConversion {
-public:
- explicit ConstantOpConversion(MLIRContext *context)
- : DialectOpConversion(toy::ConstantOp::getOperationName(), 1, context) {}
-
- SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
- FuncBuilder &rewriter) const override {
- toy::ConstantOp cstOp = op->cast<toy::ConstantOp>();
- auto loc = cstOp.getLoc();
- auto retTy = cstOp.getResult()->getType().cast<toy::ToyArrayType>();
- auto shape = retTy.getShape();
- Value *result = memRefTypeCast(
- rewriter, rewriter.create<toy::AllocOp>(loc, retTy).getResult());
-
- auto cstValue = cstOp.getValue();
- auto f64Ty = rewriter.getF64Type();
- using namespace edsc;
- using intrinsics::constant_float;
- using intrinsics::constant_index;
- ScopedContext scope(rewriter, loc);
- MemRefView vOp(result);
- IndexedValue iOp(result);
- for (uint64_t i = 0; i < shape[0]; ++i) {
- if (shape.size() == 1) {
- auto value = cstValue.getValue(ArrayRef<uint64_t>{i})
- .cast<FloatAttr>()
- .getValue();
- iOp(constant_index(i)) = constant_float(value, f64Ty);
- continue;
- }
- for (uint64_t j = 0; j < shape[1]; ++j) {
- auto value = cstValue.getValue(ArrayRef<uint64_t>{i, j})
- .cast<FloatAttr>()
- .getValue();
- iOp(constant_index(i), constant_index(j)) =
- constant_float(value, f64Ty);
- }
- }
- return {result};
- }
-};
-
-/// Lower transpose operation to an affine loop nest.
-class TransposeOpConversion : public DialectOpConversion {
-public:
- explicit TransposeOpConversion(MLIRContext *context)
- : DialectOpConversion(toy::TransposeOp::getOperationName(), 1, context) {}
-
- SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
- FuncBuilder &rewriter) const override {
- auto transpose = op->cast<toy::TransposeOp>();
- auto loc = transpose.getLoc();
- Value *result = memRefTypeCast(
- rewriter,
- rewriter.create<toy::AllocOp>(loc, transpose.getResult()->getType())
- .getResult());
- Value *operand = memRefTypeCast(rewriter, operands[0]);
-
- using namespace edsc;
- ScopedContext scope(rewriter, loc);
- ValueHandle zero = intrinsics::constant_index(0);
- MemRefView vRes(result), vOperand(operand);
- IndexedValue iRes(result), iOperand(operand);
- IndexHandle i, j, M(vRes.ub(0)), N(vRes.ub(1));
- // clang-format off
- LoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, {1, 1})({
- iRes(i, j) = iOperand(j, i)
- });
- // clang-format on
-
- return {typeCast(rewriter, result, transpose.getType())};
- }
-};
-
-// Lower toy.return to standard return operation.
-class ReturnOpConversion : public DialectOpConversion {
-public:
- explicit ReturnOpConversion(MLIRContext *context)
- : DialectOpConversion(toy::ReturnOp::getOperationName(), 1, context) {}
-
- SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
- FuncBuilder &rewriter) const override {
- auto retOp = op->cast<toy::ReturnOp>();
- using namespace edsc;
- auto loc = retOp.getLoc();
- // Argument is optional, handle both cases.
- if (retOp.getNumOperands())
- rewriter.create<ReturnOp>(loc, operands[0]);
- else
- rewriter.create<ReturnOp>(loc);
- return {};
- }
-};
-
-/// This is the main class registering our individual converter classes with
-/// the DialectConversion framework in MLIR.
-class LateLowering : public DialectConversion {
-protected:
- /// Initialize the list of converters.
- llvm::DenseSet<DialectOpConversion *>
- initConverters(MLIRContext *context) override {
- return ConversionListBuilder<AddOpConversion, PrintOpConversion,
- ConstantOpConversion, TransposeOpConversion,
- ReturnOpConversion>::build(&allocator,
- context);
- }
-
- /// Convert a Toy type, this gets called for block and region arguments, and
- /// attributes.
- Type convertType(Type t) override {
- if (auto array = t.cast<toy::ToyArrayType>()) {
- return array.toMemref();
- }
- return t;
- }
-
-private:
- llvm::BumpPtrAllocator allocator;
-};
-
-/// This is lowering to Linalg the parts that can be (matmul and add on arrays)
-/// and is targeting LLVM otherwise.
-struct LateLoweringPass : public ModulePass<LateLoweringPass> {
-
- void runOnModule() override {
- // Perform Toy specific lowering
- if (failed(LateLowering().convert(&getModule()))) {
- getModule().getContext()->emitError(
- UnknownLoc::get(getModule().getContext()), "Error lowering Toy\n");
- signalPassFailure();
- }
- // At this point the IR is almost using only standard and affine dialects.
- // A few things remain before we emit LLVM IR. First to reuse as much of
- // MLIR as possible we will try to lower everything to the standard and/or
- // affine dialect: they already include conversion to the LLVM dialect.
-
- // First patch calls type to return memref instead of ToyArray
- for (auto &function : getModule()) {
- function.walk([&](Operation *op) {
- auto callOp = op->dyn_cast<CallOp>();
- if (!callOp)
- return;
- if (!callOp.getNumResults())
- return;
- auto retToyTy =
- callOp.getResult(0)->getType().dyn_cast<toy::ToyArrayType>();
- if (!retToyTy)
- return;
- callOp.getResult(0)->setType(retToyTy.toMemref());
- });
- }
-
- for (auto &function : getModule()) {
- function.walk([&](Operation *op) {
- // Turns toy.alloc into sequence of alloc/dealloc (later malloc/free).
- if (auto allocOp = op->dyn_cast<toy::AllocOp>()) {
- auto result = allocTensor(allocOp);
- allocOp.replaceAllUsesWith(result);
- allocOp.erase();
- return;
- }
- // Eliminate all type.cast before lowering to LLVM.
- if (auto typeCastOp = op->dyn_cast<toy::TypeCastOp>()) {
- typeCastOp.replaceAllUsesWith(typeCastOp.getOperand());
- typeCastOp.erase();
- return;
- }
- });
- }
-
- // Lower Linalg to affine
- for (auto &function : getModule())
- linalg::lowerToLoops(&function);
-
- getModule().dump();
-
- // Finally convert to LLVM Dialect
- linalg::convertLinalg3ToLLVM(getModule());
- }
-
- /// Allocate buffers (malloc/free) for Toy operations. This can't be done as
- /// part of dialect conversion framework since we need to insert `dealloc`
- /// operations just before the return, but the conversion framework is
- /// operating in a brand new function: we don't have the return to hook the
- /// dealloc operations.
- Value *allocTensor(toy::AllocOp alloc) {
- FuncBuilder builder(alloc);
- auto retTy = alloc.getResult()->getType();
-
- auto memRefTy = retTy.dyn_cast<MemRefType>();
- if (!memRefTy)
- memRefTy = retTy.cast<toy::ToyArrayType>().toMemref();
- if (!memRefTy) {
- alloc.emitOpError("is expected to allocate a Toy array or a MemRef");
- llvm_unreachable("fatal error");
- }
- auto loc = alloc.getLoc();
- Value *result = builder.create<AllocOp>(loc, memRefTy).getResult();
-
- // Insert a `dealloc` operation right before the `return` operations, unless
- // it is returned itself in which case the caller is responsible for it.
- builder.getFunction()->walk([&](Operation *op) {
- auto returnOp = op->dyn_cast<ReturnOp>();
- if (!returnOp)
- return;
- if (returnOp.getNumOperands() && returnOp.getOperand(0) == alloc)
- return;
- builder.setInsertionPoint(returnOp);
- builder.create<DeallocOp>(alloc.getLoc(), result);
- });
- return result;
- }
-};
-} // end anonymous namespace
-
-namespace toy {
-Pass *createLateLoweringPass() { return new LateLoweringPass(); }
-
-std::unique_ptr<DialectConversion> makeToyLateLowering() {
- return llvm::make_unique<LateLowering>();
-}
-
-} // namespace toy
diff --git a/mlir/MLIRGen.cpp b/mlir/MLIRGen.cpp
deleted file mode 100644
index e2001fb..0000000
--- a/mlir/MLIRGen.cpp
+++ /dev/null
@@ -1,480 +0,0 @@
-//===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// This file implements a simple IR generation targeting MLIR from a Module AST
-// for the Toy language.
-//
-//===----------------------------------------------------------------------===//
-
-#include "toy/MLIRGen.h"
-#include "toy/AST.h"
-#include "toy/Dialect.h"
-
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Location.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Module.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/IR/Types.h"
-#include "mlir/StandardOps/Ops.h"
-
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/ScopedHashTable.h"
-#include "llvm/Support/raw_ostream.h"
-#include <numeric>
-
-using namespace toy;
-using llvm::cast;
-using llvm::dyn_cast;
-using llvm::isa;
-using llvm::make_unique;
-using llvm::ScopedHashTableScope;
-using llvm::SmallVector;
-using llvm::StringRef;
-using llvm::Twine;
-
-namespace {
-
-/// Implementation of a simple MLIR emission from the Toy AST.
-///
-/// This will emit operations that are specific to the Toy language, preserving
-/// the semantics of the language and (hopefully) allow to perform accurate
-/// analysis and transformation based on these high level semantics.
-///
-/// At this point we take advantage of the "raw" MLIR APIs to create operations
-/// that haven't been registered in any way with MLIR. These operations are
-/// unknown to MLIR, custom passes could operate by string-matching the name of
-/// these operations, but no other type checking or semantic is associated with
-/// them natively by MLIR.
-class MLIRGenImpl {
-public:
- MLIRGenImpl(mlir::MLIRContext &context) : context(context) {}
-
- /// Public API: convert the AST for a Toy module (source file) to an MLIR
- /// Module.
- std::unique_ptr<mlir::Module> mlirGen(ModuleAST &moduleAST) {
- // We create an empty MLIR module and codegen functions one at a time and
- // add them to the module.
- theModule = make_unique<mlir::Module>(&context);
-
- for (FunctionAST &F : moduleAST) {
- auto func = mlirGen(F);
- if (!func)
- return nullptr;
- theModule->getFunctions().push_back(func.release());
- }
-
- // FIXME: (in the next chapter...) without registering a dialect in MLIR,
- // this won't do much, but it should at least check some structural
- // properties.
- if (failed(theModule->verify())) {
- context.emitError(mlir::UnknownLoc::get(&context),
- "Module verification error");
- return nullptr;
- }
-
- return std::move(theModule);
- }
-
-private:
- /// In MLIR (like in LLVM) a "context" object holds the memory allocation and
- /// the ownership of many internal structure of the IR and provide a level
- /// of "uniquing" across multiple modules (types for instance).
- mlir::MLIRContext &context;
-
- /// A "module" matches a source file: it contains a list of functions.
- std::unique_ptr<mlir::Module> theModule;
-
- /// The builder is a helper class to create IR inside a function. It is
- /// re-initialized every time we enter a function and kept around as a
- /// convenience for emitting individual operations.
- /// The builder is stateful, in particular it keeeps an "insertion point":
- /// this is where the next operations will be introduced.
- std::unique_ptr<mlir::FuncBuilder> builder;
-
- /// The symbol table maps a variable name to a value in the current scope.
- /// Entering a function creates a new scope, and the function arguments are
- /// added to the mapping. When the processing of a function is terminated, the
- /// scope is destroyed and the mappings created in this scope are dropped.
- llvm::ScopedHashTable<StringRef, mlir::Value *> symbolTable;
-
- /// Helper conversion for a Toy AST location to an MLIR location.
- mlir::FileLineColLoc loc(Location loc) {
- return mlir::FileLineColLoc::get(
- mlir::UniquedFilename::get(*loc.file, &context), loc.line, loc.col,
- &context);
- }
-
- /// Declare a variable in the current scope, return true if the variable
- /// wasn't declared yet.
- bool declare(llvm::StringRef var, mlir::Value *value) {
- if (symbolTable.count(var)) {
- return false;
- }
- symbolTable.insert(var, value);
- return true;
- }
-
- /// Create the prototype for an MLIR function with as many arguments as the
- /// provided Toy AST prototype.
- mlir::Function *mlirGen(PrototypeAST &proto) {
- // This is a generic function, the return type will be inferred later.
- llvm::SmallVector<mlir::Type, 4> ret_types;
- // Arguments type is uniformly a generic array.
- llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(),
- getType(VarType{}));
- auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context);
- auto *function = new mlir::Function(loc(proto.loc()), proto.getName(),
- func_type, /* attrs = */ {});
-
- // Mark the function as generic: it'll require type specialization for every
- // call site.
- if (function->getNumArguments())
- function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
-
- return function;
- }
-
- /// Emit a new function and add it to the MLIR module.
- std::unique_ptr<mlir::Function> mlirGen(FunctionAST &funcAST) {
- // Create a scope in the symbol table to hold variable declarations.
- ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable);
-
- // Create an MLIR function for the given prototype.
- std::unique_ptr<mlir::Function> function(mlirGen(*funcAST.getProto()));
- if (!function)
- return nullptr;
-
- // Let's start the body of the function now!
- // In MLIR the entry block of the function is special: it must have the same
- // argument list as the function itself.
- function->addEntryBlock();
-
- auto &entryBlock = function->front();
- auto &protoArgs = funcAST.getProto()->getArgs();
- // Declare all the function arguments in the symbol table.
- for (const auto &name_value :
- llvm::zip(protoArgs, entryBlock.getArguments())) {
- declare(std::get<0>(name_value)->getName(), std::get<1>(name_value));
- }
-
- // Create a builder for the function, it will be used throughout the codegen
- // to create operations in this function.
- builder = llvm::make_unique<mlir::FuncBuilder>(function.get());
-
- // Emit the body of the function.
- if (!mlirGen(*funcAST.getBody()))
- return nullptr;
-
- // Implicitly return void if no return statement was emited.
- // FIXME: we may fix the parser instead to always return the last expression
- // (this would possibly help the REPL case later)
- if (function->getBlocks().back().back().getName().getStringRef() !=
- "toy.return") {
- ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None);
- mlirGen(fakeRet);
- }
-
- return function;
- }
-
- /// Emit a binary operation
- mlir::Value *mlirGen(BinaryExprAST &binop) {
- // First emit the operations for each side of the operation before emitting
- // the operation itself. For example if the expression is `a + foo(a)`
- // 1) First it will visiting the LHS, which will return a reference to the
- // value holding `a`. This value should have been emitted at declaration
- // time and registered in the symbol table, so nothing would be
- // codegen'd. If the value is not in the symbol table, an error has been
- // emitted and nullptr is returned.
- // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted
- // and the result value is returned. If an error occurs we get a nullptr
- // and propagate.
- //
- mlir::Value *L = mlirGen(*binop.getLHS());
- if (!L)
- return nullptr;
- mlir::Value *R = mlirGen(*binop.getRHS());
- if (!R)
- return nullptr;
- auto location = loc(binop.loc());
-
- // Derive the operation name from the binary operator. At the moment we only
- // support '+' and '*'.
- switch (binop.getOp()) {
- case '+':
- return builder->create<AddOp>(location, L, R).getResult();
- break;
- case '*':
- return builder->create<MulOp>(location, L, R).getResult();
- default:
- context.emitError(loc(binop.loc()),
- Twine("Error: invalid binary operator '") +
- Twine(binop.getOp()) + "'");
- return nullptr;
- }
- }
-
- // This is a reference to a variable in an expression. The variable is
- // expected to have been declared and so should have a value in the symbol
- // table, otherwise emit an error and return nullptr.
- mlir::Value *mlirGen(VariableExprAST &expr) {
- if (symbolTable.count(expr.getName()))
- return symbolTable.lookup(expr.getName());
- context.emitError(loc(expr.loc()), Twine("Error: unknown variable '") +
- expr.getName() + "'");
- return nullptr;
- }
-
- // Emit a return operation, return true on success.
- bool mlirGen(ReturnExprAST &ret) {
- auto location = loc(ret.loc());
- // `return` takes an optional expression, we need to account for it here.
- if (!ret.getExpr().hasValue()) {
- builder->create<ReturnOp>(location);
- return true;
- }
- auto *expr = mlirGen(*ret.getExpr().getValue());
- if (!expr)
- return false;
- builder->create<ReturnOp>(location, expr);
- return true;
- }
-
- // Emit a literal/constant array. It will be emitted as a flattened array of
- // data in an Attribute attached to a `toy.constant` operation.
- // See documentation on [Attributes](LangRef.md#attributes) for more details.
- // Here is an excerpt:
- //
- // Attributes are the mechanism for specifying constant data in MLIR in
- // places where a variable is never allowed [...]. They consist of a name
- // and a [concrete attribute value](#attribute-values). It is possible to
- // attach attributes to operations, functions, and function arguments. The
- // set of expected attributes, their structure, and their interpretation
- // are all contextually dependent on what they are attached to.
- //
- // Example, the source level statement:
- // var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
- // will be converted to:
- // %0 = "toy.constant"() {value: dense<tensor<2x3xf64>,
- // [[1.000000e+00, 2.000000e+00, 3.000000e+00],
- // [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> memref<2x3xf64>
- //
- mlir::Value *mlirGen(LiteralExprAST &lit) {
- auto location = loc(lit.loc());
- // The attribute is a vector with an attribute per element (number) in the
- // array, see `collectData()` below for more details.
- std::vector<mlir::Attribute> data;
- data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1,
- std::multiplies<int>()));
- collectData(lit, data);
-
- // FIXME: using a tensor type is a HACK here.
- // Can we do differently without registering a dialect? Using a string blob?
- mlir::Type elementType = mlir::FloatType::getF64(&context);
- auto dataType = builder->getTensorType(lit.getDims(), elementType);
-
- // This is the actual attribute that actually hold the list of values for
- // this array literal.
- auto dataAttribute = builder->getDenseElementsAttr(dataType, data)
- .cast<mlir::DenseElementsAttr>();
-
- // Build the MLIR op `toy.constant`, only boilerplate below.
- return builder->create<ConstantOp>(location, lit.getDims(), dataAttribute)
- .getResult();
- }
-
- // Recursive helper function to accumulate the data that compose an array
- // literal. It flattens the nested structure in the supplied vector. For
- // example with this array:
- // [[1, 2], [3, 4]]
- // we will generate:
- // [ 1, 2, 3, 4 ]
- // Individual numbers are wrapped in a light wrapper `mlir::FloatAttr`.
- // Attributes are the way MLIR attaches constant to operations and functions.
- void collectData(ExprAST &expr, std::vector<mlir::Attribute> &data) {
- if (auto *lit = dyn_cast<LiteralExprAST>(&expr)) {
- for (auto &value : lit->getValues())
- collectData(*value, data);
- return;
- }
- assert(isa<NumberExprAST>(expr) && "expected literal or number expr");
- mlir::Type elementType = mlir::FloatType::getF64(&context);
- auto attr = mlir::FloatAttr::getChecked(
- elementType, cast<NumberExprAST>(expr).getValue(), loc(expr.loc()));
- data.push_back(attr);
- }
-
- // Emit a call expression. It emits specific operations for the `transpose`
- // builtin. Other identifiers are assumed to be user-defined functions.
- mlir::Value *mlirGen(CallExprAST &call) {
- auto location = loc(call.loc());
- std::string callee = call.getCallee();
- if (callee == "transpose") {
- if (call.getArgs().size() != 1) {
- context.emitError(
- location, Twine("MLIR codegen encountered an error: toy.transpose "
- "does not accept multiple arguments"));
- return nullptr;
- }
- mlir::Value *arg = mlirGen(*call.getArgs()[0]);
- return builder->create<TransposeOp>(location, arg).getResult();
- }
-
- // Codegen the operands first
- SmallVector<mlir::Value *, 4> operands;
- for (auto &expr : call.getArgs()) {
- auto *arg = mlirGen(*expr);
- if (!arg)
- return nullptr;
- operands.push_back(arg);
- }
- // Calls to user-defined function are mapped to a custom call that takes
- // the callee name as an attribute.
- return builder->create<GenericCallOp>(location, call.getCallee(), operands)
- .getResult();
- }
-
- // Emit a call expression. It emits specific operations for two builtins:
- // transpose(x) and print(x). Other identifiers are assumed to be user-defined
- // functions. Return false on failure.
- bool mlirGen(PrintExprAST &call) {
- auto *arg = mlirGen(*call.getArg());
- if (!arg)
- return false;
- auto location = loc(call.loc());
- builder->create<PrintOp>(location, arg);
- return true;
- }
-
- // Emit a constant for a single number (FIXME: semantic? broadcast?)
- mlir::Value *mlirGen(NumberExprAST &num) {
- auto location = loc(num.loc());
- mlir::Type elementType = mlir::FloatType::getF64(&context);
- auto attr = mlir::FloatAttr::getChecked(elementType, num.getValue(),
- loc(num.loc()));
- return builder->create<ConstantOp>(location, attr).getResult();
- }
-
- // Dispatch codegen for the right expression subclass using RTTI.
- mlir::Value *mlirGen(ExprAST &expr) {
- switch (expr.getKind()) {
- case toy::ExprAST::Expr_BinOp:
- return mlirGen(cast<BinaryExprAST>(expr));
- case toy::ExprAST::Expr_Var:
- return mlirGen(cast<VariableExprAST>(expr));
- case toy::ExprAST::Expr_Literal:
- return mlirGen(cast<LiteralExprAST>(expr));
- case toy::ExprAST::Expr_Call:
- return mlirGen(cast<CallExprAST>(expr));
- case toy::ExprAST::Expr_Num:
- return mlirGen(cast<NumberExprAST>(expr));
- default:
- context.emitError(
- loc(expr.loc()),
- Twine("MLIR codegen encountered an unhandled expr kind '") +
- Twine(expr.getKind()) + "'");
- return nullptr;
- }
- }
-
- // Handle a variable declaration, we'll codegen the expression that forms the
- // initializer and record the value in the symbol table before returning it.
- // Future expressions will be able to reference this variable through symbol
- // table lookup.
- mlir::Value *mlirGen(VarDeclExprAST &vardecl) {
- mlir::Value *value = nullptr;
- auto location = loc(vardecl.loc());
- if (auto init = vardecl.getInitVal()) {
- value = mlirGen(*init);
- if (!value)
- return nullptr;
- // We have the initializer value, but in case the variable was declared
- // with specific shape, we emit a "reshape" operation. It will get
- // optimized out later as needed.
- if (!vardecl.getType().shape.empty()) {
- value = builder
- ->create<ReshapeOp>(
- location, value,
- getType(vardecl.getType()).cast<ToyArrayType>())
- .getResult();
- }
- } else {
- context.emitError(loc(vardecl.loc()),
- "Missing initializer in variable declaration");
- return nullptr;
- }
- // Register the value in the symbol table
- declare(vardecl.getName(), value);
- return value;
- }
-
- /// Codegen a list of expression, return false if one of them hit an error.
- bool mlirGen(ExprASTList &blockAST) {
- ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable);
- for (auto &expr : blockAST) {
- // Specific handling for variable declarations, return statement, and
- // print. These can only appear in block list and not in nested
- // expressions.
- if (auto *vardecl = dyn_cast<VarDeclExprAST>(expr.get())) {
- if (!mlirGen(*vardecl))
- return false;
- continue;
- }
- if (auto *ret = dyn_cast<ReturnExprAST>(expr.get())) {
- if (!mlirGen(*ret))
- return false;
- return true;
- }
- if (auto *print = dyn_cast<PrintExprAST>(expr.get())) {
- if (!mlirGen(*print))
- return false;
- continue;
- }
- // Generic expression dispatch codegen.
- if (!mlirGen(*expr))
- return false;
- }
- return true;
- }
-
- /// Build a type from a list of shape dimensions. Types are `array` followed
- /// by an optional dimension list, example: array<2, 2>
- /// They are wrapped in a `toy` dialect (see next chapter) and get printed:
- /// !toy.array<2, 2>
- template <typename T> mlir::Type getType(T shape) {
- SmallVector<int64_t, 8> shape64(shape.begin(), shape.end());
- return ToyArrayType::get(&context, shape64);
- }
-
- /// Build an MLIR type from a Toy AST variable type
- /// (forward to the generic getType(T) above).
- mlir::Type getType(const VarType &type) { return getType(type.shape); }
-};
-
-} // namespace
-
-namespace toy {
-
-// The public API for codegen.
-std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context,
- ModuleAST &moduleAST) {
- return MLIRGenImpl(context).mlirGen(moduleAST);
-}
-
-} // namespace toy
diff --git a/mlir/ShapeInferencePass.cpp b/mlir/ShapeInferencePass.cpp
deleted file mode 100644
index 7e3ea3f..0000000
--- a/mlir/ShapeInferencePass.cpp
+++ /dev/null
@@ -1,387 +0,0 @@
-//===- ShapeInferencePass.cpp - Toy Shape Inference / Func Specialization -===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// This file implements a Module level pass performing interprocedural
-// propagation of array shapes through function specialization.
-//
-//===----------------------------------------------------------------------===//
-
-#include "toy/Dialect.h"
-
-#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/StandardOps/Ops.h"
-#include "mlir/Support/LogicalResult.h"
-#include "llvm/ADT/DenseSet.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/StringSet.h"
-#include "llvm/Support/Debug.h"
-#include "llvm/Support/ErrorHandling.h"
-#include "llvm/Support/raw_ostream.h"
-#include <algorithm>
-
-#define DEBUG_TYPE "toy-shape-inference"
-
-using namespace toy;
-using llvm::MutableArrayRef;
-using llvm::SmallVector;
-using llvm::SmallVectorImpl;
-using llvm::StringRef;
-using llvm::Twine;
-
-/// Create mangled name for function specialization. We will simply append the
-/// shape of the arguments to the function name. For example calling
-///
-/// "toy.generic_call"(%1, %3) {callee: "foo"}
-/// : (!toy<"array<2, 3>">, !toy<"array<2, 3>">) -> !toy<"array">
-///
-/// would be mangled foo_2x3_2x3. This mangling isn't robust as the user could
-/// have provide a function with a similar name. But we will claim this as a
-/// feature: this allow the user to provide custom specialization!
-static std::string mangle(StringRef funcName,
- MutableArrayRef<mlir::OpOperand> operands) {
- std::string mangledName;
- mangledName.reserve(funcName.size() + operands.size() * 6);
- mangledName = funcName;
- for (auto &operand : operands) {
- auto arrayTy = operand.get()->getType().cast<ToyArrayType>();
- mangledName += "_";
- const char *sep = "";
- for (auto dim : arrayTy.getShape()) {
- mangledName += (sep + Twine(dim)).str();
- sep = "x";
- }
- }
- return mangledName;
-}
-
-namespace {
-
-/// The ShapeInferencePass is a ModulePass: it will run on the Module as a
-/// whole. MLIR also supports FunctionPass which are restricted to modify a
-/// single function at a time. This pass couldn't be a function pass due the
-/// nature of its interprocedural transformations.
-///
-/// The algorithm has two levels, first intra-procedurally:
-///
-/// 1) Build a worklist containing all the operations that are returning
-/// a generic Toy array: these are the operations that need shape
-/// inference.
-/// 2) Iterate on the worklist:
-/// a) find an operation to process: the next ready operation in the
-/// worklist has all of its arguments non-generic,
-/// b) if no operation is found, break out of the loop,
-/// c) remove the operation from the worklist,
-/// d) infer the shape of its output from the arguments type.
-/// 3) If the worklist is empty, the algorithm succeeded and we infer the
-/// return type for the function from the return operation.
-///
-/// There is a twist though: when a call to a generic function is encountered,
-/// shape inference requires the return type of the callee to be inferred first.
-/// At this point we need to run specialize the callee by cloning it. Here is
-/// the inter-procedural flow:
-///
-/// 1) Keep a worklist of function to process. Start with function "main".
-/// 2) While the worklist isn't empty:
-/// a) Take the last inserted function in the worklist.
-/// b) Run the intra-procedural shape inference on this function.
-/// c) If the intra-procedural shape inference can't complete, it returns
-/// a Function that needs to be inferred first. In this case, queue this
-/// new function and continue. Otherwise the inference succeeded and we
-/// can pop from the queue.
-///
-class ShapeInferencePass : public mlir::ModulePass<ShapeInferencePass> {
-public:
- // One entry in the inter-procedural worklist. It keeps track of the
- // function to process, the mangled name for this specialization, and the
- // types of the arguments on which to specialize.
- struct FunctionToSpecialize {
- mlir::Function *function;
- std::string mangledName;
- std::vector<mlir::Type> argumentsType;
- };
-
- void runOnModule() override {
- auto &module = getModule();
- auto *main = module.getNamedFunction("main");
- if (!main) {
- module.getContext()->emitError(
- mlir::UnknownLoc::get(module.getContext()),
- "Shape inference failed: can't find a main function\n");
- signalPassFailure();
- return;
- }
-
- /// Inter-procedural loop, initialize with `main` and iterate till
- /// successfully infer the full reachable call-graph from main.
- SmallVector<FunctionToSpecialize, 8> worklist;
- worklist.push_back({main, "", {}});
- while (!worklist.empty()) {
- if (failed(specialize(worklist)))
- return;
- }
-
- // Delete any generic function left
- // FIXME: we may want this as a separate pass.
- for (mlir::Function &function : llvm::make_early_inc_range(module)) {
- if (auto genericAttr =
- function.getAttrOfType<mlir::BoolAttr>("toy.generic")) {
- if (genericAttr.getValue())
- function.erase();
- }
- }
- }
-
- /// Run inference on a function. If a mangledName is provided, we need to
- /// specialize the function: to this end clone it first.
- mlir::LogicalResult
- specialize(SmallVectorImpl<FunctionToSpecialize> &funcWorklist) {
- FunctionToSpecialize &functionToSpecialize = funcWorklist.back();
- mlir::Function *f = functionToSpecialize.function;
-
- // Check if cloning for specialization is needed (usually anything but main)
- // We will create a new function with the concrete types for the parameters
- // and clone the body into it.
- if (!functionToSpecialize.mangledName.empty()) {
- if (getModule().getNamedFunction(functionToSpecialize.mangledName)) {
- funcWorklist.pop_back();
- // Function already specialized, move on.
- return mlir::success();
- }
- // Create a new function with a generic array return type, it will be
- // updated when the inference for the function body completes.
- auto type = mlir::FunctionType::get(functionToSpecialize.argumentsType,
- {ToyArrayType::get(&getContext())},
- &getContext());
- auto *newFunction = new mlir::Function(
- f->getLoc(), functionToSpecialize.mangledName, type, f->getAttrs());
- getModule().getFunctions().push_back(newFunction);
-
- // Clone the function body
- mlir::BlockAndValueMapping mapper;
- f->cloneInto(newFunction, mapper);
- LLVM_DEBUG({
- llvm::dbgs() << "====== Cloned : \n";
- f->dump();
- llvm::dbgs() << "====== Into : \n";
- newFunction->dump();
- });
- f = newFunction;
- f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
- // Remap the entry-block arguments
- // FIXME: this seems like a bug in `cloneInto()` above?
- auto &entryBlock = f->getBlocks().front();
- int blockArgSize = entryBlock.getArguments().size();
- assert(blockArgSize == f->getType().getInputs().size());
- entryBlock.addArguments(f->getType().getInputs());
- auto argList = entryBlock.getArguments();
- for (int argNum = 0; argNum < blockArgSize; ++argNum) {
- argList[0]->replaceAllUsesWith(argList[blockArgSize]);
- entryBlock.eraseArgument(0);
- }
- assert(succeeded(f->verify()));
- }
- LLVM_DEBUG(llvm::dbgs()
- << "Run shape inference on : '" << f->getName() << "'\n");
-
- auto *toyDialect = getContext().getRegisteredDialect("toy");
- if (!toyDialect) {
- getContext().emitError(mlir::UnknownLoc::get(&getContext()),
- "Toy dialect is not registered");
- signalPassFailure();
- return mlir::failure();
- }
-
- // Populate the worklist with the operations that need shape inference:
- // these are the Toy operations that return a generic array.
- llvm::SmallPtrSet<mlir::Operation *, 16> opWorklist;
- f->walk([&](mlir::Operation *op) {
- if (op->getDialect() == toyDialect) {
- if (op->getNumResults() == 1 &&
- op->getResult(0)->getType().cast<ToyArrayType>().isGeneric())
- opWorklist.insert(op);
- }
- });
-
- // Iterate on the operations in the worklist until all operations have been
- // inferred or no change happened (fix point).
- while (!opWorklist.empty()) {
- // Find the next operation ready for inference, that is an operation
- // with all operands already resolved (non-generic).
- auto nextop = llvm::find_if(opWorklist, [](mlir::Operation *op) {
- return llvm::all_of(op->getOperands(), [](mlir::Value *v) {
- return !v->getType().cast<ToyArrayType>().isGeneric();
- });
- });
- if (nextop == opWorklist.end())
- break; // failure: no operations can be inferred.
-
- mlir::Operation *op = *nextop;
- opWorklist.erase(op);
- LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n");
-
- // The add operation is trivial: propagate the input type as is.
- if (auto addOp = op->dyn_cast<AddOp>()) {
- op->getResult(0)->setType(op->getOperand(0)->getType());
- continue;
- }
-
- // Transpose is easy: just invert the dimensions.
- if (op->getName().getStringRef() == "toy.transpose") {
- SmallVector<int64_t, 2> dims;
- auto arrayTy = op->getOperand(0)->getType().cast<ToyArrayType>();
- dims.insert(dims.end(), arrayTy.getShape().begin(),
- arrayTy.getShape().end());
- if (dims.size() == 2)
- std::swap(dims[0], dims[1]);
- op->getResult(0)->setType(ToyArrayType::get(&getContext(), dims));
- continue;
- }
-
- // Multiplication is a bit trickier, handle rank 1 as dot product and rank
- // 2 as matrix multiplications.
- // We need to be careful about rank mismatch here: the verifier could
- // catch it but shape inference earlier in the pass could generate an
- // invalid IR (from an invalid Toy input of course) and we wouldn't want
- // to crash here.
- if (auto mulOp = op->dyn_cast<MulOp>()) {
- auto lhs = mulOp.getLHS()->getType().cast<ToyArrayType>();
- auto rhs = mulOp.getRHS()->getType().cast<ToyArrayType>();
- auto lhsRank = lhs.getShape().size();
- auto rhsRank = rhs.getShape().size();
- if (lhsRank != rhsRank) {
- op->emitError("Shape mismatch: LHS and RHS must have the same "
- "rank for multiplication, got " +
- Twine(lhsRank) + " vs " + Twine(lhsRank));
- return mlir::failure();
- }
- SmallVector<int64_t, 2> dims;
- if (lhsRank == 1) {
- // dot product, result shape is <1>
- dims.push_back(1);
- } else {
- if (lhsRank != 2) {
- op->emitError(
- "Shape mismatch: expect rank 1 or 2 for mul operands, got " +
- Twine(lhsRank));
- return mlir::failure();
- }
- dims.push_back(lhs.getShape()[0]);
- dims.push_back(rhs.getShape()[1]);
- }
- op->getResult(0)->setType(ToyArrayType::get(&getContext(), dims));
- continue;
- }
-
- // Process calls: lookup the callee after mangling the name with the
- // argument shapes. If the callee does not exist, we stop the inference
- // for this function, queue the callee in the inter-procedural work list,
- // and return. The current function stays in the work list and will
- // restart after the callee is processed.
- if (auto callOp = op->dyn_cast<GenericCallOp>()) {
- auto calleeName = callOp.getCalleeName();
- auto *callee = getModule().getNamedFunction(calleeName);
- if (!callee) {
- f->emitError(
- llvm::Twine("Shape inference failed, call to unknown '") +
- calleeName + "'");
- signalPassFailure();
- return mlir::failure();
- }
- auto mangledName = mangle(calleeName, op->getOpOperands());
- LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName
- << "', mangled: '" << mangledName << "'\n");
- auto *mangledCallee = getModule().getNamedFunction(mangledName);
- if (!mangledCallee) {
- // Can't find the target, this is where we queue the request for the
- // callee and stop the inference for the current function now.
- std::vector<mlir::Type> funcArgs;
- for (auto operand : op->getOperands())
- funcArgs.push_back(operand->getType());
- funcWorklist.push_back(
- {callee, std::move(mangledName), std::move(funcArgs)});
- return mlir::success();
- }
- // Found a specialized callee! Let's turn this into a normal call
- // operation.
- SmallVector<mlir::Value *, 8> operands;
- for (mlir::Value *v : op->getOperands())
- operands.push_back(v);
- mlir::FuncBuilder builder(f);
- builder.setInsertionPoint(op);
- auto newCall =
- builder.create<mlir::CallOp>(op->getLoc(), mangledCallee, operands);
- if (newCall.getNumResults()) {
- op->getResult(0)->replaceAllUsesWith(newCall.getResult(0));
- op->erase();
- continue;
- }
- }
- }
-
- // Done with inference on this function, removing it from the worklist.
- funcWorklist.pop_back();
- // Mark the function as non-generic now that inference has succeeded
- f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
-
- // If the operation worklist isn't empty, this indicates a failure.
- if (!opWorklist.empty()) {
- std::string str;
- llvm::raw_string_ostream errorMsg(str);
- errorMsg << "Shape inference failed, " << opWorklist.size()
- << " operations couldn't be inferred\n";
- for (auto *ope : opWorklist)
- errorMsg << " - " << *ope << "\n";
- f->emitError(errorMsg.str());
- signalPassFailure();
- return mlir::failure();
- }
-
- // Finally, update the return type of the function based on the argument to
- // the return operation.
- for (auto &block : f->getBlocks()) {
- auto ret = block.getTerminator()->cast<ReturnOp>();
- if (!ret)
- continue;
- if (ret.getNumOperands() &&
- f->getType().getResult(0) == ret.getOperand()->getType())
- // type match, we're done
- break;
- SmallVector<mlir::Type, 1> retTy;
- if (ret.getNumOperands())
- retTy.push_back(ret.getOperand()->getType());
- mlir::Type elementType = mlir::FloatType::getF64(&getContext());
- std::vector<mlir::Type> argumentsType;
- for (auto arg : f->getArguments())
- argumentsType.push_back(arg->getType());
- auto newType =
- mlir::FunctionType::get(argumentsType, retTy, &getContext());
- f->setType(newType);
- assert(succeeded(f->verify()));
- break;
- }
- return mlir::success();
- }
-};
-} // end anonymous namespace
-
-namespace toy {
-mlir::Pass *createShapeInferencePass() { return new ShapeInferencePass(); }
-} // namespace toy
diff --git a/mlir/ToyCombine.cpp b/mlir/ToyCombine.cpp
deleted file mode 100644
index 8d6aed6..0000000
--- a/mlir/ToyCombine.cpp
+++ /dev/null
@@ -1,209 +0,0 @@
-//===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// This file implements a simple combiner for optimizing pattern in the Toy
-// dialect.
-//
-//===----------------------------------------------------------------------===//
-
-#include "toy/Dialect.h"
-
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/StandardTypes.h"
-
-#include <numeric>
-
-namespace toy {
-
-namespace {
-
-/// Fold transpose(transpose(x)) -> transpose(x)
-struct SimplifyRedundantTranspose : public mlir::RewritePattern {
- /// We register this pattern to match every toy.transpose in the IR.
- /// The "benefit" is used by the framework to order the patterns and process
- /// them in order of profitability.
- SimplifyRedundantTranspose(mlir::MLIRContext *context)
- : RewritePattern(TransposeOp::getOperationName(), /* benefit = */ 1,
- context) {}
-
- /// This method is attempting to match a pattern and rewrite it. The rewriter
- /// argument is the orchestrator of the sequence of rewrites. It is expected
- /// to interact with it to perform any changes to the IR from here.
- mlir::PatternMatchResult
- matchAndRewrite(mlir::Operation *op,
- mlir::PatternRewriter &rewriter) const override {
- // We can directly cast the current operation as this will only get invoked
- // on TransposeOp.
- TransposeOp transpose = op->cast<TransposeOp>();
- // look through the input to the current transpose
- mlir::Value *transposeInput = transpose.getOperand();
- mlir::Operation *transposeInputInst = transposeInput->getDefiningOp();
- // If the input is defined by another Transpose, bingo!
- TransposeOp transposeInputOp =
- mlir::dyn_cast_or_null<TransposeOp>(transposeInputInst);
- if (!transposeInputOp)
- return matchFailure();
-
- // Use the rewriter to perform the replacement
- rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
- return matchSuccess();
- }
-};
-
-/// Fold reshape(constant(x)) -> constant(x'), with x' being reshaped in place.
-struct SimplifyReshapeConstant : public mlir::RewritePattern {
- SimplifyReshapeConstant(mlir::MLIRContext *context)
- : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1,
- context) {}
-
- mlir::PatternMatchResult
- matchAndRewrite(mlir::Operation *op,
- mlir::PatternRewriter &rewriter) const override {
- ReshapeOp reshape = op->cast<ReshapeOp>();
- // look through the input to the current reshape
- mlir::Value *reshapeInput = reshape.getOperand();
- mlir::Operation *reshapeInputInst = reshapeInput->getDefiningOp();
- // If the input is defined by another reshape, bingo!
- ConstantOp constantOp =
- mlir::dyn_cast_or_null<ConstantOp>(reshapeInputInst);
- if (!constantOp)
- return matchFailure();
-
- auto reshapeType = op->getResult(0)->getType().cast<ToyArrayType>();
- if (auto valueAttr =
- constantOp.getAttrOfType<mlir::DenseElementsAttr>("value")) {
- // FIXME Check matching of element count!
- // auto oldType = constantOp.getType();
- auto newType = rewriter.getTensorType(
- reshapeType.getShape(), valueAttr.getType().getElementType());
- auto newAttr =
- mlir::DenseElementsAttr::get(newType, valueAttr.getRawData());
- auto newConstant = rewriter.create<ConstantOp>(
- constantOp.getLoc(), reshapeType.getShape(), newAttr);
- rewriter.replaceOp(op, {newConstant});
- } else if (auto valueAttr =
- constantOp.getAttrOfType<mlir::FloatAttr>("value")) {
- // Broadcast
- auto dataSize = std::accumulate(reshapeType.getShape().begin(),
- reshapeType.getShape().end(), 1,
- std::multiplies<int>());
- std::vector<mlir::Attribute> data(dataSize, valueAttr);
- auto tensorTy = rewriter.getTensorType(reshapeType.getShape(),
- reshapeType.getElementType());
- auto newAttr = mlir::DenseElementsAttr::get(tensorTy, data);
- auto newConstant = rewriter.create<ConstantOp>(
- constantOp.getLoc(), reshapeType.getShape(), newAttr);
- rewriter.replaceOp(op, {newConstant});
- } else {
- llvm_unreachable("Unsupported Constant format");
- }
- return matchSuccess();
- }
-};
-
-/// Fold reshape(reshape(x)) -> reshape(x)
-struct SimplifyReshapeReshape : public mlir::RewritePattern {
- SimplifyReshapeReshape(mlir::MLIRContext *context)
- : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1,
- context) {}
-
- mlir::PatternMatchResult
- matchAndRewrite(mlir::Operation *op,
- mlir::PatternRewriter &rewriter) const override {
- ReshapeOp reshape = op->cast<ReshapeOp>();
- // look through the input to the current reshape
- mlir::Value *reshapeInput = reshape.getOperand();
- mlir::Operation *reshapeInputInst = reshapeInput->getDefiningOp();
- // If the input is defined by another reshape, bingo!
- ReshapeOp reshapeInputOp =
- mlir::dyn_cast_or_null<ReshapeOp>(reshapeInputInst);
- if (!reshapeInputOp)
- return matchFailure();
-
- // Use the rewriter to perform the replacement
- rewriter.replaceOp(op, {reshapeInputOp});
- return matchSuccess();
- }
-};
-
-/// Fold reshape(x)) -> x, when input type matches output type
-struct SimplifyNullReshape : public mlir::RewritePattern {
- SimplifyNullReshape(mlir::MLIRContext *context)
- : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1,
- context) {}
-
- mlir::PatternMatchResult
- matchAndRewrite(mlir::Operation *op,
- mlir::PatternRewriter &rewriter) const override {
- ReshapeOp reshape = op->cast<ReshapeOp>();
- if (reshape.getOperand()->getType() != reshape.getResult()->getType())
- return matchFailure();
- rewriter.replaceOp(reshape, {reshape.getOperand()});
- return matchSuccess();
- }
-};
-
-} // end anonymous namespace.
-
-// Register our patterns for rewrite by the Canonicalization framework.
-void TransposeOp::getCanonicalizationPatterns(
- mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
- results.push_back(llvm::make_unique<SimplifyRedundantTranspose>(context));
-}
-
-// Register our patterns for rewrite by the Canonicalization framework.
-void ReshapeOp::getCanonicalizationPatterns(
- mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
- results.push_back(llvm::make_unique<SimplifyReshapeConstant>(context));
- results.push_back(llvm::make_unique<SimplifyReshapeReshape>(context));
- results.push_back(llvm::make_unique<SimplifyNullReshape>(context));
-}
-
-namespace {
-
-/// Fold type.cast(x) -> x, when input type matches output type
-struct SimplifyIdentityTypeCast : public mlir::RewritePattern {
- SimplifyIdentityTypeCast(mlir::MLIRContext *context)
- : RewritePattern(TypeCastOp::getOperationName(), /* benefit = */ 1,
- context) {}
-
- mlir::PatternMatchResult
- matchAndRewrite(mlir::Operation *op,
- mlir::PatternRewriter &rewriter) const override {
- TypeCastOp typeCast = op->cast<TypeCastOp>();
- auto resTy = typeCast.getResult()->getType();
- auto *candidateOp = op;
- while (candidateOp && candidateOp->isa<TypeCastOp>()) {
- if (resTy == candidateOp->getOperand(0)->getType()) {
- rewriter.replaceOp(typeCast, {candidateOp->getOperand(0)});
- return matchSuccess();
- }
- candidateOp = candidateOp->getOperand(0)->getDefiningOp();
- }
- return matchFailure();
- }
-};
-
-} // end anonymous namespace.
-
-void TypeCastOp::getCanonicalizationPatterns(
- mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
- results.push_back(llvm::make_unique<SimplifyIdentityTypeCast>(context));
-}
-
-} // namespace toy
diff --git a/mlir/ToyDialect.cpp b/mlir/ToyDialect.cpp
deleted file mode 100644
index be117f5..0000000
--- a/mlir/ToyDialect.cpp
+++ /dev/null
@@ -1,405 +0,0 @@
-//===- ToyDialect.cpp - Toy IR Dialect registration in MLIR ---------------===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// This file implements the dialect for the Toy IR: custom type parsing and
-// operation verification.
-//
-//===----------------------------------------------------------------------===//
-
-#include "toy/Dialect.h"
-
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/Support/STLExtras.h"
-#include "llvm/ADT/iterator_range.h"
-#include "llvm/Support/ErrorHandling.h"
-#include "llvm/Support/Regex.h"
-#include "llvm/Support/raw_ostream.h"
-
-using llvm::ArrayRef;
-using llvm::raw_ostream;
-using llvm::raw_string_ostream;
-using llvm::SmallVector;
-using llvm::StringRef;
-using llvm::Twine;
-
-namespace toy {
-namespace detail {
-
-/// This class holds the implementation of the ToyArrayType.
-/// It is intended to be uniqued based on its content and owned by the context.
-struct ToyArrayTypeStorage : public mlir::TypeStorage {
- /// This defines how we unique this type in the context: our key contains
- /// only the shape, a more complex type would have multiple entries in the
- /// tuple here.
- /// The element of the tuples usually matches 1-1 the arguments from the
- /// public `get()` method arguments from the facade.
- using KeyTy = std::tuple<ArrayRef<int64_t>>;
- static unsigned hashKey(const KeyTy &key) {
- return llvm::hash_combine(std::get<0>(key));
- }
- /// When the key hash hits an existing type, we compare the shape themselves
- /// to confirm we have the right type.
- bool operator==(const KeyTy &key) const { return key == KeyTy(getShape()); }
-
- /// This is a factory method to create our type storage. It is only
- /// invoked after looking up the type in the context using the key and not
- /// finding it.
- static ToyArrayTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
- const KeyTy &key) {
- // Copy the shape array into the bumpptr allocator owned by the context.
- ArrayRef<int64_t> shape = allocator.copyInto(std::get<0>(key));
-
- // Allocate the instance for the ToyArrayTypeStorage itself
- auto *storage = allocator.allocate<ToyArrayTypeStorage>();
- // Initialize the instance using placement new.
- return new (storage) ToyArrayTypeStorage(shape);
- }
-
- ArrayRef<int64_t> getShape() const { return shape; }
-
-private:
- ArrayRef<int64_t> shape;
-
- /// Constructor is only invoked from the `construct()` method above.
- ToyArrayTypeStorage(ArrayRef<int64_t> shape) : shape(shape) {}
-};
-
-} // namespace detail
-
-mlir::Type ToyArrayType::getElementType() {
- return mlir::FloatType::getF64(getContext());
-}
-
-ToyArrayType ToyArrayType::get(mlir::MLIRContext *context,
- ArrayRef<int64_t> shape) {
- return Base::get(context, ToyTypeKind::TOY_ARRAY, shape);
-}
-
-ArrayRef<int64_t> ToyArrayType::getShape() { return getImpl()->getShape(); }
-
-mlir::MemRefType ToyArrayType::toMemref() {
- auto memRefType = mlir::MemRefType::get(getShape(), getElementType(), {}, 0);
- return memRefType;
-}
-
-/// Dialect creation, the instance will be owned by the context. This is the
-/// point of registration of custom types and operations for the dialect.
-ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
- addOperations<ConstantOp, GenericCallOp, PrintOp, TransposeOp, ReshapeOp,
- MulOp, AddOp, ReturnOp, AllocOp, TypeCastOp>();
- addTypes<ToyArrayType>();
-}
-
-/// Parse a type registered to this dialect, we expect only Toy arrays.
-mlir::Type ToyDialect::parseType(StringRef tyData, mlir::Location loc) const {
- // Sanity check: we only support array or array<...>
- if (!tyData.startswith("array")) {
- getContext()->emitError(loc, "Invalid Toy type '" + tyData +
- "', array expected");
- return nullptr;
- }
- // Drop the "array" prefix from the type name, we expect either an empty
- // string or just the shape.
- tyData = tyData.drop_front(StringRef("array").size());
- // This is the generic array case without shape, early return it.
- if (tyData.empty())
- return ToyArrayType::get(getContext());
-
- // Use a regex to parse the shape (for efficient we should store this regex in
- // the dialect itself).
- SmallVector<StringRef, 4> matches;
- auto shapeRegex = llvm::Regex("^<([0-9]+)(, ([0-9]+))*>$");
- if (!shapeRegex.match(tyData, &matches)) {
- getContext()->emitError(loc, "Invalid toy array shape '" + tyData + "'");
- return nullptr;
- }
- SmallVector<int64_t, 4> shape;
- // Iterate through the captures, skip the first one which is the full string.
- for (auto dimStr :
- llvm::make_range(std::next(matches.begin()), matches.end())) {
- if (dimStr.startswith(","))
- continue; // POSIX misses non-capturing groups.
- if (dimStr.empty())
- continue; // '*' makes it an optional group capture
- // Convert the capture to an integer
- unsigned long long dim;
- if (getAsUnsignedInteger(dimStr, /* Radix = */ 10, dim)) {
- getContext()->emitError(
- loc, "Couldn't parse dimension as integer, matched: " + dimStr);
- return mlir::Type();
- }
- shape.push_back(dim);
- }
- // Finally we collected all the dimensions in the shape,
- // create the array type.
- return ToyArrayType::get(getContext(), shape);
-}
-
-/// Print a Toy array type, for example `array<2, 3, 4>`
-void ToyDialect::printType(mlir::Type type, raw_ostream &os) const {
- auto arrayTy = type.dyn_cast<ToyArrayType>();
- if (!arrayTy) {
- os << "unknown toy type";
- return;
- }
- os << "array";
- if (!arrayTy.getShape().empty()) {
- os << "<";
- mlir::interleaveComma(arrayTy.getShape(), os);
- os << ">";
- }
-}
-
-////////////////////////////////////////////////////////////////////////////////
-//////////////////// Custom Operations for the Dialect /////////////////////////
-////////////////////////////////////////////////////////////////////////////////
-
-/// Helper to verify that the result of an operation is a Toy array type.
-template <typename T> static mlir::LogicalResult verifyToyReturnArray(T *op) {
- if (!op->getResult()->getType().template isa<ToyArrayType>()) {
- std::string msg;
- raw_string_ostream os(msg);
- os << "expects a Toy Array for its argument, got "
- << op->getResult()->getType();
- return op->emitOpError(os.str());
- }
- return mlir::success();
-}
-
-/// Helper to verify that the two operands of a binary operation are Toy
-/// arrays..
-template <typename T> static mlir::LogicalResult verifyToyBinOperands(T *op) {
- if (!op->getOperand(0)->getType().template isa<ToyArrayType>()) {
- std::string msg;
- raw_string_ostream os(msg);
- os << "expects a Toy Array for its LHS, got "
- << op->getOperand(0)->getType();
- return op->emitOpError(os.str());
- }
- if (!op->getOperand(1)->getType().template isa<ToyArrayType>()) {
- std::string msg;
- raw_string_ostream os(msg);
- os << "expects a Toy Array for its LHS, got "
- << op->getOperand(0)->getType();
- return op->emitOpError(os.str());
- }
- return mlir::success();
-}
-
-/// Build a constant operation.
-/// The builder is passed as an argument, so is the state that this method is
-/// expected to fill in order to build the operation.
-void ConstantOp::build(mlir::Builder *builder, mlir::OperationState *state,
- ArrayRef<int64_t> shape, mlir::DenseElementsAttr value) {
- state->types.push_back(ToyArrayType::get(builder->getContext(), shape));
- auto dataAttribute = builder->getNamedAttr("value", value);
- state->attributes.push_back(dataAttribute);
-}
-
-/// Build a constant operation.
-/// The builder is passed as an argument, so is the state that this method is
-/// expected to fill in order to build the operation.
-void ConstantOp::build(mlir::Builder *builder, mlir::OperationState *state,
- mlir::FloatAttr value) {
- // Broadcast and forward to the other build factory
- mlir::Type elementType = mlir::FloatType::getF64(builder->getContext());
- auto dataType = builder->getTensorType({1}, elementType);
- auto dataAttribute = builder->getDenseElementsAttr(dataType, {value})
- .cast<mlir::DenseElementsAttr>();
-
- ConstantOp::build(builder, state, {1}, dataAttribute);
-}
-
-/// Verifier for constant operation.
-mlir::LogicalResult ConstantOp::verify() {
- // Ensure that the return type is a Toy array
- if (failed(verifyToyReturnArray(this)))
- return mlir::failure();
-
- // We expect the constant itself to be stored as an attribute.
- auto dataAttr = getAttr("value").dyn_cast<mlir::DenseElementsAttr>();
- if (!dataAttr) {
- return emitOpError(
- "missing valid `value` DenseElementsAttribute on toy.constant()");
- }
- auto attrType = dataAttr.getType().dyn_cast<mlir::TensorType>();
- if (!attrType) {
- return emitOpError(
- "missing valid `value` DenseElementsAttribute on toy.constant()");
- }
-
- // If the return type of the constant is not a generic array, the shape must
- // match the shape of the attribute holding the data.
- auto resultType = getResult()->getType().cast<ToyArrayType>();
- if (!resultType.isGeneric()) {
- if (attrType.getRank() != resultType.getRank()) {
- return emitOpError("The rank of the toy.constant return type must match "
- "the one of the attached value attribute: " +
- Twine(attrType.getRank()) +
- " != " + Twine(resultType.getRank()));
- }
- for (int dim = 0; dim < attrType.getRank(); ++dim) {
- if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
- std::string msg;
- raw_string_ostream os(msg);
- return emitOpError(
- "Shape mismatch between toy.constant return type and its "
- "attribute at dimension " +
- Twine(dim) + ": " + Twine(attrType.getShape()[dim]) +
- " != " + Twine(resultType.getShape()[dim]));
- }
- }
- }
- return mlir::success();
-}
-
-void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState *state,
- StringRef callee, ArrayRef<mlir::Value *> arguments) {
- // Generic call always returns a generic ToyArray initially
- state->types.push_back(ToyArrayType::get(builder->getContext()));
- state->operands.assign(arguments.begin(), arguments.end());
- auto calleeAttr = builder->getStringAttr(callee);
- state->attributes.push_back(builder->getNamedAttr("callee", calleeAttr));
-}
-
-mlir::LogicalResult GenericCallOp::verify() {
- // Verify that every operand is a Toy Array
- for (int opId = 0, num = getNumOperands(); opId < num; ++opId) {
- if (!getOperand(opId)->getType().template isa<ToyArrayType>()) {
- std::string msg;
- raw_string_ostream os(msg);
- os << "expects a Toy Array for its " << opId << " operand, got "
- << getOperand(opId)->getType();
- return emitOpError(os.str());
- }
- }
- return mlir::success();
-}
-
-/// Return the name of the callee.
-StringRef GenericCallOp::getCalleeName() {
- return getAttr("callee").cast<mlir::StringAttr>().getValue();
-}
-
-template <typename T> static mlir::LogicalResult verifyToySingleOperand(T *op) {
- if (!op->getOperand()->getType().template isa<ToyArrayType>()) {
- std::string msg;
- raw_string_ostream os(msg);
- os << "expects a Toy Array for its argument, got "
- << op->getOperand()->getType();
- return op->emitOpError(os.str());
- }
- return mlir::success();
-}
-
-void ReturnOp::build(mlir::Builder *builder, mlir::OperationState *state,
- mlir::Value *value) {
- // Return does not return any value and has an optional single argument
- if (value)
- state->operands.push_back(value);
-}
-
-mlir::LogicalResult ReturnOp::verify() {
- if (getNumOperands() > 1)
- return emitOpError("expects zero or one operand, got " +
- Twine(getNumOperands()));
- if (hasOperand() && failed(verifyToySingleOperand(this)))
- return mlir::failure();
- return mlir::success();
-}
-
-void PrintOp::build(mlir::Builder *builder, mlir::OperationState *state,
- mlir::Value *value) {
- // Print does not return any value and has a single argument
- state->operands.push_back(value);
-}
-
-mlir::LogicalResult PrintOp::verify() {
- if (failed(verifyToySingleOperand(this)))
- return mlir::failure();
- return mlir::success();
-}
-
-void TransposeOp::build(mlir::Builder *builder, mlir::OperationState *state,
- mlir::Value *value) {
- state->types.push_back(ToyArrayType::get(builder->getContext()));
- state->operands.push_back(value);
-}
-
-mlir::LogicalResult TransposeOp::verify() {
- if (failed(verifyToySingleOperand(this)))
- return mlir::failure();
- return mlir::success();
-}
-
-void ReshapeOp::build(mlir::Builder *builder, mlir::OperationState *state,
- mlir::Value *value, ToyArrayType reshapedType) {
- state->types.push_back(reshapedType);
- state->operands.push_back(value);
-}
-
-mlir::LogicalResult ReshapeOp::verify() {
- if (failed(verifyToySingleOperand(this)))
- return mlir::failure();
- auto retTy = getResult()->getType().dyn_cast<ToyArrayType>();
- if (!retTy)
- return emitOpError("toy.reshape is expected to produce a Toy array");
- if (retTy.isGeneric())
- return emitOpError("toy.reshape is expected to produce a shaped Toy array, "
- "got a generic one.");
- return mlir::success();
-}
-
-void AddOp::build(mlir::Builder *builder, mlir::OperationState *state,
- mlir::Value *lhs, mlir::Value *rhs) {
- state->types.push_back(ToyArrayType::get(builder->getContext()));
- state->operands.push_back(lhs);
- state->operands.push_back(rhs);
-}
-
-mlir::LogicalResult AddOp::verify() {
- if (failed(verifyToyBinOperands(this)))
- return mlir::failure();
- return mlir::success();
-}
-
-void MulOp::build(mlir::Builder *builder, mlir::OperationState *state,
- mlir::Value *lhs, mlir::Value *rhs) {
- state->types.push_back(ToyArrayType::get(builder->getContext()));
- state->operands.push_back(lhs);
- state->operands.push_back(rhs);
-}
-
-mlir::LogicalResult MulOp::verify() {
- if (failed(verifyToyBinOperands(this)))
- return mlir::failure();
- return mlir::success();
-}
-
-void AllocOp::build(mlir::Builder *builder, mlir::OperationState *state,
- mlir::Type retType) {
- state->types.push_back(retType);
-}
-
-void TypeCastOp::build(mlir::Builder *builder, mlir::OperationState *state,
- mlir::Value *value, mlir::Type destTy) {
- state->operands.push_back(value);
- state->types.push_back(destTy);
-}
-
-} // namespace toy