Add the default shape inference for the transposition operation.
This commit is contained in:
		
							parent
							
								
									caeba371fb
								
							
						
					
					
						commit
						151f4f8c44
					
				|  | @ -267,7 +267,7 @@ def gen_schema(schema) : | |||
|                         'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor', | ||||
|                         'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu', | ||||
|                         'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', | ||||
|                         'Identity', 'Cos', 'Log'] | ||||
|                         'Identity', 'Cos', 'Log', 'Transpose'] | ||||
|     CanonicalList=['Add', 'Identity'] | ||||
|     line_indent = '  ' | ||||
| 
 | ||||
|  |  | |||
|  | @ -396,6 +396,24 @@ void ONNXReshapeOp::inferShapes() { | |||
|       RankedTensorType::get(dims, inputTensorTy.getElementType())); | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| // Transpose
 | ||||
| 
 | ||||
| void ONNXTransposeOp::inferShapes() { | ||||
|   // Cannot infer shape if no shape exists.
 | ||||
|   if (!getOperand()->getType().isa<RankedTensorType>()) | ||||
|     emitError("Shape tensor not ranked."); | ||||
| 
 | ||||
|   // Naive transposition which handles the default case of
 | ||||
|   // reversing the shape of the tensor (similar to numpy.transpose).
 | ||||
|   // TODO: Once attributes are supported we can handle the case where the
 | ||||
|   // transposition uses a permutation vector to interchange the axes.
 | ||||
|   auto arrayTy = getOperand()->getType().cast<RankedTensorType>(); | ||||
|   SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape())); | ||||
|   getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType())); | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // TableGen'd op method definitions
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
|  |  | |||
|  | @ -3089,7 +3089,7 @@ def ONNXTopKOp:ONNX_Op<"TopK", | |||
| } | ||||
| 
 | ||||
| def ONNXTransposeOp:ONNX_Op<"Transpose",  | ||||
|     [NoSideEffect]> { | ||||
|     [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { | ||||
|   let summary = "ONNX Transpose operation"; | ||||
|   let description = [{ | ||||
|     "Transpose the input tensor similar to numpy.transpose. For example, when" | ||||
|  |  | |||
|  | @ -115,7 +115,8 @@ public: | |||
|         op->getName().getStringRef() != "onnx.MatMul" && | ||||
|         op->getName().getStringRef() != "onnx.Gemm" && | ||||
|         op->getName().getStringRef() != "onnx.FullGemm" && | ||||
|         op->getName().getStringRef() != "onnx.Reshape") | ||||
|         op->getName().getStringRef() != "onnx.Reshape" && | ||||
|         op->getName().getStringRef() != "onnx.Transpose") | ||||
|       return false; | ||||
|     return llvm::any_of(op->getResultTypes(), [](Type result_type) { | ||||
|       return !result_type.isa<RankedTensorType>(); | ||||
|  |  | |||
|  | @ -0,0 +1,12 @@ | |||
| // RUN: onnf-opt --shape-inference %s -split-input-file | FileCheck %s | ||||
| 
 | ||||
| /// Test the default behavior of transpose when no information for the | ||||
| /// permutation of the axes is provided. | ||||
| func @test_default_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { | ||||
|   %0 = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<*xf32> | ||||
|   "std.return"(%0) : (tensor<*xf32>) -> () | ||||
| } | ||||
| 
 | ||||
| // CHECK-LABEL: test_default_transpose | ||||
| // CHECK: [[RES:%.+]] = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32> | ||||
| // CHECK: return [[RES]] : tensor<32x1x5x5xf32> | ||||
		Loading…
	
		Reference in New Issue