summaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
Diffstat (limited to 'include')
-rw-r--r--include/linalg1/Analysis.h49
-rw-r--r--include/linalg1/Common.h120
-rw-r--r--include/linalg1/ConvertToLLVMDialect.h66
-rw-r--r--include/linalg1/Dialect.h42
-rw-r--r--include/linalg1/Intrinsics.h32
-rw-r--r--include/linalg1/LLVMIntrinsics.h41
-rw-r--r--include/linalg1/Ops.h26
-rw-r--r--include/linalg1/RangeOp.h56
-rw-r--r--include/linalg1/RangeType.h49
-rw-r--r--include/linalg1/SliceOp.h91
-rw-r--r--include/linalg1/Types.h36
-rw-r--r--include/linalg1/Utils.h37
-rw-r--r--include/linalg1/ViewOp.h67
-rw-r--r--include/linalg1/ViewType.h57
-rw-r--r--include/linalg2/Analysis.h23
-rw-r--r--include/linalg2/Intrinsics.h32
-rw-r--r--include/linalg2/Ops.h24
-rw-r--r--include/linalg2/TensorOps-inl.h120
-rw-r--r--include/linalg2/TensorOps.h287
-rw-r--r--include/linalg2/Transforms.h36
-rw-r--r--include/linalg3/Analysis.h37
-rw-r--r--include/linalg3/ConvertToLLVMDialect.h29
-rw-r--r--include/linalg3/Intrinsics.h31
-rw-r--r--include/linalg3/LoadStoreOps.h89
-rw-r--r--include/linalg3/Ops.h25
-rw-r--r--include/linalg3/TensorOps-inl.h145
-rw-r--r--include/linalg3/TensorOps.h54
-rw-r--r--include/linalg3/Transforms.h80
-rw-r--r--include/toy/AST.h256
-rw-r--r--include/toy/Dialect.h393
-rw-r--r--include/toy/Lexer.h239
-rw-r--r--include/toy/Lowering.h45
-rw-r--r--include/toy/MLIRGen.h42
-rw-r--r--include/toy/Parser.h494
-rw-r--r--include/toy/Passes.h33
35 files changed, 3283 insertions, 0 deletions
diff --git a/include/linalg1/Analysis.h b/include/linalg1/Analysis.h
new file mode 100644
index 0000000..ef8fb98
--- /dev/null
+++ b/include/linalg1/Analysis.h
@@ -0,0 +1,49 @@
+//===- Analysis.h - Linalg dialect Analysis function definitions ----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG1_ANALYSIS_H_
+#define LINALG1_ANALYSIS_H_
+
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+class Value;
+} // namespace mlir
+
+namespace linalg {
+class ViewOp;
+
+/// Walks the chain of SliceOp until the unique base ViewOp.
+ViewOp getViewBaseViewOp(mlir::Value *view);
+
+/// Walks the chain of SliceOp until the unique base ViewOp and returns the
+/// MemRef upon which the ViewOp is laid.
+mlir::Value *getViewSupportingMemRef(mlir::Value *view);
+
+/// Extract the indexing from the root ViewOp that this slice constrins along
+/// `dim`. To achieve this, it walks back the chain of SliceOp and determine the
+/// first slice that constrains `dim`.
+/// Note that the dimension in the original ViewOp may shift due to
+/// rank-reducing operations.
+/// Returns a pair, with the indexing as the first element and the actual
+/// dimension, in the root ViewOp, as the second element.
+std::pair<mlir::Value *, unsigned> getViewRootIndexing(mlir::Value *view,
+ unsigned dim);
+
+} // namespace linalg
+
+#endif // LINALG1_ANALYSIS_H_
diff --git a/include/linalg1/Common.h b/include/linalg1/Common.h
new file mode 100644
index 0000000..6573c72
--- /dev/null
+++ b/include/linalg1/Common.h
@@ -0,0 +1,120 @@
+//===- Common.h - Linalg dialect RangeOp operation -----------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG1_COMMON_H_
+#define LINALG1_COMMON_H_
+
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/EDSC/Builders.h"
+#include "mlir/EDSC/Helpers.h"
+#include "mlir/EDSC/Intrinsics.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Identifier.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/StandardOps/Ops.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/LoopUtils.h"
+#include "mlir/Transforms/Passes.h"
+
+namespace linalg {
+namespace common {
+
+////////////////////////////////////////////////////////////////////////////////
+// Define a few boilerplate objects used across all linalg examples.
+////////////////////////////////////////////////////////////////////////////////
+
+/// A 2-D abstraction over a flat contiguous memory region of f32 with symbolic
+/// sizes.
+template <int N>
+inline mlir::MemRefType floatMemRefType(mlir::MLIRContext *context,
+ unsigned memorySpace = 0) {
+ llvm::SmallVector<int64_t, 4> shape(N, -1);
+ auto f32 = mlir::FloatType::getF32(context);
+ return mlir::MemRefType::get(shape, f32, {}, memorySpace);
+}
+
+/// A basic function builder
+inline mlir::Function *makeFunction(mlir::Module &module, llvm::StringRef name,
+ llvm::ArrayRef<mlir::Type> types,
+ llvm::ArrayRef<mlir::Type> resultTypes) {
+ auto *context = module.getContext();
+ auto *function = new mlir::Function(
+ mlir::UnknownLoc::get(context), name,
+ mlir::FunctionType::get({types}, resultTypes, context));
+ function->addEntryBlock();
+ module.getFunctions().push_back(function);
+ return function;
+}
+
+/// A basic pass manager pre-populated with cleanup passes.
+inline std::unique_ptr<mlir::PassManager> cleanupPassManager() {
+ std::unique_ptr<mlir::PassManager> pm(new mlir::PassManager());
+ pm->addPass(mlir::createCanonicalizerPass());
+ pm->addPass(mlir::createSimplifyAffineStructuresPass());
+ pm->addPass(mlir::createCSEPass());
+ pm->addPass(mlir::createCanonicalizerPass());
+ return pm;
+}
+
+/// A simple function to verify and cleanup the IR before printing it to
+/// llvm::outs() for FileCheck'ing.
+/// If an error occurs, dump to llvm::errs() and do not print to llvm::outs()
+/// which will make the associated FileCheck test fail.
+inline void cleanupAndPrintFunction(mlir::Function *f) {
+ bool printToOuts = true;
+ auto check = [f, &printToOuts](mlir::LogicalResult result) {
+ if (failed(result)) {
+ f->getContext()->emitError(f->getLoc(),
+ "Verification and cleanup passes failed");
+ printToOuts = false;
+ }
+ };
+ auto pm = cleanupPassManager();
+ check(f->getModule()->verify());
+ check(pm->run(f->getModule()));
+ if (printToOuts)
+ f->print(llvm::outs());
+}
+
+/// Helper class to sugar building loop nests from indexings that appear in
+/// ViewOp and SliceOp.
+class LoopNestRangeBuilder {
+public:
+ LoopNestRangeBuilder(llvm::ArrayRef<mlir::edsc::ValueHandle *> ivs,
+ llvm::ArrayRef<mlir::edsc::ValueHandle> indexings);
+ LoopNestRangeBuilder(llvm::ArrayRef<mlir::edsc::ValueHandle *> ivs,
+ llvm::ArrayRef<mlir::Value *> indexings);
+ mlir::edsc::ValueHandle
+ operator()(llvm::ArrayRef<mlir::edsc::CapturableHandle> stmts);
+
+private:
+ llvm::SmallVector<mlir::edsc::LoopBuilder, 4> loops;
+};
+
+} // namespace common
+} // namespace linalg
+
+#endif // LINALG1_COMMON_H_
diff --git a/include/linalg1/ConvertToLLVMDialect.h b/include/linalg1/ConvertToLLVMDialect.h
new file mode 100644
index 0000000..8e5a7ce
--- /dev/null
+++ b/include/linalg1/ConvertToLLVMDialect.h
@@ -0,0 +1,66 @@
+//===- ConvertToLLVMDialect.h - conversion from Linalg to LLVM --*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG1_CONVERTTOLLVMDIALECT_H_
+#define LINALG1_CONVERTTOLLVMDIALECT_H_
+
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/Support/Allocator.h"
+
+#include <memory>
+
+namespace mlir {
+class DialectConversion;
+class DialectOpConversion;
+class MLIRContext;
+class Module;
+class Type;
+namespace LLVM {
+class LLVMType;
+} // end namespace LLVM
+} // end namespace mlir
+
+namespace linalg {
+/// Convert the given Linalg dialect type `t` into an LLVM IR dialect type.
+/// Keep all other types unmodified.
+mlir::Type convertLinalgType(mlir::Type t);
+
+/// Allocate the conversion patterns for RangeOp, ViewOp and SliceOp from the
+/// Linalg dialect to the LLVM IR dialect. The converters are allocated in the
+/// `allocator` using the provided `context`. The latter must have the LLVM IR
+/// dialect registered.
+/// This function can be used to apply multiple conversion patterns in the same
+/// pass. It does not have to be called explicitly before the conversion.
+llvm::DenseSet<mlir::DialectOpConversion *>
+allocateDescriptorConverters(llvm::BumpPtrAllocator *allocator,
+ mlir::MLIRContext *context);
+
+/// Create a DialectConversion from the Linalg dialect to the LLVM IR dialect.
+/// The conversion is set up to convert types and function signatures using
+/// `convertLinalgType` and obtains operation converters by calling `initer`.
+std::unique_ptr<mlir::DialectConversion> makeLinalgToLLVMLowering(
+ std::function<llvm::DenseSet<mlir::DialectOpConversion *>(
+ llvm::BumpPtrAllocator *, mlir::MLIRContext *context)>
+ initer);
+
+/// Convert the Linalg dialect types and RangeOp, ViewOp and SliceOp operations
+/// to the LLVM IR dialect types and operations in the given `module`. This is
+/// the main entry point to the conversion.
+void convertToLLVM(mlir::Module &module);
+} // end namespace linalg
+
+#endif // LINALG1_CONVERTTOLLVMDIALECT_H_
diff --git a/include/linalg1/Dialect.h b/include/linalg1/Dialect.h
new file mode 100644
index 0000000..70023e1
--- /dev/null
+++ b/include/linalg1/Dialect.h
@@ -0,0 +1,42 @@
+//===- Dialect.h - Definition of the Linalg dialect -----------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG1_DIALECT_H_
+#define LINALG1_DIALECT_H_
+
+#include "mlir/IR/Dialect.h"
+
+namespace linalg {
+
+/// The Linalg Dialect is not exposed to the outside world. It is registered by
+/// linking and accessed via generic MLIR accessors.
+class LinalgDialect : public mlir::Dialect {
+public:
+ /// Create a new Dialect that is registered on construction and adds the
+ /// relevant types and operations.
+ explicit LinalgDialect(mlir::MLIRContext *context);
+
+ /// Parse a type registered to this dialect.
+ mlir::Type parseType(llvm::StringRef spec, mlir::Location loc) const override;
+
+ /// Print a type registered to this dialect.
+ void printType(mlir::Type type, llvm::raw_ostream &os) const override;
+};
+
+} // namespace linalg
+
+#endif // LINALG1_DIALECT_H_
diff --git a/include/linalg1/Intrinsics.h b/include/linalg1/Intrinsics.h
new file mode 100644
index 0000000..305e3f3
--- /dev/null
+++ b/include/linalg1/Intrinsics.h
@@ -0,0 +1,32 @@
+//===- Intrinsics.h - Linalg intrinsics definitions -----------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG1_INTRINSICS_H_
+#define LINALG1_INTRINSICS_H_
+
+#include "linalg1/Ops.h"
+#include "mlir/EDSC/Intrinsics.h"
+
+namespace linalg {
+namespace intrinsics {
+using range = mlir::edsc::intrinsics::ValueBuilder<RangeOp>;
+using slice = mlir::edsc::intrinsics::ValueBuilder<SliceOp>;
+using view = mlir::edsc::intrinsics::ValueBuilder<ViewOp>;
+} // namespace intrinsics
+} // namespace linalg
+
+#endif // LINALG1_INTRINSICS_H_
diff --git a/include/linalg1/LLVMIntrinsics.h b/include/linalg1/LLVMIntrinsics.h
new file mode 100644
index 0000000..577981b
--- /dev/null
+++ b/include/linalg1/LLVMIntrinsics.h
@@ -0,0 +1,41 @@
+//===- LLVMIntrinsics.h - declarative builders for LLVM dialect -*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG1_LLVMINTRINSICS_H_
+#define LINALG1_LLVMINTRINSICS_H_
+
+#include "mlir/EDSC/Builders.h"
+#include "mlir/EDSC/Intrinsics.h"
+#include "mlir/LLVMIR/LLVMDialect.h"
+
+// Expose some LLVM IR instructions to declarative builders.
+namespace intrinsics {
+using undef = mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::UndefOp>;
+using insertvalue =
+ mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::InsertValueOp>;
+using extractvalue =
+ mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::ExtractValueOp>;
+using constant = mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::ConstantOp>;
+using add = mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::AddOp>;
+using sub = mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::SubOp>;
+using mul = mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::MulOp>;
+using load = mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::LoadOp>;
+using store = mlir::edsc::intrinsics::OperationBuilder<mlir::LLVM::StoreOp>;
+using gep = mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::GEPOp>;
+} // end namespace intrinsics
+
+#endif // LINALG1_LLVMINTRINSICS_H_
diff --git a/include/linalg1/Ops.h b/include/linalg1/Ops.h
new file mode 100644
index 0000000..2e662cf
--- /dev/null
+++ b/include/linalg1/Ops.h
@@ -0,0 +1,26 @@
+//===- Ops.h - Linalg Ops single entry point ------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG1_OPS_H_
+#define LINALG1_OPS_H_
+
+#include "linalg1/Types.h"
+#include "linalg1/RangeOp.h"
+#include "linalg1/SliceOp.h"
+#include "linalg1/ViewOp.h"
+
+#endif // LINALG1_OPS_H_
diff --git a/include/linalg1/RangeOp.h b/include/linalg1/RangeOp.h
new file mode 100644
index 0000000..9652f51
--- /dev/null
+++ b/include/linalg1/RangeOp.h
@@ -0,0 +1,56 @@
+//===- RangeOp.h - Linalg dialect RangeOp operation definition ------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG1_RANGEOP_H_
+#define LINALG1_RANGEOP_H_
+
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Support/LLVM.h"
+
+namespace linalg {
+
+/// A RangeOp is used to create a value of RangeType from 3 values of type index
+/// that represent the min, max and step values of the range.
+/// Note: step must be an mlir::ConstantIndexOp for now due to current
+/// `affine.for` limitations.
+class RangeOp : public mlir::Op<RangeOp, mlir::OpTrait::NOperands<3>::Impl,
+ mlir::OpTrait::OneResult,
+ mlir::OpTrait::HasNoSideEffect> {
+public:
+ using Op::Op;
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Hooks to customize the behavior of this op.
+ //////////////////////////////////////////////////////////////////////////////
+ static llvm::StringRef getOperationName() { return "linalg.range"; }
+ static void build(mlir::Builder *b, mlir::OperationState *result,
+ mlir::Value *min, mlir::Value *max, mlir::Value *step);
+ mlir::LogicalResult verify();
+ static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
+ void print(mlir::OpAsmPrinter *p);
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Op-specific functionality.
+ //////////////////////////////////////////////////////////////////////////////
+ mlir::Value *getMin() { return getOperand(0); }
+ mlir::Value *getMax() { return getOperand(1); }
+ mlir::Value *getStep() { return getOperand(2); }
+};
+
+} // namespace linalg
+
+#endif // LINALG1_RANGEOP_H_
diff --git a/include/linalg1/RangeType.h b/include/linalg1/RangeType.h
new file mode 100644
index 0000000..d17c058
--- /dev/null
+++ b/include/linalg1/RangeType.h
@@ -0,0 +1,49 @@
+//===- RangeType.h - Linalg RangeType definition --------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG1_RANGETYPE_H_
+#define LINALG1_RANGETYPE_H_
+
+#include "linalg1/Types.h"
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+class MLIRContext;
+}
+
+namespace linalg {
+
+/// A RangeType is the simplest possible form of a type in MLIR. It represents
+/// a minimal range abstraction (min, max, step). Since RangeType is constructed
+/// without any additional argument, this example illustrates the minimal
+/// amount of information required to implement a new custom MLIR type.
+class RangeType : public mlir::Type::TypeBase<RangeType, mlir::Type> {
+public:
+ // Used to implement llvm-style cast.
+ using Base::Base;
+ /// Construction hook.
+ static RangeType get(mlir::MLIRContext *context) {
+ /// Custom, uniqu'ed construction in the mlir::MLIRContext.
+ return Base::get(context, LinalgTypes::Range);
+ }
+ /// Used to implement llvm-style cast.
+ static bool kindof(unsigned kind) { return kind == LinalgTypes::Range; }
+};
+
+} // namespace linalg
+
+#endif // LINALG1_RANGETYPE_H_
diff --git a/include/linalg1/SliceOp.h b/include/linalg1/SliceOp.h
new file mode 100644
index 0000000..1d79784
--- /dev/null
+++ b/include/linalg1/SliceOp.h
@@ -0,0 +1,91 @@
+//===- SliceOp.h - Linalg dialect SliceOp operation definition ------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG1_SLICEOP_H_
+#define LINALG1_SLICEOP_H_
+
+#include "linalg1/Types.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Support/LLVM.h"
+
+namespace linalg {
+
+/// A SliceOp is used to create a "sub-View" from a ViewType. It results in a
+/// new ViewType which is contained within its parent ViewType.
+class SliceOp : public mlir::Op<SliceOp, mlir::OpTrait::NOperands<2>::Impl,
+ mlir::OpTrait::OneResult,
+ mlir::OpTrait::HasNoSideEffect> {
+public:
+ using Op::Op;
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Hooks to customize the behavior of this op.
+ //////////////////////////////////////////////////////////////////////////////
+ static llvm::StringRef getOperationName() { return "linalg.slice"; }
+ static void build(mlir::Builder *b, mlir::OperationState *result,
+ mlir::Value *view, mlir::Value *indexing, unsigned dim);
+ mlir::LogicalResult verify();
+ static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
+ void print(mlir::OpAsmPrinter *p);
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Op-specific functionality.
+ //////////////////////////////////////////////////////////////////////////////
+ enum { FirstIndexingOperand = 1 };
+ /// Returns the attribute name that describes which dimension of the input
+ /// view that this SliceOp slices.
+ static llvm::StringRef getSlicingDimAttrName() { return "dim"; }
+ /// Returns the unique result of the parent SliceOp of ViewOp instruction that
+ /// created the view on which this SliceOp operates.
+ mlir::Value *getParentView() { return getOperand(0); }
+ /// Returns the indexing operand of the current SliceOp.
+ /// This operands may either be:
+ /// 1. A range, in which case the operand comes from a RangeOp. This SliceOp
+ /// does not reduce the dimension of the input ViewType.
+ /// 2. An index, in which case the operand comes from any possible producer
+ /// of an index. This SliceOp reduces the dimension of the input ViewType
+ /// by 1.
+ mlir::Value *getIndexing() { return getOperand(1); }
+ /// Returns the dim of the parent ViewType that is sliced by this SliceOp.
+ unsigned getSlicingDim() {
+ return getAttrOfType<mlir::IntegerAttr>(getSlicingDimAttrName()).getInt();
+ }
+ /// Returns the ViewType resulting from this SliceOp.
+ ViewType getViewType();
+ /// Returns the rank of the current ViewType.
+ unsigned getRank();
+ /// Return the element type of the current ViewType.
+ mlir::Type getElementType();
+
+ /// Returns the ViewType of `getParentView()`.
+ ViewType getParentViewType();
+ /// Returns the rank of the ViewType of `getParentView()`.
+ unsigned getParentRank();
+ /// Returns the element Type of the ViewType of `getParentView()`.
+ mlir::Type getParentElementType();
+
+ /// Returns true if the rank of the part view is greater than the rank of
+ /// the child view.
+ bool isRankDecreasing();
+
+ // Get all the indexings in this slice.
+ mlir::Operation::operand_range getIndexings();
+};
+
+} // namespace linalg
+
+#endif // LINALG1_SLICEOP_H_
diff --git a/include/linalg1/Types.h b/include/linalg1/Types.h
new file mode 100644
index 0000000..5032e96
--- /dev/null
+++ b/include/linalg1/Types.h
@@ -0,0 +1,36 @@
+//===- Types.h - Linalg Types forward declarations ------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG1_TYPES_H_
+#define LINALG1_TYPES_H_
+
+#include "mlir/IR/Types.h"
+
+namespace linalg {
+
+enum LinalgTypes {
+ Range = mlir::Type::FIRST_PRIVATE_EXPERIMENTAL_0_TYPE,
+ View,
+ FIRST_PRIVATE_EXPERIMENTAL_0_TYPE = View,
+};
+
+} // namespace linalg
+
+#include "linalg1/RangeType.h"
+#include "linalg1/ViewType.h"
+
+#endif // LINALG1_TYPES_H_
diff --git a/include/linalg1/Utils.h b/include/linalg1/Utils.h
new file mode 100644
index 0000000..3f7bb76
--- /dev/null
+++ b/include/linalg1/Utils.h
@@ -0,0 +1,37 @@
+//===- Utils.h - Linalg dialect utility functions definitions -------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG1_UTILS_H_
+#define LINALG1_UTILS_H_
+
+namespace mlir {
+class Value;
+} // namespace mlir
+
+namespace linalg {
+class ViewOp;
+
+/// Asserts `view` is of ViewType and returns its rank.
+unsigned getViewRank(mlir::Value *view);
+
+/// Helper function to emit and return a new ViewOp from `memRef` that is
+/// assumed to be of MemRefType. This needs to be called under a ScopedContext.
+ViewOp emitAndReturnViewOpFromMemRef(mlir::Value *memRef);
+
+} // namespace linalg
+
+#endif // LINALG1_UTILS_H_
diff --git a/include/linalg1/ViewOp.h b/include/linalg1/ViewOp.h
new file mode 100644
index 0000000..fcda553
--- /dev/null
+++ b/include/linalg1/ViewOp.h
@@ -0,0 +1,67 @@
+//===- ViewOp.h - Linalg dialect ViewOp operation definition ------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG1_VIEWOP_H_
+#define LINALG1_VIEWOP_H_
+
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Support/LLVM.h"
+
+namespace linalg {
+
+class ViewType;
+
+/// A `ViewOp` produces a `ViewType` which is a multi-dimensional range
+/// abstraction on top of an underlying data type. For now we use the existing
+/// mlir::MemRef for the underlying data type.
+class ViewOp : public mlir::Op<ViewOp, mlir::OpTrait::VariadicOperands,
+ mlir::OpTrait::OneResult,
+ mlir::OpTrait::HasNoSideEffect> {
+public:
+ using Op::Op;
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Hooks to customize the behavior of this op.
+ //////////////////////////////////////////////////////////////////////////////
+ static llvm::StringRef getOperationName() { return "linalg.view"; }
+ static void build(mlir::Builder *b, mlir::OperationState *result,
+ mlir::Value *memRef,
+ llvm::ArrayRef<mlir::Value *> indexings);
+ mlir::LogicalResult verify();
+ static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
+ void print(mlir::OpAsmPrinter *p);
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Op-specific functionality.
+ //////////////////////////////////////////////////////////////////////////////
+ enum { FirstIndexingOperand = 1 };
+ unsigned getRank();
+ mlir::Type getElementType();
+ ViewType getViewType();
+ // May be something else than a MemRef in the future.
+ mlir::Value *getSupportingMemRef();
+ // Get the underlying indexing at a given rank.
+ mlir::Value *getIndexing(unsigned rank);
+ // Get all the indexings of type RangeOp.
+ llvm::SmallVector<mlir::Value *, 8> getRanges();
+ // Get all the indexings in this view.
+ mlir::Operation::operand_range getIndexings();
+};
+
+} // namespace linalg
+
+#endif // LINALG1_VIEWOP_H_
diff --git a/include/linalg1/ViewType.h b/include/linalg1/ViewType.h
new file mode 100644
index 0000000..c58e12c
--- /dev/null
+++ b/include/linalg1/ViewType.h
@@ -0,0 +1,57 @@
+//===- ViewType.h - Linalg ViewType definition --------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG1_VIEWTYPE_H_
+#define LINALG1_VIEWTYPE_H_
+
+#include "linalg1/Types.h"
+#include "mlir/IR/Types.h"
+
+namespace linalg {
+
+class ViewTypeStorage;
+
+/// A ViewType represents a range abstraction on top of an underlying storage
+/// type. It is parameterizable by the underlying element type and the rank of
+/// the view.
+class ViewType
+ : public mlir::Type::TypeBase<ViewType, mlir::Type, ViewTypeStorage> {
+public:
+ //////////////////////////////////////////////////////////////////////////////
+ // Hooks to customize the behavior of this type.
+ //////////////////////////////////////////////////////////////////////////////
+ // Used to implement llvm-style cast.
+ using Base::Base;
+ // Used to implement llvm-style cast.
+ static bool kindof(unsigned kind) { return kind == LinalgTypes::View; }
+ /// Construction hook.
+ static ViewType get(mlir::MLIRContext *context, mlir::Type elementType,
+ unsigned rank);
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Type-specific functionality.
+ //////////////////////////////////////////////////////////////////////////////
+ /// Return the underlying elemental type.
+ mlir::Type getElementType();
+ /// Return the rank of the view.
+ /// This is the number of indexings needed to reach an underlying element.
+ unsigned getRank();
+};
+
+} // namespace linalg
+
+#endif // LINALG1_VIEWTYPE_H_
diff --git a/include/linalg2/Analysis.h b/include/linalg2/Analysis.h
new file mode 100644
index 0000000..43acd95
--- /dev/null
+++ b/include/linalg2/Analysis.h
@@ -0,0 +1,23 @@
+//===- Analysis.h - Linalg dialect Analysis function definitions ----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG2_ANALYSIS_H_
+#define LINALG2_ANALYSIS_H_
+
+#include "linalg1/Analysis.h"
+
+#endif // LINALG2_ANALYSIS_H_
diff --git a/include/linalg2/Intrinsics.h b/include/linalg2/Intrinsics.h
new file mode 100644
index 0000000..e74e059
--- /dev/null
+++ b/include/linalg2/Intrinsics.h
@@ -0,0 +1,32 @@
+//===- Intrinsics.h - Linalg intrinsics definitions -----------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG2_INTRINSICS_H_
+#define LINALG2_INTRINSICS_H_
+
+#include "linalg1/Intrinsics.h"
+#include "linalg2/Ops.h"
+
+namespace linalg {
+namespace intrinsics {
+using dot = mlir::edsc::intrinsics::OperationBuilder<DotOp>;
+using matmul = mlir::edsc::intrinsics::OperationBuilder<MatmulOp>;
+using matvec = mlir::edsc::intrinsics::OperationBuilder<MatvecOp>;
+} // namespace intrinsics
+} // namespace linalg
+
+#endif // LINALG2_INTRINSICS_H_
diff --git a/include/linalg2/Ops.h b/include/linalg2/Ops.h
new file mode 100644
index 0000000..141b1d0
--- /dev/null
+++ b/include/linalg2/Ops.h
@@ -0,0 +1,24 @@
+//===- Ops.h - Linalg Ops single entry point ------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG2_OPS_H_
+#define LINALG2_OPS_H_
+
+#include "linalg1/Ops.h"
+#include "linalg2/TensorOps.h"
+
+#endif // LINALG2_OPS_H_
diff --git a/include/linalg2/TensorOps-inl.h b/include/linalg2/TensorOps-inl.h
new file mode 100644
index 0000000..940f8d7
--- /dev/null
+++ b/include/linalg2/TensorOps-inl.h
@@ -0,0 +1,120 @@
+//===- TensorOps-inl.h - Linalg dialect TensorOps operation implementation ===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
+/// TensorOps by adding implementations as they are needed in the appropriate
+/// step in the tutorial.
+#ifndef LINALG2_TENSOROPS_INL_H_
+#define LINALG2_TENSOROPS_INL_H_
+
+#include "linalg2/Ops.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/StandardTypes.h"
+
+namespace linalg {
+
+template <class ConcreteOp>
+mlir::Operation::operand_range
+linalg::TensorContractionBase<ConcreteOp>::getInputs() {
+ auto *op = static_cast<ConcreteOp *>(this)->getOperation();
+ return {op->operand_begin(), op->operand_begin() + getNumInputs()};
+}
+
+template <class ConcreteOp>
+mlir::Operation::operand_range
+linalg::TensorContractionBase<ConcreteOp>::getOutputs() {
+ auto *op = static_cast<ConcreteOp *>(this)->getOperation();
+ return {op->operand_begin() + getNumInputs(),
+ op->operand_begin() + getNumInputs() + getNumOutputs()};
+}
+
+template <class ConcreteOp>
+mlir::Operation::operand_range
+linalg::TensorContractionBase<ConcreteOp>::getInputsAndOutputs() {
+ return {getInputs().begin(), getOutputs().end()};
+}
+
+template <class ConcreteOp>
+mlir::LogicalResult linalg::TensorContractionBase<ConcreteOp>::verify() {
+ auto *concreteOp = static_cast<ConcreteOp *>(this)->getOperation();
+ if (getNumInputs() <= 0)
+ concreteOp->emitOpError("expected at least one input");
+ if (getNumOutputs() <= 0)
+ concreteOp->emitOpError("expected at least one output");
+ if (concreteOp->getNumOperands() != getNumInputs() + getNumOutputs()) {
+ concreteOp->emitOpError("expected " +
+ llvm::Twine(getNumInputs() + getNumOutputs()) +
+ " operands");
+ }
+ for (unsigned i = 0, e = getNumInputs(); i < e; ++i) {
+ if (!concreteOp->getOperand(i)->getType().template isa<ViewType>())
+ return concreteOp->emitOpError("operand " + llvm::Twine(i) +
+ " not a ViewType");
+ }
+ for (unsigned i = getNumInputs(), e = getNumInputs() + getNumOutputs(); i < e;
+ ++i) {
+ auto viewType =
+ concreteOp->getOperand(i)->getType().template dyn_cast<ViewType>();
+ if (!viewType)
+ return concreteOp->emitOpError("operand " + llvm::Twine(i) +
+ " not a ViewType");
+ if (viewType.getRank() != getNumParallelDims())
+ return concreteOp->emitOpError("operand " + llvm::Twine(i) +
+ " must be of rank " +
+ llvm::Twine(getNumParallelDims()));
+ }
+ return mlir::success();
+}
+
+template <class ConcreteOp>
+bool linalg::TensorContractionBase<ConcreteOp>::parse(
+ mlir::OpAsmParser *parser, mlir::OperationState *result) {
+ llvm_unreachable("Parsing linalg dialect is not supported in this tutorial");
+}
+
+// A TensorContraction prints as:
+//
+// ```{.mlir}
+// concrete_op_name (ssa-inputs, ssa-outputs) : output-view-types
+// ```
+//
+// for example:
+//
+// ```
+// linalg.matmul(%0, %1, %2) : view<?x?xf32>
+// ```
+//
+// Where %0, %1 and %2 are ssa-values of type ViewType.
+template <class ConcreteOp>
+void linalg::TensorContractionBase<ConcreteOp>::print(mlir::OpAsmPrinter *p) {
+ *p << static_cast<ConcreteOp *>(this)->getOperationName() << "(";
+ auto *last = *std::prev(getInputsAndOutputs().end());
+ for (auto *i : getInputsAndOutputs()) {
+ *p << *i << ((i == last) ? "" : ", ");
+ }
+ *p << ") : ";
+ auto *lastOutput = *std::prev(getOutputs().end());
+ for (auto *o : getOutputs()) {
+ *p << o->getType() << ((o == lastOutput) ? "" : ",");
+ }
+}
+
+} // namespace linalg
+
+#endif // LINALG2_TENSOROPS_INL_H_
diff --git a/include/linalg2/TensorOps.h b/include/linalg2/TensorOps.h
new file mode 100644
index 0000000..39e51f0
--- /dev/null
+++ b/include/linalg2/TensorOps.h
@@ -0,0 +1,287 @@
+//===- TensorOps.h - Linalg dialect TensorOps operation definition --------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG2_TENSOROPS_H_
+#define LINALG2_TENSOROPS_H_
+
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+class AffineForOp;
+} // namespace mlir
+
+namespace linalg {
+
+/// A generic TensorContraction base class which captures the generic behavior
+/// of tensor contraction operations (with broadcast).
+template <class ConcreteOp> class TensorContractionBase {
+protected:
+ using TensorContractionBaseType = TensorContractionBase<ConcreteOp>;
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Hooks to customize the behavior of this op.
+ //////////////////////////////////////////////////////////////////////////////
+ /// Generic implementation of hooks that should be called from `ConcreteType`s
+ mlir::LogicalResult verify();
+ static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
+ void print(mlir::OpAsmPrinter *p);
+
+public:
+ //////////////////////////////////////////////////////////////////////////////
+ // Op-specific functionality.
+ //////////////////////////////////////////////////////////////////////////////
+ TensorContractionBase() = default;
+ mlir::Operation::operand_range getInputs();
+ mlir::Operation::operand_range getOutputs();
+ mlir::Operation::operand_range getInputsAndOutputs();
+
+ /// These are better as methods calling into the ConcreteOp instead of
+ /// template parameters because methods allow more generic behavior and avoid
+ /// specializing for number of arguments. All derived classes have
+ /// `VariadicOperands` and a build method from both an ArrayRef<mlirValue*>
+ /// and the proper number of mlir::Value*.
+ unsigned getNumInputs() {
+ return static_cast<ConcreteOp *>(this)->numInputs;
+ };
+ unsigned getNumOutputs() {
+ return static_cast<ConcreteOp *>(this)->numOutputs;
+ };
+ unsigned getNumParallelDims() {
+ return static_cast<ConcreteOp *>(this)->numParallelDims;
+ };
+ unsigned getNumReductionDims() {
+ return static_cast<ConcreteOp *>(this)->numReductionDims;
+ };
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Used in Linalg3 and later.
+ //////////////////////////////////////////////////////////////////////////////
+ mlir::Value *getInputView(unsigned viewIndex);
+ mlir::Value *getOutputView(unsigned viewIndex);
+ mlir::Value *getView(unsigned viewIndex) {
+ return viewIndex < getNumInputs()
+ ? getInputView(viewIndex)
+ : getOutputView(viewIndex - getNumInputs());
+ }
+
+ /// Each op is responsible for declaring how it lowers itself to scalar form,
+ /// given the enclosing parallel and reduction induction variables.
+ /// `emitScalarImplementation` emits the scalar IR for the op in the nesting
+ /// context of the innermost enclosing loop(i.e. `reductionIvs.back()` or
+ /// `parallel.back()`).
+ void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
+ llvm::ArrayRef<mlir::Value *> reductionIvs);
+
+ /// Represents a mapping from the loops to all the ranges of the operands.
+ /// The operands and their ranges are in the order defined by the particular
+ /// ConcreteOp implementation, the resulting map must match those.
+ /// In favorable cases, this can be calculated by an analysis but specifying
+ /// it explicitly is not expensive and generalizes to cases where an analysis
+ /// is not available. For details, see the description of
+ /// loopsToOperandRangeMaps in each ConcreteOp.
+ llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps();
+};
+
+/// Implements c = A * B where c is a scalar and A and B are 1-D vectors.
+class DotOp : public TensorContractionBase<DotOp>,
+ public mlir::Op<DotOp, mlir::OpTrait::VariadicOperands,
+ mlir::OpTrait::ZeroResult> {
+public:
+ using Op::Op;
+ using TensorContractionBaseType =
+ TensorContractionBase::TensorContractionBaseType;
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Hooks to customize the behavior of this op.
+ //////////////////////////////////////////////////////////////////////////////
+ static llvm::StringRef getOperationName() { return "linalg.dot"; }
+ static void build(mlir::Builder *b, mlir::OperationState *result,
+ llvm::ArrayRef<mlir::Value *> operands);
+ static void build(mlir::Builder *b, mlir::OperationState *result,
+ mlir::Value *A, mlir::Value *B, mlir::Value *C) {
+ return build(b, result, {A, B, C});
+ }
+ mlir::LogicalResult verify();
+ static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
+ void print(mlir::OpAsmPrinter *p);
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Op-specific functionality.
+ //////////////////////////////////////////////////////////////////////////////
+ static constexpr unsigned numInputs = 2;
+ static constexpr unsigned numOutputs = 1;
+ static constexpr unsigned numParallelDims = 0;
+ static constexpr unsigned numReductionDims = 1;
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Used in Linalg3 and later.
+ //////////////////////////////////////////////////////////////////////////////
+ /// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
+ /// loop over matvec). Does nothing by default.
+ void writeAsFinerGrainTensorContraction();
+
+ /// Inputs to this map will be (%k) coming from enclosing loops.
+ /// Therefore, the mapping to get back to A(K), B(K), C() is:
+ /// (d0) -> (d0, d0)(%k)
+ /// And the operands ranges are:
+ /// (%k, %k)
+ llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps();
+
+ /// Given an enclosing reduction loop with iv `r_i`, emits MLIR corresponding
+ /// to:
+ /// 1. conditionally assign scalarC to 0.0f on the first iteration or load
+ /// C[] from memory (0-D tensor)
+ /// 2. multiply A[r_i] by B[r_i] and add to scalarC
+ /// 3. store back scalarC at C[]
+ ///
+ /// In some compact index notation this could be written:
+ /// cond = (r_i == zero)
+ /// scalarC = select(cond, zerof, C[]);
+ /// C[] = scalarC + A[r_i] * B[r_i];
+ void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
+ llvm::ArrayRef<mlir::Value *> reductionIvs);
+};
+
+/// Implements C = A * B where A is a 2-D matrix and X and Y are 1-D vectors.
+class MatvecOp : public TensorContractionBase<MatvecOp>,
+ public mlir::Op<MatvecOp, mlir::OpTrait::VariadicOperands,
+ mlir::OpTrait::ZeroResult> {
+public:
+ using Op::Op;
+ using TensorContractionBaseType =
+ TensorContractionBase::TensorContractionBaseType;
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Hooks to customize the behavior of this op.
+ //////////////////////////////////////////////////////////////////////////////
+ static llvm::StringRef getOperationName() { return "linalg.matvec"; }
+ static void build(mlir::Builder *b, mlir::OperationState *result,
+ llvm::ArrayRef<mlir::Value *> operands);
+ static void build(mlir::Builder *b, mlir::OperationState *result,
+ mlir::Value *A, mlir::Value *B, mlir::Value *C) {
+ return build(b, result, {A, B, C});
+ }
+ mlir::LogicalResult verify();
+ static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
+ void print(mlir::OpAsmPrinter *p);
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Op-specific functionality.
+ //////////////////////////////////////////////////////////////////////////////
+ static constexpr unsigned numInputs = 2;
+ static constexpr unsigned numOutputs = 1;
+ static constexpr unsigned numParallelDims = 1;
+ static constexpr unsigned numReductionDims = 1;
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Used in Linalg3 and later.
+ //////////////////////////////////////////////////////////////////////////////
+ /// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
+ /// loop over matvec). Does nothing by default.
+ void writeAsFinerGrainTensorContraction();
+
+ /// Inputs to this map will be (%m, %k) coming from enclosing loops.
+ /// Therefore, the mapping to get back to A(M, K), B(K), C(M) is:
+ /// (d0, d1) -> (d0, d1, d1, d0)(%m, %k)
+ /// And the operands ranges are:
+ /// (%m, %k, %k, %m)
+ llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps();
+
+ /// Given an enclosing parallel loop with iv `i` and an enclosing parallel
+ /// loop with iv `r_j`, emits MLIR corresponding to:
+ /// 1. conditionally assign scalarC to 0.0f on the first iteration or load
+ /// C[i]
+ /// 2. multiply A[i, r_j] by B[r_j] and add to scalarC
+ /// 3. store back scalarC at C[i]
+ ///
+ /// In some compact index notation this could be written:
+ /// cond = (r_j == zero)
+ /// scalarC = select(cond, zerof, C(i));
+ /// C(i) = scalarC + A(i, r_j) * B(r_j);
+ void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
+ llvm::ArrayRef<mlir::Value *> reductionIvs);
+};
+
+/// Implements C = A * B on 2-D matrices.
+class MatmulOp : public TensorContractionBase<MatmulOp>,
+ public mlir::Op<MatmulOp, mlir::OpTrait::VariadicOperands,
+ mlir::OpTrait::ZeroResult> {
+public:
+ using Op::Op;
+ using TensorContractionBaseType =
+ TensorContractionBase::TensorContractionBaseType;
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Hooks to customize the behavior of this op.
+ //////////////////////////////////////////////////////////////////////////////
+ static llvm::StringRef getOperationName() { return "linalg.matmul"; }
+ static void build(mlir::Builder *b, mlir::OperationState *result,
+ llvm::ArrayRef<mlir::Value *> operands);
+ static void build(mlir::Builder *b, mlir::OperationState *result,
+ mlir::Value *A, mlir::Value *B, mlir::Value *C) {
+ return build(b, result, {A, B, C});
+ }
+ mlir::LogicalResult verify();
+ static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
+ void print(mlir::OpAsmPrinter *p);
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Op-specific functionality.
+ //////////////////////////////////////////////////////////////////////////////
+ static constexpr unsigned numInputs = 2;
+ static constexpr unsigned numOutputs = 1;
+ static constexpr unsigned numParallelDims = 2;
+ static constexpr unsigned numReductionDims = 1;
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Used in Linalg3 and later.
+ //////////////////////////////////////////////////////////////////////////////
+ /// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
+ /// loop over matvec). Does nothing by default.
+ void writeAsFinerGrainTensorContraction();
+
+ /// Inputs to this map will be (%m, %n, %k) coming from enclosing loops.
+ /// Therefore, the mapping to get back to A(M, K), B(K, N), C(M, N) is:
+ /// (d0, d1, d2) -> (d0, d2, d2, d1, d0, d1)(%m, %n, %k)
+ /// And the operands ranges are:
+ /// (%m, %k, %k, %n, %m, %n)
+ llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps();
+
+ /// Given a enclosing parallel loops with ivs `i` and `j`, and an enclosing
+ /// reduction loop with iv `r_k`, emits MLIR corresponding to:
+ /// 1. conditionally assign scalarC to 0.0f on the first iteration or load
+ /// C[i, j]
+ /// 2. multiply A[i, r_k] by B[r_k, j] and add to scalarC
+ /// 3. store back scalarC at C[i, j]
+ ///
+ /// In some compact index notation this could be written:
+ /// cond = (r_k == zero)
+ /// scalarC = select(cond, zerof, C[i, j]);
+ /// C[i, j] = scalarC + A[i, r_k] * B[r_k, j];
+ void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
+ llvm::ArrayRef<mlir::Value *> reductionIvs);
+};
+
+} // namespace linalg
+
+/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
+/// TensorOps by adding implementations as they are needed in the appropriate
+/// step in the tutorial.
+#include "linalg2/TensorOps-inl.h"
+
+#endif // LINALG2_TENSOROPS_H_
diff --git a/include/linalg2/Transforms.h b/include/linalg2/Transforms.h
new file mode 100644
index 0000000..c55f863
--- /dev/null
+++ b/include/linalg2/Transforms.h
@@ -0,0 +1,36 @@
+//===- Transforms.h - Linalg dialect Transformations definition -----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG2_TRANSFORMS_H_
+#define LINALG2_TRANSFORMS_H_
+
+namespace mlir {
+class Value;
+} // namespace mlir
+
+namespace linalg {
+
+class ViewOp;
+
+/// Takes a `view` of type ViewType (i.e. either a ViewOp or a SliceOp) and
+/// composes away all the SliceOp to return a single ViewOp.
+/// Inserts the required operations after `view`.
+ViewOp emitAndReturnFullyComposedView(mlir::Value *v);
+
+} // namespace linalg
+
+#endif // LINALG2_TRANSFORMS_H_
diff --git a/include/linalg3/Analysis.h b/include/linalg3/Analysis.h
new file mode 100644
index 0000000..813fc37
--- /dev/null
+++ b/include/linalg3/Analysis.h
@@ -0,0 +1,37 @@
+//===- Analysis.h - Linalg dialect Analysis function definitions ----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG3_ANALYSIS_H_
+#define LINALG3_ANALYSIS_H_
+
+#include "linalg2/Analysis.h"
+
+namespace mlir {
+class AffineMap;
+} // namespace mlir
+
+namespace linalg {
+
+/// Given a `map` specification and a subset of its results
+/// `[beginResult, endResult)`, returns the inverse map that maps result
+/// positions to dim positions.
+mlir::AffineMap inverseSubMap(mlir::AffineMap map, unsigned beginResult = 0,
+ unsigned endResult = 0);
+
+} // namespace linalg
+
+#endif // LINALG3_ANALYSIS_H_
diff --git a/include/linalg3/ConvertToLLVMDialect.h b/include/linalg3/ConvertToLLVMDialect.h
new file mode 100644
index 0000000..8f122e0
--- /dev/null
+++ b/include/linalg3/ConvertToLLVMDialect.h
@@ -0,0 +1,29 @@
+//===- ConvertToLLVMDialect.h - conversion from Linalg to LLVM --*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG3_CONVERTTOLLVMDIALECT_H_
+#define LINALG3_CONVERTTOLLVMDIALECT_H_
+
+namespace mlir {
+class Module;
+} // end namespace mlir
+
+namespace linalg {
+void convertLinalg3ToLLVM(mlir::Module &module);
+} // end namespace linalg
+
+#endif // LINALG3_CONVERTTOLLVMDIALECT_H_
diff --git a/include/linalg3/Intrinsics.h b/include/linalg3/Intrinsics.h
new file mode 100644
index 0000000..75a0417
--- /dev/null
+++ b/include/linalg3/Intrinsics.h
@@ -0,0 +1,31 @@
+//===- Intrinsics.h - Linalg intrinsics definitions -----------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG3_INTRINSICS_H_
+#define LINALG3_INTRINSICS_H_
+
+#include "linalg2/Intrinsics.h"
+#include "linalg3/Ops.h"
+
+namespace linalg {
+namespace intrinsics {
+using load = mlir::edsc::intrinsics::ValueBuilder<LoadOp>;
+using store = mlir::edsc::intrinsics::OperationBuilder<StoreOp>;
+} // namespace intrinsics
+} // namespace linalg
+
+#endif // LINALG3_INTRINSICS_H_
diff --git a/include/linalg3/LoadStoreOps.h b/include/linalg3/LoadStoreOps.h
new file mode 100644
index 0000000..b77e702
--- /dev/null
+++ b/include/linalg3/LoadStoreOps.h
@@ -0,0 +1,89 @@
+//===- LoadStoreOps.h - Linalg dialect Load/Store operation definitions ---===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG3_LOADSTOREOP_H_
+#define LINALG3_LOADSTOREOP_H_
+
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Support/LLVM.h"
+
+namespace linalg {
+
+class ViewType;
+
+/// A linalg.LoadOp is the counterpart of affine.load but operating on ViewType
+/// instead of MemRefType.
+class LoadOp : public mlir::Op<LoadOp, mlir::OpTrait::VariadicOperands,
+ mlir::OpTrait::OneResult> {
+public:
+ using Op::Op;
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Hooks to customize the behavior of this op.
+ //////////////////////////////////////////////////////////////////////////////
+ static llvm::StringRef getOperationName() { return "linalg.load"; }
+ static void build(mlir::Builder *b, mlir::OperationState *result,
+ mlir::Value *view,
+ mlir::ArrayRef<mlir::Value *> indices = {});
+ mlir::LogicalResult verify();
+ static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
+ void print(mlir::OpAsmPrinter *p);
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Op-specific functionality.
+ //////////////////////////////////////////////////////////////////////////////
+ unsigned getRank();
+ ViewType getViewType();
+ mlir::Value *getView() { return getOperand(0); }
+ mlir::Operation::operand_range getIndices() {
+ return {operand_begin() + 1, operand_end()};
+ }
+};
+
+/// A linalg.StoreOp is the counterpart of affine.store but operating on
+/// ViewType instead of MemRefType.
+class StoreOp : public mlir::Op<StoreOp, mlir::OpTrait::VariadicOperands,
+ mlir::OpTrait::ZeroResult> {
+public:
+ using Op::Op;
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Hooks to customize the behavior of this op.
+ //////////////////////////////////////////////////////////////////////////////
+ static llvm::StringRef getOperationName() { return "linalg.store"; }
+ static void build(mlir::Builder *b, mlir::OperationState *result,
+ mlir::Value *valueToStore, mlir::Value *view,
+ mlir::ArrayRef<mlir::Value *> indices = {});
+ mlir::LogicalResult verify();
+ static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
+ void print(mlir::OpAsmPrinter *p);
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Op-specific functionality.
+ //////////////////////////////////////////////////////////////////////////////
+ unsigned getRank();
+ ViewType getViewType();
+ mlir::Value *getValueToStore() { return getOperand(0); }
+ mlir::Value *getView() { return getOperand(1); }
+ mlir::Operation::operand_range getIndices() {
+ return {operand_begin() + 2, operand_end()};
+ }
+};
+
+} // namespace linalg
+
+#endif // LINALG3_LOADSTOREOP_H_
diff --git a/include/linalg3/Ops.h b/include/linalg3/Ops.h
new file mode 100644
index 0000000..813cbff
--- /dev/null
+++ b/include/linalg3/Ops.h
@@ -0,0 +1,25 @@
+//===- Ops.h - Linalg Ops single entry point ------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG3_OPS_H_
+#define LINALG3_OPS_H_
+
+#include "linalg2/Ops.h"
+#include "linalg3/LoadStoreOps.h"
+#include "linalg3/TensorOps.h"
+
+#endif // LINALG3_OPS_H_
diff --git a/include/linalg3/TensorOps-inl.h b/include/linalg3/TensorOps-inl.h
new file mode 100644
index 0000000..b651053
--- /dev/null
+++ b/include/linalg3/TensorOps-inl.h
@@ -0,0 +1,145 @@
+//===- TensorOps-inl.h - Linalg dialect TensorOps operation implementation ===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
+/// TensorOps by adding implementations as they are needed in the appropriate
+/// step in the tutorial.
+#ifndef LINALG3_TENSOROPS_INL_H_
+#define LINALG3_TENSOROPS_INL_H_
+
+#include "linalg1/Common.h"
+#include "linalg1/Utils.h"
+#include "linalg2/TensorOps.h"
+#include "linalg3/Analysis.h"
+#include "linalg3/Ops.h"
+
+template <class ConcreteOp>
+mlir::Value *
+linalg::TensorContractionBase<ConcreteOp>::getInputView(unsigned viewIndex) {
+ return *(getInputs().begin() + viewIndex);
+}
+
+template <class ConcreteOp>
+mlir::Value *
+linalg::TensorContractionBase<ConcreteOp>::getOutputView(unsigned viewIndex) {
+ return *(getOutputs().begin() + viewIndex);
+}
+
+template <class ConcreteOp>
+llvm::SmallVector<mlir::AffineMap, 8>
+linalg::TensorContractionBase<ConcreteOp>::loopsToOperandRangeMaps() {
+ return static_cast<ConcreteOp *>(this)->loopsToOperandRangeMaps();
+}
+
+template <class ConcreteOp>
+void linalg::TensorContractionBase<ConcreteOp>::emitScalarImplementation(
+ llvm::ArrayRef<mlir::Value *> parallelIvs,
+ llvm::ArrayRef<mlir::Value *> reductionIvs) {
+ static_cast<ConcreteOp *>(this)->emitScalarImplementation(parallelIvs,
+ reductionIvs);
+}
+
+template <class ConcreteOp>
+mlir::AffineMap linalg::operandRangesToLoopsMap(
+ linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
+ mlir::AffineMap current;
+ // Individual submaps may not be invertible but their union must be invertible
+ // by construction.
+ for (auto m : tensorContraction.loopsToOperandRangeMaps()) {
+ if (!m)
+ continue;
+ if (!current) {
+ current = m;
+ continue;
+ }
+ llvm::SmallVector<mlir::AffineExpr, 8> results(current.getResults().begin(),
+ current.getResults().end());
+ results.append(m.getResults().begin(), m.getResults().end());
+ current = mlir::AffineMap::get(
+ std::max(current.getNumDims(), m.getNumDims()),
+ current.getNumSymbols() + m.getNumSymbols(), results, {});
+ }
+ return inverseSubMap(current);
+}
+
+// Extract the ranges from a given ViewOp or SliceOp.
+//
+// In the case of a ViewOp, things are simple: just traverse the indexings and
+// get all the ranges (i.e. drop the indices).
+//
+// In the case of a SliceOp, things are trickier because we need to handle a
+// potential rank-reduction:
+// 1. Examine the indexing to determine if it is rank-reducing.
+// 2. If it is rank-reducing, an offset of 1 is added to the dimensions such
+// that `d >= slicingDim`. This is to account for the rank reduction.
+// `getRootIndex` is then called on the **parent** view
+static llvm::SmallVector<mlir::Value *, 8>
+extractRangesFromViewOrSliceOp(mlir::Value *view) {
+ // This expects a viewType which must come from either ViewOp or SliceOp.
+ assert(view->getType().isa<linalg::ViewType>() && "expected ViewType");
+ if (auto viewOp = view->getDefiningOp()->dyn_cast<linalg::ViewOp>())
+ return viewOp.getRanges();
+
+ auto sliceOp = view->getDefiningOp()->cast<linalg::SliceOp>();
+ unsigned slicingDim = sliceOp.getSlicingDim();
+ auto *indexing = *(sliceOp.getIndexings().begin());
+ bool isRankReducing = indexing->getType().isa<mlir::IndexType>();
+ unsigned offset = 0;
+ llvm::SmallVector<mlir::Value *, 8> res;
+ res.reserve(sliceOp.getRank());
+ for (unsigned d = 0, e = sliceOp.getRank(); d < e; ++d) {
+ if (d == slicingDim && isRankReducing)
+ offset = 1;
+ auto *parentView = sliceOp.getParentView();
+ auto indexingPosPair = linalg::getViewRootIndexing(parentView, d + offset);
+ res.push_back(indexingPosPair.first);
+ }
+ return res;
+}
+
+template <class ConcreteOp>
+static llvm::SmallVector<mlir::Value *, 8>
+getInputRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
+ llvm::SmallVector<mlir::Value *, 8> res;
+ for (auto *in : tensorContraction.getInputs()) {
+ auto subres = extractRangesFromViewOrSliceOp(in);
+ res.append(subres.begin(), subres.end());
+ }
+ return res;
+}
+
+template <class ConcreteOp>
+static llvm::SmallVector<mlir::Value *, 8>
+getOutputRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
+ llvm::SmallVector<mlir::Value *, 8> res;
+ for (auto *out : tensorContraction.getOutputs()) {
+ auto subres = extractRangesFromViewOrSliceOp(out);
+ res.append(subres.begin(), subres.end());
+ }
+ return res;
+}
+
+template <class ConcreteOp>
+llvm::SmallVector<mlir::Value *, 8> linalg::getRanges(
+ linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
+ llvm::SmallVector<mlir::Value *, 8> res = getInputRanges(tensorContraction);
+ llvm::SmallVector<mlir::Value *, 8> tmp = getOutputRanges(tensorContraction);
+ res.append(tmp.begin(), tmp.end());
+ return res;
+}
+
+#endif // LINALG3_TENSOROPS_INL_H_
diff --git a/include/linalg3/TensorOps.h b/include/linalg3/TensorOps.h
new file mode 100644
index 0000000..bf5a377
--- /dev/null
+++ b/include/linalg3/TensorOps.h
@@ -0,0 +1,54 @@
+//===- TensorOps.h - Linalg dialect TensorOps operation definition --------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG3_TENSOROPS_H_
+#define LINALG3_TENSOROPS_H_
+
+#include "linalg2/TensorOps.h"
+
+namespace linalg {
+
+///
+/// Ideally all these functions would go in an Analysis but as long as
+/// TensorContractionBase is templated, they need to remain close enough.
+///
+
+/// Takes a `tensorContraction` and a returns an AffineMap that can be used to
+/// map ranges to enclosing loops for all the operands' ranges.
+template <class ConcreteOp>
+mlir::AffineMap operandRangesToLoopsMap(
+ linalg::TensorContractionBase<ConcreteOp> &tensorContraction);
+
+/// Takes a `tensorContraction` and returns the ranges of all its operands.
+/// When an operand comes from a ViewOp, things are simple:
+/// just traverse the indexings and get all the ranges
+/// (i.e. drop the rank-reducing indices).
+/// In the case of a SliceOp, things are more involved because we need to handle
+/// potential rank-reductions.
+/// This function abstracts this complexity away and returns all the ranges.
+template <class ConcreteOp>
+llvm::SmallVector<mlir::Value *, 8>
+getRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction);
+
+} // namespace linalg
+
+/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
+/// TensorOps by adding implementations as they are needed in the appropriate
+/// step in the tutorial.
+#include "linalg3/TensorOps-inl.h"
+
+#endif // LINALG3_TENSOROPS_H_
diff --git a/include/linalg3/Transforms.h b/include/linalg3/Transforms.h
new file mode 100644
index 0000000..9af528e
--- /dev/null
+++ b/include/linalg3/Transforms.h
@@ -0,0 +1,80 @@
+//===- Transforms.h - Linalg dialect Transformations definition -----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG3_TRANSFORMS_H_
+#define LINALG3_TRANSFORMS_H_
+
+#include "linalg2/Transforms.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/Optional.h"
+
+namespace mlir {
+class AffineForOp;
+class AffineMap;
+class Function;
+class FunctionPassBase;
+class Operation;
+class Value;
+} // namespace mlir
+
+namespace linalg {
+
+struct RangeParts {
+ explicit RangeParts(unsigned reserved);
+ RangeParts(llvm::ArrayRef<mlir::Value *> ranges);
+ llvm::SmallVector<mlir::Value *, 4> makeRanges();
+
+ llvm::SmallVector<mlir::Value *, 4> mins;
+ llvm::SmallVector<mlir::Value *, 4> maxes;
+ llvm::SmallVector<mlir::Value *, 4> steps;
+};
+
+mlir::Value *
+makeFoldedComposedAffineApply(mlir::AffineMap map,
+ llvm::ArrayRef<mlir::Value *> operandsRef);
+
+llvm::SmallVector<mlir::Value *, 4>
+makeGenericLoopRanges(mlir::AffineMap operandRangesToLoopMaps,
+ llvm::ArrayRef<mlir::Value *> ranges,
+ llvm::ArrayRef<mlir::Value *> tileSizes = {});
+
+/// Traverses `f` and rewrites linalg.slice, and the operations it depends on,
+/// to only use linalg.view operations.
+void composeSliceOps(mlir::Function *f);
+
+/// Traverses `f` and rewrites linalg.matmul(resp. linalg.matvec)
+/// as linalg.matvec(resp. linalg.dot).
+void lowerToFinerGrainedTensorContraction(mlir::Function *f);
+
+/// Operation-wise writing of linalg operations to loop form.
+/// It is the caller's responsibility to erase the `op` if necessary.
+/// This returns the enclosing loops around the body of `op` for further
+/// composition of transformations.
+llvm::Optional<llvm::SmallVector<mlir::AffineForOp, 4>>
+writeAsLoops(mlir::Operation *op);
+
+/// Traverses `f` and rewrites linalg operations in loop form.
+void lowerToLoops(mlir::Function *f);
+
+/// Creates a pass that rewrites linalg.load and linalg.store to affine.load and
+/// affine.store operations.
+mlir::FunctionPassBase *createLowerLinalgLoadStorePass();
+
+} // namespace linalg
+
+#endif // LINALG3_TRANSFORMS_H_
diff --git a/include/toy/AST.h b/include/toy/AST.h
new file mode 100644
index 0000000..456a323
--- /dev/null
+++ b/include/toy/AST.h
@@ -0,0 +1,256 @@
+//===- AST.h - Node definition for the Toy AST ----------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the AST for the Toy language. It is optimized for
+// simplicity, not efficiency. The AST forms a tree structure where each node
+// references its children using std::unique_ptr<>.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_AST_H_
+#define MLIR_TUTORIAL_TOY_AST_H_
+
+#include "toy/Lexer.h"
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
+#include <vector>
+
+namespace toy {
+
+/// A variable
+struct VarType {
+ enum { TY_FLOAT, TY_INT } elt_ty;
+ std::vector<int> shape;
+};
+
+/// Base class for all expression nodes.
+class ExprAST {
+public:
+ enum ExprASTKind {
+ Expr_VarDecl,
+ Expr_Return,
+ Expr_Num,
+ Expr_Literal,
+ Expr_Var,
+ Expr_BinOp,
+ Expr_Call,
+ Expr_Print, // builtin
+ Expr_If,
+ Expr_For,
+ };
+
+ ExprAST(ExprASTKind kind, Location location)
+ : kind(kind), location(location) {}
+
+ virtual ~ExprAST() = default;
+
+ ExprASTKind getKind() const { return kind; }
+
+ const Location &loc() { return location; }
+
+private:
+ const ExprASTKind kind;
+ Location location;
+};
+
+/// A block-list of expressions.
+using ExprASTList = std::vector<std::unique_ptr<ExprAST>>;
+
+/// Expression class for numeric literals like "1.0".
+class NumberExprAST : public ExprAST {
+ double Val;
+
+public:
+ NumberExprAST(Location loc, double Val) : ExprAST(Expr_Num, loc), Val(Val) {}
+
+ double getValue() { return Val; }
+
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; }
+};
+
+///
+class LiteralExprAST : public ExprAST {
+ std::vector<std::unique_ptr<ExprAST>> values;
+ std::vector<int64_t> dims;
+
+public:
+ LiteralExprAST(Location loc, std::vector<std::unique_ptr<ExprAST>> values,
+ std::vector<int64_t> dims)
+ : ExprAST(Expr_Literal, loc), values(std::move(values)),
+ dims(std::move(dims)) {}
+
+ std::vector<std::unique_ptr<ExprAST>> &getValues() { return values; }
+ std::vector<int64_t> &getDims() { return dims; }
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *C) { return C->getKind() == Expr_Literal; }
+};
+
+/// Expression class for referencing a variable, like "a".
+class VariableExprAST : public ExprAST {
+ std::string name;
+
+public:
+ VariableExprAST(Location loc, const std::string &name)
+ : ExprAST(Expr_Var, loc), name(name) {}
+
+ llvm::StringRef getName() { return name; }
+
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; }
+};
+
+///
+class VarDeclExprAST : public ExprAST {
+ std::string name;
+ VarType type;
+ std::unique_ptr<ExprAST> initVal;
+
+public:
+ VarDeclExprAST(Location loc, const std::string &name, VarType type,
+ std::unique_ptr<ExprAST> initVal)
+ : ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)),
+ initVal(std::move(initVal)) {}
+
+ llvm::StringRef getName() { return name; }
+ ExprAST *getInitVal() { return initVal.get(); }
+ VarType &getType() { return type; }
+
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; }
+};
+
+///
+class ReturnExprAST : public ExprAST {
+ llvm::Optional<std::unique_ptr<ExprAST>> expr;
+
+public:
+ ReturnExprAST(Location loc, llvm::Optional<std::unique_ptr<ExprAST>> expr)
+ : ExprAST(Expr_Return, loc), expr(std::move(expr)) {}
+
+ llvm::Optional<ExprAST *> getExpr() {
+ if (expr.hasValue())
+ return expr->get();
+ return llvm::NoneType();
+ }
+
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *C) { return C->getKind() == Expr_Return; }
+};
+
+/// Expression class for a binary operator.
+class BinaryExprAST : public ExprAST {
+ char Op;
+ std::unique_ptr<ExprAST> LHS, RHS;
+
+public:
+ char getOp() { return Op; }
+ ExprAST *getLHS() { return LHS.get(); }
+ ExprAST *getRHS() { return RHS.get(); }
+
+ BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> LHS,
+ std::unique_ptr<ExprAST> RHS)
+ : ExprAST(Expr_BinOp, loc), Op(Op), LHS(std::move(LHS)),
+ RHS(std::move(RHS)) {}
+
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *C) { return C->getKind() == Expr_BinOp; }
+};
+
+/// Expression class for function calls.
+class CallExprAST : public ExprAST {
+ std::string Callee;
+ std::vector<std::unique_ptr<ExprAST>> Args;
+
+public:
+ CallExprAST(Location loc, const std::string &Callee,
+ std::vector<std::unique_ptr<ExprAST>> Args)
+ : ExprAST(Expr_Call, loc), Callee(Callee), Args(std::move(Args)) {}
+
+ llvm::StringRef getCallee() { return Callee; }
+ llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return Args; }
+
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *C) { return C->getKind() == Expr_Call; }
+};
+
+/// Expression class for builtin print calls.
+class PrintExprAST : public ExprAST {
+ std::unique_ptr<ExprAST> Arg;
+
+public:
+ PrintExprAST(Location loc, std::unique_ptr<ExprAST> Arg)
+ : ExprAST(Expr_Print, loc), Arg(std::move(Arg)) {}
+
+ ExprAST *getArg() { return Arg.get(); }
+
+ /// LLVM style RTTI
+ static bool classof(const ExprAST *C) { return C->getKind() == Expr_Print; }
+};
+
+/// This class represents the "prototype" for a function, which captures its
+/// name, and its argument names (thus implicitly the number of arguments the
+/// function takes).
+class PrototypeAST {
+ Location location;
+ std::string name;
+ std::vector<std::unique_ptr<VariableExprAST>> args;
+
+public:
+ PrototypeAST(Location location, const std::string &name,
+ std::vector<std::unique_ptr<VariableExprAST>> args)
+ : location(location), name(name), args(std::move(args)) {}
+
+ const Location &loc() { return location; }
+ const std::string &getName() const { return name; }
+ const std::vector<std::unique_ptr<VariableExprAST>> &getArgs() {
+ return args;
+ }
+};
+
+/// This class represents a function definition itself.
+class FunctionAST {
+ std::unique_ptr<PrototypeAST> Proto;
+ std::unique_ptr<ExprASTList> Body;
+
+public:
+ FunctionAST(std::unique_ptr<PrototypeAST> Proto,
+ std::unique_ptr<ExprASTList> Body)
+ : Proto(std::move(Proto)), Body(std::move(Body)) {}
+ PrototypeAST *getProto() { return Proto.get(); }
+ ExprASTList *getBody() { return Body.get(); }
+};
+
+/// This class represents a list of functions to be processed together
+class ModuleAST {
+ std::vector<FunctionAST> functions;
+
+public:
+ ModuleAST(std::vector<FunctionAST> functions)
+ : functions(std::move(functions)) {}
+
+ auto begin() -> decltype(functions.begin()) { return functions.begin(); }
+ auto end() -> decltype(functions.end()) { return functions.end(); }
+};
+
+void dump(ModuleAST &);
+
+} // namespace toy
+
+#endif // MLIR_TUTORIAL_TOY_AST_H_
diff --git a/include/toy/Dialect.h b/include/toy/Dialect.h
new file mode 100644
index 0000000..9d7f82d
--- /dev/null
+++ b/include/toy/Dialect.h
@@ -0,0 +1,393 @@
+//===- Dialect.h - Dialect definition for the Toy IR ----------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the IR Dialect for the Toy language.
+// See g3doc/Tutorials/Toy/Ch-3.md for more information.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_
+#define MLIR_TUTORIAL_TOY_DIALECT_H_
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/TypeSupport.h"
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+class Builder;
+}
+
+namespace toy {
+
+/// This is the definition of the Toy dialect. A dialect inherits from
+/// mlir::Dialect and register custom operations and types (in its constructor).
+/// It can also overridding general behavior of dialects exposed as virtual
+/// method, for example regarding verification and parsing/printing.
+class ToyDialect : public mlir::Dialect {
+public:
+ explicit ToyDialect(mlir::MLIRContext *ctx);
+
+ /// Parse a type registered to this dialect. Overridding this method is
+ /// required for dialects that have custom types.
+ /// Technically this is only needed to be able to round-trip to textual IR.
+ mlir::Type parseType(llvm::StringRef tyData,
+ mlir::Location loc) const override;
+
+ /// Print a type registered to this dialect. Overridding this method is
+ /// only required for dialects that have custom types.
+ /// Technically this is only needed to be able to round-trip to textual IR.
+ void printType(mlir::Type type, llvm::raw_ostream &os) const override;
+};
+
+////////////////////////////////////////////////////////////////////////////////
+/////////////////////// Custom Types for the Dialect ///////////////////////////
+////////////////////////////////////////////////////////////////////////////////
+
+namespace detail {
+struct ToyArrayTypeStorage;
+}
+
+/// LLVM-style RTTI: one entry per subclass to allow dyn_cast/isa.
+enum ToyTypeKind {
+ // The enum starts at the range reserved for this dialect.
+ TOY_TYPE = mlir::Type::FIRST_TOY_TYPE,
+ TOY_ARRAY,
+};
+
+/// Type for Toy arrays.
+/// In MLIR Types are reference to immutable and uniqued objects owned by the
+/// MLIRContext. As such `ToyArrayType` only wraps a pointer to an uniqued
+/// instance of `ToyArrayTypeStorage` (defined in our implementation file) and
+/// provides the public facade API to interact with the type.
+class ToyArrayType : public mlir::Type::TypeBase<ToyArrayType, mlir::Type,
+ detail::ToyArrayTypeStorage> {
+public:
+ using Base::Base;
+
+ /// Returns the dimensions for this array, or and empty range for a generic
+ /// array.
+ llvm::ArrayRef<int64_t> getShape();
+
+ /// Predicate to test if this array is generic (shape haven't been inferred
+ /// yet).
+ bool isGeneric() { return getShape().empty(); }
+
+ /// Return the rank of this array (0 if it is generic).
+ int getRank() { return getShape().size(); }
+
+ /// Return the type of individual elements in the array.
+ mlir::Type getElementType();
+
+ /// Get a MemRef equivalent to this array type.
+ mlir::MemRefType toMemref();
+
+ /// Get the unique instance of this Type from the context.
+ /// A ToyArrayType is only defined by the shape of the array.
+ static ToyArrayType get(mlir::MLIRContext *context,
+ llvm::ArrayRef<int64_t> shape = {});
+
+ /// Support method to enable LLVM-style RTTI type casting.
+ static bool kindof(unsigned kind) { return kind == ToyTypeKind::TOY_ARRAY; }
+};
+
+////////////////////////////////////////////////////////////////////////////////
+//////////////////// Custom Operations for the Dialect /////////////////////////
+////////////////////////////////////////////////////////////////////////////////
+
+/// Constant operation turns a literal into an SSA value. The data is attached
+/// to the operation as an attribute. For example:
+///
+/// %0 = "toy.constant"()
+/// {value: dense<tensor<2x3xf64>, [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>}
+/// : () -> !toy<"array<2, 3>">
+///
+/// An operation inherits from `class Op` and specifies optional traits. Here we
+/// indicate that `toy.constant` does not have any operands and returns a single
+/// result. The traits provide some utilities methods for the operation, for
+/// instance we will be able to use `getResult()`, but `getOperand()` won't be
+/// available.
+class ConstantOp : public mlir::Op<ConstantOp, mlir::OpTrait::ZeroOperands,
+ mlir::OpTrait::OneResult,
+ mlir::OpTrait::HasNoSideEffect> {
+public:
+ /// This is the name used by MLIR to match an operation to this class during
+ /// parsing.
+ static llvm::StringRef getOperationName() { return "toy.constant"; }
+
+ /// The operation can have extra verification beyond the traits they define.
+ mlir::LogicalResult verify();
+
+ /// Interface to mlir::Builder::create<PrintOp>(...)
+ /// This method populates the `state` that MLIR uses to create operations.
+ /// The `toy.constant` operation does not have arguments but attaches a
+ /// constant array as an attribute and returns it as an SSA value.
+ static void build(mlir::Builder *builder, mlir::OperationState *state,
+ llvm::ArrayRef<int64_t> shape,
+ mlir::DenseElementsAttr value);
+
+ /// Similar to the one above, but takes a single float and returns a
+ /// !toy<"array<1>">.
+ static void build(mlir::Builder *builder, mlir::OperationState *state,
+ mlir::FloatAttr value);
+
+ mlir::DenseElementsAttr getValue() {
+ return getAttr("value").cast<mlir::DenseElementsAttr>();
+ }
+
+ /// Inherit constructor.
+ using Op::Op;
+};
+
+/// Generic calls represent calls to a user defined function that needs to
+/// be specialized for the shape of its arguments. The callee name is attached
+/// as a literal string as an attribute. The arguments list must match the
+/// arguments expected by the callee. For example:
+///
+/// %4 = "toy.generic_call"(%1, %3) {callee: "my_func"}
+/// : (!toy<"array<2, 3>">, !toy<"array<2, 3>">) -> !toy<"array">
+///
+/// This is only valid if a function named "my_func" exists and takes two
+/// arguments.
+class GenericCallOp
+ : public mlir::Op<GenericCallOp, mlir::OpTrait::VariadicOperands,
+ mlir::OpTrait::OneResult> {
+public:
+ /// MLIR will use this to register the operation with the parser/printer.
+ static llvm::StringRef getOperationName() { return "toy.generic_call"; }
+
+ /// Operations can add custom verification beyond the traits they define.
+ mlir::LogicalResult verify();
+
+ /// Interface to the builder to allow:
+ /// mlir::Builder::create<GenericCallOp>(...)
+ /// This method populate the `state` that MLIR use to create operations.
+ /// The `toy.generic_call` operation accepts a callee name and a list of
+ /// arguments for the call.
+ static void build(mlir::Builder *builder, mlir::OperationState *state,
+ llvm::StringRef callee,
+ llvm::ArrayRef<mlir::Value *> arguments);
+
+ /// Return the name of the callee.
+ llvm::StringRef getCalleeName();
+
+ /// Inherit constructor.
+ using Op::Op;
+};
+
+/// Return operations terminate blocks (and functions as well). They take a
+/// single argument and the type must match the function return type.
+class ReturnOp
+ : public mlir::Op<ReturnOp, mlir::OpTrait::VariadicOperands,
+ mlir::OpTrait::ZeroResult, mlir::OpTrait::IsTerminator> {
+public:
+ static llvm::StringRef getOperationName() { return "toy.return"; }
+
+ /// Operations can add custom verification beyond the traits they define.
+ mlir::LogicalResult verify();
+
+ /// Interface to mlir::Builder::create<PrintOp>(...)
+ /// This method populate the `state` that MLIR use to create operations.
+ /// The `toy.return` operation accepts an optional single array as an argument
+ /// and does not have any returned value.
+ static void build(mlir::Builder *builder, mlir::OperationState *state,
+ mlir::Value *value = nullptr);
+
+ /// Return true if there is a returned value.
+ bool hasOperand() { return 0 != getNumOperands(); }
+
+ /// Helper to return the optional operand. Caller must check if the operand
+ /// is present before calling this.
+ mlir::Value *getOperand() { return getOperation()->getOperand(0); }
+
+ /// Inherit constructor.
+ using Op::Op;
+};
+
+/// The print builtin takes a single array argument and does not return any.
+class PrintOp : public mlir::Op<PrintOp, mlir::OpTrait::OneOperand,
+ mlir::OpTrait::ZeroResult> {
+public:
+ static llvm::StringRef getOperationName() { return "toy.print"; }
+
+ /// Operations can add custom verification beyond the traits they define.
+ mlir::LogicalResult verify();
+
+ /// Interface to mlir::Builder::create<PrintOp>(...)
+ /// This method populate the `state` that MLIR use to create operations.
+ /// The `toy.print` operation accepts a single array as argument and does
+ /// not have any returned value.
+ static void build(mlir::Builder *builder, mlir::OperationState *state,
+ mlir::Value *value);
+
+ /// Inherit constructor.
+ using Op::Op;
+};
+
+class TransposeOp : public mlir::Op<TransposeOp, mlir::OpTrait::OneOperand,
+ mlir::OpTrait::OneResult,
+ mlir::OpTrait::HasNoSideEffect> {
+public:
+ static llvm::StringRef getOperationName() { return "toy.transpose"; }
+
+ /// Operation can add custom verification beyond the traits they define.
+ mlir::LogicalResult verify();
+
+ /// Interface to mlir::Builder::create<TransposeOp>(...)
+ /// This method populate the `state` that MLIR use to create operations.
+ /// The `toy.transpose` operation accepts a single array as argument and
+ /// returns the transposed array as its only result.
+ static void build(mlir::Builder *builder, mlir::OperationState *state,
+ mlir::Value *value);
+
+ // Register our patterns for rewrite by the Canonicalization framework.
+ static void
+ getCanonicalizationPatterns(mlir::OwningRewritePatternList &results,
+ mlir::MLIRContext *context);
+
+ /// Inherit constructor.
+ using Op::Op;
+};
+
+/// Reshape operation is transforming its input array into a new array with the
+/// same number of elements but different shapes. For example:
+///
+/// %0 = "toy.transpose"(%arg1) : (!toy<"array<10>">) -> !toy<"array<5, 2>">
+///
+class ReshapeOp : public mlir::Op<ReshapeOp, mlir::OpTrait::OneOperand,
+ mlir::OpTrait::OneResult,
+ mlir::OpTrait::HasNoSideEffect> {
+public:
+ static llvm::StringRef getOperationName() { return "toy.reshape"; }
+
+ /// Operation can add custom verification beyond the traits they define.
+ mlir::LogicalResult verify();
+
+ /// Interface to mlir::Builder::create<ReshapeOp>(...)
+ /// This method populate the `state` that MLIR use to create operations.
+ /// The `toy.reshape` operation accepts a single array as argument and
+ /// returns the array with the specified reshapedType as its only result.
+ static void build(mlir::Builder *builder, mlir::OperationState *state,
+ mlir::Value *value, ToyArrayType reshapedType);
+
+ // Register our patterns for rewrite by the Canonicalization framework.
+ static void
+ getCanonicalizationPatterns(mlir::OwningRewritePatternList &results,
+ mlir::MLIRContext *context);
+
+ /// Inherit constructor.
+ using Op::Op;
+};
+
+/// Binary operation implementing a multiplication. For two-dimensional array
+/// a matrix multiplication is implemented, while for one dimensional array a
+/// dot product is performed.
+class MulOp : public mlir::Op<MulOp, mlir::OpTrait::NOperands<2>::Impl,
+ mlir::OpTrait::OneResult,
+ mlir::OpTrait::HasNoSideEffect> {
+public:
+ static llvm::StringRef getOperationName() { return "toy.mul"; }
+
+ /// Operation can add custom verification beyond the traits they define.
+ mlir::LogicalResult verify();
+
+ /// Interface to mlir::Builder::create<PrintOp>(...)
+ /// This method populate the `state` that MLIR use to create operations.
+ /// The `toy.mul` operation accepts two operands as argument and returns
+ /// a single value.
+ static void build(mlir::Builder *builder, mlir::OperationState *state,
+ mlir::Value *lhs, mlir::Value *rhs);
+
+ /// Convenience accessor for LHS of the expression.
+ mlir::Value *getLHS() { return getOperand(0); }
+
+ /// Convenience accessor for RHS of the expression.
+ mlir::Value *getRHS() { return getOperand(1); }
+
+ /// Inherit constructor.
+ using Op::Op;
+};
+
+/// Element wise addition of two arrays. The shape must match.
+class AddOp : public mlir::Op<AddOp, mlir::OpTrait::NOperands<2>::Impl,
+ mlir::OpTrait::OneResult,
+ mlir::OpTrait::HasNoSideEffect> {
+public:
+ static llvm::StringRef getOperationName() { return "toy.add"; }
+
+ /// Operation can add custom verification beyond the traits they define.
+ mlir::LogicalResult verify();
+
+ /// Interface to mlir::Builder::create<PrintOp>(...)
+ /// This method populate the `state` that MLIR use to create operations.
+ /// The `toy.mul` operation accepts two operands as argument and returns
+ /// a single value.
+ static void build(mlir::Builder *builder, mlir::OperationState *state,
+ mlir::Value *lhs, mlir::Value *rhs);
+
+ /// Convenience accessor for LHS of the expression.
+ mlir::Value *getLHS() { return getOperand(0); }
+
+ /// Convenience accessor for RHS of the expression.
+ mlir::Value *getRHS() { return getOperand(1); }
+
+ /// Inherit constructor.
+ using Op::Op;
+};
+
+/// AllocOp is a temporary operation for buffer allocation, created as part of
+/// partial lowering.
+class AllocOp : public mlir::Op<AllocOp, mlir::OpTrait::ZeroOperands,
+ mlir::OpTrait::OneResult> {
+public:
+ static llvm::StringRef getOperationName() { return "toy.alloc"; }
+
+ /// Interface to mlir::Builder::create<AllocOp>(...)
+ /// This method populate the `state` that MLIR use to create operations.
+ /// `toy.alloc` does not have any argument and returns a toy array.
+ static void build(mlir::Builder *builder, mlir::OperationState *state,
+ mlir::Type retType);
+
+ /// Inherit constructor.
+ using Op::Op;
+};
+
+/// FIXME: should be in std?
+class TypeCastOp : public mlir::Op<TypeCastOp, mlir::OpTrait::OneOperand,
+ mlir::OpTrait::OneResult,
+ mlir::OpTrait::HasNoSideEffect> {
+public:
+ static llvm::StringRef getOperationName() { return "toy.cast"; }
+
+ static void build(mlir::Builder *builder, mlir::OperationState *state,
+ mlir::Value *value, mlir::Type destTy);
+
+ // Register our patterns for rewrite by the Canonicalization framework.
+ static void
+ getCanonicalizationPatterns(mlir::OwningRewritePatternList &results,
+ mlir::MLIRContext *context);
+
+ /// Inherit constructor.
+ using Op::Op;
+};
+
+} // end namespace toy
+
+#endif // MLIR_TUTORIAL_TOY_DIALECT_H_
diff --git a/include/toy/Lexer.h b/include/toy/Lexer.h
new file mode 100644
index 0000000..d73adb9
--- /dev/null
+++ b/include/toy/Lexer.h
@@ -0,0 +1,239 @@
+//===- Lexer.h - Lexer for the Toy language -------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a simple Lexer for the Toy language.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_LEXER_H_
+#define MLIR_TUTORIAL_TOY_LEXER_H_
+
+#include "llvm/ADT/StringRef.h"
+
+#include <memory>
+#include <string>
+
+namespace toy {
+
+/// Structure definition a location in a file.
+struct Location {
+ std::shared_ptr<std::string> file; ///< filename
+ int line; ///< line number.
+ int col; ///< column number.
+};
+
+// List of Token returned by the lexer.
+enum Token : int {
+ tok_semicolon = ';',
+ tok_parenthese_open = '(',
+ tok_parenthese_close = ')',
+ tok_bracket_open = '{',
+ tok_bracket_close = '}',
+ tok_sbracket_open = '[',
+ tok_sbracket_close = ']',
+
+ tok_eof = -1,
+
+ // commands
+ tok_return = -2,
+ tok_var = -3,
+ tok_def = -4,
+
+ // primary
+ tok_identifier = -5,
+ tok_number = -6,
+};
+
+/// The Lexer is an abstract base class providing all the facilities that the
+/// Parser expects. It goes through the stream one token at a time and keeps
+/// track of the location in the file for debugging purpose.
+/// It relies on a subclass to provide a `readNextLine()` method. The subclass
+/// can proceed by reading the next line from the standard input or from a
+/// memory mapped file.
+class Lexer {
+public:
+ /// Create a lexer for the given filename. The filename is kept only for
+ /// debugging purpose (attaching a location to a Token).
+ Lexer(std::string filename)
+ : lastLocation(
+ {std::make_shared<std::string>(std::move(filename)), 0, 0}) {}
+ virtual ~Lexer() = default;
+
+ /// Look at the current token in the stream.
+ Token getCurToken() { return curTok; }
+
+ /// Move to the next token in the stream and return it.
+ Token getNextToken() { return curTok = getTok(); }
+
+ /// Move to the next token in the stream, asserting on the current token
+ /// matching the expectation.
+ void consume(Token tok) {
+ assert(tok == curTok && "consume Token mismatch expectation");
+ getNextToken();
+ }
+
+ /// Return the current identifier (prereq: getCurToken() == tok_identifier)
+ llvm::StringRef getId() {
+ assert(curTok == tok_identifier);
+ return IdentifierStr;
+ }
+
+ /// Return the current number (prereq: getCurToken() == tok_number)
+ double getValue() {
+ assert(curTok == tok_number);
+ return NumVal;
+ }
+
+ /// Return the location for the beginning of the current token.
+ Location getLastLocation() { return lastLocation; }
+
+ // Return the current line in the file.
+ int getLine() { return curLineNum; }
+
+ // Return the current column in the file.
+ int getCol() { return curCol; }
+
+private:
+ /// Delegate to a derived class fetching the next line. Returns an empty
+ /// string to signal end of file (EOF). Lines are expected to always finish
+ /// with "\n"
+ virtual llvm::StringRef readNextLine() = 0;
+
+ /// Return the next character from the stream. This manages the buffer for the
+ /// current line and request the next line buffer to the derived class as
+ /// needed.
+ int getNextChar() {
+ // The current line buffer should not be empty unless it is the end of file.
+ if (curLineBuffer.empty())
+ return EOF;
+ ++curCol;
+ auto nextchar = curLineBuffer.front();
+ curLineBuffer = curLineBuffer.drop_front();
+ if (curLineBuffer.empty())
+ curLineBuffer = readNextLine();
+ if (nextchar == '\n') {
+ ++curLineNum;
+ curCol = 0;
+ }
+ return nextchar;
+ }
+
+ /// Return the next token from standard input.
+ Token getTok() {
+ // Skip any whitespace.
+ while (isspace(LastChar))
+ LastChar = Token(getNextChar());
+
+ // Save the current location before reading the token characters.
+ lastLocation.line = curLineNum;
+ lastLocation.col = curCol;
+
+ if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9_]*
+ IdentifierStr = (char)LastChar;
+ while (isalnum((LastChar = Token(getNextChar()))) || LastChar == '_')
+ IdentifierStr += (char)LastChar;
+
+ if (IdentifierStr == "return")
+ return tok_return;
+ if (IdentifierStr == "def")
+ return tok_def;
+ if (IdentifierStr == "var")
+ return tok_var;
+ return tok_identifier;
+ }
+
+ if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
+ std::string NumStr;
+ do {
+ NumStr += LastChar;
+ LastChar = Token(getNextChar());
+ } while (isdigit(LastChar) || LastChar == '.');
+
+ NumVal = strtod(NumStr.c_str(), nullptr);
+ return tok_number;
+ }
+
+ if (LastChar == '#') {
+ // Comment until end of line.
+ do
+ LastChar = Token(getNextChar());
+ while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
+
+ if (LastChar != EOF)
+ return getTok();
+ }
+
+ // Check for end of file. Don't eat the EOF.
+ if (LastChar == EOF)
+ return tok_eof;
+
+ // Otherwise, just return the character as its ascii value.
+ Token ThisChar = Token(LastChar);
+ LastChar = Token(getNextChar());
+ return ThisChar;
+ }
+
+ /// The last token read from the input.
+ Token curTok = tok_eof;
+
+ /// Location for `curTok`.
+ Location lastLocation;
+
+ /// If the current Token is an identifier, this string contains the value.
+ std::string IdentifierStr;
+
+ /// If the current Token is a number, this contains the value.
+ double NumVal = 0;
+
+ /// The last value returned by getNextChar(). We need to keep it around as we
+ /// always need to read ahead one character to decide when to end a token and
+ /// we can't put it back in the stream after reading from it.
+ Token LastChar = Token(' ');
+
+ /// Keep track of the current line number in the input stream
+ int curLineNum = 0;
+
+ /// Keep track of the current column number in the input stream
+ int curCol = 0;
+
+ /// Buffer supplied by the derived class on calls to `readNextLine()`
+ llvm::StringRef curLineBuffer = "\n";
+};
+
+/// A lexer implementation operating on a buffer in memory.
+class LexerBuffer final : public Lexer {
+public:
+ LexerBuffer(const char *begin, const char *end, std::string filename)
+ : Lexer(std::move(filename)), current(begin), end(end) {}
+
+private:
+ /// Provide one line at a time to the Lexer, return an empty string when
+ /// reaching the end of the buffer.
+ llvm::StringRef readNextLine() override {
+ auto *begin = current;
+ while (current <= end && *current && *current != '\n')
+ ++current;
+ if (current <= end && *current)
+ ++current;
+ llvm::StringRef result{begin, static_cast<size_t>(current - begin)};
+ return result;
+ }
+ const char *current, *end;
+};
+} // namespace toy
+
+#endif // MLIR_TUTORIAL_TOY_LEXER_H_
diff --git a/include/toy/Lowering.h b/include/toy/Lowering.h
new file mode 100644
index 0000000..362a342
--- /dev/null
+++ b/include/toy/Lowering.h
@@ -0,0 +1,45 @@
+//===- Lowering.h - Lexer for the Toy language ----------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file exposes the interface to the lowering for Toy. It is divided in
+// two parts: an *early lowering* that emits operations in the `Linalg`
+// dialects for a subset of the Toy IR, and a *late lowering* that materializes
+// buffers and converts all operations and type to the LLVM dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_EXAMPLES_TOY_LOWERING_H_
+#define MLIR_EXAMPLES_TOY_LOWERING_H_
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+class DialectConversion;
+} // namespace mlir
+
+namespace toy {
+/// Create a pass for lowering operations in the `Linalg` dialects, for a subset
+/// of the Toy IR (matmul).
+mlir::Pass *createEarlyLoweringPass();
+
+/// Create a pass for the late lowering toward LLVM dialect.
+mlir::Pass *createLateLoweringPass();
+
+} // namespace toy
+
+#endif // MLIR_EXAMPLES_TOY_LOWERING_H_
diff --git a/include/toy/MLIRGen.h b/include/toy/MLIRGen.h
new file mode 100644
index 0000000..21637bc
--- /dev/null
+++ b/include/toy/MLIRGen.h
@@ -0,0 +1,42 @@
+//===- MLIRGen.h - MLIR Generation from a Toy AST -------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file declares a simple interface to perform IR generation targeting MLIR
+// from a Module AST for the Toy language.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_MLIRGEN_H_
+#define MLIR_TUTORIAL_TOY_MLIRGEN_H_
+
+#include <memory>
+
+namespace mlir {
+class MLIRContext;
+class Module;
+} // namespace mlir
+
+namespace toy {
+class ModuleAST;
+
+/// Emit IR for the given Toy moduleAST, returns a newly created MLIR module
+/// or nullptr on failure.
+std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context,
+ ModuleAST &moduleAST);
+} // namespace toy
+
+#endif // MLIR_TUTORIAL_TOY_MLIRGEN_H_
diff --git a/include/toy/Parser.h b/include/toy/Parser.h
new file mode 100644
index 0000000..bc7aa52
--- /dev/null
+++ b/include/toy/Parser.h
@@ -0,0 +1,494 @@
+//===- Parser.h - Toy Language Parser -------------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the parser for the Toy language. It processes the Token
+// provided by the Lexer and returns an AST.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_PARSER_H
+#define MLIR_TUTORIAL_TOY_PARSER_H
+
+#include "toy/AST.h"
+#include "toy/Lexer.h"
+
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <map>
+#include <utility>
+#include <vector>
+
+namespace toy {
+
+/// This is a simple recursive parser for the Toy language. It produces a well
+/// formed AST from a stream of Token supplied by the Lexer. No semantic checks
+/// or symbol resolution is performed. For example, variables are referenced by
+/// string and the code could reference an undeclared variable and the parsing
+/// succeeds.
+class Parser {
+public:
+ /// Create a Parser for the supplied lexer.
+ Parser(Lexer &lexer) : lexer(lexer) {}
+
+ /// Parse a full Module. A module is a list of function definitions.
+ std::unique_ptr<ModuleAST> ParseModule() {
+ lexer.getNextToken(); // prime the lexer
+
+ // Parse functions one at a time and accumulate in this vector.
+ std::vector<FunctionAST> functions;
+ while (auto F = ParseDefinition()) {
+ functions.push_back(std::move(*F));
+ if (lexer.getCurToken() == tok_eof)
+ break;
+ }
+ // If we didn't reach EOF, there was an error during parsing
+ if (lexer.getCurToken() != tok_eof)
+ return parseError<ModuleAST>("nothing", "at end of module");
+
+ return llvm::make_unique<ModuleAST>(std::move(functions));
+ }
+
+private:
+ Lexer &lexer;
+
+ /// Parse a return statement.
+ /// return :== return ; | return expr ;
+ std::unique_ptr<ReturnExprAST> ParseReturn() {
+ auto loc = lexer.getLastLocation();
+ lexer.consume(tok_return);
+
+ // return takes an optional argument
+ llvm::Optional<std::unique_ptr<ExprAST>> expr;
+ if (lexer.getCurToken() != ';') {
+ expr = ParseExpression();
+ if (!expr)
+ return nullptr;
+ }
+ return llvm::make_unique<ReturnExprAST>(std::move(loc), std::move(expr));
+ }
+
+ /// Parse a literal number.
+ /// numberexpr ::= number
+ std::unique_ptr<ExprAST> ParseNumberExpr() {
+ auto loc = lexer.getLastLocation();
+ auto Result =
+ llvm::make_unique<NumberExprAST>(std::move(loc), lexer.getValue());
+ lexer.consume(tok_number);
+ return std::move(Result);
+ }
+
+ /// Parse a literal array expression.
+ /// tensorLiteral ::= [ literalList ] | number
+ /// literalList ::= tensorLiteral | tensorLiteral, literalList
+ std::unique_ptr<ExprAST> ParseTensorLitteralExpr() {
+ auto loc = lexer.getLastLocation();
+ lexer.consume(Token('['));
+
+ // Hold the list of values at this nesting level.
+ std::vector<std::unique_ptr<ExprAST>> values;
+ // Hold the dimensions for all the nesting inside this level.
+ std::vector<int64_t> dims;
+ do {
+ // We can have either another nested array or a number literal.
+ if (lexer.getCurToken() == '[') {
+ values.push_back(ParseTensorLitteralExpr());
+ if (!values.back())
+ return nullptr; // parse error in the nested array.
+ } else {
+ if (lexer.getCurToken() != tok_number)
+ return parseError<ExprAST>("<num> or [", "in literal expression");
+ values.push_back(ParseNumberExpr());
+ }
+
+ // End of this list on ']'
+ if (lexer.getCurToken() == ']')
+ break;
+
+ // Elements are separated by a comma.
+ if (lexer.getCurToken() != ',')
+ return parseError<ExprAST>("] or ,", "in literal expression");
+
+ lexer.getNextToken(); // eat ,
+ } while (true);
+ if (values.empty())
+ return parseError<ExprAST>("<something>", "to fill literal expression");
+ lexer.getNextToken(); // eat ]
+ /// Fill in the dimensions now. First the current nesting level:
+ dims.push_back(values.size());
+ /// If there is any nested array, process all of them and ensure that
+ /// dimensions are uniform.
+ if (llvm::any_of(values, [](std::unique_ptr<ExprAST> &expr) {
+ return llvm::isa<LiteralExprAST>(expr.get());
+ })) {
+ auto *firstLiteral = llvm::dyn_cast<LiteralExprAST>(values.front().get());
+ if (!firstLiteral)
+ return parseError<ExprAST>("uniform well-nested dimensions",
+ "inside literal expession");
+
+ // Append the nested dimensions to the current level
+ auto &firstDims = firstLiteral->getDims();
+ dims.insert(dims.end(), firstDims.begin(), firstDims.end());
+
+ // Sanity check that shape is uniform across all elements of the list.
+ for (auto &expr : values) {
+ auto *exprLiteral = llvm::cast<LiteralExprAST>(expr.get());
+ if (!exprLiteral)
+ return parseError<ExprAST>("uniform well-nested dimensions",
+ "inside literal expession");
+ if (exprLiteral->getDims() != firstDims)
+ return parseError<ExprAST>("uniform well-nested dimensions",
+ "inside literal expession");
+ }
+ }
+ return llvm::make_unique<LiteralExprAST>(std::move(loc), std::move(values),
+ std::move(dims));
+ }
+
+ /// parenexpr ::= '(' expression ')'
+ std::unique_ptr<ExprAST> ParseParenExpr() {
+ lexer.getNextToken(); // eat (.
+ auto V = ParseExpression();
+ if (!V)
+ return nullptr;
+
+ if (lexer.getCurToken() != ')')
+ return parseError<ExprAST>(")", "to close expression with parentheses");
+ lexer.consume(Token(')'));
+ return V;
+ }
+
+ /// identifierexpr
+ /// ::= identifier
+ /// ::= identifier '(' expression ')'
+ std::unique_ptr<ExprAST> ParseIdentifierExpr() {
+ std::string name = lexer.getId();
+
+ auto loc = lexer.getLastLocation();
+ lexer.getNextToken(); // eat identifier.
+
+ if (lexer.getCurToken() != '(') // Simple variable ref.
+ return llvm::make_unique<VariableExprAST>(std::move(loc), name);
+
+ // This is a function call.
+ lexer.consume(Token('('));
+ std::vector<std::unique_ptr<ExprAST>> Args;
+ if (lexer.getCurToken() != ')') {
+ while (true) {
+ if (auto Arg = ParseExpression())
+ Args.push_back(std::move(Arg));
+ else
+ return nullptr;
+
+ if (lexer.getCurToken() == ')')
+ break;
+
+ if (lexer.getCurToken() != ',')
+ return parseError<ExprAST>(", or )", "in argument list");
+ lexer.getNextToken();
+ }
+ }
+ lexer.consume(Token(')'));
+
+ // It can be a builtin call to print
+ if (name == "print") {
+ if (Args.size() != 1)
+ return parseError<ExprAST>("<single arg>", "as argument to print()");
+
+ return llvm::make_unique<PrintExprAST>(std::move(loc),
+ std::move(Args[0]));
+ }
+
+ // Call to a user-defined function
+ return llvm::make_unique<CallExprAST>(std::move(loc), name,
+ std::move(Args));
+ }
+
+ /// primary
+ /// ::= identifierexpr
+ /// ::= numberexpr
+ /// ::= parenexpr
+ /// ::= tensorliteral
+ std::unique_ptr<ExprAST> ParsePrimary() {
+ switch (lexer.getCurToken()) {
+ default:
+ llvm::errs() << "unknown token '" << lexer.getCurToken()
+ << "' when expecting an expression\n";
+ return nullptr;
+ case tok_identifier:
+ return ParseIdentifierExpr();
+ case tok_number:
+ return ParseNumberExpr();
+ case '(':
+ return ParseParenExpr();
+ case '[':
+ return ParseTensorLitteralExpr();
+ case ';':
+ return nullptr;
+ case '}':
+ return nullptr;
+ }
+ }
+
+ /// Recursively parse the right hand side of a binary expression, the ExprPrec
+ /// argument indicates the precedence of the current binary operator.
+ ///
+ /// binoprhs ::= ('+' primary)*
+ std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
+ std::unique_ptr<ExprAST> LHS) {
+ // If this is a binop, find its precedence.
+ while (true) {
+ int TokPrec = GetTokPrecedence();
+
+ // If this is a binop that binds at least as tightly as the current binop,
+ // consume it, otherwise we are done.
+ if (TokPrec < ExprPrec)
+ return LHS;
+
+ // Okay, we know this is a binop.
+ int BinOp = lexer.getCurToken();
+ lexer.consume(Token(BinOp));
+ auto loc = lexer.getLastLocation();
+
+ // Parse the primary expression after the binary operator.
+ auto RHS = ParsePrimary();
+ if (!RHS)
+ return parseError<ExprAST>("expression", "to complete binary operator");
+
+ // If BinOp binds less tightly with RHS than the operator after RHS, let
+ // the pending operator take RHS as its LHS.
+ int NextPrec = GetTokPrecedence();
+ if (TokPrec < NextPrec) {
+ RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
+ if (!RHS)
+ return nullptr;
+ }
+
+ // Merge LHS/RHS.
+ LHS = llvm::make_unique<BinaryExprAST>(std::move(loc), BinOp,
+ std::move(LHS), std::move(RHS));
+ }
+ }
+
+ /// expression::= primary binoprhs
+ std::unique_ptr<ExprAST> ParseExpression() {
+ auto LHS = ParsePrimary();
+ if (!LHS)
+ return nullptr;
+
+ return ParseBinOpRHS(0, std::move(LHS));
+ }
+
+ /// type ::= < shape_list >
+ /// shape_list ::= num | num , shape_list
+ std::unique_ptr<VarType> ParseType() {
+ if (lexer.getCurToken() != '<')
+ return parseError<VarType>("<", "to begin type");
+ lexer.getNextToken(); // eat <
+
+ auto type = llvm::make_unique<VarType>();
+
+ while (lexer.getCurToken() == tok_number) {
+ type->shape.push_back(lexer.getValue());
+ lexer.getNextToken();
+ if (lexer.getCurToken() == ',')
+ lexer.getNextToken();
+ }
+
+ if (lexer.getCurToken() != '>')
+ return parseError<VarType>(">", "to end type");
+ lexer.getNextToken(); // eat >
+ return type;
+ }
+
+ /// Parse a variable declaration, it starts with a `var` keyword followed by
+ /// and identifier and an optional type (shape specification) before the
+ /// initializer.
+ /// decl ::= var identifier [ type ] = expr
+ std::unique_ptr<VarDeclExprAST> ParseDeclaration() {
+ if (lexer.getCurToken() != tok_var)
+ return parseError<VarDeclExprAST>("var", "to begin declaration");
+ auto loc = lexer.getLastLocation();
+ lexer.getNextToken(); // eat var
+
+ if (lexer.getCurToken() != tok_identifier)
+ return parseError<VarDeclExprAST>("identified",
+ "after 'var' declaration");
+ std::string id = lexer.getId();
+ lexer.getNextToken(); // eat id
+
+ std::unique_ptr<VarType> type; // Type is optional, it can be inferred
+ if (lexer.getCurToken() == '<') {
+ type = ParseType();
+ if (!type)
+ return nullptr;
+ }
+
+ if (!type)
+ type = llvm::make_unique<VarType>();
+ lexer.consume(Token('='));
+ auto expr = ParseExpression();
+ return llvm::make_unique<VarDeclExprAST>(std::move(loc), std::move(id),
+ std::move(*type), std::move(expr));
+ }
+
+ /// Parse a block: a list of expression separated by semicolons and wrapped in
+ /// curly braces.
+ ///
+ /// block ::= { expression_list }
+ /// expression_list ::= block_expr ; expression_list
+ /// block_expr ::= decl | "return" | expr
+ std::unique_ptr<ExprASTList> ParseBlock() {
+ if (lexer.getCurToken() != '{')
+ return parseError<ExprASTList>("{", "to begin block");
+ lexer.consume(Token('{'));
+
+ auto exprList = llvm::make_unique<ExprASTList>();
+
+ // Ignore empty expressions: swallow sequences of semicolons.
+ while (lexer.getCurToken() == ';')
+ lexer.consume(Token(';'));
+
+ while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) {
+ if (lexer.getCurToken() == tok_var) {
+ // Variable declaration
+ auto varDecl = ParseDeclaration();
+ if (!varDecl)
+ return nullptr;
+ exprList->push_back(std::move(varDecl));
+ } else if (lexer.getCurToken() == tok_return) {
+ // Return statement
+ auto ret = ParseReturn();
+ if (!ret)
+ return nullptr;
+ exprList->push_back(std::move(ret));
+ } else {
+ // General expression
+ auto expr = ParseExpression();
+ if (!expr)
+ return nullptr;
+ exprList->push_back(std::move(expr));
+ }
+ // Ensure that elements are separated by a semicolon.
+ if (lexer.getCurToken() != ';')
+ return parseError<ExprASTList>(";", "after expression");
+
+ // Ignore empty expressions: swallow sequences of semicolons.
+ while (lexer.getCurToken() == ';')
+ lexer.consume(Token(';'));
+ }
+
+ if (lexer.getCurToken() != '}')
+ return parseError<ExprASTList>("}", "to close block");
+
+ lexer.consume(Token('}'));
+ return exprList;
+ }
+
+ /// prototype ::= def id '(' decl_list ')'
+ /// decl_list ::= identifier | identifier, decl_list
+ std::unique_ptr<PrototypeAST> ParsePrototype() {
+ auto loc = lexer.getLastLocation();
+ lexer.consume(tok_def);
+ if (lexer.getCurToken() != tok_identifier)
+ return parseError<PrototypeAST>("function name", "in prototype");
+
+ std::string FnName = lexer.getId();
+ lexer.consume(tok_identifier);
+
+ if (lexer.getCurToken() != '(')
+ return parseError<PrototypeAST>("(", "in prototype");
+ lexer.consume(Token('('));
+
+ std::vector<std::unique_ptr<VariableExprAST>> args;
+ if (lexer.getCurToken() != ')') {
+ do {
+ std::string name = lexer.getId();
+ auto loc = lexer.getLastLocation();
+ lexer.consume(tok_identifier);
+ auto decl = llvm::make_unique<VariableExprAST>(std::move(loc), name);
+ args.push_back(std::move(decl));
+ if (lexer.getCurToken() != ',')
+ break;
+ lexer.consume(Token(','));
+ if (lexer.getCurToken() != tok_identifier)
+ return parseError<PrototypeAST>(
+ "identifier", "after ',' in function parameter list");
+ } while (true);
+ }
+ if (lexer.getCurToken() != ')')
+ return parseError<PrototypeAST>("}", "to end function prototype");
+
+ // success.
+ lexer.consume(Token(')'));
+ return llvm::make_unique<PrototypeAST>(std::move(loc), FnName,
+ std::move(args));
+ }
+
+ /// Parse a function definition, we expect a prototype initiated with the
+ /// `def` keyword, followed by a block containing a list of expressions.
+ ///
+ /// definition ::= prototype block
+ std::unique_ptr<FunctionAST> ParseDefinition() {
+ auto Proto = ParsePrototype();
+ if (!Proto)
+ return nullptr;
+
+ if (auto block = ParseBlock())
+ return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(block));
+ return nullptr;
+ }
+
+ /// Get the precedence of the pending binary operator token.
+ int GetTokPrecedence() {
+ if (!isascii(lexer.getCurToken()))
+ return -1;
+
+ // 1 is lowest precedence.
+ switch (static_cast<char>(lexer.getCurToken())) {
+ case '-':
+ return 20;
+ case '+':
+ return 20;
+ case '*':
+ return 40;
+ default:
+ return -1;
+ }
+ }
+
+ /// Helper function to signal errors while parsing, it takes an argument
+ /// indicating the expected token and another argument giving more context.
+ /// Location is retrieved from the lexer to enrich the error message.
+ template <typename R, typename T, typename U = const char *>
+ std::unique_ptr<R> parseError(T &&expected, U &&context = "") {
+ auto curToken = lexer.getCurToken();
+ llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", "
+ << lexer.getLastLocation().col << "): expected '" << expected
+ << "' " << context << " but has Token " << curToken;
+ if (isprint(curToken))
+ llvm::errs() << " '" << (char)curToken << "'";
+ llvm::errs() << "\n";
+ return nullptr;
+ }
+};
+
+} // namespace toy
+
+#endif // MLIR_TUTORIAL_TOY_PARSER_H
diff --git a/include/toy/Passes.h b/include/toy/Passes.h
new file mode 100644
index 0000000..dd73b95
--- /dev/null
+++ b/include/toy/Passes.h
@@ -0,0 +1,33 @@
+//===- Passes.h - Toy Passes Definition -----------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file exposes the entry points to create compiler passes for Toy.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_PASSES_H
+#define MLIR_TUTORIAL_TOY_PASSES_H
+
+namespace mlir {
+class Pass;
+} // namespace mlir
+
+namespace toy {
+mlir::Pass *createShapeInferencePass();
+} // namespace toy
+
+#endif // MLIR_TUTORIAL_TOY_PASSES_H