From fa04c32a0ccb9af6c0c50c457b078d1728759d51 Mon Sep 17 00:00:00 2001 From: chentong319 Date: Fri, 11 Sep 2020 13:47:11 -0400 Subject: [PATCH] ShapeInference for SizeOp (#299) * add shape inference * Revert "add shape inference" This reverts commit f9d42f39e68e14b5648abccfc8617fff00244d16. * shape inference * test case * format --- src/Dialect/ONNX/ONNXOps.cpp | 12 ++++++++++++ src/Dialect/ONNX/ONNXOps.td.inc | 2 +- test/mlir/onnx/onnx_shape_inference.mlir | 13 ++++++++++++- utils/gen_onnx_mlir.py | 1 + 4 files changed, 26 insertions(+), 2 deletions(-) diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index ac680b4..aaf375a 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -2378,6 +2378,18 @@ LogicalResult ONNXShapeOp::inferShapes() { return success(); } +//===----------------------------------------------------------------------===// +// Size +//===----------------------------------------------------------------------===// + +LogicalResult ONNXSizeOp::inferShapes() { + // Output is scalar of int64 containing the size of the input tensor. + SmallVector outDims; + getResult().setType( + RankedTensorType::get(outDims, IntegerType::get(64, getContext()))); + return success(); +} + //===----------------------------------------------------------------------===// // Tile //===----------------------------------------------------------------------===// diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index 56fbeeb..fc9bb83 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -4863,7 +4863,7 @@ def ONNXSinhOp:ONNX_Op<"Sinh", } def ONNXSizeOp:ONNX_Op<"Size", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Size operation"; let description = [{ diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 1249122..07a78c7 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -1564,4 +1564,15 @@ func @test_onehotencoder_float2(%arg0: tensor<20x2x3xf32>) -> tensor<*xf32> { // CHECK-LABEL: test_onehotencoder_float2 // CHECK: [[RES:%.+]] = "onnx.OneHotEncoder"(%arg0) {cats_int64s = [1, 2, 4], cats_strings = ["female", "male"], zeros = 1 : i64} : (tensor<20x2x3xf32>) -> tensor<20x2x3x3xf32> // CHECK: return [[RES]] : tensor<20x2x3x3xf32> -} \ No newline at end of file +} + +// ----- + +func @test_size(%arg0: tensor<*xf32>) -> tensor<*xi64> { + %0 = "onnx.Size"(%arg0) : (tensor<*xf32>) -> tensor<*xi64> + "std.return"(%0) : (tensor<*xi64>) -> () + + // CHECK-LABEL: test_size + // CHECK: [[RES:%.+]] = "onnx.Size"(%arg0) : (tensor<*xf32>) -> tensor + // CHECK: return [[RES]] : tensor +} diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index 474866a..171206a 100644 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -303,6 +303,7 @@ OpsWithShapeInference=[ 'Sign', 'Sin', 'Sinh', + 'Size', 'Slice', 'Softmax', 'Softplus',