diff options
Diffstat (limited to 'toy/ToyCombine.cpp')
-rw-r--r-- | toy/ToyCombine.cpp | 209 |
1 files changed, 209 insertions, 0 deletions
diff --git a/toy/ToyCombine.cpp b/toy/ToyCombine.cpp new file mode 100644 index 0000000..8d6aed6 --- /dev/null +++ b/toy/ToyCombine.cpp @@ -0,0 +1,209 @@ +//===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a simple combiner for optimizing pattern in the Toy +// dialect. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" + +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" + +#include <numeric> + +namespace toy { + +namespace { + +/// Fold transpose(transpose(x)) -> transpose(x) +struct SimplifyRedundantTranspose : public mlir::RewritePattern { + /// We register this pattern to match every toy.transpose in the IR. + /// The "benefit" is used by the framework to order the patterns and process + /// them in order of profitability. + SimplifyRedundantTranspose(mlir::MLIRContext *context) + : RewritePattern(TransposeOp::getOperationName(), /* benefit = */ 1, + context) {} + + /// This method is attempting to match a pattern and rewrite it. The rewriter + /// argument is the orchestrator of the sequence of rewrites. It is expected + /// to interact with it to perform any changes to the IR from here. + mlir::PatternMatchResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + // We can directly cast the current operation as this will only get invoked + // on TransposeOp. + TransposeOp transpose = op->cast<TransposeOp>(); + // look through the input to the current transpose + mlir::Value *transposeInput = transpose.getOperand(); + mlir::Operation *transposeInputInst = transposeInput->getDefiningOp(); + // If the input is defined by another Transpose, bingo! + TransposeOp transposeInputOp = + mlir::dyn_cast_or_null<TransposeOp>(transposeInputInst); + if (!transposeInputOp) + return matchFailure(); + + // Use the rewriter to perform the replacement + rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp}); + return matchSuccess(); + } +}; + +/// Fold reshape(constant(x)) -> constant(x'), with x' being reshaped in place. +struct SimplifyReshapeConstant : public mlir::RewritePattern { + SimplifyReshapeConstant(mlir::MLIRContext *context) + : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1, + context) {} + + mlir::PatternMatchResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + ReshapeOp reshape = op->cast<ReshapeOp>(); + // look through the input to the current reshape + mlir::Value *reshapeInput = reshape.getOperand(); + mlir::Operation *reshapeInputInst = reshapeInput->getDefiningOp(); + // If the input is defined by another reshape, bingo! + ConstantOp constantOp = + mlir::dyn_cast_or_null<ConstantOp>(reshapeInputInst); + if (!constantOp) + return matchFailure(); + + auto reshapeType = op->getResult(0)->getType().cast<ToyArrayType>(); + if (auto valueAttr = + constantOp.getAttrOfType<mlir::DenseElementsAttr>("value")) { + // FIXME Check matching of element count! + // auto oldType = constantOp.getType(); + auto newType = rewriter.getTensorType( + reshapeType.getShape(), valueAttr.getType().getElementType()); + auto newAttr = + mlir::DenseElementsAttr::get(newType, valueAttr.getRawData()); + auto newConstant = rewriter.create<ConstantOp>( + constantOp.getLoc(), reshapeType.getShape(), newAttr); + rewriter.replaceOp(op, {newConstant}); + } else if (auto valueAttr = + constantOp.getAttrOfType<mlir::FloatAttr>("value")) { + // Broadcast + auto dataSize = std::accumulate(reshapeType.getShape().begin(), + reshapeType.getShape().end(), 1, + std::multiplies<int>()); + std::vector<mlir::Attribute> data(dataSize, valueAttr); + auto tensorTy = rewriter.getTensorType(reshapeType.getShape(), + reshapeType.getElementType()); + auto newAttr = mlir::DenseElementsAttr::get(tensorTy, data); + auto newConstant = rewriter.create<ConstantOp>( + constantOp.getLoc(), reshapeType.getShape(), newAttr); + rewriter.replaceOp(op, {newConstant}); + } else { + llvm_unreachable("Unsupported Constant format"); + } + return matchSuccess(); + } +}; + +/// Fold reshape(reshape(x)) -> reshape(x) +struct SimplifyReshapeReshape : public mlir::RewritePattern { + SimplifyReshapeReshape(mlir::MLIRContext *context) + : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1, + context) {} + + mlir::PatternMatchResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + ReshapeOp reshape = op->cast<ReshapeOp>(); + // look through the input to the current reshape + mlir::Value *reshapeInput = reshape.getOperand(); + mlir::Operation *reshapeInputInst = reshapeInput->getDefiningOp(); + // If the input is defined by another reshape, bingo! + ReshapeOp reshapeInputOp = + mlir::dyn_cast_or_null<ReshapeOp>(reshapeInputInst); + if (!reshapeInputOp) + return matchFailure(); + + // Use the rewriter to perform the replacement + rewriter.replaceOp(op, {reshapeInputOp}); + return matchSuccess(); + } +}; + +/// Fold reshape(x)) -> x, when input type matches output type +struct SimplifyNullReshape : public mlir::RewritePattern { + SimplifyNullReshape(mlir::MLIRContext *context) + : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1, + context) {} + + mlir::PatternMatchResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + ReshapeOp reshape = op->cast<ReshapeOp>(); + if (reshape.getOperand()->getType() != reshape.getResult()->getType()) + return matchFailure(); + rewriter.replaceOp(reshape, {reshape.getOperand()}); + return matchSuccess(); + } +}; + +} // end anonymous namespace. + +// Register our patterns for rewrite by the Canonicalization framework. +void TransposeOp::getCanonicalizationPatterns( + mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { + results.push_back(llvm::make_unique<SimplifyRedundantTranspose>(context)); +} + +// Register our patterns for rewrite by the Canonicalization framework. +void ReshapeOp::getCanonicalizationPatterns( + mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { + results.push_back(llvm::make_unique<SimplifyReshapeConstant>(context)); + results.push_back(llvm::make_unique<SimplifyReshapeReshape>(context)); + results.push_back(llvm::make_unique<SimplifyNullReshape>(context)); +} + +namespace { + +/// Fold type.cast(x) -> x, when input type matches output type +struct SimplifyIdentityTypeCast : public mlir::RewritePattern { + SimplifyIdentityTypeCast(mlir::MLIRContext *context) + : RewritePattern(TypeCastOp::getOperationName(), /* benefit = */ 1, + context) {} + + mlir::PatternMatchResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + TypeCastOp typeCast = op->cast<TypeCastOp>(); + auto resTy = typeCast.getResult()->getType(); + auto *candidateOp = op; + while (candidateOp && candidateOp->isa<TypeCastOp>()) { + if (resTy == candidateOp->getOperand(0)->getType()) { + rewriter.replaceOp(typeCast, {candidateOp->getOperand(0)}); + return matchSuccess(); + } + candidateOp = candidateOp->getOperand(0)->getDefiningOp(); + } + return matchFailure(); + } +}; + +} // end anonymous namespace. + +void TypeCastOp::getCanonicalizationPatterns( + mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { + results.push_back(llvm::make_unique<SimplifyIdentityTypeCast>(context)); +} + +} // namespace toy |