summaryrefslogtreecommitdiff
path: root/toy/ToyDialect.cpp
blob: be117f56de34566c3b6c4961e2ec8d10548e0fe6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
//===- ToyDialect.cpp - Toy IR Dialect registration in MLIR ---------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements the dialect for the Toy IR: custom type parsing and
// operation verification.
//
//===----------------------------------------------------------------------===//

#include "toy/Dialect.h"

#include "mlir/IR/Builders.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/raw_ostream.h"

using llvm::ArrayRef;
using llvm::raw_ostream;
using llvm::raw_string_ostream;
using llvm::SmallVector;
using llvm::StringRef;
using llvm::Twine;

namespace toy {
namespace detail {

/// This class holds the implementation of the ToyArrayType.
/// It is intended to be uniqued based on its content and owned by the context.
struct ToyArrayTypeStorage : public mlir::TypeStorage {
  /// This defines how we unique this type in the context: our key contains
  /// only the shape, a more complex type would have multiple entries in the
  /// tuple here.
  /// The element of the tuples usually matches 1-1 the arguments from the
  /// public `get()` method arguments from the facade.
  using KeyTy = std::tuple<ArrayRef<int64_t>>;
  static unsigned hashKey(const KeyTy &key) {
    return llvm::hash_combine(std::get<0>(key));
  }
  /// When the key hash hits an existing type, we compare the shape themselves
  /// to confirm we have the right type.
  bool operator==(const KeyTy &key) const { return key == KeyTy(getShape()); }

  /// This is a factory method to create our type storage. It is only
  /// invoked after looking up the type in the context using the key and not
  /// finding it.
  static ToyArrayTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
                                        const KeyTy &key) {
    // Copy the shape array into the bumpptr allocator owned by the context.
    ArrayRef<int64_t> shape = allocator.copyInto(std::get<0>(key));

    // Allocate the instance for the ToyArrayTypeStorage itself
    auto *storage = allocator.allocate<ToyArrayTypeStorage>();
    // Initialize the instance using placement new.
    return new (storage) ToyArrayTypeStorage(shape);
  }

  ArrayRef<int64_t> getShape() const { return shape; }

private:
  ArrayRef<int64_t> shape;

  /// Constructor is only invoked from the `construct()` method above.
  ToyArrayTypeStorage(ArrayRef<int64_t> shape) : shape(shape) {}
};

} // namespace detail

mlir::Type ToyArrayType::getElementType() {
  return mlir::FloatType::getF64(getContext());
}

ToyArrayType ToyArrayType::get(mlir::MLIRContext *context,
                               ArrayRef<int64_t> shape) {
  return Base::get(context, ToyTypeKind::TOY_ARRAY, shape);
}

ArrayRef<int64_t> ToyArrayType::getShape() { return getImpl()->getShape(); }

mlir::MemRefType ToyArrayType::toMemref() {
  auto memRefType = mlir::MemRefType::get(getShape(), getElementType(), {}, 0);
  return memRefType;
}

/// Dialect creation, the instance will be owned by the context. This is the
/// point of registration of custom types and operations for the dialect.
ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
  addOperations<ConstantOp, GenericCallOp, PrintOp, TransposeOp, ReshapeOp,
                MulOp, AddOp, ReturnOp, AllocOp, TypeCastOp>();
  addTypes<ToyArrayType>();
}

/// Parse a type registered to this dialect, we expect only Toy arrays.
mlir::Type ToyDialect::parseType(StringRef tyData, mlir::Location loc) const {
  // Sanity check: we only support array or array<...>
  if (!tyData.startswith("array")) {
    getContext()->emitError(loc, "Invalid Toy type '" + tyData +
                                     "', array expected");
    return nullptr;
  }
  // Drop the "array" prefix from the type name, we expect either an empty
  // string or just the shape.
  tyData = tyData.drop_front(StringRef("array").size());
  // This is the generic array case without shape, early return it.
  if (tyData.empty())
    return ToyArrayType::get(getContext());

  // Use a regex to parse the shape (for efficient we should store this regex in
  // the dialect itself).
  SmallVector<StringRef, 4> matches;
  auto shapeRegex = llvm::Regex("^<([0-9]+)(, ([0-9]+))*>$");
  if (!shapeRegex.match(tyData, &matches)) {
    getContext()->emitError(loc, "Invalid toy array shape '" + tyData + "'");
    return nullptr;
  }
  SmallVector<int64_t, 4> shape;
  // Iterate through the captures, skip the first one which is the full string.
  for (auto dimStr :
       llvm::make_range(std::next(matches.begin()), matches.end())) {
    if (dimStr.startswith(","))
      continue; // POSIX misses non-capturing groups.
    if (dimStr.empty())
      continue; // '*' makes it an optional group capture
    // Convert the capture to an integer
    unsigned long long dim;
    if (getAsUnsignedInteger(dimStr, /* Radix = */ 10, dim)) {
      getContext()->emitError(
          loc, "Couldn't parse dimension as integer, matched: " + dimStr);
      return mlir::Type();
    }
    shape.push_back(dim);
  }
  // Finally we collected all the dimensions in the shape,
  // create the array type.
  return ToyArrayType::get(getContext(), shape);
}

/// Print a Toy array type, for example `array<2, 3, 4>`
void ToyDialect::printType(mlir::Type type, raw_ostream &os) const {
  auto arrayTy = type.dyn_cast<ToyArrayType>();
  if (!arrayTy) {
    os << "unknown toy type";
    return;
  }
  os << "array";
  if (!arrayTy.getShape().empty()) {
    os << "<";
    mlir::interleaveComma(arrayTy.getShape(), os);
    os << ">";
  }
}

////////////////////////////////////////////////////////////////////////////////
//////////////////// Custom Operations for the Dialect /////////////////////////
////////////////////////////////////////////////////////////////////////////////

/// Helper to verify that the result of an operation is a Toy array type.
template <typename T> static mlir::LogicalResult verifyToyReturnArray(T *op) {
  if (!op->getResult()->getType().template isa<ToyArrayType>()) {
    std::string msg;
    raw_string_ostream os(msg);
    os << "expects a Toy Array for its argument, got "
       << op->getResult()->getType();
    return op->emitOpError(os.str());
  }
  return mlir::success();
}

/// Helper to verify that the two operands of a binary operation are Toy
/// arrays..
template <typename T> static mlir::LogicalResult verifyToyBinOperands(T *op) {
  if (!op->getOperand(0)->getType().template isa<ToyArrayType>()) {
    std::string msg;
    raw_string_ostream os(msg);
    os << "expects a Toy Array for its LHS, got "
       << op->getOperand(0)->getType();
    return op->emitOpError(os.str());
  }
  if (!op->getOperand(1)->getType().template isa<ToyArrayType>()) {
    std::string msg;
    raw_string_ostream os(msg);
    os << "expects a Toy Array for its LHS, got "
       << op->getOperand(0)->getType();
    return op->emitOpError(os.str());
  }
  return mlir::success();
}

/// Build a constant operation.
/// The builder is passed as an argument, so is the state that this method is
/// expected to fill in order to build the operation.
void ConstantOp::build(mlir::Builder *builder, mlir::OperationState *state,
                       ArrayRef<int64_t> shape, mlir::DenseElementsAttr value) {
  state->types.push_back(ToyArrayType::get(builder->getContext(), shape));
  auto dataAttribute = builder->getNamedAttr("value", value);
  state->attributes.push_back(dataAttribute);
}

/// Build a constant operation.
/// The builder is passed as an argument, so is the state that this method is
/// expected to fill in order to build the operation.
void ConstantOp::build(mlir::Builder *builder, mlir::OperationState *state,
                       mlir::FloatAttr value) {
  // Broadcast and forward to the other build factory
  mlir::Type elementType = mlir::FloatType::getF64(builder->getContext());
  auto dataType = builder->getTensorType({1}, elementType);
  auto dataAttribute = builder->getDenseElementsAttr(dataType, {value})
                           .cast<mlir::DenseElementsAttr>();

  ConstantOp::build(builder, state, {1}, dataAttribute);
}

/// Verifier for constant operation.
mlir::LogicalResult ConstantOp::verify() {
  // Ensure that the return type is a Toy array
  if (failed(verifyToyReturnArray(this)))
    return mlir::failure();

  // We expect the constant itself to be stored as an attribute.
  auto dataAttr = getAttr("value").dyn_cast<mlir::DenseElementsAttr>();
  if (!dataAttr) {
    return emitOpError(
        "missing valid `value` DenseElementsAttribute on toy.constant()");
  }
  auto attrType = dataAttr.getType().dyn_cast<mlir::TensorType>();
  if (!attrType) {
    return emitOpError(
        "missing valid `value` DenseElementsAttribute on toy.constant()");
  }

  // If the return type of the constant is not a generic array, the shape must
  // match the shape of the attribute holding the data.
  auto resultType = getResult()->getType().cast<ToyArrayType>();
  if (!resultType.isGeneric()) {
    if (attrType.getRank() != resultType.getRank()) {
      return emitOpError("The rank of the toy.constant return type must match "
                         "the one of the attached value attribute: " +
                         Twine(attrType.getRank()) +
                         " != " + Twine(resultType.getRank()));
    }
    for (int dim = 0; dim < attrType.getRank(); ++dim) {
      if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
        std::string msg;
        raw_string_ostream os(msg);
        return emitOpError(
            "Shape mismatch between toy.constant return type and its "
            "attribute at dimension " +
            Twine(dim) + ": " + Twine(attrType.getShape()[dim]) +
            " != " + Twine(resultType.getShape()[dim]));
      }
    }
  }
  return mlir::success();
}

void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState *state,
                          StringRef callee, ArrayRef<mlir::Value *> arguments) {
  // Generic call always returns a generic ToyArray initially
  state->types.push_back(ToyArrayType::get(builder->getContext()));
  state->operands.assign(arguments.begin(), arguments.end());
  auto calleeAttr = builder->getStringAttr(callee);
  state->attributes.push_back(builder->getNamedAttr("callee", calleeAttr));
}

mlir::LogicalResult GenericCallOp::verify() {
  // Verify that every operand is a Toy Array
  for (int opId = 0, num = getNumOperands(); opId < num; ++opId) {
    if (!getOperand(opId)->getType().template isa<ToyArrayType>()) {
      std::string msg;
      raw_string_ostream os(msg);
      os << "expects a Toy Array for its " << opId << " operand, got "
         << getOperand(opId)->getType();
      return emitOpError(os.str());
    }
  }
  return mlir::success();
}

/// Return the name of the callee.
StringRef GenericCallOp::getCalleeName() {
  return getAttr("callee").cast<mlir::StringAttr>().getValue();
}

template <typename T> static mlir::LogicalResult verifyToySingleOperand(T *op) {
  if (!op->getOperand()->getType().template isa<ToyArrayType>()) {
    std::string msg;
    raw_string_ostream os(msg);
    os << "expects a Toy Array for its argument, got "
       << op->getOperand()->getType();
    return op->emitOpError(os.str());
  }
  return mlir::success();
}

void ReturnOp::build(mlir::Builder *builder, mlir::OperationState *state,
                     mlir::Value *value) {
  // Return does not return any value and has an optional single argument
  if (value)
    state->operands.push_back(value);
}

mlir::LogicalResult ReturnOp::verify() {
  if (getNumOperands() > 1)
    return emitOpError("expects zero or one operand, got " +
                       Twine(getNumOperands()));
  if (hasOperand() && failed(verifyToySingleOperand(this)))
    return mlir::failure();
  return mlir::success();
}

void PrintOp::build(mlir::Builder *builder, mlir::OperationState *state,
                    mlir::Value *value) {
  // Print does not return any value and has a single argument
  state->operands.push_back(value);
}

mlir::LogicalResult PrintOp::verify() {
  if (failed(verifyToySingleOperand(this)))
    return mlir::failure();
  return mlir::success();
}

void TransposeOp::build(mlir::Builder *builder, mlir::OperationState *state,
                        mlir::Value *value) {
  state->types.push_back(ToyArrayType::get(builder->getContext()));
  state->operands.push_back(value);
}

mlir::LogicalResult TransposeOp::verify() {
  if (failed(verifyToySingleOperand(this)))
    return mlir::failure();
  return mlir::success();
}

void ReshapeOp::build(mlir::Builder *builder, mlir::OperationState *state,
                      mlir::Value *value, ToyArrayType reshapedType) {
  state->types.push_back(reshapedType);
  state->operands.push_back(value);
}

mlir::LogicalResult ReshapeOp::verify() {
  if (failed(verifyToySingleOperand(this)))
    return mlir::failure();
  auto retTy = getResult()->getType().dyn_cast<ToyArrayType>();
  if (!retTy)
    return emitOpError("toy.reshape is expected to produce a Toy array");
  if (retTy.isGeneric())
    return emitOpError("toy.reshape is expected to produce a shaped Toy array, "
                       "got a generic one.");
  return mlir::success();
}

void AddOp::build(mlir::Builder *builder, mlir::OperationState *state,
                  mlir::Value *lhs, mlir::Value *rhs) {
  state->types.push_back(ToyArrayType::get(builder->getContext()));
  state->operands.push_back(lhs);
  state->operands.push_back(rhs);
}

mlir::LogicalResult AddOp::verify() {
  if (failed(verifyToyBinOperands(this)))
    return mlir::failure();
  return mlir::success();
}

void MulOp::build(mlir::Builder *builder, mlir::OperationState *state,
                  mlir::Value *lhs, mlir::Value *rhs) {
  state->types.push_back(ToyArrayType::get(builder->getContext()));
  state->operands.push_back(lhs);
  state->operands.push_back(rhs);
}

mlir::LogicalResult MulOp::verify() {
  if (failed(verifyToyBinOperands(this)))
    return mlir::failure();
  return mlir::success();
}

void AllocOp::build(mlir::Builder *builder, mlir::OperationState *state,
                    mlir::Type retType) {
  state->types.push_back(retType);
}

void TypeCastOp::build(mlir::Builder *builder, mlir::OperationState *state,
                       mlir::Value *value, mlir::Type destTy) {
  state->operands.push_back(value);
  state->types.push_back(destTy);
}

} // namespace toy