//====- LateLowering.cpp - Lowering from Toy+Linalg to LLVM -===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= // // This file implements late lowering of IR mixing Toy and Linalg to LLVM. // It involves intemerdiate steps: // - // - a mix of affine and standard dialect. // //===----------------------------------------------------------------------===// #include "toy/Dialect.h" #include "linalg3/Intrinsics.h" #include "linalg1/ViewOp.h" #include "linalg3/ConvertToLLVMDialect.h" #include "linalg3/TensorOps.h" #include "linalg3/Transforms.h" #include "mlir/EDSC/Builders.h" #include "mlir/EDSC/Helpers.h" #include "mlir/EDSC/Intrinsics.h" #include "mlir/IR/Builders.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/StandardTypes.h" #include "mlir/LLVMIR/LLVMDialect.h" #include "mlir/Parser.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Type.h" #include using namespace mlir; namespace { /// Utility function for type casting: this is making the type checker happy, /// while delaying the actual work involved to convert the type. Most of the /// time both side of the cast (producer and consumer) will be lowered to a /// dialect like LLVM and end up with the same LLVM representation, at which /// point this becomes a no-op and is eliminated. Value *typeCast(FuncBuilder &builder, Value *val, Type destTy) { if (val->getType() == destTy) return val; return builder.create(val->getLoc(), val, destTy) .getResult(); } /// Create a type cast to turn a toy.array into a memref. The Toy Array will be /// lowered to a memref during buffer allocation, at which point the type cast /// becomes useless. Value *memRefTypeCast(FuncBuilder &builder, Value *val) { if (val->getType().isa()) return val; auto toyArrayTy = val->getType().dyn_cast(); if (!toyArrayTy) return val; return typeCast(builder, val, toyArrayTy.toMemref()); } /// Lower a toy.add to an affine loop nest. /// /// This class inherit from `DialectOpConversion` and override `rewrite`, /// similarly to the PatternRewriter introduced in the previous chapter. /// It will be called by the DialectConversion framework (see `LateLowering` /// class below). class AddOpConversion : public DialectOpConversion { public: explicit AddOpConversion(MLIRContext *context) : DialectOpConversion(toy::AddOp::getOperationName(), 1, context) {} /// Lower the `op` by generating IR using the `rewriter` builder. The builder /// is setup with a new function, the `operands` array has been populated with /// the rewritten operands for `op` in the new function. /// The results created by the new IR with the builder are returned, and their /// number must match the number of result of `op`. SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { auto add = op->cast(); auto loc = add.getLoc(); // Create a `toy.alloc` operation to allocate the output buffer for this op. Value *result = memRefTypeCast( rewriter, rewriter.create(loc, add.getResult()->getType()) .getResult()); Value *lhs = memRefTypeCast(rewriter, operands[0]); Value *rhs = memRefTypeCast(rewriter, operands[1]); using namespace edsc; ScopedContext scope(rewriter, loc); ValueHandle zero = intrinsics::constant_index(0); MemRefView vRes(result), vLHS(lhs), vRHS(rhs); IndexedValue iRes(result), iLHS(lhs), iRHS(rhs); IndexHandle i, j, M(vRes.ub(0)); if (vRes.rank() == 1) { LoopNestBuilder({&i}, {zero}, {M}, {1})({iRes(i) = iLHS(i) + iRHS(i)}); } else { assert(vRes.rank() == 2 && "only rank 1 and 2 are supported right now"); IndexHandle N(vRes.ub(1)); LoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, {1, 1})({iRes(i, j) = iLHS(i, j) + iRHS(i, j)}); } // Return the newly allocated buffer, with a type.cast to preserve the // consumers. return {typeCast(rewriter, result, add.getType())}; } }; /// Lowers `toy.print` to a loop nest calling `printf` on every individual /// elements of the array. class PrintOpConversion : public DialectOpConversion { public: explicit PrintOpConversion(MLIRContext *context) : DialectOpConversion(toy::PrintOp::getOperationName(), 1, context) {} SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { // Get or create the declaration of the printf function in the module. Function *printfFunc = getPrintf(*op->getFunction()->getModule()); auto print = op->cast(); auto loc = print.getLoc(); // We will operate on a MemRef abstraction, we use a type.cast to get one // if our operand is still a Toy array. Value *operand = memRefTypeCast(rewriter, operands[0]); Type retTy = printfFunc->getType().getResult(0); // Create our loop nest now using namespace edsc; using llvmCall = intrinsics::ValueBuilder; ScopedContext scope(rewriter, loc); ValueHandle zero = intrinsics::constant_index(0); ValueHandle fmtCst(getConstantCharBuffer(rewriter, loc, "%f ")); MemRefView vOp(operand); IndexedValue iOp(operand); IndexHandle i, j, M(vOp.ub(0)); ValueHandle fmtEol(getConstantCharBuffer(rewriter, loc, "\n")); if (vOp.rank() == 1) { // clang-format off LoopBuilder(&i, zero, M, 1)({ llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtCst, iOp(i)}) }); llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol}); // clang-format on } else { IndexHandle N(vOp.ub(1)); // clang-format off LoopBuilder(&i, zero, M, 1)({ LoopBuilder(&j, zero, N, 1)({ llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtCst, iOp(i, j)}) }), llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol}) }); // clang-format on } return {}; } private: // Turn a string into a toy.alloc (malloc/free abstraction) and a sequence // of stores into the buffer, and return a MemRef into the buffer. Value *getConstantCharBuffer(FuncBuilder &builder, Location loc, StringRef data) const { auto retTy = builder.getMemRefType(data.size() + 1, builder.getIntegerType(8)); Value *result = builder.create(loc, retTy).getResult(); using namespace edsc; using intrinsics::constant_index; using intrinsics::constant_int; ScopedContext scope(builder, loc); MemRefView vOp(result); IndexedValue iOp(result); for (uint64_t i = 0; i < data.size(); ++i) { iOp(constant_index(i)) = constant_int(data[i], 8); } iOp(constant_index(data.size())) = constant_int(0, 8); return result; } /// Return the prototype declaration for printf in the module, create it if /// necessary. Function *getPrintf(Module &module) const { auto *printfFunc = module.getNamedFunction("printf"); if (printfFunc) return printfFunc; // Create a function declaration for printf, signature is `i32 (i8*, ...)` Builder builder(&module); MLIRContext *context = module.getContext(); LLVM::LLVMDialect *llvmDialect = static_cast( module.getContext()->getRegisteredDialect("llvm")); auto &llvmModule = llvmDialect->getLLVMModule(); llvm::IRBuilder<> llvmBuilder(llvmModule.getContext()); auto llvmI32Ty = LLVM::LLVMType::get(context, llvmBuilder.getIntNTy(32)); auto llvmI8PtrTy = LLVM::LLVMType::get(context, llvmBuilder.getIntNTy(8)->getPointerTo()); auto printfTy = builder.getFunctionType({llvmI8PtrTy}, {llvmI32Ty}); printfFunc = new Function(builder.getUnknownLoc(), "printf", printfTy); // It should be variadic, but we don't support it fully just yet. printfFunc->setAttr("std.varargs", builder.getBoolAttr(true)); module.getFunctions().push_back(printfFunc); return printfFunc; } }; /// Lowers constant to a sequence of store in a buffer. class ConstantOpConversion : public DialectOpConversion { public: explicit ConstantOpConversion(MLIRContext *context) : DialectOpConversion(toy::ConstantOp::getOperationName(), 1, context) {} SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { toy::ConstantOp cstOp = op->cast(); auto loc = cstOp.getLoc(); auto retTy = cstOp.getResult()->getType().cast(); auto shape = retTy.getShape(); Value *result = memRefTypeCast( rewriter, rewriter.create(loc, retTy).getResult()); auto cstValue = cstOp.getValue(); auto f64Ty = rewriter.getF64Type(); using namespace edsc; using intrinsics::constant_float; using intrinsics::constant_index; ScopedContext scope(rewriter, loc); MemRefView vOp(result); IndexedValue iOp(result); for (uint64_t i = 0; i < shape[0]; ++i) { if (shape.size() == 1) { auto value = cstValue.getValue(ArrayRef{i}) .cast() .getValue(); iOp(constant_index(i)) = constant_float(value, f64Ty); continue; } for (uint64_t j = 0; j < shape[1]; ++j) { auto value = cstValue.getValue(ArrayRef{i, j}) .cast() .getValue(); iOp(constant_index(i), constant_index(j)) = constant_float(value, f64Ty); } } return {result}; } }; /// Lower transpose operation to an affine loop nest. class TransposeOpConversion : public DialectOpConversion { public: explicit TransposeOpConversion(MLIRContext *context) : DialectOpConversion(toy::TransposeOp::getOperationName(), 1, context) {} SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { auto transpose = op->cast(); auto loc = transpose.getLoc(); Value *result = memRefTypeCast( rewriter, rewriter.create(loc, transpose.getResult()->getType()) .getResult()); Value *operand = memRefTypeCast(rewriter, operands[0]); using namespace edsc; ScopedContext scope(rewriter, loc); ValueHandle zero = intrinsics::constant_index(0); MemRefView vRes(result), vOperand(operand); IndexedValue iRes(result), iOperand(operand); IndexHandle i, j, M(vRes.ub(0)), N(vRes.ub(1)); // clang-format off LoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, {1, 1})({ iRes(i, j) = iOperand(j, i) }); // clang-format on return {typeCast(rewriter, result, transpose.getType())}; } }; // Lower toy.return to standard return operation. class ReturnOpConversion : public DialectOpConversion { public: explicit ReturnOpConversion(MLIRContext *context) : DialectOpConversion(toy::ReturnOp::getOperationName(), 1, context) {} SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { auto retOp = op->cast(); using namespace edsc; auto loc = retOp.getLoc(); // Argument is optional, handle both cases. if (retOp.getNumOperands()) rewriter.create(loc, operands[0]); else rewriter.create(loc); return {}; } }; /// This is the main class registering our individual converter classes with /// the DialectConversion framework in MLIR. class LateLowering : public DialectConversion { protected: /// Initialize the list of converters. llvm::DenseSet initConverters(MLIRContext *context) override { return ConversionListBuilder::build(&allocator, context); } /// Convert a Toy type, this gets called for block and region arguments, and /// attributes. Type convertType(Type t) override { if (auto array = t.cast()) { return array.toMemref(); } return t; } private: llvm::BumpPtrAllocator allocator; }; /// This is lowering to Linalg the parts that can be (matmul and add on arrays) /// and is targeting LLVM otherwise. struct LateLoweringPass : public ModulePass { void runOnModule() override { // Perform Toy specific lowering if (failed(LateLowering().convert(&getModule()))) { getModule().getContext()->emitError( UnknownLoc::get(getModule().getContext()), "Error lowering Toy\n"); signalPassFailure(); } // At this point the IR is almost using only standard and affine dialects. // A few things remain before we emit LLVM IR. First to reuse as much of // MLIR as possible we will try to lower everything to the standard and/or // affine dialect: they already include conversion to the LLVM dialect. // First patch calls type to return memref instead of ToyArray for (auto &function : getModule()) { function.walk([&](Operation *op) { auto callOp = op->dyn_cast(); if (!callOp) return; if (!callOp.getNumResults()) return; auto retToyTy = callOp.getResult(0)->getType().dyn_cast(); if (!retToyTy) return; callOp.getResult(0)->setType(retToyTy.toMemref()); }); } for (auto &function : getModule()) { function.walk([&](Operation *op) { // Turns toy.alloc into sequence of alloc/dealloc (later malloc/free). if (auto allocOp = op->dyn_cast()) { auto result = allocTensor(allocOp); allocOp.replaceAllUsesWith(result); allocOp.erase(); return; } // Eliminate all type.cast before lowering to LLVM. if (auto typeCastOp = op->dyn_cast()) { typeCastOp.replaceAllUsesWith(typeCastOp.getOperand()); typeCastOp.erase(); return; } }); } // Lower Linalg to affine for (auto &function : getModule()) linalg::lowerToLoops(&function); getModule().dump(); // Finally convert to LLVM Dialect linalg::convertLinalg3ToLLVM(getModule()); } /// Allocate buffers (malloc/free) for Toy operations. This can't be done as /// part of dialect conversion framework since we need to insert `dealloc` /// operations just before the return, but the conversion framework is /// operating in a brand new function: we don't have the return to hook the /// dealloc operations. Value *allocTensor(toy::AllocOp alloc) { FuncBuilder builder(alloc); auto retTy = alloc.getResult()->getType(); auto memRefTy = retTy.dyn_cast(); if (!memRefTy) memRefTy = retTy.cast().toMemref(); if (!memRefTy) { alloc.emitOpError("is expected to allocate a Toy array or a MemRef"); llvm_unreachable("fatal error"); } auto loc = alloc.getLoc(); Value *result = builder.create(loc, memRefTy).getResult(); // Insert a `dealloc` operation right before the `return` operations, unless // it is returned itself in which case the caller is responsible for it. builder.getFunction()->walk([&](Operation *op) { auto returnOp = op->dyn_cast(); if (!returnOp) return; if (returnOp.getNumOperands() && returnOp.getOperand(0) == alloc) return; builder.setInsertionPoint(returnOp); builder.create(alloc.getLoc(), result); }); return result; } }; } // end anonymous namespace namespace toy { Pass *createLateLoweringPass() { return new LateLoweringPass(); } std::unique_ptr makeToyLateLowering() { return llvm::make_unique(); } } // namespace toy