diff options
Diffstat (limited to 'toy/EarlyLowering.cpp')
-rw-r--r-- | toy/EarlyLowering.cpp | 158 |
1 files changed, 158 insertions, 0 deletions
diff --git a/toy/EarlyLowering.cpp b/toy/EarlyLowering.cpp new file mode 100644 index 0000000..634c72e --- /dev/null +++ b/toy/EarlyLowering.cpp @@ -0,0 +1,158 @@ +//=======- EarlyLowering.cpp - Toy Lowering to Linear Algebra Dialect -=======// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements early lowering of Toy IR to Linalg Dialect: we only +// lower the computationally intensive part of the program (matmul...) to a +// dialect specialized for optimizations. +// +// This is intended to showcase how multiple dialects can cohabit in the same +// function. After this lowering, you would still have toy.print in the IR for +// example. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" + +#include "linalg3/Intrinsics.h" +#include "linalg1/ViewOp.h" +#include "linalg3/TensorOps.h" +#include "mlir/EDSC/Builders.h" +#include "mlir/EDSC/Helpers.h" +#include "mlir/EDSC/Intrinsics.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/LLVMIR/LLVMDialect.h" +#include "mlir/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Type.h" + +#include <algorithm> + +using namespace mlir; + +namespace { +/// Utility function for type casting: this is making the type checker happy, +/// while delaying the actual work involved to convert the type. Most of the +/// time both side of the cast (producer and consumer) will be lowered to a +/// dialect like LLVM and end up with the same LLVM representation, at which +/// point this becomes a no-op and is eliminated. +Value *typeCast(FuncBuilder &builder, Value *val, Type destTy) { + if (val->getType() == destTy) + return val; + return builder.create<toy::TypeCastOp>(val->getLoc(), val, destTy) + .getResult(); +} + +/// Create a type cast to turn a toy.array into a memref. The Toy Array will be +/// lowered to a memref during buffer allocation, at which point the type cast +/// becomes useless. +Value *memRefTypeCast(FuncBuilder &builder, Value *val) { + if (val->getType().isa<MemRefType>()) + return val; + auto toyArrayTy = val->getType().dyn_cast<toy::ToyArrayType>(); + if (!toyArrayTy) + return val; + return typeCast(builder, val, toyArrayTy.toMemref()); +} + +/// Lower toy.mul to Linalg `matmul`. +/// +/// This class inherit from `DialectOpConversion` and override `rewrite`, +/// similarly to the PatternRewriter introduced in the previous chapter. +/// It will be called by the DialectConversion framework (see `LateLowering` +/// class below). +class MulOpConversion : public DialectOpConversion { +public: + explicit MulOpConversion(MLIRContext *context) + : DialectOpConversion(toy::MulOp::getOperationName(), 1, context) {} + + SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands, + FuncBuilder &rewriter) const override { + using namespace edsc; + using intrinsics::constant_index; + using linalg::intrinsics::range; + using linalg::intrinsics::view; + toy::MulOp mul = op->cast<toy::MulOp>(); + auto loc = mul.getLoc(); + Value *result = memRefTypeCast( + rewriter, rewriter.create<toy::AllocOp>(loc, mul.getResult()->getType()) + .getResult()); + Value *lhs = memRefTypeCast(rewriter, operands[0]); + auto memrefLHSTy = lhs->getType().cast<MemRefType>(); + Value *rhs = memRefTypeCast(rewriter, operands[1]); + auto memrefRHSTy = rhs->getType().cast<MemRefType>(); + mlir::edsc::ScopedContext scope(rewriter, loc); + edsc::ValueHandle r0 = + range(constant_index(0), constant_index(memrefLHSTy.getDimSize(0)), + constant_index(1)); + edsc::ValueHandle r1 = + range(constant_index(0), constant_index(memrefLHSTy.getDimSize(1)), + constant_index(1)); + edsc::ValueHandle r2 = + range(constant_index(0), constant_index(memrefRHSTy.getDimSize(1)), + constant_index(1)); + auto lhsView = view(lhs, {r0, r1}); + auto rhsView = view(rhs, {r1, r2}); + auto resultView = view(result, {r0, r2}); + rewriter.create<linalg::MatmulOp>(loc, lhsView, rhsView, resultView); + return {typeCast(rewriter, result, mul.getType())}; + } +}; + +// The conversion class from Toy IR Dialect to a mix of Linalg and LLVM. +class EarlyLowering : public DialectConversion { +protected: + // Initialize the list of converters. + llvm::DenseSet<DialectOpConversion *> + initConverters(MLIRContext *context) override { + return ConversionListBuilder<MulOpConversion>::build(&allocator, context); + } + +private: + llvm::BumpPtrAllocator allocator; +}; + +/// This is lowering to Linalg the parts that are computationally intensive +/// (like matmul for example...) while keeping the rest of the code in the Toy +/// dialect. +struct EarlyLoweringPass : public ModulePass<EarlyLoweringPass> { + + void runOnModule() override { + if (failed(EarlyLowering().convert(&getModule()))) { + getModule().getContext()->emitError( + mlir::UnknownLoc::get(getModule().getContext()), + "Error lowering Toy\n"); + signalPassFailure(); + } + } +}; +} // end anonymous namespace + +namespace toy { +Pass *createEarlyLoweringPass() { return new EarlyLoweringPass(); } + +std::unique_ptr<mlir::DialectConversion> makeToyEarlyLowering() { + return llvm::make_unique<EarlyLowering>(); +} + +} // namespace toy |