From 0781257b2a8d544abdcce38824a9b8288a04800d Mon Sep 17 00:00:00 2001 From: Tuowen Zhao Date: Sat, 27 Apr 2019 19:05:25 -0600 Subject: Split toy dialect using static registration --- CMakeLists.txt | 30 +-- mlir/EarlyLowering.cpp | 158 --------------- mlir/LateLowering.cpp | 452 ----------------------------------------- mlir/MLIRGen.cpp | 480 -------------------------------------------- mlir/ShapeInferencePass.cpp | 387 ----------------------------------- mlir/ToyCombine.cpp | 209 ------------------- mlir/ToyDialect.cpp | 405 ------------------------------------- toy/CMakeLists.txt | 6 + toy/EarlyLowering.cpp | 158 +++++++++++++++ toy/LateLowering.cpp | 452 +++++++++++++++++++++++++++++++++++++++++ toy/MLIRGen.cpp | 480 ++++++++++++++++++++++++++++++++++++++++++++ toy/RegisterDialects.cpp | 10 + toy/ShapeInferencePass.cpp | 387 +++++++++++++++++++++++++++++++++++ toy/ToyCombine.cpp | 209 +++++++++++++++++++ toy/ToyDialect.cpp | 405 +++++++++++++++++++++++++++++++++++++ toyc.cpp | 3 - 16 files changed, 2122 insertions(+), 2109 deletions(-) delete mode 100644 mlir/EarlyLowering.cpp delete mode 100644 mlir/LateLowering.cpp delete mode 100644 mlir/MLIRGen.cpp delete mode 100644 mlir/ShapeInferencePass.cpp delete mode 100644 mlir/ToyCombine.cpp delete mode 100644 mlir/ToyDialect.cpp create mode 100644 toy/CMakeLists.txt create mode 100644 toy/EarlyLowering.cpp create mode 100644 toy/LateLowering.cpp create mode 100644 toy/MLIRGen.cpp create mode 100644 toy/RegisterDialects.cpp create mode 100644 toy/ShapeInferencePass.cpp create mode 100644 toy/ToyCombine.cpp create mode 100644 toy/ToyDialect.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index f04ac1f..7c084ae 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,45 +11,40 @@ else() message(FATAL_ERROR "LLVM not found; it is derived from MLIR_INSTALL_PREFIX which has value of ${MLIR_INSTALL_PREFIX}") endif() +set (CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR};${CMAKE_MODULE_PATH}") +include(AddLLVM) + function(whole_archive_link target) if("${CMAKE_SYSTEM_NAME}" STREQUAL "Darwin") - set(link_flags "-L${CMAKE_BINARY_DIR}/lib ") + set(link_flags "") FOREACH(LIB ${ARGN}) string(CONCAT link_flags ${link_flags} "-Wl,-force_load ${CMAKE_BINARY_DIR}/lib/lib${LIB}.a ") ENDFOREACH(LIB) else() - set(link_flags "-L${CMAKE_BINARY_DIR}/lib -Wl,--whole-archive,") + set(link_flags "-Wl,--whole-archive,") FOREACH(LIB ${ARGN}) string(CONCAT link_flags ${link_flags} "-l${LIB},") ENDFOREACH(LIB) string(CONCAT link_flags ${link_flags} "--no-whole-archive") endif() - set_target_properties(${target} PROPERTIES LINK_FLAGS ${link_flags}) + set_target_properties(${target} PROPERTIES LINK_FLAGS "${link_flags}") endfunction(whole_archive_link) llvm_map_components_to_libnames(llvm_libs support) -if (NOT ${LLVM_ENABLE_RTTI}) - set(CMAKE_CXX_FLAGS "-fno-rtti ${CMAKE_CXX_FLAGS}") -endif() - include_directories( include ${LLVM_INCLUDE_DIR}) -link_directories(${LLVM_LIBRARY_DIR}) +link_directories( + ${LLVM_LIBRARY_DIR} + ${CMAKE_BINARY_DIR}) add_definitions(${LLVM_DEFINITIONS}) add_executable(toyc toyc.cpp parser/AST.cpp - mlir/EarlyLowering.cpp - mlir/LateLowering.cpp - mlir/MLIRGen.cpp - mlir/ShapeInferencePass.cpp - mlir/ToyDialect.cpp - mlir/ToyCombine.cpp ) target_link_libraries(toyc @@ -71,4 +66,9 @@ target_link_libraries(toyc MLIRSupport ) -whole_archive_link(toyc MLIRStandardOps MLIRAffineOps) +add_dependencies(toyc ToyDialect) + +llvm_update_compile_flags(toyc) +whole_archive_link(toyc ToyDialect MLIRStandardOps MLIRAffineOps) + +add_subdirectory(toy) diff --git a/mlir/EarlyLowering.cpp b/mlir/EarlyLowering.cpp deleted file mode 100644 index 634c72e..0000000 --- a/mlir/EarlyLowering.cpp +++ /dev/null @@ -1,158 +0,0 @@ -//=======- EarlyLowering.cpp - Toy Lowering to Linear Algebra Dialect -=======// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file implements early lowering of Toy IR to Linalg Dialect: we only -// lower the computationally intensive part of the program (matmul...) to a -// dialect specialized for optimizations. -// -// This is intended to showcase how multiple dialects can cohabit in the same -// function. After this lowering, you would still have toy.print in the IR for -// example. -// -//===----------------------------------------------------------------------===// - -#include "toy/Dialect.h" - -#include "linalg3/Intrinsics.h" -#include "linalg1/ViewOp.h" -#include "linalg3/TensorOps.h" -#include "mlir/EDSC/Builders.h" -#include "mlir/EDSC/Helpers.h" -#include "mlir/EDSC/Intrinsics.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/OperationSupport.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/LLVMIR/LLVMDialect.h" -#include "mlir/Parser.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" -#include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Type.h" - -#include - -using namespace mlir; - -namespace { -/// Utility function for type casting: this is making the type checker happy, -/// while delaying the actual work involved to convert the type. Most of the -/// time both side of the cast (producer and consumer) will be lowered to a -/// dialect like LLVM and end up with the same LLVM representation, at which -/// point this becomes a no-op and is eliminated. -Value *typeCast(FuncBuilder &builder, Value *val, Type destTy) { - if (val->getType() == destTy) - return val; - return builder.create(val->getLoc(), val, destTy) - .getResult(); -} - -/// Create a type cast to turn a toy.array into a memref. The Toy Array will be -/// lowered to a memref during buffer allocation, at which point the type cast -/// becomes useless. -Value *memRefTypeCast(FuncBuilder &builder, Value *val) { - if (val->getType().isa()) - return val; - auto toyArrayTy = val->getType().dyn_cast(); - if (!toyArrayTy) - return val; - return typeCast(builder, val, toyArrayTy.toMemref()); -} - -/// Lower toy.mul to Linalg `matmul`. -/// -/// This class inherit from `DialectOpConversion` and override `rewrite`, -/// similarly to the PatternRewriter introduced in the previous chapter. -/// It will be called by the DialectConversion framework (see `LateLowering` -/// class below). -class MulOpConversion : public DialectOpConversion { -public: - explicit MulOpConversion(MLIRContext *context) - : DialectOpConversion(toy::MulOp::getOperationName(), 1, context) {} - - SmallVector rewrite(Operation *op, ArrayRef operands, - FuncBuilder &rewriter) const override { - using namespace edsc; - using intrinsics::constant_index; - using linalg::intrinsics::range; - using linalg::intrinsics::view; - toy::MulOp mul = op->cast(); - auto loc = mul.getLoc(); - Value *result = memRefTypeCast( - rewriter, rewriter.create(loc, mul.getResult()->getType()) - .getResult()); - Value *lhs = memRefTypeCast(rewriter, operands[0]); - auto memrefLHSTy = lhs->getType().cast(); - Value *rhs = memRefTypeCast(rewriter, operands[1]); - auto memrefRHSTy = rhs->getType().cast(); - mlir::edsc::ScopedContext scope(rewriter, loc); - edsc::ValueHandle r0 = - range(constant_index(0), constant_index(memrefLHSTy.getDimSize(0)), - constant_index(1)); - edsc::ValueHandle r1 = - range(constant_index(0), constant_index(memrefLHSTy.getDimSize(1)), - constant_index(1)); - edsc::ValueHandle r2 = - range(constant_index(0), constant_index(memrefRHSTy.getDimSize(1)), - constant_index(1)); - auto lhsView = view(lhs, {r0, r1}); - auto rhsView = view(rhs, {r1, r2}); - auto resultView = view(result, {r0, r2}); - rewriter.create(loc, lhsView, rhsView, resultView); - return {typeCast(rewriter, result, mul.getType())}; - } -}; - -// The conversion class from Toy IR Dialect to a mix of Linalg and LLVM. -class EarlyLowering : public DialectConversion { -protected: - // Initialize the list of converters. - llvm::DenseSet - initConverters(MLIRContext *context) override { - return ConversionListBuilder::build(&allocator, context); - } - -private: - llvm::BumpPtrAllocator allocator; -}; - -/// This is lowering to Linalg the parts that are computationally intensive -/// (like matmul for example...) while keeping the rest of the code in the Toy -/// dialect. -struct EarlyLoweringPass : public ModulePass { - - void runOnModule() override { - if (failed(EarlyLowering().convert(&getModule()))) { - getModule().getContext()->emitError( - mlir::UnknownLoc::get(getModule().getContext()), - "Error lowering Toy\n"); - signalPassFailure(); - } - } -}; -} // end anonymous namespace - -namespace toy { -Pass *createEarlyLoweringPass() { return new EarlyLoweringPass(); } - -std::unique_ptr makeToyEarlyLowering() { - return llvm::make_unique(); -} - -} // namespace toy diff --git a/mlir/LateLowering.cpp b/mlir/LateLowering.cpp deleted file mode 100644 index eeae6ee..0000000 --- a/mlir/LateLowering.cpp +++ /dev/null @@ -1,452 +0,0 @@ -//====- LateLowering.cpp - Lowering from Toy+Linalg to LLVM -===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file implements late lowering of IR mixing Toy and Linalg to LLVM. -// It involves intemerdiate steps: -// - -// - a mix of affine and standard dialect. -// -//===----------------------------------------------------------------------===// - -#include "toy/Dialect.h" - -#include "linalg3/Intrinsics.h" -#include "linalg1/ViewOp.h" -#include "linalg3/ConvertToLLVMDialect.h" -#include "linalg3/TensorOps.h" -#include "linalg3/Transforms.h" -#include "mlir/EDSC/Builders.h" -#include "mlir/EDSC/Helpers.h" -#include "mlir/EDSC/Intrinsics.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/OperationSupport.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/LLVMIR/LLVMDialect.h" -#include "mlir/Parser.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" - -#include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Type.h" - -#include - -using namespace mlir; - -namespace { -/// Utility function for type casting: this is making the type checker happy, -/// while delaying the actual work involved to convert the type. Most of the -/// time both side of the cast (producer and consumer) will be lowered to a -/// dialect like LLVM and end up with the same LLVM representation, at which -/// point this becomes a no-op and is eliminated. -Value *typeCast(FuncBuilder &builder, Value *val, Type destTy) { - if (val->getType() == destTy) - return val; - return builder.create(val->getLoc(), val, destTy) - .getResult(); -} - -/// Create a type cast to turn a toy.array into a memref. The Toy Array will be -/// lowered to a memref during buffer allocation, at which point the type cast -/// becomes useless. -Value *memRefTypeCast(FuncBuilder &builder, Value *val) { - if (val->getType().isa()) - return val; - auto toyArrayTy = val->getType().dyn_cast(); - if (!toyArrayTy) - return val; - return typeCast(builder, val, toyArrayTy.toMemref()); -} - -/// Lower a toy.add to an affine loop nest. -/// -/// This class inherit from `DialectOpConversion` and override `rewrite`, -/// similarly to the PatternRewriter introduced in the previous chapter. -/// It will be called by the DialectConversion framework (see `LateLowering` -/// class below). -class AddOpConversion : public DialectOpConversion { -public: - explicit AddOpConversion(MLIRContext *context) - : DialectOpConversion(toy::AddOp::getOperationName(), 1, context) {} - - /// Lower the `op` by generating IR using the `rewriter` builder. The builder - /// is setup with a new function, the `operands` array has been populated with - /// the rewritten operands for `op` in the new function. - /// The results created by the new IR with the builder are returned, and their - /// number must match the number of result of `op`. - SmallVector rewrite(Operation *op, ArrayRef operands, - FuncBuilder &rewriter) const override { - auto add = op->cast(); - auto loc = add.getLoc(); - // Create a `toy.alloc` operation to allocate the output buffer for this op. - Value *result = memRefTypeCast( - rewriter, rewriter.create(loc, add.getResult()->getType()) - .getResult()); - Value *lhs = memRefTypeCast(rewriter, operands[0]); - Value *rhs = memRefTypeCast(rewriter, operands[1]); - - using namespace edsc; - ScopedContext scope(rewriter, loc); - ValueHandle zero = intrinsics::constant_index(0); - MemRefView vRes(result), vLHS(lhs), vRHS(rhs); - IndexedValue iRes(result), iLHS(lhs), iRHS(rhs); - IndexHandle i, j, M(vRes.ub(0)); - if (vRes.rank() == 1) { - LoopNestBuilder({&i}, {zero}, {M}, {1})({iRes(i) = iLHS(i) + iRHS(i)}); - } else { - assert(vRes.rank() == 2 && "only rank 1 and 2 are supported right now"); - IndexHandle N(vRes.ub(1)); - LoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, - {1, 1})({iRes(i, j) = iLHS(i, j) + iRHS(i, j)}); - } - - // Return the newly allocated buffer, with a type.cast to preserve the - // consumers. - return {typeCast(rewriter, result, add.getType())}; - } -}; - -/// Lowers `toy.print` to a loop nest calling `printf` on every individual -/// elements of the array. -class PrintOpConversion : public DialectOpConversion { -public: - explicit PrintOpConversion(MLIRContext *context) - : DialectOpConversion(toy::PrintOp::getOperationName(), 1, context) {} - - SmallVector rewrite(Operation *op, ArrayRef operands, - FuncBuilder &rewriter) const override { - // Get or create the declaration of the printf function in the module. - Function *printfFunc = getPrintf(*op->getFunction()->getModule()); - - auto print = op->cast(); - auto loc = print.getLoc(); - // We will operate on a MemRef abstraction, we use a type.cast to get one - // if our operand is still a Toy array. - Value *operand = memRefTypeCast(rewriter, operands[0]); - Type retTy = printfFunc->getType().getResult(0); - - // Create our loop nest now - using namespace edsc; - using llvmCall = intrinsics::ValueBuilder; - ScopedContext scope(rewriter, loc); - ValueHandle zero = intrinsics::constant_index(0); - ValueHandle fmtCst(getConstantCharBuffer(rewriter, loc, "%f ")); - MemRefView vOp(operand); - IndexedValue iOp(operand); - IndexHandle i, j, M(vOp.ub(0)); - - ValueHandle fmtEol(getConstantCharBuffer(rewriter, loc, "\n")); - if (vOp.rank() == 1) { - // clang-format off - LoopBuilder(&i, zero, M, 1)({ - llvmCall(retTy, - rewriter.getFunctionAttr(printfFunc), - {fmtCst, iOp(i)}) - }); - llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol}); - // clang-format on - } else { - IndexHandle N(vOp.ub(1)); - // clang-format off - LoopBuilder(&i, zero, M, 1)({ - LoopBuilder(&j, zero, N, 1)({ - llvmCall(retTy, - rewriter.getFunctionAttr(printfFunc), - {fmtCst, iOp(i, j)}) - }), - llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol}) - }); - // clang-format on - } - return {}; - } - -private: - // Turn a string into a toy.alloc (malloc/free abstraction) and a sequence - // of stores into the buffer, and return a MemRef into the buffer. - Value *getConstantCharBuffer(FuncBuilder &builder, Location loc, - StringRef data) const { - auto retTy = - builder.getMemRefType(data.size() + 1, builder.getIntegerType(8)); - Value *result = builder.create(loc, retTy).getResult(); - using namespace edsc; - using intrinsics::constant_index; - using intrinsics::constant_int; - ScopedContext scope(builder, loc); - MemRefView vOp(result); - IndexedValue iOp(result); - for (uint64_t i = 0; i < data.size(); ++i) { - iOp(constant_index(i)) = constant_int(data[i], 8); - } - iOp(constant_index(data.size())) = constant_int(0, 8); - return result; - } - - /// Return the prototype declaration for printf in the module, create it if - /// necessary. - Function *getPrintf(Module &module) const { - auto *printfFunc = module.getNamedFunction("printf"); - if (printfFunc) - return printfFunc; - - // Create a function declaration for printf, signature is `i32 (i8*, ...)` - Builder builder(&module); - MLIRContext *context = module.getContext(); - LLVM::LLVMDialect *llvmDialect = static_cast( - module.getContext()->getRegisteredDialect("llvm")); - auto &llvmModule = llvmDialect->getLLVMModule(); - llvm::IRBuilder<> llvmBuilder(llvmModule.getContext()); - - auto llvmI32Ty = LLVM::LLVMType::get(context, llvmBuilder.getIntNTy(32)); - auto llvmI8PtrTy = - LLVM::LLVMType::get(context, llvmBuilder.getIntNTy(8)->getPointerTo()); - auto printfTy = builder.getFunctionType({llvmI8PtrTy}, {llvmI32Ty}); - printfFunc = new Function(builder.getUnknownLoc(), "printf", printfTy); - // It should be variadic, but we don't support it fully just yet. - printfFunc->setAttr("std.varargs", builder.getBoolAttr(true)); - module.getFunctions().push_back(printfFunc); - return printfFunc; - } -}; - -/// Lowers constant to a sequence of store in a buffer. -class ConstantOpConversion : public DialectOpConversion { -public: - explicit ConstantOpConversion(MLIRContext *context) - : DialectOpConversion(toy::ConstantOp::getOperationName(), 1, context) {} - - SmallVector rewrite(Operation *op, ArrayRef operands, - FuncBuilder &rewriter) const override { - toy::ConstantOp cstOp = op->cast(); - auto loc = cstOp.getLoc(); - auto retTy = cstOp.getResult()->getType().cast(); - auto shape = retTy.getShape(); - Value *result = memRefTypeCast( - rewriter, rewriter.create(loc, retTy).getResult()); - - auto cstValue = cstOp.getValue(); - auto f64Ty = rewriter.getF64Type(); - using namespace edsc; - using intrinsics::constant_float; - using intrinsics::constant_index; - ScopedContext scope(rewriter, loc); - MemRefView vOp(result); - IndexedValue iOp(result); - for (uint64_t i = 0; i < shape[0]; ++i) { - if (shape.size() == 1) { - auto value = cstValue.getValue(ArrayRef{i}) - .cast() - .getValue(); - iOp(constant_index(i)) = constant_float(value, f64Ty); - continue; - } - for (uint64_t j = 0; j < shape[1]; ++j) { - auto value = cstValue.getValue(ArrayRef{i, j}) - .cast() - .getValue(); - iOp(constant_index(i), constant_index(j)) = - constant_float(value, f64Ty); - } - } - return {result}; - } -}; - -/// Lower transpose operation to an affine loop nest. -class TransposeOpConversion : public DialectOpConversion { -public: - explicit TransposeOpConversion(MLIRContext *context) - : DialectOpConversion(toy::TransposeOp::getOperationName(), 1, context) {} - - SmallVector rewrite(Operation *op, ArrayRef operands, - FuncBuilder &rewriter) const override { - auto transpose = op->cast(); - auto loc = transpose.getLoc(); - Value *result = memRefTypeCast( - rewriter, - rewriter.create(loc, transpose.getResult()->getType()) - .getResult()); - Value *operand = memRefTypeCast(rewriter, operands[0]); - - using namespace edsc; - ScopedContext scope(rewriter, loc); - ValueHandle zero = intrinsics::constant_index(0); - MemRefView vRes(result), vOperand(operand); - IndexedValue iRes(result), iOperand(operand); - IndexHandle i, j, M(vRes.ub(0)), N(vRes.ub(1)); - // clang-format off - LoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, {1, 1})({ - iRes(i, j) = iOperand(j, i) - }); - // clang-format on - - return {typeCast(rewriter, result, transpose.getType())}; - } -}; - -// Lower toy.return to standard return operation. -class ReturnOpConversion : public DialectOpConversion { -public: - explicit ReturnOpConversion(MLIRContext *context) - : DialectOpConversion(toy::ReturnOp::getOperationName(), 1, context) {} - - SmallVector rewrite(Operation *op, ArrayRef operands, - FuncBuilder &rewriter) const override { - auto retOp = op->cast(); - using namespace edsc; - auto loc = retOp.getLoc(); - // Argument is optional, handle both cases. - if (retOp.getNumOperands()) - rewriter.create(loc, operands[0]); - else - rewriter.create(loc); - return {}; - } -}; - -/// This is the main class registering our individual converter classes with -/// the DialectConversion framework in MLIR. -class LateLowering : public DialectConversion { -protected: - /// Initialize the list of converters. - llvm::DenseSet - initConverters(MLIRContext *context) override { - return ConversionListBuilder::build(&allocator, - context); - } - - /// Convert a Toy type, this gets called for block and region arguments, and - /// attributes. - Type convertType(Type t) override { - if (auto array = t.cast()) { - return array.toMemref(); - } - return t; - } - -private: - llvm::BumpPtrAllocator allocator; -}; - -/// This is lowering to Linalg the parts that can be (matmul and add on arrays) -/// and is targeting LLVM otherwise. -struct LateLoweringPass : public ModulePass { - - void runOnModule() override { - // Perform Toy specific lowering - if (failed(LateLowering().convert(&getModule()))) { - getModule().getContext()->emitError( - UnknownLoc::get(getModule().getContext()), "Error lowering Toy\n"); - signalPassFailure(); - } - // At this point the IR is almost using only standard and affine dialects. - // A few things remain before we emit LLVM IR. First to reuse as much of - // MLIR as possible we will try to lower everything to the standard and/or - // affine dialect: they already include conversion to the LLVM dialect. - - // First patch calls type to return memref instead of ToyArray - for (auto &function : getModule()) { - function.walk([&](Operation *op) { - auto callOp = op->dyn_cast(); - if (!callOp) - return; - if (!callOp.getNumResults()) - return; - auto retToyTy = - callOp.getResult(0)->getType().dyn_cast(); - if (!retToyTy) - return; - callOp.getResult(0)->setType(retToyTy.toMemref()); - }); - } - - for (auto &function : getModule()) { - function.walk([&](Operation *op) { - // Turns toy.alloc into sequence of alloc/dealloc (later malloc/free). - if (auto allocOp = op->dyn_cast()) { - auto result = allocTensor(allocOp); - allocOp.replaceAllUsesWith(result); - allocOp.erase(); - return; - } - // Eliminate all type.cast before lowering to LLVM. - if (auto typeCastOp = op->dyn_cast()) { - typeCastOp.replaceAllUsesWith(typeCastOp.getOperand()); - typeCastOp.erase(); - return; - } - }); - } - - // Lower Linalg to affine - for (auto &function : getModule()) - linalg::lowerToLoops(&function); - - getModule().dump(); - - // Finally convert to LLVM Dialect - linalg::convertLinalg3ToLLVM(getModule()); - } - - /// Allocate buffers (malloc/free) for Toy operations. This can't be done as - /// part of dialect conversion framework since we need to insert `dealloc` - /// operations just before the return, but the conversion framework is - /// operating in a brand new function: we don't have the return to hook the - /// dealloc operations. - Value *allocTensor(toy::AllocOp alloc) { - FuncBuilder builder(alloc); - auto retTy = alloc.getResult()->getType(); - - auto memRefTy = retTy.dyn_cast(); - if (!memRefTy) - memRefTy = retTy.cast().toMemref(); - if (!memRefTy) { - alloc.emitOpError("is expected to allocate a Toy array or a MemRef"); - llvm_unreachable("fatal error"); - } - auto loc = alloc.getLoc(); - Value *result = builder.create(loc, memRefTy).getResult(); - - // Insert a `dealloc` operation right before the `return` operations, unless - // it is returned itself in which case the caller is responsible for it. - builder.getFunction()->walk([&](Operation *op) { - auto returnOp = op->dyn_cast(); - if (!returnOp) - return; - if (returnOp.getNumOperands() && returnOp.getOperand(0) == alloc) - return; - builder.setInsertionPoint(returnOp); - builder.create(alloc.getLoc(), result); - }); - return result; - } -}; -} // end anonymous namespace - -namespace toy { -Pass *createLateLoweringPass() { return new LateLoweringPass(); } - -std::unique_ptr makeToyLateLowering() { - return llvm::make_unique(); -} - -} // namespace toy diff --git a/mlir/MLIRGen.cpp b/mlir/MLIRGen.cpp deleted file mode 100644 index e2001fb..0000000 --- a/mlir/MLIRGen.cpp +++ /dev/null @@ -1,480 +0,0 @@ -//===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file implements a simple IR generation targeting MLIR from a Module AST -// for the Toy language. -// -//===----------------------------------------------------------------------===// - -#include "toy/MLIRGen.h" -#include "toy/AST.h" -#include "toy/Dialect.h" - -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Module.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/IR/Types.h" -#include "mlir/StandardOps/Ops.h" - -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/ScopedHashTable.h" -#include "llvm/Support/raw_ostream.h" -#include - -using namespace toy; -using llvm::cast; -using llvm::dyn_cast; -using llvm::isa; -using llvm::make_unique; -using llvm::ScopedHashTableScope; -using llvm::SmallVector; -using llvm::StringRef; -using llvm::Twine; - -namespace { - -/// Implementation of a simple MLIR emission from the Toy AST. -/// -/// This will emit operations that are specific to the Toy language, preserving -/// the semantics of the language and (hopefully) allow to perform accurate -/// analysis and transformation based on these high level semantics. -/// -/// At this point we take advantage of the "raw" MLIR APIs to create operations -/// that haven't been registered in any way with MLIR. These operations are -/// unknown to MLIR, custom passes could operate by string-matching the name of -/// these operations, but no other type checking or semantic is associated with -/// them natively by MLIR. -class MLIRGenImpl { -public: - MLIRGenImpl(mlir::MLIRContext &context) : context(context) {} - - /// Public API: convert the AST for a Toy module (source file) to an MLIR - /// Module. - std::unique_ptr mlirGen(ModuleAST &moduleAST) { - // We create an empty MLIR module and codegen functions one at a time and - // add them to the module. - theModule = make_unique(&context); - - for (FunctionAST &F : moduleAST) { - auto func = mlirGen(F); - if (!func) - return nullptr; - theModule->getFunctions().push_back(func.release()); - } - - // FIXME: (in the next chapter...) without registering a dialect in MLIR, - // this won't do much, but it should at least check some structural - // properties. - if (failed(theModule->verify())) { - context.emitError(mlir::UnknownLoc::get(&context), - "Module verification error"); - return nullptr; - } - - return std::move(theModule); - } - -private: - /// In MLIR (like in LLVM) a "context" object holds the memory allocation and - /// the ownership of many internal structure of the IR and provide a level - /// of "uniquing" across multiple modules (types for instance). - mlir::MLIRContext &context; - - /// A "module" matches a source file: it contains a list of functions. - std::unique_ptr theModule; - - /// The builder is a helper class to create IR inside a function. It is - /// re-initialized every time we enter a function and kept around as a - /// convenience for emitting individual operations. - /// The builder is stateful, in particular it keeeps an "insertion point": - /// this is where the next operations will be introduced. - std::unique_ptr builder; - - /// The symbol table maps a variable name to a value in the current scope. - /// Entering a function creates a new scope, and the function arguments are - /// added to the mapping. When the processing of a function is terminated, the - /// scope is destroyed and the mappings created in this scope are dropped. - llvm::ScopedHashTable symbolTable; - - /// Helper conversion for a Toy AST location to an MLIR location. - mlir::FileLineColLoc loc(Location loc) { - return mlir::FileLineColLoc::get( - mlir::UniquedFilename::get(*loc.file, &context), loc.line, loc.col, - &context); - } - - /// Declare a variable in the current scope, return true if the variable - /// wasn't declared yet. - bool declare(llvm::StringRef var, mlir::Value *value) { - if (symbolTable.count(var)) { - return false; - } - symbolTable.insert(var, value); - return true; - } - - /// Create the prototype for an MLIR function with as many arguments as the - /// provided Toy AST prototype. - mlir::Function *mlirGen(PrototypeAST &proto) { - // This is a generic function, the return type will be inferred later. - llvm::SmallVector ret_types; - // Arguments type is uniformly a generic array. - llvm::SmallVector arg_types(proto.getArgs().size(), - getType(VarType{})); - auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context); - auto *function = new mlir::Function(loc(proto.loc()), proto.getName(), - func_type, /* attrs = */ {}); - - // Mark the function as generic: it'll require type specialization for every - // call site. - if (function->getNumArguments()) - function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context)); - - return function; - } - - /// Emit a new function and add it to the MLIR module. - std::unique_ptr mlirGen(FunctionAST &funcAST) { - // Create a scope in the symbol table to hold variable declarations. - ScopedHashTableScope var_scope(symbolTable); - - // Create an MLIR function for the given prototype. - std::unique_ptr function(mlirGen(*funcAST.getProto())); - if (!function) - return nullptr; - - // Let's start the body of the function now! - // In MLIR the entry block of the function is special: it must have the same - // argument list as the function itself. - function->addEntryBlock(); - - auto &entryBlock = function->front(); - auto &protoArgs = funcAST.getProto()->getArgs(); - // Declare all the function arguments in the symbol table. - for (const auto &name_value : - llvm::zip(protoArgs, entryBlock.getArguments())) { - declare(std::get<0>(name_value)->getName(), std::get<1>(name_value)); - } - - // Create a builder for the function, it will be used throughout the codegen - // to create operations in this function. - builder = llvm::make_unique(function.get()); - - // Emit the body of the function. - if (!mlirGen(*funcAST.getBody())) - return nullptr; - - // Implicitly return void if no return statement was emited. - // FIXME: we may fix the parser instead to always return the last expression - // (this would possibly help the REPL case later) - if (function->getBlocks().back().back().getName().getStringRef() != - "toy.return") { - ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None); - mlirGen(fakeRet); - } - - return function; - } - - /// Emit a binary operation - mlir::Value *mlirGen(BinaryExprAST &binop) { - // First emit the operations for each side of the operation before emitting - // the operation itself. For example if the expression is `a + foo(a)` - // 1) First it will visiting the LHS, which will return a reference to the - // value holding `a`. This value should have been emitted at declaration - // time and registered in the symbol table, so nothing would be - // codegen'd. If the value is not in the symbol table, an error has been - // emitted and nullptr is returned. - // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted - // and the result value is returned. If an error occurs we get a nullptr - // and propagate. - // - mlir::Value *L = mlirGen(*binop.getLHS()); - if (!L) - return nullptr; - mlir::Value *R = mlirGen(*binop.getRHS()); - if (!R) - return nullptr; - auto location = loc(binop.loc()); - - // Derive the operation name from the binary operator. At the moment we only - // support '+' and '*'. - switch (binop.getOp()) { - case '+': - return builder->create(location, L, R).getResult(); - break; - case '*': - return builder->create(location, L, R).getResult(); - default: - context.emitError(loc(binop.loc()), - Twine("Error: invalid binary operator '") + - Twine(binop.getOp()) + "'"); - return nullptr; - } - } - - // This is a reference to a variable in an expression. The variable is - // expected to have been declared and so should have a value in the symbol - // table, otherwise emit an error and return nullptr. - mlir::Value *mlirGen(VariableExprAST &expr) { - if (symbolTable.count(expr.getName())) - return symbolTable.lookup(expr.getName()); - context.emitError(loc(expr.loc()), Twine("Error: unknown variable '") + - expr.getName() + "'"); - return nullptr; - } - - // Emit a return operation, return true on success. - bool mlirGen(ReturnExprAST &ret) { - auto location = loc(ret.loc()); - // `return` takes an optional expression, we need to account for it here. - if (!ret.getExpr().hasValue()) { - builder->create(location); - return true; - } - auto *expr = mlirGen(*ret.getExpr().getValue()); - if (!expr) - return false; - builder->create(location, expr); - return true; - } - - // Emit a literal/constant array. It will be emitted as a flattened array of - // data in an Attribute attached to a `toy.constant` operation. - // See documentation on [Attributes](LangRef.md#attributes) for more details. - // Here is an excerpt: - // - // Attributes are the mechanism for specifying constant data in MLIR in - // places where a variable is never allowed [...]. They consist of a name - // and a [concrete attribute value](#attribute-values). It is possible to - // attach attributes to operations, functions, and function arguments. The - // set of expected attributes, their structure, and their interpretation - // are all contextually dependent on what they are attached to. - // - // Example, the source level statement: - // var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; - // will be converted to: - // %0 = "toy.constant"() {value: dense, - // [[1.000000e+00, 2.000000e+00, 3.000000e+00], - // [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> memref<2x3xf64> - // - mlir::Value *mlirGen(LiteralExprAST &lit) { - auto location = loc(lit.loc()); - // The attribute is a vector with an attribute per element (number) in the - // array, see `collectData()` below for more details. - std::vector data; - data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, - std::multiplies())); - collectData(lit, data); - - // FIXME: using a tensor type is a HACK here. - // Can we do differently without registering a dialect? Using a string blob? - mlir::Type elementType = mlir::FloatType::getF64(&context); - auto dataType = builder->getTensorType(lit.getDims(), elementType); - - // This is the actual attribute that actually hold the list of values for - // this array literal. - auto dataAttribute = builder->getDenseElementsAttr(dataType, data) - .cast(); - - // Build the MLIR op `toy.constant`, only boilerplate below. - return builder->create(location, lit.getDims(), dataAttribute) - .getResult(); - } - - // Recursive helper function to accumulate the data that compose an array - // literal. It flattens the nested structure in the supplied vector. For - // example with this array: - // [[1, 2], [3, 4]] - // we will generate: - // [ 1, 2, 3, 4 ] - // Individual numbers are wrapped in a light wrapper `mlir::FloatAttr`. - // Attributes are the way MLIR attaches constant to operations and functions. - void collectData(ExprAST &expr, std::vector &data) { - if (auto *lit = dyn_cast(&expr)) { - for (auto &value : lit->getValues()) - collectData(*value, data); - return; - } - assert(isa(expr) && "expected literal or number expr"); - mlir::Type elementType = mlir::FloatType::getF64(&context); - auto attr = mlir::FloatAttr::getChecked( - elementType, cast(expr).getValue(), loc(expr.loc())); - data.push_back(attr); - } - - // Emit a call expression. It emits specific operations for the `transpose` - // builtin. Other identifiers are assumed to be user-defined functions. - mlir::Value *mlirGen(CallExprAST &call) { - auto location = loc(call.loc()); - std::string callee = call.getCallee(); - if (callee == "transpose") { - if (call.getArgs().size() != 1) { - context.emitError( - location, Twine("MLIR codegen encountered an error: toy.transpose " - "does not accept multiple arguments")); - return nullptr; - } - mlir::Value *arg = mlirGen(*call.getArgs()[0]); - return builder->create(location, arg).getResult(); - } - - // Codegen the operands first - SmallVector operands; - for (auto &expr : call.getArgs()) { - auto *arg = mlirGen(*expr); - if (!arg) - return nullptr; - operands.push_back(arg); - } - // Calls to user-defined function are mapped to a custom call that takes - // the callee name as an attribute. - return builder->create(location, call.getCallee(), operands) - .getResult(); - } - - // Emit a call expression. It emits specific operations for two builtins: - // transpose(x) and print(x). Other identifiers are assumed to be user-defined - // functions. Return false on failure. - bool mlirGen(PrintExprAST &call) { - auto *arg = mlirGen(*call.getArg()); - if (!arg) - return false; - auto location = loc(call.loc()); - builder->create(location, arg); - return true; - } - - // Emit a constant for a single number (FIXME: semantic? broadcast?) - mlir::Value *mlirGen(NumberExprAST &num) { - auto location = loc(num.loc()); - mlir::Type elementType = mlir::FloatType::getF64(&context); - auto attr = mlir::FloatAttr::getChecked(elementType, num.getValue(), - loc(num.loc())); - return builder->create(location, attr).getResult(); - } - - // Dispatch codegen for the right expression subclass using RTTI. - mlir::Value *mlirGen(ExprAST &expr) { - switch (expr.getKind()) { - case toy::ExprAST::Expr_BinOp: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Var: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Literal: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Call: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Num: - return mlirGen(cast(expr)); - default: - context.emitError( - loc(expr.loc()), - Twine("MLIR codegen encountered an unhandled expr kind '") + - Twine(expr.getKind()) + "'"); - return nullptr; - } - } - - // Handle a variable declaration, we'll codegen the expression that forms the - // initializer and record the value in the symbol table before returning it. - // Future expressions will be able to reference this variable through symbol - // table lookup. - mlir::Value *mlirGen(VarDeclExprAST &vardecl) { - mlir::Value *value = nullptr; - auto location = loc(vardecl.loc()); - if (auto init = vardecl.getInitVal()) { - value = mlirGen(*init); - if (!value) - return nullptr; - // We have the initializer value, but in case the variable was declared - // with specific shape, we emit a "reshape" operation. It will get - // optimized out later as needed. - if (!vardecl.getType().shape.empty()) { - value = builder - ->create( - location, value, - getType(vardecl.getType()).cast()) - .getResult(); - } - } else { - context.emitError(loc(vardecl.loc()), - "Missing initializer in variable declaration"); - return nullptr; - } - // Register the value in the symbol table - declare(vardecl.getName(), value); - return value; - } - - /// Codegen a list of expression, return false if one of them hit an error. - bool mlirGen(ExprASTList &blockAST) { - ScopedHashTableScope var_scope(symbolTable); - for (auto &expr : blockAST) { - // Specific handling for variable declarations, return statement, and - // print. These can only appear in block list and not in nested - // expressions. - if (auto *vardecl = dyn_cast(expr.get())) { - if (!mlirGen(*vardecl)) - return false; - continue; - } - if (auto *ret = dyn_cast(expr.get())) { - if (!mlirGen(*ret)) - return false; - return true; - } - if (auto *print = dyn_cast(expr.get())) { - if (!mlirGen(*print)) - return false; - continue; - } - // Generic expression dispatch codegen. - if (!mlirGen(*expr)) - return false; - } - return true; - } - - /// Build a type from a list of shape dimensions. Types are `array` followed - /// by an optional dimension list, example: array<2, 2> - /// They are wrapped in a `toy` dialect (see next chapter) and get printed: - /// !toy.array<2, 2> - template mlir::Type getType(T shape) { - SmallVector shape64(shape.begin(), shape.end()); - return ToyArrayType::get(&context, shape64); - } - - /// Build an MLIR type from a Toy AST variable type - /// (forward to the generic getType(T) above). - mlir::Type getType(const VarType &type) { return getType(type.shape); } -}; - -} // namespace - -namespace toy { - -// The public API for codegen. -std::unique_ptr mlirGen(mlir::MLIRContext &context, - ModuleAST &moduleAST) { - return MLIRGenImpl(context).mlirGen(moduleAST); -} - -} // namespace toy diff --git a/mlir/ShapeInferencePass.cpp b/mlir/ShapeInferencePass.cpp deleted file mode 100644 index 7e3ea3f..0000000 --- a/mlir/ShapeInferencePass.cpp +++ /dev/null @@ -1,387 +0,0 @@ -//===- ShapeInferencePass.cpp - Toy Shape Inference / Func Specialization -===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file implements a Module level pass performing interprocedural -// propagation of array shapes through function specialization. -// -//===----------------------------------------------------------------------===// - -#include "toy/Dialect.h" - -#include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/Ops.h" -#include "mlir/Support/LogicalResult.h" -#include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringSet.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/raw_ostream.h" -#include - -#define DEBUG_TYPE "toy-shape-inference" - -using namespace toy; -using llvm::MutableArrayRef; -using llvm::SmallVector; -using llvm::SmallVectorImpl; -using llvm::StringRef; -using llvm::Twine; - -/// Create mangled name for function specialization. We will simply append the -/// shape of the arguments to the function name. For example calling -/// -/// "toy.generic_call"(%1, %3) {callee: "foo"} -/// : (!toy<"array<2, 3>">, !toy<"array<2, 3>">) -> !toy<"array"> -/// -/// would be mangled foo_2x3_2x3. This mangling isn't robust as the user could -/// have provide a function with a similar name. But we will claim this as a -/// feature: this allow the user to provide custom specialization! -static std::string mangle(StringRef funcName, - MutableArrayRef operands) { - std::string mangledName; - mangledName.reserve(funcName.size() + operands.size() * 6); - mangledName = funcName; - for (auto &operand : operands) { - auto arrayTy = operand.get()->getType().cast(); - mangledName += "_"; - const char *sep = ""; - for (auto dim : arrayTy.getShape()) { - mangledName += (sep + Twine(dim)).str(); - sep = "x"; - } - } - return mangledName; -} - -namespace { - -/// The ShapeInferencePass is a ModulePass: it will run on the Module as a -/// whole. MLIR also supports FunctionPass which are restricted to modify a -/// single function at a time. This pass couldn't be a function pass due the -/// nature of its interprocedural transformations. -/// -/// The algorithm has two levels, first intra-procedurally: -/// -/// 1) Build a worklist containing all the operations that are returning -/// a generic Toy array: these are the operations that need shape -/// inference. -/// 2) Iterate on the worklist: -/// a) find an operation to process: the next ready operation in the -/// worklist has all of its arguments non-generic, -/// b) if no operation is found, break out of the loop, -/// c) remove the operation from the worklist, -/// d) infer the shape of its output from the arguments type. -/// 3) If the worklist is empty, the algorithm succeeded and we infer the -/// return type for the function from the return operation. -/// -/// There is a twist though: when a call to a generic function is encountered, -/// shape inference requires the return type of the callee to be inferred first. -/// At this point we need to run specialize the callee by cloning it. Here is -/// the inter-procedural flow: -/// -/// 1) Keep a worklist of function to process. Start with function "main". -/// 2) While the worklist isn't empty: -/// a) Take the last inserted function in the worklist. -/// b) Run the intra-procedural shape inference on this function. -/// c) If the intra-procedural shape inference can't complete, it returns -/// a Function that needs to be inferred first. In this case, queue this -/// new function and continue. Otherwise the inference succeeded and we -/// can pop from the queue. -/// -class ShapeInferencePass : public mlir::ModulePass { -public: - // One entry in the inter-procedural worklist. It keeps track of the - // function to process, the mangled name for this specialization, and the - // types of the arguments on which to specialize. - struct FunctionToSpecialize { - mlir::Function *function; - std::string mangledName; - std::vector argumentsType; - }; - - void runOnModule() override { - auto &module = getModule(); - auto *main = module.getNamedFunction("main"); - if (!main) { - module.getContext()->emitError( - mlir::UnknownLoc::get(module.getContext()), - "Shape inference failed: can't find a main function\n"); - signalPassFailure(); - return; - } - - /// Inter-procedural loop, initialize with `main` and iterate till - /// successfully infer the full reachable call-graph from main. - SmallVector worklist; - worklist.push_back({main, "", {}}); - while (!worklist.empty()) { - if (failed(specialize(worklist))) - return; - } - - // Delete any generic function left - // FIXME: we may want this as a separate pass. - for (mlir::Function &function : llvm::make_early_inc_range(module)) { - if (auto genericAttr = - function.getAttrOfType("toy.generic")) { - if (genericAttr.getValue()) - function.erase(); - } - } - } - - /// Run inference on a function. If a mangledName is provided, we need to - /// specialize the function: to this end clone it first. - mlir::LogicalResult - specialize(SmallVectorImpl &funcWorklist) { - FunctionToSpecialize &functionToSpecialize = funcWorklist.back(); - mlir::Function *f = functionToSpecialize.function; - - // Check if cloning for specialization is needed (usually anything but main) - // We will create a new function with the concrete types for the parameters - // and clone the body into it. - if (!functionToSpecialize.mangledName.empty()) { - if (getModule().getNamedFunction(functionToSpecialize.mangledName)) { - funcWorklist.pop_back(); - // Function already specialized, move on. - return mlir::success(); - } - // Create a new function with a generic array return type, it will be - // updated when the inference for the function body completes. - auto type = mlir::FunctionType::get(functionToSpecialize.argumentsType, - {ToyArrayType::get(&getContext())}, - &getContext()); - auto *newFunction = new mlir::Function( - f->getLoc(), functionToSpecialize.mangledName, type, f->getAttrs()); - getModule().getFunctions().push_back(newFunction); - - // Clone the function body - mlir::BlockAndValueMapping mapper; - f->cloneInto(newFunction, mapper); - LLVM_DEBUG({ - llvm::dbgs() << "====== Cloned : \n"; - f->dump(); - llvm::dbgs() << "====== Into : \n"; - newFunction->dump(); - }); - f = newFunction; - f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); - // Remap the entry-block arguments - // FIXME: this seems like a bug in `cloneInto()` above? - auto &entryBlock = f->getBlocks().front(); - int blockArgSize = entryBlock.getArguments().size(); - assert(blockArgSize == f->getType().getInputs().size()); - entryBlock.addArguments(f->getType().getInputs()); - auto argList = entryBlock.getArguments(); - for (int argNum = 0; argNum < blockArgSize; ++argNum) { - argList[0]->replaceAllUsesWith(argList[blockArgSize]); - entryBlock.eraseArgument(0); - } - assert(succeeded(f->verify())); - } - LLVM_DEBUG(llvm::dbgs() - << "Run shape inference on : '" << f->getName() << "'\n"); - - auto *toyDialect = getContext().getRegisteredDialect("toy"); - if (!toyDialect) { - getContext().emitError(mlir::UnknownLoc::get(&getContext()), - "Toy dialect is not registered"); - signalPassFailure(); - return mlir::failure(); - } - - // Populate the worklist with the operations that need shape inference: - // these are the Toy operations that return a generic array. - llvm::SmallPtrSet opWorklist; - f->walk([&](mlir::Operation *op) { - if (op->getDialect() == toyDialect) { - if (op->getNumResults() == 1 && - op->getResult(0)->getType().cast().isGeneric()) - opWorklist.insert(op); - } - }); - - // Iterate on the operations in the worklist until all operations have been - // inferred or no change happened (fix point). - while (!opWorklist.empty()) { - // Find the next operation ready for inference, that is an operation - // with all operands already resolved (non-generic). - auto nextop = llvm::find_if(opWorklist, [](mlir::Operation *op) { - return llvm::all_of(op->getOperands(), [](mlir::Value *v) { - return !v->getType().cast().isGeneric(); - }); - }); - if (nextop == opWorklist.end()) - break; // failure: no operations can be inferred. - - mlir::Operation *op = *nextop; - opWorklist.erase(op); - LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); - - // The add operation is trivial: propagate the input type as is. - if (auto addOp = op->dyn_cast()) { - op->getResult(0)->setType(op->getOperand(0)->getType()); - continue; - } - - // Transpose is easy: just invert the dimensions. - if (op->getName().getStringRef() == "toy.transpose") { - SmallVector dims; - auto arrayTy = op->getOperand(0)->getType().cast(); - dims.insert(dims.end(), arrayTy.getShape().begin(), - arrayTy.getShape().end()); - if (dims.size() == 2) - std::swap(dims[0], dims[1]); - op->getResult(0)->setType(ToyArrayType::get(&getContext(), dims)); - continue; - } - - // Multiplication is a bit trickier, handle rank 1 as dot product and rank - // 2 as matrix multiplications. - // We need to be careful about rank mismatch here: the verifier could - // catch it but shape inference earlier in the pass could generate an - // invalid IR (from an invalid Toy input of course) and we wouldn't want - // to crash here. - if (auto mulOp = op->dyn_cast()) { - auto lhs = mulOp.getLHS()->getType().cast(); - auto rhs = mulOp.getRHS()->getType().cast(); - auto lhsRank = lhs.getShape().size(); - auto rhsRank = rhs.getShape().size(); - if (lhsRank != rhsRank) { - op->emitError("Shape mismatch: LHS and RHS must have the same " - "rank for multiplication, got " + - Twine(lhsRank) + " vs " + Twine(lhsRank)); - return mlir::failure(); - } - SmallVector dims; - if (lhsRank == 1) { - // dot product, result shape is <1> - dims.push_back(1); - } else { - if (lhsRank != 2) { - op->emitError( - "Shape mismatch: expect rank 1 or 2 for mul operands, got " + - Twine(lhsRank)); - return mlir::failure(); - } - dims.push_back(lhs.getShape()[0]); - dims.push_back(rhs.getShape()[1]); - } - op->getResult(0)->setType(ToyArrayType::get(&getContext(), dims)); - continue; - } - - // Process calls: lookup the callee after mangling the name with the - // argument shapes. If the callee does not exist, we stop the inference - // for this function, queue the callee in the inter-procedural work list, - // and return. The current function stays in the work list and will - // restart after the callee is processed. - if (auto callOp = op->dyn_cast()) { - auto calleeName = callOp.getCalleeName(); - auto *callee = getModule().getNamedFunction(calleeName); - if (!callee) { - f->emitError( - llvm::Twine("Shape inference failed, call to unknown '") + - calleeName + "'"); - signalPassFailure(); - return mlir::failure(); - } - auto mangledName = mangle(calleeName, op->getOpOperands()); - LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName - << "', mangled: '" << mangledName << "'\n"); - auto *mangledCallee = getModule().getNamedFunction(mangledName); - if (!mangledCallee) { - // Can't find the target, this is where we queue the request for the - // callee and stop the inference for the current function now. - std::vector funcArgs; - for (auto operand : op->getOperands()) - funcArgs.push_back(operand->getType()); - funcWorklist.push_back( - {callee, std::move(mangledName), std::move(funcArgs)}); - return mlir::success(); - } - // Found a specialized callee! Let's turn this into a normal call - // operation. - SmallVector operands; - for (mlir::Value *v : op->getOperands()) - operands.push_back(v); - mlir::FuncBuilder builder(f); - builder.setInsertionPoint(op); - auto newCall = - builder.create(op->getLoc(), mangledCallee, operands); - if (newCall.getNumResults()) { - op->getResult(0)->replaceAllUsesWith(newCall.getResult(0)); - op->erase(); - continue; - } - } - } - - // Done with inference on this function, removing it from the worklist. - funcWorklist.pop_back(); - // Mark the function as non-generic now that inference has succeeded - f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); - - // If the operation worklist isn't empty, this indicates a failure. - if (!opWorklist.empty()) { - std::string str; - llvm::raw_string_ostream errorMsg(str); - errorMsg << "Shape inference failed, " << opWorklist.size() - << " operations couldn't be inferred\n"; - for (auto *ope : opWorklist) - errorMsg << " - " << *ope << "\n"; - f->emitError(errorMsg.str()); - signalPassFailure(); - return mlir::failure(); - } - - // Finally, update the return type of the function based on the argument to - // the return operation. - for (auto &block : f->getBlocks()) { - auto ret = block.getTerminator()->cast(); - if (!ret) - continue; - if (ret.getNumOperands() && - f->getType().getResult(0) == ret.getOperand()->getType()) - // type match, we're done - break; - SmallVector retTy; - if (ret.getNumOperands()) - retTy.push_back(ret.getOperand()->getType()); - mlir::Type elementType = mlir::FloatType::getF64(&getContext()); - std::vector argumentsType; - for (auto arg : f->getArguments()) - argumentsType.push_back(arg->getType()); - auto newType = - mlir::FunctionType::get(argumentsType, retTy, &getContext()); - f->setType(newType); - assert(succeeded(f->verify())); - break; - } - return mlir::success(); - } -}; -} // end anonymous namespace - -namespace toy { -mlir::Pass *createShapeInferencePass() { return new ShapeInferencePass(); } -} // namespace toy diff --git a/mlir/ToyCombine.cpp b/mlir/ToyCombine.cpp deleted file mode 100644 index 8d6aed6..0000000 --- a/mlir/ToyCombine.cpp +++ /dev/null @@ -1,209 +0,0 @@ -//===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file implements a simple combiner for optimizing pattern in the Toy -// dialect. -// -//===----------------------------------------------------------------------===// - -#include "toy/Dialect.h" - -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/StandardTypes.h" - -#include - -namespace toy { - -namespace { - -/// Fold transpose(transpose(x)) -> transpose(x) -struct SimplifyRedundantTranspose : public mlir::RewritePattern { - /// We register this pattern to match every toy.transpose in the IR. - /// The "benefit" is used by the framework to order the patterns and process - /// them in order of profitability. - SimplifyRedundantTranspose(mlir::MLIRContext *context) - : RewritePattern(TransposeOp::getOperationName(), /* benefit = */ 1, - context) {} - - /// This method is attempting to match a pattern and rewrite it. The rewriter - /// argument is the orchestrator of the sequence of rewrites. It is expected - /// to interact with it to perform any changes to the IR from here. - mlir::PatternMatchResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - // We can directly cast the current operation as this will only get invoked - // on TransposeOp. - TransposeOp transpose = op->cast(); - // look through the input to the current transpose - mlir::Value *transposeInput = transpose.getOperand(); - mlir::Operation *transposeInputInst = transposeInput->getDefiningOp(); - // If the input is defined by another Transpose, bingo! - TransposeOp transposeInputOp = - mlir::dyn_cast_or_null(transposeInputInst); - if (!transposeInputOp) - return matchFailure(); - - // Use the rewriter to perform the replacement - rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp}); - return matchSuccess(); - } -}; - -/// Fold reshape(constant(x)) -> constant(x'), with x' being reshaped in place. -struct SimplifyReshapeConstant : public mlir::RewritePattern { - SimplifyReshapeConstant(mlir::MLIRContext *context) - : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1, - context) {} - - mlir::PatternMatchResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - ReshapeOp reshape = op->cast(); - // look through the input to the current reshape - mlir::Value *reshapeInput = reshape.getOperand(); - mlir::Operation *reshapeInputInst = reshapeInput->getDefiningOp(); - // If the input is defined by another reshape, bingo! - ConstantOp constantOp = - mlir::dyn_cast_or_null(reshapeInputInst); - if (!constantOp) - return matchFailure(); - - auto reshapeType = op->getResult(0)->getType().cast(); - if (auto valueAttr = - constantOp.getAttrOfType("value")) { - // FIXME Check matching of element count! - // auto oldType = constantOp.getType(); - auto newType = rewriter.getTensorType( - reshapeType.getShape(), valueAttr.getType().getElementType()); - auto newAttr = - mlir::DenseElementsAttr::get(newType, valueAttr.getRawData()); - auto newConstant = rewriter.create( - constantOp.getLoc(), reshapeType.getShape(), newAttr); - rewriter.replaceOp(op, {newConstant}); - } else if (auto valueAttr = - constantOp.getAttrOfType("value")) { - // Broadcast - auto dataSize = std::accumulate(reshapeType.getShape().begin(), - reshapeType.getShape().end(), 1, - std::multiplies()); - std::vector data(dataSize, valueAttr); - auto tensorTy = rewriter.getTensorType(reshapeType.getShape(), - reshapeType.getElementType()); - auto newAttr = mlir::DenseElementsAttr::get(tensorTy, data); - auto newConstant = rewriter.create( - constantOp.getLoc(), reshapeType.getShape(), newAttr); - rewriter.replaceOp(op, {newConstant}); - } else { - llvm_unreachable("Unsupported Constant format"); - } - return matchSuccess(); - } -}; - -/// Fold reshape(reshape(x)) -> reshape(x) -struct SimplifyReshapeReshape : public mlir::RewritePattern { - SimplifyReshapeReshape(mlir::MLIRContext *context) - : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1, - context) {} - - mlir::PatternMatchResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - ReshapeOp reshape = op->cast(); - // look through the input to the current reshape - mlir::Value *reshapeInput = reshape.getOperand(); - mlir::Operation *reshapeInputInst = reshapeInput->getDefiningOp(); - // If the input is defined by another reshape, bingo! - ReshapeOp reshapeInputOp = - mlir::dyn_cast_or_null(reshapeInputInst); - if (!reshapeInputOp) - return matchFailure(); - - // Use the rewriter to perform the replacement - rewriter.replaceOp(op, {reshapeInputOp}); - return matchSuccess(); - } -}; - -/// Fold reshape(x)) -> x, when input type matches output type -struct SimplifyNullReshape : public mlir::RewritePattern { - SimplifyNullReshape(mlir::MLIRContext *context) - : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1, - context) {} - - mlir::PatternMatchResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - ReshapeOp reshape = op->cast(); - if (reshape.getOperand()->getType() != reshape.getResult()->getType()) - return matchFailure(); - rewriter.replaceOp(reshape, {reshape.getOperand()}); - return matchSuccess(); - } -}; - -} // end anonymous namespace. - -// Register our patterns for rewrite by the Canonicalization framework. -void TransposeOp::getCanonicalizationPatterns( - mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { - results.push_back(llvm::make_unique(context)); -} - -// Register our patterns for rewrite by the Canonicalization framework. -void ReshapeOp::getCanonicalizationPatterns( - mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { - results.push_back(llvm::make_unique(context)); - results.push_back(llvm::make_unique(context)); - results.push_back(llvm::make_unique(context)); -} - -namespace { - -/// Fold type.cast(x) -> x, when input type matches output type -struct SimplifyIdentityTypeCast : public mlir::RewritePattern { - SimplifyIdentityTypeCast(mlir::MLIRContext *context) - : RewritePattern(TypeCastOp::getOperationName(), /* benefit = */ 1, - context) {} - - mlir::PatternMatchResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - TypeCastOp typeCast = op->cast(); - auto resTy = typeCast.getResult()->getType(); - auto *candidateOp = op; - while (candidateOp && candidateOp->isa()) { - if (resTy == candidateOp->getOperand(0)->getType()) { - rewriter.replaceOp(typeCast, {candidateOp->getOperand(0)}); - return matchSuccess(); - } - candidateOp = candidateOp->getOperand(0)->getDefiningOp(); - } - return matchFailure(); - } -}; - -} // end anonymous namespace. - -void TypeCastOp::getCanonicalizationPatterns( - mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { - results.push_back(llvm::make_unique(context)); -} - -} // namespace toy diff --git a/mlir/ToyDialect.cpp b/mlir/ToyDialect.cpp deleted file mode 100644 index be117f5..0000000 --- a/mlir/ToyDialect.cpp +++ /dev/null @@ -1,405 +0,0 @@ -//===- ToyDialect.cpp - Toy IR Dialect registration in MLIR ---------------===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file implements the dialect for the Toy IR: custom type parsing and -// operation verification. -// -//===----------------------------------------------------------------------===// - -#include "toy/Dialect.h" - -#include "mlir/IR/Builders.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/Support/STLExtras.h" -#include "llvm/ADT/iterator_range.h" -#include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/Regex.h" -#include "llvm/Support/raw_ostream.h" - -using llvm::ArrayRef; -using llvm::raw_ostream; -using llvm::raw_string_ostream; -using llvm::SmallVector; -using llvm::StringRef; -using llvm::Twine; - -namespace toy { -namespace detail { - -/// This class holds the implementation of the ToyArrayType. -/// It is intended to be uniqued based on its content and owned by the context. -struct ToyArrayTypeStorage : public mlir::TypeStorage { - /// This defines how we unique this type in the context: our key contains - /// only the shape, a more complex type would have multiple entries in the - /// tuple here. - /// The element of the tuples usually matches 1-1 the arguments from the - /// public `get()` method arguments from the facade. - using KeyTy = std::tuple>; - static unsigned hashKey(const KeyTy &key) { - return llvm::hash_combine(std::get<0>(key)); - } - /// When the key hash hits an existing type, we compare the shape themselves - /// to confirm we have the right type. - bool operator==(const KeyTy &key) const { return key == KeyTy(getShape()); } - - /// This is a factory method to create our type storage. It is only - /// invoked after looking up the type in the context using the key and not - /// finding it. - static ToyArrayTypeStorage *construct(mlir::TypeStorageAllocator &allocator, - const KeyTy &key) { - // Copy the shape array into the bumpptr allocator owned by the context. - ArrayRef shape = allocator.copyInto(std::get<0>(key)); - - // Allocate the instance for the ToyArrayTypeStorage itself - auto *storage = allocator.allocate(); - // Initialize the instance using placement new. - return new (storage) ToyArrayTypeStorage(shape); - } - - ArrayRef getShape() const { return shape; } - -private: - ArrayRef shape; - - /// Constructor is only invoked from the `construct()` method above. - ToyArrayTypeStorage(ArrayRef shape) : shape(shape) {} -}; - -} // namespace detail - -mlir::Type ToyArrayType::getElementType() { - return mlir::FloatType::getF64(getContext()); -} - -ToyArrayType ToyArrayType::get(mlir::MLIRContext *context, - ArrayRef shape) { - return Base::get(context, ToyTypeKind::TOY_ARRAY, shape); -} - -ArrayRef ToyArrayType::getShape() { return getImpl()->getShape(); } - -mlir::MemRefType ToyArrayType::toMemref() { - auto memRefType = mlir::MemRefType::get(getShape(), getElementType(), {}, 0); - return memRefType; -} - -/// Dialect creation, the instance will be owned by the context. This is the -/// point of registration of custom types and operations for the dialect. -ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) { - addOperations(); - addTypes(); -} - -/// Parse a type registered to this dialect, we expect only Toy arrays. -mlir::Type ToyDialect::parseType(StringRef tyData, mlir::Location loc) const { - // Sanity check: we only support array or array<...> - if (!tyData.startswith("array")) { - getContext()->emitError(loc, "Invalid Toy type '" + tyData + - "', array expected"); - return nullptr; - } - // Drop the "array" prefix from the type name, we expect either an empty - // string or just the shape. - tyData = tyData.drop_front(StringRef("array").size()); - // This is the generic array case without shape, early return it. - if (tyData.empty()) - return ToyArrayType::get(getContext()); - - // Use a regex to parse the shape (for efficient we should store this regex in - // the dialect itself). - SmallVector matches; - auto shapeRegex = llvm::Regex("^<([0-9]+)(, ([0-9]+))*>$"); - if (!shapeRegex.match(tyData, &matches)) { - getContext()->emitError(loc, "Invalid toy array shape '" + tyData + "'"); - return nullptr; - } - SmallVector shape; - // Iterate through the captures, skip the first one which is the full string. - for (auto dimStr : - llvm::make_range(std::next(matches.begin()), matches.end())) { - if (dimStr.startswith(",")) - continue; // POSIX misses non-capturing groups. - if (dimStr.empty()) - continue; // '*' makes it an optional group capture - // Convert the capture to an integer - unsigned long long dim; - if (getAsUnsignedInteger(dimStr, /* Radix = */ 10, dim)) { - getContext()->emitError( - loc, "Couldn't parse dimension as integer, matched: " + dimStr); - return mlir::Type(); - } - shape.push_back(dim); - } - // Finally we collected all the dimensions in the shape, - // create the array type. - return ToyArrayType::get(getContext(), shape); -} - -/// Print a Toy array type, for example `array<2, 3, 4>` -void ToyDialect::printType(mlir::Type type, raw_ostream &os) const { - auto arrayTy = type.dyn_cast(); - if (!arrayTy) { - os << "unknown toy type"; - return; - } - os << "array"; - if (!arrayTy.getShape().empty()) { - os << "<"; - mlir::interleaveComma(arrayTy.getShape(), os); - os << ">"; - } -} - -//////////////////////////////////////////////////////////////////////////////// -//////////////////// Custom Operations for the Dialect ///////////////////////// -//////////////////////////////////////////////////////////////////////////////// - -/// Helper to verify that the result of an operation is a Toy array type. -template static mlir::LogicalResult verifyToyReturnArray(T *op) { - if (!op->getResult()->getType().template isa()) { - std::string msg; - raw_string_ostream os(msg); - os << "expects a Toy Array for its argument, got " - << op->getResult()->getType(); - return op->emitOpError(os.str()); - } - return mlir::success(); -} - -/// Helper to verify that the two operands of a binary operation are Toy -/// arrays.. -template static mlir::LogicalResult verifyToyBinOperands(T *op) { - if (!op->getOperand(0)->getType().template isa()) { - std::string msg; - raw_string_ostream os(msg); - os << "expects a Toy Array for its LHS, got " - << op->getOperand(0)->getType(); - return op->emitOpError(os.str()); - } - if (!op->getOperand(1)->getType().template isa()) { - std::string msg; - raw_string_ostream os(msg); - os << "expects a Toy Array for its LHS, got " - << op->getOperand(0)->getType(); - return op->emitOpError(os.str()); - } - return mlir::success(); -} - -/// Build a constant operation. -/// The builder is passed as an argument, so is the state that this method is -/// expected to fill in order to build the operation. -void ConstantOp::build(mlir::Builder *builder, mlir::OperationState *state, - ArrayRef shape, mlir::DenseElementsAttr value) { - state->types.push_back(ToyArrayType::get(builder->getContext(), shape)); - auto dataAttribute = builder->getNamedAttr("value", value); - state->attributes.push_back(dataAttribute); -} - -/// Build a constant operation. -/// The builder is passed as an argument, so is the state that this method is -/// expected to fill in order to build the operation. -void ConstantOp::build(mlir::Builder *builder, mlir::OperationState *state, - mlir::FloatAttr value) { - // Broadcast and forward to the other build factory - mlir::Type elementType = mlir::FloatType::getF64(builder->getContext()); - auto dataType = builder->getTensorType({1}, elementType); - auto dataAttribute = builder->getDenseElementsAttr(dataType, {value}) - .cast(); - - ConstantOp::build(builder, state, {1}, dataAttribute); -} - -/// Verifier for constant operation. -mlir::LogicalResult ConstantOp::verify() { - // Ensure that the return type is a Toy array - if (failed(verifyToyReturnArray(this))) - return mlir::failure(); - - // We expect the constant itself to be stored as an attribute. - auto dataAttr = getAttr("value").dyn_cast(); - if (!dataAttr) { - return emitOpError( - "missing valid `value` DenseElementsAttribute on toy.constant()"); - } - auto attrType = dataAttr.getType().dyn_cast(); - if (!attrType) { - return emitOpError( - "missing valid `value` DenseElementsAttribute on toy.constant()"); - } - - // If the return type of the constant is not a generic array, the shape must - // match the shape of the attribute holding the data. - auto resultType = getResult()->getType().cast(); - if (!resultType.isGeneric()) { - if (attrType.getRank() != resultType.getRank()) { - return emitOpError("The rank of the toy.constant return type must match " - "the one of the attached value attribute: " + - Twine(attrType.getRank()) + - " != " + Twine(resultType.getRank())); - } - for (int dim = 0; dim < attrType.getRank(); ++dim) { - if (attrType.getShape()[dim] != resultType.getShape()[dim]) { - std::string msg; - raw_string_ostream os(msg); - return emitOpError( - "Shape mismatch between toy.constant return type and its " - "attribute at dimension " + - Twine(dim) + ": " + Twine(attrType.getShape()[dim]) + - " != " + Twine(resultType.getShape()[dim])); - } - } - } - return mlir::success(); -} - -void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState *state, - StringRef callee, ArrayRef arguments) { - // Generic call always returns a generic ToyArray initially - state->types.push_back(ToyArrayType::get(builder->getContext())); - state->operands.assign(arguments.begin(), arguments.end()); - auto calleeAttr = builder->getStringAttr(callee); - state->attributes.push_back(builder->getNamedAttr("callee", calleeAttr)); -} - -mlir::LogicalResult GenericCallOp::verify() { - // Verify that every operand is a Toy Array - for (int opId = 0, num = getNumOperands(); opId < num; ++opId) { - if (!getOperand(opId)->getType().template isa()) { - std::string msg; - raw_string_ostream os(msg); - os << "expects a Toy Array for its " << opId << " operand, got " - << getOperand(opId)->getType(); - return emitOpError(os.str()); - } - } - return mlir::success(); -} - -/// Return the name of the callee. -StringRef GenericCallOp::getCalleeName() { - return getAttr("callee").cast().getValue(); -} - -template static mlir::LogicalResult verifyToySingleOperand(T *op) { - if (!op->getOperand()->getType().template isa()) { - std::string msg; - raw_string_ostream os(msg); - os << "expects a Toy Array for its argument, got " - << op->getOperand()->getType(); - return op->emitOpError(os.str()); - } - return mlir::success(); -} - -void ReturnOp::build(mlir::Builder *builder, mlir::OperationState *state, - mlir::Value *value) { - // Return does not return any value and has an optional single argument - if (value) - state->operands.push_back(value); -} - -mlir::LogicalResult ReturnOp::verify() { - if (getNumOperands() > 1) - return emitOpError("expects zero or one operand, got " + - Twine(getNumOperands())); - if (hasOperand() && failed(verifyToySingleOperand(this))) - return mlir::failure(); - return mlir::success(); -} - -void PrintOp::build(mlir::Builder *builder, mlir::OperationState *state, - mlir::Value *value) { - // Print does not return any value and has a single argument - state->operands.push_back(value); -} - -mlir::LogicalResult PrintOp::verify() { - if (failed(verifyToySingleOperand(this))) - return mlir::failure(); - return mlir::success(); -} - -void TransposeOp::build(mlir::Builder *builder, mlir::OperationState *state, - mlir::Value *value) { - state->types.push_back(ToyArrayType::get(builder->getContext())); - state->operands.push_back(value); -} - -mlir::LogicalResult TransposeOp::verify() { - if (failed(verifyToySingleOperand(this))) - return mlir::failure(); - return mlir::success(); -} - -void ReshapeOp::build(mlir::Builder *builder, mlir::OperationState *state, - mlir::Value *value, ToyArrayType reshapedType) { - state->types.push_back(reshapedType); - state->operands.push_back(value); -} - -mlir::LogicalResult ReshapeOp::verify() { - if (failed(verifyToySingleOperand(this))) - return mlir::failure(); - auto retTy = getResult()->getType().dyn_cast(); - if (!retTy) - return emitOpError("toy.reshape is expected to produce a Toy array"); - if (retTy.isGeneric()) - return emitOpError("toy.reshape is expected to produce a shaped Toy array, " - "got a generic one."); - return mlir::success(); -} - -void AddOp::build(mlir::Builder *builder, mlir::OperationState *state, - mlir::Value *lhs, mlir::Value *rhs) { - state->types.push_back(ToyArrayType::get(builder->getContext())); - state->operands.push_back(lhs); - state->operands.push_back(rhs); -} - -mlir::LogicalResult AddOp::verify() { - if (failed(verifyToyBinOperands(this))) - return mlir::failure(); - return mlir::success(); -} - -void MulOp::build(mlir::Builder *builder, mlir::OperationState *state, - mlir::Value *lhs, mlir::Value *rhs) { - state->types.push_back(ToyArrayType::get(builder->getContext())); - state->operands.push_back(lhs); - state->operands.push_back(rhs); -} - -mlir::LogicalResult MulOp::verify() { - if (failed(verifyToyBinOperands(this))) - return mlir::failure(); - return mlir::success(); -} - -void AllocOp::build(mlir::Builder *builder, mlir::OperationState *state, - mlir::Type retType) { - state->types.push_back(retType); -} - -void TypeCastOp::build(mlir::Builder *builder, mlir::OperationState *state, - mlir::Value *value, mlir::Type destTy) { - state->operands.push_back(value); - state->types.push_back(destTy); -} - -} // namespace toy diff --git a/toy/CMakeLists.txt b/toy/CMakeLists.txt new file mode 100644 index 0000000..1c7826e --- /dev/null +++ b/toy/CMakeLists.txt @@ -0,0 +1,6 @@ +file(GLOB globbed *.c *.cpp) +add_llvm_library(ToyDialect + ${globbed} + ) +#add_dependencies(MLIRStandardOps MLIRStandardOpsIncGen LLVMSupport) +#target_link_libraries(MLIRStandardOps LLVMSupport) diff --git a/toy/EarlyLowering.cpp b/toy/EarlyLowering.cpp new file mode 100644 index 0000000..634c72e --- /dev/null +++ b/toy/EarlyLowering.cpp @@ -0,0 +1,158 @@ +//=======- EarlyLowering.cpp - Toy Lowering to Linear Algebra Dialect -=======// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements early lowering of Toy IR to Linalg Dialect: we only +// lower the computationally intensive part of the program (matmul...) to a +// dialect specialized for optimizations. +// +// This is intended to showcase how multiple dialects can cohabit in the same +// function. After this lowering, you would still have toy.print in the IR for +// example. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" + +#include "linalg3/Intrinsics.h" +#include "linalg1/ViewOp.h" +#include "linalg3/TensorOps.h" +#include "mlir/EDSC/Builders.h" +#include "mlir/EDSC/Helpers.h" +#include "mlir/EDSC/Intrinsics.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/LLVMIR/LLVMDialect.h" +#include "mlir/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Type.h" + +#include + +using namespace mlir; + +namespace { +/// Utility function for type casting: this is making the type checker happy, +/// while delaying the actual work involved to convert the type. Most of the +/// time both side of the cast (producer and consumer) will be lowered to a +/// dialect like LLVM and end up with the same LLVM representation, at which +/// point this becomes a no-op and is eliminated. +Value *typeCast(FuncBuilder &builder, Value *val, Type destTy) { + if (val->getType() == destTy) + return val; + return builder.create(val->getLoc(), val, destTy) + .getResult(); +} + +/// Create a type cast to turn a toy.array into a memref. The Toy Array will be +/// lowered to a memref during buffer allocation, at which point the type cast +/// becomes useless. +Value *memRefTypeCast(FuncBuilder &builder, Value *val) { + if (val->getType().isa()) + return val; + auto toyArrayTy = val->getType().dyn_cast(); + if (!toyArrayTy) + return val; + return typeCast(builder, val, toyArrayTy.toMemref()); +} + +/// Lower toy.mul to Linalg `matmul`. +/// +/// This class inherit from `DialectOpConversion` and override `rewrite`, +/// similarly to the PatternRewriter introduced in the previous chapter. +/// It will be called by the DialectConversion framework (see `LateLowering` +/// class below). +class MulOpConversion : public DialectOpConversion { +public: + explicit MulOpConversion(MLIRContext *context) + : DialectOpConversion(toy::MulOp::getOperationName(), 1, context) {} + + SmallVector rewrite(Operation *op, ArrayRef operands, + FuncBuilder &rewriter) const override { + using namespace edsc; + using intrinsics::constant_index; + using linalg::intrinsics::range; + using linalg::intrinsics::view; + toy::MulOp mul = op->cast(); + auto loc = mul.getLoc(); + Value *result = memRefTypeCast( + rewriter, rewriter.create(loc, mul.getResult()->getType()) + .getResult()); + Value *lhs = memRefTypeCast(rewriter, operands[0]); + auto memrefLHSTy = lhs->getType().cast(); + Value *rhs = memRefTypeCast(rewriter, operands[1]); + auto memrefRHSTy = rhs->getType().cast(); + mlir::edsc::ScopedContext scope(rewriter, loc); + edsc::ValueHandle r0 = + range(constant_index(0), constant_index(memrefLHSTy.getDimSize(0)), + constant_index(1)); + edsc::ValueHandle r1 = + range(constant_index(0), constant_index(memrefLHSTy.getDimSize(1)), + constant_index(1)); + edsc::ValueHandle r2 = + range(constant_index(0), constant_index(memrefRHSTy.getDimSize(1)), + constant_index(1)); + auto lhsView = view(lhs, {r0, r1}); + auto rhsView = view(rhs, {r1, r2}); + auto resultView = view(result, {r0, r2}); + rewriter.create(loc, lhsView, rhsView, resultView); + return {typeCast(rewriter, result, mul.getType())}; + } +}; + +// The conversion class from Toy IR Dialect to a mix of Linalg and LLVM. +class EarlyLowering : public DialectConversion { +protected: + // Initialize the list of converters. + llvm::DenseSet + initConverters(MLIRContext *context) override { + return ConversionListBuilder::build(&allocator, context); + } + +private: + llvm::BumpPtrAllocator allocator; +}; + +/// This is lowering to Linalg the parts that are computationally intensive +/// (like matmul for example...) while keeping the rest of the code in the Toy +/// dialect. +struct EarlyLoweringPass : public ModulePass { + + void runOnModule() override { + if (failed(EarlyLowering().convert(&getModule()))) { + getModule().getContext()->emitError( + mlir::UnknownLoc::get(getModule().getContext()), + "Error lowering Toy\n"); + signalPassFailure(); + } + } +}; +} // end anonymous namespace + +namespace toy { +Pass *createEarlyLoweringPass() { return new EarlyLoweringPass(); } + +std::unique_ptr makeToyEarlyLowering() { + return llvm::make_unique(); +} + +} // namespace toy diff --git a/toy/LateLowering.cpp b/toy/LateLowering.cpp new file mode 100644 index 0000000..eeae6ee --- /dev/null +++ b/toy/LateLowering.cpp @@ -0,0 +1,452 @@ +//====- LateLowering.cpp - Lowering from Toy+Linalg to LLVM -===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements late lowering of IR mixing Toy and Linalg to LLVM. +// It involves intemerdiate steps: +// - +// - a mix of affine and standard dialect. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" + +#include "linalg3/Intrinsics.h" +#include "linalg1/ViewOp.h" +#include "linalg3/ConvertToLLVMDialect.h" +#include "linalg3/TensorOps.h" +#include "linalg3/Transforms.h" +#include "mlir/EDSC/Builders.h" +#include "mlir/EDSC/Helpers.h" +#include "mlir/EDSC/Intrinsics.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/LLVMIR/LLVMDialect.h" +#include "mlir/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Type.h" + +#include + +using namespace mlir; + +namespace { +/// Utility function for type casting: this is making the type checker happy, +/// while delaying the actual work involved to convert the type. Most of the +/// time both side of the cast (producer and consumer) will be lowered to a +/// dialect like LLVM and end up with the same LLVM representation, at which +/// point this becomes a no-op and is eliminated. +Value *typeCast(FuncBuilder &builder, Value *val, Type destTy) { + if (val->getType() == destTy) + return val; + return builder.create(val->getLoc(), val, destTy) + .getResult(); +} + +/// Create a type cast to turn a toy.array into a memref. The Toy Array will be +/// lowered to a memref during buffer allocation, at which point the type cast +/// becomes useless. +Value *memRefTypeCast(FuncBuilder &builder, Value *val) { + if (val->getType().isa()) + return val; + auto toyArrayTy = val->getType().dyn_cast(); + if (!toyArrayTy) + return val; + return typeCast(builder, val, toyArrayTy.toMemref()); +} + +/// Lower a toy.add to an affine loop nest. +/// +/// This class inherit from `DialectOpConversion` and override `rewrite`, +/// similarly to the PatternRewriter introduced in the previous chapter. +/// It will be called by the DialectConversion framework (see `LateLowering` +/// class below). +class AddOpConversion : public DialectOpConversion { +public: + explicit AddOpConversion(MLIRContext *context) + : DialectOpConversion(toy::AddOp::getOperationName(), 1, context) {} + + /// Lower the `op` by generating IR using the `rewriter` builder. The builder + /// is setup with a new function, the `operands` array has been populated with + /// the rewritten operands for `op` in the new function. + /// The results created by the new IR with the builder are returned, and their + /// number must match the number of result of `op`. + SmallVector rewrite(Operation *op, ArrayRef operands, + FuncBuilder &rewriter) const override { + auto add = op->cast(); + auto loc = add.getLoc(); + // Create a `toy.alloc` operation to allocate the output buffer for this op. + Value *result = memRefTypeCast( + rewriter, rewriter.create(loc, add.getResult()->getType()) + .getResult()); + Value *lhs = memRefTypeCast(rewriter, operands[0]); + Value *rhs = memRefTypeCast(rewriter, operands[1]); + + using namespace edsc; + ScopedContext scope(rewriter, loc); + ValueHandle zero = intrinsics::constant_index(0); + MemRefView vRes(result), vLHS(lhs), vRHS(rhs); + IndexedValue iRes(result), iLHS(lhs), iRHS(rhs); + IndexHandle i, j, M(vRes.ub(0)); + if (vRes.rank() == 1) { + LoopNestBuilder({&i}, {zero}, {M}, {1})({iRes(i) = iLHS(i) + iRHS(i)}); + } else { + assert(vRes.rank() == 2 && "only rank 1 and 2 are supported right now"); + IndexHandle N(vRes.ub(1)); + LoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, + {1, 1})({iRes(i, j) = iLHS(i, j) + iRHS(i, j)}); + } + + // Return the newly allocated buffer, with a type.cast to preserve the + // consumers. + return {typeCast(rewriter, result, add.getType())}; + } +}; + +/// Lowers `toy.print` to a loop nest calling `printf` on every individual +/// elements of the array. +class PrintOpConversion : public DialectOpConversion { +public: + explicit PrintOpConversion(MLIRContext *context) + : DialectOpConversion(toy::PrintOp::getOperationName(), 1, context) {} + + SmallVector rewrite(Operation *op, ArrayRef operands, + FuncBuilder &rewriter) const override { + // Get or create the declaration of the printf function in the module. + Function *printfFunc = getPrintf(*op->getFunction()->getModule()); + + auto print = op->cast(); + auto loc = print.getLoc(); + // We will operate on a MemRef abstraction, we use a type.cast to get one + // if our operand is still a Toy array. + Value *operand = memRefTypeCast(rewriter, operands[0]); + Type retTy = printfFunc->getType().getResult(0); + + // Create our loop nest now + using namespace edsc; + using llvmCall = intrinsics::ValueBuilder; + ScopedContext scope(rewriter, loc); + ValueHandle zero = intrinsics::constant_index(0); + ValueHandle fmtCst(getConstantCharBuffer(rewriter, loc, "%f ")); + MemRefView vOp(operand); + IndexedValue iOp(operand); + IndexHandle i, j, M(vOp.ub(0)); + + ValueHandle fmtEol(getConstantCharBuffer(rewriter, loc, "\n")); + if (vOp.rank() == 1) { + // clang-format off + LoopBuilder(&i, zero, M, 1)({ + llvmCall(retTy, + rewriter.getFunctionAttr(printfFunc), + {fmtCst, iOp(i)}) + }); + llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol}); + // clang-format on + } else { + IndexHandle N(vOp.ub(1)); + // clang-format off + LoopBuilder(&i, zero, M, 1)({ + LoopBuilder(&j, zero, N, 1)({ + llvmCall(retTy, + rewriter.getFunctionAttr(printfFunc), + {fmtCst, iOp(i, j)}) + }), + llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol}) + }); + // clang-format on + } + return {}; + } + +private: + // Turn a string into a toy.alloc (malloc/free abstraction) and a sequence + // of stores into the buffer, and return a MemRef into the buffer. + Value *getConstantCharBuffer(FuncBuilder &builder, Location loc, + StringRef data) const { + auto retTy = + builder.getMemRefType(data.size() + 1, builder.getIntegerType(8)); + Value *result = builder.create(loc, retTy).getResult(); + using namespace edsc; + using intrinsics::constant_index; + using intrinsics::constant_int; + ScopedContext scope(builder, loc); + MemRefView vOp(result); + IndexedValue iOp(result); + for (uint64_t i = 0; i < data.size(); ++i) { + iOp(constant_index(i)) = constant_int(data[i], 8); + } + iOp(constant_index(data.size())) = constant_int(0, 8); + return result; + } + + /// Return the prototype declaration for printf in the module, create it if + /// necessary. + Function *getPrintf(Module &module) const { + auto *printfFunc = module.getNamedFunction("printf"); + if (printfFunc) + return printfFunc; + + // Create a function declaration for printf, signature is `i32 (i8*, ...)` + Builder builder(&module); + MLIRContext *context = module.getContext(); + LLVM::LLVMDialect *llvmDialect = static_cast( + module.getContext()->getRegisteredDialect("llvm")); + auto &llvmModule = llvmDialect->getLLVMModule(); + llvm::IRBuilder<> llvmBuilder(llvmModule.getContext()); + + auto llvmI32Ty = LLVM::LLVMType::get(context, llvmBuilder.getIntNTy(32)); + auto llvmI8PtrTy = + LLVM::LLVMType::get(context, llvmBuilder.getIntNTy(8)->getPointerTo()); + auto printfTy = builder.getFunctionType({llvmI8PtrTy}, {llvmI32Ty}); + printfFunc = new Function(builder.getUnknownLoc(), "printf", printfTy); + // It should be variadic, but we don't support it fully just yet. + printfFunc->setAttr("std.varargs", builder.getBoolAttr(true)); + module.getFunctions().push_back(printfFunc); + return printfFunc; + } +}; + +/// Lowers constant to a sequence of store in a buffer. +class ConstantOpConversion : public DialectOpConversion { +public: + explicit ConstantOpConversion(MLIRContext *context) + : DialectOpConversion(toy::ConstantOp::getOperationName(), 1, context) {} + + SmallVector rewrite(Operation *op, ArrayRef operands, + FuncBuilder &rewriter) const override { + toy::ConstantOp cstOp = op->cast(); + auto loc = cstOp.getLoc(); + auto retTy = cstOp.getResult()->getType().cast(); + auto shape = retTy.getShape(); + Value *result = memRefTypeCast( + rewriter, rewriter.create(loc, retTy).getResult()); + + auto cstValue = cstOp.getValue(); + auto f64Ty = rewriter.getF64Type(); + using namespace edsc; + using intrinsics::constant_float; + using intrinsics::constant_index; + ScopedContext scope(rewriter, loc); + MemRefView vOp(result); + IndexedValue iOp(result); + for (uint64_t i = 0; i < shape[0]; ++i) { + if (shape.size() == 1) { + auto value = cstValue.getValue(ArrayRef{i}) + .cast() + .getValue(); + iOp(constant_index(i)) = constant_float(value, f64Ty); + continue; + } + for (uint64_t j = 0; j < shape[1]; ++j) { + auto value = cstValue.getValue(ArrayRef{i, j}) + .cast() + .getValue(); + iOp(constant_index(i), constant_index(j)) = + constant_float(value, f64Ty); + } + } + return {result}; + } +}; + +/// Lower transpose operation to an affine loop nest. +class TransposeOpConversion : public DialectOpConversion { +public: + explicit TransposeOpConversion(MLIRContext *context) + : DialectOpConversion(toy::TransposeOp::getOperationName(), 1, context) {} + + SmallVector rewrite(Operation *op, ArrayRef operands, + FuncBuilder &rewriter) const override { + auto transpose = op->cast(); + auto loc = transpose.getLoc(); + Value *result = memRefTypeCast( + rewriter, + rewriter.create(loc, transpose.getResult()->getType()) + .getResult()); + Value *operand = memRefTypeCast(rewriter, operands[0]); + + using namespace edsc; + ScopedContext scope(rewriter, loc); + ValueHandle zero = intrinsics::constant_index(0); + MemRefView vRes(result), vOperand(operand); + IndexedValue iRes(result), iOperand(operand); + IndexHandle i, j, M(vRes.ub(0)), N(vRes.ub(1)); + // clang-format off + LoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, {1, 1})({ + iRes(i, j) = iOperand(j, i) + }); + // clang-format on + + return {typeCast(rewriter, result, transpose.getType())}; + } +}; + +// Lower toy.return to standard return operation. +class ReturnOpConversion : public DialectOpConversion { +public: + explicit ReturnOpConversion(MLIRContext *context) + : DialectOpConversion(toy::ReturnOp::getOperationName(), 1, context) {} + + SmallVector rewrite(Operation *op, ArrayRef operands, + FuncBuilder &rewriter) const override { + auto retOp = op->cast(); + using namespace edsc; + auto loc = retOp.getLoc(); + // Argument is optional, handle both cases. + if (retOp.getNumOperands()) + rewriter.create(loc, operands[0]); + else + rewriter.create(loc); + return {}; + } +}; + +/// This is the main class registering our individual converter classes with +/// the DialectConversion framework in MLIR. +class LateLowering : public DialectConversion { +protected: + /// Initialize the list of converters. + llvm::DenseSet + initConverters(MLIRContext *context) override { + return ConversionListBuilder::build(&allocator, + context); + } + + /// Convert a Toy type, this gets called for block and region arguments, and + /// attributes. + Type convertType(Type t) override { + if (auto array = t.cast()) { + return array.toMemref(); + } + return t; + } + +private: + llvm::BumpPtrAllocator allocator; +}; + +/// This is lowering to Linalg the parts that can be (matmul and add on arrays) +/// and is targeting LLVM otherwise. +struct LateLoweringPass : public ModulePass { + + void runOnModule() override { + // Perform Toy specific lowering + if (failed(LateLowering().convert(&getModule()))) { + getModule().getContext()->emitError( + UnknownLoc::get(getModule().getContext()), "Error lowering Toy\n"); + signalPassFailure(); + } + // At this point the IR is almost using only standard and affine dialects. + // A few things remain before we emit LLVM IR. First to reuse as much of + // MLIR as possible we will try to lower everything to the standard and/or + // affine dialect: they already include conversion to the LLVM dialect. + + // First patch calls type to return memref instead of ToyArray + for (auto &function : getModule()) { + function.walk([&](Operation *op) { + auto callOp = op->dyn_cast(); + if (!callOp) + return; + if (!callOp.getNumResults()) + return; + auto retToyTy = + callOp.getResult(0)->getType().dyn_cast(); + if (!retToyTy) + return; + callOp.getResult(0)->setType(retToyTy.toMemref()); + }); + } + + for (auto &function : getModule()) { + function.walk([&](Operation *op) { + // Turns toy.alloc into sequence of alloc/dealloc (later malloc/free). + if (auto allocOp = op->dyn_cast()) { + auto result = allocTensor(allocOp); + allocOp.replaceAllUsesWith(result); + allocOp.erase(); + return; + } + // Eliminate all type.cast before lowering to LLVM. + if (auto typeCastOp = op->dyn_cast()) { + typeCastOp.replaceAllUsesWith(typeCastOp.getOperand()); + typeCastOp.erase(); + return; + } + }); + } + + // Lower Linalg to affine + for (auto &function : getModule()) + linalg::lowerToLoops(&function); + + getModule().dump(); + + // Finally convert to LLVM Dialect + linalg::convertLinalg3ToLLVM(getModule()); + } + + /// Allocate buffers (malloc/free) for Toy operations. This can't be done as + /// part of dialect conversion framework since we need to insert `dealloc` + /// operations just before the return, but the conversion framework is + /// operating in a brand new function: we don't have the return to hook the + /// dealloc operations. + Value *allocTensor(toy::AllocOp alloc) { + FuncBuilder builder(alloc); + auto retTy = alloc.getResult()->getType(); + + auto memRefTy = retTy.dyn_cast(); + if (!memRefTy) + memRefTy = retTy.cast().toMemref(); + if (!memRefTy) { + alloc.emitOpError("is expected to allocate a Toy array or a MemRef"); + llvm_unreachable("fatal error"); + } + auto loc = alloc.getLoc(); + Value *result = builder.create(loc, memRefTy).getResult(); + + // Insert a `dealloc` operation right before the `return` operations, unless + // it is returned itself in which case the caller is responsible for it. + builder.getFunction()->walk([&](Operation *op) { + auto returnOp = op->dyn_cast(); + if (!returnOp) + return; + if (returnOp.getNumOperands() && returnOp.getOperand(0) == alloc) + return; + builder.setInsertionPoint(returnOp); + builder.create(alloc.getLoc(), result); + }); + return result; + } +}; +} // end anonymous namespace + +namespace toy { +Pass *createLateLoweringPass() { return new LateLoweringPass(); } + +std::unique_ptr makeToyLateLowering() { + return llvm::make_unique(); +} + +} // namespace toy diff --git a/toy/MLIRGen.cpp b/toy/MLIRGen.cpp new file mode 100644 index 0000000..e2001fb --- /dev/null +++ b/toy/MLIRGen.cpp @@ -0,0 +1,480 @@ +//===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a simple IR generation targeting MLIR from a Module AST +// for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/MLIRGen.h" +#include "toy/AST.h" +#include "toy/Dialect.h" + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" +#include "mlir/StandardOps/Ops.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopedHashTable.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace toy; +using llvm::cast; +using llvm::dyn_cast; +using llvm::isa; +using llvm::make_unique; +using llvm::ScopedHashTableScope; +using llvm::SmallVector; +using llvm::StringRef; +using llvm::Twine; + +namespace { + +/// Implementation of a simple MLIR emission from the Toy AST. +/// +/// This will emit operations that are specific to the Toy language, preserving +/// the semantics of the language and (hopefully) allow to perform accurate +/// analysis and transformation based on these high level semantics. +/// +/// At this point we take advantage of the "raw" MLIR APIs to create operations +/// that haven't been registered in any way with MLIR. These operations are +/// unknown to MLIR, custom passes could operate by string-matching the name of +/// these operations, but no other type checking or semantic is associated with +/// them natively by MLIR. +class MLIRGenImpl { +public: + MLIRGenImpl(mlir::MLIRContext &context) : context(context) {} + + /// Public API: convert the AST for a Toy module (source file) to an MLIR + /// Module. + std::unique_ptr mlirGen(ModuleAST &moduleAST) { + // We create an empty MLIR module and codegen functions one at a time and + // add them to the module. + theModule = make_unique(&context); + + for (FunctionAST &F : moduleAST) { + auto func = mlirGen(F); + if (!func) + return nullptr; + theModule->getFunctions().push_back(func.release()); + } + + // FIXME: (in the next chapter...) without registering a dialect in MLIR, + // this won't do much, but it should at least check some structural + // properties. + if (failed(theModule->verify())) { + context.emitError(mlir::UnknownLoc::get(&context), + "Module verification error"); + return nullptr; + } + + return std::move(theModule); + } + +private: + /// In MLIR (like in LLVM) a "context" object holds the memory allocation and + /// the ownership of many internal structure of the IR and provide a level + /// of "uniquing" across multiple modules (types for instance). + mlir::MLIRContext &context; + + /// A "module" matches a source file: it contains a list of functions. + std::unique_ptr theModule; + + /// The builder is a helper class to create IR inside a function. It is + /// re-initialized every time we enter a function and kept around as a + /// convenience for emitting individual operations. + /// The builder is stateful, in particular it keeeps an "insertion point": + /// this is where the next operations will be introduced. + std::unique_ptr builder; + + /// The symbol table maps a variable name to a value in the current scope. + /// Entering a function creates a new scope, and the function arguments are + /// added to the mapping. When the processing of a function is terminated, the + /// scope is destroyed and the mappings created in this scope are dropped. + llvm::ScopedHashTable symbolTable; + + /// Helper conversion for a Toy AST location to an MLIR location. + mlir::FileLineColLoc loc(Location loc) { + return mlir::FileLineColLoc::get( + mlir::UniquedFilename::get(*loc.file, &context), loc.line, loc.col, + &context); + } + + /// Declare a variable in the current scope, return true if the variable + /// wasn't declared yet. + bool declare(llvm::StringRef var, mlir::Value *value) { + if (symbolTable.count(var)) { + return false; + } + symbolTable.insert(var, value); + return true; + } + + /// Create the prototype for an MLIR function with as many arguments as the + /// provided Toy AST prototype. + mlir::Function *mlirGen(PrototypeAST &proto) { + // This is a generic function, the return type will be inferred later. + llvm::SmallVector ret_types; + // Arguments type is uniformly a generic array. + llvm::SmallVector arg_types(proto.getArgs().size(), + getType(VarType{})); + auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context); + auto *function = new mlir::Function(loc(proto.loc()), proto.getName(), + func_type, /* attrs = */ {}); + + // Mark the function as generic: it'll require type specialization for every + // call site. + if (function->getNumArguments()) + function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context)); + + return function; + } + + /// Emit a new function and add it to the MLIR module. + std::unique_ptr mlirGen(FunctionAST &funcAST) { + // Create a scope in the symbol table to hold variable declarations. + ScopedHashTableScope var_scope(symbolTable); + + // Create an MLIR function for the given prototype. + std::unique_ptr function(mlirGen(*funcAST.getProto())); + if (!function) + return nullptr; + + // Let's start the body of the function now! + // In MLIR the entry block of the function is special: it must have the same + // argument list as the function itself. + function->addEntryBlock(); + + auto &entryBlock = function->front(); + auto &protoArgs = funcAST.getProto()->getArgs(); + // Declare all the function arguments in the symbol table. + for (const auto &name_value : + llvm::zip(protoArgs, entryBlock.getArguments())) { + declare(std::get<0>(name_value)->getName(), std::get<1>(name_value)); + } + + // Create a builder for the function, it will be used throughout the codegen + // to create operations in this function. + builder = llvm::make_unique(function.get()); + + // Emit the body of the function. + if (!mlirGen(*funcAST.getBody())) + return nullptr; + + // Implicitly return void if no return statement was emited. + // FIXME: we may fix the parser instead to always return the last expression + // (this would possibly help the REPL case later) + if (function->getBlocks().back().back().getName().getStringRef() != + "toy.return") { + ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None); + mlirGen(fakeRet); + } + + return function; + } + + /// Emit a binary operation + mlir::Value *mlirGen(BinaryExprAST &binop) { + // First emit the operations for each side of the operation before emitting + // the operation itself. For example if the expression is `a + foo(a)` + // 1) First it will visiting the LHS, which will return a reference to the + // value holding `a`. This value should have been emitted at declaration + // time and registered in the symbol table, so nothing would be + // codegen'd. If the value is not in the symbol table, an error has been + // emitted and nullptr is returned. + // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted + // and the result value is returned. If an error occurs we get a nullptr + // and propagate. + // + mlir::Value *L = mlirGen(*binop.getLHS()); + if (!L) + return nullptr; + mlir::Value *R = mlirGen(*binop.getRHS()); + if (!R) + return nullptr; + auto location = loc(binop.loc()); + + // Derive the operation name from the binary operator. At the moment we only + // support '+' and '*'. + switch (binop.getOp()) { + case '+': + return builder->create(location, L, R).getResult(); + break; + case '*': + return builder->create(location, L, R).getResult(); + default: + context.emitError(loc(binop.loc()), + Twine("Error: invalid binary operator '") + + Twine(binop.getOp()) + "'"); + return nullptr; + } + } + + // This is a reference to a variable in an expression. The variable is + // expected to have been declared and so should have a value in the symbol + // table, otherwise emit an error and return nullptr. + mlir::Value *mlirGen(VariableExprAST &expr) { + if (symbolTable.count(expr.getName())) + return symbolTable.lookup(expr.getName()); + context.emitError(loc(expr.loc()), Twine("Error: unknown variable '") + + expr.getName() + "'"); + return nullptr; + } + + // Emit a return operation, return true on success. + bool mlirGen(ReturnExprAST &ret) { + auto location = loc(ret.loc()); + // `return` takes an optional expression, we need to account for it here. + if (!ret.getExpr().hasValue()) { + builder->create(location); + return true; + } + auto *expr = mlirGen(*ret.getExpr().getValue()); + if (!expr) + return false; + builder->create(location, expr); + return true; + } + + // Emit a literal/constant array. It will be emitted as a flattened array of + // data in an Attribute attached to a `toy.constant` operation. + // See documentation on [Attributes](LangRef.md#attributes) for more details. + // Here is an excerpt: + // + // Attributes are the mechanism for specifying constant data in MLIR in + // places where a variable is never allowed [...]. They consist of a name + // and a [concrete attribute value](#attribute-values). It is possible to + // attach attributes to operations, functions, and function arguments. The + // set of expected attributes, their structure, and their interpretation + // are all contextually dependent on what they are attached to. + // + // Example, the source level statement: + // var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + // will be converted to: + // %0 = "toy.constant"() {value: dense, + // [[1.000000e+00, 2.000000e+00, 3.000000e+00], + // [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> memref<2x3xf64> + // + mlir::Value *mlirGen(LiteralExprAST &lit) { + auto location = loc(lit.loc()); + // The attribute is a vector with an attribute per element (number) in the + // array, see `collectData()` below for more details. + std::vector data; + data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, + std::multiplies())); + collectData(lit, data); + + // FIXME: using a tensor type is a HACK here. + // Can we do differently without registering a dialect? Using a string blob? + mlir::Type elementType = mlir::FloatType::getF64(&context); + auto dataType = builder->getTensorType(lit.getDims(), elementType); + + // This is the actual attribute that actually hold the list of values for + // this array literal. + auto dataAttribute = builder->getDenseElementsAttr(dataType, data) + .cast(); + + // Build the MLIR op `toy.constant`, only boilerplate below. + return builder->create(location, lit.getDims(), dataAttribute) + .getResult(); + } + + // Recursive helper function to accumulate the data that compose an array + // literal. It flattens the nested structure in the supplied vector. For + // example with this array: + // [[1, 2], [3, 4]] + // we will generate: + // [ 1, 2, 3, 4 ] + // Individual numbers are wrapped in a light wrapper `mlir::FloatAttr`. + // Attributes are the way MLIR attaches constant to operations and functions. + void collectData(ExprAST &expr, std::vector &data) { + if (auto *lit = dyn_cast(&expr)) { + for (auto &value : lit->getValues()) + collectData(*value, data); + return; + } + assert(isa(expr) && "expected literal or number expr"); + mlir::Type elementType = mlir::FloatType::getF64(&context); + auto attr = mlir::FloatAttr::getChecked( + elementType, cast(expr).getValue(), loc(expr.loc())); + data.push_back(attr); + } + + // Emit a call expression. It emits specific operations for the `transpose` + // builtin. Other identifiers are assumed to be user-defined functions. + mlir::Value *mlirGen(CallExprAST &call) { + auto location = loc(call.loc()); + std::string callee = call.getCallee(); + if (callee == "transpose") { + if (call.getArgs().size() != 1) { + context.emitError( + location, Twine("MLIR codegen encountered an error: toy.transpose " + "does not accept multiple arguments")); + return nullptr; + } + mlir::Value *arg = mlirGen(*call.getArgs()[0]); + return builder->create(location, arg).getResult(); + } + + // Codegen the operands first + SmallVector operands; + for (auto &expr : call.getArgs()) { + auto *arg = mlirGen(*expr); + if (!arg) + return nullptr; + operands.push_back(arg); + } + // Calls to user-defined function are mapped to a custom call that takes + // the callee name as an attribute. + return builder->create(location, call.getCallee(), operands) + .getResult(); + } + + // Emit a call expression. It emits specific operations for two builtins: + // transpose(x) and print(x). Other identifiers are assumed to be user-defined + // functions. Return false on failure. + bool mlirGen(PrintExprAST &call) { + auto *arg = mlirGen(*call.getArg()); + if (!arg) + return false; + auto location = loc(call.loc()); + builder->create(location, arg); + return true; + } + + // Emit a constant for a single number (FIXME: semantic? broadcast?) + mlir::Value *mlirGen(NumberExprAST &num) { + auto location = loc(num.loc()); + mlir::Type elementType = mlir::FloatType::getF64(&context); + auto attr = mlir::FloatAttr::getChecked(elementType, num.getValue(), + loc(num.loc())); + return builder->create(location, attr).getResult(); + } + + // Dispatch codegen for the right expression subclass using RTTI. + mlir::Value *mlirGen(ExprAST &expr) { + switch (expr.getKind()) { + case toy::ExprAST::Expr_BinOp: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Var: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Literal: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Call: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Num: + return mlirGen(cast(expr)); + default: + context.emitError( + loc(expr.loc()), + Twine("MLIR codegen encountered an unhandled expr kind '") + + Twine(expr.getKind()) + "'"); + return nullptr; + } + } + + // Handle a variable declaration, we'll codegen the expression that forms the + // initializer and record the value in the symbol table before returning it. + // Future expressions will be able to reference this variable through symbol + // table lookup. + mlir::Value *mlirGen(VarDeclExprAST &vardecl) { + mlir::Value *value = nullptr; + auto location = loc(vardecl.loc()); + if (auto init = vardecl.getInitVal()) { + value = mlirGen(*init); + if (!value) + return nullptr; + // We have the initializer value, but in case the variable was declared + // with specific shape, we emit a "reshape" operation. It will get + // optimized out later as needed. + if (!vardecl.getType().shape.empty()) { + value = builder + ->create( + location, value, + getType(vardecl.getType()).cast()) + .getResult(); + } + } else { + context.emitError(loc(vardecl.loc()), + "Missing initializer in variable declaration"); + return nullptr; + } + // Register the value in the symbol table + declare(vardecl.getName(), value); + return value; + } + + /// Codegen a list of expression, return false if one of them hit an error. + bool mlirGen(ExprASTList &blockAST) { + ScopedHashTableScope var_scope(symbolTable); + for (auto &expr : blockAST) { + // Specific handling for variable declarations, return statement, and + // print. These can only appear in block list and not in nested + // expressions. + if (auto *vardecl = dyn_cast(expr.get())) { + if (!mlirGen(*vardecl)) + return false; + continue; + } + if (auto *ret = dyn_cast(expr.get())) { + if (!mlirGen(*ret)) + return false; + return true; + } + if (auto *print = dyn_cast(expr.get())) { + if (!mlirGen(*print)) + return false; + continue; + } + // Generic expression dispatch codegen. + if (!mlirGen(*expr)) + return false; + } + return true; + } + + /// Build a type from a list of shape dimensions. Types are `array` followed + /// by an optional dimension list, example: array<2, 2> + /// They are wrapped in a `toy` dialect (see next chapter) and get printed: + /// !toy.array<2, 2> + template mlir::Type getType(T shape) { + SmallVector shape64(shape.begin(), shape.end()); + return ToyArrayType::get(&context, shape64); + } + + /// Build an MLIR type from a Toy AST variable type + /// (forward to the generic getType(T) above). + mlir::Type getType(const VarType &type) { return getType(type.shape); } +}; + +} // namespace + +namespace toy { + +// The public API for codegen. +std::unique_ptr mlirGen(mlir::MLIRContext &context, + ModuleAST &moduleAST) { + return MLIRGenImpl(context).mlirGen(moduleAST); +} + +} // namespace toy diff --git a/toy/RegisterDialects.cpp b/toy/RegisterDialects.cpp new file mode 100644 index 0000000..7314cb2 --- /dev/null +++ b/toy/RegisterDialects.cpp @@ -0,0 +1,10 @@ +// +// Created by ztuowen on 4/27/19. +// + +#include "mlir/IR/Dialect.h" +#include "toy/Dialect.h" +#include "linalg1/Dialect.h" + +static mlir::DialectRegistration Toy; +static mlir::DialectRegistration Linalg; diff --git a/toy/ShapeInferencePass.cpp b/toy/ShapeInferencePass.cpp new file mode 100644 index 0000000..7e3ea3f --- /dev/null +++ b/toy/ShapeInferencePass.cpp @@ -0,0 +1,387 @@ +//===- ShapeInferencePass.cpp - Toy Shape Inference / Func Specialization -===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a Module level pass performing interprocedural +// propagation of array shapes through function specialization. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" + +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/StandardOps/Ops.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include + +#define DEBUG_TYPE "toy-shape-inference" + +using namespace toy; +using llvm::MutableArrayRef; +using llvm::SmallVector; +using llvm::SmallVectorImpl; +using llvm::StringRef; +using llvm::Twine; + +/// Create mangled name for function specialization. We will simply append the +/// shape of the arguments to the function name. For example calling +/// +/// "toy.generic_call"(%1, %3) {callee: "foo"} +/// : (!toy<"array<2, 3>">, !toy<"array<2, 3>">) -> !toy<"array"> +/// +/// would be mangled foo_2x3_2x3. This mangling isn't robust as the user could +/// have provide a function with a similar name. But we will claim this as a +/// feature: this allow the user to provide custom specialization! +static std::string mangle(StringRef funcName, + MutableArrayRef operands) { + std::string mangledName; + mangledName.reserve(funcName.size() + operands.size() * 6); + mangledName = funcName; + for (auto &operand : operands) { + auto arrayTy = operand.get()->getType().cast(); + mangledName += "_"; + const char *sep = ""; + for (auto dim : arrayTy.getShape()) { + mangledName += (sep + Twine(dim)).str(); + sep = "x"; + } + } + return mangledName; +} + +namespace { + +/// The ShapeInferencePass is a ModulePass: it will run on the Module as a +/// whole. MLIR also supports FunctionPass which are restricted to modify a +/// single function at a time. This pass couldn't be a function pass due the +/// nature of its interprocedural transformations. +/// +/// The algorithm has two levels, first intra-procedurally: +/// +/// 1) Build a worklist containing all the operations that are returning +/// a generic Toy array: these are the operations that need shape +/// inference. +/// 2) Iterate on the worklist: +/// a) find an operation to process: the next ready operation in the +/// worklist has all of its arguments non-generic, +/// b) if no operation is found, break out of the loop, +/// c) remove the operation from the worklist, +/// d) infer the shape of its output from the arguments type. +/// 3) If the worklist is empty, the algorithm succeeded and we infer the +/// return type for the function from the return operation. +/// +/// There is a twist though: when a call to a generic function is encountered, +/// shape inference requires the return type of the callee to be inferred first. +/// At this point we need to run specialize the callee by cloning it. Here is +/// the inter-procedural flow: +/// +/// 1) Keep a worklist of function to process. Start with function "main". +/// 2) While the worklist isn't empty: +/// a) Take the last inserted function in the worklist. +/// b) Run the intra-procedural shape inference on this function. +/// c) If the intra-procedural shape inference can't complete, it returns +/// a Function that needs to be inferred first. In this case, queue this +/// new function and continue. Otherwise the inference succeeded and we +/// can pop from the queue. +/// +class ShapeInferencePass : public mlir::ModulePass { +public: + // One entry in the inter-procedural worklist. It keeps track of the + // function to process, the mangled name for this specialization, and the + // types of the arguments on which to specialize. + struct FunctionToSpecialize { + mlir::Function *function; + std::string mangledName; + std::vector argumentsType; + }; + + void runOnModule() override { + auto &module = getModule(); + auto *main = module.getNamedFunction("main"); + if (!main) { + module.getContext()->emitError( + mlir::UnknownLoc::get(module.getContext()), + "Shape inference failed: can't find a main function\n"); + signalPassFailure(); + return; + } + + /// Inter-procedural loop, initialize with `main` and iterate till + /// successfully infer the full reachable call-graph from main. + SmallVector worklist; + worklist.push_back({main, "", {}}); + while (!worklist.empty()) { + if (failed(specialize(worklist))) + return; + } + + // Delete any generic function left + // FIXME: we may want this as a separate pass. + for (mlir::Function &function : llvm::make_early_inc_range(module)) { + if (auto genericAttr = + function.getAttrOfType("toy.generic")) { + if (genericAttr.getValue()) + function.erase(); + } + } + } + + /// Run inference on a function. If a mangledName is provided, we need to + /// specialize the function: to this end clone it first. + mlir::LogicalResult + specialize(SmallVectorImpl &funcWorklist) { + FunctionToSpecialize &functionToSpecialize = funcWorklist.back(); + mlir::Function *f = functionToSpecialize.function; + + // Check if cloning for specialization is needed (usually anything but main) + // We will create a new function with the concrete types for the parameters + // and clone the body into it. + if (!functionToSpecialize.mangledName.empty()) { + if (getModule().getNamedFunction(functionToSpecialize.mangledName)) { + funcWorklist.pop_back(); + // Function already specialized, move on. + return mlir::success(); + } + // Create a new function with a generic array return type, it will be + // updated when the inference for the function body completes. + auto type = mlir::FunctionType::get(functionToSpecialize.argumentsType, + {ToyArrayType::get(&getContext())}, + &getContext()); + auto *newFunction = new mlir::Function( + f->getLoc(), functionToSpecialize.mangledName, type, f->getAttrs()); + getModule().getFunctions().push_back(newFunction); + + // Clone the function body + mlir::BlockAndValueMapping mapper; + f->cloneInto(newFunction, mapper); + LLVM_DEBUG({ + llvm::dbgs() << "====== Cloned : \n"; + f->dump(); + llvm::dbgs() << "====== Into : \n"; + newFunction->dump(); + }); + f = newFunction; + f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); + // Remap the entry-block arguments + // FIXME: this seems like a bug in `cloneInto()` above? + auto &entryBlock = f->getBlocks().front(); + int blockArgSize = entryBlock.getArguments().size(); + assert(blockArgSize == f->getType().getInputs().size()); + entryBlock.addArguments(f->getType().getInputs()); + auto argList = entryBlock.getArguments(); + for (int argNum = 0; argNum < blockArgSize; ++argNum) { + argList[0]->replaceAllUsesWith(argList[blockArgSize]); + entryBlock.eraseArgument(0); + } + assert(succeeded(f->verify())); + } + LLVM_DEBUG(llvm::dbgs() + << "Run shape inference on : '" << f->getName() << "'\n"); + + auto *toyDialect = getContext().getRegisteredDialect("toy"); + if (!toyDialect) { + getContext().emitError(mlir::UnknownLoc::get(&getContext()), + "Toy dialect is not registered"); + signalPassFailure(); + return mlir::failure(); + } + + // Populate the worklist with the operations that need shape inference: + // these are the Toy operations that return a generic array. + llvm::SmallPtrSet opWorklist; + f->walk([&](mlir::Operation *op) { + if (op->getDialect() == toyDialect) { + if (op->getNumResults() == 1 && + op->getResult(0)->getType().cast().isGeneric()) + opWorklist.insert(op); + } + }); + + // Iterate on the operations in the worklist until all operations have been + // inferred or no change happened (fix point). + while (!opWorklist.empty()) { + // Find the next operation ready for inference, that is an operation + // with all operands already resolved (non-generic). + auto nextop = llvm::find_if(opWorklist, [](mlir::Operation *op) { + return llvm::all_of(op->getOperands(), [](mlir::Value *v) { + return !v->getType().cast().isGeneric(); + }); + }); + if (nextop == opWorklist.end()) + break; // failure: no operations can be inferred. + + mlir::Operation *op = *nextop; + opWorklist.erase(op); + LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); + + // The add operation is trivial: propagate the input type as is. + if (auto addOp = op->dyn_cast()) { + op->getResult(0)->setType(op->getOperand(0)->getType()); + continue; + } + + // Transpose is easy: just invert the dimensions. + if (op->getName().getStringRef() == "toy.transpose") { + SmallVector dims; + auto arrayTy = op->getOperand(0)->getType().cast(); + dims.insert(dims.end(), arrayTy.getShape().begin(), + arrayTy.getShape().end()); + if (dims.size() == 2) + std::swap(dims[0], dims[1]); + op->getResult(0)->setType(ToyArrayType::get(&getContext(), dims)); + continue; + } + + // Multiplication is a bit trickier, handle rank 1 as dot product and rank + // 2 as matrix multiplications. + // We need to be careful about rank mismatch here: the verifier could + // catch it but shape inference earlier in the pass could generate an + // invalid IR (from an invalid Toy input of course) and we wouldn't want + // to crash here. + if (auto mulOp = op->dyn_cast()) { + auto lhs = mulOp.getLHS()->getType().cast(); + auto rhs = mulOp.getRHS()->getType().cast(); + auto lhsRank = lhs.getShape().size(); + auto rhsRank = rhs.getShape().size(); + if (lhsRank != rhsRank) { + op->emitError("Shape mismatch: LHS and RHS must have the same " + "rank for multiplication, got " + + Twine(lhsRank) + " vs " + Twine(lhsRank)); + return mlir::failure(); + } + SmallVector dims; + if (lhsRank == 1) { + // dot product, result shape is <1> + dims.push_back(1); + } else { + if (lhsRank != 2) { + op->emitError( + "Shape mismatch: expect rank 1 or 2 for mul operands, got " + + Twine(lhsRank)); + return mlir::failure(); + } + dims.push_back(lhs.getShape()[0]); + dims.push_back(rhs.getShape()[1]); + } + op->getResult(0)->setType(ToyArrayType::get(&getContext(), dims)); + continue; + } + + // Process calls: lookup the callee after mangling the name with the + // argument shapes. If the callee does not exist, we stop the inference + // for this function, queue the callee in the inter-procedural work list, + // and return. The current function stays in the work list and will + // restart after the callee is processed. + if (auto callOp = op->dyn_cast()) { + auto calleeName = callOp.getCalleeName(); + auto *callee = getModule().getNamedFunction(calleeName); + if (!callee) { + f->emitError( + llvm::Twine("Shape inference failed, call to unknown '") + + calleeName + "'"); + signalPassFailure(); + return mlir::failure(); + } + auto mangledName = mangle(calleeName, op->getOpOperands()); + LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName + << "', mangled: '" << mangledName << "'\n"); + auto *mangledCallee = getModule().getNamedFunction(mangledName); + if (!mangledCallee) { + // Can't find the target, this is where we queue the request for the + // callee and stop the inference for the current function now. + std::vector funcArgs; + for (auto operand : op->getOperands()) + funcArgs.push_back(operand->getType()); + funcWorklist.push_back( + {callee, std::move(mangledName), std::move(funcArgs)}); + return mlir::success(); + } + // Found a specialized callee! Let's turn this into a normal call + // operation. + SmallVector operands; + for (mlir::Value *v : op->getOperands()) + operands.push_back(v); + mlir::FuncBuilder builder(f); + builder.setInsertionPoint(op); + auto newCall = + builder.create(op->getLoc(), mangledCallee, operands); + if (newCall.getNumResults()) { + op->getResult(0)->replaceAllUsesWith(newCall.getResult(0)); + op->erase(); + continue; + } + } + } + + // Done with inference on this function, removing it from the worklist. + funcWorklist.pop_back(); + // Mark the function as non-generic now that inference has succeeded + f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); + + // If the operation worklist isn't empty, this indicates a failure. + if (!opWorklist.empty()) { + std::string str; + llvm::raw_string_ostream errorMsg(str); + errorMsg << "Shape inference failed, " << opWorklist.size() + << " operations couldn't be inferred\n"; + for (auto *ope : opWorklist) + errorMsg << " - " << *ope << "\n"; + f->emitError(errorMsg.str()); + signalPassFailure(); + return mlir::failure(); + } + + // Finally, update the return type of the function based on the argument to + // the return operation. + for (auto &block : f->getBlocks()) { + auto ret = block.getTerminator()->cast(); + if (!ret) + continue; + if (ret.getNumOperands() && + f->getType().getResult(0) == ret.getOperand()->getType()) + // type match, we're done + break; + SmallVector retTy; + if (ret.getNumOperands()) + retTy.push_back(ret.getOperand()->getType()); + mlir::Type elementType = mlir::FloatType::getF64(&getContext()); + std::vector argumentsType; + for (auto arg : f->getArguments()) + argumentsType.push_back(arg->getType()); + auto newType = + mlir::FunctionType::get(argumentsType, retTy, &getContext()); + f->setType(newType); + assert(succeeded(f->verify())); + break; + } + return mlir::success(); + } +}; +} // end anonymous namespace + +namespace toy { +mlir::Pass *createShapeInferencePass() { return new ShapeInferencePass(); } +} // namespace toy diff --git a/toy/ToyCombine.cpp b/toy/ToyCombine.cpp new file mode 100644 index 0000000..8d6aed6 --- /dev/null +++ b/toy/ToyCombine.cpp @@ -0,0 +1,209 @@ +//===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a simple combiner for optimizing pattern in the Toy +// dialect. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" + +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" + +#include + +namespace toy { + +namespace { + +/// Fold transpose(transpose(x)) -> transpose(x) +struct SimplifyRedundantTranspose : public mlir::RewritePattern { + /// We register this pattern to match every toy.transpose in the IR. + /// The "benefit" is used by the framework to order the patterns and process + /// them in order of profitability. + SimplifyRedundantTranspose(mlir::MLIRContext *context) + : RewritePattern(TransposeOp::getOperationName(), /* benefit = */ 1, + context) {} + + /// This method is attempting to match a pattern and rewrite it. The rewriter + /// argument is the orchestrator of the sequence of rewrites. It is expected + /// to interact with it to perform any changes to the IR from here. + mlir::PatternMatchResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + // We can directly cast the current operation as this will only get invoked + // on TransposeOp. + TransposeOp transpose = op->cast(); + // look through the input to the current transpose + mlir::Value *transposeInput = transpose.getOperand(); + mlir::Operation *transposeInputInst = transposeInput->getDefiningOp(); + // If the input is defined by another Transpose, bingo! + TransposeOp transposeInputOp = + mlir::dyn_cast_or_null(transposeInputInst); + if (!transposeInputOp) + return matchFailure(); + + // Use the rewriter to perform the replacement + rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp}); + return matchSuccess(); + } +}; + +/// Fold reshape(constant(x)) -> constant(x'), with x' being reshaped in place. +struct SimplifyReshapeConstant : public mlir::RewritePattern { + SimplifyReshapeConstant(mlir::MLIRContext *context) + : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1, + context) {} + + mlir::PatternMatchResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + ReshapeOp reshape = op->cast(); + // look through the input to the current reshape + mlir::Value *reshapeInput = reshape.getOperand(); + mlir::Operation *reshapeInputInst = reshapeInput->getDefiningOp(); + // If the input is defined by another reshape, bingo! + ConstantOp constantOp = + mlir::dyn_cast_or_null(reshapeInputInst); + if (!constantOp) + return matchFailure(); + + auto reshapeType = op->getResult(0)->getType().cast(); + if (auto valueAttr = + constantOp.getAttrOfType("value")) { + // FIXME Check matching of element count! + // auto oldType = constantOp.getType(); + auto newType = rewriter.getTensorType( + reshapeType.getShape(), valueAttr.getType().getElementType()); + auto newAttr = + mlir::DenseElementsAttr::get(newType, valueAttr.getRawData()); + auto newConstant = rewriter.create( + constantOp.getLoc(), reshapeType.getShape(), newAttr); + rewriter.replaceOp(op, {newConstant}); + } else if (auto valueAttr = + constantOp.getAttrOfType("value")) { + // Broadcast + auto dataSize = std::accumulate(reshapeType.getShape().begin(), + reshapeType.getShape().end(), 1, + std::multiplies()); + std::vector data(dataSize, valueAttr); + auto tensorTy = rewriter.getTensorType(reshapeType.getShape(), + reshapeType.getElementType()); + auto newAttr = mlir::DenseElementsAttr::get(tensorTy, data); + auto newConstant = rewriter.create( + constantOp.getLoc(), reshapeType.getShape(), newAttr); + rewriter.replaceOp(op, {newConstant}); + } else { + llvm_unreachable("Unsupported Constant format"); + } + return matchSuccess(); + } +}; + +/// Fold reshape(reshape(x)) -> reshape(x) +struct SimplifyReshapeReshape : public mlir::RewritePattern { + SimplifyReshapeReshape(mlir::MLIRContext *context) + : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1, + context) {} + + mlir::PatternMatchResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + ReshapeOp reshape = op->cast(); + // look through the input to the current reshape + mlir::Value *reshapeInput = reshape.getOperand(); + mlir::Operation *reshapeInputInst = reshapeInput->getDefiningOp(); + // If the input is defined by another reshape, bingo! + ReshapeOp reshapeInputOp = + mlir::dyn_cast_or_null(reshapeInputInst); + if (!reshapeInputOp) + return matchFailure(); + + // Use the rewriter to perform the replacement + rewriter.replaceOp(op, {reshapeInputOp}); + return matchSuccess(); + } +}; + +/// Fold reshape(x)) -> x, when input type matches output type +struct SimplifyNullReshape : public mlir::RewritePattern { + SimplifyNullReshape(mlir::MLIRContext *context) + : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1, + context) {} + + mlir::PatternMatchResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + ReshapeOp reshape = op->cast(); + if (reshape.getOperand()->getType() != reshape.getResult()->getType()) + return matchFailure(); + rewriter.replaceOp(reshape, {reshape.getOperand()}); + return matchSuccess(); + } +}; + +} // end anonymous namespace. + +// Register our patterns for rewrite by the Canonicalization framework. +void TransposeOp::getCanonicalizationPatterns( + mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { + results.push_back(llvm::make_unique(context)); +} + +// Register our patterns for rewrite by the Canonicalization framework. +void ReshapeOp::getCanonicalizationPatterns( + mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { + results.push_back(llvm::make_unique(context)); + results.push_back(llvm::make_unique(context)); + results.push_back(llvm::make_unique(context)); +} + +namespace { + +/// Fold type.cast(x) -> x, when input type matches output type +struct SimplifyIdentityTypeCast : public mlir::RewritePattern { + SimplifyIdentityTypeCast(mlir::MLIRContext *context) + : RewritePattern(TypeCastOp::getOperationName(), /* benefit = */ 1, + context) {} + + mlir::PatternMatchResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + TypeCastOp typeCast = op->cast(); + auto resTy = typeCast.getResult()->getType(); + auto *candidateOp = op; + while (candidateOp && candidateOp->isa()) { + if (resTy == candidateOp->getOperand(0)->getType()) { + rewriter.replaceOp(typeCast, {candidateOp->getOperand(0)}); + return matchSuccess(); + } + candidateOp = candidateOp->getOperand(0)->getDefiningOp(); + } + return matchFailure(); + } +}; + +} // end anonymous namespace. + +void TypeCastOp::getCanonicalizationPatterns( + mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { + results.push_back(llvm::make_unique(context)); +} + +} // namespace toy diff --git a/toy/ToyDialect.cpp b/toy/ToyDialect.cpp new file mode 100644 index 0000000..be117f5 --- /dev/null +++ b/toy/ToyDialect.cpp @@ -0,0 +1,405 @@ +//===- ToyDialect.cpp - Toy IR Dialect registration in MLIR ---------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements the dialect for the Toy IR: custom type parsing and +// operation verification. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Support/STLExtras.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/Regex.h" +#include "llvm/Support/raw_ostream.h" + +using llvm::ArrayRef; +using llvm::raw_ostream; +using llvm::raw_string_ostream; +using llvm::SmallVector; +using llvm::StringRef; +using llvm::Twine; + +namespace toy { +namespace detail { + +/// This class holds the implementation of the ToyArrayType. +/// It is intended to be uniqued based on its content and owned by the context. +struct ToyArrayTypeStorage : public mlir::TypeStorage { + /// This defines how we unique this type in the context: our key contains + /// only the shape, a more complex type would have multiple entries in the + /// tuple here. + /// The element of the tuples usually matches 1-1 the arguments from the + /// public `get()` method arguments from the facade. + using KeyTy = std::tuple>; + static unsigned hashKey(const KeyTy &key) { + return llvm::hash_combine(std::get<0>(key)); + } + /// When the key hash hits an existing type, we compare the shape themselves + /// to confirm we have the right type. + bool operator==(const KeyTy &key) const { return key == KeyTy(getShape()); } + + /// This is a factory method to create our type storage. It is only + /// invoked after looking up the type in the context using the key and not + /// finding it. + static ToyArrayTypeStorage *construct(mlir::TypeStorageAllocator &allocator, + const KeyTy &key) { + // Copy the shape array into the bumpptr allocator owned by the context. + ArrayRef shape = allocator.copyInto(std::get<0>(key)); + + // Allocate the instance for the ToyArrayTypeStorage itself + auto *storage = allocator.allocate(); + // Initialize the instance using placement new. + return new (storage) ToyArrayTypeStorage(shape); + } + + ArrayRef getShape() const { return shape; } + +private: + ArrayRef shape; + + /// Constructor is only invoked from the `construct()` method above. + ToyArrayTypeStorage(ArrayRef shape) : shape(shape) {} +}; + +} // namespace detail + +mlir::Type ToyArrayType::getElementType() { + return mlir::FloatType::getF64(getContext()); +} + +ToyArrayType ToyArrayType::get(mlir::MLIRContext *context, + ArrayRef shape) { + return Base::get(context, ToyTypeKind::TOY_ARRAY, shape); +} + +ArrayRef ToyArrayType::getShape() { return getImpl()->getShape(); } + +mlir::MemRefType ToyArrayType::toMemref() { + auto memRefType = mlir::MemRefType::get(getShape(), getElementType(), {}, 0); + return memRefType; +} + +/// Dialect creation, the instance will be owned by the context. This is the +/// point of registration of custom types and operations for the dialect. +ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) { + addOperations(); + addTypes(); +} + +/// Parse a type registered to this dialect, we expect only Toy arrays. +mlir::Type ToyDialect::parseType(StringRef tyData, mlir::Location loc) const { + // Sanity check: we only support array or array<...> + if (!tyData.startswith("array")) { + getContext()->emitError(loc, "Invalid Toy type '" + tyData + + "', array expected"); + return nullptr; + } + // Drop the "array" prefix from the type name, we expect either an empty + // string or just the shape. + tyData = tyData.drop_front(StringRef("array").size()); + // This is the generic array case without shape, early return it. + if (tyData.empty()) + return ToyArrayType::get(getContext()); + + // Use a regex to parse the shape (for efficient we should store this regex in + // the dialect itself). + SmallVector matches; + auto shapeRegex = llvm::Regex("^<([0-9]+)(, ([0-9]+))*>$"); + if (!shapeRegex.match(tyData, &matches)) { + getContext()->emitError(loc, "Invalid toy array shape '" + tyData + "'"); + return nullptr; + } + SmallVector shape; + // Iterate through the captures, skip the first one which is the full string. + for (auto dimStr : + llvm::make_range(std::next(matches.begin()), matches.end())) { + if (dimStr.startswith(",")) + continue; // POSIX misses non-capturing groups. + if (dimStr.empty()) + continue; // '*' makes it an optional group capture + // Convert the capture to an integer + unsigned long long dim; + if (getAsUnsignedInteger(dimStr, /* Radix = */ 10, dim)) { + getContext()->emitError( + loc, "Couldn't parse dimension as integer, matched: " + dimStr); + return mlir::Type(); + } + shape.push_back(dim); + } + // Finally we collected all the dimensions in the shape, + // create the array type. + return ToyArrayType::get(getContext(), shape); +} + +/// Print a Toy array type, for example `array<2, 3, 4>` +void ToyDialect::printType(mlir::Type type, raw_ostream &os) const { + auto arrayTy = type.dyn_cast(); + if (!arrayTy) { + os << "unknown toy type"; + return; + } + os << "array"; + if (!arrayTy.getShape().empty()) { + os << "<"; + mlir::interleaveComma(arrayTy.getShape(), os); + os << ">"; + } +} + +//////////////////////////////////////////////////////////////////////////////// +//////////////////// Custom Operations for the Dialect ///////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +/// Helper to verify that the result of an operation is a Toy array type. +template static mlir::LogicalResult verifyToyReturnArray(T *op) { + if (!op->getResult()->getType().template isa()) { + std::string msg; + raw_string_ostream os(msg); + os << "expects a Toy Array for its argument, got " + << op->getResult()->getType(); + return op->emitOpError(os.str()); + } + return mlir::success(); +} + +/// Helper to verify that the two operands of a binary operation are Toy +/// arrays.. +template static mlir::LogicalResult verifyToyBinOperands(T *op) { + if (!op->getOperand(0)->getType().template isa()) { + std::string msg; + raw_string_ostream os(msg); + os << "expects a Toy Array for its LHS, got " + << op->getOperand(0)->getType(); + return op->emitOpError(os.str()); + } + if (!op->getOperand(1)->getType().template isa()) { + std::string msg; + raw_string_ostream os(msg); + os << "expects a Toy Array for its LHS, got " + << op->getOperand(0)->getType(); + return op->emitOpError(os.str()); + } + return mlir::success(); +} + +/// Build a constant operation. +/// The builder is passed as an argument, so is the state that this method is +/// expected to fill in order to build the operation. +void ConstantOp::build(mlir::Builder *builder, mlir::OperationState *state, + ArrayRef shape, mlir::DenseElementsAttr value) { + state->types.push_back(ToyArrayType::get(builder->getContext(), shape)); + auto dataAttribute = builder->getNamedAttr("value", value); + state->attributes.push_back(dataAttribute); +} + +/// Build a constant operation. +/// The builder is passed as an argument, so is the state that this method is +/// expected to fill in order to build the operation. +void ConstantOp::build(mlir::Builder *builder, mlir::OperationState *state, + mlir::FloatAttr value) { + // Broadcast and forward to the other build factory + mlir::Type elementType = mlir::FloatType::getF64(builder->getContext()); + auto dataType = builder->getTensorType({1}, elementType); + auto dataAttribute = builder->getDenseElementsAttr(dataType, {value}) + .cast(); + + ConstantOp::build(builder, state, {1}, dataAttribute); +} + +/// Verifier for constant operation. +mlir::LogicalResult ConstantOp::verify() { + // Ensure that the return type is a Toy array + if (failed(verifyToyReturnArray(this))) + return mlir::failure(); + + // We expect the constant itself to be stored as an attribute. + auto dataAttr = getAttr("value").dyn_cast(); + if (!dataAttr) { + return emitOpError( + "missing valid `value` DenseElementsAttribute on toy.constant()"); + } + auto attrType = dataAttr.getType().dyn_cast(); + if (!attrType) { + return emitOpError( + "missing valid `value` DenseElementsAttribute on toy.constant()"); + } + + // If the return type of the constant is not a generic array, the shape must + // match the shape of the attribute holding the data. + auto resultType = getResult()->getType().cast(); + if (!resultType.isGeneric()) { + if (attrType.getRank() != resultType.getRank()) { + return emitOpError("The rank of the toy.constant return type must match " + "the one of the attached value attribute: " + + Twine(attrType.getRank()) + + " != " + Twine(resultType.getRank())); + } + for (int dim = 0; dim < attrType.getRank(); ++dim) { + if (attrType.getShape()[dim] != resultType.getShape()[dim]) { + std::string msg; + raw_string_ostream os(msg); + return emitOpError( + "Shape mismatch between toy.constant return type and its " + "attribute at dimension " + + Twine(dim) + ": " + Twine(attrType.getShape()[dim]) + + " != " + Twine(resultType.getShape()[dim])); + } + } + } + return mlir::success(); +} + +void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState *state, + StringRef callee, ArrayRef arguments) { + // Generic call always returns a generic ToyArray initially + state->types.push_back(ToyArrayType::get(builder->getContext())); + state->operands.assign(arguments.begin(), arguments.end()); + auto calleeAttr = builder->getStringAttr(callee); + state->attributes.push_back(builder->getNamedAttr("callee", calleeAttr)); +} + +mlir::LogicalResult GenericCallOp::verify() { + // Verify that every operand is a Toy Array + for (int opId = 0, num = getNumOperands(); opId < num; ++opId) { + if (!getOperand(opId)->getType().template isa()) { + std::string msg; + raw_string_ostream os(msg); + os << "expects a Toy Array for its " << opId << " operand, got " + << getOperand(opId)->getType(); + return emitOpError(os.str()); + } + } + return mlir::success(); +} + +/// Return the name of the callee. +StringRef GenericCallOp::getCalleeName() { + return getAttr("callee").cast().getValue(); +} + +template static mlir::LogicalResult verifyToySingleOperand(T *op) { + if (!op->getOperand()->getType().template isa()) { + std::string msg; + raw_string_ostream os(msg); + os << "expects a Toy Array for its argument, got " + << op->getOperand()->getType(); + return op->emitOpError(os.str()); + } + return mlir::success(); +} + +void ReturnOp::build(mlir::Builder *builder, mlir::OperationState *state, + mlir::Value *value) { + // Return does not return any value and has an optional single argument + if (value) + state->operands.push_back(value); +} + +mlir::LogicalResult ReturnOp::verify() { + if (getNumOperands() > 1) + return emitOpError("expects zero or one operand, got " + + Twine(getNumOperands())); + if (hasOperand() && failed(verifyToySingleOperand(this))) + return mlir::failure(); + return mlir::success(); +} + +void PrintOp::build(mlir::Builder *builder, mlir::OperationState *state, + mlir::Value *value) { + // Print does not return any value and has a single argument + state->operands.push_back(value); +} + +mlir::LogicalResult PrintOp::verify() { + if (failed(verifyToySingleOperand(this))) + return mlir::failure(); + return mlir::success(); +} + +void TransposeOp::build(mlir::Builder *builder, mlir::OperationState *state, + mlir::Value *value) { + state->types.push_back(ToyArrayType::get(builder->getContext())); + state->operands.push_back(value); +} + +mlir::LogicalResult TransposeOp::verify() { + if (failed(verifyToySingleOperand(this))) + return mlir::failure(); + return mlir::success(); +} + +void ReshapeOp::build(mlir::Builder *builder, mlir::OperationState *state, + mlir::Value *value, ToyArrayType reshapedType) { + state->types.push_back(reshapedType); + state->operands.push_back(value); +} + +mlir::LogicalResult ReshapeOp::verify() { + if (failed(verifyToySingleOperand(this))) + return mlir::failure(); + auto retTy = getResult()->getType().dyn_cast(); + if (!retTy) + return emitOpError("toy.reshape is expected to produce a Toy array"); + if (retTy.isGeneric()) + return emitOpError("toy.reshape is expected to produce a shaped Toy array, " + "got a generic one."); + return mlir::success(); +} + +void AddOp::build(mlir::Builder *builder, mlir::OperationState *state, + mlir::Value *lhs, mlir::Value *rhs) { + state->types.push_back(ToyArrayType::get(builder->getContext())); + state->operands.push_back(lhs); + state->operands.push_back(rhs); +} + +mlir::LogicalResult AddOp::verify() { + if (failed(verifyToyBinOperands(this))) + return mlir::failure(); + return mlir::success(); +} + +void MulOp::build(mlir::Builder *builder, mlir::OperationState *state, + mlir::Value *lhs, mlir::Value *rhs) { + state->types.push_back(ToyArrayType::get(builder->getContext())); + state->operands.push_back(lhs); + state->operands.push_back(rhs); +} + +mlir::LogicalResult MulOp::verify() { + if (failed(verifyToyBinOperands(this))) + return mlir::failure(); + return mlir::success(); +} + +void AllocOp::build(mlir::Builder *builder, mlir::OperationState *state, + mlir::Type retType) { + state->types.push_back(retType); +} + +void TypeCastOp::build(mlir::Builder *builder, mlir::OperationState *state, + mlir::Value *value, mlir::Type destTy) { + state->operands.push_back(value); + state->types.push_back(destTy); +} + +} // namespace toy diff --git a/toyc.cpp b/toyc.cpp index 6c50191..506a141 100644 --- a/toyc.cpp +++ b/toyc.cpp @@ -297,9 +297,6 @@ int dumpAST() { int main(int argc, char **argv) { // Register our Dialects with MLIR - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerPassManagerCLOptions(); cl::ParseCommandLineOptions(argc, argv, "toy compiler\n"); -- cgit v1.2.3-70-g09d2