Merge pull request #11 from clang-ykt/naive-transpose
Add default shape inference for the transposition operation.
This commit is contained in:
		
						commit
						7607edefe9
					
				|  | @ -267,7 +267,7 @@ def gen_schema(schema) : | ||||||
|                         'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor', |                         'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor', | ||||||
|                         'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu', |                         'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu', | ||||||
|                         'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', |                         'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', | ||||||
|                         'Identity', 'Cos', 'Log'] |                         'Identity', 'Cos', 'Log', 'Transpose'] | ||||||
|     CanonicalList=['Add', 'Identity'] |     CanonicalList=['Add', 'Identity'] | ||||||
|     line_indent = '  ' |     line_indent = '  ' | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -396,6 +396,24 @@ void ONNXReshapeOp::inferShapes() { | ||||||
|       RankedTensorType::get(dims, inputTensorTy.getElementType())); |       RankedTensorType::get(dims, inputTensorTy.getElementType())); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | //===----------------------------------------------------------------------===//
 | ||||||
|  | 
 | ||||||
|  | // Transpose
 | ||||||
|  | 
 | ||||||
|  | void ONNXTransposeOp::inferShapes() { | ||||||
|  |   // Cannot infer shape if no shape exists.
 | ||||||
|  |   if (!getOperand()->getType().isa<RankedTensorType>()) | ||||||
|  |     emitError("Shape tensor not ranked."); | ||||||
|  | 
 | ||||||
|  |   // Naive transposition which handles the default case of
 | ||||||
|  |   // reversing the shape of the tensor (similar to numpy.transpose).
 | ||||||
|  |   // TODO: Once attributes are supported we can handle the case where the
 | ||||||
|  |   // transposition uses a permutation vector to interchange the axes.
 | ||||||
|  |   auto arrayTy = getOperand()->getType().cast<RankedTensorType>(); | ||||||
|  |   SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape())); | ||||||
|  |   getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType())); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| //===----------------------------------------------------------------------===//
 | //===----------------------------------------------------------------------===//
 | ||||||
| // TableGen'd op method definitions
 | // TableGen'd op method definitions
 | ||||||
| //===----------------------------------------------------------------------===//
 | //===----------------------------------------------------------------------===//
 | ||||||
|  |  | ||||||
|  | @ -3089,7 +3089,7 @@ def ONNXTopKOp:ONNX_Op<"TopK", | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| def ONNXTransposeOp:ONNX_Op<"Transpose",  | def ONNXTransposeOp:ONNX_Op<"Transpose",  | ||||||
|     [NoSideEffect]> { |     [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { | ||||||
|   let summary = "ONNX Transpose operation"; |   let summary = "ONNX Transpose operation"; | ||||||
|   let description = [{ |   let description = [{ | ||||||
|     "Transpose the input tensor similar to numpy.transpose. For example, when" |     "Transpose the input tensor similar to numpy.transpose. For example, when" | ||||||
|  |  | ||||||
|  | @ -115,7 +115,8 @@ public: | ||||||
|         op->getName().getStringRef() != "onnx.MatMul" && |         op->getName().getStringRef() != "onnx.MatMul" && | ||||||
|         op->getName().getStringRef() != "onnx.Gemm" && |         op->getName().getStringRef() != "onnx.Gemm" && | ||||||
|         op->getName().getStringRef() != "onnx.FullGemm" && |         op->getName().getStringRef() != "onnx.FullGemm" && | ||||||
|         op->getName().getStringRef() != "onnx.Reshape") |         op->getName().getStringRef() != "onnx.Reshape" && | ||||||
|  |         op->getName().getStringRef() != "onnx.Transpose") | ||||||
|       return false; |       return false; | ||||||
|     return llvm::any_of(op->getResultTypes(), [](Type result_type) { |     return llvm::any_of(op->getResultTypes(), [](Type result_type) { | ||||||
|       return !result_type.isa<RankedTensorType>(); |       return !result_type.isa<RankedTensorType>(); | ||||||
|  |  | ||||||
|  | @ -0,0 +1,12 @@ | ||||||
|  | // RUN: onnf-opt --shape-inference %s -split-input-file | FileCheck %s | ||||||
|  | 
 | ||||||
|  | /// Test the default behavior of transpose when no information for the | ||||||
|  | /// permutation of the axes is provided. | ||||||
|  | func @test_default_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { | ||||||
|  |   %0 = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<*xf32> | ||||||
|  |   "std.return"(%0) : (tensor<*xf32>) -> () | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: test_default_transpose | ||||||
|  | // CHECK: [[RES:%.+]] = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32> | ||||||
|  | // CHECK: return [[RES]] : tensor<32x1x5x5xf32> | ||||||
		Loading…
	
		Reference in New Issue