ShapeInference for SizeOp (#299)
* add shape inference * Revert "add shape inference" This reverts commit f9d42f39e68e14b5648abccfc8617fff00244d16. * shape inference * test case * format
This commit is contained in:
parent
1fcf97ef8d
commit
fa04c32a0c
|
@ -2378,6 +2378,18 @@ LogicalResult ONNXShapeOp::inferShapes() {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Size
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ONNXSizeOp::inferShapes() {
|
||||
// Output is scalar of int64 containing the size of the input tensor.
|
||||
SmallVector<int64_t, 1> outDims;
|
||||
getResult().setType(
|
||||
RankedTensorType::get(outDims, IntegerType::get(64, getContext())));
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Tile
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -4863,7 +4863,7 @@ def ONNXSinhOp:ONNX_Op<"Sinh",
|
|||
}
|
||||
|
||||
def ONNXSizeOp:ONNX_Op<"Size",
|
||||
[NoSideEffect]> {
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let hasCanonicalizer = 1;
|
||||
let summary = "ONNX Size operation";
|
||||
let description = [{
|
||||
|
|
|
@ -1565,3 +1565,14 @@ func @test_onehotencoder_float2(%arg0: tensor<20x2x3xf32>) -> tensor<*xf32> {
|
|||
// CHECK: [[RES:%.+]] = "onnx.OneHotEncoder"(%arg0) {cats_int64s = [1, 2, 4], cats_strings = ["female", "male"], zeros = 1 : i64} : (tensor<20x2x3xf32>) -> tensor<20x2x3x3xf32>
|
||||
// CHECK: return [[RES]] : tensor<20x2x3x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_size(%arg0: tensor<*xf32>) -> tensor<*xi64> {
|
||||
%0 = "onnx.Size"(%arg0) : (tensor<*xf32>) -> tensor<*xi64>
|
||||
"std.return"(%0) : (tensor<*xi64>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_size
|
||||
// CHECK: [[RES:%.+]] = "onnx.Size"(%arg0) : (tensor<*xf32>) -> tensor<i64>
|
||||
// CHECK: return [[RES]] : tensor<i64>
|
||||
}
|
||||
|
|
|
@ -303,6 +303,7 @@ OpsWithShapeInference=[
|
|||
'Sign',
|
||||
'Sin',
|
||||
'Sinh',
|
||||
'Size',
|
||||
'Slice',
|
||||
'Softmax',
|
||||
'Softplus',
|
||||
|
|
Loading…
Reference in New Issue