ScalerOp support non-float input using CastOp (#231)
* move scalerop to decompose * change clang format * change clang format * add shape inference for scaler op * fixing generated onnxop * generate onnx.md * redo get onnx.md and onnxop.td.inc using onnx 1.6 * Add shape inference for scaler op * add benefit for scaler decompose and simplify scaler shape inference * add scaler decompose benefit num and simplify shape inference * add cast builder Co-authored-by: chentong319 <chentong@us.ibm.com>
This commit is contained in:
parent
2e8f012195
commit
e631283c71
|
@ -494,6 +494,14 @@ def ONNXCastOp:ONNX_Op<"Cast",
|
||||||
return resultTypes;
|
return resultTypes;
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<"OpBuilder &builder, OperationState &state, Value input, IntegerAttr to", [{
|
||||||
|
auto toAttr = to.getValue().getSExtValue();
|
||||||
|
auto resultType = mlir::UnrankedTensorType::get(
|
||||||
|
convertONNXTypeToMLIRType(builder, static_cast<onnx::TensorProto_DataType>(toAttr)));
|
||||||
|
build(builder, state, resultType, input, to);
|
||||||
|
}] >
|
||||||
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXCeilOp:ONNX_Op<"Ceil",
|
def ONNXCeilOp:ONNX_Op<"Ceil",
|
||||||
|
|
|
@ -63,10 +63,6 @@ def AttributeIsNull :
|
||||||
Constraint<CPred<"! ($_self)">,
|
Constraint<CPred<"! ($_self)">,
|
||||||
"Attribute is null">;
|
"Attribute is null">;
|
||||||
|
|
||||||
def AttributeNotNull :
|
|
||||||
Constraint<CPred<" ($_self)">,
|
|
||||||
"Attribute exists">;
|
|
||||||
|
|
||||||
def HasFloatType : Constraint<CPred<"(($_self).getType().dyn_cast<ShapedType>().getElementType().isF32())">>;
|
def HasFloatType : Constraint<CPred<"(($_self).getType().dyn_cast<ShapedType>().getElementType().isF32())">>;
|
||||||
|
|
||||||
def GetNullAttr :
|
def GetNullAttr :
|
||||||
|
@ -83,20 +79,30 @@ def ScalerNullPattern : Pat<
|
||||||
(ONNXScalerOp $x, $a, $b),
|
(ONNXScalerOp $x, $a, $b),
|
||||||
(replaceWithValue $x),
|
(replaceWithValue $x),
|
||||||
[(HasFloatType:$x), (AttributeIsNull:$a), (AttributeIsNull:$b)],
|
[(HasFloatType:$x), (AttributeIsNull:$a), (AttributeIsNull:$b)],
|
||||||
(addBenefit 4)>;
|
(addBenefit 5)>;
|
||||||
|
|
||||||
// No attribute, input x not float type
|
// No attribute, input x not float type
|
||||||
def ScalerNullPattern2 : Pat<
|
def ScalerNullPattern2 : Pat<
|
||||||
(ONNXScalerOp $x, $a, $b),
|
(ONNXScalerOp $x, $a, $b),
|
||||||
(ONNXCastOp $x, (ScalerT)),
|
(ONNXCastOp $x, (ScalerT)),
|
||||||
[(AttributeIsNull:$a), (AttributeIsNull:$b)],
|
[(AttributeIsNull:$a), (AttributeIsNull:$b)],
|
||||||
(addBenefit 3)>;
|
(addBenefit 4)>;
|
||||||
|
|
||||||
// No scale
|
// No scale
|
||||||
def ScalerNoScalePattern : Pat<
|
def ScalerNoScalePattern : Pat<
|
||||||
(ONNXScalerOp $x, $a, $b),
|
(ONNXScalerOp $x, $a, $b),
|
||||||
(ONNXSubOp $x,
|
(ONNXSubOp $x,
|
||||||
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a))),
|
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a))),
|
||||||
|
[(HasFloatType:$x), (AttributeIsNull:$b)],
|
||||||
|
(addBenefit 3)>;
|
||||||
|
|
||||||
|
// No scale, input x not float type
|
||||||
|
def ScalerNoScalePattern2 : Pat<
|
||||||
|
(ONNXScalerOp $x, $a, $b),
|
||||||
|
(ONNXSubOp
|
||||||
|
(ONNXCastOp $x, (ScalerT)),
|
||||||
|
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a))
|
||||||
|
),
|
||||||
[(AttributeIsNull:$b)],
|
[(AttributeIsNull:$b)],
|
||||||
(addBenefit 2)>;
|
(addBenefit 2)>;
|
||||||
|
|
||||||
|
@ -105,6 +111,15 @@ def ScalerNoOffsetPattern : Pat<
|
||||||
(ONNXScalerOp $x, $a, $b),
|
(ONNXScalerOp $x, $a, $b),
|
||||||
(ONNXMulOp $x,
|
(ONNXMulOp $x,
|
||||||
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b))),
|
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b))),
|
||||||
|
[(HasFloatType:$x), (AttributeIsNull:$a)],
|
||||||
|
(addBenefit 3)>;
|
||||||
|
|
||||||
|
// No offset, input x not float type
|
||||||
|
def ScalerNoOffsetPattern2 : Pat<
|
||||||
|
(ONNXScalerOp $x, $a, $b),
|
||||||
|
(ONNXMulOp
|
||||||
|
(ONNXCastOp $x, (ScalerT)),
|
||||||
|
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b))),
|
||||||
[(AttributeIsNull:$a)],
|
[(AttributeIsNull:$a)],
|
||||||
(addBenefit 2)>;
|
(addBenefit 2)>;
|
||||||
|
|
||||||
|
@ -115,7 +130,17 @@ def ScalerPattern : Pat<
|
||||||
(ONNXSubOp $x,
|
(ONNXSubOp $x,
|
||||||
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a))),
|
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a))),
|
||||||
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b))),
|
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b))),
|
||||||
[],
|
[(HasFloatType:$x)],
|
||||||
(addBenefit 1)>;
|
(addBenefit 1)>;
|
||||||
|
|
||||||
|
// Normal ONNXScalerOp, input x not float type
|
||||||
|
def ScalerPattern2 : Pat<
|
||||||
|
(ONNXScalerOp $x, $a, $b),
|
||||||
|
(ONNXMulOp
|
||||||
|
(ONNXSubOp (ONNXCastOp $x, (ScalerT)),
|
||||||
|
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a))),
|
||||||
|
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b))),
|
||||||
|
[],
|
||||||
|
(addBenefit 0)>;
|
||||||
|
|
||||||
#endif // ONNX_DECOMPOSE
|
#endif // ONNX_DECOMPOSE
|
||||||
|
|
|
@ -97,6 +97,20 @@ func @test_scaler_no_offset(%arg0: tensor<3xf32>) -> tensor<3xf32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// scaler no offset, int input
|
||||||
|
// CHECK-LABEL: func @test_scaler_no_offset2(%{{.*}}: tensor<3xi32>) -> tensor<3xf32> {
|
||||||
|
func @test_scaler_no_offset2(%arg0: tensor<3xi32>) -> tensor<3xf32> {
|
||||||
|
%0 = "onnx.Scaler"(%arg0) {scale = [3.125000e-02 : f32, 0.0909090936 : f32, 0.0333333351 : f32]} : (tensor<3xi32>) -> tensor<3xf32>
|
||||||
|
return %0 : tensor<3xf32>
|
||||||
|
|
||||||
|
// CHECK-NEXT: %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<3xi32>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: %1 = "onnx.Constant"() {value = dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32>} : () -> tensor<3xf32>
|
||||||
|
// CHECK-NEXT: %2 = "onnx.Mul"(%0, %1) : (tensor<*xf32>, tensor<3xf32>) -> tensor<3xf32>
|
||||||
|
// CHECK-NEXT: return %2 : tensor<3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// scaler no scale
|
// scaler no scale
|
||||||
// CHECK-LABEL: func @test_scaler_no_scale(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
|
// CHECK-LABEL: func @test_scaler_no_scale(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
|
||||||
func @test_scaler_no_scale(%arg0: tensor<3xf32>) -> tensor<3xf32> {
|
func @test_scaler_no_scale(%arg0: tensor<3xf32>) -> tensor<3xf32> {
|
||||||
|
@ -110,6 +124,20 @@ func @test_scaler_no_scale(%arg0: tensor<3xf32>) -> tensor<3xf32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// scaler no scale, int input
|
||||||
|
// CHECK-LABEL: func @test_scaler_no_scale2(%{{.*}}: tensor<3xi32>) -> tensor<3xf32> {
|
||||||
|
func @test_scaler_no_scale2(%arg0: tensor<3xi32>) -> tensor<3xf32> {
|
||||||
|
%0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32]} : (tensor<3xi32>) -> tensor<3xf32>
|
||||||
|
return %0 : tensor<3xf32>
|
||||||
|
|
||||||
|
//CHECK-NEXT: %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<3xi32>) -> tensor<*xf32>
|
||||||
|
//CHECK-NEXT: %1 = "onnx.Constant"() {value = dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32>} : () -> tensor<3xf32>
|
||||||
|
//CHECK-NEXT: %2 = "onnx.Sub"(%0, %1) : (tensor<*xf32>, tensor<3xf32>) -> tensor<3xf32>
|
||||||
|
//CHECK-NEXT: return %2 : tensor<3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// normal scaler
|
// normal scaler
|
||||||
// CHECK-LABEL: func @test_scaler_normal(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
|
// CHECK-LABEL: func @test_scaler_normal(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
|
||||||
func @test_scaler_normal(%arg0: tensor<3xf32>) -> tensor<3xf32> {
|
func @test_scaler_normal(%arg0: tensor<3xf32>) -> tensor<3xf32> {
|
||||||
|
@ -125,6 +153,22 @@ func @test_scaler_normal(%arg0: tensor<3xf32>) -> tensor<3xf32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// normal scaler, int input
|
||||||
|
// CHECK-LABEL: func @test_scaler_normal2(%{{.*}}: tensor<3xi32>) -> tensor<3xf32> {
|
||||||
|
func @test_scaler_normal2(%arg0: tensor<3xi32>) -> tensor<3xf32> {
|
||||||
|
%0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32], scale = [3.125000e-02 : f32, 0.0909090936 : f32, 0.0333333351 : f32]} : (tensor<3xi32>) -> tensor<3xf32>
|
||||||
|
return %0 : tensor<3xf32>
|
||||||
|
|
||||||
|
// CHECK-NEXT: %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<3xi32>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: %1 = "onnx.Constant"() {value = dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32>} : () -> tensor<3xf32>
|
||||||
|
// CHECK-NEXT: %2 = "onnx.Sub"(%0, %1) : (tensor<*xf32>, tensor<3xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: %3 = "onnx.Constant"() {value = dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32>} : () -> tensor<3xf32>
|
||||||
|
// CHECK-NEXT: %4 = "onnx.Mul"(%2, %3) : (tensor<*xf32>, tensor<3xf32>) -> tensor<3xf32>
|
||||||
|
// CHECK-NEXT: return %4 : tensor<3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// normal scaler with constant offset and scale
|
// normal scaler with constant offset and scale
|
||||||
// CHECK-LABEL: func @test_scaler_constant(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
|
// CHECK-LABEL: func @test_scaler_constant(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
|
||||||
func @test_scaler_constant(%arg0: tensor<3xf32>) -> tensor<3xf32> {
|
func @test_scaler_constant(%arg0: tensor<3xf32>) -> tensor<3xf32> {
|
||||||
|
|
|
@ -323,10 +323,18 @@ custom_definition_misc = dict([ ('Constant',
|
||||||
build(builder, state, tensorType, sparse_value, value);
|
build(builder, state, tensorType, sparse_value, value);
|
||||||
}
|
}
|
||||||
}]>
|
}]>
|
||||||
|
];'''),
|
||||||
|
('Cast',
|
||||||
|
''' let builders = [
|
||||||
|
OpBuilder<"OpBuilder &builder, OperationState &state, Value input, IntegerAttr to", [{
|
||||||
|
auto toAttr = to.getValue().getSExtValue();
|
||||||
|
auto resultType = mlir::UnrankedTensorType::get(
|
||||||
|
convertONNXTypeToMLIRType(builder, static_cast<onnx::TensorProto_DataType>(toAttr)));
|
||||||
|
build(builder, state, resultType, input, to);
|
||||||
|
}] >
|
||||||
];'''
|
];'''
|
||||||
)])
|
)])
|
||||||
|
|
||||||
|
|
||||||
onnx_types = (
|
onnx_types = (
|
||||||
'bool', 'int8', 'int16', 'int32', 'int64', 'unkown', 'float16',
|
'bool', 'int8', 'int16', 'int32', 'int64', 'unkown', 'float16',
|
||||||
'float', 'double', 'complex64', 'complex128', 'string'
|
'float', 'double', 'complex64', 'complex128', 'string'
|
||||||
|
|
Loading…
Reference in New Issue