//===- TensorOps-inl.h - Linalg dialect TensorOps operation implementation ===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= /// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of /// TensorOps by adding implementations as they are needed in the appropriate /// step in the tutorial. #ifndef LINALG3_TENSOROPS_INL_H_ #define LINALG3_TENSOROPS_INL_H_ #include "linalg1/Common.h" #include "linalg1/Utils.h" #include "linalg2/TensorOps.h" #include "linalg3/Analysis.h" #include "linalg3/Ops.h" template mlir::Value * linalg::TensorContractionBase::getInputView(unsigned viewIndex) { return *(getInputs().begin() + viewIndex); } template mlir::Value * linalg::TensorContractionBase::getOutputView(unsigned viewIndex) { return *(getOutputs().begin() + viewIndex); } template llvm::SmallVector linalg::TensorContractionBase::loopsToOperandRangeMaps() { return static_cast(this)->loopsToOperandRangeMaps(); } template void linalg::TensorContractionBase::emitScalarImplementation( llvm::ArrayRef parallelIvs, llvm::ArrayRef reductionIvs) { static_cast(this)->emitScalarImplementation(parallelIvs, reductionIvs); } template mlir::AffineMap linalg::operandRangesToLoopsMap( linalg::TensorContractionBase &tensorContraction) { mlir::AffineMap current; // Individual submaps may not be invertible but their union must be invertible // by construction. for (auto m : tensorContraction.loopsToOperandRangeMaps()) { if (!m) continue; if (!current) { current = m; continue; } llvm::SmallVector results(current.getResults().begin(), current.getResults().end()); results.append(m.getResults().begin(), m.getResults().end()); current = mlir::AffineMap::get( std::max(current.getNumDims(), m.getNumDims()), current.getNumSymbols() + m.getNumSymbols(), results, {}); } return inverseSubMap(current); } // Extract the ranges from a given ViewOp or SliceOp. // // In the case of a ViewOp, things are simple: just traverse the indexings and // get all the ranges (i.e. drop the indices). // // In the case of a SliceOp, things are trickier because we need to handle a // potential rank-reduction: // 1. Examine the indexing to determine if it is rank-reducing. // 2. If it is rank-reducing, an offset of 1 is added to the dimensions such // that `d >= slicingDim`. This is to account for the rank reduction. // `getRootIndex` is then called on the **parent** view static llvm::SmallVector extractRangesFromViewOrSliceOp(mlir::Value *view) { // This expects a viewType which must come from either ViewOp or SliceOp. assert(view->getType().isa() && "expected ViewType"); if (auto viewOp = view->getDefiningOp()->dyn_cast()) return viewOp.getRanges(); auto sliceOp = view->getDefiningOp()->cast(); unsigned slicingDim = sliceOp.getSlicingDim(); auto *indexing = *(sliceOp.getIndexings().begin()); bool isRankReducing = indexing->getType().isa(); unsigned offset = 0; llvm::SmallVector res; res.reserve(sliceOp.getRank()); for (unsigned d = 0, e = sliceOp.getRank(); d < e; ++d) { if (d == slicingDim && isRankReducing) offset = 1; auto *parentView = sliceOp.getParentView(); auto indexingPosPair = linalg::getViewRootIndexing(parentView, d + offset); res.push_back(indexingPosPair.first); } return res; } template static llvm::SmallVector getInputRanges(linalg::TensorContractionBase &tensorContraction) { llvm::SmallVector res; for (auto *in : tensorContraction.getInputs()) { auto subres = extractRangesFromViewOrSliceOp(in); res.append(subres.begin(), subres.end()); } return res; } template static llvm::SmallVector getOutputRanges(linalg::TensorContractionBase &tensorContraction) { llvm::SmallVector res; for (auto *out : tensorContraction.getOutputs()) { auto subres = extractRangesFromViewOrSliceOp(out); res.append(subres.begin(), subres.end()); } return res; } template llvm::SmallVector linalg::getRanges( linalg::TensorContractionBase &tensorContraction) { llvm::SmallVector res = getInputRanges(tensorContraction); llvm::SmallVector tmp = getOutputRanges(tensorContraction); res.append(tmp.begin(), tmp.end()); return res; } #endif // LINALG3_TENSOROPS_INL_H_