Add verifier function for checking negative perms.

This commit is contained in:
Doru Bercea 2020-01-20 14:46:54 -05:00
parent f0b484c0bc
commit bd44d8402e
2 changed files with 22 additions and 8 deletions

View File

@ -13,6 +13,7 @@
#include "mlir/IR/Function.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/SetVector.h"
@ -413,21 +414,32 @@ void ONNXTransposeOp::inferShapes() {
if (auto permutation = getAttrOfType<ArrayAttr>(
ONNXTransposeOp::getPermAttrName())) {
// Perform transposition according to perm attribute.
for (auto perm : permutation.getValue()) {
int32_t index = perm.cast<IntegerAttr>().getInt();
if (index < 0)
emitError("Cannot tranpose when permutation contains negative index.");
dims.emplace_back(arrayTy.getShape()[index]);
}
for (auto perm : permutation.getValue())
dims.emplace_back(arrayTy.getShape()[perm.cast<IntegerAttr>().getInt()]);
} else {
// Default
for (auto shape : llvm::reverse(arrayTy.getShape()))
dims.emplace_back(shape);
for (auto dim : llvm::reverse(arrayTy.getShape()))
dims.emplace_back(dim);
}
getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType()));
}
LogicalResult verify(ONNXTransposeOp op) {
auto module = op.getParentOfType<ModuleOp>();
if (!module)
op.emitError("Expected to belong to a module.");
if (auto permutation = op.getAttrOfType<ArrayAttr>(
ONNXTransposeOp::getPermAttrName())) {
for (auto perm : permutation.getValue())
if (perm.cast<IntegerAttr>().getInt() < 0)
op.emitError("Cannot tranpose, permuation contains negative index.");
}
return success();
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//

View File

@ -3102,6 +3102,8 @@ def ONNXTransposeOp:ONNX_Op<"Transpose",
let extraClassDeclaration = [{
static StringRef getPermAttrName() { return "perm"; }
}];
let verifier = [{ return ::verify(*this); }];
}
def ONNXUniqueOp:ONNX_Op<"Unique",