//=======- EarlyLowering.cpp - Toy Lowering to Linear Algebra Dialect -=======//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements early lowering of Toy IR to Linalg Dialect: we only
// lower the computationally intensive part of the program (matmul...) to a
// dialect specialized for optimizations.
//
// This is intended to showcase how multiple dialects can cohabit in the same
// function. After this lowering, you would still have toy.print in the IR for
// example.
//
//===----------------------------------------------------------------------===//

#include "toy/Dialect.h"

#include "linalg3/Intrinsics.h"
#include "linalg1/ViewOp.h"
#include "linalg3/TensorOps.h"
#include "mlir/EDSC/Builders.h"
#include "mlir/EDSC/Helpers.h"
#include "mlir/EDSC/Intrinsics.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/LLVMIR/LLVMDialect.h"
#include "mlir/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Type.h"

#include <algorithm>

using namespace mlir;

namespace {
/// Utility function for type casting: this is making the type checker happy,
/// while delaying the actual work involved to convert the type. Most of the
/// time both side of the cast (producer and consumer) will be lowered to a
/// dialect like LLVM and end up with the same LLVM representation, at which
/// point this becomes a no-op and is eliminated.
Value *typeCast(FuncBuilder &builder, Value *val, Type destTy) {
  if (val->getType() == destTy)
    return val;
  return builder.create<toy::TypeCastOp>(val->getLoc(), val, destTy)
      .getResult();
}

/// Create a type cast to turn a toy.array into a memref. The Toy Array will be
/// lowered to a memref during buffer allocation, at which point the type cast
/// becomes useless.
Value *memRefTypeCast(FuncBuilder &builder, Value *val) {
  if (val->getType().isa<MemRefType>())
    return val;
  auto toyArrayTy = val->getType().dyn_cast<toy::ToyArrayType>();
  if (!toyArrayTy)
    return val;
  return typeCast(builder, val, toyArrayTy.toMemref());
}

/// Lower toy.mul to Linalg `matmul`.
///
/// This class inherit from `DialectOpConversion` and override `rewrite`,
/// similarly to the PatternRewriter introduced in the previous chapter.
/// It will be called by the DialectConversion framework (see `LateLowering`
/// class below).
class MulOpConversion : public DialectOpConversion {
public:
  explicit MulOpConversion(MLIRContext *context)
      : DialectOpConversion(toy::MulOp::getOperationName(), 1, context) {}

  SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
                                  FuncBuilder &rewriter) const override {
    using namespace edsc;
    using intrinsics::constant_index;
    using linalg::intrinsics::range;
    using linalg::intrinsics::view;
    toy::MulOp mul = op->cast<toy::MulOp>();
    auto loc = mul.getLoc();
    Value *result = memRefTypeCast(
        rewriter, rewriter.create<toy::AllocOp>(loc, mul.getResult()->getType())
                      .getResult());
    Value *lhs = memRefTypeCast(rewriter, operands[0]);
    auto memrefLHSTy = lhs->getType().cast<MemRefType>();
    Value *rhs = memRefTypeCast(rewriter, operands[1]);
    auto memrefRHSTy = rhs->getType().cast<MemRefType>();
    mlir::edsc::ScopedContext scope(rewriter, loc);
    edsc::ValueHandle r0 =
        range(constant_index(0), constant_index(memrefLHSTy.getDimSize(0)),
              constant_index(1));
    edsc::ValueHandle r1 =
        range(constant_index(0), constant_index(memrefLHSTy.getDimSize(1)),
              constant_index(1));
    edsc::ValueHandle r2 =
        range(constant_index(0), constant_index(memrefRHSTy.getDimSize(1)),
              constant_index(1));
    auto lhsView = view(lhs, {r0, r1});
    auto rhsView = view(rhs, {r1, r2});
    auto resultView = view(result, {r0, r2});
    rewriter.create<linalg::MatmulOp>(loc, lhsView, rhsView, resultView);
    return {typeCast(rewriter, result, mul.getType())};
  }
};

// The conversion class from Toy IR Dialect to a mix of Linalg and LLVM.
class EarlyLowering : public DialectConversion {
protected:
  // Initialize the list of converters.
  llvm::DenseSet<DialectOpConversion *>
  initConverters(MLIRContext *context) override {
    return ConversionListBuilder<MulOpConversion>::build(&allocator, context);
  }

private:
  llvm::BumpPtrAllocator allocator;
};

/// This is lowering to Linalg the parts that are computationally intensive
/// (like matmul for example...) while keeping the rest of the code in the Toy
/// dialect.
struct EarlyLoweringPass : public ModulePass<EarlyLoweringPass> {

  void runOnModule() override {
    if (failed(EarlyLowering().convert(&getModule()))) {
      getModule().getContext()->emitError(
          mlir::UnknownLoc::get(getModule().getContext()),
          "Error lowering Toy\n");
      signalPassFailure();
    }
  }
};
} // end anonymous namespace

namespace toy {
Pass *createEarlyLoweringPass() { return new EarlyLoweringPass(); }

std::unique_ptr<mlir::DialectConversion> makeToyEarlyLowering() {
  return llvm::make_unique<EarlyLowering>();
}

} // namespace toy