summaryrefslogtreecommitdiff
path: root/include/linalg1/Common.h
blob: 6573c728ff349f292a5eb4136a4281e706e52346 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
//===- Common.h - Linalg  dialect RangeOp operation -----------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================

#ifndef LINALG1_COMMON_H_
#define LINALG1_COMMON_H_

#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/EDSC/Builders.h"
#include "mlir/EDSC/Helpers.h"
#include "mlir/EDSC/Intrinsics.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/StandardOps/Ops.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Transforms/Passes.h"

namespace linalg {
namespace common {

////////////////////////////////////////////////////////////////////////////////
// Define a few boilerplate objects used across all linalg examples.
////////////////////////////////////////////////////////////////////////////////

/// A 2-D abstraction over a flat contiguous memory region of f32 with symbolic
/// sizes.
template <int N>
inline mlir::MemRefType floatMemRefType(mlir::MLIRContext *context,
                                        unsigned memorySpace = 0) {
  llvm::SmallVector<int64_t, 4> shape(N, -1);
  auto f32 = mlir::FloatType::getF32(context);
  return mlir::MemRefType::get(shape, f32, {}, memorySpace);
}

/// A basic function builder
inline mlir::Function *makeFunction(mlir::Module &module, llvm::StringRef name,
                                    llvm::ArrayRef<mlir::Type> types,
                                    llvm::ArrayRef<mlir::Type> resultTypes) {
  auto *context = module.getContext();
  auto *function = new mlir::Function(
      mlir::UnknownLoc::get(context), name,
      mlir::FunctionType::get({types}, resultTypes, context));
  function->addEntryBlock();
  module.getFunctions().push_back(function);
  return function;
}

/// A basic pass manager pre-populated with cleanup passes.
inline std::unique_ptr<mlir::PassManager> cleanupPassManager() {
  std::unique_ptr<mlir::PassManager> pm(new mlir::PassManager());
  pm->addPass(mlir::createCanonicalizerPass());
  pm->addPass(mlir::createSimplifyAffineStructuresPass());
  pm->addPass(mlir::createCSEPass());
  pm->addPass(mlir::createCanonicalizerPass());
  return pm;
}

/// A simple function to verify and cleanup the IR before printing it to
/// llvm::outs() for FileCheck'ing.
/// If an error occurs, dump to llvm::errs() and do not print to llvm::outs()
/// which will make the associated FileCheck test fail.
inline void cleanupAndPrintFunction(mlir::Function *f) {
  bool printToOuts = true;
  auto check = [f, &printToOuts](mlir::LogicalResult result) {
    if (failed(result)) {
      f->getContext()->emitError(f->getLoc(),
                                 "Verification and cleanup passes failed");
      printToOuts = false;
    }
  };
  auto pm = cleanupPassManager();
  check(f->getModule()->verify());
  check(pm->run(f->getModule()));
  if (printToOuts)
    f->print(llvm::outs());
}

/// Helper class to sugar building loop nests from indexings that appear in
/// ViewOp and SliceOp.
class LoopNestRangeBuilder {
public:
  LoopNestRangeBuilder(llvm::ArrayRef<mlir::edsc::ValueHandle *> ivs,
                       llvm::ArrayRef<mlir::edsc::ValueHandle> indexings);
  LoopNestRangeBuilder(llvm::ArrayRef<mlir::edsc::ValueHandle *> ivs,
                       llvm::ArrayRef<mlir::Value *> indexings);
  mlir::edsc::ValueHandle
  operator()(llvm::ArrayRef<mlir::edsc::CapturableHandle> stmts);

private:
  llvm::SmallVector<mlir::edsc::LoopBuilder, 4> loops;
};

} // namespace common
} // namespace linalg

#endif // LINALG1_COMMON_H_