ShapeInference for SizeOp (#299)

* add shape inference

* Revert "add shape inference"

This reverts commit f9d42f39e68e14b5648abccfc8617fff00244d16.

* shape inference

* test case

* format
This commit is contained in:
chentong319 2020-09-11 13:47:11 -04:00 committed by GitHub
parent 1fcf97ef8d
commit fa04c32a0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 26 additions and 2 deletions

View File

@ -2378,6 +2378,18 @@ LogicalResult ONNXShapeOp::inferShapes() {
return success();
}
//===----------------------------------------------------------------------===//
// Size
//===----------------------------------------------------------------------===//
LogicalResult ONNXSizeOp::inferShapes() {
// Output is scalar of int64 containing the size of the input tensor.
SmallVector<int64_t, 1> outDims;
getResult().setType(
RankedTensorType::get(outDims, IntegerType::get(64, getContext())));
return success();
}
//===----------------------------------------------------------------------===//
// Tile
//===----------------------------------------------------------------------===//

View File

@ -4863,7 +4863,7 @@ def ONNXSinhOp:ONNX_Op<"Sinh",
}
def ONNXSizeOp:ONNX_Op<"Size",
[NoSideEffect]> {
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let hasCanonicalizer = 1;
let summary = "ONNX Size operation";
let description = [{

View File

@ -1564,4 +1564,15 @@ func @test_onehotencoder_float2(%arg0: tensor<20x2x3xf32>) -> tensor<*xf32> {
// CHECK-LABEL: test_onehotencoder_float2
// CHECK: [[RES:%.+]] = "onnx.OneHotEncoder"(%arg0) {cats_int64s = [1, 2, 4], cats_strings = ["female", "male"], zeros = 1 : i64} : (tensor<20x2x3xf32>) -> tensor<20x2x3x3xf32>
// CHECK: return [[RES]] : tensor<20x2x3x3xf32>
}
}
// -----
func @test_size(%arg0: tensor<*xf32>) -> tensor<*xi64> {
%0 = "onnx.Size"(%arg0) : (tensor<*xf32>) -> tensor<*xi64>
"std.return"(%0) : (tensor<*xi64>) -> ()
// CHECK-LABEL: test_size
// CHECK: [[RES:%.+]] = "onnx.Size"(%arg0) : (tensor<*xf32>) -> tensor<i64>
// CHECK: return [[RES]] : tensor<i64>
}

View File

@ -303,6 +303,7 @@ OpsWithShapeInference=[
'Sign',
'Sin',
'Sinh',
'Size',
'Slice',
'Softmax',
'Softplus',