From 22bb32ed1b9505ae49145ca7765def6398f4803d Mon Sep 17 00:00:00 2001 From: Tuowen Zhao Date: Wed, 24 Apr 2019 10:53:07 -0600 Subject: Initial commit --- .gitignore | 13 + CMakeLists.txt | 58 ++++ include/linalg1/Analysis.h | 49 ++++ include/linalg1/Common.h | 120 ++++++++ include/linalg1/ConvertToLLVMDialect.h | 66 +++++ include/linalg1/Dialect.h | 42 +++ include/linalg1/Intrinsics.h | 32 +++ include/linalg1/LLVMIntrinsics.h | 41 +++ include/linalg1/Ops.h | 26 ++ include/linalg1/RangeOp.h | 56 ++++ include/linalg1/RangeType.h | 49 ++++ include/linalg1/SliceOp.h | 91 ++++++ include/linalg1/Types.h | 36 +++ include/linalg1/Utils.h | 37 +++ include/linalg1/ViewOp.h | 67 +++++ include/linalg1/ViewType.h | 57 ++++ include/linalg2/Analysis.h | 23 ++ include/linalg2/Intrinsics.h | 32 +++ include/linalg2/Ops.h | 24 ++ include/linalg2/TensorOps-inl.h | 120 ++++++++ include/linalg2/TensorOps.h | 287 +++++++++++++++++++ include/linalg2/Transforms.h | 36 +++ include/linalg3/Analysis.h | 37 +++ include/linalg3/ConvertToLLVMDialect.h | 29 ++ include/linalg3/Intrinsics.h | 31 +++ include/linalg3/LoadStoreOps.h | 89 ++++++ include/linalg3/Ops.h | 25 ++ include/linalg3/TensorOps-inl.h | 145 ++++++++++ include/linalg3/TensorOps.h | 54 ++++ include/linalg3/Transforms.h | 80 ++++++ include/toy/AST.h | 256 +++++++++++++++++ include/toy/Dialect.h | 393 ++++++++++++++++++++++++++ include/toy/Lexer.h | 239 ++++++++++++++++ include/toy/Lowering.h | 45 +++ include/toy/MLIRGen.h | 42 +++ include/toy/Parser.h | 494 +++++++++++++++++++++++++++++++++ include/toy/Passes.h | 33 +++ mlir/EarlyLowering.cpp | 158 +++++++++++ mlir/LateLowering.cpp | 452 ++++++++++++++++++++++++++++++ mlir/MLIRGen.cpp | 480 ++++++++++++++++++++++++++++++++ mlir/ShapeInferencePass.cpp | 387 ++++++++++++++++++++++++++ mlir/ToyCombine.cpp | 209 ++++++++++++++ mlir/ToyDialect.cpp | 405 +++++++++++++++++++++++++++ parser/AST.cpp | 263 ++++++++++++++++++ toyc.cpp | 325 ++++++++++++++++++++++ 45 files changed, 6033 insertions(+) create mode 100644 .gitignore create mode 100644 CMakeLists.txt create mode 100644 include/linalg1/Analysis.h create mode 100644 include/linalg1/Common.h create mode 100644 include/linalg1/ConvertToLLVMDialect.h create mode 100644 include/linalg1/Dialect.h create mode 100644 include/linalg1/Intrinsics.h create mode 100644 include/linalg1/LLVMIntrinsics.h create mode 100644 include/linalg1/Ops.h create mode 100644 include/linalg1/RangeOp.h create mode 100644 include/linalg1/RangeType.h create mode 100644 include/linalg1/SliceOp.h create mode 100644 include/linalg1/Types.h create mode 100644 include/linalg1/Utils.h create mode 100644 include/linalg1/ViewOp.h create mode 100644 include/linalg1/ViewType.h create mode 100644 include/linalg2/Analysis.h create mode 100644 include/linalg2/Intrinsics.h create mode 100644 include/linalg2/Ops.h create mode 100644 include/linalg2/TensorOps-inl.h create mode 100644 include/linalg2/TensorOps.h create mode 100644 include/linalg2/Transforms.h create mode 100644 include/linalg3/Analysis.h create mode 100644 include/linalg3/ConvertToLLVMDialect.h create mode 100644 include/linalg3/Intrinsics.h create mode 100644 include/linalg3/LoadStoreOps.h create mode 100644 include/linalg3/Ops.h create mode 100644 include/linalg3/TensorOps-inl.h create mode 100644 include/linalg3/TensorOps.h create mode 100644 include/linalg3/Transforms.h create mode 100644 include/toy/AST.h create mode 100644 include/toy/Dialect.h create mode 100644 include/toy/Lexer.h create mode 100644 include/toy/Lowering.h create mode 100644 include/toy/MLIRGen.h create mode 100644 include/toy/Parser.h create mode 100644 include/toy/Passes.h create mode 100644 mlir/EarlyLowering.cpp create mode 100644 mlir/LateLowering.cpp create mode 100644 mlir/MLIRGen.cpp create mode 100644 mlir/ShapeInferencePass.cpp create mode 100644 mlir/ToyCombine.cpp create mode 100644 mlir/ToyDialect.cpp create mode 100644 parser/AST.cpp create mode 100644 toyc.cpp diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c6e96c3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,13 @@ +# CLion stuff +.idea +cmake-build-* + +# Vim stuff +*.swp + +# My CMake build +build + +# Test artifacts +*.mlir +*.toy diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..0c4cdfc --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,58 @@ +cmake_minimum_required(VERSION 3.12) +project(toyc) + +set(CMAKE_CXX_STANDARD 14) + +find_package(LLVM CONFIG PATHS "${MLIR_INSTALL_PREFIX}" NO_DEFAULT_PATH) + +if (LLVM_FOUND) + message(STATUS "Using LLVM ${LLVM_PACKAGE_VERSION} at ${LLVM_INSTALL_PREFIX}") +else() + message(FATAL_ERROR "LLVM not found; it is derived from MLIR_INSTALL_PREFIX which has value of ${MLIR_INSTALL_PREFIX}") +endif() + +llvm_map_components_to_libnames(llvm_libs support) + +if (NOT ${LLVM_ENABLE_RTTI}) + set(CMAKE_CXX_FLAGS "-fno-rtti ${CMAKE_CXX_FLAGS}") +endif() + +include_directories( + include + ${LLVM_INCLUDE_DIR}) + +link_directories(${LLVM_LIBRARY_DIR}) + +add_definitions(${LLVM_DEFINITIONS}) + +add_executable(toyc + toyc.cpp + parser/AST.cpp + mlir/EarlyLowering.cpp + mlir/LateLowering.cpp + mlir/MLIRGen.cpp + mlir/ShapeInferencePass.cpp + mlir/ToyDialect.cpp + mlir/ToyCombine.cpp + ) + +SET(WHOLE_ARCHIEVE -Wl,--whole-archive MLIRStandardOps MLIRAffineOps -Wl,--no-whole-archive) +target_link_libraries(toyc + PRIVATE + ${llvm_libs} + Linalg3DialectConstruction + Linalg3 + Linalg2 + Linalg1 + ${WHOLE_ARCHIEVE} + MLIRAnalysis + MLIREDSC + MLIRExecutionEngine + MLIRIR + MLIRLLVMIR + MLIRParser + MLIRPass + MLIRTargetLLVMIR + MLIRTransforms + MLIRSupport + ) diff --git a/include/linalg1/Analysis.h b/include/linalg1/Analysis.h new file mode 100644 index 0000000..ef8fb98 --- /dev/null +++ b/include/linalg1/Analysis.h @@ -0,0 +1,49 @@ +//===- Analysis.h - Linalg dialect Analysis function definitions ----------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG1_ANALYSIS_H_ +#define LINALG1_ANALYSIS_H_ + +#include "mlir/Support/LLVM.h" + +namespace mlir { +class Value; +} // namespace mlir + +namespace linalg { +class ViewOp; + +/// Walks the chain of SliceOp until the unique base ViewOp. +ViewOp getViewBaseViewOp(mlir::Value *view); + +/// Walks the chain of SliceOp until the unique base ViewOp and returns the +/// MemRef upon which the ViewOp is laid. +mlir::Value *getViewSupportingMemRef(mlir::Value *view); + +/// Extract the indexing from the root ViewOp that this slice constrins along +/// `dim`. To achieve this, it walks back the chain of SliceOp and determine the +/// first slice that constrains `dim`. +/// Note that the dimension in the original ViewOp may shift due to +/// rank-reducing operations. +/// Returns a pair, with the indexing as the first element and the actual +/// dimension, in the root ViewOp, as the second element. +std::pair getViewRootIndexing(mlir::Value *view, + unsigned dim); + +} // namespace linalg + +#endif // LINALG1_ANALYSIS_H_ diff --git a/include/linalg1/Common.h b/include/linalg1/Common.h new file mode 100644 index 0000000..6573c72 --- /dev/null +++ b/include/linalg1/Common.h @@ -0,0 +1,120 @@ +//===- Common.h - Linalg dialect RangeOp operation -----------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG1_COMMON_H_ +#define LINALG1_COMMON_H_ + +#include "mlir/AffineOps/AffineOps.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/EDSC/Builders.h" +#include "mlir/EDSC/Helpers.h" +#include "mlir/EDSC/Intrinsics.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Identifier.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/StandardOps/Ops.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/LoopUtils.h" +#include "mlir/Transforms/Passes.h" + +namespace linalg { +namespace common { + +//////////////////////////////////////////////////////////////////////////////// +// Define a few boilerplate objects used across all linalg examples. +//////////////////////////////////////////////////////////////////////////////// + +/// A 2-D abstraction over a flat contiguous memory region of f32 with symbolic +/// sizes. +template +inline mlir::MemRefType floatMemRefType(mlir::MLIRContext *context, + unsigned memorySpace = 0) { + llvm::SmallVector shape(N, -1); + auto f32 = mlir::FloatType::getF32(context); + return mlir::MemRefType::get(shape, f32, {}, memorySpace); +} + +/// A basic function builder +inline mlir::Function *makeFunction(mlir::Module &module, llvm::StringRef name, + llvm::ArrayRef types, + llvm::ArrayRef resultTypes) { + auto *context = module.getContext(); + auto *function = new mlir::Function( + mlir::UnknownLoc::get(context), name, + mlir::FunctionType::get({types}, resultTypes, context)); + function->addEntryBlock(); + module.getFunctions().push_back(function); + return function; +} + +/// A basic pass manager pre-populated with cleanup passes. +inline std::unique_ptr cleanupPassManager() { + std::unique_ptr pm(new mlir::PassManager()); + pm->addPass(mlir::createCanonicalizerPass()); + pm->addPass(mlir::createSimplifyAffineStructuresPass()); + pm->addPass(mlir::createCSEPass()); + pm->addPass(mlir::createCanonicalizerPass()); + return pm; +} + +/// A simple function to verify and cleanup the IR before printing it to +/// llvm::outs() for FileCheck'ing. +/// If an error occurs, dump to llvm::errs() and do not print to llvm::outs() +/// which will make the associated FileCheck test fail. +inline void cleanupAndPrintFunction(mlir::Function *f) { + bool printToOuts = true; + auto check = [f, &printToOuts](mlir::LogicalResult result) { + if (failed(result)) { + f->getContext()->emitError(f->getLoc(), + "Verification and cleanup passes failed"); + printToOuts = false; + } + }; + auto pm = cleanupPassManager(); + check(f->getModule()->verify()); + check(pm->run(f->getModule())); + if (printToOuts) + f->print(llvm::outs()); +} + +/// Helper class to sugar building loop nests from indexings that appear in +/// ViewOp and SliceOp. +class LoopNestRangeBuilder { +public: + LoopNestRangeBuilder(llvm::ArrayRef ivs, + llvm::ArrayRef indexings); + LoopNestRangeBuilder(llvm::ArrayRef ivs, + llvm::ArrayRef indexings); + mlir::edsc::ValueHandle + operator()(llvm::ArrayRef stmts); + +private: + llvm::SmallVector loops; +}; + +} // namespace common +} // namespace linalg + +#endif // LINALG1_COMMON_H_ diff --git a/include/linalg1/ConvertToLLVMDialect.h b/include/linalg1/ConvertToLLVMDialect.h new file mode 100644 index 0000000..8e5a7ce --- /dev/null +++ b/include/linalg1/ConvertToLLVMDialect.h @@ -0,0 +1,66 @@ +//===- ConvertToLLVMDialect.h - conversion from Linalg to LLVM --*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG1_CONVERTTOLLVMDIALECT_H_ +#define LINALG1_CONVERTTOLLVMDIALECT_H_ + +#include "llvm/ADT/DenseSet.h" +#include "llvm/Support/Allocator.h" + +#include + +namespace mlir { +class DialectConversion; +class DialectOpConversion; +class MLIRContext; +class Module; +class Type; +namespace LLVM { +class LLVMType; +} // end namespace LLVM +} // end namespace mlir + +namespace linalg { +/// Convert the given Linalg dialect type `t` into an LLVM IR dialect type. +/// Keep all other types unmodified. +mlir::Type convertLinalgType(mlir::Type t); + +/// Allocate the conversion patterns for RangeOp, ViewOp and SliceOp from the +/// Linalg dialect to the LLVM IR dialect. The converters are allocated in the +/// `allocator` using the provided `context`. The latter must have the LLVM IR +/// dialect registered. +/// This function can be used to apply multiple conversion patterns in the same +/// pass. It does not have to be called explicitly before the conversion. +llvm::DenseSet +allocateDescriptorConverters(llvm::BumpPtrAllocator *allocator, + mlir::MLIRContext *context); + +/// Create a DialectConversion from the Linalg dialect to the LLVM IR dialect. +/// The conversion is set up to convert types and function signatures using +/// `convertLinalgType` and obtains operation converters by calling `initer`. +std::unique_ptr makeLinalgToLLVMLowering( + std::function( + llvm::BumpPtrAllocator *, mlir::MLIRContext *context)> + initer); + +/// Convert the Linalg dialect types and RangeOp, ViewOp and SliceOp operations +/// to the LLVM IR dialect types and operations in the given `module`. This is +/// the main entry point to the conversion. +void convertToLLVM(mlir::Module &module); +} // end namespace linalg + +#endif // LINALG1_CONVERTTOLLVMDIALECT_H_ diff --git a/include/linalg1/Dialect.h b/include/linalg1/Dialect.h new file mode 100644 index 0000000..70023e1 --- /dev/null +++ b/include/linalg1/Dialect.h @@ -0,0 +1,42 @@ +//===- Dialect.h - Definition of the Linalg dialect -----------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG1_DIALECT_H_ +#define LINALG1_DIALECT_H_ + +#include "mlir/IR/Dialect.h" + +namespace linalg { + +/// The Linalg Dialect is not exposed to the outside world. It is registered by +/// linking and accessed via generic MLIR accessors. +class LinalgDialect : public mlir::Dialect { +public: + /// Create a new Dialect that is registered on construction and adds the + /// relevant types and operations. + explicit LinalgDialect(mlir::MLIRContext *context); + + /// Parse a type registered to this dialect. + mlir::Type parseType(llvm::StringRef spec, mlir::Location loc) const override; + + /// Print a type registered to this dialect. + void printType(mlir::Type type, llvm::raw_ostream &os) const override; +}; + +} // namespace linalg + +#endif // LINALG1_DIALECT_H_ diff --git a/include/linalg1/Intrinsics.h b/include/linalg1/Intrinsics.h new file mode 100644 index 0000000..305e3f3 --- /dev/null +++ b/include/linalg1/Intrinsics.h @@ -0,0 +1,32 @@ +//===- Intrinsics.h - Linalg intrinsics definitions -----------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG1_INTRINSICS_H_ +#define LINALG1_INTRINSICS_H_ + +#include "linalg1/Ops.h" +#include "mlir/EDSC/Intrinsics.h" + +namespace linalg { +namespace intrinsics { +using range = mlir::edsc::intrinsics::ValueBuilder; +using slice = mlir::edsc::intrinsics::ValueBuilder; +using view = mlir::edsc::intrinsics::ValueBuilder; +} // namespace intrinsics +} // namespace linalg + +#endif // LINALG1_INTRINSICS_H_ diff --git a/include/linalg1/LLVMIntrinsics.h b/include/linalg1/LLVMIntrinsics.h new file mode 100644 index 0000000..577981b --- /dev/null +++ b/include/linalg1/LLVMIntrinsics.h @@ -0,0 +1,41 @@ +//===- LLVMIntrinsics.h - declarative builders for LLVM dialect -*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG1_LLVMINTRINSICS_H_ +#define LINALG1_LLVMINTRINSICS_H_ + +#include "mlir/EDSC/Builders.h" +#include "mlir/EDSC/Intrinsics.h" +#include "mlir/LLVMIR/LLVMDialect.h" + +// Expose some LLVM IR instructions to declarative builders. +namespace intrinsics { +using undef = mlir::edsc::intrinsics::ValueBuilder; +using insertvalue = + mlir::edsc::intrinsics::ValueBuilder; +using extractvalue = + mlir::edsc::intrinsics::ValueBuilder; +using constant = mlir::edsc::intrinsics::ValueBuilder; +using add = mlir::edsc::intrinsics::ValueBuilder; +using sub = mlir::edsc::intrinsics::ValueBuilder; +using mul = mlir::edsc::intrinsics::ValueBuilder; +using load = mlir::edsc::intrinsics::ValueBuilder; +using store = mlir::edsc::intrinsics::OperationBuilder; +using gep = mlir::edsc::intrinsics::ValueBuilder; +} // end namespace intrinsics + +#endif // LINALG1_LLVMINTRINSICS_H_ diff --git a/include/linalg1/Ops.h b/include/linalg1/Ops.h new file mode 100644 index 0000000..2e662cf --- /dev/null +++ b/include/linalg1/Ops.h @@ -0,0 +1,26 @@ +//===- Ops.h - Linalg Ops single entry point ------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG1_OPS_H_ +#define LINALG1_OPS_H_ + +#include "linalg1/Types.h" +#include "linalg1/RangeOp.h" +#include "linalg1/SliceOp.h" +#include "linalg1/ViewOp.h" + +#endif // LINALG1_OPS_H_ diff --git a/include/linalg1/RangeOp.h b/include/linalg1/RangeOp.h new file mode 100644 index 0000000..9652f51 --- /dev/null +++ b/include/linalg1/RangeOp.h @@ -0,0 +1,56 @@ +//===- RangeOp.h - Linalg dialect RangeOp operation definition ------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG1_RANGEOP_H_ +#define LINALG1_RANGEOP_H_ + +#include "mlir/IR/OpDefinition.h" +#include "mlir/Support/LLVM.h" + +namespace linalg { + +/// A RangeOp is used to create a value of RangeType from 3 values of type index +/// that represent the min, max and step values of the range. +/// Note: step must be an mlir::ConstantIndexOp for now due to current +/// `affine.for` limitations. +class RangeOp : public mlir::Op::Impl, + mlir::OpTrait::OneResult, + mlir::OpTrait::HasNoSideEffect> { +public: + using Op::Op; + + ////////////////////////////////////////////////////////////////////////////// + // Hooks to customize the behavior of this op. + ////////////////////////////////////////////////////////////////////////////// + static llvm::StringRef getOperationName() { return "linalg.range"; } + static void build(mlir::Builder *b, mlir::OperationState *result, + mlir::Value *min, mlir::Value *max, mlir::Value *step); + mlir::LogicalResult verify(); + static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + void print(mlir::OpAsmPrinter *p); + + ////////////////////////////////////////////////////////////////////////////// + // Op-specific functionality. + ////////////////////////////////////////////////////////////////////////////// + mlir::Value *getMin() { return getOperand(0); } + mlir::Value *getMax() { return getOperand(1); } + mlir::Value *getStep() { return getOperand(2); } +}; + +} // namespace linalg + +#endif // LINALG1_RANGEOP_H_ diff --git a/include/linalg1/RangeType.h b/include/linalg1/RangeType.h new file mode 100644 index 0000000..d17c058 --- /dev/null +++ b/include/linalg1/RangeType.h @@ -0,0 +1,49 @@ +//===- RangeType.h - Linalg RangeType definition --------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG1_RANGETYPE_H_ +#define LINALG1_RANGETYPE_H_ + +#include "linalg1/Types.h" +#include "mlir/IR/Types.h" + +namespace mlir { +class MLIRContext; +} + +namespace linalg { + +/// A RangeType is the simplest possible form of a type in MLIR. It represents +/// a minimal range abstraction (min, max, step). Since RangeType is constructed +/// without any additional argument, this example illustrates the minimal +/// amount of information required to implement a new custom MLIR type. +class RangeType : public mlir::Type::TypeBase { +public: + // Used to implement llvm-style cast. + using Base::Base; + /// Construction hook. + static RangeType get(mlir::MLIRContext *context) { + /// Custom, uniqu'ed construction in the mlir::MLIRContext. + return Base::get(context, LinalgTypes::Range); + } + /// Used to implement llvm-style cast. + static bool kindof(unsigned kind) { return kind == LinalgTypes::Range; } +}; + +} // namespace linalg + +#endif // LINALG1_RANGETYPE_H_ diff --git a/include/linalg1/SliceOp.h b/include/linalg1/SliceOp.h new file mode 100644 index 0000000..1d79784 --- /dev/null +++ b/include/linalg1/SliceOp.h @@ -0,0 +1,91 @@ +//===- SliceOp.h - Linalg dialect SliceOp operation definition ------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG1_SLICEOP_H_ +#define LINALG1_SLICEOP_H_ + +#include "linalg1/Types.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Support/LLVM.h" + +namespace linalg { + +/// A SliceOp is used to create a "sub-View" from a ViewType. It results in a +/// new ViewType which is contained within its parent ViewType. +class SliceOp : public mlir::Op::Impl, + mlir::OpTrait::OneResult, + mlir::OpTrait::HasNoSideEffect> { +public: + using Op::Op; + + ////////////////////////////////////////////////////////////////////////////// + // Hooks to customize the behavior of this op. + ////////////////////////////////////////////////////////////////////////////// + static llvm::StringRef getOperationName() { return "linalg.slice"; } + static void build(mlir::Builder *b, mlir::OperationState *result, + mlir::Value *view, mlir::Value *indexing, unsigned dim); + mlir::LogicalResult verify(); + static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + void print(mlir::OpAsmPrinter *p); + + ////////////////////////////////////////////////////////////////////////////// + // Op-specific functionality. + ////////////////////////////////////////////////////////////////////////////// + enum { FirstIndexingOperand = 1 }; + /// Returns the attribute name that describes which dimension of the input + /// view that this SliceOp slices. + static llvm::StringRef getSlicingDimAttrName() { return "dim"; } + /// Returns the unique result of the parent SliceOp of ViewOp instruction that + /// created the view on which this SliceOp operates. + mlir::Value *getParentView() { return getOperand(0); } + /// Returns the indexing operand of the current SliceOp. + /// This operands may either be: + /// 1. A range, in which case the operand comes from a RangeOp. This SliceOp + /// does not reduce the dimension of the input ViewType. + /// 2. An index, in which case the operand comes from any possible producer + /// of an index. This SliceOp reduces the dimension of the input ViewType + /// by 1. + mlir::Value *getIndexing() { return getOperand(1); } + /// Returns the dim of the parent ViewType that is sliced by this SliceOp. + unsigned getSlicingDim() { + return getAttrOfType(getSlicingDimAttrName()).getInt(); + } + /// Returns the ViewType resulting from this SliceOp. + ViewType getViewType(); + /// Returns the rank of the current ViewType. + unsigned getRank(); + /// Return the element type of the current ViewType. + mlir::Type getElementType(); + + /// Returns the ViewType of `getParentView()`. + ViewType getParentViewType(); + /// Returns the rank of the ViewType of `getParentView()`. + unsigned getParentRank(); + /// Returns the element Type of the ViewType of `getParentView()`. + mlir::Type getParentElementType(); + + /// Returns true if the rank of the part view is greater than the rank of + /// the child view. + bool isRankDecreasing(); + + // Get all the indexings in this slice. + mlir::Operation::operand_range getIndexings(); +}; + +} // namespace linalg + +#endif // LINALG1_SLICEOP_H_ diff --git a/include/linalg1/Types.h b/include/linalg1/Types.h new file mode 100644 index 0000000..5032e96 --- /dev/null +++ b/include/linalg1/Types.h @@ -0,0 +1,36 @@ +//===- Types.h - Linalg Types forward declarations ------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG1_TYPES_H_ +#define LINALG1_TYPES_H_ + +#include "mlir/IR/Types.h" + +namespace linalg { + +enum LinalgTypes { + Range = mlir::Type::FIRST_PRIVATE_EXPERIMENTAL_0_TYPE, + View, + FIRST_PRIVATE_EXPERIMENTAL_0_TYPE = View, +}; + +} // namespace linalg + +#include "linalg1/RangeType.h" +#include "linalg1/ViewType.h" + +#endif // LINALG1_TYPES_H_ diff --git a/include/linalg1/Utils.h b/include/linalg1/Utils.h new file mode 100644 index 0000000..3f7bb76 --- /dev/null +++ b/include/linalg1/Utils.h @@ -0,0 +1,37 @@ +//===- Utils.h - Linalg dialect utility functions definitions -------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG1_UTILS_H_ +#define LINALG1_UTILS_H_ + +namespace mlir { +class Value; +} // namespace mlir + +namespace linalg { +class ViewOp; + +/// Asserts `view` is of ViewType and returns its rank. +unsigned getViewRank(mlir::Value *view); + +/// Helper function to emit and return a new ViewOp from `memRef` that is +/// assumed to be of MemRefType. This needs to be called under a ScopedContext. +ViewOp emitAndReturnViewOpFromMemRef(mlir::Value *memRef); + +} // namespace linalg + +#endif // LINALG1_UTILS_H_ diff --git a/include/linalg1/ViewOp.h b/include/linalg1/ViewOp.h new file mode 100644 index 0000000..fcda553 --- /dev/null +++ b/include/linalg1/ViewOp.h @@ -0,0 +1,67 @@ +//===- ViewOp.h - Linalg dialect ViewOp operation definition ------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG1_VIEWOP_H_ +#define LINALG1_VIEWOP_H_ + +#include "mlir/IR/OpDefinition.h" +#include "mlir/Support/LLVM.h" + +namespace linalg { + +class ViewType; + +/// A `ViewOp` produces a `ViewType` which is a multi-dimensional range +/// abstraction on top of an underlying data type. For now we use the existing +/// mlir::MemRef for the underlying data type. +class ViewOp : public mlir::Op { +public: + using Op::Op; + + ////////////////////////////////////////////////////////////////////////////// + // Hooks to customize the behavior of this op. + ////////////////////////////////////////////////////////////////////////////// + static llvm::StringRef getOperationName() { return "linalg.view"; } + static void build(mlir::Builder *b, mlir::OperationState *result, + mlir::Value *memRef, + llvm::ArrayRef indexings); + mlir::LogicalResult verify(); + static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + void print(mlir::OpAsmPrinter *p); + + ////////////////////////////////////////////////////////////////////////////// + // Op-specific functionality. + ////////////////////////////////////////////////////////////////////////////// + enum { FirstIndexingOperand = 1 }; + unsigned getRank(); + mlir::Type getElementType(); + ViewType getViewType(); + // May be something else than a MemRef in the future. + mlir::Value *getSupportingMemRef(); + // Get the underlying indexing at a given rank. + mlir::Value *getIndexing(unsigned rank); + // Get all the indexings of type RangeOp. + llvm::SmallVector getRanges(); + // Get all the indexings in this view. + mlir::Operation::operand_range getIndexings(); +}; + +} // namespace linalg + +#endif // LINALG1_VIEWOP_H_ diff --git a/include/linalg1/ViewType.h b/include/linalg1/ViewType.h new file mode 100644 index 0000000..c58e12c --- /dev/null +++ b/include/linalg1/ViewType.h @@ -0,0 +1,57 @@ +//===- ViewType.h - Linalg ViewType definition --------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG1_VIEWTYPE_H_ +#define LINALG1_VIEWTYPE_H_ + +#include "linalg1/Types.h" +#include "mlir/IR/Types.h" + +namespace linalg { + +class ViewTypeStorage; + +/// A ViewType represents a range abstraction on top of an underlying storage +/// type. It is parameterizable by the underlying element type and the rank of +/// the view. +class ViewType + : public mlir::Type::TypeBase { +public: + ////////////////////////////////////////////////////////////////////////////// + // Hooks to customize the behavior of this type. + ////////////////////////////////////////////////////////////////////////////// + // Used to implement llvm-style cast. + using Base::Base; + // Used to implement llvm-style cast. + static bool kindof(unsigned kind) { return kind == LinalgTypes::View; } + /// Construction hook. + static ViewType get(mlir::MLIRContext *context, mlir::Type elementType, + unsigned rank); + + ////////////////////////////////////////////////////////////////////////////// + // Type-specific functionality. + ////////////////////////////////////////////////////////////////////////////// + /// Return the underlying elemental type. + mlir::Type getElementType(); + /// Return the rank of the view. + /// This is the number of indexings needed to reach an underlying element. + unsigned getRank(); +}; + +} // namespace linalg + +#endif // LINALG1_VIEWTYPE_H_ diff --git a/include/linalg2/Analysis.h b/include/linalg2/Analysis.h new file mode 100644 index 0000000..43acd95 --- /dev/null +++ b/include/linalg2/Analysis.h @@ -0,0 +1,23 @@ +//===- Analysis.h - Linalg dialect Analysis function definitions ----------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG2_ANALYSIS_H_ +#define LINALG2_ANALYSIS_H_ + +#include "linalg1/Analysis.h" + +#endif // LINALG2_ANALYSIS_H_ diff --git a/include/linalg2/Intrinsics.h b/include/linalg2/Intrinsics.h new file mode 100644 index 0000000..e74e059 --- /dev/null +++ b/include/linalg2/Intrinsics.h @@ -0,0 +1,32 @@ +//===- Intrinsics.h - Linalg intrinsics definitions -----------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG2_INTRINSICS_H_ +#define LINALG2_INTRINSICS_H_ + +#include "linalg1/Intrinsics.h" +#include "linalg2/Ops.h" + +namespace linalg { +namespace intrinsics { +using dot = mlir::edsc::intrinsics::OperationBuilder; +using matmul = mlir::edsc::intrinsics::OperationBuilder; +using matvec = mlir::edsc::intrinsics::OperationBuilder; +} // namespace intrinsics +} // namespace linalg + +#endif // LINALG2_INTRINSICS_H_ diff --git a/include/linalg2/Ops.h b/include/linalg2/Ops.h new file mode 100644 index 0000000..141b1d0 --- /dev/null +++ b/include/linalg2/Ops.h @@ -0,0 +1,24 @@ +//===- Ops.h - Linalg Ops single entry point ------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG2_OPS_H_ +#define LINALG2_OPS_H_ + +#include "linalg1/Ops.h" +#include "linalg2/TensorOps.h" + +#endif // LINALG2_OPS_H_ diff --git a/include/linalg2/TensorOps-inl.h b/include/linalg2/TensorOps-inl.h new file mode 100644 index 0000000..940f8d7 --- /dev/null +++ b/include/linalg2/TensorOps-inl.h @@ -0,0 +1,120 @@ +//===- TensorOps-inl.h - Linalg dialect TensorOps operation implementation ===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of +/// TensorOps by adding implementations as they are needed in the appropriate +/// step in the tutorial. +#ifndef LINALG2_TENSOROPS_INL_H_ +#define LINALG2_TENSOROPS_INL_H_ + +#include "linalg2/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/StandardTypes.h" + +namespace linalg { + +template +mlir::Operation::operand_range +linalg::TensorContractionBase::getInputs() { + auto *op = static_cast(this)->getOperation(); + return {op->operand_begin(), op->operand_begin() + getNumInputs()}; +} + +template +mlir::Operation::operand_range +linalg::TensorContractionBase::getOutputs() { + auto *op = static_cast(this)->getOperation(); + return {op->operand_begin() + getNumInputs(), + op->operand_begin() + getNumInputs() + getNumOutputs()}; +} + +template +mlir::Operation::operand_range +linalg::TensorContractionBase::getInputsAndOutputs() { + return {getInputs().begin(), getOutputs().end()}; +} + +template +mlir::LogicalResult linalg::TensorContractionBase::verify() { + auto *concreteOp = static_cast(this)->getOperation(); + if (getNumInputs() <= 0) + concreteOp->emitOpError("expected at least one input"); + if (getNumOutputs() <= 0) + concreteOp->emitOpError("expected at least one output"); + if (concreteOp->getNumOperands() != getNumInputs() + getNumOutputs()) { + concreteOp->emitOpError("expected " + + llvm::Twine(getNumInputs() + getNumOutputs()) + + " operands"); + } + for (unsigned i = 0, e = getNumInputs(); i < e; ++i) { + if (!concreteOp->getOperand(i)->getType().template isa()) + return concreteOp->emitOpError("operand " + llvm::Twine(i) + + " not a ViewType"); + } + for (unsigned i = getNumInputs(), e = getNumInputs() + getNumOutputs(); i < e; + ++i) { + auto viewType = + concreteOp->getOperand(i)->getType().template dyn_cast(); + if (!viewType) + return concreteOp->emitOpError("operand " + llvm::Twine(i) + + " not a ViewType"); + if (viewType.getRank() != getNumParallelDims()) + return concreteOp->emitOpError("operand " + llvm::Twine(i) + + " must be of rank " + + llvm::Twine(getNumParallelDims())); + } + return mlir::success(); +} + +template +bool linalg::TensorContractionBase::parse( + mlir::OpAsmParser *parser, mlir::OperationState *result) { + llvm_unreachable("Parsing linalg dialect is not supported in this tutorial"); +} + +// A TensorContraction prints as: +// +// ```{.mlir} +// concrete_op_name (ssa-inputs, ssa-outputs) : output-view-types +// ``` +// +// for example: +// +// ``` +// linalg.matmul(%0, %1, %2) : view +// ``` +// +// Where %0, %1 and %2 are ssa-values of type ViewType. +template +void linalg::TensorContractionBase::print(mlir::OpAsmPrinter *p) { + *p << static_cast(this)->getOperationName() << "("; + auto *last = *std::prev(getInputsAndOutputs().end()); + for (auto *i : getInputsAndOutputs()) { + *p << *i << ((i == last) ? "" : ", "); + } + *p << ") : "; + auto *lastOutput = *std::prev(getOutputs().end()); + for (auto *o : getOutputs()) { + *p << o->getType() << ((o == lastOutput) ? "" : ","); + } +} + +} // namespace linalg + +#endif // LINALG2_TENSOROPS_INL_H_ diff --git a/include/linalg2/TensorOps.h b/include/linalg2/TensorOps.h new file mode 100644 index 0000000..39e51f0 --- /dev/null +++ b/include/linalg2/TensorOps.h @@ -0,0 +1,287 @@ +//===- TensorOps.h - Linalg dialect TensorOps operation definition --------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG2_TENSOROPS_H_ +#define LINALG2_TENSOROPS_H_ + +#include "mlir/IR/OpDefinition.h" +#include "mlir/Support/LLVM.h" + +namespace mlir { +class AffineForOp; +} // namespace mlir + +namespace linalg { + +/// A generic TensorContraction base class which captures the generic behavior +/// of tensor contraction operations (with broadcast). +template class TensorContractionBase { +protected: + using TensorContractionBaseType = TensorContractionBase; + + ////////////////////////////////////////////////////////////////////////////// + // Hooks to customize the behavior of this op. + ////////////////////////////////////////////////////////////////////////////// + /// Generic implementation of hooks that should be called from `ConcreteType`s + mlir::LogicalResult verify(); + static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + void print(mlir::OpAsmPrinter *p); + +public: + ////////////////////////////////////////////////////////////////////////////// + // Op-specific functionality. + ////////////////////////////////////////////////////////////////////////////// + TensorContractionBase() = default; + mlir::Operation::operand_range getInputs(); + mlir::Operation::operand_range getOutputs(); + mlir::Operation::operand_range getInputsAndOutputs(); + + /// These are better as methods calling into the ConcreteOp instead of + /// template parameters because methods allow more generic behavior and avoid + /// specializing for number of arguments. All derived classes have + /// `VariadicOperands` and a build method from both an ArrayRef + /// and the proper number of mlir::Value*. + unsigned getNumInputs() { + return static_cast(this)->numInputs; + }; + unsigned getNumOutputs() { + return static_cast(this)->numOutputs; + }; + unsigned getNumParallelDims() { + return static_cast(this)->numParallelDims; + }; + unsigned getNumReductionDims() { + return static_cast(this)->numReductionDims; + }; + + ////////////////////////////////////////////////////////////////////////////// + // Used in Linalg3 and later. + ////////////////////////////////////////////////////////////////////////////// + mlir::Value *getInputView(unsigned viewIndex); + mlir::Value *getOutputView(unsigned viewIndex); + mlir::Value *getView(unsigned viewIndex) { + return viewIndex < getNumInputs() + ? getInputView(viewIndex) + : getOutputView(viewIndex - getNumInputs()); + } + + /// Each op is responsible for declaring how it lowers itself to scalar form, + /// given the enclosing parallel and reduction induction variables. + /// `emitScalarImplementation` emits the scalar IR for the op in the nesting + /// context of the innermost enclosing loop(i.e. `reductionIvs.back()` or + /// `parallel.back()`). + void emitScalarImplementation(llvm::ArrayRef parallelIvs, + llvm::ArrayRef reductionIvs); + + /// Represents a mapping from the loops to all the ranges of the operands. + /// The operands and their ranges are in the order defined by the particular + /// ConcreteOp implementation, the resulting map must match those. + /// In favorable cases, this can be calculated by an analysis but specifying + /// it explicitly is not expensive and generalizes to cases where an analysis + /// is not available. For details, see the description of + /// loopsToOperandRangeMaps in each ConcreteOp. + llvm::SmallVector loopsToOperandRangeMaps(); +}; + +/// Implements c = A * B where c is a scalar and A and B are 1-D vectors. +class DotOp : public TensorContractionBase, + public mlir::Op { +public: + using Op::Op; + using TensorContractionBaseType = + TensorContractionBase::TensorContractionBaseType; + + ////////////////////////////////////////////////////////////////////////////// + // Hooks to customize the behavior of this op. + ////////////////////////////////////////////////////////////////////////////// + static llvm::StringRef getOperationName() { return "linalg.dot"; } + static void build(mlir::Builder *b, mlir::OperationState *result, + llvm::ArrayRef operands); + static void build(mlir::Builder *b, mlir::OperationState *result, + mlir::Value *A, mlir::Value *B, mlir::Value *C) { + return build(b, result, {A, B, C}); + } + mlir::LogicalResult verify(); + static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + void print(mlir::OpAsmPrinter *p); + + ////////////////////////////////////////////////////////////////////////////// + // Op-specific functionality. + ////////////////////////////////////////////////////////////////////////////// + static constexpr unsigned numInputs = 2; + static constexpr unsigned numOutputs = 1; + static constexpr unsigned numParallelDims = 0; + static constexpr unsigned numReductionDims = 1; + + ////////////////////////////////////////////////////////////////////////////// + // Used in Linalg3 and later. + ////////////////////////////////////////////////////////////////////////////// + /// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a + /// loop over matvec). Does nothing by default. + void writeAsFinerGrainTensorContraction(); + + /// Inputs to this map will be (%k) coming from enclosing loops. + /// Therefore, the mapping to get back to A(K), B(K), C() is: + /// (d0) -> (d0, d0)(%k) + /// And the operands ranges are: + /// (%k, %k) + llvm::SmallVector loopsToOperandRangeMaps(); + + /// Given an enclosing reduction loop with iv `r_i`, emits MLIR corresponding + /// to: + /// 1. conditionally assign scalarC to 0.0f on the first iteration or load + /// C[] from memory (0-D tensor) + /// 2. multiply A[r_i] by B[r_i] and add to scalarC + /// 3. store back scalarC at C[] + /// + /// In some compact index notation this could be written: + /// cond = (r_i == zero) + /// scalarC = select(cond, zerof, C[]); + /// C[] = scalarC + A[r_i] * B[r_i]; + void emitScalarImplementation(llvm::ArrayRef parallelIvs, + llvm::ArrayRef reductionIvs); +}; + +/// Implements C = A * B where A is a 2-D matrix and X and Y are 1-D vectors. +class MatvecOp : public TensorContractionBase, + public mlir::Op { +public: + using Op::Op; + using TensorContractionBaseType = + TensorContractionBase::TensorContractionBaseType; + + ////////////////////////////////////////////////////////////////////////////// + // Hooks to customize the behavior of this op. + ////////////////////////////////////////////////////////////////////////////// + static llvm::StringRef getOperationName() { return "linalg.matvec"; } + static void build(mlir::Builder *b, mlir::OperationState *result, + llvm::ArrayRef operands); + static void build(mlir::Builder *b, mlir::OperationState *result, + mlir::Value *A, mlir::Value *B, mlir::Value *C) { + return build(b, result, {A, B, C}); + } + mlir::LogicalResult verify(); + static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + void print(mlir::OpAsmPrinter *p); + + ////////////////////////////////////////////////////////////////////////////// + // Op-specific functionality. + ////////////////////////////////////////////////////////////////////////////// + static constexpr unsigned numInputs = 2; + static constexpr unsigned numOutputs = 1; + static constexpr unsigned numParallelDims = 1; + static constexpr unsigned numReductionDims = 1; + + ////////////////////////////////////////////////////////////////////////////// + // Used in Linalg3 and later. + ////////////////////////////////////////////////////////////////////////////// + /// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a + /// loop over matvec). Does nothing by default. + void writeAsFinerGrainTensorContraction(); + + /// Inputs to this map will be (%m, %k) coming from enclosing loops. + /// Therefore, the mapping to get back to A(M, K), B(K), C(M) is: + /// (d0, d1) -> (d0, d1, d1, d0)(%m, %k) + /// And the operands ranges are: + /// (%m, %k, %k, %m) + llvm::SmallVector loopsToOperandRangeMaps(); + + /// Given an enclosing parallel loop with iv `i` and an enclosing parallel + /// loop with iv `r_j`, emits MLIR corresponding to: + /// 1. conditionally assign scalarC to 0.0f on the first iteration or load + /// C[i] + /// 2. multiply A[i, r_j] by B[r_j] and add to scalarC + /// 3. store back scalarC at C[i] + /// + /// In some compact index notation this could be written: + /// cond = (r_j == zero) + /// scalarC = select(cond, zerof, C(i)); + /// C(i) = scalarC + A(i, r_j) * B(r_j); + void emitScalarImplementation(llvm::ArrayRef parallelIvs, + llvm::ArrayRef reductionIvs); +}; + +/// Implements C = A * B on 2-D matrices. +class MatmulOp : public TensorContractionBase, + public mlir::Op { +public: + using Op::Op; + using TensorContractionBaseType = + TensorContractionBase::TensorContractionBaseType; + + ////////////////////////////////////////////////////////////////////////////// + // Hooks to customize the behavior of this op. + ////////////////////////////////////////////////////////////////////////////// + static llvm::StringRef getOperationName() { return "linalg.matmul"; } + static void build(mlir::Builder *b, mlir::OperationState *result, + llvm::ArrayRef operands); + static void build(mlir::Builder *b, mlir::OperationState *result, + mlir::Value *A, mlir::Value *B, mlir::Value *C) { + return build(b, result, {A, B, C}); + } + mlir::LogicalResult verify(); + static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + void print(mlir::OpAsmPrinter *p); + + ////////////////////////////////////////////////////////////////////////////// + // Op-specific functionality. + ////////////////////////////////////////////////////////////////////////////// + static constexpr unsigned numInputs = 2; + static constexpr unsigned numOutputs = 1; + static constexpr unsigned numParallelDims = 2; + static constexpr unsigned numReductionDims = 1; + + ////////////////////////////////////////////////////////////////////////////// + // Used in Linalg3 and later. + ////////////////////////////////////////////////////////////////////////////// + /// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a + /// loop over matvec). Does nothing by default. + void writeAsFinerGrainTensorContraction(); + + /// Inputs to this map will be (%m, %n, %k) coming from enclosing loops. + /// Therefore, the mapping to get back to A(M, K), B(K, N), C(M, N) is: + /// (d0, d1, d2) -> (d0, d2, d2, d1, d0, d1)(%m, %n, %k) + /// And the operands ranges are: + /// (%m, %k, %k, %n, %m, %n) + llvm::SmallVector loopsToOperandRangeMaps(); + + /// Given a enclosing parallel loops with ivs `i` and `j`, and an enclosing + /// reduction loop with iv `r_k`, emits MLIR corresponding to: + /// 1. conditionally assign scalarC to 0.0f on the first iteration or load + /// C[i, j] + /// 2. multiply A[i, r_k] by B[r_k, j] and add to scalarC + /// 3. store back scalarC at C[i, j] + /// + /// In some compact index notation this could be written: + /// cond = (r_k == zero) + /// scalarC = select(cond, zerof, C[i, j]); + /// C[i, j] = scalarC + A[i, r_k] * B[r_k, j]; + void emitScalarImplementation(llvm::ArrayRef parallelIvs, + llvm::ArrayRef reductionIvs); +}; + +} // namespace linalg + +/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of +/// TensorOps by adding implementations as they are needed in the appropriate +/// step in the tutorial. +#include "linalg2/TensorOps-inl.h" + +#endif // LINALG2_TENSOROPS_H_ diff --git a/include/linalg2/Transforms.h b/include/linalg2/Transforms.h new file mode 100644 index 0000000..c55f863 --- /dev/null +++ b/include/linalg2/Transforms.h @@ -0,0 +1,36 @@ +//===- Transforms.h - Linalg dialect Transformations definition -----------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG2_TRANSFORMS_H_ +#define LINALG2_TRANSFORMS_H_ + +namespace mlir { +class Value; +} // namespace mlir + +namespace linalg { + +class ViewOp; + +/// Takes a `view` of type ViewType (i.e. either a ViewOp or a SliceOp) and +/// composes away all the SliceOp to return a single ViewOp. +/// Inserts the required operations after `view`. +ViewOp emitAndReturnFullyComposedView(mlir::Value *v); + +} // namespace linalg + +#endif // LINALG2_TRANSFORMS_H_ diff --git a/include/linalg3/Analysis.h b/include/linalg3/Analysis.h new file mode 100644 index 0000000..813fc37 --- /dev/null +++ b/include/linalg3/Analysis.h @@ -0,0 +1,37 @@ +//===- Analysis.h - Linalg dialect Analysis function definitions ----------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG3_ANALYSIS_H_ +#define LINALG3_ANALYSIS_H_ + +#include "linalg2/Analysis.h" + +namespace mlir { +class AffineMap; +} // namespace mlir + +namespace linalg { + +/// Given a `map` specification and a subset of its results +/// `[beginResult, endResult)`, returns the inverse map that maps result +/// positions to dim positions. +mlir::AffineMap inverseSubMap(mlir::AffineMap map, unsigned beginResult = 0, + unsigned endResult = 0); + +} // namespace linalg + +#endif // LINALG3_ANALYSIS_H_ diff --git a/include/linalg3/ConvertToLLVMDialect.h b/include/linalg3/ConvertToLLVMDialect.h new file mode 100644 index 0000000..8f122e0 --- /dev/null +++ b/include/linalg3/ConvertToLLVMDialect.h @@ -0,0 +1,29 @@ +//===- ConvertToLLVMDialect.h - conversion from Linalg to LLVM --*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG3_CONVERTTOLLVMDIALECT_H_ +#define LINALG3_CONVERTTOLLVMDIALECT_H_ + +namespace mlir { +class Module; +} // end namespace mlir + +namespace linalg { +void convertLinalg3ToLLVM(mlir::Module &module); +} // end namespace linalg + +#endif // LINALG3_CONVERTTOLLVMDIALECT_H_ diff --git a/include/linalg3/Intrinsics.h b/include/linalg3/Intrinsics.h new file mode 100644 index 0000000..75a0417 --- /dev/null +++ b/include/linalg3/Intrinsics.h @@ -0,0 +1,31 @@ +//===- Intrinsics.h - Linalg intrinsics definitions -----------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG3_INTRINSICS_H_ +#define LINALG3_INTRINSICS_H_ + +#include "linalg2/Intrinsics.h" +#include "linalg3/Ops.h" + +namespace linalg { +namespace intrinsics { +using load = mlir::edsc::intrinsics::ValueBuilder; +using store = mlir::edsc::intrinsics::OperationBuilder; +} // namespace intrinsics +} // namespace linalg + +#endif // LINALG3_INTRINSICS_H_ diff --git a/include/linalg3/LoadStoreOps.h b/include/linalg3/LoadStoreOps.h new file mode 100644 index 0000000..b77e702 --- /dev/null +++ b/include/linalg3/LoadStoreOps.h @@ -0,0 +1,89 @@ +//===- LoadStoreOps.h - Linalg dialect Load/Store operation definitions ---===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG3_LOADSTOREOP_H_ +#define LINALG3_LOADSTOREOP_H_ + +#include "mlir/IR/OpDefinition.h" +#include "mlir/Support/LLVM.h" + +namespace linalg { + +class ViewType; + +/// A linalg.LoadOp is the counterpart of affine.load but operating on ViewType +/// instead of MemRefType. +class LoadOp : public mlir::Op { +public: + using Op::Op; + + ////////////////////////////////////////////////////////////////////////////// + // Hooks to customize the behavior of this op. + ////////////////////////////////////////////////////////////////////////////// + static llvm::StringRef getOperationName() { return "linalg.load"; } + static void build(mlir::Builder *b, mlir::OperationState *result, + mlir::Value *view, + mlir::ArrayRef indices = {}); + mlir::LogicalResult verify(); + static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + void print(mlir::OpAsmPrinter *p); + + ////////////////////////////////////////////////////////////////////////////// + // Op-specific functionality. + ////////////////////////////////////////////////////////////////////////////// + unsigned getRank(); + ViewType getViewType(); + mlir::Value *getView() { return getOperand(0); } + mlir::Operation::operand_range getIndices() { + return {operand_begin() + 1, operand_end()}; + } +}; + +/// A linalg.StoreOp is the counterpart of affine.store but operating on +/// ViewType instead of MemRefType. +class StoreOp : public mlir::Op { +public: + using Op::Op; + + ////////////////////////////////////////////////////////////////////////////// + // Hooks to customize the behavior of this op. + ////////////////////////////////////////////////////////////////////////////// + static llvm::StringRef getOperationName() { return "linalg.store"; } + static void build(mlir::Builder *b, mlir::OperationState *result, + mlir::Value *valueToStore, mlir::Value *view, + mlir::ArrayRef indices = {}); + mlir::LogicalResult verify(); + static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + void print(mlir::OpAsmPrinter *p); + + ////////////////////////////////////////////////////////////////////////////// + // Op-specific functionality. + ////////////////////////////////////////////////////////////////////////////// + unsigned getRank(); + ViewType getViewType(); + mlir::Value *getValueToStore() { return getOperand(0); } + mlir::Value *getView() { return getOperand(1); } + mlir::Operation::operand_range getIndices() { + return {operand_begin() + 2, operand_end()}; + } +}; + +} // namespace linalg + +#endif // LINALG3_LOADSTOREOP_H_ diff --git a/include/linalg3/Ops.h b/include/linalg3/Ops.h new file mode 100644 index 0000000..813cbff --- /dev/null +++ b/include/linalg3/Ops.h @@ -0,0 +1,25 @@ +//===- Ops.h - Linalg Ops single entry point ------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG3_OPS_H_ +#define LINALG3_OPS_H_ + +#include "linalg2/Ops.h" +#include "linalg3/LoadStoreOps.h" +#include "linalg3/TensorOps.h" + +#endif // LINALG3_OPS_H_ diff --git a/include/linalg3/TensorOps-inl.h b/include/linalg3/TensorOps-inl.h new file mode 100644 index 0000000..b651053 --- /dev/null +++ b/include/linalg3/TensorOps-inl.h @@ -0,0 +1,145 @@ +//===- TensorOps-inl.h - Linalg dialect TensorOps operation implementation ===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of +/// TensorOps by adding implementations as they are needed in the appropriate +/// step in the tutorial. +#ifndef LINALG3_TENSOROPS_INL_H_ +#define LINALG3_TENSOROPS_INL_H_ + +#include "linalg1/Common.h" +#include "linalg1/Utils.h" +#include "linalg2/TensorOps.h" +#include "linalg3/Analysis.h" +#include "linalg3/Ops.h" + +template +mlir::Value * +linalg::TensorContractionBase::getInputView(unsigned viewIndex) { + return *(getInputs().begin() + viewIndex); +} + +template +mlir::Value * +linalg::TensorContractionBase::getOutputView(unsigned viewIndex) { + return *(getOutputs().begin() + viewIndex); +} + +template +llvm::SmallVector +linalg::TensorContractionBase::loopsToOperandRangeMaps() { + return static_cast(this)->loopsToOperandRangeMaps(); +} + +template +void linalg::TensorContractionBase::emitScalarImplementation( + llvm::ArrayRef parallelIvs, + llvm::ArrayRef reductionIvs) { + static_cast(this)->emitScalarImplementation(parallelIvs, + reductionIvs); +} + +template +mlir::AffineMap linalg::operandRangesToLoopsMap( + linalg::TensorContractionBase &tensorContraction) { + mlir::AffineMap current; + // Individual submaps may not be invertible but their union must be invertible + // by construction. + for (auto m : tensorContraction.loopsToOperandRangeMaps()) { + if (!m) + continue; + if (!current) { + current = m; + continue; + } + llvm::SmallVector results(current.getResults().begin(), + current.getResults().end()); + results.append(m.getResults().begin(), m.getResults().end()); + current = mlir::AffineMap::get( + std::max(current.getNumDims(), m.getNumDims()), + current.getNumSymbols() + m.getNumSymbols(), results, {}); + } + return inverseSubMap(current); +} + +// Extract the ranges from a given ViewOp or SliceOp. +// +// In the case of a ViewOp, things are simple: just traverse the indexings and +// get all the ranges (i.e. drop the indices). +// +// In the case of a SliceOp, things are trickier because we need to handle a +// potential rank-reduction: +// 1. Examine the indexing to determine if it is rank-reducing. +// 2. If it is rank-reducing, an offset of 1 is added to the dimensions such +// that `d >= slicingDim`. This is to account for the rank reduction. +// `getRootIndex` is then called on the **parent** view +static llvm::SmallVector +extractRangesFromViewOrSliceOp(mlir::Value *view) { + // This expects a viewType which must come from either ViewOp or SliceOp. + assert(view->getType().isa() && "expected ViewType"); + if (auto viewOp = view->getDefiningOp()->dyn_cast()) + return viewOp.getRanges(); + + auto sliceOp = view->getDefiningOp()->cast(); + unsigned slicingDim = sliceOp.getSlicingDim(); + auto *indexing = *(sliceOp.getIndexings().begin()); + bool isRankReducing = indexing->getType().isa(); + unsigned offset = 0; + llvm::SmallVector res; + res.reserve(sliceOp.getRank()); + for (unsigned d = 0, e = sliceOp.getRank(); d < e; ++d) { + if (d == slicingDim && isRankReducing) + offset = 1; + auto *parentView = sliceOp.getParentView(); + auto indexingPosPair = linalg::getViewRootIndexing(parentView, d + offset); + res.push_back(indexingPosPair.first); + } + return res; +} + +template +static llvm::SmallVector +getInputRanges(linalg::TensorContractionBase &tensorContraction) { + llvm::SmallVector res; + for (auto *in : tensorContraction.getInputs()) { + auto subres = extractRangesFromViewOrSliceOp(in); + res.append(subres.begin(), subres.end()); + } + return res; +} + +template +static llvm::SmallVector +getOutputRanges(linalg::TensorContractionBase &tensorContraction) { + llvm::SmallVector res; + for (auto *out : tensorContraction.getOutputs()) { + auto subres = extractRangesFromViewOrSliceOp(out); + res.append(subres.begin(), subres.end()); + } + return res; +} + +template +llvm::SmallVector linalg::getRanges( + linalg::TensorContractionBase &tensorContraction) { + llvm::SmallVector res = getInputRanges(tensorContraction); + llvm::SmallVector tmp = getOutputRanges(tensorContraction); + res.append(tmp.begin(), tmp.end()); + return res; +} + +#endif // LINALG3_TENSOROPS_INL_H_ diff --git a/include/linalg3/TensorOps.h b/include/linalg3/TensorOps.h new file mode 100644 index 0000000..bf5a377 --- /dev/null +++ b/include/linalg3/TensorOps.h @@ -0,0 +1,54 @@ +//===- TensorOps.h - Linalg dialect TensorOps operation definition --------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG3_TENSOROPS_H_ +#define LINALG3_TENSOROPS_H_ + +#include "linalg2/TensorOps.h" + +namespace linalg { + +/// +/// Ideally all these functions would go in an Analysis but as long as +/// TensorContractionBase is templated, they need to remain close enough. +/// + +/// Takes a `tensorContraction` and a returns an AffineMap that can be used to +/// map ranges to enclosing loops for all the operands' ranges. +template +mlir::AffineMap operandRangesToLoopsMap( + linalg::TensorContractionBase &tensorContraction); + +/// Takes a `tensorContraction` and returns the ranges of all its operands. +/// When an operand comes from a ViewOp, things are simple: +/// just traverse the indexings and get all the ranges +/// (i.e. drop the rank-reducing indices). +/// In the case of a SliceOp, things are more involved because we need to handle +/// potential rank-reductions. +/// This function abstracts this complexity away and returns all the ranges. +template +llvm::SmallVector +getRanges(linalg::TensorContractionBase &tensorContraction); + +} // namespace linalg + +/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of +/// TensorOps by adding implementations as they are needed in the appropriate +/// step in the tutorial. +#include "linalg3/TensorOps-inl.h" + +#endif // LINALG3_TENSOROPS_H_ diff --git a/include/linalg3/Transforms.h b/include/linalg3/Transforms.h new file mode 100644 index 0000000..9af528e --- /dev/null +++ b/include/linalg3/Transforms.h @@ -0,0 +1,80 @@ +//===- Transforms.h - Linalg dialect Transformations definition -----------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG3_TRANSFORMS_H_ +#define LINALG3_TRANSFORMS_H_ + +#include "linalg2/Transforms.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Optional.h" + +namespace mlir { +class AffineForOp; +class AffineMap; +class Function; +class FunctionPassBase; +class Operation; +class Value; +} // namespace mlir + +namespace linalg { + +struct RangeParts { + explicit RangeParts(unsigned reserved); + RangeParts(llvm::ArrayRef ranges); + llvm::SmallVector makeRanges(); + + llvm::SmallVector mins; + llvm::SmallVector maxes; + llvm::SmallVector steps; +}; + +mlir::Value * +makeFoldedComposedAffineApply(mlir::AffineMap map, + llvm::ArrayRef operandsRef); + +llvm::SmallVector +makeGenericLoopRanges(mlir::AffineMap operandRangesToLoopMaps, + llvm::ArrayRef ranges, + llvm::ArrayRef tileSizes = {}); + +/// Traverses `f` and rewrites linalg.slice, and the operations it depends on, +/// to only use linalg.view operations. +void composeSliceOps(mlir::Function *f); + +/// Traverses `f` and rewrites linalg.matmul(resp. linalg.matvec) +/// as linalg.matvec(resp. linalg.dot). +void lowerToFinerGrainedTensorContraction(mlir::Function *f); + +/// Operation-wise writing of linalg operations to loop form. +/// It is the caller's responsibility to erase the `op` if necessary. +/// This returns the enclosing loops around the body of `op` for further +/// composition of transformations. +llvm::Optional> +writeAsLoops(mlir::Operation *op); + +/// Traverses `f` and rewrites linalg operations in loop form. +void lowerToLoops(mlir::Function *f); + +/// Creates a pass that rewrites linalg.load and linalg.store to affine.load and +/// affine.store operations. +mlir::FunctionPassBase *createLowerLinalgLoadStorePass(); + +} // namespace linalg + +#endif // LINALG3_TRANSFORMS_H_ diff --git a/include/toy/AST.h b/include/toy/AST.h new file mode 100644 index 0000000..456a323 --- /dev/null +++ b/include/toy/AST.h @@ -0,0 +1,256 @@ +//===- AST.h - Node definition for the Toy AST ----------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements the AST for the Toy language. It is optimized for +// simplicity, not efficiency. The AST forms a tree structure where each node +// references its children using std::unique_ptr<>. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_AST_H_ +#define MLIR_TUTORIAL_TOY_AST_H_ + +#include "toy/Lexer.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include + +namespace toy { + +/// A variable +struct VarType { + enum { TY_FLOAT, TY_INT } elt_ty; + std::vector shape; +}; + +/// Base class for all expression nodes. +class ExprAST { +public: + enum ExprASTKind { + Expr_VarDecl, + Expr_Return, + Expr_Num, + Expr_Literal, + Expr_Var, + Expr_BinOp, + Expr_Call, + Expr_Print, // builtin + Expr_If, + Expr_For, + }; + + ExprAST(ExprASTKind kind, Location location) + : kind(kind), location(location) {} + + virtual ~ExprAST() = default; + + ExprASTKind getKind() const { return kind; } + + const Location &loc() { return location; } + +private: + const ExprASTKind kind; + Location location; +}; + +/// A block-list of expressions. +using ExprASTList = std::vector>; + +/// Expression class for numeric literals like "1.0". +class NumberExprAST : public ExprAST { + double Val; + +public: + NumberExprAST(Location loc, double Val) : ExprAST(Expr_Num, loc), Val(Val) {} + + double getValue() { return Val; } + + /// LLVM style RTTI + static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; } +}; + +/// +class LiteralExprAST : public ExprAST { + std::vector> values; + std::vector dims; + +public: + LiteralExprAST(Location loc, std::vector> values, + std::vector dims) + : ExprAST(Expr_Literal, loc), values(std::move(values)), + dims(std::move(dims)) {} + + std::vector> &getValues() { return values; } + std::vector &getDims() { return dims; } + /// LLVM style RTTI + static bool classof(const ExprAST *C) { return C->getKind() == Expr_Literal; } +}; + +/// Expression class for referencing a variable, like "a". +class VariableExprAST : public ExprAST { + std::string name; + +public: + VariableExprAST(Location loc, const std::string &name) + : ExprAST(Expr_Var, loc), name(name) {} + + llvm::StringRef getName() { return name; } + + /// LLVM style RTTI + static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; } +}; + +/// +class VarDeclExprAST : public ExprAST { + std::string name; + VarType type; + std::unique_ptr initVal; + +public: + VarDeclExprAST(Location loc, const std::string &name, VarType type, + std::unique_ptr initVal) + : ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)), + initVal(std::move(initVal)) {} + + llvm::StringRef getName() { return name; } + ExprAST *getInitVal() { return initVal.get(); } + VarType &getType() { return type; } + + /// LLVM style RTTI + static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; } +}; + +/// +class ReturnExprAST : public ExprAST { + llvm::Optional> expr; + +public: + ReturnExprAST(Location loc, llvm::Optional> expr) + : ExprAST(Expr_Return, loc), expr(std::move(expr)) {} + + llvm::Optional getExpr() { + if (expr.hasValue()) + return expr->get(); + return llvm::NoneType(); + } + + /// LLVM style RTTI + static bool classof(const ExprAST *C) { return C->getKind() == Expr_Return; } +}; + +/// Expression class for a binary operator. +class BinaryExprAST : public ExprAST { + char Op; + std::unique_ptr LHS, RHS; + +public: + char getOp() { return Op; } + ExprAST *getLHS() { return LHS.get(); } + ExprAST *getRHS() { return RHS.get(); } + + BinaryExprAST(Location loc, char Op, std::unique_ptr LHS, + std::unique_ptr RHS) + : ExprAST(Expr_BinOp, loc), Op(Op), LHS(std::move(LHS)), + RHS(std::move(RHS)) {} + + /// LLVM style RTTI + static bool classof(const ExprAST *C) { return C->getKind() == Expr_BinOp; } +}; + +/// Expression class for function calls. +class CallExprAST : public ExprAST { + std::string Callee; + std::vector> Args; + +public: + CallExprAST(Location loc, const std::string &Callee, + std::vector> Args) + : ExprAST(Expr_Call, loc), Callee(Callee), Args(std::move(Args)) {} + + llvm::StringRef getCallee() { return Callee; } + llvm::ArrayRef> getArgs() { return Args; } + + /// LLVM style RTTI + static bool classof(const ExprAST *C) { return C->getKind() == Expr_Call; } +}; + +/// Expression class for builtin print calls. +class PrintExprAST : public ExprAST { + std::unique_ptr Arg; + +public: + PrintExprAST(Location loc, std::unique_ptr Arg) + : ExprAST(Expr_Print, loc), Arg(std::move(Arg)) {} + + ExprAST *getArg() { return Arg.get(); } + + /// LLVM style RTTI + static bool classof(const ExprAST *C) { return C->getKind() == Expr_Print; } +}; + +/// This class represents the "prototype" for a function, which captures its +/// name, and its argument names (thus implicitly the number of arguments the +/// function takes). +class PrototypeAST { + Location location; + std::string name; + std::vector> args; + +public: + PrototypeAST(Location location, const std::string &name, + std::vector> args) + : location(location), name(name), args(std::move(args)) {} + + const Location &loc() { return location; } + const std::string &getName() const { return name; } + const std::vector> &getArgs() { + return args; + } +}; + +/// This class represents a function definition itself. +class FunctionAST { + std::unique_ptr Proto; + std::unique_ptr Body; + +public: + FunctionAST(std::unique_ptr Proto, + std::unique_ptr Body) + : Proto(std::move(Proto)), Body(std::move(Body)) {} + PrototypeAST *getProto() { return Proto.get(); } + ExprASTList *getBody() { return Body.get(); } +}; + +/// This class represents a list of functions to be processed together +class ModuleAST { + std::vector functions; + +public: + ModuleAST(std::vector functions) + : functions(std::move(functions)) {} + + auto begin() -> decltype(functions.begin()) { return functions.begin(); } + auto end() -> decltype(functions.end()) { return functions.end(); } +}; + +void dump(ModuleAST &); + +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_AST_H_ diff --git a/include/toy/Dialect.h b/include/toy/Dialect.h new file mode 100644 index 0000000..9d7f82d --- /dev/null +++ b/include/toy/Dialect.h @@ -0,0 +1,393 @@ +//===- Dialect.h - Dialect definition for the Toy IR ----------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements the IR Dialect for the Toy language. +// See g3doc/Tutorials/Toy/Ch-3.md for more information. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_ +#define MLIR_TUTORIAL_TOY_DIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" + +namespace mlir { +class Builder; +} + +namespace toy { + +/// This is the definition of the Toy dialect. A dialect inherits from +/// mlir::Dialect and register custom operations and types (in its constructor). +/// It can also overridding general behavior of dialects exposed as virtual +/// method, for example regarding verification and parsing/printing. +class ToyDialect : public mlir::Dialect { +public: + explicit ToyDialect(mlir::MLIRContext *ctx); + + /// Parse a type registered to this dialect. Overridding this method is + /// required for dialects that have custom types. + /// Technically this is only needed to be able to round-trip to textual IR. + mlir::Type parseType(llvm::StringRef tyData, + mlir::Location loc) const override; + + /// Print a type registered to this dialect. Overridding this method is + /// only required for dialects that have custom types. + /// Technically this is only needed to be able to round-trip to textual IR. + void printType(mlir::Type type, llvm::raw_ostream &os) const override; +}; + +//////////////////////////////////////////////////////////////////////////////// +/////////////////////// Custom Types for the Dialect /////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +namespace detail { +struct ToyArrayTypeStorage; +} + +/// LLVM-style RTTI: one entry per subclass to allow dyn_cast/isa. +enum ToyTypeKind { + // The enum starts at the range reserved for this dialect. + TOY_TYPE = mlir::Type::FIRST_TOY_TYPE, + TOY_ARRAY, +}; + +/// Type for Toy arrays. +/// In MLIR Types are reference to immutable and uniqued objects owned by the +/// MLIRContext. As such `ToyArrayType` only wraps a pointer to an uniqued +/// instance of `ToyArrayTypeStorage` (defined in our implementation file) and +/// provides the public facade API to interact with the type. +class ToyArrayType : public mlir::Type::TypeBase { +public: + using Base::Base; + + /// Returns the dimensions for this array, or and empty range for a generic + /// array. + llvm::ArrayRef getShape(); + + /// Predicate to test if this array is generic (shape haven't been inferred + /// yet). + bool isGeneric() { return getShape().empty(); } + + /// Return the rank of this array (0 if it is generic). + int getRank() { return getShape().size(); } + + /// Return the type of individual elements in the array. + mlir::Type getElementType(); + + /// Get a MemRef equivalent to this array type. + mlir::MemRefType toMemref(); + + /// Get the unique instance of this Type from the context. + /// A ToyArrayType is only defined by the shape of the array. + static ToyArrayType get(mlir::MLIRContext *context, + llvm::ArrayRef shape = {}); + + /// Support method to enable LLVM-style RTTI type casting. + static bool kindof(unsigned kind) { return kind == ToyTypeKind::TOY_ARRAY; } +}; + +//////////////////////////////////////////////////////////////////////////////// +//////////////////// Custom Operations for the Dialect ///////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +/// Constant operation turns a literal into an SSA value. The data is attached +/// to the operation as an attribute. For example: +/// +/// %0 = "toy.constant"() +/// {value: dense, [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>} +/// : () -> !toy<"array<2, 3>"> +/// +/// An operation inherits from `class Op` and specifies optional traits. Here we +/// indicate that `toy.constant` does not have any operands and returns a single +/// result. The traits provide some utilities methods for the operation, for +/// instance we will be able to use `getResult()`, but `getOperand()` won't be +/// available. +class ConstantOp : public mlir::Op { +public: + /// This is the name used by MLIR to match an operation to this class during + /// parsing. + static llvm::StringRef getOperationName() { return "toy.constant"; } + + /// The operation can have extra verification beyond the traits they define. + mlir::LogicalResult verify(); + + /// Interface to mlir::Builder::create(...) + /// This method populates the `state` that MLIR uses to create operations. + /// The `toy.constant` operation does not have arguments but attaches a + /// constant array as an attribute and returns it as an SSA value. + static void build(mlir::Builder *builder, mlir::OperationState *state, + llvm::ArrayRef shape, + mlir::DenseElementsAttr value); + + /// Similar to the one above, but takes a single float and returns a + /// !toy<"array<1>">. + static void build(mlir::Builder *builder, mlir::OperationState *state, + mlir::FloatAttr value); + + mlir::DenseElementsAttr getValue() { + return getAttr("value").cast(); + } + + /// Inherit constructor. + using Op::Op; +}; + +/// Generic calls represent calls to a user defined function that needs to +/// be specialized for the shape of its arguments. The callee name is attached +/// as a literal string as an attribute. The arguments list must match the +/// arguments expected by the callee. For example: +/// +/// %4 = "toy.generic_call"(%1, %3) {callee: "my_func"} +/// : (!toy<"array<2, 3>">, !toy<"array<2, 3>">) -> !toy<"array"> +/// +/// This is only valid if a function named "my_func" exists and takes two +/// arguments. +class GenericCallOp + : public mlir::Op { +public: + /// MLIR will use this to register the operation with the parser/printer. + static llvm::StringRef getOperationName() { return "toy.generic_call"; } + + /// Operations can add custom verification beyond the traits they define. + mlir::LogicalResult verify(); + + /// Interface to the builder to allow: + /// mlir::Builder::create(...) + /// This method populate the `state` that MLIR use to create operations. + /// The `toy.generic_call` operation accepts a callee name and a list of + /// arguments for the call. + static void build(mlir::Builder *builder, mlir::OperationState *state, + llvm::StringRef callee, + llvm::ArrayRef arguments); + + /// Return the name of the callee. + llvm::StringRef getCalleeName(); + + /// Inherit constructor. + using Op::Op; +}; + +/// Return operations terminate blocks (and functions as well). They take a +/// single argument and the type must match the function return type. +class ReturnOp + : public mlir::Op { +public: + static llvm::StringRef getOperationName() { return "toy.return"; } + + /// Operations can add custom verification beyond the traits they define. + mlir::LogicalResult verify(); + + /// Interface to mlir::Builder::create(...) + /// This method populate the `state` that MLIR use to create operations. + /// The `toy.return` operation accepts an optional single array as an argument + /// and does not have any returned value. + static void build(mlir::Builder *builder, mlir::OperationState *state, + mlir::Value *value = nullptr); + + /// Return true if there is a returned value. + bool hasOperand() { return 0 != getNumOperands(); } + + /// Helper to return the optional operand. Caller must check if the operand + /// is present before calling this. + mlir::Value *getOperand() { return getOperation()->getOperand(0); } + + /// Inherit constructor. + using Op::Op; +}; + +/// The print builtin takes a single array argument and does not return any. +class PrintOp : public mlir::Op { +public: + static llvm::StringRef getOperationName() { return "toy.print"; } + + /// Operations can add custom verification beyond the traits they define. + mlir::LogicalResult verify(); + + /// Interface to mlir::Builder::create(...) + /// This method populate the `state` that MLIR use to create operations. + /// The `toy.print` operation accepts a single array as argument and does + /// not have any returned value. + static void build(mlir::Builder *builder, mlir::OperationState *state, + mlir::Value *value); + + /// Inherit constructor. + using Op::Op; +}; + +class TransposeOp : public mlir::Op { +public: + static llvm::StringRef getOperationName() { return "toy.transpose"; } + + /// Operation can add custom verification beyond the traits they define. + mlir::LogicalResult verify(); + + /// Interface to mlir::Builder::create(...) + /// This method populate the `state` that MLIR use to create operations. + /// The `toy.transpose` operation accepts a single array as argument and + /// returns the transposed array as its only result. + static void build(mlir::Builder *builder, mlir::OperationState *state, + mlir::Value *value); + + // Register our patterns for rewrite by the Canonicalization framework. + static void + getCanonicalizationPatterns(mlir::OwningRewritePatternList &results, + mlir::MLIRContext *context); + + /// Inherit constructor. + using Op::Op; +}; + +/// Reshape operation is transforming its input array into a new array with the +/// same number of elements but different shapes. For example: +/// +/// %0 = "toy.transpose"(%arg1) : (!toy<"array<10>">) -> !toy<"array<5, 2>"> +/// +class ReshapeOp : public mlir::Op { +public: + static llvm::StringRef getOperationName() { return "toy.reshape"; } + + /// Operation can add custom verification beyond the traits they define. + mlir::LogicalResult verify(); + + /// Interface to mlir::Builder::create(...) + /// This method populate the `state` that MLIR use to create operations. + /// The `toy.reshape` operation accepts a single array as argument and + /// returns the array with the specified reshapedType as its only result. + static void build(mlir::Builder *builder, mlir::OperationState *state, + mlir::Value *value, ToyArrayType reshapedType); + + // Register our patterns for rewrite by the Canonicalization framework. + static void + getCanonicalizationPatterns(mlir::OwningRewritePatternList &results, + mlir::MLIRContext *context); + + /// Inherit constructor. + using Op::Op; +}; + +/// Binary operation implementing a multiplication. For two-dimensional array +/// a matrix multiplication is implemented, while for one dimensional array a +/// dot product is performed. +class MulOp : public mlir::Op::Impl, + mlir::OpTrait::OneResult, + mlir::OpTrait::HasNoSideEffect> { +public: + static llvm::StringRef getOperationName() { return "toy.mul"; } + + /// Operation can add custom verification beyond the traits they define. + mlir::LogicalResult verify(); + + /// Interface to mlir::Builder::create(...) + /// This method populate the `state` that MLIR use to create operations. + /// The `toy.mul` operation accepts two operands as argument and returns + /// a single value. + static void build(mlir::Builder *builder, mlir::OperationState *state, + mlir::Value *lhs, mlir::Value *rhs); + + /// Convenience accessor for LHS of the expression. + mlir::Value *getLHS() { return getOperand(0); } + + /// Convenience accessor for RHS of the expression. + mlir::Value *getRHS() { return getOperand(1); } + + /// Inherit constructor. + using Op::Op; +}; + +/// Element wise addition of two arrays. The shape must match. +class AddOp : public mlir::Op::Impl, + mlir::OpTrait::OneResult, + mlir::OpTrait::HasNoSideEffect> { +public: + static llvm::StringRef getOperationName() { return "toy.add"; } + + /// Operation can add custom verification beyond the traits they define. + mlir::LogicalResult verify(); + + /// Interface to mlir::Builder::create(...) + /// This method populate the `state` that MLIR use to create operations. + /// The `toy.mul` operation accepts two operands as argument and returns + /// a single value. + static void build(mlir::Builder *builder, mlir::OperationState *state, + mlir::Value *lhs, mlir::Value *rhs); + + /// Convenience accessor for LHS of the expression. + mlir::Value *getLHS() { return getOperand(0); } + + /// Convenience accessor for RHS of the expression. + mlir::Value *getRHS() { return getOperand(1); } + + /// Inherit constructor. + using Op::Op; +}; + +/// AllocOp is a temporary operation for buffer allocation, created as part of +/// partial lowering. +class AllocOp : public mlir::Op { +public: + static llvm::StringRef getOperationName() { return "toy.alloc"; } + + /// Interface to mlir::Builder::create(...) + /// This method populate the `state` that MLIR use to create operations. + /// `toy.alloc` does not have any argument and returns a toy array. + static void build(mlir::Builder *builder, mlir::OperationState *state, + mlir::Type retType); + + /// Inherit constructor. + using Op::Op; +}; + +/// FIXME: should be in std? +class TypeCastOp : public mlir::Op { +public: + static llvm::StringRef getOperationName() { return "toy.cast"; } + + static void build(mlir::Builder *builder, mlir::OperationState *state, + mlir::Value *value, mlir::Type destTy); + + // Register our patterns for rewrite by the Canonicalization framework. + static void + getCanonicalizationPatterns(mlir::OwningRewritePatternList &results, + mlir::MLIRContext *context); + + /// Inherit constructor. + using Op::Op; +}; + +} // end namespace toy + +#endif // MLIR_TUTORIAL_TOY_DIALECT_H_ diff --git a/include/toy/Lexer.h b/include/toy/Lexer.h new file mode 100644 index 0000000..d73adb9 --- /dev/null +++ b/include/toy/Lexer.h @@ -0,0 +1,239 @@ +//===- Lexer.h - Lexer for the Toy language -------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a simple Lexer for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_LEXER_H_ +#define MLIR_TUTORIAL_TOY_LEXER_H_ + +#include "llvm/ADT/StringRef.h" + +#include +#include + +namespace toy { + +/// Structure definition a location in a file. +struct Location { + std::shared_ptr file; ///< filename + int line; ///< line number. + int col; ///< column number. +}; + +// List of Token returned by the lexer. +enum Token : int { + tok_semicolon = ';', + tok_parenthese_open = '(', + tok_parenthese_close = ')', + tok_bracket_open = '{', + tok_bracket_close = '}', + tok_sbracket_open = '[', + tok_sbracket_close = ']', + + tok_eof = -1, + + // commands + tok_return = -2, + tok_var = -3, + tok_def = -4, + + // primary + tok_identifier = -5, + tok_number = -6, +}; + +/// The Lexer is an abstract base class providing all the facilities that the +/// Parser expects. It goes through the stream one token at a time and keeps +/// track of the location in the file for debugging purpose. +/// It relies on a subclass to provide a `readNextLine()` method. The subclass +/// can proceed by reading the next line from the standard input or from a +/// memory mapped file. +class Lexer { +public: + /// Create a lexer for the given filename. The filename is kept only for + /// debugging purpose (attaching a location to a Token). + Lexer(std::string filename) + : lastLocation( + {std::make_shared(std::move(filename)), 0, 0}) {} + virtual ~Lexer() = default; + + /// Look at the current token in the stream. + Token getCurToken() { return curTok; } + + /// Move to the next token in the stream and return it. + Token getNextToken() { return curTok = getTok(); } + + /// Move to the next token in the stream, asserting on the current token + /// matching the expectation. + void consume(Token tok) { + assert(tok == curTok && "consume Token mismatch expectation"); + getNextToken(); + } + + /// Return the current identifier (prereq: getCurToken() == tok_identifier) + llvm::StringRef getId() { + assert(curTok == tok_identifier); + return IdentifierStr; + } + + /// Return the current number (prereq: getCurToken() == tok_number) + double getValue() { + assert(curTok == tok_number); + return NumVal; + } + + /// Return the location for the beginning of the current token. + Location getLastLocation() { return lastLocation; } + + // Return the current line in the file. + int getLine() { return curLineNum; } + + // Return the current column in the file. + int getCol() { return curCol; } + +private: + /// Delegate to a derived class fetching the next line. Returns an empty + /// string to signal end of file (EOF). Lines are expected to always finish + /// with "\n" + virtual llvm::StringRef readNextLine() = 0; + + /// Return the next character from the stream. This manages the buffer for the + /// current line and request the next line buffer to the derived class as + /// needed. + int getNextChar() { + // The current line buffer should not be empty unless it is the end of file. + if (curLineBuffer.empty()) + return EOF; + ++curCol; + auto nextchar = curLineBuffer.front(); + curLineBuffer = curLineBuffer.drop_front(); + if (curLineBuffer.empty()) + curLineBuffer = readNextLine(); + if (nextchar == '\n') { + ++curLineNum; + curCol = 0; + } + return nextchar; + } + + /// Return the next token from standard input. + Token getTok() { + // Skip any whitespace. + while (isspace(LastChar)) + LastChar = Token(getNextChar()); + + // Save the current location before reading the token characters. + lastLocation.line = curLineNum; + lastLocation.col = curCol; + + if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9_]* + IdentifierStr = (char)LastChar; + while (isalnum((LastChar = Token(getNextChar()))) || LastChar == '_') + IdentifierStr += (char)LastChar; + + if (IdentifierStr == "return") + return tok_return; + if (IdentifierStr == "def") + return tok_def; + if (IdentifierStr == "var") + return tok_var; + return tok_identifier; + } + + if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+ + std::string NumStr; + do { + NumStr += LastChar; + LastChar = Token(getNextChar()); + } while (isdigit(LastChar) || LastChar == '.'); + + NumVal = strtod(NumStr.c_str(), nullptr); + return tok_number; + } + + if (LastChar == '#') { + // Comment until end of line. + do + LastChar = Token(getNextChar()); + while (LastChar != EOF && LastChar != '\n' && LastChar != '\r'); + + if (LastChar != EOF) + return getTok(); + } + + // Check for end of file. Don't eat the EOF. + if (LastChar == EOF) + return tok_eof; + + // Otherwise, just return the character as its ascii value. + Token ThisChar = Token(LastChar); + LastChar = Token(getNextChar()); + return ThisChar; + } + + /// The last token read from the input. + Token curTok = tok_eof; + + /// Location for `curTok`. + Location lastLocation; + + /// If the current Token is an identifier, this string contains the value. + std::string IdentifierStr; + + /// If the current Token is a number, this contains the value. + double NumVal = 0; + + /// The last value returned by getNextChar(). We need to keep it around as we + /// always need to read ahead one character to decide when to end a token and + /// we can't put it back in the stream after reading from it. + Token LastChar = Token(' '); + + /// Keep track of the current line number in the input stream + int curLineNum = 0; + + /// Keep track of the current column number in the input stream + int curCol = 0; + + /// Buffer supplied by the derived class on calls to `readNextLine()` + llvm::StringRef curLineBuffer = "\n"; +}; + +/// A lexer implementation operating on a buffer in memory. +class LexerBuffer final : public Lexer { +public: + LexerBuffer(const char *begin, const char *end, std::string filename) + : Lexer(std::move(filename)), current(begin), end(end) {} + +private: + /// Provide one line at a time to the Lexer, return an empty string when + /// reaching the end of the buffer. + llvm::StringRef readNextLine() override { + auto *begin = current; + while (current <= end && *current && *current != '\n') + ++current; + if (current <= end && *current) + ++current; + llvm::StringRef result{begin, static_cast(current - begin)}; + return result; + } + const char *current, *end; +}; +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_LEXER_H_ diff --git a/include/toy/Lowering.h b/include/toy/Lowering.h new file mode 100644 index 0000000..362a342 --- /dev/null +++ b/include/toy/Lowering.h @@ -0,0 +1,45 @@ +//===- Lowering.h - Lexer for the Toy language ----------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file exposes the interface to the lowering for Toy. It is divided in +// two parts: an *early lowering* that emits operations in the `Linalg` +// dialects for a subset of the Toy IR, and a *late lowering* that materializes +// buffers and converts all operations and type to the LLVM dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_EXAMPLES_TOY_LOWERING_H_ +#define MLIR_EXAMPLES_TOY_LOWERING_H_ + +#include + +namespace mlir { +class Pass; +class DialectConversion; +} // namespace mlir + +namespace toy { +/// Create a pass for lowering operations in the `Linalg` dialects, for a subset +/// of the Toy IR (matmul). +mlir::Pass *createEarlyLoweringPass(); + +/// Create a pass for the late lowering toward LLVM dialect. +mlir::Pass *createLateLoweringPass(); + +} // namespace toy + +#endif // MLIR_EXAMPLES_TOY_LOWERING_H_ diff --git a/include/toy/MLIRGen.h b/include/toy/MLIRGen.h new file mode 100644 index 0000000..21637bc --- /dev/null +++ b/include/toy/MLIRGen.h @@ -0,0 +1,42 @@ +//===- MLIRGen.h - MLIR Generation from a Toy AST -------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file declares a simple interface to perform IR generation targeting MLIR +// from a Module AST for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_MLIRGEN_H_ +#define MLIR_TUTORIAL_TOY_MLIRGEN_H_ + +#include + +namespace mlir { +class MLIRContext; +class Module; +} // namespace mlir + +namespace toy { +class ModuleAST; + +/// Emit IR for the given Toy moduleAST, returns a newly created MLIR module +/// or nullptr on failure. +std::unique_ptr mlirGen(mlir::MLIRContext &context, + ModuleAST &moduleAST); +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_MLIRGEN_H_ diff --git a/include/toy/Parser.h b/include/toy/Parser.h new file mode 100644 index 0000000..bc7aa52 --- /dev/null +++ b/include/toy/Parser.h @@ -0,0 +1,494 @@ +//===- Parser.h - Toy Language Parser -------------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements the parser for the Toy language. It processes the Token +// provided by the Lexer and returns an AST. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_PARSER_H +#define MLIR_TUTORIAL_TOY_PARSER_H + +#include "toy/AST.h" +#include "toy/Lexer.h" + +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include + +namespace toy { + +/// This is a simple recursive parser for the Toy language. It produces a well +/// formed AST from a stream of Token supplied by the Lexer. No semantic checks +/// or symbol resolution is performed. For example, variables are referenced by +/// string and the code could reference an undeclared variable and the parsing +/// succeeds. +class Parser { +public: + /// Create a Parser for the supplied lexer. + Parser(Lexer &lexer) : lexer(lexer) {} + + /// Parse a full Module. A module is a list of function definitions. + std::unique_ptr ParseModule() { + lexer.getNextToken(); // prime the lexer + + // Parse functions one at a time and accumulate in this vector. + std::vector functions; + while (auto F = ParseDefinition()) { + functions.push_back(std::move(*F)); + if (lexer.getCurToken() == tok_eof) + break; + } + // If we didn't reach EOF, there was an error during parsing + if (lexer.getCurToken() != tok_eof) + return parseError("nothing", "at end of module"); + + return llvm::make_unique(std::move(functions)); + } + +private: + Lexer &lexer; + + /// Parse a return statement. + /// return :== return ; | return expr ; + std::unique_ptr ParseReturn() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_return); + + // return takes an optional argument + llvm::Optional> expr; + if (lexer.getCurToken() != ';') { + expr = ParseExpression(); + if (!expr) + return nullptr; + } + return llvm::make_unique(std::move(loc), std::move(expr)); + } + + /// Parse a literal number. + /// numberexpr ::= number + std::unique_ptr ParseNumberExpr() { + auto loc = lexer.getLastLocation(); + auto Result = + llvm::make_unique(std::move(loc), lexer.getValue()); + lexer.consume(tok_number); + return std::move(Result); + } + + /// Parse a literal array expression. + /// tensorLiteral ::= [ literalList ] | number + /// literalList ::= tensorLiteral | tensorLiteral, literalList + std::unique_ptr ParseTensorLitteralExpr() { + auto loc = lexer.getLastLocation(); + lexer.consume(Token('[')); + + // Hold the list of values at this nesting level. + std::vector> values; + // Hold the dimensions for all the nesting inside this level. + std::vector dims; + do { + // We can have either another nested array or a number literal. + if (lexer.getCurToken() == '[') { + values.push_back(ParseTensorLitteralExpr()); + if (!values.back()) + return nullptr; // parse error in the nested array. + } else { + if (lexer.getCurToken() != tok_number) + return parseError(" or [", "in literal expression"); + values.push_back(ParseNumberExpr()); + } + + // End of this list on ']' + if (lexer.getCurToken() == ']') + break; + + // Elements are separated by a comma. + if (lexer.getCurToken() != ',') + return parseError("] or ,", "in literal expression"); + + lexer.getNextToken(); // eat , + } while (true); + if (values.empty()) + return parseError("", "to fill literal expression"); + lexer.getNextToken(); // eat ] + /// Fill in the dimensions now. First the current nesting level: + dims.push_back(values.size()); + /// If there is any nested array, process all of them and ensure that + /// dimensions are uniform. + if (llvm::any_of(values, [](std::unique_ptr &expr) { + return llvm::isa(expr.get()); + })) { + auto *firstLiteral = llvm::dyn_cast(values.front().get()); + if (!firstLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expession"); + + // Append the nested dimensions to the current level + auto &firstDims = firstLiteral->getDims(); + dims.insert(dims.end(), firstDims.begin(), firstDims.end()); + + // Sanity check that shape is uniform across all elements of the list. + for (auto &expr : values) { + auto *exprLiteral = llvm::cast(expr.get()); + if (!exprLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expession"); + if (exprLiteral->getDims() != firstDims) + return parseError("uniform well-nested dimensions", + "inside literal expession"); + } + } + return llvm::make_unique(std::move(loc), std::move(values), + std::move(dims)); + } + + /// parenexpr ::= '(' expression ')' + std::unique_ptr ParseParenExpr() { + lexer.getNextToken(); // eat (. + auto V = ParseExpression(); + if (!V) + return nullptr; + + if (lexer.getCurToken() != ')') + return parseError(")", "to close expression with parentheses"); + lexer.consume(Token(')')); + return V; + } + + /// identifierexpr + /// ::= identifier + /// ::= identifier '(' expression ')' + std::unique_ptr ParseIdentifierExpr() { + std::string name = lexer.getId(); + + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat identifier. + + if (lexer.getCurToken() != '(') // Simple variable ref. + return llvm::make_unique(std::move(loc), name); + + // This is a function call. + lexer.consume(Token('(')); + std::vector> Args; + if (lexer.getCurToken() != ')') { + while (true) { + if (auto Arg = ParseExpression()) + Args.push_back(std::move(Arg)); + else + return nullptr; + + if (lexer.getCurToken() == ')') + break; + + if (lexer.getCurToken() != ',') + return parseError(", or )", "in argument list"); + lexer.getNextToken(); + } + } + lexer.consume(Token(')')); + + // It can be a builtin call to print + if (name == "print") { + if (Args.size() != 1) + return parseError("", "as argument to print()"); + + return llvm::make_unique(std::move(loc), + std::move(Args[0])); + } + + // Call to a user-defined function + return llvm::make_unique(std::move(loc), name, + std::move(Args)); + } + + /// primary + /// ::= identifierexpr + /// ::= numberexpr + /// ::= parenexpr + /// ::= tensorliteral + std::unique_ptr ParsePrimary() { + switch (lexer.getCurToken()) { + default: + llvm::errs() << "unknown token '" << lexer.getCurToken() + << "' when expecting an expression\n"; + return nullptr; + case tok_identifier: + return ParseIdentifierExpr(); + case tok_number: + return ParseNumberExpr(); + case '(': + return ParseParenExpr(); + case '[': + return ParseTensorLitteralExpr(); + case ';': + return nullptr; + case '}': + return nullptr; + } + } + + /// Recursively parse the right hand side of a binary expression, the ExprPrec + /// argument indicates the precedence of the current binary operator. + /// + /// binoprhs ::= ('+' primary)* + std::unique_ptr ParseBinOpRHS(int ExprPrec, + std::unique_ptr LHS) { + // If this is a binop, find its precedence. + while (true) { + int TokPrec = GetTokPrecedence(); + + // If this is a binop that binds at least as tightly as the current binop, + // consume it, otherwise we are done. + if (TokPrec < ExprPrec) + return LHS; + + // Okay, we know this is a binop. + int BinOp = lexer.getCurToken(); + lexer.consume(Token(BinOp)); + auto loc = lexer.getLastLocation(); + + // Parse the primary expression after the binary operator. + auto RHS = ParsePrimary(); + if (!RHS) + return parseError("expression", "to complete binary operator"); + + // If BinOp binds less tightly with RHS than the operator after RHS, let + // the pending operator take RHS as its LHS. + int NextPrec = GetTokPrecedence(); + if (TokPrec < NextPrec) { + RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS)); + if (!RHS) + return nullptr; + } + + // Merge LHS/RHS. + LHS = llvm::make_unique(std::move(loc), BinOp, + std::move(LHS), std::move(RHS)); + } + } + + /// expression::= primary binoprhs + std::unique_ptr ParseExpression() { + auto LHS = ParsePrimary(); + if (!LHS) + return nullptr; + + return ParseBinOpRHS(0, std::move(LHS)); + } + + /// type ::= < shape_list > + /// shape_list ::= num | num , shape_list + std::unique_ptr ParseType() { + if (lexer.getCurToken() != '<') + return parseError("<", "to begin type"); + lexer.getNextToken(); // eat < + + auto type = llvm::make_unique(); + + while (lexer.getCurToken() == tok_number) { + type->shape.push_back(lexer.getValue()); + lexer.getNextToken(); + if (lexer.getCurToken() == ',') + lexer.getNextToken(); + } + + if (lexer.getCurToken() != '>') + return parseError(">", "to end type"); + lexer.getNextToken(); // eat > + return type; + } + + /// Parse a variable declaration, it starts with a `var` keyword followed by + /// and identifier and an optional type (shape specification) before the + /// initializer. + /// decl ::= var identifier [ type ] = expr + std::unique_ptr ParseDeclaration() { + if (lexer.getCurToken() != tok_var) + return parseError("var", "to begin declaration"); + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat var + + if (lexer.getCurToken() != tok_identifier) + return parseError("identified", + "after 'var' declaration"); + std::string id = lexer.getId(); + lexer.getNextToken(); // eat id + + std::unique_ptr type; // Type is optional, it can be inferred + if (lexer.getCurToken() == '<') { + type = ParseType(); + if (!type) + return nullptr; + } + + if (!type) + type = llvm::make_unique(); + lexer.consume(Token('=')); + auto expr = ParseExpression(); + return llvm::make_unique(std::move(loc), std::move(id), + std::move(*type), std::move(expr)); + } + + /// Parse a block: a list of expression separated by semicolons and wrapped in + /// curly braces. + /// + /// block ::= { expression_list } + /// expression_list ::= block_expr ; expression_list + /// block_expr ::= decl | "return" | expr + std::unique_ptr ParseBlock() { + if (lexer.getCurToken() != '{') + return parseError("{", "to begin block"); + lexer.consume(Token('{')); + + auto exprList = llvm::make_unique(); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + + while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { + if (lexer.getCurToken() == tok_var) { + // Variable declaration + auto varDecl = ParseDeclaration(); + if (!varDecl) + return nullptr; + exprList->push_back(std::move(varDecl)); + } else if (lexer.getCurToken() == tok_return) { + // Return statement + auto ret = ParseReturn(); + if (!ret) + return nullptr; + exprList->push_back(std::move(ret)); + } else { + // General expression + auto expr = ParseExpression(); + if (!expr) + return nullptr; + exprList->push_back(std::move(expr)); + } + // Ensure that elements are separated by a semicolon. + if (lexer.getCurToken() != ';') + return parseError(";", "after expression"); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + } + + if (lexer.getCurToken() != '}') + return parseError("}", "to close block"); + + lexer.consume(Token('}')); + return exprList; + } + + /// prototype ::= def id '(' decl_list ')' + /// decl_list ::= identifier | identifier, decl_list + std::unique_ptr ParsePrototype() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_def); + if (lexer.getCurToken() != tok_identifier) + return parseError("function name", "in prototype"); + + std::string FnName = lexer.getId(); + lexer.consume(tok_identifier); + + if (lexer.getCurToken() != '(') + return parseError("(", "in prototype"); + lexer.consume(Token('(')); + + std::vector> args; + if (lexer.getCurToken() != ')') { + do { + std::string name = lexer.getId(); + auto loc = lexer.getLastLocation(); + lexer.consume(tok_identifier); + auto decl = llvm::make_unique(std::move(loc), name); + args.push_back(std::move(decl)); + if (lexer.getCurToken() != ',') + break; + lexer.consume(Token(',')); + if (lexer.getCurToken() != tok_identifier) + return parseError( + "identifier", "after ',' in function parameter list"); + } while (true); + } + if (lexer.getCurToken() != ')') + return parseError("}", "to end function prototype"); + + // success. + lexer.consume(Token(')')); + return llvm::make_unique(std::move(loc), FnName, + std::move(args)); + } + + /// Parse a function definition, we expect a prototype initiated with the + /// `def` keyword, followed by a block containing a list of expressions. + /// + /// definition ::= prototype block + std::unique_ptr ParseDefinition() { + auto Proto = ParsePrototype(); + if (!Proto) + return nullptr; + + if (auto block = ParseBlock()) + return llvm::make_unique(std::move(Proto), std::move(block)); + return nullptr; + } + + /// Get the precedence of the pending binary operator token. + int GetTokPrecedence() { + if (!isascii(lexer.getCurToken())) + return -1; + + // 1 is lowest precedence. + switch (static_cast(lexer.getCurToken())) { + case '-': + return 20; + case '+': + return 20; + case '*': + return 40; + default: + return -1; + } + } + + /// Helper function to signal errors while parsing, it takes an argument + /// indicating the expected token and another argument giving more context. + /// Location is retrieved from the lexer to enrich the error message. + template + std::unique_ptr parseError(T &&expected, U &&context = "") { + auto curToken = lexer.getCurToken(); + llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", " + << lexer.getLastLocation().col << "): expected '" << expected + << "' " << context << " but has Token " << curToken; + if (isprint(curToken)) + llvm::errs() << " '" << (char)curToken << "'"; + llvm::errs() << "\n"; + return nullptr; + } +}; + +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_PARSER_H diff --git a/include/toy/Passes.h b/include/toy/Passes.h new file mode 100644 index 0000000..dd73b95 --- /dev/null +++ b/include/toy/Passes.h @@ -0,0 +1,33 @@ +//===- Passes.h - Toy Passes Definition -----------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file exposes the entry points to create compiler passes for Toy. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_PASSES_H +#define MLIR_TUTORIAL_TOY_PASSES_H + +namespace mlir { +class Pass; +} // namespace mlir + +namespace toy { +mlir::Pass *createShapeInferencePass(); +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_PASSES_H diff --git a/mlir/EarlyLowering.cpp b/mlir/EarlyLowering.cpp new file mode 100644 index 0000000..634c72e --- /dev/null +++ b/mlir/EarlyLowering.cpp @@ -0,0 +1,158 @@ +//=======- EarlyLowering.cpp - Toy Lowering to Linear Algebra Dialect -=======// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements early lowering of Toy IR to Linalg Dialect: we only +// lower the computationally intensive part of the program (matmul...) to a +// dialect specialized for optimizations. +// +// This is intended to showcase how multiple dialects can cohabit in the same +// function. After this lowering, you would still have toy.print in the IR for +// example. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" + +#include "linalg3/Intrinsics.h" +#include "linalg1/ViewOp.h" +#include "linalg3/TensorOps.h" +#include "mlir/EDSC/Builders.h" +#include "mlir/EDSC/Helpers.h" +#include "mlir/EDSC/Intrinsics.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/LLVMIR/LLVMDialect.h" +#include "mlir/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Type.h" + +#include + +using namespace mlir; + +namespace { +/// Utility function for type casting: this is making the type checker happy, +/// while delaying the actual work involved to convert the type. Most of the +/// time both side of the cast (producer and consumer) will be lowered to a +/// dialect like LLVM and end up with the same LLVM representation, at which +/// point this becomes a no-op and is eliminated. +Value *typeCast(FuncBuilder &builder, Value *val, Type destTy) { + if (val->getType() == destTy) + return val; + return builder.create(val->getLoc(), val, destTy) + .getResult(); +} + +/// Create a type cast to turn a toy.array into a memref. The Toy Array will be +/// lowered to a memref during buffer allocation, at which point the type cast +/// becomes useless. +Value *memRefTypeCast(FuncBuilder &builder, Value *val) { + if (val->getType().isa()) + return val; + auto toyArrayTy = val->getType().dyn_cast(); + if (!toyArrayTy) + return val; + return typeCast(builder, val, toyArrayTy.toMemref()); +} + +/// Lower toy.mul to Linalg `matmul`. +/// +/// This class inherit from `DialectOpConversion` and override `rewrite`, +/// similarly to the PatternRewriter introduced in the previous chapter. +/// It will be called by the DialectConversion framework (see `LateLowering` +/// class below). +class MulOpConversion : public DialectOpConversion { +public: + explicit MulOpConversion(MLIRContext *context) + : DialectOpConversion(toy::MulOp::getOperationName(), 1, context) {} + + SmallVector rewrite(Operation *op, ArrayRef operands, + FuncBuilder &rewriter) const override { + using namespace edsc; + using intrinsics::constant_index; + using linalg::intrinsics::range; + using linalg::intrinsics::view; + toy::MulOp mul = op->cast(); + auto loc = mul.getLoc(); + Value *result = memRefTypeCast( + rewriter, rewriter.create(loc, mul.getResult()->getType()) + .getResult()); + Value *lhs = memRefTypeCast(rewriter, operands[0]); + auto memrefLHSTy = lhs->getType().cast(); + Value *rhs = memRefTypeCast(rewriter, operands[1]); + auto memrefRHSTy = rhs->getType().cast(); + mlir::edsc::ScopedContext scope(rewriter, loc); + edsc::ValueHandle r0 = + range(constant_index(0), constant_index(memrefLHSTy.getDimSize(0)), + constant_index(1)); + edsc::ValueHandle r1 = + range(constant_index(0), constant_index(memrefLHSTy.getDimSize(1)), + constant_index(1)); + edsc::ValueHandle r2 = + range(constant_index(0), constant_index(memrefRHSTy.getDimSize(1)), + constant_index(1)); + auto lhsView = view(lhs, {r0, r1}); + auto rhsView = view(rhs, {r1, r2}); + auto resultView = view(result, {r0, r2}); + rewriter.create(loc, lhsView, rhsView, resultView); + return {typeCast(rewriter, result, mul.getType())}; + } +}; + +// The conversion class from Toy IR Dialect to a mix of Linalg and LLVM. +class EarlyLowering : public DialectConversion { +protected: + // Initialize the list of converters. + llvm::DenseSet + initConverters(MLIRContext *context) override { + return ConversionListBuilder::build(&allocator, context); + } + +private: + llvm::BumpPtrAllocator allocator; +}; + +/// This is lowering to Linalg the parts that are computationally intensive +/// (like matmul for example...) while keeping the rest of the code in the Toy +/// dialect. +struct EarlyLoweringPass : public ModulePass { + + void runOnModule() override { + if (failed(EarlyLowering().convert(&getModule()))) { + getModule().getContext()->emitError( + mlir::UnknownLoc::get(getModule().getContext()), + "Error lowering Toy\n"); + signalPassFailure(); + } + } +}; +} // end anonymous namespace + +namespace toy { +Pass *createEarlyLoweringPass() { return new EarlyLoweringPass(); } + +std::unique_ptr makeToyEarlyLowering() { + return llvm::make_unique(); +} + +} // namespace toy diff --git a/mlir/LateLowering.cpp b/mlir/LateLowering.cpp new file mode 100644 index 0000000..eeae6ee --- /dev/null +++ b/mlir/LateLowering.cpp @@ -0,0 +1,452 @@ +//====- LateLowering.cpp - Lowering from Toy+Linalg to LLVM -===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements late lowering of IR mixing Toy and Linalg to LLVM. +// It involves intemerdiate steps: +// - +// - a mix of affine and standard dialect. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" + +#include "linalg3/Intrinsics.h" +#include "linalg1/ViewOp.h" +#include "linalg3/ConvertToLLVMDialect.h" +#include "linalg3/TensorOps.h" +#include "linalg3/Transforms.h" +#include "mlir/EDSC/Builders.h" +#include "mlir/EDSC/Helpers.h" +#include "mlir/EDSC/Intrinsics.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/LLVMIR/LLVMDialect.h" +#include "mlir/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Type.h" + +#include + +using namespace mlir; + +namespace { +/// Utility function for type casting: this is making the type checker happy, +/// while delaying the actual work involved to convert the type. Most of the +/// time both side of the cast (producer and consumer) will be lowered to a +/// dialect like LLVM and end up with the same LLVM representation, at which +/// point this becomes a no-op and is eliminated. +Value *typeCast(FuncBuilder &builder, Value *val, Type destTy) { + if (val->getType() == destTy) + return val; + return builder.create(val->getLoc(), val, destTy) + .getResult(); +} + +/// Create a type cast to turn a toy.array into a memref. The Toy Array will be +/// lowered to a memref during buffer allocation, at which point the type cast +/// becomes useless. +Value *memRefTypeCast(FuncBuilder &builder, Value *val) { + if (val->getType().isa()) + return val; + auto toyArrayTy = val->getType().dyn_cast(); + if (!toyArrayTy) + return val; + return typeCast(builder, val, toyArrayTy.toMemref()); +} + +/// Lower a toy.add to an affine loop nest. +/// +/// This class inherit from `DialectOpConversion` and override `rewrite`, +/// similarly to the PatternRewriter introduced in the previous chapter. +/// It will be called by the DialectConversion framework (see `LateLowering` +/// class below). +class AddOpConversion : public DialectOpConversion { +public: + explicit AddOpConversion(MLIRContext *context) + : DialectOpConversion(toy::AddOp::getOperationName(), 1, context) {} + + /// Lower the `op` by generating IR using the `rewriter` builder. The builder + /// is setup with a new function, the `operands` array has been populated with + /// the rewritten operands for `op` in the new function. + /// The results created by the new IR with the builder are returned, and their + /// number must match the number of result of `op`. + SmallVector rewrite(Operation *op, ArrayRef operands, + FuncBuilder &rewriter) const override { + auto add = op->cast(); + auto loc = add.getLoc(); + // Create a `toy.alloc` operation to allocate the output buffer for this op. + Value *result = memRefTypeCast( + rewriter, rewriter.create(loc, add.getResult()->getType()) + .getResult()); + Value *lhs = memRefTypeCast(rewriter, operands[0]); + Value *rhs = memRefTypeCast(rewriter, operands[1]); + + using namespace edsc; + ScopedContext scope(rewriter, loc); + ValueHandle zero = intrinsics::constant_index(0); + MemRefView vRes(result), vLHS(lhs), vRHS(rhs); + IndexedValue iRes(result), iLHS(lhs), iRHS(rhs); + IndexHandle i, j, M(vRes.ub(0)); + if (vRes.rank() == 1) { + LoopNestBuilder({&i}, {zero}, {M}, {1})({iRes(i) = iLHS(i) + iRHS(i)}); + } else { + assert(vRes.rank() == 2 && "only rank 1 and 2 are supported right now"); + IndexHandle N(vRes.ub(1)); + LoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, + {1, 1})({iRes(i, j) = iLHS(i, j) + iRHS(i, j)}); + } + + // Return the newly allocated buffer, with a type.cast to preserve the + // consumers. + return {typeCast(rewriter, result, add.getType())}; + } +}; + +/// Lowers `toy.print` to a loop nest calling `printf` on every individual +/// elements of the array. +class PrintOpConversion : public DialectOpConversion { +public: + explicit PrintOpConversion(MLIRContext *context) + : DialectOpConversion(toy::PrintOp::getOperationName(), 1, context) {} + + SmallVector rewrite(Operation *op, ArrayRef operands, + FuncBuilder &rewriter) const override { + // Get or create the declaration of the printf function in the module. + Function *printfFunc = getPrintf(*op->getFunction()->getModule()); + + auto print = op->cast(); + auto loc = print.getLoc(); + // We will operate on a MemRef abstraction, we use a type.cast to get one + // if our operand is still a Toy array. + Value *operand = memRefTypeCast(rewriter, operands[0]); + Type retTy = printfFunc->getType().getResult(0); + + // Create our loop nest now + using namespace edsc; + using llvmCall = intrinsics::ValueBuilder; + ScopedContext scope(rewriter, loc); + ValueHandle zero = intrinsics::constant_index(0); + ValueHandle fmtCst(getConstantCharBuffer(rewriter, loc, "%f ")); + MemRefView vOp(operand); + IndexedValue iOp(operand); + IndexHandle i, j, M(vOp.ub(0)); + + ValueHandle fmtEol(getConstantCharBuffer(rewriter, loc, "\n")); + if (vOp.rank() == 1) { + // clang-format off + LoopBuilder(&i, zero, M, 1)({ + llvmCall(retTy, + rewriter.getFunctionAttr(printfFunc), + {fmtCst, iOp(i)}) + }); + llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol}); + // clang-format on + } else { + IndexHandle N(vOp.ub(1)); + // clang-format off + LoopBuilder(&i, zero, M, 1)({ + LoopBuilder(&j, zero, N, 1)({ + llvmCall(retTy, + rewriter.getFunctionAttr(printfFunc), + {fmtCst, iOp(i, j)}) + }), + llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol}) + }); + // clang-format on + } + return {}; + } + +private: + // Turn a string into a toy.alloc (malloc/free abstraction) and a sequence + // of stores into the buffer, and return a MemRef into the buffer. + Value *getConstantCharBuffer(FuncBuilder &builder, Location loc, + StringRef data) const { + auto retTy = + builder.getMemRefType(data.size() + 1, builder.getIntegerType(8)); + Value *result = builder.create(loc, retTy).getResult(); + using namespace edsc; + using intrinsics::constant_index; + using intrinsics::constant_int; + ScopedContext scope(builder, loc); + MemRefView vOp(result); + IndexedValue iOp(result); + for (uint64_t i = 0; i < data.size(); ++i) { + iOp(constant_index(i)) = constant_int(data[i], 8); + } + iOp(constant_index(data.size())) = constant_int(0, 8); + return result; + } + + /// Return the prototype declaration for printf in the module, create it if + /// necessary. + Function *getPrintf(Module &module) const { + auto *printfFunc = module.getNamedFunction("printf"); + if (printfFunc) + return printfFunc; + + // Create a function declaration for printf, signature is `i32 (i8*, ...)` + Builder builder(&module); + MLIRContext *context = module.getContext(); + LLVM::LLVMDialect *llvmDialect = static_cast( + module.getContext()->getRegisteredDialect("llvm")); + auto &llvmModule = llvmDialect->getLLVMModule(); + llvm::IRBuilder<> llvmBuilder(llvmModule.getContext()); + + auto llvmI32Ty = LLVM::LLVMType::get(context, llvmBuilder.getIntNTy(32)); + auto llvmI8PtrTy = + LLVM::LLVMType::get(context, llvmBuilder.getIntNTy(8)->getPointerTo()); + auto printfTy = builder.getFunctionType({llvmI8PtrTy}, {llvmI32Ty}); + printfFunc = new Function(builder.getUnknownLoc(), "printf", printfTy); + // It should be variadic, but we don't support it fully just yet. + printfFunc->setAttr("std.varargs", builder.getBoolAttr(true)); + module.getFunctions().push_back(printfFunc); + return printfFunc; + } +}; + +/// Lowers constant to a sequence of store in a buffer. +class ConstantOpConversion : public DialectOpConversion { +public: + explicit ConstantOpConversion(MLIRContext *context) + : DialectOpConversion(toy::ConstantOp::getOperationName(), 1, context) {} + + SmallVector rewrite(Operation *op, ArrayRef operands, + FuncBuilder &rewriter) const override { + toy::ConstantOp cstOp = op->cast(); + auto loc = cstOp.getLoc(); + auto retTy = cstOp.getResult()->getType().cast(); + auto shape = retTy.getShape(); + Value *result = memRefTypeCast( + rewriter, rewriter.create(loc, retTy).getResult()); + + auto cstValue = cstOp.getValue(); + auto f64Ty = rewriter.getF64Type(); + using namespace edsc; + using intrinsics::constant_float; + using intrinsics::constant_index; + ScopedContext scope(rewriter, loc); + MemRefView vOp(result); + IndexedValue iOp(result); + for (uint64_t i = 0; i < shape[0]; ++i) { + if (shape.size() == 1) { + auto value = cstValue.getValue(ArrayRef{i}) + .cast() + .getValue(); + iOp(constant_index(i)) = constant_float(value, f64Ty); + continue; + } + for (uint64_t j = 0; j < shape[1]; ++j) { + auto value = cstValue.getValue(ArrayRef{i, j}) + .cast() + .getValue(); + iOp(constant_index(i), constant_index(j)) = + constant_float(value, f64Ty); + } + } + return {result}; + } +}; + +/// Lower transpose operation to an affine loop nest. +class TransposeOpConversion : public DialectOpConversion { +public: + explicit TransposeOpConversion(MLIRContext *context) + : DialectOpConversion(toy::TransposeOp::getOperationName(), 1, context) {} + + SmallVector rewrite(Operation *op, ArrayRef operands, + FuncBuilder &rewriter) const override { + auto transpose = op->cast(); + auto loc = transpose.getLoc(); + Value *result = memRefTypeCast( + rewriter, + rewriter.create(loc, transpose.getResult()->getType()) + .getResult()); + Value *operand = memRefTypeCast(rewriter, operands[0]); + + using namespace edsc; + ScopedContext scope(rewriter, loc); + ValueHandle zero = intrinsics::constant_index(0); + MemRefView vRes(result), vOperand(operand); + IndexedValue iRes(result), iOperand(operand); + IndexHandle i, j, M(vRes.ub(0)), N(vRes.ub(1)); + // clang-format off + LoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, {1, 1})({ + iRes(i, j) = iOperand(j, i) + }); + // clang-format on + + return {typeCast(rewriter, result, transpose.getType())}; + } +}; + +// Lower toy.return to standard return operation. +class ReturnOpConversion : public DialectOpConversion { +public: + explicit ReturnOpConversion(MLIRContext *context) + : DialectOpConversion(toy::ReturnOp::getOperationName(), 1, context) {} + + SmallVector rewrite(Operation *op, ArrayRef operands, + FuncBuilder &rewriter) const override { + auto retOp = op->cast(); + using namespace edsc; + auto loc = retOp.getLoc(); + // Argument is optional, handle both cases. + if (retOp.getNumOperands()) + rewriter.create(loc, operands[0]); + else + rewriter.create(loc); + return {}; + } +}; + +/// This is the main class registering our individual converter classes with +/// the DialectConversion framework in MLIR. +class LateLowering : public DialectConversion { +protected: + /// Initialize the list of converters. + llvm::DenseSet + initConverters(MLIRContext *context) override { + return ConversionListBuilder::build(&allocator, + context); + } + + /// Convert a Toy type, this gets called for block and region arguments, and + /// attributes. + Type convertType(Type t) override { + if (auto array = t.cast()) { + return array.toMemref(); + } + return t; + } + +private: + llvm::BumpPtrAllocator allocator; +}; + +/// This is lowering to Linalg the parts that can be (matmul and add on arrays) +/// and is targeting LLVM otherwise. +struct LateLoweringPass : public ModulePass { + + void runOnModule() override { + // Perform Toy specific lowering + if (failed(LateLowering().convert(&getModule()))) { + getModule().getContext()->emitError( + UnknownLoc::get(getModule().getContext()), "Error lowering Toy\n"); + signalPassFailure(); + } + // At this point the IR is almost using only standard and affine dialects. + // A few things remain before we emit LLVM IR. First to reuse as much of + // MLIR as possible we will try to lower everything to the standard and/or + // affine dialect: they already include conversion to the LLVM dialect. + + // First patch calls type to return memref instead of ToyArray + for (auto &function : getModule()) { + function.walk([&](Operation *op) { + auto callOp = op->dyn_cast(); + if (!callOp) + return; + if (!callOp.getNumResults()) + return; + auto retToyTy = + callOp.getResult(0)->getType().dyn_cast(); + if (!retToyTy) + return; + callOp.getResult(0)->setType(retToyTy.toMemref()); + }); + } + + for (auto &function : getModule()) { + function.walk([&](Operation *op) { + // Turns toy.alloc into sequence of alloc/dealloc (later malloc/free). + if (auto allocOp = op->dyn_cast()) { + auto result = allocTensor(allocOp); + allocOp.replaceAllUsesWith(result); + allocOp.erase(); + return; + } + // Eliminate all type.cast before lowering to LLVM. + if (auto typeCastOp = op->dyn_cast()) { + typeCastOp.replaceAllUsesWith(typeCastOp.getOperand()); + typeCastOp.erase(); + return; + } + }); + } + + // Lower Linalg to affine + for (auto &function : getModule()) + linalg::lowerToLoops(&function); + + getModule().dump(); + + // Finally convert to LLVM Dialect + linalg::convertLinalg3ToLLVM(getModule()); + } + + /// Allocate buffers (malloc/free) for Toy operations. This can't be done as + /// part of dialect conversion framework since we need to insert `dealloc` + /// operations just before the return, but the conversion framework is + /// operating in a brand new function: we don't have the return to hook the + /// dealloc operations. + Value *allocTensor(toy::AllocOp alloc) { + FuncBuilder builder(alloc); + auto retTy = alloc.getResult()->getType(); + + auto memRefTy = retTy.dyn_cast(); + if (!memRefTy) + memRefTy = retTy.cast().toMemref(); + if (!memRefTy) { + alloc.emitOpError("is expected to allocate a Toy array or a MemRef"); + llvm_unreachable("fatal error"); + } + auto loc = alloc.getLoc(); + Value *result = builder.create(loc, memRefTy).getResult(); + + // Insert a `dealloc` operation right before the `return` operations, unless + // it is returned itself in which case the caller is responsible for it. + builder.getFunction()->walk([&](Operation *op) { + auto returnOp = op->dyn_cast(); + if (!returnOp) + return; + if (returnOp.getNumOperands() && returnOp.getOperand(0) == alloc) + return; + builder.setInsertionPoint(returnOp); + builder.create(alloc.getLoc(), result); + }); + return result; + } +}; +} // end anonymous namespace + +namespace toy { +Pass *createLateLoweringPass() { return new LateLoweringPass(); } + +std::unique_ptr makeToyLateLowering() { + return llvm::make_unique(); +} + +} // namespace toy diff --git a/mlir/MLIRGen.cpp b/mlir/MLIRGen.cpp new file mode 100644 index 0000000..e2001fb --- /dev/null +++ b/mlir/MLIRGen.cpp @@ -0,0 +1,480 @@ +//===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a simple IR generation targeting MLIR from a Module AST +// for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/MLIRGen.h" +#include "toy/AST.h" +#include "toy/Dialect.h" + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" +#include "mlir/StandardOps/Ops.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopedHashTable.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace toy; +using llvm::cast; +using llvm::dyn_cast; +using llvm::isa; +using llvm::make_unique; +using llvm::ScopedHashTableScope; +using llvm::SmallVector; +using llvm::StringRef; +using llvm::Twine; + +namespace { + +/// Implementation of a simple MLIR emission from the Toy AST. +/// +/// This will emit operations that are specific to the Toy language, preserving +/// the semantics of the language and (hopefully) allow to perform accurate +/// analysis and transformation based on these high level semantics. +/// +/// At this point we take advantage of the "raw" MLIR APIs to create operations +/// that haven't been registered in any way with MLIR. These operations are +/// unknown to MLIR, custom passes could operate by string-matching the name of +/// these operations, but no other type checking or semantic is associated with +/// them natively by MLIR. +class MLIRGenImpl { +public: + MLIRGenImpl(mlir::MLIRContext &context) : context(context) {} + + /// Public API: convert the AST for a Toy module (source file) to an MLIR + /// Module. + std::unique_ptr mlirGen(ModuleAST &moduleAST) { + // We create an empty MLIR module and codegen functions one at a time and + // add them to the module. + theModule = make_unique(&context); + + for (FunctionAST &F : moduleAST) { + auto func = mlirGen(F); + if (!func) + return nullptr; + theModule->getFunctions().push_back(func.release()); + } + + // FIXME: (in the next chapter...) without registering a dialect in MLIR, + // this won't do much, but it should at least check some structural + // properties. + if (failed(theModule->verify())) { + context.emitError(mlir::UnknownLoc::get(&context), + "Module verification error"); + return nullptr; + } + + return std::move(theModule); + } + +private: + /// In MLIR (like in LLVM) a "context" object holds the memory allocation and + /// the ownership of many internal structure of the IR and provide a level + /// of "uniquing" across multiple modules (types for instance). + mlir::MLIRContext &context; + + /// A "module" matches a source file: it contains a list of functions. + std::unique_ptr theModule; + + /// The builder is a helper class to create IR inside a function. It is + /// re-initialized every time we enter a function and kept around as a + /// convenience for emitting individual operations. + /// The builder is stateful, in particular it keeeps an "insertion point": + /// this is where the next operations will be introduced. + std::unique_ptr builder; + + /// The symbol table maps a variable name to a value in the current scope. + /// Entering a function creates a new scope, and the function arguments are + /// added to the mapping. When the processing of a function is terminated, the + /// scope is destroyed and the mappings created in this scope are dropped. + llvm::ScopedHashTable symbolTable; + + /// Helper conversion for a Toy AST location to an MLIR location. + mlir::FileLineColLoc loc(Location loc) { + return mlir::FileLineColLoc::get( + mlir::UniquedFilename::get(*loc.file, &context), loc.line, loc.col, + &context); + } + + /// Declare a variable in the current scope, return true if the variable + /// wasn't declared yet. + bool declare(llvm::StringRef var, mlir::Value *value) { + if (symbolTable.count(var)) { + return false; + } + symbolTable.insert(var, value); + return true; + } + + /// Create the prototype for an MLIR function with as many arguments as the + /// provided Toy AST prototype. + mlir::Function *mlirGen(PrototypeAST &proto) { + // This is a generic function, the return type will be inferred later. + llvm::SmallVector ret_types; + // Arguments type is uniformly a generic array. + llvm::SmallVector arg_types(proto.getArgs().size(), + getType(VarType{})); + auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context); + auto *function = new mlir::Function(loc(proto.loc()), proto.getName(), + func_type, /* attrs = */ {}); + + // Mark the function as generic: it'll require type specialization for every + // call site. + if (function->getNumArguments()) + function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context)); + + return function; + } + + /// Emit a new function and add it to the MLIR module. + std::unique_ptr mlirGen(FunctionAST &funcAST) { + // Create a scope in the symbol table to hold variable declarations. + ScopedHashTableScope var_scope(symbolTable); + + // Create an MLIR function for the given prototype. + std::unique_ptr function(mlirGen(*funcAST.getProto())); + if (!function) + return nullptr; + + // Let's start the body of the function now! + // In MLIR the entry block of the function is special: it must have the same + // argument list as the function itself. + function->addEntryBlock(); + + auto &entryBlock = function->front(); + auto &protoArgs = funcAST.getProto()->getArgs(); + // Declare all the function arguments in the symbol table. + for (const auto &name_value : + llvm::zip(protoArgs, entryBlock.getArguments())) { + declare(std::get<0>(name_value)->getName(), std::get<1>(name_value)); + } + + // Create a builder for the function, it will be used throughout the codegen + // to create operations in this function. + builder = llvm::make_unique(function.get()); + + // Emit the body of the function. + if (!mlirGen(*funcAST.getBody())) + return nullptr; + + // Implicitly return void if no return statement was emited. + // FIXME: we may fix the parser instead to always return the last expression + // (this would possibly help the REPL case later) + if (function->getBlocks().back().back().getName().getStringRef() != + "toy.return") { + ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None); + mlirGen(fakeRet); + } + + return function; + } + + /// Emit a binary operation + mlir::Value *mlirGen(BinaryExprAST &binop) { + // First emit the operations for each side of the operation before emitting + // the operation itself. For example if the expression is `a + foo(a)` + // 1) First it will visiting the LHS, which will return a reference to the + // value holding `a`. This value should have been emitted at declaration + // time and registered in the symbol table, so nothing would be + // codegen'd. If the value is not in the symbol table, an error has been + // emitted and nullptr is returned. + // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted + // and the result value is returned. If an error occurs we get a nullptr + // and propagate. + // + mlir::Value *L = mlirGen(*binop.getLHS()); + if (!L) + return nullptr; + mlir::Value *R = mlirGen(*binop.getRHS()); + if (!R) + return nullptr; + auto location = loc(binop.loc()); + + // Derive the operation name from the binary operator. At the moment we only + // support '+' and '*'. + switch (binop.getOp()) { + case '+': + return builder->create(location, L, R).getResult(); + break; + case '*': + return builder->create(location, L, R).getResult(); + default: + context.emitError(loc(binop.loc()), + Twine("Error: invalid binary operator '") + + Twine(binop.getOp()) + "'"); + return nullptr; + } + } + + // This is a reference to a variable in an expression. The variable is + // expected to have been declared and so should have a value in the symbol + // table, otherwise emit an error and return nullptr. + mlir::Value *mlirGen(VariableExprAST &expr) { + if (symbolTable.count(expr.getName())) + return symbolTable.lookup(expr.getName()); + context.emitError(loc(expr.loc()), Twine("Error: unknown variable '") + + expr.getName() + "'"); + return nullptr; + } + + // Emit a return operation, return true on success. + bool mlirGen(ReturnExprAST &ret) { + auto location = loc(ret.loc()); + // `return` takes an optional expression, we need to account for it here. + if (!ret.getExpr().hasValue()) { + builder->create(location); + return true; + } + auto *expr = mlirGen(*ret.getExpr().getValue()); + if (!expr) + return false; + builder->create(location, expr); + return true; + } + + // Emit a literal/constant array. It will be emitted as a flattened array of + // data in an Attribute attached to a `toy.constant` operation. + // See documentation on [Attributes](LangRef.md#attributes) for more details. + // Here is an excerpt: + // + // Attributes are the mechanism for specifying constant data in MLIR in + // places where a variable is never allowed [...]. They consist of a name + // and a [concrete attribute value](#attribute-values). It is possible to + // attach attributes to operations, functions, and function arguments. The + // set of expected attributes, their structure, and their interpretation + // are all contextually dependent on what they are attached to. + // + // Example, the source level statement: + // var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + // will be converted to: + // %0 = "toy.constant"() {value: dense, + // [[1.000000e+00, 2.000000e+00, 3.000000e+00], + // [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> memref<2x3xf64> + // + mlir::Value *mlirGen(LiteralExprAST &lit) { + auto location = loc(lit.loc()); + // The attribute is a vector with an attribute per element (number) in the + // array, see `collectData()` below for more details. + std::vector data; + data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, + std::multiplies())); + collectData(lit, data); + + // FIXME: using a tensor type is a HACK here. + // Can we do differently without registering a dialect? Using a string blob? + mlir::Type elementType = mlir::FloatType::getF64(&context); + auto dataType = builder->getTensorType(lit.getDims(), elementType); + + // This is the actual attribute that actually hold the list of values for + // this array literal. + auto dataAttribute = builder->getDenseElementsAttr(dataType, data) + .cast(); + + // Build the MLIR op `toy.constant`, only boilerplate below. + return builder->create(location, lit.getDims(), dataAttribute) + .getResult(); + } + + // Recursive helper function to accumulate the data that compose an array + // literal. It flattens the nested structure in the supplied vector. For + // example with this array: + // [[1, 2], [3, 4]] + // we will generate: + // [ 1, 2, 3, 4 ] + // Individual numbers are wrapped in a light wrapper `mlir::FloatAttr`. + // Attributes are the way MLIR attaches constant to operations and functions. + void collectData(ExprAST &expr, std::vector &data) { + if (auto *lit = dyn_cast(&expr)) { + for (auto &value : lit->getValues()) + collectData(*value, data); + return; + } + assert(isa(expr) && "expected literal or number expr"); + mlir::Type elementType = mlir::FloatType::getF64(&context); + auto attr = mlir::FloatAttr::getChecked( + elementType, cast(expr).getValue(), loc(expr.loc())); + data.push_back(attr); + } + + // Emit a call expression. It emits specific operations for the `transpose` + // builtin. Other identifiers are assumed to be user-defined functions. + mlir::Value *mlirGen(CallExprAST &call) { + auto location = loc(call.loc()); + std::string callee = call.getCallee(); + if (callee == "transpose") { + if (call.getArgs().size() != 1) { + context.emitError( + location, Twine("MLIR codegen encountered an error: toy.transpose " + "does not accept multiple arguments")); + return nullptr; + } + mlir::Value *arg = mlirGen(*call.getArgs()[0]); + return builder->create(location, arg).getResult(); + } + + // Codegen the operands first + SmallVector operands; + for (auto &expr : call.getArgs()) { + auto *arg = mlirGen(*expr); + if (!arg) + return nullptr; + operands.push_back(arg); + } + // Calls to user-defined function are mapped to a custom call that takes + // the callee name as an attribute. + return builder->create(location, call.getCallee(), operands) + .getResult(); + } + + // Emit a call expression. It emits specific operations for two builtins: + // transpose(x) and print(x). Other identifiers are assumed to be user-defined + // functions. Return false on failure. + bool mlirGen(PrintExprAST &call) { + auto *arg = mlirGen(*call.getArg()); + if (!arg) + return false; + auto location = loc(call.loc()); + builder->create(location, arg); + return true; + } + + // Emit a constant for a single number (FIXME: semantic? broadcast?) + mlir::Value *mlirGen(NumberExprAST &num) { + auto location = loc(num.loc()); + mlir::Type elementType = mlir::FloatType::getF64(&context); + auto attr = mlir::FloatAttr::getChecked(elementType, num.getValue(), + loc(num.loc())); + return builder->create(location, attr).getResult(); + } + + // Dispatch codegen for the right expression subclass using RTTI. + mlir::Value *mlirGen(ExprAST &expr) { + switch (expr.getKind()) { + case toy::ExprAST::Expr_BinOp: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Var: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Literal: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Call: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Num: + return mlirGen(cast(expr)); + default: + context.emitError( + loc(expr.loc()), + Twine("MLIR codegen encountered an unhandled expr kind '") + + Twine(expr.getKind()) + "'"); + return nullptr; + } + } + + // Handle a variable declaration, we'll codegen the expression that forms the + // initializer and record the value in the symbol table before returning it. + // Future expressions will be able to reference this variable through symbol + // table lookup. + mlir::Value *mlirGen(VarDeclExprAST &vardecl) { + mlir::Value *value = nullptr; + auto location = loc(vardecl.loc()); + if (auto init = vardecl.getInitVal()) { + value = mlirGen(*init); + if (!value) + return nullptr; + // We have the initializer value, but in case the variable was declared + // with specific shape, we emit a "reshape" operation. It will get + // optimized out later as needed. + if (!vardecl.getType().shape.empty()) { + value = builder + ->create( + location, value, + getType(vardecl.getType()).cast()) + .getResult(); + } + } else { + context.emitError(loc(vardecl.loc()), + "Missing initializer in variable declaration"); + return nullptr; + } + // Register the value in the symbol table + declare(vardecl.getName(), value); + return value; + } + + /// Codegen a list of expression, return false if one of them hit an error. + bool mlirGen(ExprASTList &blockAST) { + ScopedHashTableScope var_scope(symbolTable); + for (auto &expr : blockAST) { + // Specific handling for variable declarations, return statement, and + // print. These can only appear in block list and not in nested + // expressions. + if (auto *vardecl = dyn_cast(expr.get())) { + if (!mlirGen(*vardecl)) + return false; + continue; + } + if (auto *ret = dyn_cast(expr.get())) { + if (!mlirGen(*ret)) + return false; + return true; + } + if (auto *print = dyn_cast(expr.get())) { + if (!mlirGen(*print)) + return false; + continue; + } + // Generic expression dispatch codegen. + if (!mlirGen(*expr)) + return false; + } + return true; + } + + /// Build a type from a list of shape dimensions. Types are `array` followed + /// by an optional dimension list, example: array<2, 2> + /// They are wrapped in a `toy` dialect (see next chapter) and get printed: + /// !toy.array<2, 2> + template mlir::Type getType(T shape) { + SmallVector shape64(shape.begin(), shape.end()); + return ToyArrayType::get(&context, shape64); + } + + /// Build an MLIR type from a Toy AST variable type + /// (forward to the generic getType(T) above). + mlir::Type getType(const VarType &type) { return getType(type.shape); } +}; + +} // namespace + +namespace toy { + +// The public API for codegen. +std::unique_ptr mlirGen(mlir::MLIRContext &context, + ModuleAST &moduleAST) { + return MLIRGenImpl(context).mlirGen(moduleAST); +} + +} // namespace toy diff --git a/mlir/ShapeInferencePass.cpp b/mlir/ShapeInferencePass.cpp new file mode 100644 index 0000000..7e3ea3f --- /dev/null +++ b/mlir/ShapeInferencePass.cpp @@ -0,0 +1,387 @@ +//===- ShapeInferencePass.cpp - Toy Shape Inference / Func Specialization -===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a Module level pass performing interprocedural +// propagation of array shapes through function specialization. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" + +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/StandardOps/Ops.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include + +#define DEBUG_TYPE "toy-shape-inference" + +using namespace toy; +using llvm::MutableArrayRef; +using llvm::SmallVector; +using llvm::SmallVectorImpl; +using llvm::StringRef; +using llvm::Twine; + +/// Create mangled name for function specialization. We will simply append the +/// shape of the arguments to the function name. For example calling +/// +/// "toy.generic_call"(%1, %3) {callee: "foo"} +/// : (!toy<"array<2, 3>">, !toy<"array<2, 3>">) -> !toy<"array"> +/// +/// would be mangled foo_2x3_2x3. This mangling isn't robust as the user could +/// have provide a function with a similar name. But we will claim this as a +/// feature: this allow the user to provide custom specialization! +static std::string mangle(StringRef funcName, + MutableArrayRef operands) { + std::string mangledName; + mangledName.reserve(funcName.size() + operands.size() * 6); + mangledName = funcName; + for (auto &operand : operands) { + auto arrayTy = operand.get()->getType().cast(); + mangledName += "_"; + const char *sep = ""; + for (auto dim : arrayTy.getShape()) { + mangledName += (sep + Twine(dim)).str(); + sep = "x"; + } + } + return mangledName; +} + +namespace { + +/// The ShapeInferencePass is a ModulePass: it will run on the Module as a +/// whole. MLIR also supports FunctionPass which are restricted to modify a +/// single function at a time. This pass couldn't be a function pass due the +/// nature of its interprocedural transformations. +/// +/// The algorithm has two levels, first intra-procedurally: +/// +/// 1) Build a worklist containing all the operations that are returning +/// a generic Toy array: these are the operations that need shape +/// inference. +/// 2) Iterate on the worklist: +/// a) find an operation to process: the next ready operation in the +/// worklist has all of its arguments non-generic, +/// b) if no operation is found, break out of the loop, +/// c) remove the operation from the worklist, +/// d) infer the shape of its output from the arguments type. +/// 3) If the worklist is empty, the algorithm succeeded and we infer the +/// return type for the function from the return operation. +/// +/// There is a twist though: when a call to a generic function is encountered, +/// shape inference requires the return type of the callee to be inferred first. +/// At this point we need to run specialize the callee by cloning it. Here is +/// the inter-procedural flow: +/// +/// 1) Keep a worklist of function to process. Start with function "main". +/// 2) While the worklist isn't empty: +/// a) Take the last inserted function in the worklist. +/// b) Run the intra-procedural shape inference on this function. +/// c) If the intra-procedural shape inference can't complete, it returns +/// a Function that needs to be inferred first. In this case, queue this +/// new function and continue. Otherwise the inference succeeded and we +/// can pop from the queue. +/// +class ShapeInferencePass : public mlir::ModulePass { +public: + // One entry in the inter-procedural worklist. It keeps track of the + // function to process, the mangled name for this specialization, and the + // types of the arguments on which to specialize. + struct FunctionToSpecialize { + mlir::Function *function; + std::string mangledName; + std::vector argumentsType; + }; + + void runOnModule() override { + auto &module = getModule(); + auto *main = module.getNamedFunction("main"); + if (!main) { + module.getContext()->emitError( + mlir::UnknownLoc::get(module.getContext()), + "Shape inference failed: can't find a main function\n"); + signalPassFailure(); + return; + } + + /// Inter-procedural loop, initialize with `main` and iterate till + /// successfully infer the full reachable call-graph from main. + SmallVector worklist; + worklist.push_back({main, "", {}}); + while (!worklist.empty()) { + if (failed(specialize(worklist))) + return; + } + + // Delete any generic function left + // FIXME: we may want this as a separate pass. + for (mlir::Function &function : llvm::make_early_inc_range(module)) { + if (auto genericAttr = + function.getAttrOfType("toy.generic")) { + if (genericAttr.getValue()) + function.erase(); + } + } + } + + /// Run inference on a function. If a mangledName is provided, we need to + /// specialize the function: to this end clone it first. + mlir::LogicalResult + specialize(SmallVectorImpl &funcWorklist) { + FunctionToSpecialize &functionToSpecialize = funcWorklist.back(); + mlir::Function *f = functionToSpecialize.function; + + // Check if cloning for specialization is needed (usually anything but main) + // We will create a new function with the concrete types for the parameters + // and clone the body into it. + if (!functionToSpecialize.mangledName.empty()) { + if (getModule().getNamedFunction(functionToSpecialize.mangledName)) { + funcWorklist.pop_back(); + // Function already specialized, move on. + return mlir::success(); + } + // Create a new function with a generic array return type, it will be + // updated when the inference for the function body completes. + auto type = mlir::FunctionType::get(functionToSpecialize.argumentsType, + {ToyArrayType::get(&getContext())}, + &getContext()); + auto *newFunction = new mlir::Function( + f->getLoc(), functionToSpecialize.mangledName, type, f->getAttrs()); + getModule().getFunctions().push_back(newFunction); + + // Clone the function body + mlir::BlockAndValueMapping mapper; + f->cloneInto(newFunction, mapper); + LLVM_DEBUG({ + llvm::dbgs() << "====== Cloned : \n"; + f->dump(); + llvm::dbgs() << "====== Into : \n"; + newFunction->dump(); + }); + f = newFunction; + f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); + // Remap the entry-block arguments + // FIXME: this seems like a bug in `cloneInto()` above? + auto &entryBlock = f->getBlocks().front(); + int blockArgSize = entryBlock.getArguments().size(); + assert(blockArgSize == f->getType().getInputs().size()); + entryBlock.addArguments(f->getType().getInputs()); + auto argList = entryBlock.getArguments(); + for (int argNum = 0; argNum < blockArgSize; ++argNum) { + argList[0]->replaceAllUsesWith(argList[blockArgSize]); + entryBlock.eraseArgument(0); + } + assert(succeeded(f->verify())); + } + LLVM_DEBUG(llvm::dbgs() + << "Run shape inference on : '" << f->getName() << "'\n"); + + auto *toyDialect = getContext().getRegisteredDialect("toy"); + if (!toyDialect) { + getContext().emitError(mlir::UnknownLoc::get(&getContext()), + "Toy dialect is not registered"); + signalPassFailure(); + return mlir::failure(); + } + + // Populate the worklist with the operations that need shape inference: + // these are the Toy operations that return a generic array. + llvm::SmallPtrSet opWorklist; + f->walk([&](mlir::Operation *op) { + if (op->getDialect() == toyDialect) { + if (op->getNumResults() == 1 && + op->getResult(0)->getType().cast().isGeneric()) + opWorklist.insert(op); + } + }); + + // Iterate on the operations in the worklist until all operations have been + // inferred or no change happened (fix point). + while (!opWorklist.empty()) { + // Find the next operation ready for inference, that is an operation + // with all operands already resolved (non-generic). + auto nextop = llvm::find_if(opWorklist, [](mlir::Operation *op) { + return llvm::all_of(op->getOperands(), [](mlir::Value *v) { + return !v->getType().cast().isGeneric(); + }); + }); + if (nextop == opWorklist.end()) + break; // failure: no operations can be inferred. + + mlir::Operation *op = *nextop; + opWorklist.erase(op); + LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); + + // The add operation is trivial: propagate the input type as is. + if (auto addOp = op->dyn_cast()) { + op->getResult(0)->setType(op->getOperand(0)->getType()); + continue; + } + + // Transpose is easy: just invert the dimensions. + if (op->getName().getStringRef() == "toy.transpose") { + SmallVector dims; + auto arrayTy = op->getOperand(0)->getType().cast(); + dims.insert(dims.end(), arrayTy.getShape().begin(), + arrayTy.getShape().end()); + if (dims.size() == 2) + std::swap(dims[0], dims[1]); + op->getResult(0)->setType(ToyArrayType::get(&getContext(), dims)); + continue; + } + + // Multiplication is a bit trickier, handle rank 1 as dot product and rank + // 2 as matrix multiplications. + // We need to be careful about rank mismatch here: the verifier could + // catch it but shape inference earlier in the pass could generate an + // invalid IR (from an invalid Toy input of course) and we wouldn't want + // to crash here. + if (auto mulOp = op->dyn_cast()) { + auto lhs = mulOp.getLHS()->getType().cast(); + auto rhs = mulOp.getRHS()->getType().cast(); + auto lhsRank = lhs.getShape().size(); + auto rhsRank = rhs.getShape().size(); + if (lhsRank != rhsRank) { + op->emitError("Shape mismatch: LHS and RHS must have the same " + "rank for multiplication, got " + + Twine(lhsRank) + " vs " + Twine(lhsRank)); + return mlir::failure(); + } + SmallVector dims; + if (lhsRank == 1) { + // dot product, result shape is <1> + dims.push_back(1); + } else { + if (lhsRank != 2) { + op->emitError( + "Shape mismatch: expect rank 1 or 2 for mul operands, got " + + Twine(lhsRank)); + return mlir::failure(); + } + dims.push_back(lhs.getShape()[0]); + dims.push_back(rhs.getShape()[1]); + } + op->getResult(0)->setType(ToyArrayType::get(&getContext(), dims)); + continue; + } + + // Process calls: lookup the callee after mangling the name with the + // argument shapes. If the callee does not exist, we stop the inference + // for this function, queue the callee in the inter-procedural work list, + // and return. The current function stays in the work list and will + // restart after the callee is processed. + if (auto callOp = op->dyn_cast()) { + auto calleeName = callOp.getCalleeName(); + auto *callee = getModule().getNamedFunction(calleeName); + if (!callee) { + f->emitError( + llvm::Twine("Shape inference failed, call to unknown '") + + calleeName + "'"); + signalPassFailure(); + return mlir::failure(); + } + auto mangledName = mangle(calleeName, op->getOpOperands()); + LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName + << "', mangled: '" << mangledName << "'\n"); + auto *mangledCallee = getModule().getNamedFunction(mangledName); + if (!mangledCallee) { + // Can't find the target, this is where we queue the request for the + // callee and stop the inference for the current function now. + std::vector funcArgs; + for (auto operand : op->getOperands()) + funcArgs.push_back(operand->getType()); + funcWorklist.push_back( + {callee, std::move(mangledName), std::move(funcArgs)}); + return mlir::success(); + } + // Found a specialized callee! Let's turn this into a normal call + // operation. + SmallVector operands; + for (mlir::Value *v : op->getOperands()) + operands.push_back(v); + mlir::FuncBuilder builder(f); + builder.setInsertionPoint(op); + auto newCall = + builder.create(op->getLoc(), mangledCallee, operands); + if (newCall.getNumResults()) { + op->getResult(0)->replaceAllUsesWith(newCall.getResult(0)); + op->erase(); + continue; + } + } + } + + // Done with inference on this function, removing it from the worklist. + funcWorklist.pop_back(); + // Mark the function as non-generic now that inference has succeeded + f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); + + // If the operation worklist isn't empty, this indicates a failure. + if (!opWorklist.empty()) { + std::string str; + llvm::raw_string_ostream errorMsg(str); + errorMsg << "Shape inference failed, " << opWorklist.size() + << " operations couldn't be inferred\n"; + for (auto *ope : opWorklist) + errorMsg << " - " << *ope << "\n"; + f->emitError(errorMsg.str()); + signalPassFailure(); + return mlir::failure(); + } + + // Finally, update the return type of the function based on the argument to + // the return operation. + for (auto &block : f->getBlocks()) { + auto ret = block.getTerminator()->cast(); + if (!ret) + continue; + if (ret.getNumOperands() && + f->getType().getResult(0) == ret.getOperand()->getType()) + // type match, we're done + break; + SmallVector retTy; + if (ret.getNumOperands()) + retTy.push_back(ret.getOperand()->getType()); + mlir::Type elementType = mlir::FloatType::getF64(&getContext()); + std::vector argumentsType; + for (auto arg : f->getArguments()) + argumentsType.push_back(arg->getType()); + auto newType = + mlir::FunctionType::get(argumentsType, retTy, &getContext()); + f->setType(newType); + assert(succeeded(f->verify())); + break; + } + return mlir::success(); + } +}; +} // end anonymous namespace + +namespace toy { +mlir::Pass *createShapeInferencePass() { return new ShapeInferencePass(); } +} // namespace toy diff --git a/mlir/ToyCombine.cpp b/mlir/ToyCombine.cpp new file mode 100644 index 0000000..8d6aed6 --- /dev/null +++ b/mlir/ToyCombine.cpp @@ -0,0 +1,209 @@ +//===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a simple combiner for optimizing pattern in the Toy +// dialect. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" + +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" + +#include + +namespace toy { + +namespace { + +/// Fold transpose(transpose(x)) -> transpose(x) +struct SimplifyRedundantTranspose : public mlir::RewritePattern { + /// We register this pattern to match every toy.transpose in the IR. + /// The "benefit" is used by the framework to order the patterns and process + /// them in order of profitability. + SimplifyRedundantTranspose(mlir::MLIRContext *context) + : RewritePattern(TransposeOp::getOperationName(), /* benefit = */ 1, + context) {} + + /// This method is attempting to match a pattern and rewrite it. The rewriter + /// argument is the orchestrator of the sequence of rewrites. It is expected + /// to interact with it to perform any changes to the IR from here. + mlir::PatternMatchResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + // We can directly cast the current operation as this will only get invoked + // on TransposeOp. + TransposeOp transpose = op->cast(); + // look through the input to the current transpose + mlir::Value *transposeInput = transpose.getOperand(); + mlir::Operation *transposeInputInst = transposeInput->getDefiningOp(); + // If the input is defined by another Transpose, bingo! + TransposeOp transposeInputOp = + mlir::dyn_cast_or_null(transposeInputInst); + if (!transposeInputOp) + return matchFailure(); + + // Use the rewriter to perform the replacement + rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp}); + return matchSuccess(); + } +}; + +/// Fold reshape(constant(x)) -> constant(x'), with x' being reshaped in place. +struct SimplifyReshapeConstant : public mlir::RewritePattern { + SimplifyReshapeConstant(mlir::MLIRContext *context) + : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1, + context) {} + + mlir::PatternMatchResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + ReshapeOp reshape = op->cast(); + // look through the input to the current reshape + mlir::Value *reshapeInput = reshape.getOperand(); + mlir::Operation *reshapeInputInst = reshapeInput->getDefiningOp(); + // If the input is defined by another reshape, bingo! + ConstantOp constantOp = + mlir::dyn_cast_or_null(reshapeInputInst); + if (!constantOp) + return matchFailure(); + + auto reshapeType = op->getResult(0)->getType().cast(); + if (auto valueAttr = + constantOp.getAttrOfType("value")) { + // FIXME Check matching of element count! + // auto oldType = constantOp.getType(); + auto newType = rewriter.getTensorType( + reshapeType.getShape(), valueAttr.getType().getElementType()); + auto newAttr = + mlir::DenseElementsAttr::get(newType, valueAttr.getRawData()); + auto newConstant = rewriter.create( + constantOp.getLoc(), reshapeType.getShape(), newAttr); + rewriter.replaceOp(op, {newConstant}); + } else if (auto valueAttr = + constantOp.getAttrOfType("value")) { + // Broadcast + auto dataSize = std::accumulate(reshapeType.getShape().begin(), + reshapeType.getShape().end(), 1, + std::multiplies()); + std::vector data(dataSize, valueAttr); + auto tensorTy = rewriter.getTensorType(reshapeType.getShape(), + reshapeType.getElementType()); + auto newAttr = mlir::DenseElementsAttr::get(tensorTy, data); + auto newConstant = rewriter.create( + constantOp.getLoc(), reshapeType.getShape(), newAttr); + rewriter.replaceOp(op, {newConstant}); + } else { + llvm_unreachable("Unsupported Constant format"); + } + return matchSuccess(); + } +}; + +/// Fold reshape(reshape(x)) -> reshape(x) +struct SimplifyReshapeReshape : public mlir::RewritePattern { + SimplifyReshapeReshape(mlir::MLIRContext *context) + : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1, + context) {} + + mlir::PatternMatchResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + ReshapeOp reshape = op->cast(); + // look through the input to the current reshape + mlir::Value *reshapeInput = reshape.getOperand(); + mlir::Operation *reshapeInputInst = reshapeInput->getDefiningOp(); + // If the input is defined by another reshape, bingo! + ReshapeOp reshapeInputOp = + mlir::dyn_cast_or_null(reshapeInputInst); + if (!reshapeInputOp) + return matchFailure(); + + // Use the rewriter to perform the replacement + rewriter.replaceOp(op, {reshapeInputOp}); + return matchSuccess(); + } +}; + +/// Fold reshape(x)) -> x, when input type matches output type +struct SimplifyNullReshape : public mlir::RewritePattern { + SimplifyNullReshape(mlir::MLIRContext *context) + : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1, + context) {} + + mlir::PatternMatchResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + ReshapeOp reshape = op->cast(); + if (reshape.getOperand()->getType() != reshape.getResult()->getType()) + return matchFailure(); + rewriter.replaceOp(reshape, {reshape.getOperand()}); + return matchSuccess(); + } +}; + +} // end anonymous namespace. + +// Register our patterns for rewrite by the Canonicalization framework. +void TransposeOp::getCanonicalizationPatterns( + mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { + results.push_back(llvm::make_unique(context)); +} + +// Register our patterns for rewrite by the Canonicalization framework. +void ReshapeOp::getCanonicalizationPatterns( + mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { + results.push_back(llvm::make_unique(context)); + results.push_back(llvm::make_unique(context)); + results.push_back(llvm::make_unique(context)); +} + +namespace { + +/// Fold type.cast(x) -> x, when input type matches output type +struct SimplifyIdentityTypeCast : public mlir::RewritePattern { + SimplifyIdentityTypeCast(mlir::MLIRContext *context) + : RewritePattern(TypeCastOp::getOperationName(), /* benefit = */ 1, + context) {} + + mlir::PatternMatchResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + TypeCastOp typeCast = op->cast(); + auto resTy = typeCast.getResult()->getType(); + auto *candidateOp = op; + while (candidateOp && candidateOp->isa()) { + if (resTy == candidateOp->getOperand(0)->getType()) { + rewriter.replaceOp(typeCast, {candidateOp->getOperand(0)}); + return matchSuccess(); + } + candidateOp = candidateOp->getOperand(0)->getDefiningOp(); + } + return matchFailure(); + } +}; + +} // end anonymous namespace. + +void TypeCastOp::getCanonicalizationPatterns( + mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { + results.push_back(llvm::make_unique(context)); +} + +} // namespace toy diff --git a/mlir/ToyDialect.cpp b/mlir/ToyDialect.cpp new file mode 100644 index 0000000..be117f5 --- /dev/null +++ b/mlir/ToyDialect.cpp @@ -0,0 +1,405 @@ +//===- ToyDialect.cpp - Toy IR Dialect registration in MLIR ---------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements the dialect for the Toy IR: custom type parsing and +// operation verification. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Support/STLExtras.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/Regex.h" +#include "llvm/Support/raw_ostream.h" + +using llvm::ArrayRef; +using llvm::raw_ostream; +using llvm::raw_string_ostream; +using llvm::SmallVector; +using llvm::StringRef; +using llvm::Twine; + +namespace toy { +namespace detail { + +/// This class holds the implementation of the ToyArrayType. +/// It is intended to be uniqued based on its content and owned by the context. +struct ToyArrayTypeStorage : public mlir::TypeStorage { + /// This defines how we unique this type in the context: our key contains + /// only the shape, a more complex type would have multiple entries in the + /// tuple here. + /// The element of the tuples usually matches 1-1 the arguments from the + /// public `get()` method arguments from the facade. + using KeyTy = std::tuple>; + static unsigned hashKey(const KeyTy &key) { + return llvm::hash_combine(std::get<0>(key)); + } + /// When the key hash hits an existing type, we compare the shape themselves + /// to confirm we have the right type. + bool operator==(const KeyTy &key) const { return key == KeyTy(getShape()); } + + /// This is a factory method to create our type storage. It is only + /// invoked after looking up the type in the context using the key and not + /// finding it. + static ToyArrayTypeStorage *construct(mlir::TypeStorageAllocator &allocator, + const KeyTy &key) { + // Copy the shape array into the bumpptr allocator owned by the context. + ArrayRef shape = allocator.copyInto(std::get<0>(key)); + + // Allocate the instance for the ToyArrayTypeStorage itself + auto *storage = allocator.allocate(); + // Initialize the instance using placement new. + return new (storage) ToyArrayTypeStorage(shape); + } + + ArrayRef getShape() const { return shape; } + +private: + ArrayRef shape; + + /// Constructor is only invoked from the `construct()` method above. + ToyArrayTypeStorage(ArrayRef shape) : shape(shape) {} +}; + +} // namespace detail + +mlir::Type ToyArrayType::getElementType() { + return mlir::FloatType::getF64(getContext()); +} + +ToyArrayType ToyArrayType::get(mlir::MLIRContext *context, + ArrayRef shape) { + return Base::get(context, ToyTypeKind::TOY_ARRAY, shape); +} + +ArrayRef ToyArrayType::getShape() { return getImpl()->getShape(); } + +mlir::MemRefType ToyArrayType::toMemref() { + auto memRefType = mlir::MemRefType::get(getShape(), getElementType(), {}, 0); + return memRefType; +} + +/// Dialect creation, the instance will be owned by the context. This is the +/// point of registration of custom types and operations for the dialect. +ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) { + addOperations(); + addTypes(); +} + +/// Parse a type registered to this dialect, we expect only Toy arrays. +mlir::Type ToyDialect::parseType(StringRef tyData, mlir::Location loc) const { + // Sanity check: we only support array or array<...> + if (!tyData.startswith("array")) { + getContext()->emitError(loc, "Invalid Toy type '" + tyData + + "', array expected"); + return nullptr; + } + // Drop the "array" prefix from the type name, we expect either an empty + // string or just the shape. + tyData = tyData.drop_front(StringRef("array").size()); + // This is the generic array case without shape, early return it. + if (tyData.empty()) + return ToyArrayType::get(getContext()); + + // Use a regex to parse the shape (for efficient we should store this regex in + // the dialect itself). + SmallVector matches; + auto shapeRegex = llvm::Regex("^<([0-9]+)(, ([0-9]+))*>$"); + if (!shapeRegex.match(tyData, &matches)) { + getContext()->emitError(loc, "Invalid toy array shape '" + tyData + "'"); + return nullptr; + } + SmallVector shape; + // Iterate through the captures, skip the first one which is the full string. + for (auto dimStr : + llvm::make_range(std::next(matches.begin()), matches.end())) { + if (dimStr.startswith(",")) + continue; // POSIX misses non-capturing groups. + if (dimStr.empty()) + continue; // '*' makes it an optional group capture + // Convert the capture to an integer + unsigned long long dim; + if (getAsUnsignedInteger(dimStr, /* Radix = */ 10, dim)) { + getContext()->emitError( + loc, "Couldn't parse dimension as integer, matched: " + dimStr); + return mlir::Type(); + } + shape.push_back(dim); + } + // Finally we collected all the dimensions in the shape, + // create the array type. + return ToyArrayType::get(getContext(), shape); +} + +/// Print a Toy array type, for example `array<2, 3, 4>` +void ToyDialect::printType(mlir::Type type, raw_ostream &os) const { + auto arrayTy = type.dyn_cast(); + if (!arrayTy) { + os << "unknown toy type"; + return; + } + os << "array"; + if (!arrayTy.getShape().empty()) { + os << "<"; + mlir::interleaveComma(arrayTy.getShape(), os); + os << ">"; + } +} + +//////////////////////////////////////////////////////////////////////////////// +//////////////////// Custom Operations for the Dialect ///////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +/// Helper to verify that the result of an operation is a Toy array type. +template static mlir::LogicalResult verifyToyReturnArray(T *op) { + if (!op->getResult()->getType().template isa()) { + std::string msg; + raw_string_ostream os(msg); + os << "expects a Toy Array for its argument, got " + << op->getResult()->getType(); + return op->emitOpError(os.str()); + } + return mlir::success(); +} + +/// Helper to verify that the two operands of a binary operation are Toy +/// arrays.. +template static mlir::LogicalResult verifyToyBinOperands(T *op) { + if (!op->getOperand(0)->getType().template isa()) { + std::string msg; + raw_string_ostream os(msg); + os << "expects a Toy Array for its LHS, got " + << op->getOperand(0)->getType(); + return op->emitOpError(os.str()); + } + if (!op->getOperand(1)->getType().template isa()) { + std::string msg; + raw_string_ostream os(msg); + os << "expects a Toy Array for its LHS, got " + << op->getOperand(0)->getType(); + return op->emitOpError(os.str()); + } + return mlir::success(); +} + +/// Build a constant operation. +/// The builder is passed as an argument, so is the state that this method is +/// expected to fill in order to build the operation. +void ConstantOp::build(mlir::Builder *builder, mlir::OperationState *state, + ArrayRef shape, mlir::DenseElementsAttr value) { + state->types.push_back(ToyArrayType::get(builder->getContext(), shape)); + auto dataAttribute = builder->getNamedAttr("value", value); + state->attributes.push_back(dataAttribute); +} + +/// Build a constant operation. +/// The builder is passed as an argument, so is the state that this method is +/// expected to fill in order to build the operation. +void ConstantOp::build(mlir::Builder *builder, mlir::OperationState *state, + mlir::FloatAttr value) { + // Broadcast and forward to the other build factory + mlir::Type elementType = mlir::FloatType::getF64(builder->getContext()); + auto dataType = builder->getTensorType({1}, elementType); + auto dataAttribute = builder->getDenseElementsAttr(dataType, {value}) + .cast(); + + ConstantOp::build(builder, state, {1}, dataAttribute); +} + +/// Verifier for constant operation. +mlir::LogicalResult ConstantOp::verify() { + // Ensure that the return type is a Toy array + if (failed(verifyToyReturnArray(this))) + return mlir::failure(); + + // We expect the constant itself to be stored as an attribute. + auto dataAttr = getAttr("value").dyn_cast(); + if (!dataAttr) { + return emitOpError( + "missing valid `value` DenseElementsAttribute on toy.constant()"); + } + auto attrType = dataAttr.getType().dyn_cast(); + if (!attrType) { + return emitOpError( + "missing valid `value` DenseElementsAttribute on toy.constant()"); + } + + // If the return type of the constant is not a generic array, the shape must + // match the shape of the attribute holding the data. + auto resultType = getResult()->getType().cast(); + if (!resultType.isGeneric()) { + if (attrType.getRank() != resultType.getRank()) { + return emitOpError("The rank of the toy.constant return type must match " + "the one of the attached value attribute: " + + Twine(attrType.getRank()) + + " != " + Twine(resultType.getRank())); + } + for (int dim = 0; dim < attrType.getRank(); ++dim) { + if (attrType.getShape()[dim] != resultType.getShape()[dim]) { + std::string msg; + raw_string_ostream os(msg); + return emitOpError( + "Shape mismatch between toy.constant return type and its " + "attribute at dimension " + + Twine(dim) + ": " + Twine(attrType.getShape()[dim]) + + " != " + Twine(resultType.getShape()[dim])); + } + } + } + return mlir::success(); +} + +void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState *state, + StringRef callee, ArrayRef arguments) { + // Generic call always returns a generic ToyArray initially + state->types.push_back(ToyArrayType::get(builder->getContext())); + state->operands.assign(arguments.begin(), arguments.end()); + auto calleeAttr = builder->getStringAttr(callee); + state->attributes.push_back(builder->getNamedAttr("callee", calleeAttr)); +} + +mlir::LogicalResult GenericCallOp::verify() { + // Verify that every operand is a Toy Array + for (int opId = 0, num = getNumOperands(); opId < num; ++opId) { + if (!getOperand(opId)->getType().template isa()) { + std::string msg; + raw_string_ostream os(msg); + os << "expects a Toy Array for its " << opId << " operand, got " + << getOperand(opId)->getType(); + return emitOpError(os.str()); + } + } + return mlir::success(); +} + +/// Return the name of the callee. +StringRef GenericCallOp::getCalleeName() { + return getAttr("callee").cast().getValue(); +} + +template static mlir::LogicalResult verifyToySingleOperand(T *op) { + if (!op->getOperand()->getType().template isa()) { + std::string msg; + raw_string_ostream os(msg); + os << "expects a Toy Array for its argument, got " + << op->getOperand()->getType(); + return op->emitOpError(os.str()); + } + return mlir::success(); +} + +void ReturnOp::build(mlir::Builder *builder, mlir::OperationState *state, + mlir::Value *value) { + // Return does not return any value and has an optional single argument + if (value) + state->operands.push_back(value); +} + +mlir::LogicalResult ReturnOp::verify() { + if (getNumOperands() > 1) + return emitOpError("expects zero or one operand, got " + + Twine(getNumOperands())); + if (hasOperand() && failed(verifyToySingleOperand(this))) + return mlir::failure(); + return mlir::success(); +} + +void PrintOp::build(mlir::Builder *builder, mlir::OperationState *state, + mlir::Value *value) { + // Print does not return any value and has a single argument + state->operands.push_back(value); +} + +mlir::LogicalResult PrintOp::verify() { + if (failed(verifyToySingleOperand(this))) + return mlir::failure(); + return mlir::success(); +} + +void TransposeOp::build(mlir::Builder *builder, mlir::OperationState *state, + mlir::Value *value) { + state->types.push_back(ToyArrayType::get(builder->getContext())); + state->operands.push_back(value); +} + +mlir::LogicalResult TransposeOp::verify() { + if (failed(verifyToySingleOperand(this))) + return mlir::failure(); + return mlir::success(); +} + +void ReshapeOp::build(mlir::Builder *builder, mlir::OperationState *state, + mlir::Value *value, ToyArrayType reshapedType) { + state->types.push_back(reshapedType); + state->operands.push_back(value); +} + +mlir::LogicalResult ReshapeOp::verify() { + if (failed(verifyToySingleOperand(this))) + return mlir::failure(); + auto retTy = getResult()->getType().dyn_cast(); + if (!retTy) + return emitOpError("toy.reshape is expected to produce a Toy array"); + if (retTy.isGeneric()) + return emitOpError("toy.reshape is expected to produce a shaped Toy array, " + "got a generic one."); + return mlir::success(); +} + +void AddOp::build(mlir::Builder *builder, mlir::OperationState *state, + mlir::Value *lhs, mlir::Value *rhs) { + state->types.push_back(ToyArrayType::get(builder->getContext())); + state->operands.push_back(lhs); + state->operands.push_back(rhs); +} + +mlir::LogicalResult AddOp::verify() { + if (failed(verifyToyBinOperands(this))) + return mlir::failure(); + return mlir::success(); +} + +void MulOp::build(mlir::Builder *builder, mlir::OperationState *state, + mlir::Value *lhs, mlir::Value *rhs) { + state->types.push_back(ToyArrayType::get(builder->getContext())); + state->operands.push_back(lhs); + state->operands.push_back(rhs); +} + +mlir::LogicalResult MulOp::verify() { + if (failed(verifyToyBinOperands(this))) + return mlir::failure(); + return mlir::success(); +} + +void AllocOp::build(mlir::Builder *builder, mlir::OperationState *state, + mlir::Type retType) { + state->types.push_back(retType); +} + +void TypeCastOp::build(mlir::Builder *builder, mlir::OperationState *state, + mlir::Value *value, mlir::Type destTy) { + state->operands.push_back(value); + state->types.push_back(destTy); +} + +} // namespace toy diff --git a/parser/AST.cpp b/parser/AST.cpp new file mode 100644 index 0000000..869f2ef --- /dev/null +++ b/parser/AST.cpp @@ -0,0 +1,263 @@ +//===- AST.cpp - Helper for printing out the Toy AST ----------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements the AST dump for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/AST.h" + +#include "llvm/ADT/Twine.h" +#include "llvm/Support/raw_ostream.h" + +using namespace toy; + +namespace { + +// RAII helper to manage increasing/decreasing the indentation as we traverse +// the AST +struct Indent { + Indent(int &level) : level(level) { ++level; } + ~Indent() { --level; } + int &level; +}; + +/// Helper class that implement the AST tree traversal and print the nodes along +/// the way. The only data member is the current indentation level. +class ASTDumper { +public: + void dump(ModuleAST *Node); + +private: + void dump(VarType &type); + void dump(VarDeclExprAST *varDecl); + void dump(ExprAST *expr); + void dump(ExprASTList *exprList); + void dump(NumberExprAST *num); + void dump(LiteralExprAST *Node); + void dump(VariableExprAST *Node); + void dump(ReturnExprAST *Node); + void dump(BinaryExprAST *Node); + void dump(CallExprAST *Node); + void dump(PrintExprAST *Node); + void dump(PrototypeAST *Node); + void dump(FunctionAST *Node); + + // Actually print spaces matching the current indentation level + void indent() { + for (int i = 0; i < curIndent; i++) + llvm::errs() << " "; + } + int curIndent = 0; +}; + +} // namespace + +/// Return a formatted string for the location of any node +template static std::string loc(T *Node) { + const auto &loc = Node->loc(); + return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" + + llvm::Twine(loc.col)) + .str(); +} + +// Helper Macro to bump the indentation level and print the leading spaces for +// the current indentations +#define INDENT() \ + Indent level_(curIndent); \ + indent(); + +/// Dispatch to a generic expressions to the appropriate subclass using RTTI +void ASTDumper::dump(ExprAST *expr) { +#define dispatch(CLASS) \ + if (CLASS *node = llvm::dyn_cast(expr)) \ + return dump(node); + dispatch(VarDeclExprAST); + dispatch(LiteralExprAST); + dispatch(NumberExprAST); + dispatch(VariableExprAST); + dispatch(ReturnExprAST); + dispatch(BinaryExprAST); + dispatch(CallExprAST); + dispatch(PrintExprAST); + // No match, fallback to a generic message + INDENT(); + llvm::errs() << "getKind() << ">\n"; +} + +/// A variable declaration is printing the variable name, the type, and then +/// recurse in the initializer value. +void ASTDumper::dump(VarDeclExprAST *varDecl) { + INDENT(); + llvm::errs() << "VarDecl " << varDecl->getName(); + dump(varDecl->getType()); + llvm::errs() << " " << loc(varDecl) << "\n"; + dump(varDecl->getInitVal()); +} + +/// A "block", or a list of expression +void ASTDumper::dump(ExprASTList *exprList) { + INDENT(); + llvm::errs() << "Block {\n"; + for (auto &expr : *exprList) + dump(expr.get()); + indent(); + llvm::errs() << "} // Block\n"; +} + +/// A literal number, just print the value. +void ASTDumper::dump(NumberExprAST *num) { + INDENT(); + llvm::errs() << num->getValue() << " " << loc(num) << "\n"; +} + +/// Helper to print recurisvely a literal. This handles nested array like: +/// [ [ 1, 2 ], [ 3, 4 ] ] +/// We print out such array with the dimensions spelled out at every level: +/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] +void printLitHelper(ExprAST *lit_or_num) { + // Inside a literal expression we can have either a number or another literal + if (auto num = llvm::dyn_cast(lit_or_num)) { + llvm::errs() << num->getValue(); + return; + } + auto *literal = llvm::cast(lit_or_num); + + // Print the dimension for this literal first + llvm::errs() << "<"; + { + const char *sep = ""; + for (auto dim : literal->getDims()) { + llvm::errs() << sep << dim; + sep = ", "; + } + } + llvm::errs() << ">"; + + // Now print the content, recursing on every element of the list + llvm::errs() << "[ "; + const char *sep = ""; + for (auto &elt : literal->getValues()) { + llvm::errs() << sep; + printLitHelper(elt.get()); + sep = ", "; + } + llvm::errs() << "]"; +} + +/// Print a literal, see the recursive helper above for the implementation. +void ASTDumper::dump(LiteralExprAST *Node) { + INDENT(); + llvm::errs() << "Literal: "; + printLitHelper(Node); + llvm::errs() << " " << loc(Node) << "\n"; +} + +/// Print a variable reference (just a name). +void ASTDumper::dump(VariableExprAST *Node) { + INDENT(); + llvm::errs() << "var: " << Node->getName() << " " << loc(Node) << "\n"; +} + +/// Return statement print the return and its (optional) argument. +void ASTDumper::dump(ReturnExprAST *Node) { + INDENT(); + llvm::errs() << "Return\n"; + if (Node->getExpr().hasValue()) + return dump(*Node->getExpr()); + { + INDENT(); + llvm::errs() << "(void)\n"; + } +} + +/// Print a binary operation, first the operator, then recurse into LHS and RHS. +void ASTDumper::dump(BinaryExprAST *Node) { + INDENT(); + llvm::errs() << "BinOp: " << Node->getOp() << " " << loc(Node) << "\n"; + dump(Node->getLHS()); + dump(Node->getRHS()); +} + +/// Print a call expression, first the callee name and the list of args by +/// recursing into each individual argument. +void ASTDumper::dump(CallExprAST *Node) { + INDENT(); + llvm::errs() << "Call '" << Node->getCallee() << "' [ " << loc(Node) << "\n"; + for (auto &arg : Node->getArgs()) + dump(arg.get()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print a builtin print call, first the builtin name and then the argument. +void ASTDumper::dump(PrintExprAST *Node) { + INDENT(); + llvm::errs() << "Print [ " << loc(Node) << "\n"; + dump(Node->getArg()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print type: only the shape is printed in between '<' and '>' +void ASTDumper::dump(VarType &type) { + llvm::errs() << "<"; + const char *sep = ""; + for (auto shape : type.shape) { + llvm::errs() << sep << shape; + sep = ", "; + } + llvm::errs() << ">"; +} + +/// Print a function prototype, first the function name, and then the list of +/// parameters names. +void ASTDumper::dump(PrototypeAST *Node) { + INDENT(); + llvm::errs() << "Proto '" << Node->getName() << "' " << loc(Node) << "'\n"; + indent(); + llvm::errs() << "Params: ["; + const char *sep = ""; + for (auto &arg : Node->getArgs()) { + llvm::errs() << sep << arg->getName(); + sep = ", "; + } + llvm::errs() << "]\n"; +} + +/// Print a function, first the prototype and then the body. +void ASTDumper::dump(FunctionAST *Node) { + INDENT(); + llvm::errs() << "Function \n"; + dump(Node->getProto()); + dump(Node->getBody()); +} + +/// Print a module, actually loop over the functions and print them in sequence. +void ASTDumper::dump(ModuleAST *Node) { + INDENT(); + llvm::errs() << "Module:\n"; + for (auto &F : *Node) + dump(&F); +} + +namespace toy { + +// Public API +void dump(ModuleAST &module) { ASTDumper().dump(&module); } + +} // namespace toy diff --git a/toyc.cpp b/toyc.cpp new file mode 100644 index 0000000..6c50191 --- /dev/null +++ b/toyc.cpp @@ -0,0 +1,325 @@ +//===- toyc.cpp - The Toy Compiler ----------------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements the entry point for the Toy compiler. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" +#include "toy/Lowering.h" +#include "toy/MLIRGen.h" +#include "toy/Parser.h" +#include "toy/Passes.h" + +#include "linalg1/Dialect.h" +#include "mlir/ExecutionEngine/ExecutionEngine.h" +#include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/Parser.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Target/LLVMIR.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/ADT/StringRef.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorOr.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" + +using namespace toy; +namespace cl = llvm::cl; + +static cl::opt inputFilename(cl::Positional, + cl::desc(""), + cl::init("-"), + cl::value_desc("filename")); + +namespace { +enum InputType { Toy, MLIR }; +} +static cl::opt inputType( + "x", cl::init(Toy), cl::desc("Decided the kind of output desired"), + cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")), + cl::values(clEnumValN(MLIR, "mlir", + "load the input file as an MLIR file"))); + +namespace { +enum Action { + None, + DumpAST, + DumpMLIR, + DumpMLIRLinalg, + DumpLLVMDialect, + DumpLLVMIR, + RunJIT +}; +} +static cl::opt emitAction( + "emit", cl::desc("Select the kind of output desired"), + cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")), + cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump")), + cl::values(clEnumValN(DumpMLIRLinalg, "mlir-linalg", + "output the MLIR dump after linalg lowering")), + cl::values(clEnumValN(DumpLLVMDialect, "llvm-dialect", + "output the LLVM MLIR Dialect dump")), + cl::values(clEnumValN(DumpLLVMIR, "llvm-ir", "output the LLVM IR dump")), + cl::values( + clEnumValN(RunJIT, "jit", + "JIT the code and run it by invoking the main function"))); + +static cl::opt EnableOpt("opt", cl::desc("Enable optimizations")); + +/// Returns a Toy AST resulting from parsing the file or a nullptr on error. +std::unique_ptr parseInputFile(llvm::StringRef filename) { + llvm::ErrorOr> FileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(filename); + if (std::error_code EC = FileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << EC.message() << "\n"; + return nullptr; + } + auto buffer = FileOrErr.get()->getBuffer(); + LexerBuffer lexer(buffer.begin(), buffer.end(), filename); + Parser parser(lexer); + return parser.ParseModule(); +} + +mlir::LogicalResult optimize(mlir::Module &module) { + mlir::PassManager pm; + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(createShapeInferencePass()); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createCSEPass()); + + // Apply any generic pass manager command line options. + applyPassManagerCLOptions(pm); + + return pm.run(&module); +} + +mlir::LogicalResult lowerDialect(mlir::Module &module, bool OnlyLinalg) { + mlir::PassManager pm; + pm.addPass(createEarlyLoweringPass()); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createCSEPass()); + if (!OnlyLinalg) { + pm.addPass(createLateLoweringPass()); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createCSEPass()); + } + // Apply any generic pass manager command line options. + applyPassManagerCLOptions(pm); + + return pm.run(&module); +} + +mlir::LogicalResult lowerLLVMModule(mlir::Module &module) { + mlir::PassManager pm; + pm.addPass(createEarlyLoweringPass()); + pm.addPass(createLateLoweringPass()); + + // Apply any generic pass manager command line options. + applyPassManagerCLOptions(pm); + + return pm.run(&module); +} + +std::unique_ptr loadFileAndProcessModule( + mlir::MLIRContext &context, bool EnableLinalgLowering = false, + bool EnableLLVMLowering = false, bool EnableOpt = false) { + + std::unique_ptr module; + if (inputType == InputType::MLIR || + llvm::StringRef(inputFilename).endswith(".mlir")) { + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); + if (std::error_code EC = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << EC.message() << "\n"; + return nullptr; + } + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + module.reset(mlir::parseSourceFile(sourceMgr, &context)); + if (!module) { + llvm::errs() << "Error can't load file " << inputFilename << "\n"; + return nullptr; + } + if (failed(module->verify())) { + llvm::errs() << "Error verifying MLIR module\n"; + return nullptr; + } + } else { + auto moduleAST = parseInputFile(inputFilename); + module = mlirGen(context, *moduleAST); + } + if (!module) + return nullptr; + if (EnableOpt) { + if (failed(optimize(*module))) { + llvm::errs() << "Module optimization failed\n"; + return nullptr; + } + } + if (EnableLLVMLowering || EnableLinalgLowering) { + if (failed(lowerDialect(*module, !EnableLLVMLowering))) { + llvm::errs() << "Module lowering failed\n"; + return nullptr; + } + } + return module; +} + +int dumpMLIR() { + mlir::MLIRContext context; + auto module = + loadFileAndProcessModule(context, /*EnableLinalgLowering=*/false, + /*EnableLLVMLowering=*/false, EnableOpt); + if (!module) + return -1; + module->dump(); + return 0; +} + +int dumpMLIRLinalg() { + mlir::MLIRContext context; + auto module = loadFileAndProcessModule(context, /*EnableLinalgLowering=*/true, + /*EnableLLVMLowering=*/false, + /* EnableOpt=*/true); + if (!module) + return -1; + module->dump(); + return 0; +} + +int dumpLLVMDialect() { + mlir::MLIRContext context; + auto module = loadFileAndProcessModule( + context, /*EnableLinalgLowering=*/false, /* EnableLLVMLowering=*/true, + /* EnableOpt=*/true); + if (!module) { + llvm::errs() << "Failed to load/lower MLIR module\n"; + return -1; + } + module->dump(); + return 0; +} + +int dumpLLVMIR() { + mlir::MLIRContext context; + auto module = loadFileAndProcessModule( + context, /*EnableLinalgLowering=*/false, /* EnableLLVMLowering=*/true, + /* EnableOpt=*/true); + if (!module) { + llvm::errs() << "Failed to load/lower MLIR module\n"; + return -1; + } + auto llvmModule = translateModuleToLLVMIR(*module); + if (!llvmModule) { + llvm::errs() << "Failed to emit LLVM IR\n"; + return -1; + } + // Initialize LLVM targets. + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + mlir::ExecutionEngine::setupTargetTriple(llvmModule.get()); + auto optPipeline = mlir::makeOptimizingTransformer( + /* optLevel=*/EnableOpt ? 3 : 0, /* sizeLevel=*/0); + if (auto err = optPipeline(llvmModule.get())) { + llvm::errs() << "Failed to optimize LLVM IR " << err << "\n"; + return -1; + } + llvm::errs() << *llvmModule << "\n"; + return 0; +} + +int runJit() { + mlir::MLIRContext context; + auto module = loadFileAndProcessModule( + context, /*EnableLinalgLowering=*/false, /* EnableLLVMLowering=*/true, + /* EnableOpt=*/true); + + // Initialize LLVM targets. + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + + // Create an MLIR execution engine. Note that it takes a null pass manager + // to make sure it won't run "default" passes on the MLIR that would trigger + // a second conversion to LLVM IR. The execution engine eagerly JIT-compiles + // the module. + auto optPipeline = mlir::makeOptimizingTransformer( + /* optLevel=*/EnableOpt ? 3 : 0, /* sizeLevel=*/0); + auto maybeEngine = + mlir::ExecutionEngine::create(module.get(), /*pm=*/nullptr, optPipeline); + assert(maybeEngine && "failed to construct an execution engine"); + auto &engine = maybeEngine.get(); + + // Invoke the JIT-compiled function with the arguments. Note that, for API + // uniformity reasons, it takes a list of type-erased pointers to arguments. + auto invocationResult = engine->invoke("main"); + if (invocationResult) { + llvm::errs() << "JIT invocation failed\n"; + return -1; + } + + return 0; +} + +int dumpAST() { + if (inputType == InputType::MLIR) { + llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n"; + return 5; + } + + auto moduleAST = parseInputFile(inputFilename); + if (!moduleAST) + return 1; + + dump(*moduleAST); + return 0; +} + +int main(int argc, char **argv) { + // Register our Dialects with MLIR + mlir::registerDialect(); + mlir::registerDialect(); + + mlir::registerPassManagerCLOptions(); + cl::ParseCommandLineOptions(argc, argv, "toy compiler\n"); + + switch (emitAction) { + case Action::DumpAST: + return dumpAST(); + case Action::DumpMLIR: + return dumpMLIR(); + case Action::DumpMLIRLinalg: + return dumpMLIRLinalg(); + case Action::DumpLLVMDialect: + return dumpLLVMDialect(); + case Action::DumpLLVMIR: + return dumpLLVMIR(); + case Action::RunJIT: + return runJit(); + default: + llvm::errs() << "No action specified (parsing only?), use -emit=\n"; + return -1; + } + + return 0; +} -- cgit v1.2.3-70-g09d2