From 22bb32ed1b9505ae49145ca7765def6398f4803d Mon Sep 17 00:00:00 2001
From: Tuowen Zhao <ztuowen@gmail.com>
Date: Wed, 24 Apr 2019 10:53:07 -0600
Subject: Initial commit

---
 include/linalg3/TensorOps-inl.h | 145 ++++++++++++++++++++++++++++++++++++++++
 1 file changed, 145 insertions(+)
 create mode 100644 include/linalg3/TensorOps-inl.h

(limited to 'include/linalg3/TensorOps-inl.h')

diff --git a/include/linalg3/TensorOps-inl.h b/include/linalg3/TensorOps-inl.h
new file mode 100644
index 0000000..b651053
--- /dev/null
+++ b/include/linalg3/TensorOps-inl.h
@@ -0,0 +1,145 @@
+//===- TensorOps-inl.h - Linalg dialect TensorOps operation implementation ===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
+/// TensorOps by adding implementations as they are needed in the appropriate
+/// step in the tutorial.
+#ifndef LINALG3_TENSOROPS_INL_H_
+#define LINALG3_TENSOROPS_INL_H_
+
+#include "linalg1/Common.h"
+#include "linalg1/Utils.h"
+#include "linalg2/TensorOps.h"
+#include "linalg3/Analysis.h"
+#include "linalg3/Ops.h"
+
+template <class ConcreteOp>
+mlir::Value *
+linalg::TensorContractionBase<ConcreteOp>::getInputView(unsigned viewIndex) {
+  return *(getInputs().begin() + viewIndex);
+}
+
+template <class ConcreteOp>
+mlir::Value *
+linalg::TensorContractionBase<ConcreteOp>::getOutputView(unsigned viewIndex) {
+  return *(getOutputs().begin() + viewIndex);
+}
+
+template <class ConcreteOp>
+llvm::SmallVector<mlir::AffineMap, 8>
+linalg::TensorContractionBase<ConcreteOp>::loopsToOperandRangeMaps() {
+  return static_cast<ConcreteOp *>(this)->loopsToOperandRangeMaps();
+}
+
+template <class ConcreteOp>
+void linalg::TensorContractionBase<ConcreteOp>::emitScalarImplementation(
+    llvm::ArrayRef<mlir::Value *> parallelIvs,
+    llvm::ArrayRef<mlir::Value *> reductionIvs) {
+  static_cast<ConcreteOp *>(this)->emitScalarImplementation(parallelIvs,
+                                                            reductionIvs);
+}
+
+template <class ConcreteOp>
+mlir::AffineMap linalg::operandRangesToLoopsMap(
+    linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
+  mlir::AffineMap current;
+  // Individual submaps may not be invertible but their union must be invertible
+  // by construction.
+  for (auto m : tensorContraction.loopsToOperandRangeMaps()) {
+    if (!m)
+      continue;
+    if (!current) {
+      current = m;
+      continue;
+    }
+    llvm::SmallVector<mlir::AffineExpr, 8> results(current.getResults().begin(),
+                                                   current.getResults().end());
+    results.append(m.getResults().begin(), m.getResults().end());
+    current = mlir::AffineMap::get(
+        std::max(current.getNumDims(), m.getNumDims()),
+        current.getNumSymbols() + m.getNumSymbols(), results, {});
+  }
+  return inverseSubMap(current);
+}
+
+// Extract the ranges from a given ViewOp or SliceOp.
+//
+// In the case of a ViewOp, things are simple: just traverse the indexings and
+// get all the ranges (i.e. drop the indices).
+//
+// In the case of a SliceOp, things are trickier because we need to handle a
+// potential rank-reduction:
+//   1. Examine the indexing to determine if it is rank-reducing.
+//   2. If it is rank-reducing, an offset of 1 is added to the dimensions such
+//      that `d >= slicingDim`. This is to account for the rank reduction.
+// `getRootIndex` is then called on the **parent** view
+static llvm::SmallVector<mlir::Value *, 8>
+extractRangesFromViewOrSliceOp(mlir::Value *view) {
+  // This expects a viewType which must come from either ViewOp or SliceOp.
+  assert(view->getType().isa<linalg::ViewType>() && "expected ViewType");
+  if (auto viewOp = view->getDefiningOp()->dyn_cast<linalg::ViewOp>())
+    return viewOp.getRanges();
+
+  auto sliceOp = view->getDefiningOp()->cast<linalg::SliceOp>();
+  unsigned slicingDim = sliceOp.getSlicingDim();
+  auto *indexing = *(sliceOp.getIndexings().begin());
+  bool isRankReducing = indexing->getType().isa<mlir::IndexType>();
+  unsigned offset = 0;
+  llvm::SmallVector<mlir::Value *, 8> res;
+  res.reserve(sliceOp.getRank());
+  for (unsigned d = 0, e = sliceOp.getRank(); d < e; ++d) {
+    if (d == slicingDim && isRankReducing)
+      offset = 1;
+    auto *parentView = sliceOp.getParentView();
+    auto indexingPosPair = linalg::getViewRootIndexing(parentView, d + offset);
+    res.push_back(indexingPosPair.first);
+  }
+  return res;
+}
+
+template <class ConcreteOp>
+static llvm::SmallVector<mlir::Value *, 8>
+getInputRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
+  llvm::SmallVector<mlir::Value *, 8> res;
+  for (auto *in : tensorContraction.getInputs()) {
+    auto subres = extractRangesFromViewOrSliceOp(in);
+    res.append(subres.begin(), subres.end());
+  }
+  return res;
+}
+
+template <class ConcreteOp>
+static llvm::SmallVector<mlir::Value *, 8>
+getOutputRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
+  llvm::SmallVector<mlir::Value *, 8> res;
+  for (auto *out : tensorContraction.getOutputs()) {
+    auto subres = extractRangesFromViewOrSliceOp(out);
+    res.append(subres.begin(), subres.end());
+  }
+  return res;
+}
+
+template <class ConcreteOp>
+llvm::SmallVector<mlir::Value *, 8> linalg::getRanges(
+    linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
+  llvm::SmallVector<mlir::Value *, 8> res = getInputRanges(tensorContraction);
+  llvm::SmallVector<mlir::Value *, 8> tmp = getOutputRanges(tensorContraction);
+  res.append(tmp.begin(), tmp.end());
+  return res;
+}
+
+#endif // LINALG3_TENSOROPS_INL_H_
-- 
cgit v1.2.3-70-g09d2