Add verifier function for checking negative perms.

This commit is contained in:
Doru Bercea 2020-01-20 14:46:54 -05:00
parent f0b484c0bc
commit bd44d8402e
2 changed files with 22 additions and 8 deletions

View File

@ -13,6 +13,7 @@
#include "mlir/IR/Function.h" #include "mlir/IR/Function.h"
#include "mlir/IR/IntegerSet.h" #include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h" #include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SetVector.h"
@ -413,21 +414,32 @@ void ONNXTransposeOp::inferShapes() {
if (auto permutation = getAttrOfType<ArrayAttr>( if (auto permutation = getAttrOfType<ArrayAttr>(
ONNXTransposeOp::getPermAttrName())) { ONNXTransposeOp::getPermAttrName())) {
// Perform transposition according to perm attribute. // Perform transposition according to perm attribute.
for (auto perm : permutation.getValue()) { for (auto perm : permutation.getValue())
int32_t index = perm.cast<IntegerAttr>().getInt(); dims.emplace_back(arrayTy.getShape()[perm.cast<IntegerAttr>().getInt()]);
if (index < 0)
emitError("Cannot tranpose when permutation contains negative index.");
dims.emplace_back(arrayTy.getShape()[index]);
}
} else { } else {
// Default // Default
for (auto shape : llvm::reverse(arrayTy.getShape())) for (auto dim : llvm::reverse(arrayTy.getShape()))
dims.emplace_back(shape); dims.emplace_back(dim);
} }
getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType())); getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType()));
} }
LogicalResult verify(ONNXTransposeOp op) {
auto module = op.getParentOfType<ModuleOp>();
if (!module)
op.emitError("Expected to belong to a module.");
if (auto permutation = op.getAttrOfType<ArrayAttr>(
ONNXTransposeOp::getPermAttrName())) {
for (auto perm : permutation.getValue())
if (perm.cast<IntegerAttr>().getInt() < 0)
op.emitError("Cannot tranpose, permuation contains negative index.");
}
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TableGen'd op method definitions // TableGen'd op method definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -3102,6 +3102,8 @@ def ONNXTransposeOp:ONNX_Op<"Transpose",
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static StringRef getPermAttrName() { return "perm"; } static StringRef getPermAttrName() { return "perm"; }
}]; }];
let verifier = [{ return ::verify(*this); }];
} }
def ONNXUniqueOp:ONNX_Op<"Unique", def ONNXUniqueOp:ONNX_Op<"Unique",