diff options
Diffstat (limited to 'include')
35 files changed, 3283 insertions, 0 deletions
diff --git a/include/linalg1/Analysis.h b/include/linalg1/Analysis.h new file mode 100644 index 0000000..ef8fb98 --- /dev/null +++ b/include/linalg1/Analysis.h @@ -0,0 +1,49 @@ +//===- Analysis.h - Linalg dialect Analysis function definitions ----------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG1_ANALYSIS_H_ +#define LINALG1_ANALYSIS_H_ + +#include "mlir/Support/LLVM.h" + +namespace mlir { +class Value; +} // namespace mlir + +namespace linalg { +class ViewOp; + +/// Walks the chain of SliceOp until the unique base ViewOp. +ViewOp getViewBaseViewOp(mlir::Value *view); + +/// Walks the chain of SliceOp until the unique base ViewOp and returns the +/// MemRef upon which the ViewOp is laid. +mlir::Value *getViewSupportingMemRef(mlir::Value *view); + +/// Extract the indexing from the root ViewOp that this slice constrins along +/// `dim`. To achieve this, it walks back the chain of SliceOp and determine the +/// first slice that constrains `dim`. +/// Note that the dimension in the original ViewOp may shift due to +/// rank-reducing operations. +/// Returns a pair, with the indexing as the first element and the actual +/// dimension, in the root ViewOp, as the second element. +std::pair<mlir::Value *, unsigned> getViewRootIndexing(mlir::Value *view, + unsigned dim); + +} // namespace linalg + +#endif // LINALG1_ANALYSIS_H_ diff --git a/include/linalg1/Common.h b/include/linalg1/Common.h new file mode 100644 index 0000000..6573c72 --- /dev/null +++ b/include/linalg1/Common.h @@ -0,0 +1,120 @@ +//===- Common.h - Linalg dialect RangeOp operation -----------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG1_COMMON_H_ +#define LINALG1_COMMON_H_ + +#include "mlir/AffineOps/AffineOps.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/EDSC/Builders.h" +#include "mlir/EDSC/Helpers.h" +#include "mlir/EDSC/Intrinsics.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Identifier.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/StandardOps/Ops.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/LoopUtils.h" +#include "mlir/Transforms/Passes.h" + +namespace linalg { +namespace common { + +//////////////////////////////////////////////////////////////////////////////// +// Define a few boilerplate objects used across all linalg examples. +//////////////////////////////////////////////////////////////////////////////// + +/// A 2-D abstraction over a flat contiguous memory region of f32 with symbolic +/// sizes. +template <int N> +inline mlir::MemRefType floatMemRefType(mlir::MLIRContext *context, + unsigned memorySpace = 0) { + llvm::SmallVector<int64_t, 4> shape(N, -1); + auto f32 = mlir::FloatType::getF32(context); + return mlir::MemRefType::get(shape, f32, {}, memorySpace); +} + +/// A basic function builder +inline mlir::Function *makeFunction(mlir::Module &module, llvm::StringRef name, + llvm::ArrayRef<mlir::Type> types, + llvm::ArrayRef<mlir::Type> resultTypes) { + auto *context = module.getContext(); + auto *function = new mlir::Function( + mlir::UnknownLoc::get(context), name, + mlir::FunctionType::get({types}, resultTypes, context)); + function->addEntryBlock(); + module.getFunctions().push_back(function); + return function; +} + +/// A basic pass manager pre-populated with cleanup passes. +inline std::unique_ptr<mlir::PassManager> cleanupPassManager() { + std::unique_ptr<mlir::PassManager> pm(new mlir::PassManager()); + pm->addPass(mlir::createCanonicalizerPass()); + pm->addPass(mlir::createSimplifyAffineStructuresPass()); + pm->addPass(mlir::createCSEPass()); + pm->addPass(mlir::createCanonicalizerPass()); + return pm; +} + +/// A simple function to verify and cleanup the IR before printing it to +/// llvm::outs() for FileCheck'ing. +/// If an error occurs, dump to llvm::errs() and do not print to llvm::outs() +/// which will make the associated FileCheck test fail. +inline void cleanupAndPrintFunction(mlir::Function *f) { + bool printToOuts = true; + auto check = [f, &printToOuts](mlir::LogicalResult result) { + if (failed(result)) { + f->getContext()->emitError(f->getLoc(), + "Verification and cleanup passes failed"); + printToOuts = false; + } + }; + auto pm = cleanupPassManager(); + check(f->getModule()->verify()); + check(pm->run(f->getModule())); + if (printToOuts) + f->print(llvm::outs()); +} + +/// Helper class to sugar building loop nests from indexings that appear in +/// ViewOp and SliceOp. +class LoopNestRangeBuilder { +public: + LoopNestRangeBuilder(llvm::ArrayRef<mlir::edsc::ValueHandle *> ivs, + llvm::ArrayRef<mlir::edsc::ValueHandle> indexings); + LoopNestRangeBuilder(llvm::ArrayRef<mlir::edsc::ValueHandle *> ivs, + llvm::ArrayRef<mlir::Value *> indexings); + mlir::edsc::ValueHandle + operator()(llvm::ArrayRef<mlir::edsc::CapturableHandle> stmts); + +private: + llvm::SmallVector<mlir::edsc::LoopBuilder, 4> loops; +}; + +} // namespace common +} // namespace linalg + +#endif // LINALG1_COMMON_H_ diff --git a/include/linalg1/ConvertToLLVMDialect.h b/include/linalg1/ConvertToLLVMDialect.h new file mode 100644 index 0000000..8e5a7ce --- /dev/null +++ b/include/linalg1/ConvertToLLVMDialect.h @@ -0,0 +1,66 @@ +//===- ConvertToLLVMDialect.h - conversion from Linalg to LLVM --*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG1_CONVERTTOLLVMDIALECT_H_ +#define LINALG1_CONVERTTOLLVMDIALECT_H_ + +#include "llvm/ADT/DenseSet.h" +#include "llvm/Support/Allocator.h" + +#include <memory> + +namespace mlir { +class DialectConversion; +class DialectOpConversion; +class MLIRContext; +class Module; +class Type; +namespace LLVM { +class LLVMType; +} // end namespace LLVM +} // end namespace mlir + +namespace linalg { +/// Convert the given Linalg dialect type `t` into an LLVM IR dialect type. +/// Keep all other types unmodified. +mlir::Type convertLinalgType(mlir::Type t); + +/// Allocate the conversion patterns for RangeOp, ViewOp and SliceOp from the +/// Linalg dialect to the LLVM IR dialect. The converters are allocated in the +/// `allocator` using the provided `context`. The latter must have the LLVM IR +/// dialect registered. +/// This function can be used to apply multiple conversion patterns in the same +/// pass. It does not have to be called explicitly before the conversion. +llvm::DenseSet<mlir::DialectOpConversion *> +allocateDescriptorConverters(llvm::BumpPtrAllocator *allocator, + mlir::MLIRContext *context); + +/// Create a DialectConversion from the Linalg dialect to the LLVM IR dialect. +/// The conversion is set up to convert types and function signatures using +/// `convertLinalgType` and obtains operation converters by calling `initer`. +std::unique_ptr<mlir::DialectConversion> makeLinalgToLLVMLowering( + std::function<llvm::DenseSet<mlir::DialectOpConversion *>( + llvm::BumpPtrAllocator *, mlir::MLIRContext *context)> + initer); + +/// Convert the Linalg dialect types and RangeOp, ViewOp and SliceOp operations +/// to the LLVM IR dialect types and operations in the given `module`. This is +/// the main entry point to the conversion. +void convertToLLVM(mlir::Module &module); +} // end namespace linalg + +#endif // LINALG1_CONVERTTOLLVMDIALECT_H_ diff --git a/include/linalg1/Dialect.h b/include/linalg1/Dialect.h new file mode 100644 index 0000000..70023e1 --- /dev/null +++ b/include/linalg1/Dialect.h @@ -0,0 +1,42 @@ +//===- Dialect.h - Definition of the Linalg dialect -----------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG1_DIALECT_H_ +#define LINALG1_DIALECT_H_ + +#include "mlir/IR/Dialect.h" + +namespace linalg { + +/// The Linalg Dialect is not exposed to the outside world. It is registered by +/// linking and accessed via generic MLIR accessors. +class LinalgDialect : public mlir::Dialect { +public: + /// Create a new Dialect that is registered on construction and adds the + /// relevant types and operations. + explicit LinalgDialect(mlir::MLIRContext *context); + + /// Parse a type registered to this dialect. + mlir::Type parseType(llvm::StringRef spec, mlir::Location loc) const override; + + /// Print a type registered to this dialect. + void printType(mlir::Type type, llvm::raw_ostream &os) const override; +}; + +} // namespace linalg + +#endif // LINALG1_DIALECT_H_ diff --git a/include/linalg1/Intrinsics.h b/include/linalg1/Intrinsics.h new file mode 100644 index 0000000..305e3f3 --- /dev/null +++ b/include/linalg1/Intrinsics.h @@ -0,0 +1,32 @@ +//===- Intrinsics.h - Linalg intrinsics definitions -----------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG1_INTRINSICS_H_ +#define LINALG1_INTRINSICS_H_ + +#include "linalg1/Ops.h" +#include "mlir/EDSC/Intrinsics.h" + +namespace linalg { +namespace intrinsics { +using range = mlir::edsc::intrinsics::ValueBuilder<RangeOp>; +using slice = mlir::edsc::intrinsics::ValueBuilder<SliceOp>; +using view = mlir::edsc::intrinsics::ValueBuilder<ViewOp>; +} // namespace intrinsics +} // namespace linalg + +#endif // LINALG1_INTRINSICS_H_ diff --git a/include/linalg1/LLVMIntrinsics.h b/include/linalg1/LLVMIntrinsics.h new file mode 100644 index 0000000..577981b --- /dev/null +++ b/include/linalg1/LLVMIntrinsics.h @@ -0,0 +1,41 @@ +//===- LLVMIntrinsics.h - declarative builders for LLVM dialect -*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG1_LLVMINTRINSICS_H_ +#define LINALG1_LLVMINTRINSICS_H_ + +#include "mlir/EDSC/Builders.h" +#include "mlir/EDSC/Intrinsics.h" +#include "mlir/LLVMIR/LLVMDialect.h" + +// Expose some LLVM IR instructions to declarative builders. +namespace intrinsics { +using undef = mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::UndefOp>; +using insertvalue = + mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::InsertValueOp>; +using extractvalue = + mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::ExtractValueOp>; +using constant = mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::ConstantOp>; +using add = mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::AddOp>; +using sub = mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::SubOp>; +using mul = mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::MulOp>; +using load = mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::LoadOp>; +using store = mlir::edsc::intrinsics::OperationBuilder<mlir::LLVM::StoreOp>; +using gep = mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::GEPOp>; +} // end namespace intrinsics + +#endif // LINALG1_LLVMINTRINSICS_H_ diff --git a/include/linalg1/Ops.h b/include/linalg1/Ops.h new file mode 100644 index 0000000..2e662cf --- /dev/null +++ b/include/linalg1/Ops.h @@ -0,0 +1,26 @@ +//===- Ops.h - Linalg Ops single entry point ------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG1_OPS_H_ +#define LINALG1_OPS_H_ + +#include "linalg1/Types.h" +#include "linalg1/RangeOp.h" +#include "linalg1/SliceOp.h" +#include "linalg1/ViewOp.h" + +#endif // LINALG1_OPS_H_ diff --git a/include/linalg1/RangeOp.h b/include/linalg1/RangeOp.h new file mode 100644 index 0000000..9652f51 --- /dev/null +++ b/include/linalg1/RangeOp.h @@ -0,0 +1,56 @@ +//===- RangeOp.h - Linalg dialect RangeOp operation definition ------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG1_RANGEOP_H_ +#define LINALG1_RANGEOP_H_ + +#include "mlir/IR/OpDefinition.h" +#include "mlir/Support/LLVM.h" + +namespace linalg { + +/// A RangeOp is used to create a value of RangeType from 3 values of type index +/// that represent the min, max and step values of the range. +/// Note: step must be an mlir::ConstantIndexOp for now due to current +/// `affine.for` limitations. +class RangeOp : public mlir::Op<RangeOp, mlir::OpTrait::NOperands<3>::Impl, + mlir::OpTrait::OneResult, + mlir::OpTrait::HasNoSideEffect> { +public: + using Op::Op; + + ////////////////////////////////////////////////////////////////////////////// + // Hooks to customize the behavior of this op. + ////////////////////////////////////////////////////////////////////////////// + static llvm::StringRef getOperationName() { return "linalg.range"; } + static void build(mlir::Builder *b, mlir::OperationState *result, + mlir::Value *min, mlir::Value *max, mlir::Value *step); + mlir::LogicalResult verify(); + static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + void print(mlir::OpAsmPrinter *p); + + ////////////////////////////////////////////////////////////////////////////// + // Op-specific functionality. + ////////////////////////////////////////////////////////////////////////////// + mlir::Value *getMin() { return getOperand(0); } + mlir::Value *getMax() { return getOperand(1); } + mlir::Value *getStep() { return getOperand(2); } +}; + +} // namespace linalg + +#endif // LINALG1_RANGEOP_H_ diff --git a/include/linalg1/RangeType.h b/include/linalg1/RangeType.h new file mode 100644 index 0000000..d17c058 --- /dev/null +++ b/include/linalg1/RangeType.h @@ -0,0 +1,49 @@ +//===- RangeType.h - Linalg RangeType definition --------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG1_RANGETYPE_H_ +#define LINALG1_RANGETYPE_H_ + +#include "linalg1/Types.h" +#include "mlir/IR/Types.h" + +namespace mlir { +class MLIRContext; +} + +namespace linalg { + +/// A RangeType is the simplest possible form of a type in MLIR. It represents +/// a minimal range abstraction (min, max, step). Since RangeType is constructed +/// without any additional argument, this example illustrates the minimal +/// amount of information required to implement a new custom MLIR type. +class RangeType : public mlir::Type::TypeBase<RangeType, mlir::Type> { +public: + // Used to implement llvm-style cast. + using Base::Base; + /// Construction hook. + static RangeType get(mlir::MLIRContext *context) { + /// Custom, uniqu'ed construction in the mlir::MLIRContext. + return Base::get(context, LinalgTypes::Range); + } + /// Used to implement llvm-style cast. + static bool kindof(unsigned kind) { return kind == LinalgTypes::Range; } +}; + +} // namespace linalg + +#endif // LINALG1_RANGETYPE_H_ diff --git a/include/linalg1/SliceOp.h b/include/linalg1/SliceOp.h new file mode 100644 index 0000000..1d79784 --- /dev/null +++ b/include/linalg1/SliceOp.h @@ -0,0 +1,91 @@ +//===- SliceOp.h - Linalg dialect SliceOp operation definition ------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG1_SLICEOP_H_ +#define LINALG1_SLICEOP_H_ + +#include "linalg1/Types.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Support/LLVM.h" + +namespace linalg { + +/// A SliceOp is used to create a "sub-View" from a ViewType. It results in a +/// new ViewType which is contained within its parent ViewType. +class SliceOp : public mlir::Op<SliceOp, mlir::OpTrait::NOperands<2>::Impl, + mlir::OpTrait::OneResult, + mlir::OpTrait::HasNoSideEffect> { +public: + using Op::Op; + + ////////////////////////////////////////////////////////////////////////////// + // Hooks to customize the behavior of this op. + ////////////////////////////////////////////////////////////////////////////// + static llvm::StringRef getOperationName() { return "linalg.slice"; } + static void build(mlir::Builder *b, mlir::OperationState *result, + mlir::Value *view, mlir::Value *indexing, unsigned dim); + mlir::LogicalResult verify(); + static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + void print(mlir::OpAsmPrinter *p); + + ////////////////////////////////////////////////////////////////////////////// + // Op-specific functionality. + ////////////////////////////////////////////////////////////////////////////// + enum { FirstIndexingOperand = 1 }; + /// Returns the attribute name that describes which dimension of the input + /// view that this SliceOp slices. + static llvm::StringRef getSlicingDimAttrName() { return "dim"; } + /// Returns the unique result of the parent SliceOp of ViewOp instruction that + /// created the view on which this SliceOp operates. + mlir::Value *getParentView() { return getOperand(0); } + /// Returns the indexing operand of the current SliceOp. + /// This operands may either be: + /// 1. A range, in which case the operand comes from a RangeOp. This SliceOp + /// does not reduce the dimension of the input ViewType. + /// 2. An index, in which case the operand comes from any possible producer + /// of an index. This SliceOp reduces the dimension of the input ViewType + /// by 1. + mlir::Value *getIndexing() { return getOperand(1); } + /// Returns the dim of the parent ViewType that is sliced by this SliceOp. + unsigned getSlicingDim() { + return getAttrOfType<mlir::IntegerAttr>(getSlicingDimAttrName()).getInt(); + } + /// Returns the ViewType resulting from this SliceOp. + ViewType getViewType(); + /// Returns the rank of the current ViewType. + unsigned getRank(); + /// Return the element type of the current ViewType. + mlir::Type getElementType(); + + /// Returns the ViewType of `getParentView()`. + ViewType getParentViewType(); + /// Returns the rank of the ViewType of `getParentView()`. + unsigned getParentRank(); + /// Returns the element Type of the ViewType of `getParentView()`. + mlir::Type getParentElementType(); + + /// Returns true if the rank of the part view is greater than the rank of + /// the child view. + bool isRankDecreasing(); + + // Get all the indexings in this slice. + mlir::Operation::operand_range getIndexings(); +}; + +} // namespace linalg + +#endif // LINALG1_SLICEOP_H_ diff --git a/include/linalg1/Types.h b/include/linalg1/Types.h new file mode 100644 index 0000000..5032e96 --- /dev/null +++ b/include/linalg1/Types.h @@ -0,0 +1,36 @@ +//===- Types.h - Linalg Types forward declarations ------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG1_TYPES_H_ +#define LINALG1_TYPES_H_ + +#include "mlir/IR/Types.h" + +namespace linalg { + +enum LinalgTypes { + Range = mlir::Type::FIRST_PRIVATE_EXPERIMENTAL_0_TYPE, + View, + FIRST_PRIVATE_EXPERIMENTAL_0_TYPE = View, +}; + +} // namespace linalg + +#include "linalg1/RangeType.h" +#include "linalg1/ViewType.h" + +#endif // LINALG1_TYPES_H_ diff --git a/include/linalg1/Utils.h b/include/linalg1/Utils.h new file mode 100644 index 0000000..3f7bb76 --- /dev/null +++ b/include/linalg1/Utils.h @@ -0,0 +1,37 @@ +//===- Utils.h - Linalg dialect utility functions definitions -------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG1_UTILS_H_ +#define LINALG1_UTILS_H_ + +namespace mlir { +class Value; +} // namespace mlir + +namespace linalg { +class ViewOp; + +/// Asserts `view` is of ViewType and returns its rank. +unsigned getViewRank(mlir::Value *view); + +/// Helper function to emit and return a new ViewOp from `memRef` that is +/// assumed to be of MemRefType. This needs to be called under a ScopedContext. +ViewOp emitAndReturnViewOpFromMemRef(mlir::Value *memRef); + +} // namespace linalg + +#endif // LINALG1_UTILS_H_ diff --git a/include/linalg1/ViewOp.h b/include/linalg1/ViewOp.h new file mode 100644 index 0000000..fcda553 --- /dev/null +++ b/include/linalg1/ViewOp.h @@ -0,0 +1,67 @@ +//===- ViewOp.h - Linalg dialect ViewOp operation definition ------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG1_VIEWOP_H_ +#define LINALG1_VIEWOP_H_ + +#include "mlir/IR/OpDefinition.h" +#include "mlir/Support/LLVM.h" + +namespace linalg { + +class ViewType; + +/// A `ViewOp` produces a `ViewType` which is a multi-dimensional range +/// abstraction on top of an underlying data type. For now we use the existing +/// mlir::MemRef for the underlying data type. +class ViewOp : public mlir::Op<ViewOp, mlir::OpTrait::VariadicOperands, + mlir::OpTrait::OneResult, + mlir::OpTrait::HasNoSideEffect> { +public: + using Op::Op; + + ////////////////////////////////////////////////////////////////////////////// + // Hooks to customize the behavior of this op. + ////////////////////////////////////////////////////////////////////////////// + static llvm::StringRef getOperationName() { return "linalg.view"; } + static void build(mlir::Builder *b, mlir::OperationState *result, + mlir::Value *memRef, + llvm::ArrayRef<mlir::Value *> indexings); + mlir::LogicalResult verify(); + static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + void print(mlir::OpAsmPrinter *p); + + ////////////////////////////////////////////////////////////////////////////// + // Op-specific functionality. + ////////////////////////////////////////////////////////////////////////////// + enum { FirstIndexingOperand = 1 }; + unsigned getRank(); + mlir::Type getElementType(); + ViewType getViewType(); + // May be something else than a MemRef in the future. + mlir::Value *getSupportingMemRef(); + // Get the underlying indexing at a given rank. + mlir::Value *getIndexing(unsigned rank); + // Get all the indexings of type RangeOp. + llvm::SmallVector<mlir::Value *, 8> getRanges(); + // Get all the indexings in this view. + mlir::Operation::operand_range getIndexings(); +}; + +} // namespace linalg + +#endif // LINALG1_VIEWOP_H_ diff --git a/include/linalg1/ViewType.h b/include/linalg1/ViewType.h new file mode 100644 index 0000000..c58e12c --- /dev/null +++ b/include/linalg1/ViewType.h @@ -0,0 +1,57 @@ +//===- ViewType.h - Linalg ViewType definition --------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG1_VIEWTYPE_H_ +#define LINALG1_VIEWTYPE_H_ + +#include "linalg1/Types.h" +#include "mlir/IR/Types.h" + +namespace linalg { + +class ViewTypeStorage; + +/// A ViewType represents a range abstraction on top of an underlying storage +/// type. It is parameterizable by the underlying element type and the rank of +/// the view. +class ViewType + : public mlir::Type::TypeBase<ViewType, mlir::Type, ViewTypeStorage> { +public: + ////////////////////////////////////////////////////////////////////////////// + // Hooks to customize the behavior of this type. + ////////////////////////////////////////////////////////////////////////////// + // Used to implement llvm-style cast. + using Base::Base; + // Used to implement llvm-style cast. + static bool kindof(unsigned kind) { return kind == LinalgTypes::View; } + /// Construction hook. + static ViewType get(mlir::MLIRContext *context, mlir::Type elementType, + unsigned rank); + + ////////////////////////////////////////////////////////////////////////////// + // Type-specific functionality. + ////////////////////////////////////////////////////////////////////////////// + /// Return the underlying elemental type. + mlir::Type getElementType(); + /// Return the rank of the view. + /// This is the number of indexings needed to reach an underlying element. + unsigned getRank(); +}; + +} // namespace linalg + +#endif // LINALG1_VIEWTYPE_H_ diff --git a/include/linalg2/Analysis.h b/include/linalg2/Analysis.h new file mode 100644 index 0000000..43acd95 --- /dev/null +++ b/include/linalg2/Analysis.h @@ -0,0 +1,23 @@ +//===- Analysis.h - Linalg dialect Analysis function definitions ----------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG2_ANALYSIS_H_ +#define LINALG2_ANALYSIS_H_ + +#include "linalg1/Analysis.h" + +#endif // LINALG2_ANALYSIS_H_ diff --git a/include/linalg2/Intrinsics.h b/include/linalg2/Intrinsics.h new file mode 100644 index 0000000..e74e059 --- /dev/null +++ b/include/linalg2/Intrinsics.h @@ -0,0 +1,32 @@ +//===- Intrinsics.h - Linalg intrinsics definitions -----------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG2_INTRINSICS_H_ +#define LINALG2_INTRINSICS_H_ + +#include "linalg1/Intrinsics.h" +#include "linalg2/Ops.h" + +namespace linalg { +namespace intrinsics { +using dot = mlir::edsc::intrinsics::OperationBuilder<DotOp>; +using matmul = mlir::edsc::intrinsics::OperationBuilder<MatmulOp>; +using matvec = mlir::edsc::intrinsics::OperationBuilder<MatvecOp>; +} // namespace intrinsics +} // namespace linalg + +#endif // LINALG2_INTRINSICS_H_ diff --git a/include/linalg2/Ops.h b/include/linalg2/Ops.h new file mode 100644 index 0000000..141b1d0 --- /dev/null +++ b/include/linalg2/Ops.h @@ -0,0 +1,24 @@ +//===- Ops.h - Linalg Ops single entry point ------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG2_OPS_H_ +#define LINALG2_OPS_H_ + +#include "linalg1/Ops.h" +#include "linalg2/TensorOps.h" + +#endif // LINALG2_OPS_H_ diff --git a/include/linalg2/TensorOps-inl.h b/include/linalg2/TensorOps-inl.h new file mode 100644 index 0000000..940f8d7 --- /dev/null +++ b/include/linalg2/TensorOps-inl.h @@ -0,0 +1,120 @@ +//===- TensorOps-inl.h - Linalg dialect TensorOps operation implementation ===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of +/// TensorOps by adding implementations as they are needed in the appropriate +/// step in the tutorial. +#ifndef LINALG2_TENSOROPS_INL_H_ +#define LINALG2_TENSOROPS_INL_H_ + +#include "linalg2/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/StandardTypes.h" + +namespace linalg { + +template <class ConcreteOp> +mlir::Operation::operand_range +linalg::TensorContractionBase<ConcreteOp>::getInputs() { + auto *op = static_cast<ConcreteOp *>(this)->getOperation(); + return {op->operand_begin(), op->operand_begin() + getNumInputs()}; +} + +template <class ConcreteOp> +mlir::Operation::operand_range +linalg::TensorContractionBase<ConcreteOp>::getOutputs() { + auto *op = static_cast<ConcreteOp *>(this)->getOperation(); + return {op->operand_begin() + getNumInputs(), + op->operand_begin() + getNumInputs() + getNumOutputs()}; +} + +template <class ConcreteOp> +mlir::Operation::operand_range +linalg::TensorContractionBase<ConcreteOp>::getInputsAndOutputs() { + return {getInputs().begin(), getOutputs().end()}; +} + +template <class ConcreteOp> +mlir::LogicalResult linalg::TensorContractionBase<ConcreteOp>::verify() { + auto *concreteOp = static_cast<ConcreteOp *>(this)->getOperation(); + if (getNumInputs() <= 0) + concreteOp->emitOpError("expected at least one input"); + if (getNumOutputs() <= 0) + concreteOp->emitOpError("expected at least one output"); + if (concreteOp->getNumOperands() != getNumInputs() + getNumOutputs()) { + concreteOp->emitOpError("expected " + + llvm::Twine(getNumInputs() + getNumOutputs()) + + " operands"); + } + for (unsigned i = 0, e = getNumInputs(); i < e; ++i) { + if (!concreteOp->getOperand(i)->getType().template isa<ViewType>()) + return concreteOp->emitOpError("operand " + llvm::Twine(i) + + " not a ViewType"); + } + for (unsigned i = getNumInputs(), e = getNumInputs() + getNumOutputs(); i < e; + ++i) { + auto viewType = + concreteOp->getOperand(i)->getType().template dyn_cast<ViewType>(); + if (!viewType) + return concreteOp->emitOpError("operand " + llvm::Twine(i) + + " not a ViewType"); + if (viewType.getRank() != getNumParallelDims()) + return concreteOp->emitOpError("operand " + llvm::Twine(i) + + " must be of rank " + + llvm::Twine(getNumParallelDims())); + } + return mlir::success(); +} + +template <class ConcreteOp> +bool linalg::TensorContractionBase<ConcreteOp>::parse( + mlir::OpAsmParser *parser, mlir::OperationState *result) { + llvm_unreachable("Parsing linalg dialect is not supported in this tutorial"); +} + +// A TensorContraction prints as: +// +// ```{.mlir} +// concrete_op_name (ssa-inputs, ssa-outputs) : output-view-types +// ``` +// +// for example: +// +// ``` +// linalg.matmul(%0, %1, %2) : view<?x?xf32> +// ``` +// +// Where %0, %1 and %2 are ssa-values of type ViewType. +template <class ConcreteOp> +void linalg::TensorContractionBase<ConcreteOp>::print(mlir::OpAsmPrinter *p) { + *p << static_cast<ConcreteOp *>(this)->getOperationName() << "("; + auto *last = *std::prev(getInputsAndOutputs().end()); + for (auto *i : getInputsAndOutputs()) { + *p << *i << ((i == last) ? "" : ", "); + } + *p << ") : "; + auto *lastOutput = *std::prev(getOutputs().end()); + for (auto *o : getOutputs()) { + *p << o->getType() << ((o == lastOutput) ? "" : ","); + } +} + +} // namespace linalg + +#endif // LINALG2_TENSOROPS_INL_H_ diff --git a/include/linalg2/TensorOps.h b/include/linalg2/TensorOps.h new file mode 100644 index 0000000..39e51f0 --- /dev/null +++ b/include/linalg2/TensorOps.h @@ -0,0 +1,287 @@ +//===- TensorOps.h - Linalg dialect TensorOps operation definition --------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG2_TENSOROPS_H_ +#define LINALG2_TENSOROPS_H_ + +#include "mlir/IR/OpDefinition.h" +#include "mlir/Support/LLVM.h" + +namespace mlir { +class AffineForOp; +} // namespace mlir + +namespace linalg { + +/// A generic TensorContraction base class which captures the generic behavior +/// of tensor contraction operations (with broadcast). +template <class ConcreteOp> class TensorContractionBase { +protected: + using TensorContractionBaseType = TensorContractionBase<ConcreteOp>; + + ////////////////////////////////////////////////////////////////////////////// + // Hooks to customize the behavior of this op. + ////////////////////////////////////////////////////////////////////////////// + /// Generic implementation of hooks that should be called from `ConcreteType`s + mlir::LogicalResult verify(); + static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + void print(mlir::OpAsmPrinter *p); + +public: + ////////////////////////////////////////////////////////////////////////////// + // Op-specific functionality. + ////////////////////////////////////////////////////////////////////////////// + TensorContractionBase() = default; + mlir::Operation::operand_range getInputs(); + mlir::Operation::operand_range getOutputs(); + mlir::Operation::operand_range getInputsAndOutputs(); + + /// These are better as methods calling into the ConcreteOp instead of + /// template parameters because methods allow more generic behavior and avoid + /// specializing for number of arguments. All derived classes have + /// `VariadicOperands` and a build method from both an ArrayRef<mlirValue*> + /// and the proper number of mlir::Value*. + unsigned getNumInputs() { + return static_cast<ConcreteOp *>(this)->numInputs; + }; + unsigned getNumOutputs() { + return static_cast<ConcreteOp *>(this)->numOutputs; + }; + unsigned getNumParallelDims() { + return static_cast<ConcreteOp *>(this)->numParallelDims; + }; + unsigned getNumReductionDims() { + return static_cast<ConcreteOp *>(this)->numReductionDims; + }; + + ////////////////////////////////////////////////////////////////////////////// + // Used in Linalg3 and later. + ////////////////////////////////////////////////////////////////////////////// + mlir::Value *getInputView(unsigned viewIndex); + mlir::Value *getOutputView(unsigned viewIndex); + mlir::Value *getView(unsigned viewIndex) { + return viewIndex < getNumInputs() + ? getInputView(viewIndex) + : getOutputView(viewIndex - getNumInputs()); + } + + /// Each op is responsible for declaring how it lowers itself to scalar form, + /// given the enclosing parallel and reduction induction variables. + /// `emitScalarImplementation` emits the scalar IR for the op in the nesting + /// context of the innermost enclosing loop(i.e. `reductionIvs.back()` or + /// `parallel.back()`). + void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs, + llvm::ArrayRef<mlir::Value *> reductionIvs); + + /// Represents a mapping from the loops to all the ranges of the operands. + /// The operands and their ranges are in the order defined by the particular + /// ConcreteOp implementation, the resulting map must match those. + /// In favorable cases, this can be calculated by an analysis but specifying + /// it explicitly is not expensive and generalizes to cases where an analysis + /// is not available. For details, see the description of + /// loopsToOperandRangeMaps in each ConcreteOp. + llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps(); +}; + +/// Implements c = A * B where c is a scalar and A and B are 1-D vectors. +class DotOp : public TensorContractionBase<DotOp>, + public mlir::Op<DotOp, mlir::OpTrait::VariadicOperands, + mlir::OpTrait::ZeroResult> { +public: + using Op::Op; + using TensorContractionBaseType = + TensorContractionBase::TensorContractionBaseType; + + ////////////////////////////////////////////////////////////////////////////// + // Hooks to customize the behavior of this op. + ////////////////////////////////////////////////////////////////////////////// + static llvm::StringRef getOperationName() { return "linalg.dot"; } + static void build(mlir::Builder *b, mlir::OperationState *result, + llvm::ArrayRef<mlir::Value *> operands); + static void build(mlir::Builder *b, mlir::OperationState *result, + mlir::Value *A, mlir::Value *B, mlir::Value *C) { + return build(b, result, {A, B, C}); + } + mlir::LogicalResult verify(); + static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + void print(mlir::OpAsmPrinter *p); + + ////////////////////////////////////////////////////////////////////////////// + // Op-specific functionality. + ////////////////////////////////////////////////////////////////////////////// + static constexpr unsigned numInputs = 2; + static constexpr unsigned numOutputs = 1; + static constexpr unsigned numParallelDims = 0; + static constexpr unsigned numReductionDims = 1; + + ////////////////////////////////////////////////////////////////////////////// + // Used in Linalg3 and later. + ////////////////////////////////////////////////////////////////////////////// + /// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a + /// loop over matvec). Does nothing by default. + void writeAsFinerGrainTensorContraction(); + + /// Inputs to this map will be (%k) coming from enclosing loops. + /// Therefore, the mapping to get back to A(K), B(K), C() is: + /// (d0) -> (d0, d0)(%k) + /// And the operands ranges are: + /// (%k, %k) + llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps(); + + /// Given an enclosing reduction loop with iv `r_i`, emits MLIR corresponding + /// to: + /// 1. conditionally assign scalarC to 0.0f on the first iteration or load + /// C[] from memory (0-D tensor) + /// 2. multiply A[r_i] by B[r_i] and add to scalarC + /// 3. store back scalarC at C[] + /// + /// In some compact index notation this could be written: + /// cond = (r_i == zero) + /// scalarC = select(cond, zerof, C[]); + /// C[] = scalarC + A[r_i] * B[r_i]; + void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs, + llvm::ArrayRef<mlir::Value *> reductionIvs); +}; + +/// Implements C = A * B where A is a 2-D matrix and X and Y are 1-D vectors. +class MatvecOp : public TensorContractionBase<MatvecOp>, + public mlir::Op<MatvecOp, mlir::OpTrait::VariadicOperands, + mlir::OpTrait::ZeroResult> { +public: + using Op::Op; + using TensorContractionBaseType = + TensorContractionBase::TensorContractionBaseType; + + ////////////////////////////////////////////////////////////////////////////// + // Hooks to customize the behavior of this op. + ////////////////////////////////////////////////////////////////////////////// + static llvm::StringRef getOperationName() { return "linalg.matvec"; } + static void build(mlir::Builder *b, mlir::OperationState *result, + llvm::ArrayRef<mlir::Value *> operands); + static void build(mlir::Builder *b, mlir::OperationState *result, + mlir::Value *A, mlir::Value *B, mlir::Value *C) { + return build(b, result, {A, B, C}); + } + mlir::LogicalResult verify(); + static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + void print(mlir::OpAsmPrinter *p); + + ////////////////////////////////////////////////////////////////////////////// + // Op-specific functionality. + ////////////////////////////////////////////////////////////////////////////// + static constexpr unsigned numInputs = 2; + static constexpr unsigned numOutputs = 1; + static constexpr unsigned numParallelDims = 1; + static constexpr unsigned numReductionDims = 1; + + ////////////////////////////////////////////////////////////////////////////// + // Used in Linalg3 and later. + ////////////////////////////////////////////////////////////////////////////// + /// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a + /// loop over matvec). Does nothing by default. + void writeAsFinerGrainTensorContraction(); + + /// Inputs to this map will be (%m, %k) coming from enclosing loops. + /// Therefore, the mapping to get back to A(M, K), B(K), C(M) is: + /// (d0, d1) -> (d0, d1, d1, d0)(%m, %k) + /// And the operands ranges are: + /// (%m, %k, %k, %m) + llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps(); + + /// Given an enclosing parallel loop with iv `i` and an enclosing parallel + /// loop with iv `r_j`, emits MLIR corresponding to: + /// 1. conditionally assign scalarC to 0.0f on the first iteration or load + /// C[i] + /// 2. multiply A[i, r_j] by B[r_j] and add to scalarC + /// 3. store back scalarC at C[i] + /// + /// In some compact index notation this could be written: + /// cond = (r_j == zero) + /// scalarC = select(cond, zerof, C(i)); + /// C(i) = scalarC + A(i, r_j) * B(r_j); + void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs, + llvm::ArrayRef<mlir::Value *> reductionIvs); +}; + +/// Implements C = A * B on 2-D matrices. +class MatmulOp : public TensorContractionBase<MatmulOp>, + public mlir::Op<MatmulOp, mlir::OpTrait::VariadicOperands, + mlir::OpTrait::ZeroResult> { +public: + using Op::Op; + using TensorContractionBaseType = + TensorContractionBase::TensorContractionBaseType; + + ////////////////////////////////////////////////////////////////////////////// + // Hooks to customize the behavior of this op. + ////////////////////////////////////////////////////////////////////////////// + static llvm::StringRef getOperationName() { return "linalg.matmul"; } + static void build(mlir::Builder *b, mlir::OperationState *result, + llvm::ArrayRef<mlir::Value *> operands); + static void build(mlir::Builder *b, mlir::OperationState *result, + mlir::Value *A, mlir::Value *B, mlir::Value *C) { + return build(b, result, {A, B, C}); + } + mlir::LogicalResult verify(); + static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + void print(mlir::OpAsmPrinter *p); + + ////////////////////////////////////////////////////////////////////////////// + // Op-specific functionality. + ////////////////////////////////////////////////////////////////////////////// + static constexpr unsigned numInputs = 2; + static constexpr unsigned numOutputs = 1; + static constexpr unsigned numParallelDims = 2; + static constexpr unsigned numReductionDims = 1; + + ////////////////////////////////////////////////////////////////////////////// + // Used in Linalg3 and later. + ////////////////////////////////////////////////////////////////////////////// + /// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a + /// loop over matvec). Does nothing by default. + void writeAsFinerGrainTensorContraction(); + + /// Inputs to this map will be (%m, %n, %k) coming from enclosing loops. + /// Therefore, the mapping to get back to A(M, K), B(K, N), C(M, N) is: + /// (d0, d1, d2) -> (d0, d2, d2, d1, d0, d1)(%m, %n, %k) + /// And the operands ranges are: + /// (%m, %k, %k, %n, %m, %n) + llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps(); + + /// Given a enclosing parallel loops with ivs `i` and `j`, and an enclosing + /// reduction loop with iv `r_k`, emits MLIR corresponding to: + /// 1. conditionally assign scalarC to 0.0f on the first iteration or load + /// C[i, j] + /// 2. multiply A[i, r_k] by B[r_k, j] and add to scalarC + /// 3. store back scalarC at C[i, j] + /// + /// In some compact index notation this could be written: + /// cond = (r_k == zero) + /// scalarC = select(cond, zerof, C[i, j]); + /// C[i, j] = scalarC + A[i, r_k] * B[r_k, j]; + void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs, + llvm::ArrayRef<mlir::Value *> reductionIvs); +}; + +} // namespace linalg + +/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of +/// TensorOps by adding implementations as they are needed in the appropriate +/// step in the tutorial. +#include "linalg2/TensorOps-inl.h" + +#endif // LINALG2_TENSOROPS_H_ diff --git a/include/linalg2/Transforms.h b/include/linalg2/Transforms.h new file mode 100644 index 0000000..c55f863 --- /dev/null +++ b/include/linalg2/Transforms.h @@ -0,0 +1,36 @@ +//===- Transforms.h - Linalg dialect Transformations definition -----------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG2_TRANSFORMS_H_ +#define LINALG2_TRANSFORMS_H_ + +namespace mlir { +class Value; +} // namespace mlir + +namespace linalg { + +class ViewOp; + +/// Takes a `view` of type ViewType (i.e. either a ViewOp or a SliceOp) and +/// composes away all the SliceOp to return a single ViewOp. +/// Inserts the required operations after `view`. +ViewOp emitAndReturnFullyComposedView(mlir::Value *v); + +} // namespace linalg + +#endif // LINALG2_TRANSFORMS_H_ diff --git a/include/linalg3/Analysis.h b/include/linalg3/Analysis.h new file mode 100644 index 0000000..813fc37 --- /dev/null +++ b/include/linalg3/Analysis.h @@ -0,0 +1,37 @@ +//===- Analysis.h - Linalg dialect Analysis function definitions ----------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG3_ANALYSIS_H_ +#define LINALG3_ANALYSIS_H_ + +#include "linalg2/Analysis.h" + +namespace mlir { +class AffineMap; +} // namespace mlir + +namespace linalg { + +/// Given a `map` specification and a subset of its results +/// `[beginResult, endResult)`, returns the inverse map that maps result +/// positions to dim positions. +mlir::AffineMap inverseSubMap(mlir::AffineMap map, unsigned beginResult = 0, + unsigned endResult = 0); + +} // namespace linalg + +#endif // LINALG3_ANALYSIS_H_ diff --git a/include/linalg3/ConvertToLLVMDialect.h b/include/linalg3/ConvertToLLVMDialect.h new file mode 100644 index 0000000..8f122e0 --- /dev/null +++ b/include/linalg3/ConvertToLLVMDialect.h @@ -0,0 +1,29 @@ +//===- ConvertToLLVMDialect.h - conversion from Linalg to LLVM --*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG3_CONVERTTOLLVMDIALECT_H_ +#define LINALG3_CONVERTTOLLVMDIALECT_H_ + +namespace mlir { +class Module; +} // end namespace mlir + +namespace linalg { +void convertLinalg3ToLLVM(mlir::Module &module); +} // end namespace linalg + +#endif // LINALG3_CONVERTTOLLVMDIALECT_H_ diff --git a/include/linalg3/Intrinsics.h b/include/linalg3/Intrinsics.h new file mode 100644 index 0000000..75a0417 --- /dev/null +++ b/include/linalg3/Intrinsics.h @@ -0,0 +1,31 @@ +//===- Intrinsics.h - Linalg intrinsics definitions -----------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG3_INTRINSICS_H_ +#define LINALG3_INTRINSICS_H_ + +#include "linalg2/Intrinsics.h" +#include "linalg3/Ops.h" + +namespace linalg { +namespace intrinsics { +using load = mlir::edsc::intrinsics::ValueBuilder<LoadOp>; +using store = mlir::edsc::intrinsics::OperationBuilder<StoreOp>; +} // namespace intrinsics +} // namespace linalg + +#endif // LINALG3_INTRINSICS_H_ diff --git a/include/linalg3/LoadStoreOps.h b/include/linalg3/LoadStoreOps.h new file mode 100644 index 0000000..b77e702 --- /dev/null +++ b/include/linalg3/LoadStoreOps.h @@ -0,0 +1,89 @@ +//===- LoadStoreOps.h - Linalg dialect Load/Store operation definitions ---===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG3_LOADSTOREOP_H_ +#define LINALG3_LOADSTOREOP_H_ + +#include "mlir/IR/OpDefinition.h" +#include "mlir/Support/LLVM.h" + +namespace linalg { + +class ViewType; + +/// A linalg.LoadOp is the counterpart of affine.load but operating on ViewType +/// instead of MemRefType. +class LoadOp : public mlir::Op<LoadOp, mlir::OpTrait::VariadicOperands, + mlir::OpTrait::OneResult> { +public: + using Op::Op; + + ////////////////////////////////////////////////////////////////////////////// + // Hooks to customize the behavior of this op. + ////////////////////////////////////////////////////////////////////////////// + static llvm::StringRef getOperationName() { return "linalg.load"; } + static void build(mlir::Builder *b, mlir::OperationState *result, + mlir::Value *view, + mlir::ArrayRef<mlir::Value *> indices = {}); + mlir::LogicalResult verify(); + static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + void print(mlir::OpAsmPrinter *p); + + ////////////////////////////////////////////////////////////////////////////// + // Op-specific functionality. + ////////////////////////////////////////////////////////////////////////////// + unsigned getRank(); + ViewType getViewType(); + mlir::Value *getView() { return getOperand(0); } + mlir::Operation::operand_range getIndices() { + return {operand_begin() + 1, operand_end()}; + } +}; + +/// A linalg.StoreOp is the counterpart of affine.store but operating on +/// ViewType instead of MemRefType. +class StoreOp : public mlir::Op<StoreOp, mlir::OpTrait::VariadicOperands, + mlir::OpTrait::ZeroResult> { +public: + using Op::Op; + + ////////////////////////////////////////////////////////////////////////////// + // Hooks to customize the behavior of this op. + ////////////////////////////////////////////////////////////////////////////// + static llvm::StringRef getOperationName() { return "linalg.store"; } + static void build(mlir::Builder *b, mlir::OperationState *result, + mlir::Value *valueToStore, mlir::Value *view, + mlir::ArrayRef<mlir::Value *> indices = {}); + mlir::LogicalResult verify(); + static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + void print(mlir::OpAsmPrinter *p); + + ////////////////////////////////////////////////////////////////////////////// + // Op-specific functionality. + ////////////////////////////////////////////////////////////////////////////// + unsigned getRank(); + ViewType getViewType(); + mlir::Value *getValueToStore() { return getOperand(0); } + mlir::Value *getView() { return getOperand(1); } + mlir::Operation::operand_range getIndices() { + return {operand_begin() + 2, operand_end()}; + } +}; + +} // namespace linalg + +#endif // LINALG3_LOADSTOREOP_H_ diff --git a/include/linalg3/Ops.h b/include/linalg3/Ops.h new file mode 100644 index 0000000..813cbff --- /dev/null +++ b/include/linalg3/Ops.h @@ -0,0 +1,25 @@ +//===- Ops.h - Linalg Ops single entry point ------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG3_OPS_H_ +#define LINALG3_OPS_H_ + +#include "linalg2/Ops.h" +#include "linalg3/LoadStoreOps.h" +#include "linalg3/TensorOps.h" + +#endif // LINALG3_OPS_H_ diff --git a/include/linalg3/TensorOps-inl.h b/include/linalg3/TensorOps-inl.h new file mode 100644 index 0000000..b651053 --- /dev/null +++ b/include/linalg3/TensorOps-inl.h @@ -0,0 +1,145 @@ +//===- TensorOps-inl.h - Linalg dialect TensorOps operation implementation ===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of +/// TensorOps by adding implementations as they are needed in the appropriate +/// step in the tutorial. +#ifndef LINALG3_TENSOROPS_INL_H_ +#define LINALG3_TENSOROPS_INL_H_ + +#include "linalg1/Common.h" +#include "linalg1/Utils.h" +#include "linalg2/TensorOps.h" +#include "linalg3/Analysis.h" +#include "linalg3/Ops.h" + +template <class ConcreteOp> +mlir::Value * +linalg::TensorContractionBase<ConcreteOp>::getInputView(unsigned viewIndex) { + return *(getInputs().begin() + viewIndex); +} + +template <class ConcreteOp> +mlir::Value * +linalg::TensorContractionBase<ConcreteOp>::getOutputView(unsigned viewIndex) { + return *(getOutputs().begin() + viewIndex); +} + +template <class ConcreteOp> +llvm::SmallVector<mlir::AffineMap, 8> +linalg::TensorContractionBase<ConcreteOp>::loopsToOperandRangeMaps() { + return static_cast<ConcreteOp *>(this)->loopsToOperandRangeMaps(); +} + +template <class ConcreteOp> +void linalg::TensorContractionBase<ConcreteOp>::emitScalarImplementation( + llvm::ArrayRef<mlir::Value *> parallelIvs, + llvm::ArrayRef<mlir::Value *> reductionIvs) { + static_cast<ConcreteOp *>(this)->emitScalarImplementation(parallelIvs, + reductionIvs); +} + +template <class ConcreteOp> +mlir::AffineMap linalg::operandRangesToLoopsMap( + linalg::TensorContractionBase<ConcreteOp> &tensorContraction) { + mlir::AffineMap current; + // Individual submaps may not be invertible but their union must be invertible + // by construction. + for (auto m : tensorContraction.loopsToOperandRangeMaps()) { + if (!m) + continue; + if (!current) { + current = m; + continue; + } + llvm::SmallVector<mlir::AffineExpr, 8> results(current.getResults().begin(), + current.getResults().end()); + results.append(m.getResults().begin(), m.getResults().end()); + current = mlir::AffineMap::get( + std::max(current.getNumDims(), m.getNumDims()), + current.getNumSymbols() + m.getNumSymbols(), results, {}); + } + return inverseSubMap(current); +} + +// Extract the ranges from a given ViewOp or SliceOp. +// +// In the case of a ViewOp, things are simple: just traverse the indexings and +// get all the ranges (i.e. drop the indices). +// +// In the case of a SliceOp, things are trickier because we need to handle a +// potential rank-reduction: +// 1. Examine the indexing to determine if it is rank-reducing. +// 2. If it is rank-reducing, an offset of 1 is added to the dimensions such +// that `d >= slicingDim`. This is to account for the rank reduction. +// `getRootIndex` is then called on the **parent** view +static llvm::SmallVector<mlir::Value *, 8> +extractRangesFromViewOrSliceOp(mlir::Value *view) { + // This expects a viewType which must come from either ViewOp or SliceOp. + assert(view->getType().isa<linalg::ViewType>() && "expected ViewType"); + if (auto viewOp = view->getDefiningOp()->dyn_cast<linalg::ViewOp>()) + return viewOp.getRanges(); + + auto sliceOp = view->getDefiningOp()->cast<linalg::SliceOp>(); + unsigned slicingDim = sliceOp.getSlicingDim(); + auto *indexing = *(sliceOp.getIndexings().begin()); + bool isRankReducing = indexing->getType().isa<mlir::IndexType>(); + unsigned offset = 0; + llvm::SmallVector<mlir::Value *, 8> res; + res.reserve(sliceOp.getRank()); + for (unsigned d = 0, e = sliceOp.getRank(); d < e; ++d) { + if (d == slicingDim && isRankReducing) + offset = 1; + auto *parentView = sliceOp.getParentView(); + auto indexingPosPair = linalg::getViewRootIndexing(parentView, d + offset); + res.push_back(indexingPosPair.first); + } + return res; +} + +template <class ConcreteOp> +static llvm::SmallVector<mlir::Value *, 8> +getInputRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction) { + llvm::SmallVector<mlir::Value *, 8> res; + for (auto *in : tensorContraction.getInputs()) { + auto subres = extractRangesFromViewOrSliceOp(in); + res.append(subres.begin(), subres.end()); + } + return res; +} + +template <class ConcreteOp> +static llvm::SmallVector<mlir::Value *, 8> +getOutputRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction) { + llvm::SmallVector<mlir::Value *, 8> res; + for (auto *out : tensorContraction.getOutputs()) { + auto subres = extractRangesFromViewOrSliceOp(out); + res.append(subres.begin(), subres.end()); + } + return res; +} + +template <class ConcreteOp> +llvm::SmallVector<mlir::Value *, 8> linalg::getRanges( + linalg::TensorContractionBase<ConcreteOp> &tensorContraction) { + llvm::SmallVector<mlir::Value *, 8> res = getInputRanges(tensorContraction); + llvm::SmallVector<mlir::Value *, 8> tmp = getOutputRanges(tensorContraction); + res.append(tmp.begin(), tmp.end()); + return res; +} + +#endif // LINALG3_TENSOROPS_INL_H_ diff --git a/include/linalg3/TensorOps.h b/include/linalg3/TensorOps.h new file mode 100644 index 0000000..bf5a377 --- /dev/null +++ b/include/linalg3/TensorOps.h @@ -0,0 +1,54 @@ +//===- TensorOps.h - Linalg dialect TensorOps operation definition --------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG3_TENSOROPS_H_ +#define LINALG3_TENSOROPS_H_ + +#include "linalg2/TensorOps.h" + +namespace linalg { + +/// +/// Ideally all these functions would go in an Analysis but as long as +/// TensorContractionBase is templated, they need to remain close enough. +/// + +/// Takes a `tensorContraction` and a returns an AffineMap that can be used to +/// map ranges to enclosing loops for all the operands' ranges. +template <class ConcreteOp> +mlir::AffineMap operandRangesToLoopsMap( + linalg::TensorContractionBase<ConcreteOp> &tensorContraction); + +/// Takes a `tensorContraction` and returns the ranges of all its operands. +/// When an operand comes from a ViewOp, things are simple: +/// just traverse the indexings and get all the ranges +/// (i.e. drop the rank-reducing indices). +/// In the case of a SliceOp, things are more involved because we need to handle +/// potential rank-reductions. +/// This function abstracts this complexity away and returns all the ranges. +template <class ConcreteOp> +llvm::SmallVector<mlir::Value *, 8> +getRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction); + +} // namespace linalg + +/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of +/// TensorOps by adding implementations as they are needed in the appropriate +/// step in the tutorial. +#include "linalg3/TensorOps-inl.h" + +#endif // LINALG3_TENSOROPS_H_ diff --git a/include/linalg3/Transforms.h b/include/linalg3/Transforms.h new file mode 100644 index 0000000..9af528e --- /dev/null +++ b/include/linalg3/Transforms.h @@ -0,0 +1,80 @@ +//===- Transforms.h - Linalg dialect Transformations definition -----------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG3_TRANSFORMS_H_ +#define LINALG3_TRANSFORMS_H_ + +#include "linalg2/Transforms.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Optional.h" + +namespace mlir { +class AffineForOp; +class AffineMap; +class Function; +class FunctionPassBase; +class Operation; +class Value; +} // namespace mlir + +namespace linalg { + +struct RangeParts { + explicit RangeParts(unsigned reserved); + RangeParts(llvm::ArrayRef<mlir::Value *> ranges); + llvm::SmallVector<mlir::Value *, 4> makeRanges(); + + llvm::SmallVector<mlir::Value *, 4> mins; + llvm::SmallVector<mlir::Value *, 4> maxes; + llvm::SmallVector<mlir::Value *, 4> steps; +}; + +mlir::Value * +makeFoldedComposedAffineApply(mlir::AffineMap map, + llvm::ArrayRef<mlir::Value *> operandsRef); + +llvm::SmallVector<mlir::Value *, 4> +makeGenericLoopRanges(mlir::AffineMap operandRangesToLoopMaps, + llvm::ArrayRef<mlir::Value *> ranges, + llvm::ArrayRef<mlir::Value *> tileSizes = {}); + +/// Traverses `f` and rewrites linalg.slice, and the operations it depends on, +/// to only use linalg.view operations. +void composeSliceOps(mlir::Function *f); + +/// Traverses `f` and rewrites linalg.matmul(resp. linalg.matvec) +/// as linalg.matvec(resp. linalg.dot). +void lowerToFinerGrainedTensorContraction(mlir::Function *f); + +/// Operation-wise writing of linalg operations to loop form. +/// It is the caller's responsibility to erase the `op` if necessary. +/// This returns the enclosing loops around the body of `op` for further +/// composition of transformations. +llvm::Optional<llvm::SmallVector<mlir::AffineForOp, 4>> +writeAsLoops(mlir::Operation *op); + +/// Traverses `f` and rewrites linalg operations in loop form. +void lowerToLoops(mlir::Function *f); + +/// Creates a pass that rewrites linalg.load and linalg.store to affine.load and +/// affine.store operations. +mlir::FunctionPassBase *createLowerLinalgLoadStorePass(); + +} // namespace linalg + +#endif // LINALG3_TRANSFORMS_H_ diff --git a/include/toy/AST.h b/include/toy/AST.h new file mode 100644 index 0000000..456a323 --- /dev/null +++ b/include/toy/AST.h @@ -0,0 +1,256 @@ +//===- AST.h - Node definition for the Toy AST ----------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements the AST for the Toy language. It is optimized for +// simplicity, not efficiency. The AST forms a tree structure where each node +// references its children using std::unique_ptr<>. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_AST_H_ +#define MLIR_TUTORIAL_TOY_AST_H_ + +#include "toy/Lexer.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include <vector> + +namespace toy { + +/// A variable +struct VarType { + enum { TY_FLOAT, TY_INT } elt_ty; + std::vector<int> shape; +}; + +/// Base class for all expression nodes. +class ExprAST { +public: + enum ExprASTKind { + Expr_VarDecl, + Expr_Return, + Expr_Num, + Expr_Literal, + Expr_Var, + Expr_BinOp, + Expr_Call, + Expr_Print, // builtin + Expr_If, + Expr_For, + }; + + ExprAST(ExprASTKind kind, Location location) + : kind(kind), location(location) {} + + virtual ~ExprAST() = default; + + ExprASTKind getKind() const { return kind; } + + const Location &loc() { return location; } + +private: + const ExprASTKind kind; + Location location; +}; + +/// A block-list of expressions. +using ExprASTList = std::vector<std::unique_ptr<ExprAST>>; + +/// Expression class for numeric literals like "1.0". +class NumberExprAST : public ExprAST { + double Val; + +public: + NumberExprAST(Location loc, double Val) : ExprAST(Expr_Num, loc), Val(Val) {} + + double getValue() { return Val; } + + /// LLVM style RTTI + static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; } +}; + +/// +class LiteralExprAST : public ExprAST { + std::vector<std::unique_ptr<ExprAST>> values; + std::vector<int64_t> dims; + +public: + LiteralExprAST(Location loc, std::vector<std::unique_ptr<ExprAST>> values, + std::vector<int64_t> dims) + : ExprAST(Expr_Literal, loc), values(std::move(values)), + dims(std::move(dims)) {} + + std::vector<std::unique_ptr<ExprAST>> &getValues() { return values; } + std::vector<int64_t> &getDims() { return dims; } + /// LLVM style RTTI + static bool classof(const ExprAST *C) { return C->getKind() == Expr_Literal; } +}; + +/// Expression class for referencing a variable, like "a". +class VariableExprAST : public ExprAST { + std::string name; + +public: + VariableExprAST(Location loc, const std::string &name) + : ExprAST(Expr_Var, loc), name(name) {} + + llvm::StringRef getName() { return name; } + + /// LLVM style RTTI + static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; } +}; + +/// +class VarDeclExprAST : public ExprAST { + std::string name; + VarType type; + std::unique_ptr<ExprAST> initVal; + +public: + VarDeclExprAST(Location loc, const std::string &name, VarType type, + std::unique_ptr<ExprAST> initVal) + : ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)), + initVal(std::move(initVal)) {} + + llvm::StringRef getName() { return name; } + ExprAST *getInitVal() { return initVal.get(); } + VarType &getType() { return type; } + + /// LLVM style RTTI + static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; } +}; + +/// +class ReturnExprAST : public ExprAST { + llvm::Optional<std::unique_ptr<ExprAST>> expr; + +public: + ReturnExprAST(Location loc, llvm::Optional<std::unique_ptr<ExprAST>> expr) + : ExprAST(Expr_Return, loc), expr(std::move(expr)) {} + + llvm::Optional<ExprAST *> getExpr() { + if (expr.hasValue()) + return expr->get(); + return llvm::NoneType(); + } + + /// LLVM style RTTI + static bool classof(const ExprAST *C) { return C->getKind() == Expr_Return; } +}; + +/// Expression class for a binary operator. +class BinaryExprAST : public ExprAST { + char Op; + std::unique_ptr<ExprAST> LHS, RHS; + +public: + char getOp() { return Op; } + ExprAST *getLHS() { return LHS.get(); } + ExprAST *getRHS() { return RHS.get(); } + + BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> LHS, + std::unique_ptr<ExprAST> RHS) + : ExprAST(Expr_BinOp, loc), Op(Op), LHS(std::move(LHS)), + RHS(std::move(RHS)) {} + + /// LLVM style RTTI + static bool classof(const ExprAST *C) { return C->getKind() == Expr_BinOp; } +}; + +/// Expression class for function calls. +class CallExprAST : public ExprAST { + std::string Callee; + std::vector<std::unique_ptr<ExprAST>> Args; + +public: + CallExprAST(Location loc, const std::string &Callee, + std::vector<std::unique_ptr<ExprAST>> Args) + : ExprAST(Expr_Call, loc), Callee(Callee), Args(std::move(Args)) {} + + llvm::StringRef getCallee() { return Callee; } + llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return Args; } + + /// LLVM style RTTI + static bool classof(const ExprAST *C) { return C->getKind() == Expr_Call; } +}; + +/// Expression class for builtin print calls. +class PrintExprAST : public ExprAST { + std::unique_ptr<ExprAST> Arg; + +public: + PrintExprAST(Location loc, std::unique_ptr<ExprAST> Arg) + : ExprAST(Expr_Print, loc), Arg(std::move(Arg)) {} + + ExprAST *getArg() { return Arg.get(); } + + /// LLVM style RTTI + static bool classof(const ExprAST *C) { return C->getKind() == Expr_Print; } +}; + +/// This class represents the "prototype" for a function, which captures its +/// name, and its argument names (thus implicitly the number of arguments the +/// function takes). +class PrototypeAST { + Location location; + std::string name; + std::vector<std::unique_ptr<VariableExprAST>> args; + +public: + PrototypeAST(Location location, const std::string &name, + std::vector<std::unique_ptr<VariableExprAST>> args) + : location(location), name(name), args(std::move(args)) {} + + const Location &loc() { return location; } + const std::string &getName() const { return name; } + const std::vector<std::unique_ptr<VariableExprAST>> &getArgs() { + return args; + } +}; + +/// This class represents a function definition itself. +class FunctionAST { + std::unique_ptr<PrototypeAST> Proto; + std::unique_ptr<ExprASTList> Body; + +public: + FunctionAST(std::unique_ptr<PrototypeAST> Proto, + std::unique_ptr<ExprASTList> Body) + : Proto(std::move(Proto)), Body(std::move(Body)) {} + PrototypeAST *getProto() { return Proto.get(); } + ExprASTList *getBody() { return Body.get(); } +}; + +/// This class represents a list of functions to be processed together +class ModuleAST { + std::vector<FunctionAST> functions; + +public: + ModuleAST(std::vector<FunctionAST> functions) + : functions(std::move(functions)) {} + + auto begin() -> decltype(functions.begin()) { return functions.begin(); } + auto end() -> decltype(functions.end()) { return functions.end(); } +}; + +void dump(ModuleAST &); + +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_AST_H_ diff --git a/include/toy/Dialect.h b/include/toy/Dialect.h new file mode 100644 index 0000000..9d7f82d --- /dev/null +++ b/include/toy/Dialect.h @@ -0,0 +1,393 @@ +//===- Dialect.h - Dialect definition for the Toy IR ----------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements the IR Dialect for the Toy language. +// See g3doc/Tutorials/Toy/Ch-3.md for more information. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_ +#define MLIR_TUTORIAL_TOY_DIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" + +namespace mlir { +class Builder; +} + +namespace toy { + +/// This is the definition of the Toy dialect. A dialect inherits from +/// mlir::Dialect and register custom operations and types (in its constructor). +/// It can also overridding general behavior of dialects exposed as virtual +/// method, for example regarding verification and parsing/printing. +class ToyDialect : public mlir::Dialect { +public: + explicit ToyDialect(mlir::MLIRContext *ctx); + + /// Parse a type registered to this dialect. Overridding this method is + /// required for dialects that have custom types. + /// Technically this is only needed to be able to round-trip to textual IR. + mlir::Type parseType(llvm::StringRef tyData, + mlir::Location loc) const override; + + /// Print a type registered to this dialect. Overridding this method is + /// only required for dialects that have custom types. + /// Technically this is only needed to be able to round-trip to textual IR. + void printType(mlir::Type type, llvm::raw_ostream &os) const override; +}; + +//////////////////////////////////////////////////////////////////////////////// +/////////////////////// Custom Types for the Dialect /////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +namespace detail { +struct ToyArrayTypeStorage; +} + +/// LLVM-style RTTI: one entry per subclass to allow dyn_cast/isa. +enum ToyTypeKind { + // The enum starts at the range reserved for this dialect. + TOY_TYPE = mlir::Type::FIRST_TOY_TYPE, + TOY_ARRAY, +}; + +/// Type for Toy arrays. +/// In MLIR Types are reference to immutable and uniqued objects owned by the +/// MLIRContext. As such `ToyArrayType` only wraps a pointer to an uniqued +/// instance of `ToyArrayTypeStorage` (defined in our implementation file) and +/// provides the public facade API to interact with the type. +class ToyArrayType : public mlir::Type::TypeBase<ToyArrayType, mlir::Type, + detail::ToyArrayTypeStorage> { +public: + using Base::Base; + + /// Returns the dimensions for this array, or and empty range for a generic + /// array. + llvm::ArrayRef<int64_t> getShape(); + + /// Predicate to test if this array is generic (shape haven't been inferred + /// yet). + bool isGeneric() { return getShape().empty(); } + + /// Return the rank of this array (0 if it is generic). + int getRank() { return getShape().size(); } + + /// Return the type of individual elements in the array. + mlir::Type getElementType(); + + /// Get a MemRef equivalent to this array type. + mlir::MemRefType toMemref(); + + /// Get the unique instance of this Type from the context. + /// A ToyArrayType is only defined by the shape of the array. + static ToyArrayType get(mlir::MLIRContext *context, + llvm::ArrayRef<int64_t> shape = {}); + + /// Support method to enable LLVM-style RTTI type casting. + static bool kindof(unsigned kind) { return kind == ToyTypeKind::TOY_ARRAY; } +}; + +//////////////////////////////////////////////////////////////////////////////// +//////////////////// Custom Operations for the Dialect ///////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +/// Constant operation turns a literal into an SSA value. The data is attached +/// to the operation as an attribute. For example: +/// +/// %0 = "toy.constant"() +/// {value: dense<tensor<2x3xf64>, [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>} +/// : () -> !toy<"array<2, 3>"> +/// +/// An operation inherits from `class Op` and specifies optional traits. Here we +/// indicate that `toy.constant` does not have any operands and returns a single +/// result. The traits provide some utilities methods for the operation, for +/// instance we will be able to use `getResult()`, but `getOperand()` won't be +/// available. +class ConstantOp : public mlir::Op<ConstantOp, mlir::OpTrait::ZeroOperands, + mlir::OpTrait::OneResult, + mlir::OpTrait::HasNoSideEffect> { +public: + /// This is the name used by MLIR to match an operation to this class during + /// parsing. + static llvm::StringRef getOperationName() { return "toy.constant"; } + + /// The operation can have extra verification beyond the traits they define. + mlir::LogicalResult verify(); + + /// Interface to mlir::Builder::create<PrintOp>(...) + /// This method populates the `state` that MLIR uses to create operations. + /// The `toy.constant` operation does not have arguments but attaches a + /// constant array as an attribute and returns it as an SSA value. + static void build(mlir::Builder *builder, mlir::OperationState *state, + llvm::ArrayRef<int64_t> shape, + mlir::DenseElementsAttr value); + + /// Similar to the one above, but takes a single float and returns a + /// !toy<"array<1>">. + static void build(mlir::Builder *builder, mlir::OperationState *state, + mlir::FloatAttr value); + + mlir::DenseElementsAttr getValue() { + return getAttr("value").cast<mlir::DenseElementsAttr>(); + } + + /// Inherit constructor. + using Op::Op; +}; + +/// Generic calls represent calls to a user defined function that needs to +/// be specialized for the shape of its arguments. The callee name is attached +/// as a literal string as an attribute. The arguments list must match the +/// arguments expected by the callee. For example: +/// +/// %4 = "toy.generic_call"(%1, %3) {callee: "my_func"} +/// : (!toy<"array<2, 3>">, !toy<"array<2, 3>">) -> !toy<"array"> +/// +/// This is only valid if a function named "my_func" exists and takes two +/// arguments. +class GenericCallOp + : public mlir::Op<GenericCallOp, mlir::OpTrait::VariadicOperands, + mlir::OpTrait::OneResult> { +public: + /// MLIR will use this to register the operation with the parser/printer. + static llvm::StringRef getOperationName() { return "toy.generic_call"; } + + /// Operations can add custom verification beyond the traits they define. + mlir::LogicalResult verify(); + + /// Interface to the builder to allow: + /// mlir::Builder::create<GenericCallOp>(...) + /// This method populate the `state` that MLIR use to create operations. + /// The `toy.generic_call` operation accepts a callee name and a list of + /// arguments for the call. + static void build(mlir::Builder *builder, mlir::OperationState *state, + llvm::StringRef callee, + llvm::ArrayRef<mlir::Value *> arguments); + + /// Return the name of the callee. + llvm::StringRef getCalleeName(); + + /// Inherit constructor. + using Op::Op; +}; + +/// Return operations terminate blocks (and functions as well). They take a +/// single argument and the type must match the function return type. +class ReturnOp + : public mlir::Op<ReturnOp, mlir::OpTrait::VariadicOperands, + mlir::OpTrait::ZeroResult, mlir::OpTrait::IsTerminator> { +public: + static llvm::StringRef getOperationName() { return "toy.return"; } + + /// Operations can add custom verification beyond the traits they define. + mlir::LogicalResult verify(); + + /// Interface to mlir::Builder::create<PrintOp>(...) + /// This method populate the `state` that MLIR use to create operations. + /// The `toy.return` operation accepts an optional single array as an argument + /// and does not have any returned value. + static void build(mlir::Builder *builder, mlir::OperationState *state, + mlir::Value *value = nullptr); + + /// Return true if there is a returned value. + bool hasOperand() { return 0 != getNumOperands(); } + + /// Helper to return the optional operand. Caller must check if the operand + /// is present before calling this. + mlir::Value *getOperand() { return getOperation()->getOperand(0); } + + /// Inherit constructor. + using Op::Op; +}; + +/// The print builtin takes a single array argument and does not return any. +class PrintOp : public mlir::Op<PrintOp, mlir::OpTrait::OneOperand, + mlir::OpTrait::ZeroResult> { +public: + static llvm::StringRef getOperationName() { return "toy.print"; } + + /// Operations can add custom verification beyond the traits they define. + mlir::LogicalResult verify(); + + /// Interface to mlir::Builder::create<PrintOp>(...) + /// This method populate the `state` that MLIR use to create operations. + /// The `toy.print` operation accepts a single array as argument and does + /// not have any returned value. + static void build(mlir::Builder *builder, mlir::OperationState *state, + mlir::Value *value); + + /// Inherit constructor. + using Op::Op; +}; + +class TransposeOp : public mlir::Op<TransposeOp, mlir::OpTrait::OneOperand, + mlir::OpTrait::OneResult, + mlir::OpTrait::HasNoSideEffect> { +public: + static llvm::StringRef getOperationName() { return "toy.transpose"; } + + /// Operation can add custom verification beyond the traits they define. + mlir::LogicalResult verify(); + + /// Interface to mlir::Builder::create<TransposeOp>(...) + /// This method populate the `state` that MLIR use to create operations. + /// The `toy.transpose` operation accepts a single array as argument and + /// returns the transposed array as its only result. + static void build(mlir::Builder *builder, mlir::OperationState *state, + mlir::Value *value); + + // Register our patterns for rewrite by the Canonicalization framework. + static void + getCanonicalizationPatterns(mlir::OwningRewritePatternList &results, + mlir::MLIRContext *context); + + /// Inherit constructor. + using Op::Op; +}; + +/// Reshape operation is transforming its input array into a new array with the +/// same number of elements but different shapes. For example: +/// +/// %0 = "toy.transpose"(%arg1) : (!toy<"array<10>">) -> !toy<"array<5, 2>"> +/// +class ReshapeOp : public mlir::Op<ReshapeOp, mlir::OpTrait::OneOperand, + mlir::OpTrait::OneResult, + mlir::OpTrait::HasNoSideEffect> { +public: + static llvm::StringRef getOperationName() { return "toy.reshape"; } + + /// Operation can add custom verification beyond the traits they define. + mlir::LogicalResult verify(); + + /// Interface to mlir::Builder::create<ReshapeOp>(...) + /// This method populate the `state` that MLIR use to create operations. + /// The `toy.reshape` operation accepts a single array as argument and + /// returns the array with the specified reshapedType as its only result. + static void build(mlir::Builder *builder, mlir::OperationState *state, + mlir::Value *value, ToyArrayType reshapedType); + + // Register our patterns for rewrite by the Canonicalization framework. + static void + getCanonicalizationPatterns(mlir::OwningRewritePatternList &results, + mlir::MLIRContext *context); + + /// Inherit constructor. + using Op::Op; +}; + +/// Binary operation implementing a multiplication. For two-dimensional array +/// a matrix multiplication is implemented, while for one dimensional array a +/// dot product is performed. +class MulOp : public mlir::Op<MulOp, mlir::OpTrait::NOperands<2>::Impl, + mlir::OpTrait::OneResult, + mlir::OpTrait::HasNoSideEffect> { +public: + static llvm::StringRef getOperationName() { return "toy.mul"; } + + /// Operation can add custom verification beyond the traits they define. + mlir::LogicalResult verify(); + + /// Interface to mlir::Builder::create<PrintOp>(...) + /// This method populate the `state` that MLIR use to create operations. + /// The `toy.mul` operation accepts two operands as argument and returns + /// a single value. + static void build(mlir::Builder *builder, mlir::OperationState *state, + mlir::Value *lhs, mlir::Value *rhs); + + /// Convenience accessor for LHS of the expression. + mlir::Value *getLHS() { return getOperand(0); } + + /// Convenience accessor for RHS of the expression. + mlir::Value *getRHS() { return getOperand(1); } + + /// Inherit constructor. + using Op::Op; +}; + +/// Element wise addition of two arrays. The shape must match. +class AddOp : public mlir::Op<AddOp, mlir::OpTrait::NOperands<2>::Impl, + mlir::OpTrait::OneResult, + mlir::OpTrait::HasNoSideEffect> { +public: + static llvm::StringRef getOperationName() { return "toy.add"; } + + /// Operation can add custom verification beyond the traits they define. + mlir::LogicalResult verify(); + + /// Interface to mlir::Builder::create<PrintOp>(...) + /// This method populate the `state` that MLIR use to create operations. + /// The `toy.mul` operation accepts two operands as argument and returns + /// a single value. + static void build(mlir::Builder *builder, mlir::OperationState *state, + mlir::Value *lhs, mlir::Value *rhs); + + /// Convenience accessor for LHS of the expression. + mlir::Value *getLHS() { return getOperand(0); } + + /// Convenience accessor for RHS of the expression. + mlir::Value *getRHS() { return getOperand(1); } + + /// Inherit constructor. + using Op::Op; +}; + +/// AllocOp is a temporary operation for buffer allocation, created as part of +/// partial lowering. +class AllocOp : public mlir::Op<AllocOp, mlir::OpTrait::ZeroOperands, + mlir::OpTrait::OneResult> { +public: + static llvm::StringRef getOperationName() { return "toy.alloc"; } + + /// Interface to mlir::Builder::create<AllocOp>(...) + /// This method populate the `state` that MLIR use to create operations. + /// `toy.alloc` does not have any argument and returns a toy array. + static void build(mlir::Builder *builder, mlir::OperationState *state, + mlir::Type retType); + + /// Inherit constructor. + using Op::Op; +}; + +/// FIXME: should be in std? +class TypeCastOp : public mlir::Op<TypeCastOp, mlir::OpTrait::OneOperand, + mlir::OpTrait::OneResult, + mlir::OpTrait::HasNoSideEffect> { +public: + static llvm::StringRef getOperationName() { return "toy.cast"; } + + static void build(mlir::Builder *builder, mlir::OperationState *state, + mlir::Value *value, mlir::Type destTy); + + // Register our patterns for rewrite by the Canonicalization framework. + static void + getCanonicalizationPatterns(mlir::OwningRewritePatternList &results, + mlir::MLIRContext *context); + + /// Inherit constructor. + using Op::Op; +}; + +} // end namespace toy + +#endif // MLIR_TUTORIAL_TOY_DIALECT_H_ diff --git a/include/toy/Lexer.h b/include/toy/Lexer.h new file mode 100644 index 0000000..d73adb9 --- /dev/null +++ b/include/toy/Lexer.h @@ -0,0 +1,239 @@ +//===- Lexer.h - Lexer for the Toy language -------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a simple Lexer for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_LEXER_H_ +#define MLIR_TUTORIAL_TOY_LEXER_H_ + +#include "llvm/ADT/StringRef.h" + +#include <memory> +#include <string> + +namespace toy { + +/// Structure definition a location in a file. +struct Location { + std::shared_ptr<std::string> file; ///< filename + int line; ///< line number. + int col; ///< column number. +}; + +// List of Token returned by the lexer. +enum Token : int { + tok_semicolon = ';', + tok_parenthese_open = '(', + tok_parenthese_close = ')', + tok_bracket_open = '{', + tok_bracket_close = '}', + tok_sbracket_open = '[', + tok_sbracket_close = ']', + + tok_eof = -1, + + // commands + tok_return = -2, + tok_var = -3, + tok_def = -4, + + // primary + tok_identifier = -5, + tok_number = -6, +}; + +/// The Lexer is an abstract base class providing all the facilities that the +/// Parser expects. It goes through the stream one token at a time and keeps +/// track of the location in the file for debugging purpose. +/// It relies on a subclass to provide a `readNextLine()` method. The subclass +/// can proceed by reading the next line from the standard input or from a +/// memory mapped file. +class Lexer { +public: + /// Create a lexer for the given filename. The filename is kept only for + /// debugging purpose (attaching a location to a Token). + Lexer(std::string filename) + : lastLocation( + {std::make_shared<std::string>(std::move(filename)), 0, 0}) {} + virtual ~Lexer() = default; + + /// Look at the current token in the stream. + Token getCurToken() { return curTok; } + + /// Move to the next token in the stream and return it. + Token getNextToken() { return curTok = getTok(); } + + /// Move to the next token in the stream, asserting on the current token + /// matching the expectation. + void consume(Token tok) { + assert(tok == curTok && "consume Token mismatch expectation"); + getNextToken(); + } + + /// Return the current identifier (prereq: getCurToken() == tok_identifier) + llvm::StringRef getId() { + assert(curTok == tok_identifier); + return IdentifierStr; + } + + /// Return the current number (prereq: getCurToken() == tok_number) + double getValue() { + assert(curTok == tok_number); + return NumVal; + } + + /// Return the location for the beginning of the current token. + Location getLastLocation() { return lastLocation; } + + // Return the current line in the file. + int getLine() { return curLineNum; } + + // Return the current column in the file. + int getCol() { return curCol; } + +private: + /// Delegate to a derived class fetching the next line. Returns an empty + /// string to signal end of file (EOF). Lines are expected to always finish + /// with "\n" + virtual llvm::StringRef readNextLine() = 0; + + /// Return the next character from the stream. This manages the buffer for the + /// current line and request the next line buffer to the derived class as + /// needed. + int getNextChar() { + // The current line buffer should not be empty unless it is the end of file. + if (curLineBuffer.empty()) + return EOF; + ++curCol; + auto nextchar = curLineBuffer.front(); + curLineBuffer = curLineBuffer.drop_front(); + if (curLineBuffer.empty()) + curLineBuffer = readNextLine(); + if (nextchar == '\n') { + ++curLineNum; + curCol = 0; + } + return nextchar; + } + + /// Return the next token from standard input. + Token getTok() { + // Skip any whitespace. + while (isspace(LastChar)) + LastChar = Token(getNextChar()); + + // Save the current location before reading the token characters. + lastLocation.line = curLineNum; + lastLocation.col = curCol; + + if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9_]* + IdentifierStr = (char)LastChar; + while (isalnum((LastChar = Token(getNextChar()))) || LastChar == '_') + IdentifierStr += (char)LastChar; + + if (IdentifierStr == "return") + return tok_return; + if (IdentifierStr == "def") + return tok_def; + if (IdentifierStr == "var") + return tok_var; + return tok_identifier; + } + + if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+ + std::string NumStr; + do { + NumStr += LastChar; + LastChar = Token(getNextChar()); + } while (isdigit(LastChar) || LastChar == '.'); + + NumVal = strtod(NumStr.c_str(), nullptr); + return tok_number; + } + + if (LastChar == '#') { + // Comment until end of line. + do + LastChar = Token(getNextChar()); + while (LastChar != EOF && LastChar != '\n' && LastChar != '\r'); + + if (LastChar != EOF) + return getTok(); + } + + // Check for end of file. Don't eat the EOF. + if (LastChar == EOF) + return tok_eof; + + // Otherwise, just return the character as its ascii value. + Token ThisChar = Token(LastChar); + LastChar = Token(getNextChar()); + return ThisChar; + } + + /// The last token read from the input. + Token curTok = tok_eof; + + /// Location for `curTok`. + Location lastLocation; + + /// If the current Token is an identifier, this string contains the value. + std::string IdentifierStr; + + /// If the current Token is a number, this contains the value. + double NumVal = 0; + + /// The last value returned by getNextChar(). We need to keep it around as we + /// always need to read ahead one character to decide when to end a token and + /// we can't put it back in the stream after reading from it. + Token LastChar = Token(' '); + + /// Keep track of the current line number in the input stream + int curLineNum = 0; + + /// Keep track of the current column number in the input stream + int curCol = 0; + + /// Buffer supplied by the derived class on calls to `readNextLine()` + llvm::StringRef curLineBuffer = "\n"; +}; + +/// A lexer implementation operating on a buffer in memory. +class LexerBuffer final : public Lexer { +public: + LexerBuffer(const char *begin, const char *end, std::string filename) + : Lexer(std::move(filename)), current(begin), end(end) {} + +private: + /// Provide one line at a time to the Lexer, return an empty string when + /// reaching the end of the buffer. + llvm::StringRef readNextLine() override { + auto *begin = current; + while (current <= end && *current && *current != '\n') + ++current; + if (current <= end && *current) + ++current; + llvm::StringRef result{begin, static_cast<size_t>(current - begin)}; + return result; + } + const char *current, *end; +}; +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_LEXER_H_ diff --git a/include/toy/Lowering.h b/include/toy/Lowering.h new file mode 100644 index 0000000..362a342 --- /dev/null +++ b/include/toy/Lowering.h @@ -0,0 +1,45 @@ +//===- Lowering.h - Lexer for the Toy language ----------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file exposes the interface to the lowering for Toy. It is divided in +// two parts: an *early lowering* that emits operations in the `Linalg` +// dialects for a subset of the Toy IR, and a *late lowering* that materializes +// buffers and converts all operations and type to the LLVM dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_EXAMPLES_TOY_LOWERING_H_ +#define MLIR_EXAMPLES_TOY_LOWERING_H_ + +#include <memory> + +namespace mlir { +class Pass; +class DialectConversion; +} // namespace mlir + +namespace toy { +/// Create a pass for lowering operations in the `Linalg` dialects, for a subset +/// of the Toy IR (matmul). +mlir::Pass *createEarlyLoweringPass(); + +/// Create a pass for the late lowering toward LLVM dialect. +mlir::Pass *createLateLoweringPass(); + +} // namespace toy + +#endif // MLIR_EXAMPLES_TOY_LOWERING_H_ diff --git a/include/toy/MLIRGen.h b/include/toy/MLIRGen.h new file mode 100644 index 0000000..21637bc --- /dev/null +++ b/include/toy/MLIRGen.h @@ -0,0 +1,42 @@ +//===- MLIRGen.h - MLIR Generation from a Toy AST -------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file declares a simple interface to perform IR generation targeting MLIR +// from a Module AST for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_MLIRGEN_H_ +#define MLIR_TUTORIAL_TOY_MLIRGEN_H_ + +#include <memory> + +namespace mlir { +class MLIRContext; +class Module; +} // namespace mlir + +namespace toy { +class ModuleAST; + +/// Emit IR for the given Toy moduleAST, returns a newly created MLIR module +/// or nullptr on failure. +std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context, + ModuleAST &moduleAST); +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_MLIRGEN_H_ diff --git a/include/toy/Parser.h b/include/toy/Parser.h new file mode 100644 index 0000000..bc7aa52 --- /dev/null +++ b/include/toy/Parser.h @@ -0,0 +1,494 @@ +//===- Parser.h - Toy Language Parser -------------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements the parser for the Toy language. It processes the Token +// provided by the Lexer and returns an AST. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_PARSER_H +#define MLIR_TUTORIAL_TOY_PARSER_H + +#include "toy/AST.h" +#include "toy/Lexer.h" + +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/raw_ostream.h" + +#include <map> +#include <utility> +#include <vector> + +namespace toy { + +/// This is a simple recursive parser for the Toy language. It produces a well +/// formed AST from a stream of Token supplied by the Lexer. No semantic checks +/// or symbol resolution is performed. For example, variables are referenced by +/// string and the code could reference an undeclared variable and the parsing +/// succeeds. +class Parser { +public: + /// Create a Parser for the supplied lexer. + Parser(Lexer &lexer) : lexer(lexer) {} + + /// Parse a full Module. A module is a list of function definitions. + std::unique_ptr<ModuleAST> ParseModule() { + lexer.getNextToken(); // prime the lexer + + // Parse functions one at a time and accumulate in this vector. + std::vector<FunctionAST> functions; + while (auto F = ParseDefinition()) { + functions.push_back(std::move(*F)); + if (lexer.getCurToken() == tok_eof) + break; + } + // If we didn't reach EOF, there was an error during parsing + if (lexer.getCurToken() != tok_eof) + return parseError<ModuleAST>("nothing", "at end of module"); + + return llvm::make_unique<ModuleAST>(std::move(functions)); + } + +private: + Lexer &lexer; + + /// Parse a return statement. + /// return :== return ; | return expr ; + std::unique_ptr<ReturnExprAST> ParseReturn() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_return); + + // return takes an optional argument + llvm::Optional<std::unique_ptr<ExprAST>> expr; + if (lexer.getCurToken() != ';') { + expr = ParseExpression(); + if (!expr) + return nullptr; + } + return llvm::make_unique<ReturnExprAST>(std::move(loc), std::move(expr)); + } + + /// Parse a literal number. + /// numberexpr ::= number + std::unique_ptr<ExprAST> ParseNumberExpr() { + auto loc = lexer.getLastLocation(); + auto Result = + llvm::make_unique<NumberExprAST>(std::move(loc), lexer.getValue()); + lexer.consume(tok_number); + return std::move(Result); + } + + /// Parse a literal array expression. + /// tensorLiteral ::= [ literalList ] | number + /// literalList ::= tensorLiteral | tensorLiteral, literalList + std::unique_ptr<ExprAST> ParseTensorLitteralExpr() { + auto loc = lexer.getLastLocation(); + lexer.consume(Token('[')); + + // Hold the list of values at this nesting level. + std::vector<std::unique_ptr<ExprAST>> values; + // Hold the dimensions for all the nesting inside this level. + std::vector<int64_t> dims; + do { + // We can have either another nested array or a number literal. + if (lexer.getCurToken() == '[') { + values.push_back(ParseTensorLitteralExpr()); + if (!values.back()) + return nullptr; // parse error in the nested array. + } else { + if (lexer.getCurToken() != tok_number) + return parseError<ExprAST>("<num> or [", "in literal expression"); + values.push_back(ParseNumberExpr()); + } + + // End of this list on ']' + if (lexer.getCurToken() == ']') + break; + + // Elements are separated by a comma. + if (lexer.getCurToken() != ',') + return parseError<ExprAST>("] or ,", "in literal expression"); + + lexer.getNextToken(); // eat , + } while (true); + if (values.empty()) + return parseError<ExprAST>("<something>", "to fill literal expression"); + lexer.getNextToken(); // eat ] + /// Fill in the dimensions now. First the current nesting level: + dims.push_back(values.size()); + /// If there is any nested array, process all of them and ensure that + /// dimensions are uniform. + if (llvm::any_of(values, [](std::unique_ptr<ExprAST> &expr) { + return llvm::isa<LiteralExprAST>(expr.get()); + })) { + auto *firstLiteral = llvm::dyn_cast<LiteralExprAST>(values.front().get()); + if (!firstLiteral) + return parseError<ExprAST>("uniform well-nested dimensions", + "inside literal expession"); + + // Append the nested dimensions to the current level + auto &firstDims = firstLiteral->getDims(); + dims.insert(dims.end(), firstDims.begin(), firstDims.end()); + + // Sanity check that shape is uniform across all elements of the list. + for (auto &expr : values) { + auto *exprLiteral = llvm::cast<LiteralExprAST>(expr.get()); + if (!exprLiteral) + return parseError<ExprAST>("uniform well-nested dimensions", + "inside literal expession"); + if (exprLiteral->getDims() != firstDims) + return parseError<ExprAST>("uniform well-nested dimensions", + "inside literal expession"); + } + } + return llvm::make_unique<LiteralExprAST>(std::move(loc), std::move(values), + std::move(dims)); + } + + /// parenexpr ::= '(' expression ')' + std::unique_ptr<ExprAST> ParseParenExpr() { + lexer.getNextToken(); // eat (. + auto V = ParseExpression(); + if (!V) + return nullptr; + + if (lexer.getCurToken() != ')') + return parseError<ExprAST>(")", "to close expression with parentheses"); + lexer.consume(Token(')')); + return V; + } + + /// identifierexpr + /// ::= identifier + /// ::= identifier '(' expression ')' + std::unique_ptr<ExprAST> ParseIdentifierExpr() { + std::string name = lexer.getId(); + + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat identifier. + + if (lexer.getCurToken() != '(') // Simple variable ref. + return llvm::make_unique<VariableExprAST>(std::move(loc), name); + + // This is a function call. + lexer.consume(Token('(')); + std::vector<std::unique_ptr<ExprAST>> Args; + if (lexer.getCurToken() != ')') { + while (true) { + if (auto Arg = ParseExpression()) + Args.push_back(std::move(Arg)); + else + return nullptr; + + if (lexer.getCurToken() == ')') + break; + + if (lexer.getCurToken() != ',') + return parseError<ExprAST>(", or )", "in argument list"); + lexer.getNextToken(); + } + } + lexer.consume(Token(')')); + + // It can be a builtin call to print + if (name == "print") { + if (Args.size() != 1) + return parseError<ExprAST>("<single arg>", "as argument to print()"); + + return llvm::make_unique<PrintExprAST>(std::move(loc), + std::move(Args[0])); + } + + // Call to a user-defined function + return llvm::make_unique<CallExprAST>(std::move(loc), name, + std::move(Args)); + } + + /// primary + /// ::= identifierexpr + /// ::= numberexpr + /// ::= parenexpr + /// ::= tensorliteral + std::unique_ptr<ExprAST> ParsePrimary() { + switch (lexer.getCurToken()) { + default: + llvm::errs() << "unknown token '" << lexer.getCurToken() + << "' when expecting an expression\n"; + return nullptr; + case tok_identifier: + return ParseIdentifierExpr(); + case tok_number: + return ParseNumberExpr(); + case '(': + return ParseParenExpr(); + case '[': + return ParseTensorLitteralExpr(); + case ';': + return nullptr; + case '}': + return nullptr; + } + } + + /// Recursively parse the right hand side of a binary expression, the ExprPrec + /// argument indicates the precedence of the current binary operator. + /// + /// binoprhs ::= ('+' primary)* + std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec, + std::unique_ptr<ExprAST> LHS) { + // If this is a binop, find its precedence. + while (true) { + int TokPrec = GetTokPrecedence(); + + // If this is a binop that binds at least as tightly as the current binop, + // consume it, otherwise we are done. + if (TokPrec < ExprPrec) + return LHS; + + // Okay, we know this is a binop. + int BinOp = lexer.getCurToken(); + lexer.consume(Token(BinOp)); + auto loc = lexer.getLastLocation(); + + // Parse the primary expression after the binary operator. + auto RHS = ParsePrimary(); + if (!RHS) + return parseError<ExprAST>("expression", "to complete binary operator"); + + // If BinOp binds less tightly with RHS than the operator after RHS, let + // the pending operator take RHS as its LHS. + int NextPrec = GetTokPrecedence(); + if (TokPrec < NextPrec) { + RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS)); + if (!RHS) + return nullptr; + } + + // Merge LHS/RHS. + LHS = llvm::make_unique<BinaryExprAST>(std::move(loc), BinOp, + std::move(LHS), std::move(RHS)); + } + } + + /// expression::= primary binoprhs + std::unique_ptr<ExprAST> ParseExpression() { + auto LHS = ParsePrimary(); + if (!LHS) + return nullptr; + + return ParseBinOpRHS(0, std::move(LHS)); + } + + /// type ::= < shape_list > + /// shape_list ::= num | num , shape_list + std::unique_ptr<VarType> ParseType() { + if (lexer.getCurToken() != '<') + return parseError<VarType>("<", "to begin type"); + lexer.getNextToken(); // eat < + + auto type = llvm::make_unique<VarType>(); + + while (lexer.getCurToken() == tok_number) { + type->shape.push_back(lexer.getValue()); + lexer.getNextToken(); + if (lexer.getCurToken() == ',') + lexer.getNextToken(); + } + + if (lexer.getCurToken() != '>') + return parseError<VarType>(">", "to end type"); + lexer.getNextToken(); // eat > + return type; + } + + /// Parse a variable declaration, it starts with a `var` keyword followed by + /// and identifier and an optional type (shape specification) before the + /// initializer. + /// decl ::= var identifier [ type ] = expr + std::unique_ptr<VarDeclExprAST> ParseDeclaration() { + if (lexer.getCurToken() != tok_var) + return parseError<VarDeclExprAST>("var", "to begin declaration"); + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat var + + if (lexer.getCurToken() != tok_identifier) + return parseError<VarDeclExprAST>("identified", + "after 'var' declaration"); + std::string id = lexer.getId(); + lexer.getNextToken(); // eat id + + std::unique_ptr<VarType> type; // Type is optional, it can be inferred + if (lexer.getCurToken() == '<') { + type = ParseType(); + if (!type) + return nullptr; + } + + if (!type) + type = llvm::make_unique<VarType>(); + lexer.consume(Token('=')); + auto expr = ParseExpression(); + return llvm::make_unique<VarDeclExprAST>(std::move(loc), std::move(id), + std::move(*type), std::move(expr)); + } + + /// Parse a block: a list of expression separated by semicolons and wrapped in + /// curly braces. + /// + /// block ::= { expression_list } + /// expression_list ::= block_expr ; expression_list + /// block_expr ::= decl | "return" | expr + std::unique_ptr<ExprASTList> ParseBlock() { + if (lexer.getCurToken() != '{') + return parseError<ExprASTList>("{", "to begin block"); + lexer.consume(Token('{')); + + auto exprList = llvm::make_unique<ExprASTList>(); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + + while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { + if (lexer.getCurToken() == tok_var) { + // Variable declaration + auto varDecl = ParseDeclaration(); + if (!varDecl) + return nullptr; + exprList->push_back(std::move(varDecl)); + } else if (lexer.getCurToken() == tok_return) { + // Return statement + auto ret = ParseReturn(); + if (!ret) + return nullptr; + exprList->push_back(std::move(ret)); + } else { + // General expression + auto expr = ParseExpression(); + if (!expr) + return nullptr; + exprList->push_back(std::move(expr)); + } + // Ensure that elements are separated by a semicolon. + if (lexer.getCurToken() != ';') + return parseError<ExprASTList>(";", "after expression"); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + } + + if (lexer.getCurToken() != '}') + return parseError<ExprASTList>("}", "to close block"); + + lexer.consume(Token('}')); + return exprList; + } + + /// prototype ::= def id '(' decl_list ')' + /// decl_list ::= identifier | identifier, decl_list + std::unique_ptr<PrototypeAST> ParsePrototype() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_def); + if (lexer.getCurToken() != tok_identifier) + return parseError<PrototypeAST>("function name", "in prototype"); + + std::string FnName = lexer.getId(); + lexer.consume(tok_identifier); + + if (lexer.getCurToken() != '(') + return parseError<PrototypeAST>("(", "in prototype"); + lexer.consume(Token('(')); + + std::vector<std::unique_ptr<VariableExprAST>> args; + if (lexer.getCurToken() != ')') { + do { + std::string name = lexer.getId(); + auto loc = lexer.getLastLocation(); + lexer.consume(tok_identifier); + auto decl = llvm::make_unique<VariableExprAST>(std::move(loc), name); + args.push_back(std::move(decl)); + if (lexer.getCurToken() != ',') + break; + lexer.consume(Token(',')); + if (lexer.getCurToken() != tok_identifier) + return parseError<PrototypeAST>( + "identifier", "after ',' in function parameter list"); + } while (true); + } + if (lexer.getCurToken() != ')') + return parseError<PrototypeAST>("}", "to end function prototype"); + + // success. + lexer.consume(Token(')')); + return llvm::make_unique<PrototypeAST>(std::move(loc), FnName, + std::move(args)); + } + + /// Parse a function definition, we expect a prototype initiated with the + /// `def` keyword, followed by a block containing a list of expressions. + /// + /// definition ::= prototype block + std::unique_ptr<FunctionAST> ParseDefinition() { + auto Proto = ParsePrototype(); + if (!Proto) + return nullptr; + + if (auto block = ParseBlock()) + return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(block)); + return nullptr; + } + + /// Get the precedence of the pending binary operator token. + int GetTokPrecedence() { + if (!isascii(lexer.getCurToken())) + return -1; + + // 1 is lowest precedence. + switch (static_cast<char>(lexer.getCurToken())) { + case '-': + return 20; + case '+': + return 20; + case '*': + return 40; + default: + return -1; + } + } + + /// Helper function to signal errors while parsing, it takes an argument + /// indicating the expected token and another argument giving more context. + /// Location is retrieved from the lexer to enrich the error message. + template <typename R, typename T, typename U = const char *> + std::unique_ptr<R> parseError(T &&expected, U &&context = "") { + auto curToken = lexer.getCurToken(); + llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", " + << lexer.getLastLocation().col << "): expected '" << expected + << "' " << context << " but has Token " << curToken; + if (isprint(curToken)) + llvm::errs() << " '" << (char)curToken << "'"; + llvm::errs() << "\n"; + return nullptr; + } +}; + +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_PARSER_H diff --git a/include/toy/Passes.h b/include/toy/Passes.h new file mode 100644 index 0000000..dd73b95 --- /dev/null +++ b/include/toy/Passes.h @@ -0,0 +1,33 @@ +//===- Passes.h - Toy Passes Definition -----------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file exposes the entry points to create compiler passes for Toy. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_PASSES_H +#define MLIR_TUTORIAL_TOY_PASSES_H + +namespace mlir { +class Pass; +} // namespace mlir + +namespace toy { +mlir::Pass *createShapeInferencePass(); +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_PASSES_H |