ShapeInference for SizeOp (#299)
* add shape inference * Revert "add shape inference" This reverts commit f9d42f39e68e14b5648abccfc8617fff00244d16. * shape inference * test case * format
This commit is contained in:
parent
1fcf97ef8d
commit
fa04c32a0c
|
@ -2378,6 +2378,18 @@ LogicalResult ONNXShapeOp::inferShapes() {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Size
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
LogicalResult ONNXSizeOp::inferShapes() {
|
||||||
|
// Output is scalar of int64 containing the size of the input tensor.
|
||||||
|
SmallVector<int64_t, 1> outDims;
|
||||||
|
getResult().setType(
|
||||||
|
RankedTensorType::get(outDims, IntegerType::get(64, getContext())));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Tile
|
// Tile
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -4863,7 +4863,7 @@ def ONNXSinhOp:ONNX_Op<"Sinh",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXSizeOp:ONNX_Op<"Size",
|
def ONNXSizeOp:ONNX_Op<"Size",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
let summary = "ONNX Size operation";
|
let summary = "ONNX Size operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
|
|
@ -1564,4 +1564,15 @@ func @test_onehotencoder_float2(%arg0: tensor<20x2x3xf32>) -> tensor<*xf32> {
|
||||||
// CHECK-LABEL: test_onehotencoder_float2
|
// CHECK-LABEL: test_onehotencoder_float2
|
||||||
// CHECK: [[RES:%.+]] = "onnx.OneHotEncoder"(%arg0) {cats_int64s = [1, 2, 4], cats_strings = ["female", "male"], zeros = 1 : i64} : (tensor<20x2x3xf32>) -> tensor<20x2x3x3xf32>
|
// CHECK: [[RES:%.+]] = "onnx.OneHotEncoder"(%arg0) {cats_int64s = [1, 2, 4], cats_strings = ["female", "male"], zeros = 1 : i64} : (tensor<20x2x3xf32>) -> tensor<20x2x3x3xf32>
|
||||||
// CHECK: return [[RES]] : tensor<20x2x3x3xf32>
|
// CHECK: return [[RES]] : tensor<20x2x3x3xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_size(%arg0: tensor<*xf32>) -> tensor<*xi64> {
|
||||||
|
%0 = "onnx.Size"(%arg0) : (tensor<*xf32>) -> tensor<*xi64>
|
||||||
|
"std.return"(%0) : (tensor<*xi64>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_size
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Size"(%arg0) : (tensor<*xf32>) -> tensor<i64>
|
||||||
|
// CHECK: return [[RES]] : tensor<i64>
|
||||||
|
}
|
||||||
|
|
|
@ -303,6 +303,7 @@ OpsWithShapeInference=[
|
||||||
'Sign',
|
'Sign',
|
||||||
'Sin',
|
'Sin',
|
||||||
'Sinh',
|
'Sinh',
|
||||||
|
'Size',
|
||||||
'Slice',
|
'Slice',
|
||||||
'Softmax',
|
'Softmax',
|
||||||
'Softplus',
|
'Softplus',
|
||||||
|
|
Loading…
Reference in New Issue