Add verifier function for checking negative perms.
This commit is contained in:
parent
f0b484c0bc
commit
bd44d8402e
|
@ -13,6 +13,7 @@
|
|||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
|
@ -413,21 +414,32 @@ void ONNXTransposeOp::inferShapes() {
|
|||
if (auto permutation = getAttrOfType<ArrayAttr>(
|
||||
ONNXTransposeOp::getPermAttrName())) {
|
||||
// Perform transposition according to perm attribute.
|
||||
for (auto perm : permutation.getValue()) {
|
||||
int32_t index = perm.cast<IntegerAttr>().getInt();
|
||||
if (index < 0)
|
||||
emitError("Cannot tranpose when permutation contains negative index.");
|
||||
dims.emplace_back(arrayTy.getShape()[index]);
|
||||
}
|
||||
for (auto perm : permutation.getValue())
|
||||
dims.emplace_back(arrayTy.getShape()[perm.cast<IntegerAttr>().getInt()]);
|
||||
} else {
|
||||
// Default
|
||||
for (auto shape : llvm::reverse(arrayTy.getShape()))
|
||||
dims.emplace_back(shape);
|
||||
for (auto dim : llvm::reverse(arrayTy.getShape()))
|
||||
dims.emplace_back(dim);
|
||||
}
|
||||
|
||||
getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||
}
|
||||
|
||||
LogicalResult verify(ONNXTransposeOp op) {
|
||||
auto module = op.getParentOfType<ModuleOp>();
|
||||
if (!module)
|
||||
op.emitError("Expected to belong to a module.");
|
||||
|
||||
if (auto permutation = op.getAttrOfType<ArrayAttr>(
|
||||
ONNXTransposeOp::getPermAttrName())) {
|
||||
for (auto perm : permutation.getValue())
|
||||
if (perm.cast<IntegerAttr>().getInt() < 0)
|
||||
op.emitError("Cannot tranpose, permuation contains negative index.");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -3102,6 +3102,8 @@ def ONNXTransposeOp:ONNX_Op<"Transpose",
|
|||
let extraClassDeclaration = [{
|
||||
static StringRef getPermAttrName() { return "perm"; }
|
||||
}];
|
||||
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
def ONNXUniqueOp:ONNX_Op<"Unique",
|
||||
|
|
Loading…
Reference in New Issue