From bd44d8402e09c8b1c6068550abc640fdf5c41ef9 Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Mon, 20 Jan 2020 14:46:54 -0500 Subject: [PATCH] Add verifier function for checking negative perms. --- src/dialect/onnx/onnx_ops.cpp | 28 ++++++++++++++++++++-------- src/dialect/onnx/onnxop.inc | 2 ++ 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 73d76b5..cef90cb 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -13,6 +13,7 @@ #include "mlir/IR/Function.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/Module.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/SetVector.h" @@ -413,21 +414,32 @@ void ONNXTransposeOp::inferShapes() { if (auto permutation = getAttrOfType( ONNXTransposeOp::getPermAttrName())) { // Perform transposition according to perm attribute. - for (auto perm : permutation.getValue()) { - int32_t index = perm.cast().getInt(); - if (index < 0) - emitError("Cannot tranpose when permutation contains negative index."); - dims.emplace_back(arrayTy.getShape()[index]); - } + for (auto perm : permutation.getValue()) + dims.emplace_back(arrayTy.getShape()[perm.cast().getInt()]); } else { // Default - for (auto shape : llvm::reverse(arrayTy.getShape())) - dims.emplace_back(shape); + for (auto dim : llvm::reverse(arrayTy.getShape())) + dims.emplace_back(dim); } getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType())); } +LogicalResult verify(ONNXTransposeOp op) { + auto module = op.getParentOfType(); + if (!module) + op.emitError("Expected to belong to a module."); + + if (auto permutation = op.getAttrOfType( + ONNXTransposeOp::getPermAttrName())) { + for (auto perm : permutation.getValue()) + if (perm.cast().getInt() < 0) + op.emitError("Cannot tranpose, permuation contains negative index."); + } + + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/src/dialect/onnx/onnxop.inc b/src/dialect/onnx/onnxop.inc index 5d22346..fc2714e 100644 --- a/src/dialect/onnx/onnxop.inc +++ b/src/dialect/onnx/onnxop.inc @@ -3102,6 +3102,8 @@ def ONNXTransposeOp:ONNX_Op<"Transpose", let extraClassDeclaration = [{ static StringRef getPermAttrName() { return "perm"; } }]; + + let verifier = [{ return ::verify(*this); }]; } def ONNXUniqueOp:ONNX_Op<"Unique",