ShapeInference for SizeOp (#299)
* add shape inference * Revert "add shape inference" This reverts commit f9d42f39e68e14b5648abccfc8617fff00244d16. * shape inference * test case * format
This commit is contained in:
		
							parent
							
								
									1fcf97ef8d
								
							
						
					
					
						commit
						fa04c32a0c
					
				| 
						 | 
					@ -2378,6 +2378,18 @@ LogicalResult ONNXShapeOp::inferShapes() {
 | 
				
			||||||
  return success();
 | 
					  return success();
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// Size
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					LogicalResult ONNXSizeOp::inferShapes() {
 | 
				
			||||||
 | 
					  // Output is scalar of int64 containing the size of the input tensor.
 | 
				
			||||||
 | 
					  SmallVector<int64_t, 1> outDims;
 | 
				
			||||||
 | 
					  getResult().setType(
 | 
				
			||||||
 | 
					      RankedTensorType::get(outDims, IntegerType::get(64, getContext())));
 | 
				
			||||||
 | 
					  return success();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
// Tile
 | 
					// Tile
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -4863,7 +4863,7 @@ def ONNXSinhOp:ONNX_Op<"Sinh",
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def ONNXSizeOp:ONNX_Op<"Size",
 | 
					def ONNXSizeOp:ONNX_Op<"Size",
 | 
				
			||||||
  [NoSideEffect]> {
 | 
					  [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
 | 
				
			||||||
  let hasCanonicalizer = 1;
 | 
					  let hasCanonicalizer = 1;
 | 
				
			||||||
  let summary = "ONNX Size operation";
 | 
					  let summary = "ONNX Size operation";
 | 
				
			||||||
  let description = [{
 | 
					  let description = [{
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1565,3 +1565,14 @@ func @test_onehotencoder_float2(%arg0: tensor<20x2x3xf32>) -> tensor<*xf32> {
 | 
				
			||||||
  // CHECK: [[RES:%.+]] = "onnx.OneHotEncoder"(%arg0) {cats_int64s = [1, 2, 4], cats_strings = ["female", "male"], zeros = 1 : i64} : (tensor<20x2x3xf32>) -> tensor<20x2x3x3xf32>
 | 
					  // CHECK: [[RES:%.+]] = "onnx.OneHotEncoder"(%arg0) {cats_int64s = [1, 2, 4], cats_strings = ["female", "male"], zeros = 1 : i64} : (tensor<20x2x3xf32>) -> tensor<20x2x3x3xf32>
 | 
				
			||||||
  // CHECK: return [[RES]] : tensor<20x2x3x3xf32>
 | 
					  // CHECK: return [[RES]] : tensor<20x2x3x3xf32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @test_size(%arg0: tensor<*xf32>) -> tensor<*xi64> {
 | 
				
			||||||
 | 
					  %0 = "onnx.Size"(%arg0) : (tensor<*xf32>) -> tensor<*xi64>  
 | 
				
			||||||
 | 
					  "std.return"(%0) : (tensor<*xi64>) -> ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK-LABEL: test_size
 | 
				
			||||||
 | 
					  // CHECK: [[RES:%.+]] = "onnx.Size"(%arg0) : (tensor<*xf32>) -> tensor<i64>
 | 
				
			||||||
 | 
					  // CHECK: return [[RES]] : tensor<i64>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -303,6 +303,7 @@ OpsWithShapeInference=[
 | 
				
			||||||
    'Sign',
 | 
					    'Sign',
 | 
				
			||||||
    'Sin',
 | 
					    'Sin',
 | 
				
			||||||
    'Sinh',
 | 
					    'Sinh',
 | 
				
			||||||
 | 
					    'Size',
 | 
				
			||||||
    'Slice',
 | 
					    'Slice',
 | 
				
			||||||
    'Softmax',
 | 
					    'Softmax',
 | 
				
			||||||
    'Softplus',
 | 
					    'Softplus',
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue