ScalerOp support non-float input using CastOp (#231)
* move scalerop to decompose * change clang format * change clang format * add shape inference for scaler op * fixing generated onnxop * generate onnx.md * redo get onnx.md and onnxop.td.inc using onnx 1.6 * Add shape inference for scaler op * add benefit for scaler decompose and simplify scaler shape inference * add scaler decompose benefit num and simplify shape inference * add cast builder Co-authored-by: chentong319 <chentong@us.ibm.com>
This commit is contained in:
parent
2e8f012195
commit
e631283c71
|
@ -494,6 +494,14 @@ def ONNXCastOp:ONNX_Op<"Cast",
|
|||
return resultTypes;
|
||||
}
|
||||
}];
|
||||
let builders = [
|
||||
OpBuilder<"OpBuilder &builder, OperationState &state, Value input, IntegerAttr to", [{
|
||||
auto toAttr = to.getValue().getSExtValue();
|
||||
auto resultType = mlir::UnrankedTensorType::get(
|
||||
convertONNXTypeToMLIRType(builder, static_cast<onnx::TensorProto_DataType>(toAttr)));
|
||||
build(builder, state, resultType, input, to);
|
||||
}] >
|
||||
];
|
||||
}
|
||||
|
||||
def ONNXCeilOp:ONNX_Op<"Ceil",
|
||||
|
|
|
@ -63,10 +63,6 @@ def AttributeIsNull :
|
|||
Constraint<CPred<"! ($_self)">,
|
||||
"Attribute is null">;
|
||||
|
||||
def AttributeNotNull :
|
||||
Constraint<CPred<" ($_self)">,
|
||||
"Attribute exists">;
|
||||
|
||||
def HasFloatType : Constraint<CPred<"(($_self).getType().dyn_cast<ShapedType>().getElementType().isF32())">>;
|
||||
|
||||
def GetNullAttr :
|
||||
|
@ -83,20 +79,30 @@ def ScalerNullPattern : Pat<
|
|||
(ONNXScalerOp $x, $a, $b),
|
||||
(replaceWithValue $x),
|
||||
[(HasFloatType:$x), (AttributeIsNull:$a), (AttributeIsNull:$b)],
|
||||
(addBenefit 4)>;
|
||||
(addBenefit 5)>;
|
||||
|
||||
// No attribute, input x not float type
|
||||
def ScalerNullPattern2 : Pat<
|
||||
(ONNXScalerOp $x, $a, $b),
|
||||
(ONNXCastOp $x, (ScalerT)),
|
||||
[(AttributeIsNull:$a), (AttributeIsNull:$b)],
|
||||
(addBenefit 3)>;
|
||||
(addBenefit 4)>;
|
||||
|
||||
// No scale
|
||||
def ScalerNoScalePattern : Pat<
|
||||
(ONNXScalerOp $x, $a, $b),
|
||||
(ONNXSubOp $x,
|
||||
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a))),
|
||||
[(HasFloatType:$x), (AttributeIsNull:$b)],
|
||||
(addBenefit 3)>;
|
||||
|
||||
// No scale, input x not float type
|
||||
def ScalerNoScalePattern2 : Pat<
|
||||
(ONNXScalerOp $x, $a, $b),
|
||||
(ONNXSubOp
|
||||
(ONNXCastOp $x, (ScalerT)),
|
||||
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a))
|
||||
),
|
||||
[(AttributeIsNull:$b)],
|
||||
(addBenefit 2)>;
|
||||
|
||||
|
@ -105,6 +111,15 @@ def ScalerNoOffsetPattern : Pat<
|
|||
(ONNXScalerOp $x, $a, $b),
|
||||
(ONNXMulOp $x,
|
||||
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b))),
|
||||
[(HasFloatType:$x), (AttributeIsNull:$a)],
|
||||
(addBenefit 3)>;
|
||||
|
||||
// No offset, input x not float type
|
||||
def ScalerNoOffsetPattern2 : Pat<
|
||||
(ONNXScalerOp $x, $a, $b),
|
||||
(ONNXMulOp
|
||||
(ONNXCastOp $x, (ScalerT)),
|
||||
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b))),
|
||||
[(AttributeIsNull:$a)],
|
||||
(addBenefit 2)>;
|
||||
|
||||
|
@ -115,7 +130,17 @@ def ScalerPattern : Pat<
|
|||
(ONNXSubOp $x,
|
||||
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a))),
|
||||
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b))),
|
||||
[],
|
||||
[(HasFloatType:$x)],
|
||||
(addBenefit 1)>;
|
||||
|
||||
// Normal ONNXScalerOp, input x not float type
|
||||
def ScalerPattern2 : Pat<
|
||||
(ONNXScalerOp $x, $a, $b),
|
||||
(ONNXMulOp
|
||||
(ONNXSubOp (ONNXCastOp $x, (ScalerT)),
|
||||
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a))),
|
||||
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b))),
|
||||
[],
|
||||
(addBenefit 0)>;
|
||||
|
||||
#endif // ONNX_DECOMPOSE
|
||||
|
|
|
@ -97,6 +97,20 @@ func @test_scaler_no_offset(%arg0: tensor<3xf32>) -> tensor<3xf32> {
|
|||
|
||||
// -----
|
||||
|
||||
// scaler no offset, int input
|
||||
// CHECK-LABEL: func @test_scaler_no_offset2(%{{.*}}: tensor<3xi32>) -> tensor<3xf32> {
|
||||
func @test_scaler_no_offset2(%arg0: tensor<3xi32>) -> tensor<3xf32> {
|
||||
%0 = "onnx.Scaler"(%arg0) {scale = [3.125000e-02 : f32, 0.0909090936 : f32, 0.0333333351 : f32]} : (tensor<3xi32>) -> tensor<3xf32>
|
||||
return %0 : tensor<3xf32>
|
||||
|
||||
// CHECK-NEXT: %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<3xi32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %1 = "onnx.Constant"() {value = dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32>} : () -> tensor<3xf32>
|
||||
// CHECK-NEXT: %2 = "onnx.Mul"(%0, %1) : (tensor<*xf32>, tensor<3xf32>) -> tensor<3xf32>
|
||||
// CHECK-NEXT: return %2 : tensor<3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// scaler no scale
|
||||
// CHECK-LABEL: func @test_scaler_no_scale(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
|
||||
func @test_scaler_no_scale(%arg0: tensor<3xf32>) -> tensor<3xf32> {
|
||||
|
@ -110,6 +124,20 @@ func @test_scaler_no_scale(%arg0: tensor<3xf32>) -> tensor<3xf32> {
|
|||
|
||||
// -----
|
||||
|
||||
// scaler no scale, int input
|
||||
// CHECK-LABEL: func @test_scaler_no_scale2(%{{.*}}: tensor<3xi32>) -> tensor<3xf32> {
|
||||
func @test_scaler_no_scale2(%arg0: tensor<3xi32>) -> tensor<3xf32> {
|
||||
%0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32]} : (tensor<3xi32>) -> tensor<3xf32>
|
||||
return %0 : tensor<3xf32>
|
||||
|
||||
//CHECK-NEXT: %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<3xi32>) -> tensor<*xf32>
|
||||
//CHECK-NEXT: %1 = "onnx.Constant"() {value = dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32>} : () -> tensor<3xf32>
|
||||
//CHECK-NEXT: %2 = "onnx.Sub"(%0, %1) : (tensor<*xf32>, tensor<3xf32>) -> tensor<3xf32>
|
||||
//CHECK-NEXT: return %2 : tensor<3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// normal scaler
|
||||
// CHECK-LABEL: func @test_scaler_normal(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
|
||||
func @test_scaler_normal(%arg0: tensor<3xf32>) -> tensor<3xf32> {
|
||||
|
@ -125,6 +153,22 @@ func @test_scaler_normal(%arg0: tensor<3xf32>) -> tensor<3xf32> {
|
|||
|
||||
// -----
|
||||
|
||||
// normal scaler, int input
|
||||
// CHECK-LABEL: func @test_scaler_normal2(%{{.*}}: tensor<3xi32>) -> tensor<3xf32> {
|
||||
func @test_scaler_normal2(%arg0: tensor<3xi32>) -> tensor<3xf32> {
|
||||
%0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32], scale = [3.125000e-02 : f32, 0.0909090936 : f32, 0.0333333351 : f32]} : (tensor<3xi32>) -> tensor<3xf32>
|
||||
return %0 : tensor<3xf32>
|
||||
|
||||
// CHECK-NEXT: %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<3xi32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %1 = "onnx.Constant"() {value = dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32>} : () -> tensor<3xf32>
|
||||
// CHECK-NEXT: %2 = "onnx.Sub"(%0, %1) : (tensor<*xf32>, tensor<3xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %3 = "onnx.Constant"() {value = dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32>} : () -> tensor<3xf32>
|
||||
// CHECK-NEXT: %4 = "onnx.Mul"(%2, %3) : (tensor<*xf32>, tensor<3xf32>) -> tensor<3xf32>
|
||||
// CHECK-NEXT: return %4 : tensor<3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// normal scaler with constant offset and scale
|
||||
// CHECK-LABEL: func @test_scaler_constant(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
|
||||
func @test_scaler_constant(%arg0: tensor<3xf32>) -> tensor<3xf32> {
|
||||
|
|
|
@ -323,10 +323,18 @@ custom_definition_misc = dict([ ('Constant',
|
|||
build(builder, state, tensorType, sparse_value, value);
|
||||
}
|
||||
}]>
|
||||
];'''),
|
||||
('Cast',
|
||||
''' let builders = [
|
||||
OpBuilder<"OpBuilder &builder, OperationState &state, Value input, IntegerAttr to", [{
|
||||
auto toAttr = to.getValue().getSExtValue();
|
||||
auto resultType = mlir::UnrankedTensorType::get(
|
||||
convertONNXTypeToMLIRType(builder, static_cast<onnx::TensorProto_DataType>(toAttr)));
|
||||
build(builder, state, resultType, input, to);
|
||||
}] >
|
||||
];'''
|
||||
)])
|
||||
|
||||
|
||||
onnx_types = (
|
||||
'bool', 'int8', 'int16', 'int32', 'int64', 'unkown', 'float16',
|
||||
'float', 'double', 'complex64', 'complex128', 'string'
|
||||
|
|
Loading…
Reference in New Issue