diff options
Diffstat (limited to 'mlir/ShapeInferencePass.cpp')
-rw-r--r-- | mlir/ShapeInferencePass.cpp | 387 |
1 files changed, 387 insertions, 0 deletions
diff --git a/mlir/ShapeInferencePass.cpp b/mlir/ShapeInferencePass.cpp new file mode 100644 index 0000000..7e3ea3f --- /dev/null +++ b/mlir/ShapeInferencePass.cpp @@ -0,0 +1,387 @@ +//===- ShapeInferencePass.cpp - Toy Shape Inference / Func Specialization -===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a Module level pass performing interprocedural +// propagation of array shapes through function specialization. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" + +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/StandardOps/Ops.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> + +#define DEBUG_TYPE "toy-shape-inference" + +using namespace toy; +using llvm::MutableArrayRef; +using llvm::SmallVector; +using llvm::SmallVectorImpl; +using llvm::StringRef; +using llvm::Twine; + +/// Create mangled name for function specialization. We will simply append the +/// shape of the arguments to the function name. For example calling +/// +/// "toy.generic_call"(%1, %3) {callee: "foo"} +/// : (!toy<"array<2, 3>">, !toy<"array<2, 3>">) -> !toy<"array"> +/// +/// would be mangled foo_2x3_2x3. This mangling isn't robust as the user could +/// have provide a function with a similar name. But we will claim this as a +/// feature: this allow the user to provide custom specialization! +static std::string mangle(StringRef funcName, + MutableArrayRef<mlir::OpOperand> operands) { + std::string mangledName; + mangledName.reserve(funcName.size() + operands.size() * 6); + mangledName = funcName; + for (auto &operand : operands) { + auto arrayTy = operand.get()->getType().cast<ToyArrayType>(); + mangledName += "_"; + const char *sep = ""; + for (auto dim : arrayTy.getShape()) { + mangledName += (sep + Twine(dim)).str(); + sep = "x"; + } + } + return mangledName; +} + +namespace { + +/// The ShapeInferencePass is a ModulePass: it will run on the Module as a +/// whole. MLIR also supports FunctionPass which are restricted to modify a +/// single function at a time. This pass couldn't be a function pass due the +/// nature of its interprocedural transformations. +/// +/// The algorithm has two levels, first intra-procedurally: +/// +/// 1) Build a worklist containing all the operations that are returning +/// a generic Toy array: these are the operations that need shape +/// inference. +/// 2) Iterate on the worklist: +/// a) find an operation to process: the next ready operation in the +/// worklist has all of its arguments non-generic, +/// b) if no operation is found, break out of the loop, +/// c) remove the operation from the worklist, +/// d) infer the shape of its output from the arguments type. +/// 3) If the worklist is empty, the algorithm succeeded and we infer the +/// return type for the function from the return operation. +/// +/// There is a twist though: when a call to a generic function is encountered, +/// shape inference requires the return type of the callee to be inferred first. +/// At this point we need to run specialize the callee by cloning it. Here is +/// the inter-procedural flow: +/// +/// 1) Keep a worklist of function to process. Start with function "main". +/// 2) While the worklist isn't empty: +/// a) Take the last inserted function in the worklist. +/// b) Run the intra-procedural shape inference on this function. +/// c) If the intra-procedural shape inference can't complete, it returns +/// a Function that needs to be inferred first. In this case, queue this +/// new function and continue. Otherwise the inference succeeded and we +/// can pop from the queue. +/// +class ShapeInferencePass : public mlir::ModulePass<ShapeInferencePass> { +public: + // One entry in the inter-procedural worklist. It keeps track of the + // function to process, the mangled name for this specialization, and the + // types of the arguments on which to specialize. + struct FunctionToSpecialize { + mlir::Function *function; + std::string mangledName; + std::vector<mlir::Type> argumentsType; + }; + + void runOnModule() override { + auto &module = getModule(); + auto *main = module.getNamedFunction("main"); + if (!main) { + module.getContext()->emitError( + mlir::UnknownLoc::get(module.getContext()), + "Shape inference failed: can't find a main function\n"); + signalPassFailure(); + return; + } + + /// Inter-procedural loop, initialize with `main` and iterate till + /// successfully infer the full reachable call-graph from main. + SmallVector<FunctionToSpecialize, 8> worklist; + worklist.push_back({main, "", {}}); + while (!worklist.empty()) { + if (failed(specialize(worklist))) + return; + } + + // Delete any generic function left + // FIXME: we may want this as a separate pass. + for (mlir::Function &function : llvm::make_early_inc_range(module)) { + if (auto genericAttr = + function.getAttrOfType<mlir::BoolAttr>("toy.generic")) { + if (genericAttr.getValue()) + function.erase(); + } + } + } + + /// Run inference on a function. If a mangledName is provided, we need to + /// specialize the function: to this end clone it first. + mlir::LogicalResult + specialize(SmallVectorImpl<FunctionToSpecialize> &funcWorklist) { + FunctionToSpecialize &functionToSpecialize = funcWorklist.back(); + mlir::Function *f = functionToSpecialize.function; + + // Check if cloning for specialization is needed (usually anything but main) + // We will create a new function with the concrete types for the parameters + // and clone the body into it. + if (!functionToSpecialize.mangledName.empty()) { + if (getModule().getNamedFunction(functionToSpecialize.mangledName)) { + funcWorklist.pop_back(); + // Function already specialized, move on. + return mlir::success(); + } + // Create a new function with a generic array return type, it will be + // updated when the inference for the function body completes. + auto type = mlir::FunctionType::get(functionToSpecialize.argumentsType, + {ToyArrayType::get(&getContext())}, + &getContext()); + auto *newFunction = new mlir::Function( + f->getLoc(), functionToSpecialize.mangledName, type, f->getAttrs()); + getModule().getFunctions().push_back(newFunction); + + // Clone the function body + mlir::BlockAndValueMapping mapper; + f->cloneInto(newFunction, mapper); + LLVM_DEBUG({ + llvm::dbgs() << "====== Cloned : \n"; + f->dump(); + llvm::dbgs() << "====== Into : \n"; + newFunction->dump(); + }); + f = newFunction; + f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); + // Remap the entry-block arguments + // FIXME: this seems like a bug in `cloneInto()` above? + auto &entryBlock = f->getBlocks().front(); + int blockArgSize = entryBlock.getArguments().size(); + assert(blockArgSize == f->getType().getInputs().size()); + entryBlock.addArguments(f->getType().getInputs()); + auto argList = entryBlock.getArguments(); + for (int argNum = 0; argNum < blockArgSize; ++argNum) { + argList[0]->replaceAllUsesWith(argList[blockArgSize]); + entryBlock.eraseArgument(0); + } + assert(succeeded(f->verify())); + } + LLVM_DEBUG(llvm::dbgs() + << "Run shape inference on : '" << f->getName() << "'\n"); + + auto *toyDialect = getContext().getRegisteredDialect("toy"); + if (!toyDialect) { + getContext().emitError(mlir::UnknownLoc::get(&getContext()), + "Toy dialect is not registered"); + signalPassFailure(); + return mlir::failure(); + } + + // Populate the worklist with the operations that need shape inference: + // these are the Toy operations that return a generic array. + llvm::SmallPtrSet<mlir::Operation *, 16> opWorklist; + f->walk([&](mlir::Operation *op) { + if (op->getDialect() == toyDialect) { + if (op->getNumResults() == 1 && + op->getResult(0)->getType().cast<ToyArrayType>().isGeneric()) + opWorklist.insert(op); + } + }); + + // Iterate on the operations in the worklist until all operations have been + // inferred or no change happened (fix point). + while (!opWorklist.empty()) { + // Find the next operation ready for inference, that is an operation + // with all operands already resolved (non-generic). + auto nextop = llvm::find_if(opWorklist, [](mlir::Operation *op) { + return llvm::all_of(op->getOperands(), [](mlir::Value *v) { + return !v->getType().cast<ToyArrayType>().isGeneric(); + }); + }); + if (nextop == opWorklist.end()) + break; // failure: no operations can be inferred. + + mlir::Operation *op = *nextop; + opWorklist.erase(op); + LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); + + // The add operation is trivial: propagate the input type as is. + if (auto addOp = op->dyn_cast<AddOp>()) { + op->getResult(0)->setType(op->getOperand(0)->getType()); + continue; + } + + // Transpose is easy: just invert the dimensions. + if (op->getName().getStringRef() == "toy.transpose") { + SmallVector<int64_t, 2> dims; + auto arrayTy = op->getOperand(0)->getType().cast<ToyArrayType>(); + dims.insert(dims.end(), arrayTy.getShape().begin(), + arrayTy.getShape().end()); + if (dims.size() == 2) + std::swap(dims[0], dims[1]); + op->getResult(0)->setType(ToyArrayType::get(&getContext(), dims)); + continue; + } + + // Multiplication is a bit trickier, handle rank 1 as dot product and rank + // 2 as matrix multiplications. + // We need to be careful about rank mismatch here: the verifier could + // catch it but shape inference earlier in the pass could generate an + // invalid IR (from an invalid Toy input of course) and we wouldn't want + // to crash here. + if (auto mulOp = op->dyn_cast<MulOp>()) { + auto lhs = mulOp.getLHS()->getType().cast<ToyArrayType>(); + auto rhs = mulOp.getRHS()->getType().cast<ToyArrayType>(); + auto lhsRank = lhs.getShape().size(); + auto rhsRank = rhs.getShape().size(); + if (lhsRank != rhsRank) { + op->emitError("Shape mismatch: LHS and RHS must have the same " + "rank for multiplication, got " + + Twine(lhsRank) + " vs " + Twine(lhsRank)); + return mlir::failure(); + } + SmallVector<int64_t, 2> dims; + if (lhsRank == 1) { + // dot product, result shape is <1> + dims.push_back(1); + } else { + if (lhsRank != 2) { + op->emitError( + "Shape mismatch: expect rank 1 or 2 for mul operands, got " + + Twine(lhsRank)); + return mlir::failure(); + } + dims.push_back(lhs.getShape()[0]); + dims.push_back(rhs.getShape()[1]); + } + op->getResult(0)->setType(ToyArrayType::get(&getContext(), dims)); + continue; + } + + // Process calls: lookup the callee after mangling the name with the + // argument shapes. If the callee does not exist, we stop the inference + // for this function, queue the callee in the inter-procedural work list, + // and return. The current function stays in the work list and will + // restart after the callee is processed. + if (auto callOp = op->dyn_cast<GenericCallOp>()) { + auto calleeName = callOp.getCalleeName(); + auto *callee = getModule().getNamedFunction(calleeName); + if (!callee) { + f->emitError( + llvm::Twine("Shape inference failed, call to unknown '") + + calleeName + "'"); + signalPassFailure(); + return mlir::failure(); + } + auto mangledName = mangle(calleeName, op->getOpOperands()); + LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName + << "', mangled: '" << mangledName << "'\n"); + auto *mangledCallee = getModule().getNamedFunction(mangledName); + if (!mangledCallee) { + // Can't find the target, this is where we queue the request for the + // callee and stop the inference for the current function now. + std::vector<mlir::Type> funcArgs; + for (auto operand : op->getOperands()) + funcArgs.push_back(operand->getType()); + funcWorklist.push_back( + {callee, std::move(mangledName), std::move(funcArgs)}); + return mlir::success(); + } + // Found a specialized callee! Let's turn this into a normal call + // operation. + SmallVector<mlir::Value *, 8> operands; + for (mlir::Value *v : op->getOperands()) + operands.push_back(v); + mlir::FuncBuilder builder(f); + builder.setInsertionPoint(op); + auto newCall = + builder.create<mlir::CallOp>(op->getLoc(), mangledCallee, operands); + if (newCall.getNumResults()) { + op->getResult(0)->replaceAllUsesWith(newCall.getResult(0)); + op->erase(); + continue; + } + } + } + + // Done with inference on this function, removing it from the worklist. + funcWorklist.pop_back(); + // Mark the function as non-generic now that inference has succeeded + f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); + + // If the operation worklist isn't empty, this indicates a failure. + if (!opWorklist.empty()) { + std::string str; + llvm::raw_string_ostream errorMsg(str); + errorMsg << "Shape inference failed, " << opWorklist.size() + << " operations couldn't be inferred\n"; + for (auto *ope : opWorklist) + errorMsg << " - " << *ope << "\n"; + f->emitError(errorMsg.str()); + signalPassFailure(); + return mlir::failure(); + } + + // Finally, update the return type of the function based on the argument to + // the return operation. + for (auto &block : f->getBlocks()) { + auto ret = block.getTerminator()->cast<ReturnOp>(); + if (!ret) + continue; + if (ret.getNumOperands() && + f->getType().getResult(0) == ret.getOperand()->getType()) + // type match, we're done + break; + SmallVector<mlir::Type, 1> retTy; + if (ret.getNumOperands()) + retTy.push_back(ret.getOperand()->getType()); + mlir::Type elementType = mlir::FloatType::getF64(&getContext()); + std::vector<mlir::Type> argumentsType; + for (auto arg : f->getArguments()) + argumentsType.push_back(arg->getType()); + auto newType = + mlir::FunctionType::get(argumentsType, retTy, &getContext()); + f->setType(newType); + assert(succeeded(f->verify())); + break; + } + return mlir::success(); + } +}; +} // end anonymous namespace + +namespace toy { +mlir::Pass *createShapeInferencePass() { return new ShapeInferencePass(); } +} // namespace toy |