//===- ToyDialect.cpp - Toy IR Dialect registration in MLIR ---------------===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= // // This file implements the dialect for the Toy IR: custom type parsing and // operation verification. // //===----------------------------------------------------------------------===// #include "toy/Dialect.h" #include "mlir/IR/Builders.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Support/STLExtras.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/Regex.h" #include "llvm/Support/raw_ostream.h" using llvm::ArrayRef; using llvm::raw_ostream; using llvm::raw_string_ostream; using llvm::SmallVector; using llvm::StringRef; using llvm::Twine; namespace toy { namespace detail { /// This class holds the implementation of the ToyArrayType. /// It is intended to be uniqued based on its content and owned by the context. struct ToyArrayTypeStorage : public mlir::TypeStorage { /// This defines how we unique this type in the context: our key contains /// only the shape, a more complex type would have multiple entries in the /// tuple here. /// The element of the tuples usually matches 1-1 the arguments from the /// public `get()` method arguments from the facade. using KeyTy = std::tuple>; static unsigned hashKey(const KeyTy &key) { return llvm::hash_combine(std::get<0>(key)); } /// When the key hash hits an existing type, we compare the shape themselves /// to confirm we have the right type. bool operator==(const KeyTy &key) const { return key == KeyTy(getShape()); } /// This is a factory method to create our type storage. It is only /// invoked after looking up the type in the context using the key and not /// finding it. static ToyArrayTypeStorage *construct(mlir::TypeStorageAllocator &allocator, const KeyTy &key) { // Copy the shape array into the bumpptr allocator owned by the context. ArrayRef shape = allocator.copyInto(std::get<0>(key)); // Allocate the instance for the ToyArrayTypeStorage itself auto *storage = allocator.allocate(); // Initialize the instance using placement new. return new (storage) ToyArrayTypeStorage(shape); } ArrayRef getShape() const { return shape; } private: ArrayRef shape; /// Constructor is only invoked from the `construct()` method above. ToyArrayTypeStorage(ArrayRef shape) : shape(shape) {} }; } // namespace detail mlir::Type ToyArrayType::getElementType() { return mlir::FloatType::getF64(getContext()); } ToyArrayType ToyArrayType::get(mlir::MLIRContext *context, ArrayRef shape) { return Base::get(context, ToyTypeKind::TOY_ARRAY, shape); } ArrayRef ToyArrayType::getShape() { return getImpl()->getShape(); } mlir::MemRefType ToyArrayType::toMemref() { auto memRefType = mlir::MemRefType::get(getShape(), getElementType(), {}, 0); return memRefType; } /// Dialect creation, the instance will be owned by the context. This is the /// point of registration of custom types and operations for the dialect. ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) { addOperations(); addTypes(); } /// Parse a type registered to this dialect, we expect only Toy arrays. mlir::Type ToyDialect::parseType(StringRef tyData, mlir::Location loc) const { // Sanity check: we only support array or array<...> if (!tyData.startswith("array")) { getContext()->emitError(loc, "Invalid Toy type '" + tyData + "', array expected"); return nullptr; } // Drop the "array" prefix from the type name, we expect either an empty // string or just the shape. tyData = tyData.drop_front(StringRef("array").size()); // This is the generic array case without shape, early return it. if (tyData.empty()) return ToyArrayType::get(getContext()); // Use a regex to parse the shape (for efficient we should store this regex in // the dialect itself). SmallVector matches; auto shapeRegex = llvm::Regex("^<([0-9]+)(, ([0-9]+))*>$"); if (!shapeRegex.match(tyData, &matches)) { getContext()->emitError(loc, "Invalid toy array shape '" + tyData + "'"); return nullptr; } SmallVector shape; // Iterate through the captures, skip the first one which is the full string. for (auto dimStr : llvm::make_range(std::next(matches.begin()), matches.end())) { if (dimStr.startswith(",")) continue; // POSIX misses non-capturing groups. if (dimStr.empty()) continue; // '*' makes it an optional group capture // Convert the capture to an integer unsigned long long dim; if (getAsUnsignedInteger(dimStr, /* Radix = */ 10, dim)) { getContext()->emitError( loc, "Couldn't parse dimension as integer, matched: " + dimStr); return mlir::Type(); } shape.push_back(dim); } // Finally we collected all the dimensions in the shape, // create the array type. return ToyArrayType::get(getContext(), shape); } /// Print a Toy array type, for example `array<2, 3, 4>` void ToyDialect::printType(mlir::Type type, raw_ostream &os) const { auto arrayTy = type.dyn_cast(); if (!arrayTy) { os << "unknown toy type"; return; } os << "array"; if (!arrayTy.getShape().empty()) { os << "<"; mlir::interleaveComma(arrayTy.getShape(), os); os << ">"; } } //////////////////////////////////////////////////////////////////////////////// //////////////////// Custom Operations for the Dialect ///////////////////////// //////////////////////////////////////////////////////////////////////////////// /// Helper to verify that the result of an operation is a Toy array type. template static mlir::LogicalResult verifyToyReturnArray(T *op) { if (!op->getResult()->getType().template isa()) { std::string msg; raw_string_ostream os(msg); os << "expects a Toy Array for its argument, got " << op->getResult()->getType(); return op->emitOpError(os.str()); } return mlir::success(); } /// Helper to verify that the two operands of a binary operation are Toy /// arrays.. template static mlir::LogicalResult verifyToyBinOperands(T *op) { if (!op->getOperand(0)->getType().template isa()) { std::string msg; raw_string_ostream os(msg); os << "expects a Toy Array for its LHS, got " << op->getOperand(0)->getType(); return op->emitOpError(os.str()); } if (!op->getOperand(1)->getType().template isa()) { std::string msg; raw_string_ostream os(msg); os << "expects a Toy Array for its LHS, got " << op->getOperand(0)->getType(); return op->emitOpError(os.str()); } return mlir::success(); } /// Build a constant operation. /// The builder is passed as an argument, so is the state that this method is /// expected to fill in order to build the operation. void ConstantOp::build(mlir::Builder *builder, mlir::OperationState *state, ArrayRef shape, mlir::DenseElementsAttr value) { state->types.push_back(ToyArrayType::get(builder->getContext(), shape)); auto dataAttribute = builder->getNamedAttr("value", value); state->attributes.push_back(dataAttribute); } /// Build a constant operation. /// The builder is passed as an argument, so is the state that this method is /// expected to fill in order to build the operation. void ConstantOp::build(mlir::Builder *builder, mlir::OperationState *state, mlir::FloatAttr value) { // Broadcast and forward to the other build factory mlir::Type elementType = mlir::FloatType::getF64(builder->getContext()); auto dataType = builder->getTensorType({1}, elementType); auto dataAttribute = builder->getDenseElementsAttr(dataType, {value}) .cast(); ConstantOp::build(builder, state, {1}, dataAttribute); } /// Verifier for constant operation. mlir::LogicalResult ConstantOp::verify() { // Ensure that the return type is a Toy array if (failed(verifyToyReturnArray(this))) return mlir::failure(); // We expect the constant itself to be stored as an attribute. auto dataAttr = getAttr("value").dyn_cast(); if (!dataAttr) { return emitOpError( "missing valid `value` DenseElementsAttribute on toy.constant()"); } auto attrType = dataAttr.getType().dyn_cast(); if (!attrType) { return emitOpError( "missing valid `value` DenseElementsAttribute on toy.constant()"); } // If the return type of the constant is not a generic array, the shape must // match the shape of the attribute holding the data. auto resultType = getResult()->getType().cast(); if (!resultType.isGeneric()) { if (attrType.getRank() != resultType.getRank()) { return emitOpError("The rank of the toy.constant return type must match " "the one of the attached value attribute: " + Twine(attrType.getRank()) + " != " + Twine(resultType.getRank())); } for (int dim = 0; dim < attrType.getRank(); ++dim) { if (attrType.getShape()[dim] != resultType.getShape()[dim]) { std::string msg; raw_string_ostream os(msg); return emitOpError( "Shape mismatch between toy.constant return type and its " "attribute at dimension " + Twine(dim) + ": " + Twine(attrType.getShape()[dim]) + " != " + Twine(resultType.getShape()[dim])); } } } return mlir::success(); } void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState *state, StringRef callee, ArrayRef arguments) { // Generic call always returns a generic ToyArray initially state->types.push_back(ToyArrayType::get(builder->getContext())); state->operands.assign(arguments.begin(), arguments.end()); auto calleeAttr = builder->getStringAttr(callee); state->attributes.push_back(builder->getNamedAttr("callee", calleeAttr)); } mlir::LogicalResult GenericCallOp::verify() { // Verify that every operand is a Toy Array for (int opId = 0, num = getNumOperands(); opId < num; ++opId) { if (!getOperand(opId)->getType().template isa()) { std::string msg; raw_string_ostream os(msg); os << "expects a Toy Array for its " << opId << " operand, got " << getOperand(opId)->getType(); return emitOpError(os.str()); } } return mlir::success(); } /// Return the name of the callee. StringRef GenericCallOp::getCalleeName() { return getAttr("callee").cast().getValue(); } template static mlir::LogicalResult verifyToySingleOperand(T *op) { if (!op->getOperand()->getType().template isa()) { std::string msg; raw_string_ostream os(msg); os << "expects a Toy Array for its argument, got " << op->getOperand()->getType(); return op->emitOpError(os.str()); } return mlir::success(); } void ReturnOp::build(mlir::Builder *builder, mlir::OperationState *state, mlir::Value *value) { // Return does not return any value and has an optional single argument if (value) state->operands.push_back(value); } mlir::LogicalResult ReturnOp::verify() { if (getNumOperands() > 1) return emitOpError("expects zero or one operand, got " + Twine(getNumOperands())); if (hasOperand() && failed(verifyToySingleOperand(this))) return mlir::failure(); return mlir::success(); } void PrintOp::build(mlir::Builder *builder, mlir::OperationState *state, mlir::Value *value) { // Print does not return any value and has a single argument state->operands.push_back(value); } mlir::LogicalResult PrintOp::verify() { if (failed(verifyToySingleOperand(this))) return mlir::failure(); return mlir::success(); } void TransposeOp::build(mlir::Builder *builder, mlir::OperationState *state, mlir::Value *value) { state->types.push_back(ToyArrayType::get(builder->getContext())); state->operands.push_back(value); } mlir::LogicalResult TransposeOp::verify() { if (failed(verifyToySingleOperand(this))) return mlir::failure(); return mlir::success(); } void ReshapeOp::build(mlir::Builder *builder, mlir::OperationState *state, mlir::Value *value, ToyArrayType reshapedType) { state->types.push_back(reshapedType); state->operands.push_back(value); } mlir::LogicalResult ReshapeOp::verify() { if (failed(verifyToySingleOperand(this))) return mlir::failure(); auto retTy = getResult()->getType().dyn_cast(); if (!retTy) return emitOpError("toy.reshape is expected to produce a Toy array"); if (retTy.isGeneric()) return emitOpError("toy.reshape is expected to produce a shaped Toy array, " "got a generic one."); return mlir::success(); } void AddOp::build(mlir::Builder *builder, mlir::OperationState *state, mlir::Value *lhs, mlir::Value *rhs) { state->types.push_back(ToyArrayType::get(builder->getContext())); state->operands.push_back(lhs); state->operands.push_back(rhs); } mlir::LogicalResult AddOp::verify() { if (failed(verifyToyBinOperands(this))) return mlir::failure(); return mlir::success(); } void MulOp::build(mlir::Builder *builder, mlir::OperationState *state, mlir::Value *lhs, mlir::Value *rhs) { state->types.push_back(ToyArrayType::get(builder->getContext())); state->operands.push_back(lhs); state->operands.push_back(rhs); } mlir::LogicalResult MulOp::verify() { if (failed(verifyToyBinOperands(this))) return mlir::failure(); return mlir::success(); } void AllocOp::build(mlir::Builder *builder, mlir::OperationState *state, mlir::Type retType) { state->types.push_back(retType); } void TypeCastOp::build(mlir::Builder *builder, mlir::OperationState *state, mlir::Value *value, mlir::Type destTy) { state->operands.push_back(value); state->types.push_back(destTy); } } // namespace toy