summaryrefslogtreecommitdiff
path: root/include/linalg3/TensorOps-inl.h
diff options
context:
space:
mode:
Diffstat (limited to 'include/linalg3/TensorOps-inl.h')
-rw-r--r--include/linalg3/TensorOps-inl.h145
1 files changed, 145 insertions, 0 deletions
diff --git a/include/linalg3/TensorOps-inl.h b/include/linalg3/TensorOps-inl.h
new file mode 100644
index 0000000..b651053
--- /dev/null
+++ b/include/linalg3/TensorOps-inl.h
@@ -0,0 +1,145 @@
+//===- TensorOps-inl.h - Linalg dialect TensorOps operation implementation ===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
+/// TensorOps by adding implementations as they are needed in the appropriate
+/// step in the tutorial.
+#ifndef LINALG3_TENSOROPS_INL_H_
+#define LINALG3_TENSOROPS_INL_H_
+
+#include "linalg1/Common.h"
+#include "linalg1/Utils.h"
+#include "linalg2/TensorOps.h"
+#include "linalg3/Analysis.h"
+#include "linalg3/Ops.h"
+
+template <class ConcreteOp>
+mlir::Value *
+linalg::TensorContractionBase<ConcreteOp>::getInputView(unsigned viewIndex) {
+ return *(getInputs().begin() + viewIndex);
+}
+
+template <class ConcreteOp>
+mlir::Value *
+linalg::TensorContractionBase<ConcreteOp>::getOutputView(unsigned viewIndex) {
+ return *(getOutputs().begin() + viewIndex);
+}
+
+template <class ConcreteOp>
+llvm::SmallVector<mlir::AffineMap, 8>
+linalg::TensorContractionBase<ConcreteOp>::loopsToOperandRangeMaps() {
+ return static_cast<ConcreteOp *>(this)->loopsToOperandRangeMaps();
+}
+
+template <class ConcreteOp>
+void linalg::TensorContractionBase<ConcreteOp>::emitScalarImplementation(
+ llvm::ArrayRef<mlir::Value *> parallelIvs,
+ llvm::ArrayRef<mlir::Value *> reductionIvs) {
+ static_cast<ConcreteOp *>(this)->emitScalarImplementation(parallelIvs,
+ reductionIvs);
+}
+
+template <class ConcreteOp>
+mlir::AffineMap linalg::operandRangesToLoopsMap(
+ linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
+ mlir::AffineMap current;
+ // Individual submaps may not be invertible but their union must be invertible
+ // by construction.
+ for (auto m : tensorContraction.loopsToOperandRangeMaps()) {
+ if (!m)
+ continue;
+ if (!current) {
+ current = m;
+ continue;
+ }
+ llvm::SmallVector<mlir::AffineExpr, 8> results(current.getResults().begin(),
+ current.getResults().end());
+ results.append(m.getResults().begin(), m.getResults().end());
+ current = mlir::AffineMap::get(
+ std::max(current.getNumDims(), m.getNumDims()),
+ current.getNumSymbols() + m.getNumSymbols(), results, {});
+ }
+ return inverseSubMap(current);
+}
+
+// Extract the ranges from a given ViewOp or SliceOp.
+//
+// In the case of a ViewOp, things are simple: just traverse the indexings and
+// get all the ranges (i.e. drop the indices).
+//
+// In the case of a SliceOp, things are trickier because we need to handle a
+// potential rank-reduction:
+// 1. Examine the indexing to determine if it is rank-reducing.
+// 2. If it is rank-reducing, an offset of 1 is added to the dimensions such
+// that `d >= slicingDim`. This is to account for the rank reduction.
+// `getRootIndex` is then called on the **parent** view
+static llvm::SmallVector<mlir::Value *, 8>
+extractRangesFromViewOrSliceOp(mlir::Value *view) {
+ // This expects a viewType which must come from either ViewOp or SliceOp.
+ assert(view->getType().isa<linalg::ViewType>() && "expected ViewType");
+ if (auto viewOp = view->getDefiningOp()->dyn_cast<linalg::ViewOp>())
+ return viewOp.getRanges();
+
+ auto sliceOp = view->getDefiningOp()->cast<linalg::SliceOp>();
+ unsigned slicingDim = sliceOp.getSlicingDim();
+ auto *indexing = *(sliceOp.getIndexings().begin());
+ bool isRankReducing = indexing->getType().isa<mlir::IndexType>();
+ unsigned offset = 0;
+ llvm::SmallVector<mlir::Value *, 8> res;
+ res.reserve(sliceOp.getRank());
+ for (unsigned d = 0, e = sliceOp.getRank(); d < e; ++d) {
+ if (d == slicingDim && isRankReducing)
+ offset = 1;
+ auto *parentView = sliceOp.getParentView();
+ auto indexingPosPair = linalg::getViewRootIndexing(parentView, d + offset);
+ res.push_back(indexingPosPair.first);
+ }
+ return res;
+}
+
+template <class ConcreteOp>
+static llvm::SmallVector<mlir::Value *, 8>
+getInputRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
+ llvm::SmallVector<mlir::Value *, 8> res;
+ for (auto *in : tensorContraction.getInputs()) {
+ auto subres = extractRangesFromViewOrSliceOp(in);
+ res.append(subres.begin(), subres.end());
+ }
+ return res;
+}
+
+template <class ConcreteOp>
+static llvm::SmallVector<mlir::Value *, 8>
+getOutputRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
+ llvm::SmallVector<mlir::Value *, 8> res;
+ for (auto *out : tensorContraction.getOutputs()) {
+ auto subres = extractRangesFromViewOrSliceOp(out);
+ res.append(subres.begin(), subres.end());
+ }
+ return res;
+}
+
+template <class ConcreteOp>
+llvm::SmallVector<mlir::Value *, 8> linalg::getRanges(
+ linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
+ llvm::SmallVector<mlir::Value *, 8> res = getInputRanges(tensorContraction);
+ llvm::SmallVector<mlir::Value *, 8> tmp = getOutputRanges(tensorContraction);
+ res.append(tmp.begin(), tmp.end());
+ return res;
+}
+
+#endif // LINALG3_TENSOROPS_INL_H_