diff options
Diffstat (limited to 'mlir/ShapeInferencePass.cpp')
-rw-r--r-- | mlir/ShapeInferencePass.cpp | 387 |
1 files changed, 0 insertions, 387 deletions
diff --git a/mlir/ShapeInferencePass.cpp b/mlir/ShapeInferencePass.cpp deleted file mode 100644 index 7e3ea3f..0000000 --- a/mlir/ShapeInferencePass.cpp +++ /dev/null @@ -1,387 +0,0 @@ -//===- ShapeInferencePass.cpp - Toy Shape Inference / Func Specialization -===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file implements a Module level pass performing interprocedural -// propagation of array shapes through function specialization. -// -//===----------------------------------------------------------------------===// - -#include "toy/Dialect.h" - -#include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/Ops.h" -#include "mlir/Support/LogicalResult.h" -#include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringSet.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/raw_ostream.h" -#include <algorithm> - -#define DEBUG_TYPE "toy-shape-inference" - -using namespace toy; -using llvm::MutableArrayRef; -using llvm::SmallVector; -using llvm::SmallVectorImpl; -using llvm::StringRef; -using llvm::Twine; - -/// Create mangled name for function specialization. We will simply append the -/// shape of the arguments to the function name. For example calling -/// -/// "toy.generic_call"(%1, %3) {callee: "foo"} -/// : (!toy<"array<2, 3>">, !toy<"array<2, 3>">) -> !toy<"array"> -/// -/// would be mangled foo_2x3_2x3. This mangling isn't robust as the user could -/// have provide a function with a similar name. But we will claim this as a -/// feature: this allow the user to provide custom specialization! -static std::string mangle(StringRef funcName, - MutableArrayRef<mlir::OpOperand> operands) { - std::string mangledName; - mangledName.reserve(funcName.size() + operands.size() * 6); - mangledName = funcName; - for (auto &operand : operands) { - auto arrayTy = operand.get()->getType().cast<ToyArrayType>(); - mangledName += "_"; - const char *sep = ""; - for (auto dim : arrayTy.getShape()) { - mangledName += (sep + Twine(dim)).str(); - sep = "x"; - } - } - return mangledName; -} - -namespace { - -/// The ShapeInferencePass is a ModulePass: it will run on the Module as a -/// whole. MLIR also supports FunctionPass which are restricted to modify a -/// single function at a time. This pass couldn't be a function pass due the -/// nature of its interprocedural transformations. -/// -/// The algorithm has two levels, first intra-procedurally: -/// -/// 1) Build a worklist containing all the operations that are returning -/// a generic Toy array: these are the operations that need shape -/// inference. -/// 2) Iterate on the worklist: -/// a) find an operation to process: the next ready operation in the -/// worklist has all of its arguments non-generic, -/// b) if no operation is found, break out of the loop, -/// c) remove the operation from the worklist, -/// d) infer the shape of its output from the arguments type. -/// 3) If the worklist is empty, the algorithm succeeded and we infer the -/// return type for the function from the return operation. -/// -/// There is a twist though: when a call to a generic function is encountered, -/// shape inference requires the return type of the callee to be inferred first. -/// At this point we need to run specialize the callee by cloning it. Here is -/// the inter-procedural flow: -/// -/// 1) Keep a worklist of function to process. Start with function "main". -/// 2) While the worklist isn't empty: -/// a) Take the last inserted function in the worklist. -/// b) Run the intra-procedural shape inference on this function. -/// c) If the intra-procedural shape inference can't complete, it returns -/// a Function that needs to be inferred first. In this case, queue this -/// new function and continue. Otherwise the inference succeeded and we -/// can pop from the queue. -/// -class ShapeInferencePass : public mlir::ModulePass<ShapeInferencePass> { -public: - // One entry in the inter-procedural worklist. It keeps track of the - // function to process, the mangled name for this specialization, and the - // types of the arguments on which to specialize. - struct FunctionToSpecialize { - mlir::Function *function; - std::string mangledName; - std::vector<mlir::Type> argumentsType; - }; - - void runOnModule() override { - auto &module = getModule(); - auto *main = module.getNamedFunction("main"); - if (!main) { - module.getContext()->emitError( - mlir::UnknownLoc::get(module.getContext()), - "Shape inference failed: can't find a main function\n"); - signalPassFailure(); - return; - } - - /// Inter-procedural loop, initialize with `main` and iterate till - /// successfully infer the full reachable call-graph from main. - SmallVector<FunctionToSpecialize, 8> worklist; - worklist.push_back({main, "", {}}); - while (!worklist.empty()) { - if (failed(specialize(worklist))) - return; - } - - // Delete any generic function left - // FIXME: we may want this as a separate pass. - for (mlir::Function &function : llvm::make_early_inc_range(module)) { - if (auto genericAttr = - function.getAttrOfType<mlir::BoolAttr>("toy.generic")) { - if (genericAttr.getValue()) - function.erase(); - } - } - } - - /// Run inference on a function. If a mangledName is provided, we need to - /// specialize the function: to this end clone it first. - mlir::LogicalResult - specialize(SmallVectorImpl<FunctionToSpecialize> &funcWorklist) { - FunctionToSpecialize &functionToSpecialize = funcWorklist.back(); - mlir::Function *f = functionToSpecialize.function; - - // Check if cloning for specialization is needed (usually anything but main) - // We will create a new function with the concrete types for the parameters - // and clone the body into it. - if (!functionToSpecialize.mangledName.empty()) { - if (getModule().getNamedFunction(functionToSpecialize.mangledName)) { - funcWorklist.pop_back(); - // Function already specialized, move on. - return mlir::success(); - } - // Create a new function with a generic array return type, it will be - // updated when the inference for the function body completes. - auto type = mlir::FunctionType::get(functionToSpecialize.argumentsType, - {ToyArrayType::get(&getContext())}, - &getContext()); - auto *newFunction = new mlir::Function( - f->getLoc(), functionToSpecialize.mangledName, type, f->getAttrs()); - getModule().getFunctions().push_back(newFunction); - - // Clone the function body - mlir::BlockAndValueMapping mapper; - f->cloneInto(newFunction, mapper); - LLVM_DEBUG({ - llvm::dbgs() << "====== Cloned : \n"; - f->dump(); - llvm::dbgs() << "====== Into : \n"; - newFunction->dump(); - }); - f = newFunction; - f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); - // Remap the entry-block arguments - // FIXME: this seems like a bug in `cloneInto()` above? - auto &entryBlock = f->getBlocks().front(); - int blockArgSize = entryBlock.getArguments().size(); - assert(blockArgSize == f->getType().getInputs().size()); - entryBlock.addArguments(f->getType().getInputs()); - auto argList = entryBlock.getArguments(); - for (int argNum = 0; argNum < blockArgSize; ++argNum) { - argList[0]->replaceAllUsesWith(argList[blockArgSize]); - entryBlock.eraseArgument(0); - } - assert(succeeded(f->verify())); - } - LLVM_DEBUG(llvm::dbgs() - << "Run shape inference on : '" << f->getName() << "'\n"); - - auto *toyDialect = getContext().getRegisteredDialect("toy"); - if (!toyDialect) { - getContext().emitError(mlir::UnknownLoc::get(&getContext()), - "Toy dialect is not registered"); - signalPassFailure(); - return mlir::failure(); - } - - // Populate the worklist with the operations that need shape inference: - // these are the Toy operations that return a generic array. - llvm::SmallPtrSet<mlir::Operation *, 16> opWorklist; - f->walk([&](mlir::Operation *op) { - if (op->getDialect() == toyDialect) { - if (op->getNumResults() == 1 && - op->getResult(0)->getType().cast<ToyArrayType>().isGeneric()) - opWorklist.insert(op); - } - }); - - // Iterate on the operations in the worklist until all operations have been - // inferred or no change happened (fix point). - while (!opWorklist.empty()) { - // Find the next operation ready for inference, that is an operation - // with all operands already resolved (non-generic). - auto nextop = llvm::find_if(opWorklist, [](mlir::Operation *op) { - return llvm::all_of(op->getOperands(), [](mlir::Value *v) { - return !v->getType().cast<ToyArrayType>().isGeneric(); - }); - }); - if (nextop == opWorklist.end()) - break; // failure: no operations can be inferred. - - mlir::Operation *op = *nextop; - opWorklist.erase(op); - LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); - - // The add operation is trivial: propagate the input type as is. - if (auto addOp = op->dyn_cast<AddOp>()) { - op->getResult(0)->setType(op->getOperand(0)->getType()); - continue; - } - - // Transpose is easy: just invert the dimensions. - if (op->getName().getStringRef() == "toy.transpose") { - SmallVector<int64_t, 2> dims; - auto arrayTy = op->getOperand(0)->getType().cast<ToyArrayType>(); - dims.insert(dims.end(), arrayTy.getShape().begin(), - arrayTy.getShape().end()); - if (dims.size() == 2) - std::swap(dims[0], dims[1]); - op->getResult(0)->setType(ToyArrayType::get(&getContext(), dims)); - continue; - } - - // Multiplication is a bit trickier, handle rank 1 as dot product and rank - // 2 as matrix multiplications. - // We need to be careful about rank mismatch here: the verifier could - // catch it but shape inference earlier in the pass could generate an - // invalid IR (from an invalid Toy input of course) and we wouldn't want - // to crash here. - if (auto mulOp = op->dyn_cast<MulOp>()) { - auto lhs = mulOp.getLHS()->getType().cast<ToyArrayType>(); - auto rhs = mulOp.getRHS()->getType().cast<ToyArrayType>(); - auto lhsRank = lhs.getShape().size(); - auto rhsRank = rhs.getShape().size(); - if (lhsRank != rhsRank) { - op->emitError("Shape mismatch: LHS and RHS must have the same " - "rank for multiplication, got " + - Twine(lhsRank) + " vs " + Twine(lhsRank)); - return mlir::failure(); - } - SmallVector<int64_t, 2> dims; - if (lhsRank == 1) { - // dot product, result shape is <1> - dims.push_back(1); - } else { - if (lhsRank != 2) { - op->emitError( - "Shape mismatch: expect rank 1 or 2 for mul operands, got " + - Twine(lhsRank)); - return mlir::failure(); - } - dims.push_back(lhs.getShape()[0]); - dims.push_back(rhs.getShape()[1]); - } - op->getResult(0)->setType(ToyArrayType::get(&getContext(), dims)); - continue; - } - - // Process calls: lookup the callee after mangling the name with the - // argument shapes. If the callee does not exist, we stop the inference - // for this function, queue the callee in the inter-procedural work list, - // and return. The current function stays in the work list and will - // restart after the callee is processed. - if (auto callOp = op->dyn_cast<GenericCallOp>()) { - auto calleeName = callOp.getCalleeName(); - auto *callee = getModule().getNamedFunction(calleeName); - if (!callee) { - f->emitError( - llvm::Twine("Shape inference failed, call to unknown '") + - calleeName + "'"); - signalPassFailure(); - return mlir::failure(); - } - auto mangledName = mangle(calleeName, op->getOpOperands()); - LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName - << "', mangled: '" << mangledName << "'\n"); - auto *mangledCallee = getModule().getNamedFunction(mangledName); - if (!mangledCallee) { - // Can't find the target, this is where we queue the request for the - // callee and stop the inference for the current function now. - std::vector<mlir::Type> funcArgs; - for (auto operand : op->getOperands()) - funcArgs.push_back(operand->getType()); - funcWorklist.push_back( - {callee, std::move(mangledName), std::move(funcArgs)}); - return mlir::success(); - } - // Found a specialized callee! Let's turn this into a normal call - // operation. - SmallVector<mlir::Value *, 8> operands; - for (mlir::Value *v : op->getOperands()) - operands.push_back(v); - mlir::FuncBuilder builder(f); - builder.setInsertionPoint(op); - auto newCall = - builder.create<mlir::CallOp>(op->getLoc(), mangledCallee, operands); - if (newCall.getNumResults()) { - op->getResult(0)->replaceAllUsesWith(newCall.getResult(0)); - op->erase(); - continue; - } - } - } - - // Done with inference on this function, removing it from the worklist. - funcWorklist.pop_back(); - // Mark the function as non-generic now that inference has succeeded - f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); - - // If the operation worklist isn't empty, this indicates a failure. - if (!opWorklist.empty()) { - std::string str; - llvm::raw_string_ostream errorMsg(str); - errorMsg << "Shape inference failed, " << opWorklist.size() - << " operations couldn't be inferred\n"; - for (auto *ope : opWorklist) - errorMsg << " - " << *ope << "\n"; - f->emitError(errorMsg.str()); - signalPassFailure(); - return mlir::failure(); - } - - // Finally, update the return type of the function based on the argument to - // the return operation. - for (auto &block : f->getBlocks()) { - auto ret = block.getTerminator()->cast<ReturnOp>(); - if (!ret) - continue; - if (ret.getNumOperands() && - f->getType().getResult(0) == ret.getOperand()->getType()) - // type match, we're done - break; - SmallVector<mlir::Type, 1> retTy; - if (ret.getNumOperands()) - retTy.push_back(ret.getOperand()->getType()); - mlir::Type elementType = mlir::FloatType::getF64(&getContext()); - std::vector<mlir::Type> argumentsType; - for (auto arg : f->getArguments()) - argumentsType.push_back(arg->getType()); - auto newType = - mlir::FunctionType::get(argumentsType, retTy, &getContext()); - f->setType(newType); - assert(succeeded(f->verify())); - break; - } - return mlir::success(); - } -}; -} // end anonymous namespace - -namespace toy { -mlir::Pass *createShapeInferencePass() { return new ShapeInferencePass(); } -} // namespace toy |