Add shape inference for ScalerOp (#228)

* move scalerop to decompose

* change clang format

* change clang format

* add shape inference for scaler op

* fixing generated onnxop

* generate onnx.md

* Add shape inference for scaler op

* add benefit for scaler decompose and simplify scaler shape inference
This commit is contained in:
Anh Leu 2020-07-23 12:05:19 -05:00 committed by GitHub
parent 034f98c00c
commit c9e3ba2d64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 188 additions and 166 deletions

View File

@ -1877,6 +1877,17 @@ LogicalResult ONNXCastOp::inferShapes() {
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// Scaler
//===----------------------------------------------------------------------===//
LogicalResult ONNXScalerOp::inferShapes() {
ShapedType inputType = X().getType().dyn_cast<ShapedType>();
getResult().setType(RankedTensorType::get(
inputType.getShape(), FloatType::getF32(getContext())));
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Constant // Constant
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -6103,8 +6103,7 @@ def ONNXSVMRegressorOp:ONNX_Op<"SVMRegressor",
} }
def ONNXScalerOp:ONNX_Op<"Scaler", def ONNXScalerOp:ONNX_Op<"Scaler",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let hasCanonicalizer = 1;
let summary = "ONNX Scaler operation"; let summary = "ONNX Scaler operation";
let description = [{ let description = [{
"Rescale input data, for example to standardize features by removing the mean and scaling to unit variance." "Rescale input data, for example to standardize features by removing the mean and scaling to unit variance."

View File

@ -24,6 +24,23 @@
using namespace mlir; using namespace mlir;
namespace { namespace {
// Create an DenseElementsAttr of ArrayAttr.
// This function is used to get Value Type of an EXISTING ArrayAttr for Scaler
// function.
DenseElementsAttr createDenseArrayAttr(
PatternRewriter &rewriter, ArrayAttr origAttrs) {
mlir::Type elementType = rewriter.getF32Type();
int nElements = origAttrs.getValue().size();
SmallVector<float, 4> wrapper(nElements, 0);
for (int i = 0; i < nElements; ++i) {
wrapper[i] = origAttrs.getValue()[i].cast<FloatAttr>().getValueAsDouble();
}
return DenseElementsAttr::get(
RankedTensorType::get(wrapper.size(), elementType),
llvm::makeArrayRef(wrapper));
}
/// Include the patterns defined in the Declarative Rewrite framework. /// Include the patterns defined in the Declarative Rewrite framework.
#include "src/Transform/ONNX/ONNXDecompose.inc" #include "src/Transform/ONNX/ONNXDecompose.inc"
@ -47,6 +64,7 @@ void DecomposeONNXToONNXPass::runOnFunction() {
target.addIllegalOp<ONNXReduceLogSumOp>(); target.addIllegalOp<ONNXReduceLogSumOp>();
target.addIllegalOp<ONNXReduceLogSumExpOp>(); target.addIllegalOp<ONNXReduceLogSumExpOp>();
target.addIllegalOp<ONNXReduceSumSquareOp>(); target.addIllegalOp<ONNXReduceSumSquareOp>();
target.addIllegalOp<ONNXScalerOp>();
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
populateWithGenerated(context, &patterns); populateWithGenerated(context, &patterns);

View File

@ -54,4 +54,68 @@ def ReduceLogSumExpOpPattern: Pat<(ONNXReduceLogSumExpOp $oprd, $axes, $keepdims
def ReduceSumSquareOpPattern: Pat<(ONNXReduceSumSquareOp $oprd, $axes, $keepdims), def ReduceSumSquareOpPattern: Pat<(ONNXReduceSumSquareOp $oprd, $axes, $keepdims),
(ONNXReduceSumOp (ONNXMulOp $oprd, $oprd), $axes, $keepdims)>; (ONNXReduceSumOp (ONNXMulOp $oprd, $oprd), $axes, $keepdims)>;
//===----------------------------------------------------------------------===//
// ONNXScalerOp %X, %Offest, %Scale
// x input, a offset, b scale
//===----------------------------------------------------------------------===//
// Useful test definitions.
def AttributeIsNull :
Constraint<CPred<"! ($_self)">,
"Attribute is null">;
def AttributeNotNull :
Constraint<CPred<" ($_self)">,
"Attribute exists">;
def HasFloatType : Constraint<CPred<"(($_self).getType().dyn_cast<ShapedType>().getElementType().isF32())">>;
def GetNullAttr :
NativeCodeCall<"Attribute()">;
// Create a DenseElementsAttr from an ArrayAttr.
def createDenseArrayAttr:
NativeCodeCall<"createDenseArrayAttr($_builder, $0)">;
def ScalerT : NativeCodeCall<"$_builder.getI64IntegerAttr(1)">;
// No attribute
def ScalerNullPattern : Pat<
(ONNXScalerOp $x, $a, $b),
(replaceWithValue $x),
[(HasFloatType:$x), (AttributeIsNull:$a), (AttributeIsNull:$b)],
(addBenefit 4)>;
// No attribute, input x not float type
def ScalerNullPattern2 : Pat<
(ONNXScalerOp $x, $a, $b),
(ONNXCastOp $x, (ScalerT)),
[(AttributeIsNull:$a), (AttributeIsNull:$b)],
(addBenefit 3)>;
// No scale
def ScalerNoScalePattern : Pat<
(ONNXScalerOp $x, $a, $b),
(ONNXSubOp $x,
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a))),
[(AttributeIsNull:$b)],
(addBenefit 2)>;
// No offset
def ScalerNoOffsetPattern : Pat<
(ONNXScalerOp $x, $a, $b),
(ONNXMulOp $x,
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b))),
[(AttributeIsNull:$a)],
(addBenefit 2)>;
// Normal ONNXScalerOp
def ScalerPattern : Pat<
(ONNXScalerOp $x, $a, $b),
(ONNXMulOp
(ONNXSubOp $x,
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a))),
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b))),
[],
(addBenefit 1)>;
#endif // ONNX_DECOMPOSE #endif // ONNX_DECOMPOSE

View File

@ -18,23 +18,6 @@ using namespace mlir;
namespace { namespace {
// Create an DenseElementsAttr of ArrayAttr.
// This function is used to get Value Type for Scaler function.
DenseElementsAttr createDenseArrayAttr(
PatternRewriter &rewriter, ArrayAttr origAttrs) {
mlir::Type elementType = rewriter.getF32Type();
int nElements = origAttrs.getValue().size();
SmallVector<float, 4> wrapper(nElements, 0);
if (origAttrs) {
for (int i = 0; i < nElements; ++i) {
wrapper[i] = origAttrs.getValue()[i].cast<FloatAttr>().getValueAsDouble();
}
}
return DenseElementsAttr::get(
RankedTensorType::get(wrapper.size(), elementType),
llvm::makeArrayRef(wrapper));
}
// Check whether an ArrayAttr contains non-zero values or not. // Check whether an ArrayAttr contains non-zero values or not.
bool hasNonZeroInArrayAttr(ArrayAttr attrs) { bool hasNonZeroInArrayAttr(ArrayAttr attrs) {
bool allZeros = true; bool allZeros = true;
@ -109,13 +92,3 @@ void ONNXConvOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) { OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ConvOpPaddingPattern>(context); results.insert<ConvOpPaddingPattern>(context);
} }
/// on the ONNXScalerOp.
void ONNXScalerOp::getCanonicalizationPatterns(
OwningRewritePatternList &result, MLIRContext *context) {
result.insert<ScalerNullPattern>(context);
result.insert<ScalerNullPattern2>(context);
result.insert<ScalerNoScalePattern>(context);
result.insert<ScalerNoOffsetPattern>(context);
result.insert<ScalerPattern>(context);
}

View File

@ -100,56 +100,4 @@ def ConvOpPaddingPattern: Pat<
[(HasNonZeroInArrayAttr:$pads), (IsNotStringAttrOfValue<"VALID"> $auto_pad)] [(HasNonZeroInArrayAttr:$pads), (IsNotStringAttrOfValue<"VALID"> $auto_pad)]
>; >;
//===----------------------------------------------------------------------===//
// ONNXScalerOp %X, %Offest, %Scale
// x input, a offset, b scale
//===----------------------------------------------------------------------===//
// Useful test definitions.
def AttributeIsNull :
Constraint<CPred<"! ($_self)">,
"Attribute is null">;
def HasFloatType : Constraint<CPred<"(($_self).getType().dyn_cast<ShapedType>().getElementType().isF32())">>;
// Create a DenseElementsAttr from an ArrayAttr.
def createDenseArrayAttr:
NativeCodeCall<"createDenseArrayAttr($_builder, $0)">;
def ScalerT : NativeCodeCall<"$_builder.getI64IntegerAttr(0)">;
// No attribute
def ScalerNullPattern : Pat<
(ONNXScalerOp $x, $a, $b),
(replaceWithValue $x),
[(HasFloatType:$x),(AttributeIsNull:$a), (AttributeIsNull:$b)]>;
// No attribute, input x not float type
def ScalerNullPattern2 : Pat<
(ONNXScalerOp $x, $a, $b),
(ONNXCastOp $x, (ScalerT)),
[(AttributeIsNull:$a), (AttributeIsNull:$b)]>;
// No scale
def ScalerNoScalePattern : Pat<
(ONNXScalerOp $x, $a, $b),
(ONNXSubOp $x,
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a))),
[(AttributeIsNull:$b)]>;
// No offset
def ScalerNoOffsetPattern : Pat<
(ONNXScalerOp $x, $a, $b),
(ONNXMulOp $x,
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b))),
[(AttributeIsNull:$a)]>;
// Normal ONNXScalerOp
def ScalerPattern : Pat<
(ONNXScalerOp $x, $a, $b),
(ONNXMulOp
(ONNXSubOp $x,
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a))),
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b)))>;
#endif // ONNX_REWRITE #endif // ONNX_REWRITE

View File

@ -96,86 +96,3 @@ func @test_gemm_add_fusion_rank3(%arg0: tensor<128x128x256xf32>, %arg1: tensor<1
// CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : i64, transB = 0 : i64} : (tensor<128x128x256xf32>, tensor<128x128x256xf32>, tensor<256xf32>) -> tensor<*xf32> // CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : i64, transB = 0 : i64} : (tensor<128x128x256xf32>, tensor<128x128x256xf32>, tensor<256xf32>) -> tensor<*xf32>
// return [[GEMM]] : tensor<*xf32> // return [[GEMM]] : tensor<*xf32>
} }
// -----
// Scaler Pattern test
// -----
// null
// CHECK-LABEL: func @test_scaler_null_float(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
func @test_scaler_null_float(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "onnx.Scaler"(%arg0) : (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-NEXT: return %arg0 : tensor<3xf32>
}
// -----
// null not float
// CHECK-LABEL: func @test_scaler_null(%{{.*}}: tensor<3xi32>) -> tensor<3xf32> {
func @test_scaler_null(%arg0: tensor<3xi32>) -> tensor<3xf32> {
%0 = "onnx.Scaler"(%arg0) : (tensor<3xi32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-NEXT: %0 = "onnx.Cast"(%arg0) {to = 0 : i64} : (tensor<3xi32>) -> tensor<3xf32>
// CHECK-NEXT: return %0 : tensor<3xf32>
}
// -----
// scaler no offset
// CHECK-LABEL: func @test_scaler_no_offset(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
func @test_scaler_no_offset(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "onnx.Scaler"(%arg0) {scale = [3.125000e-02 : f32, 0.0909090936 : f32, 0.0333333351 : f32]} : (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32>} : () -> tensor<3xf32>
// CHECK-NEXT: %1 = "onnx.Mul"(%arg0, %0) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK-NEXT: return %1 : tensor<3xf32>
}
// -----
// scaler no scale
// CHECK-LABEL: func @test_scaler_no_scale(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
func @test_scaler_no_scale(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32]} : (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32>} : () -> tensor<3xf32>
// CHECK-NEXT: %1 = "onnx.Sub"(%arg0, %0) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK-NEXT: return %1 : tensor<3xf32>
}
// -----
// normal scaler
// CHECK-LABEL: func @test_scaler_normal(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
func @test_scaler_normal(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32], scale = [3.125000e-02 : f32, 0.0909090936 : f32, 0.0333333351 : f32]} : (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32>} : () -> tensor<3xf32>
// CHECK-NEXT: %1 = "onnx.Sub"(%arg0, %0) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK-NEXT: %2 = "onnx.Constant"() {value = dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32>} : () -> tensor<3xf32>
// CHECK-NEXT: %3 = "onnx.Mul"(%1, %2) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK-NEXT: return %3 : tensor<3xf32>
}
// -----
// normal scaler with constant offset and scale
// CHECK-LABEL: func @test_scaler_constant(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
func @test_scaler_constant(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32], scale = [3.125000e-02 : f32]} : (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<1986.99939> : tensor<1xf32>} : () -> tensor<1xf32>
// CHECK-NEXT: %1 = "onnx.Sub"(%arg0, %0) : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32>
// CHECK-NEXT: %2 = "onnx.Constant"() {value = dense<3.125000e-02> : tensor<1xf32>} : () -> tensor<1xf32>
// CHECK-NEXT: %3 = "onnx.Mul"(%1, %2) : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32>
// CHECK-NEXT: return %3 : tensor<3xf32>
}

View File

@ -57,3 +57,83 @@ func @test_reducesumsquare(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
// CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[SQUARE]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32> // CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[SQUARE]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
} }
// -----
// Scaler Pattern test
// -----
// null
// CHECK-LABEL: func @test_scaler_null_float(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
func @test_scaler_null_float(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "onnx.Scaler"(%arg0) : (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-NEXT: return %arg0 : tensor<3xf32>
}
// -----
// null not float
// CHECK-LABEL: func @test_scaler_null(%{{.*}}: tensor<3xi32>) -> tensor<3xf32> {
func @test_scaler_null(%arg0: tensor<3xi32>) -> tensor<3xf32> {
%0 = "onnx.Scaler"(%arg0) : (tensor<3xi32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-NEXT: %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<3xi32>) -> tensor<3xf32>
// CHECK-NEXT: return %0 : tensor<3xf32>
}
// -----
// scaler no offset
// CHECK-LABEL: func @test_scaler_no_offset(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
func @test_scaler_no_offset(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "onnx.Scaler"(%arg0) {scale = [3.125000e-02 : f32, 0.0909090936 : f32, 0.0333333351 : f32]} : (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32>} : () -> tensor<3xf32>
// CHECK-NEXT: %1 = "onnx.Mul"(%arg0, %0) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK-NEXT: return %1 : tensor<3xf32>
}
// -----
// scaler no scale
// CHECK-LABEL: func @test_scaler_no_scale(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
func @test_scaler_no_scale(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32]} : (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32>} : () -> tensor<3xf32>
// CHECK-NEXT: %1 = "onnx.Sub"(%arg0, %0) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK-NEXT: return %1 : tensor<3xf32>
}
// -----
// normal scaler
// CHECK-LABEL: func @test_scaler_normal(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
func @test_scaler_normal(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32], scale = [3.125000e-02 : f32, 0.0909090936 : f32, 0.0333333351 : f32]} : (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32>} : () -> tensor<3xf32>
// CHECK-NEXT: %1 = "onnx.Sub"(%arg0, %0) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK-NEXT: %2 = "onnx.Constant"() {value = dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32>} : () -> tensor<3xf32>
// CHECK-NEXT: %3 = "onnx.Mul"(%1, %2) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK-NEXT: return %3 : tensor<3xf32>
}
// -----
// normal scaler with constant offset and scale
// CHECK-LABEL: func @test_scaler_constant(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
func @test_scaler_constant(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32], scale = [3.125000e-02 : f32]} : (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<1986.99939> : tensor<1xf32>} : () -> tensor<1xf32>
// CHECK-NEXT: %1 = "onnx.Sub"(%arg0, %0) : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32>
// CHECK-NEXT: %2 = "onnx.Constant"() {value = dense<3.125000e-02> : tensor<1xf32>} : () -> tensor<1xf32>
// CHECK-NEXT: %3 = "onnx.Mul"(%1, %2) : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32>
// CHECK-NEXT: return %3 : tensor<3xf32>
}

View File

@ -1355,3 +1355,15 @@ func @test_slice_all_constant_negative_steps(%arg0 : tensor<2x4xf32>) -> tensor<
// CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x2xf32> // CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x2xf32>
// CHECK: return [[RES]] : tensor<1x2xf32> // CHECK: return [[RES]] : tensor<1x2xf32>
} }
//===----------------------------------------------------------------------===//
/// Test the shape inferencing for the scaler operation.
//===----------------------------------------------------------------------===//
func @test_scaler_no_scale_int(%arg0: tensor<3xi32>) -> tensor<*xf32> {
%0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32]} : (tensor<3xi32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_scaler_no_scale_int
// CHECK: [[RES_ATTR:%.+]] = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32]} : (tensor<3xi32>) -> tensor<3xf32>
// CHECK: return [[RES_ATTR]] : tensor<3xf32>
}

View File

@ -252,11 +252,11 @@ OpsWithShapeInference = [
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN', 'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN',
'LSTM', 'GRU', 'Split', 'Pad', 'Cast', 'ConvTranspose', 'Flatten', 'LSTM', 'GRU', 'Split', 'Pad', 'Cast', 'ConvTranspose', 'Flatten',
'DynamicQuantizeLinear', 'QuantizeLinear', 'DequantizeLinear', 'ConvInteger', 'DynamicQuantizeLinear', 'QuantizeLinear', 'DequantizeLinear', 'ConvInteger',
'Squeeze', 'Shape', 'Tile', 'Gather', 'ConstantOfShape', 'Slice' 'Squeeze', 'Shape', 'Tile', 'Gather', 'ConstantOfShape', 'Slice', 'Scaler'
] ]
# Operations supporting canonicalization. # Operations supporting canonicalization.
OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Scaler'] OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv']
# Operations who have operands that, if produced by constant operations, should # Operations who have operands that, if produced by constant operations, should
# be promoted to become an attribute (via attribute promotion). # be promoted to become an attribute (via attribute promotion).