Add verifier function for checking negative perms.
This commit is contained in:
		
							parent
							
								
									f0b484c0bc
								
							
						
					
					
						commit
						bd44d8402e
					
				|  | @ -13,6 +13,7 @@ | |||
| #include "mlir/IR/Function.h" | ||||
| #include "mlir/IR/IntegerSet.h" | ||||
| #include "mlir/IR/Matchers.h" | ||||
| #include "mlir/IR/Module.h" | ||||
| #include "mlir/IR/OpImplementation.h" | ||||
| #include "mlir/IR/PatternMatch.h" | ||||
| #include "llvm/ADT/SetVector.h" | ||||
|  | @ -413,21 +414,32 @@ void ONNXTransposeOp::inferShapes() { | |||
|   if (auto permutation = getAttrOfType<ArrayAttr>( | ||||
|           ONNXTransposeOp::getPermAttrName())) { | ||||
|     // Perform transposition according to perm attribute.
 | ||||
|     for (auto perm : permutation.getValue()) { | ||||
|       int32_t index = perm.cast<IntegerAttr>().getInt(); | ||||
|       if (index < 0) | ||||
|         emitError("Cannot tranpose when permutation contains negative index."); | ||||
|       dims.emplace_back(arrayTy.getShape()[index]); | ||||
|     } | ||||
|     for (auto perm : permutation.getValue()) | ||||
|       dims.emplace_back(arrayTy.getShape()[perm.cast<IntegerAttr>().getInt()]); | ||||
|   } else { | ||||
|     // Default
 | ||||
|     for (auto shape : llvm::reverse(arrayTy.getShape())) | ||||
|       dims.emplace_back(shape); | ||||
|     for (auto dim : llvm::reverse(arrayTy.getShape())) | ||||
|       dims.emplace_back(dim); | ||||
|   } | ||||
| 
 | ||||
|   getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType())); | ||||
| } | ||||
| 
 | ||||
| LogicalResult verify(ONNXTransposeOp op) { | ||||
|   auto module = op.getParentOfType<ModuleOp>(); | ||||
|   if (!module) | ||||
|     op.emitError("Expected to belong to a module."); | ||||
| 
 | ||||
|   if (auto permutation = op.getAttrOfType<ArrayAttr>( | ||||
|           ONNXTransposeOp::getPermAttrName())) { | ||||
|     for (auto perm : permutation.getValue()) | ||||
|       if (perm.cast<IntegerAttr>().getInt() < 0) | ||||
|         op.emitError("Cannot tranpose, permuation contains negative index."); | ||||
|   } | ||||
| 
 | ||||
|   return success(); | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // TableGen'd op method definitions
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
|  |  | |||
|  | @ -3102,6 +3102,8 @@ def ONNXTransposeOp:ONNX_Op<"Transpose", | |||
|   let extraClassDeclaration = [{ | ||||
|     static StringRef getPermAttrName() { return "perm"; } | ||||
|   }]; | ||||
| 
 | ||||
|   let verifier = [{ return ::verify(*this); }]; | ||||
| } | ||||
| 
 | ||||
| def ONNXUniqueOp:ONNX_Op<"Unique",  | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue