1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
|
//===- Common.h - Linalg dialect RangeOp operation -----------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef LINALG1_COMMON_H_
#define LINALG1_COMMON_H_
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/EDSC/Builders.h"
#include "mlir/EDSC/Helpers.h"
#include "mlir/EDSC/Intrinsics.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/StandardOps/Ops.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Transforms/Passes.h"
namespace linalg {
namespace common {
////////////////////////////////////////////////////////////////////////////////
// Define a few boilerplate objects used across all linalg examples.
////////////////////////////////////////////////////////////////////////////////
/// A 2-D abstraction over a flat contiguous memory region of f32 with symbolic
/// sizes.
template <int N>
inline mlir::MemRefType floatMemRefType(mlir::MLIRContext *context,
unsigned memorySpace = 0) {
llvm::SmallVector<int64_t, 4> shape(N, -1);
auto f32 = mlir::FloatType::getF32(context);
return mlir::MemRefType::get(shape, f32, {}, memorySpace);
}
/// A basic function builder
inline mlir::Function *makeFunction(mlir::Module &module, llvm::StringRef name,
llvm::ArrayRef<mlir::Type> types,
llvm::ArrayRef<mlir::Type> resultTypes) {
auto *context = module.getContext();
auto *function = new mlir::Function(
mlir::UnknownLoc::get(context), name,
mlir::FunctionType::get({types}, resultTypes, context));
function->addEntryBlock();
module.getFunctions().push_back(function);
return function;
}
/// A basic pass manager pre-populated with cleanup passes.
inline std::unique_ptr<mlir::PassManager> cleanupPassManager() {
std::unique_ptr<mlir::PassManager> pm(new mlir::PassManager());
pm->addPass(mlir::createCanonicalizerPass());
pm->addPass(mlir::createSimplifyAffineStructuresPass());
pm->addPass(mlir::createCSEPass());
pm->addPass(mlir::createCanonicalizerPass());
return pm;
}
/// A simple function to verify and cleanup the IR before printing it to
/// llvm::outs() for FileCheck'ing.
/// If an error occurs, dump to llvm::errs() and do not print to llvm::outs()
/// which will make the associated FileCheck test fail.
inline void cleanupAndPrintFunction(mlir::Function *f) {
bool printToOuts = true;
auto check = [f, &printToOuts](mlir::LogicalResult result) {
if (failed(result)) {
f->getContext()->emitError(f->getLoc(),
"Verification and cleanup passes failed");
printToOuts = false;
}
};
auto pm = cleanupPassManager();
check(f->getModule()->verify());
check(pm->run(f->getModule()));
if (printToOuts)
f->print(llvm::outs());
}
/// Helper class to sugar building loop nests from indexings that appear in
/// ViewOp and SliceOp.
class LoopNestRangeBuilder {
public:
LoopNestRangeBuilder(llvm::ArrayRef<mlir::edsc::ValueHandle *> ivs,
llvm::ArrayRef<mlir::edsc::ValueHandle> indexings);
LoopNestRangeBuilder(llvm::ArrayRef<mlir::edsc::ValueHandle *> ivs,
llvm::ArrayRef<mlir::Value *> indexings);
mlir::edsc::ValueHandle
operator()(llvm::ArrayRef<mlir::edsc::CapturableHandle> stmts);
private:
llvm::SmallVector<mlir::edsc::LoopBuilder, 4> loops;
};
} // namespace common
} // namespace linalg
#endif // LINALG1_COMMON_H_
|