diff options
author | Tuowen Zhao <ztuowen@gmail.com> | 2019-04-27 19:05:25 -0600 |
---|---|---|
committer | Tuowen Zhao <ztuowen@gmail.com> | 2019-04-27 19:05:25 -0600 |
commit | 0781257b2a8d544abdcce38824a9b8288a04800d (patch) | |
tree | 365cea96de343e354913f90b35fc944e4459b2e9 /mlir/LateLowering.cpp | |
parent | 4127831a28e31ac53ffdb1d7e7a88dd7d6317c6e (diff) | |
download | mlir-toy-0781257b2a8d544abdcce38824a9b8288a04800d.tar.gz mlir-toy-0781257b2a8d544abdcce38824a9b8288a04800d.tar.bz2 mlir-toy-0781257b2a8d544abdcce38824a9b8288a04800d.zip |
Split toy dialect using static registration
Diffstat (limited to 'mlir/LateLowering.cpp')
-rw-r--r-- | mlir/LateLowering.cpp | 452 |
1 files changed, 0 insertions, 452 deletions
diff --git a/mlir/LateLowering.cpp b/mlir/LateLowering.cpp deleted file mode 100644 index eeae6ee..0000000 --- a/mlir/LateLowering.cpp +++ /dev/null @@ -1,452 +0,0 @@ -//====- LateLowering.cpp - Lowering from Toy+Linalg to LLVM -===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file implements late lowering of IR mixing Toy and Linalg to LLVM. -// It involves intemerdiate steps: -// - -// - a mix of affine and standard dialect. -// -//===----------------------------------------------------------------------===// - -#include "toy/Dialect.h" - -#include "linalg3/Intrinsics.h" -#include "linalg1/ViewOp.h" -#include "linalg3/ConvertToLLVMDialect.h" -#include "linalg3/TensorOps.h" -#include "linalg3/Transforms.h" -#include "mlir/EDSC/Builders.h" -#include "mlir/EDSC/Helpers.h" -#include "mlir/EDSC/Intrinsics.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/OperationSupport.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/LLVMIR/LLVMDialect.h" -#include "mlir/Parser.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" - -#include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Type.h" - -#include <algorithm> - -using namespace mlir; - -namespace { -/// Utility function for type casting: this is making the type checker happy, -/// while delaying the actual work involved to convert the type. Most of the -/// time both side of the cast (producer and consumer) will be lowered to a -/// dialect like LLVM and end up with the same LLVM representation, at which -/// point this becomes a no-op and is eliminated. -Value *typeCast(FuncBuilder &builder, Value *val, Type destTy) { - if (val->getType() == destTy) - return val; - return builder.create<toy::TypeCastOp>(val->getLoc(), val, destTy) - .getResult(); -} - -/// Create a type cast to turn a toy.array into a memref. The Toy Array will be -/// lowered to a memref during buffer allocation, at which point the type cast -/// becomes useless. -Value *memRefTypeCast(FuncBuilder &builder, Value *val) { - if (val->getType().isa<MemRefType>()) - return val; - auto toyArrayTy = val->getType().dyn_cast<toy::ToyArrayType>(); - if (!toyArrayTy) - return val; - return typeCast(builder, val, toyArrayTy.toMemref()); -} - -/// Lower a toy.add to an affine loop nest. -/// -/// This class inherit from `DialectOpConversion` and override `rewrite`, -/// similarly to the PatternRewriter introduced in the previous chapter. -/// It will be called by the DialectConversion framework (see `LateLowering` -/// class below). -class AddOpConversion : public DialectOpConversion { -public: - explicit AddOpConversion(MLIRContext *context) - : DialectOpConversion(toy::AddOp::getOperationName(), 1, context) {} - - /// Lower the `op` by generating IR using the `rewriter` builder. The builder - /// is setup with a new function, the `operands` array has been populated with - /// the rewritten operands for `op` in the new function. - /// The results created by the new IR with the builder are returned, and their - /// number must match the number of result of `op`. - SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands, - FuncBuilder &rewriter) const override { - auto add = op->cast<toy::AddOp>(); - auto loc = add.getLoc(); - // Create a `toy.alloc` operation to allocate the output buffer for this op. - Value *result = memRefTypeCast( - rewriter, rewriter.create<toy::AllocOp>(loc, add.getResult()->getType()) - .getResult()); - Value *lhs = memRefTypeCast(rewriter, operands[0]); - Value *rhs = memRefTypeCast(rewriter, operands[1]); - - using namespace edsc; - ScopedContext scope(rewriter, loc); - ValueHandle zero = intrinsics::constant_index(0); - MemRefView vRes(result), vLHS(lhs), vRHS(rhs); - IndexedValue iRes(result), iLHS(lhs), iRHS(rhs); - IndexHandle i, j, M(vRes.ub(0)); - if (vRes.rank() == 1) { - LoopNestBuilder({&i}, {zero}, {M}, {1})({iRes(i) = iLHS(i) + iRHS(i)}); - } else { - assert(vRes.rank() == 2 && "only rank 1 and 2 are supported right now"); - IndexHandle N(vRes.ub(1)); - LoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, - {1, 1})({iRes(i, j) = iLHS(i, j) + iRHS(i, j)}); - } - - // Return the newly allocated buffer, with a type.cast to preserve the - // consumers. - return {typeCast(rewriter, result, add.getType())}; - } -}; - -/// Lowers `toy.print` to a loop nest calling `printf` on every individual -/// elements of the array. -class PrintOpConversion : public DialectOpConversion { -public: - explicit PrintOpConversion(MLIRContext *context) - : DialectOpConversion(toy::PrintOp::getOperationName(), 1, context) {} - - SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands, - FuncBuilder &rewriter) const override { - // Get or create the declaration of the printf function in the module. - Function *printfFunc = getPrintf(*op->getFunction()->getModule()); - - auto print = op->cast<toy::PrintOp>(); - auto loc = print.getLoc(); - // We will operate on a MemRef abstraction, we use a type.cast to get one - // if our operand is still a Toy array. - Value *operand = memRefTypeCast(rewriter, operands[0]); - Type retTy = printfFunc->getType().getResult(0); - - // Create our loop nest now - using namespace edsc; - using llvmCall = intrinsics::ValueBuilder<LLVM::CallOp>; - ScopedContext scope(rewriter, loc); - ValueHandle zero = intrinsics::constant_index(0); - ValueHandle fmtCst(getConstantCharBuffer(rewriter, loc, "%f ")); - MemRefView vOp(operand); - IndexedValue iOp(operand); - IndexHandle i, j, M(vOp.ub(0)); - - ValueHandle fmtEol(getConstantCharBuffer(rewriter, loc, "\n")); - if (vOp.rank() == 1) { - // clang-format off - LoopBuilder(&i, zero, M, 1)({ - llvmCall(retTy, - rewriter.getFunctionAttr(printfFunc), - {fmtCst, iOp(i)}) - }); - llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol}); - // clang-format on - } else { - IndexHandle N(vOp.ub(1)); - // clang-format off - LoopBuilder(&i, zero, M, 1)({ - LoopBuilder(&j, zero, N, 1)({ - llvmCall(retTy, - rewriter.getFunctionAttr(printfFunc), - {fmtCst, iOp(i, j)}) - }), - llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol}) - }); - // clang-format on - } - return {}; - } - -private: - // Turn a string into a toy.alloc (malloc/free abstraction) and a sequence - // of stores into the buffer, and return a MemRef into the buffer. - Value *getConstantCharBuffer(FuncBuilder &builder, Location loc, - StringRef data) const { - auto retTy = - builder.getMemRefType(data.size() + 1, builder.getIntegerType(8)); - Value *result = builder.create<toy::AllocOp>(loc, retTy).getResult(); - using namespace edsc; - using intrinsics::constant_index; - using intrinsics::constant_int; - ScopedContext scope(builder, loc); - MemRefView vOp(result); - IndexedValue iOp(result); - for (uint64_t i = 0; i < data.size(); ++i) { - iOp(constant_index(i)) = constant_int(data[i], 8); - } - iOp(constant_index(data.size())) = constant_int(0, 8); - return result; - } - - /// Return the prototype declaration for printf in the module, create it if - /// necessary. - Function *getPrintf(Module &module) const { - auto *printfFunc = module.getNamedFunction("printf"); - if (printfFunc) - return printfFunc; - - // Create a function declaration for printf, signature is `i32 (i8*, ...)` - Builder builder(&module); - MLIRContext *context = module.getContext(); - LLVM::LLVMDialect *llvmDialect = static_cast<LLVM::LLVMDialect *>( - module.getContext()->getRegisteredDialect("llvm")); - auto &llvmModule = llvmDialect->getLLVMModule(); - llvm::IRBuilder<> llvmBuilder(llvmModule.getContext()); - - auto llvmI32Ty = LLVM::LLVMType::get(context, llvmBuilder.getIntNTy(32)); - auto llvmI8PtrTy = - LLVM::LLVMType::get(context, llvmBuilder.getIntNTy(8)->getPointerTo()); - auto printfTy = builder.getFunctionType({llvmI8PtrTy}, {llvmI32Ty}); - printfFunc = new Function(builder.getUnknownLoc(), "printf", printfTy); - // It should be variadic, but we don't support it fully just yet. - printfFunc->setAttr("std.varargs", builder.getBoolAttr(true)); - module.getFunctions().push_back(printfFunc); - return printfFunc; - } -}; - -/// Lowers constant to a sequence of store in a buffer. -class ConstantOpConversion : public DialectOpConversion { -public: - explicit ConstantOpConversion(MLIRContext *context) - : DialectOpConversion(toy::ConstantOp::getOperationName(), 1, context) {} - - SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands, - FuncBuilder &rewriter) const override { - toy::ConstantOp cstOp = op->cast<toy::ConstantOp>(); - auto loc = cstOp.getLoc(); - auto retTy = cstOp.getResult()->getType().cast<toy::ToyArrayType>(); - auto shape = retTy.getShape(); - Value *result = memRefTypeCast( - rewriter, rewriter.create<toy::AllocOp>(loc, retTy).getResult()); - - auto cstValue = cstOp.getValue(); - auto f64Ty = rewriter.getF64Type(); - using namespace edsc; - using intrinsics::constant_float; - using intrinsics::constant_index; - ScopedContext scope(rewriter, loc); - MemRefView vOp(result); - IndexedValue iOp(result); - for (uint64_t i = 0; i < shape[0]; ++i) { - if (shape.size() == 1) { - auto value = cstValue.getValue(ArrayRef<uint64_t>{i}) - .cast<FloatAttr>() - .getValue(); - iOp(constant_index(i)) = constant_float(value, f64Ty); - continue; - } - for (uint64_t j = 0; j < shape[1]; ++j) { - auto value = cstValue.getValue(ArrayRef<uint64_t>{i, j}) - .cast<FloatAttr>() - .getValue(); - iOp(constant_index(i), constant_index(j)) = - constant_float(value, f64Ty); - } - } - return {result}; - } -}; - -/// Lower transpose operation to an affine loop nest. -class TransposeOpConversion : public DialectOpConversion { -public: - explicit TransposeOpConversion(MLIRContext *context) - : DialectOpConversion(toy::TransposeOp::getOperationName(), 1, context) {} - - SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands, - FuncBuilder &rewriter) const override { - auto transpose = op->cast<toy::TransposeOp>(); - auto loc = transpose.getLoc(); - Value *result = memRefTypeCast( - rewriter, - rewriter.create<toy::AllocOp>(loc, transpose.getResult()->getType()) - .getResult()); - Value *operand = memRefTypeCast(rewriter, operands[0]); - - using namespace edsc; - ScopedContext scope(rewriter, loc); - ValueHandle zero = intrinsics::constant_index(0); - MemRefView vRes(result), vOperand(operand); - IndexedValue iRes(result), iOperand(operand); - IndexHandle i, j, M(vRes.ub(0)), N(vRes.ub(1)); - // clang-format off - LoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, {1, 1})({ - iRes(i, j) = iOperand(j, i) - }); - // clang-format on - - return {typeCast(rewriter, result, transpose.getType())}; - } -}; - -// Lower toy.return to standard return operation. -class ReturnOpConversion : public DialectOpConversion { -public: - explicit ReturnOpConversion(MLIRContext *context) - : DialectOpConversion(toy::ReturnOp::getOperationName(), 1, context) {} - - SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands, - FuncBuilder &rewriter) const override { - auto retOp = op->cast<toy::ReturnOp>(); - using namespace edsc; - auto loc = retOp.getLoc(); - // Argument is optional, handle both cases. - if (retOp.getNumOperands()) - rewriter.create<ReturnOp>(loc, operands[0]); - else - rewriter.create<ReturnOp>(loc); - return {}; - } -}; - -/// This is the main class registering our individual converter classes with -/// the DialectConversion framework in MLIR. -class LateLowering : public DialectConversion { -protected: - /// Initialize the list of converters. - llvm::DenseSet<DialectOpConversion *> - initConverters(MLIRContext *context) override { - return ConversionListBuilder<AddOpConversion, PrintOpConversion, - ConstantOpConversion, TransposeOpConversion, - ReturnOpConversion>::build(&allocator, - context); - } - - /// Convert a Toy type, this gets called for block and region arguments, and - /// attributes. - Type convertType(Type t) override { - if (auto array = t.cast<toy::ToyArrayType>()) { - return array.toMemref(); - } - return t; - } - -private: - llvm::BumpPtrAllocator allocator; -}; - -/// This is lowering to Linalg the parts that can be (matmul and add on arrays) -/// and is targeting LLVM otherwise. -struct LateLoweringPass : public ModulePass<LateLoweringPass> { - - void runOnModule() override { - // Perform Toy specific lowering - if (failed(LateLowering().convert(&getModule()))) { - getModule().getContext()->emitError( - UnknownLoc::get(getModule().getContext()), "Error lowering Toy\n"); - signalPassFailure(); - } - // At this point the IR is almost using only standard and affine dialects. - // A few things remain before we emit LLVM IR. First to reuse as much of - // MLIR as possible we will try to lower everything to the standard and/or - // affine dialect: they already include conversion to the LLVM dialect. - - // First patch calls type to return memref instead of ToyArray - for (auto &function : getModule()) { - function.walk([&](Operation *op) { - auto callOp = op->dyn_cast<CallOp>(); - if (!callOp) - return; - if (!callOp.getNumResults()) - return; - auto retToyTy = - callOp.getResult(0)->getType().dyn_cast<toy::ToyArrayType>(); - if (!retToyTy) - return; - callOp.getResult(0)->setType(retToyTy.toMemref()); - }); - } - - for (auto &function : getModule()) { - function.walk([&](Operation *op) { - // Turns toy.alloc into sequence of alloc/dealloc (later malloc/free). - if (auto allocOp = op->dyn_cast<toy::AllocOp>()) { - auto result = allocTensor(allocOp); - allocOp.replaceAllUsesWith(result); - allocOp.erase(); - return; - } - // Eliminate all type.cast before lowering to LLVM. - if (auto typeCastOp = op->dyn_cast<toy::TypeCastOp>()) { - typeCastOp.replaceAllUsesWith(typeCastOp.getOperand()); - typeCastOp.erase(); - return; - } - }); - } - - // Lower Linalg to affine - for (auto &function : getModule()) - linalg::lowerToLoops(&function); - - getModule().dump(); - - // Finally convert to LLVM Dialect - linalg::convertLinalg3ToLLVM(getModule()); - } - - /// Allocate buffers (malloc/free) for Toy operations. This can't be done as - /// part of dialect conversion framework since we need to insert `dealloc` - /// operations just before the return, but the conversion framework is - /// operating in a brand new function: we don't have the return to hook the - /// dealloc operations. - Value *allocTensor(toy::AllocOp alloc) { - FuncBuilder builder(alloc); - auto retTy = alloc.getResult()->getType(); - - auto memRefTy = retTy.dyn_cast<MemRefType>(); - if (!memRefTy) - memRefTy = retTy.cast<toy::ToyArrayType>().toMemref(); - if (!memRefTy) { - alloc.emitOpError("is expected to allocate a Toy array or a MemRef"); - llvm_unreachable("fatal error"); - } - auto loc = alloc.getLoc(); - Value *result = builder.create<AllocOp>(loc, memRefTy).getResult(); - - // Insert a `dealloc` operation right before the `return` operations, unless - // it is returned itself in which case the caller is responsible for it. - builder.getFunction()->walk([&](Operation *op) { - auto returnOp = op->dyn_cast<ReturnOp>(); - if (!returnOp) - return; - if (returnOp.getNumOperands() && returnOp.getOperand(0) == alloc) - return; - builder.setInsertionPoint(returnOp); - builder.create<DeallocOp>(alloc.getLoc(), result); - }); - return result; - } -}; -} // end anonymous namespace - -namespace toy { -Pass *createLateLoweringPass() { return new LateLoweringPass(); } - -std::unique_ptr<DialectConversion> makeToyLateLowering() { - return llvm::make_unique<LateLowering>(); -} - -} // namespace toy |