diff options
author | Tuowen Zhao <ztuowen@gmail.com> | 2019-04-24 10:53:07 -0600 |
---|---|---|
committer | Tuowen Zhao <ztuowen@gmail.com> | 2019-04-24 10:53:07 -0600 |
commit | 22bb32ed1b9505ae49145ca7765def6398f4803d (patch) | |
tree | fce88de88ed7ffda0856fb4798d0be58460d07c3 /include/linalg2/TensorOps-inl.h | |
download | mlir-toy-22bb32ed1b9505ae49145ca7765def6398f4803d.tar.gz mlir-toy-22bb32ed1b9505ae49145ca7765def6398f4803d.tar.bz2 mlir-toy-22bb32ed1b9505ae49145ca7765def6398f4803d.zip |
Initial commit
Diffstat (limited to 'include/linalg2/TensorOps-inl.h')
-rw-r--r-- | include/linalg2/TensorOps-inl.h | 120 |
1 files changed, 120 insertions, 0 deletions
diff --git a/include/linalg2/TensorOps-inl.h b/include/linalg2/TensorOps-inl.h new file mode 100644 index 0000000..940f8d7 --- /dev/null +++ b/include/linalg2/TensorOps-inl.h @@ -0,0 +1,120 @@ +//===- TensorOps-inl.h - Linalg dialect TensorOps operation implementation ===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of +/// TensorOps by adding implementations as they are needed in the appropriate +/// step in the tutorial. +#ifndef LINALG2_TENSOROPS_INL_H_ +#define LINALG2_TENSOROPS_INL_H_ + +#include "linalg2/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/StandardTypes.h" + +namespace linalg { + +template <class ConcreteOp> +mlir::Operation::operand_range +linalg::TensorContractionBase<ConcreteOp>::getInputs() { + auto *op = static_cast<ConcreteOp *>(this)->getOperation(); + return {op->operand_begin(), op->operand_begin() + getNumInputs()}; +} + +template <class ConcreteOp> +mlir::Operation::operand_range +linalg::TensorContractionBase<ConcreteOp>::getOutputs() { + auto *op = static_cast<ConcreteOp *>(this)->getOperation(); + return {op->operand_begin() + getNumInputs(), + op->operand_begin() + getNumInputs() + getNumOutputs()}; +} + +template <class ConcreteOp> +mlir::Operation::operand_range +linalg::TensorContractionBase<ConcreteOp>::getInputsAndOutputs() { + return {getInputs().begin(), getOutputs().end()}; +} + +template <class ConcreteOp> +mlir::LogicalResult linalg::TensorContractionBase<ConcreteOp>::verify() { + auto *concreteOp = static_cast<ConcreteOp *>(this)->getOperation(); + if (getNumInputs() <= 0) + concreteOp->emitOpError("expected at least one input"); + if (getNumOutputs() <= 0) + concreteOp->emitOpError("expected at least one output"); + if (concreteOp->getNumOperands() != getNumInputs() + getNumOutputs()) { + concreteOp->emitOpError("expected " + + llvm::Twine(getNumInputs() + getNumOutputs()) + + " operands"); + } + for (unsigned i = 0, e = getNumInputs(); i < e; ++i) { + if (!concreteOp->getOperand(i)->getType().template isa<ViewType>()) + return concreteOp->emitOpError("operand " + llvm::Twine(i) + + " not a ViewType"); + } + for (unsigned i = getNumInputs(), e = getNumInputs() + getNumOutputs(); i < e; + ++i) { + auto viewType = + concreteOp->getOperand(i)->getType().template dyn_cast<ViewType>(); + if (!viewType) + return concreteOp->emitOpError("operand " + llvm::Twine(i) + + " not a ViewType"); + if (viewType.getRank() != getNumParallelDims()) + return concreteOp->emitOpError("operand " + llvm::Twine(i) + + " must be of rank " + + llvm::Twine(getNumParallelDims())); + } + return mlir::success(); +} + +template <class ConcreteOp> +bool linalg::TensorContractionBase<ConcreteOp>::parse( + mlir::OpAsmParser *parser, mlir::OperationState *result) { + llvm_unreachable("Parsing linalg dialect is not supported in this tutorial"); +} + +// A TensorContraction prints as: +// +// ```{.mlir} +// concrete_op_name (ssa-inputs, ssa-outputs) : output-view-types +// ``` +// +// for example: +// +// ``` +// linalg.matmul(%0, %1, %2) : view<?x?xf32> +// ``` +// +// Where %0, %1 and %2 are ssa-values of type ViewType. +template <class ConcreteOp> +void linalg::TensorContractionBase<ConcreteOp>::print(mlir::OpAsmPrinter *p) { + *p << static_cast<ConcreteOp *>(this)->getOperationName() << "("; + auto *last = *std::prev(getInputsAndOutputs().end()); + for (auto *i : getInputsAndOutputs()) { + *p << *i << ((i == last) ? "" : ", "); + } + *p << ") : "; + auto *lastOutput = *std::prev(getOutputs().end()); + for (auto *o : getOutputs()) { + *p << o->getType() << ((o == lastOutput) ? "" : ","); + } +} + +} // namespace linalg + +#endif // LINALG2_TENSOROPS_INL_H_ |