Add shape inference for ScalerOp (#228)
* move scalerop to decompose * change clang format * change clang format * add shape inference for scaler op * fixing generated onnxop * generate onnx.md * Add shape inference for scaler op * add benefit for scaler decompose and simplify scaler shape inference
This commit is contained in:
parent
034f98c00c
commit
c9e3ba2d64
|
@ -1877,6 +1877,17 @@ LogicalResult ONNXCastOp::inferShapes() {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Scaler
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
LogicalResult ONNXScalerOp::inferShapes() {
|
||||||
|
ShapedType inputType = X().getType().dyn_cast<ShapedType>();
|
||||||
|
getResult().setType(RankedTensorType::get(
|
||||||
|
inputType.getShape(), FloatType::getF32(getContext())));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Constant
|
// Constant
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -6103,8 +6103,7 @@ def ONNXSVMRegressorOp:ONNX_Op<"SVMRegressor",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXScalerOp:ONNX_Op<"Scaler",
|
def ONNXScalerOp:ONNX_Op<"Scaler",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let hasCanonicalizer = 1;
|
|
||||||
let summary = "ONNX Scaler operation";
|
let summary = "ONNX Scaler operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Rescale input data, for example to standardize features by removing the mean and scaling to unit variance."
|
"Rescale input data, for example to standardize features by removing the mean and scaling to unit variance."
|
||||||
|
|
|
@ -24,6 +24,23 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
// Create an DenseElementsAttr of ArrayAttr.
|
||||||
|
// This function is used to get Value Type of an EXISTING ArrayAttr for Scaler
|
||||||
|
// function.
|
||||||
|
DenseElementsAttr createDenseArrayAttr(
|
||||||
|
PatternRewriter &rewriter, ArrayAttr origAttrs) {
|
||||||
|
mlir::Type elementType = rewriter.getF32Type();
|
||||||
|
int nElements = origAttrs.getValue().size();
|
||||||
|
SmallVector<float, 4> wrapper(nElements, 0);
|
||||||
|
for (int i = 0; i < nElements; ++i) {
|
||||||
|
wrapper[i] = origAttrs.getValue()[i].cast<FloatAttr>().getValueAsDouble();
|
||||||
|
}
|
||||||
|
return DenseElementsAttr::get(
|
||||||
|
RankedTensorType::get(wrapper.size(), elementType),
|
||||||
|
llvm::makeArrayRef(wrapper));
|
||||||
|
}
|
||||||
|
|
||||||
/// Include the patterns defined in the Declarative Rewrite framework.
|
/// Include the patterns defined in the Declarative Rewrite framework.
|
||||||
#include "src/Transform/ONNX/ONNXDecompose.inc"
|
#include "src/Transform/ONNX/ONNXDecompose.inc"
|
||||||
|
|
||||||
|
@ -47,6 +64,7 @@ void DecomposeONNXToONNXPass::runOnFunction() {
|
||||||
target.addIllegalOp<ONNXReduceLogSumOp>();
|
target.addIllegalOp<ONNXReduceLogSumOp>();
|
||||||
target.addIllegalOp<ONNXReduceLogSumExpOp>();
|
target.addIllegalOp<ONNXReduceLogSumExpOp>();
|
||||||
target.addIllegalOp<ONNXReduceSumSquareOp>();
|
target.addIllegalOp<ONNXReduceSumSquareOp>();
|
||||||
|
target.addIllegalOp<ONNXScalerOp>();
|
||||||
|
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
populateWithGenerated(context, &patterns);
|
populateWithGenerated(context, &patterns);
|
||||||
|
|
|
@ -54,4 +54,68 @@ def ReduceLogSumExpOpPattern: Pat<(ONNXReduceLogSumExpOp $oprd, $axes, $keepdims
|
||||||
def ReduceSumSquareOpPattern: Pat<(ONNXReduceSumSquareOp $oprd, $axes, $keepdims),
|
def ReduceSumSquareOpPattern: Pat<(ONNXReduceSumSquareOp $oprd, $axes, $keepdims),
|
||||||
(ONNXReduceSumOp (ONNXMulOp $oprd, $oprd), $axes, $keepdims)>;
|
(ONNXReduceSumOp (ONNXMulOp $oprd, $oprd), $axes, $keepdims)>;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ONNXScalerOp %X, %Offest, %Scale
|
||||||
|
// x input, a offset, b scale
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Useful test definitions.
|
||||||
|
def AttributeIsNull :
|
||||||
|
Constraint<CPred<"! ($_self)">,
|
||||||
|
"Attribute is null">;
|
||||||
|
|
||||||
|
def AttributeNotNull :
|
||||||
|
Constraint<CPred<" ($_self)">,
|
||||||
|
"Attribute exists">;
|
||||||
|
|
||||||
|
def HasFloatType : Constraint<CPred<"(($_self).getType().dyn_cast<ShapedType>().getElementType().isF32())">>;
|
||||||
|
|
||||||
|
def GetNullAttr :
|
||||||
|
NativeCodeCall<"Attribute()">;
|
||||||
|
|
||||||
|
// Create a DenseElementsAttr from an ArrayAttr.
|
||||||
|
def createDenseArrayAttr:
|
||||||
|
NativeCodeCall<"createDenseArrayAttr($_builder, $0)">;
|
||||||
|
|
||||||
|
def ScalerT : NativeCodeCall<"$_builder.getI64IntegerAttr(1)">;
|
||||||
|
|
||||||
|
// No attribute
|
||||||
|
def ScalerNullPattern : Pat<
|
||||||
|
(ONNXScalerOp $x, $a, $b),
|
||||||
|
(replaceWithValue $x),
|
||||||
|
[(HasFloatType:$x), (AttributeIsNull:$a), (AttributeIsNull:$b)],
|
||||||
|
(addBenefit 4)>;
|
||||||
|
|
||||||
|
// No attribute, input x not float type
|
||||||
|
def ScalerNullPattern2 : Pat<
|
||||||
|
(ONNXScalerOp $x, $a, $b),
|
||||||
|
(ONNXCastOp $x, (ScalerT)),
|
||||||
|
[(AttributeIsNull:$a), (AttributeIsNull:$b)],
|
||||||
|
(addBenefit 3)>;
|
||||||
|
|
||||||
|
// No scale
|
||||||
|
def ScalerNoScalePattern : Pat<
|
||||||
|
(ONNXScalerOp $x, $a, $b),
|
||||||
|
(ONNXSubOp $x,
|
||||||
|
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a))),
|
||||||
|
[(AttributeIsNull:$b)],
|
||||||
|
(addBenefit 2)>;
|
||||||
|
|
||||||
|
// No offset
|
||||||
|
def ScalerNoOffsetPattern : Pat<
|
||||||
|
(ONNXScalerOp $x, $a, $b),
|
||||||
|
(ONNXMulOp $x,
|
||||||
|
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b))),
|
||||||
|
[(AttributeIsNull:$a)],
|
||||||
|
(addBenefit 2)>;
|
||||||
|
|
||||||
|
// Normal ONNXScalerOp
|
||||||
|
def ScalerPattern : Pat<
|
||||||
|
(ONNXScalerOp $x, $a, $b),
|
||||||
|
(ONNXMulOp
|
||||||
|
(ONNXSubOp $x,
|
||||||
|
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a))),
|
||||||
|
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b))),
|
||||||
|
[],
|
||||||
|
(addBenefit 1)>;
|
||||||
|
|
||||||
#endif // ONNX_DECOMPOSE
|
#endif // ONNX_DECOMPOSE
|
||||||
|
|
|
@ -18,23 +18,6 @@ using namespace mlir;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Create an DenseElementsAttr of ArrayAttr.
|
|
||||||
// This function is used to get Value Type for Scaler function.
|
|
||||||
DenseElementsAttr createDenseArrayAttr(
|
|
||||||
PatternRewriter &rewriter, ArrayAttr origAttrs) {
|
|
||||||
mlir::Type elementType = rewriter.getF32Type();
|
|
||||||
int nElements = origAttrs.getValue().size();
|
|
||||||
SmallVector<float, 4> wrapper(nElements, 0);
|
|
||||||
if (origAttrs) {
|
|
||||||
for (int i = 0; i < nElements; ++i) {
|
|
||||||
wrapper[i] = origAttrs.getValue()[i].cast<FloatAttr>().getValueAsDouble();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return DenseElementsAttr::get(
|
|
||||||
RankedTensorType::get(wrapper.size(), elementType),
|
|
||||||
llvm::makeArrayRef(wrapper));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check whether an ArrayAttr contains non-zero values or not.
|
// Check whether an ArrayAttr contains non-zero values or not.
|
||||||
bool hasNonZeroInArrayAttr(ArrayAttr attrs) {
|
bool hasNonZeroInArrayAttr(ArrayAttr attrs) {
|
||||||
bool allZeros = true;
|
bool allZeros = true;
|
||||||
|
@ -109,13 +92,3 @@ void ONNXConvOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList &results, MLIRContext *context) {
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
results.insert<ConvOpPaddingPattern>(context);
|
results.insert<ConvOpPaddingPattern>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// on the ONNXScalerOp.
|
|
||||||
void ONNXScalerOp::getCanonicalizationPatterns(
|
|
||||||
OwningRewritePatternList &result, MLIRContext *context) {
|
|
||||||
result.insert<ScalerNullPattern>(context);
|
|
||||||
result.insert<ScalerNullPattern2>(context);
|
|
||||||
result.insert<ScalerNoScalePattern>(context);
|
|
||||||
result.insert<ScalerNoOffsetPattern>(context);
|
|
||||||
result.insert<ScalerPattern>(context);
|
|
||||||
}
|
|
||||||
|
|
|
@ -100,56 +100,4 @@ def ConvOpPaddingPattern: Pat<
|
||||||
[(HasNonZeroInArrayAttr:$pads), (IsNotStringAttrOfValue<"VALID"> $auto_pad)]
|
[(HasNonZeroInArrayAttr:$pads), (IsNotStringAttrOfValue<"VALID"> $auto_pad)]
|
||||||
>;
|
>;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// ONNXScalerOp %X, %Offest, %Scale
|
|
||||||
// x input, a offset, b scale
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Useful test definitions.
|
|
||||||
def AttributeIsNull :
|
|
||||||
Constraint<CPred<"! ($_self)">,
|
|
||||||
"Attribute is null">;
|
|
||||||
|
|
||||||
def HasFloatType : Constraint<CPred<"(($_self).getType().dyn_cast<ShapedType>().getElementType().isF32())">>;
|
|
||||||
|
|
||||||
// Create a DenseElementsAttr from an ArrayAttr.
|
|
||||||
def createDenseArrayAttr:
|
|
||||||
NativeCodeCall<"createDenseArrayAttr($_builder, $0)">;
|
|
||||||
|
|
||||||
def ScalerT : NativeCodeCall<"$_builder.getI64IntegerAttr(0)">;
|
|
||||||
|
|
||||||
// No attribute
|
|
||||||
def ScalerNullPattern : Pat<
|
|
||||||
(ONNXScalerOp $x, $a, $b),
|
|
||||||
(replaceWithValue $x),
|
|
||||||
[(HasFloatType:$x),(AttributeIsNull:$a), (AttributeIsNull:$b)]>;
|
|
||||||
|
|
||||||
// No attribute, input x not float type
|
|
||||||
def ScalerNullPattern2 : Pat<
|
|
||||||
(ONNXScalerOp $x, $a, $b),
|
|
||||||
(ONNXCastOp $x, (ScalerT)),
|
|
||||||
[(AttributeIsNull:$a), (AttributeIsNull:$b)]>;
|
|
||||||
|
|
||||||
// No scale
|
|
||||||
def ScalerNoScalePattern : Pat<
|
|
||||||
(ONNXScalerOp $x, $a, $b),
|
|
||||||
(ONNXSubOp $x,
|
|
||||||
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a))),
|
|
||||||
[(AttributeIsNull:$b)]>;
|
|
||||||
|
|
||||||
// No offset
|
|
||||||
def ScalerNoOffsetPattern : Pat<
|
|
||||||
(ONNXScalerOp $x, $a, $b),
|
|
||||||
(ONNXMulOp $x,
|
|
||||||
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b))),
|
|
||||||
[(AttributeIsNull:$a)]>;
|
|
||||||
|
|
||||||
// Normal ONNXScalerOp
|
|
||||||
def ScalerPattern : Pat<
|
|
||||||
(ONNXScalerOp $x, $a, $b),
|
|
||||||
(ONNXMulOp
|
|
||||||
(ONNXSubOp $x,
|
|
||||||
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a))),
|
|
||||||
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b)))>;
|
|
||||||
|
|
||||||
|
|
||||||
#endif // ONNX_REWRITE
|
#endif // ONNX_REWRITE
|
||||||
|
|
|
@ -96,86 +96,3 @@ func @test_gemm_add_fusion_rank3(%arg0: tensor<128x128x256xf32>, %arg1: tensor<1
|
||||||
// CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : i64, transB = 0 : i64} : (tensor<128x128x256xf32>, tensor<128x128x256xf32>, tensor<256xf32>) -> tensor<*xf32>
|
// CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : i64, transB = 0 : i64} : (tensor<128x128x256xf32>, tensor<128x128x256xf32>, tensor<256xf32>) -> tensor<*xf32>
|
||||||
// return [[GEMM]] : tensor<*xf32>
|
// return [[GEMM]] : tensor<*xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
|
||||||
// Scaler Pattern test
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// null
|
|
||||||
// CHECK-LABEL: func @test_scaler_null_float(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
|
|
||||||
func @test_scaler_null_float(%arg0: tensor<3xf32>) -> tensor<3xf32> {
|
|
||||||
%0 = "onnx.Scaler"(%arg0) : (tensor<3xf32>) -> tensor<3xf32>
|
|
||||||
return %0 : tensor<3xf32>
|
|
||||||
|
|
||||||
// CHECK-NEXT: return %arg0 : tensor<3xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// null not float
|
|
||||||
// CHECK-LABEL: func @test_scaler_null(%{{.*}}: tensor<3xi32>) -> tensor<3xf32> {
|
|
||||||
func @test_scaler_null(%arg0: tensor<3xi32>) -> tensor<3xf32> {
|
|
||||||
%0 = "onnx.Scaler"(%arg0) : (tensor<3xi32>) -> tensor<3xf32>
|
|
||||||
return %0 : tensor<3xf32>
|
|
||||||
|
|
||||||
// CHECK-NEXT: %0 = "onnx.Cast"(%arg0) {to = 0 : i64} : (tensor<3xi32>) -> tensor<3xf32>
|
|
||||||
// CHECK-NEXT: return %0 : tensor<3xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// scaler no offset
|
|
||||||
// CHECK-LABEL: func @test_scaler_no_offset(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
|
|
||||||
func @test_scaler_no_offset(%arg0: tensor<3xf32>) -> tensor<3xf32> {
|
|
||||||
%0 = "onnx.Scaler"(%arg0) {scale = [3.125000e-02 : f32, 0.0909090936 : f32, 0.0333333351 : f32]} : (tensor<3xf32>) -> tensor<3xf32>
|
|
||||||
return %0 : tensor<3xf32>
|
|
||||||
|
|
||||||
|
|
||||||
// CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32>} : () -> tensor<3xf32>
|
|
||||||
// CHECK-NEXT: %1 = "onnx.Mul"(%arg0, %0) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
|
|
||||||
// CHECK-NEXT: return %1 : tensor<3xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// scaler no scale
|
|
||||||
// CHECK-LABEL: func @test_scaler_no_scale(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
|
|
||||||
func @test_scaler_no_scale(%arg0: tensor<3xf32>) -> tensor<3xf32> {
|
|
||||||
%0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32]} : (tensor<3xf32>) -> tensor<3xf32>
|
|
||||||
return %0 : tensor<3xf32>
|
|
||||||
|
|
||||||
// CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32>} : () -> tensor<3xf32>
|
|
||||||
// CHECK-NEXT: %1 = "onnx.Sub"(%arg0, %0) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
|
|
||||||
// CHECK-NEXT: return %1 : tensor<3xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// normal scaler
|
|
||||||
// CHECK-LABEL: func @test_scaler_normal(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
|
|
||||||
func @test_scaler_normal(%arg0: tensor<3xf32>) -> tensor<3xf32> {
|
|
||||||
%0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32], scale = [3.125000e-02 : f32, 0.0909090936 : f32, 0.0333333351 : f32]} : (tensor<3xf32>) -> tensor<3xf32>
|
|
||||||
return %0 : tensor<3xf32>
|
|
||||||
|
|
||||||
// CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32>} : () -> tensor<3xf32>
|
|
||||||
// CHECK-NEXT: %1 = "onnx.Sub"(%arg0, %0) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
|
|
||||||
// CHECK-NEXT: %2 = "onnx.Constant"() {value = dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32>} : () -> tensor<3xf32>
|
|
||||||
// CHECK-NEXT: %3 = "onnx.Mul"(%1, %2) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
|
|
||||||
// CHECK-NEXT: return %3 : tensor<3xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// normal scaler with constant offset and scale
|
|
||||||
// CHECK-LABEL: func @test_scaler_constant(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
|
|
||||||
func @test_scaler_constant(%arg0: tensor<3xf32>) -> tensor<3xf32> {
|
|
||||||
%0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32], scale = [3.125000e-02 : f32]} : (tensor<3xf32>) -> tensor<3xf32>
|
|
||||||
return %0 : tensor<3xf32>
|
|
||||||
|
|
||||||
// CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<1986.99939> : tensor<1xf32>} : () -> tensor<1xf32>
|
|
||||||
// CHECK-NEXT: %1 = "onnx.Sub"(%arg0, %0) : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32>
|
|
||||||
// CHECK-NEXT: %2 = "onnx.Constant"() {value = dense<3.125000e-02> : tensor<1xf32>} : () -> tensor<1xf32>
|
|
||||||
// CHECK-NEXT: %3 = "onnx.Mul"(%1, %2) : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32>
|
|
||||||
// CHECK-NEXT: return %3 : tensor<3xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
|
@ -57,3 +57,83 @@ func @test_reducesumsquare(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
|
||||||
// CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[SQUARE]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
|
// CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[SQUARE]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// Scaler Pattern test
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// null
|
||||||
|
// CHECK-LABEL: func @test_scaler_null_float(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
|
||||||
|
func @test_scaler_null_float(%arg0: tensor<3xf32>) -> tensor<3xf32> {
|
||||||
|
%0 = "onnx.Scaler"(%arg0) : (tensor<3xf32>) -> tensor<3xf32>
|
||||||
|
return %0 : tensor<3xf32>
|
||||||
|
|
||||||
|
// CHECK-NEXT: return %arg0 : tensor<3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// null not float
|
||||||
|
// CHECK-LABEL: func @test_scaler_null(%{{.*}}: tensor<3xi32>) -> tensor<3xf32> {
|
||||||
|
func @test_scaler_null(%arg0: tensor<3xi32>) -> tensor<3xf32> {
|
||||||
|
%0 = "onnx.Scaler"(%arg0) : (tensor<3xi32>) -> tensor<3xf32>
|
||||||
|
return %0 : tensor<3xf32>
|
||||||
|
|
||||||
|
// CHECK-NEXT: %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<3xi32>) -> tensor<3xf32>
|
||||||
|
// CHECK-NEXT: return %0 : tensor<3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// scaler no offset
|
||||||
|
// CHECK-LABEL: func @test_scaler_no_offset(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
|
||||||
|
func @test_scaler_no_offset(%arg0: tensor<3xf32>) -> tensor<3xf32> {
|
||||||
|
%0 = "onnx.Scaler"(%arg0) {scale = [3.125000e-02 : f32, 0.0909090936 : f32, 0.0333333351 : f32]} : (tensor<3xf32>) -> tensor<3xf32>
|
||||||
|
return %0 : tensor<3xf32>
|
||||||
|
|
||||||
|
// CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32>} : () -> tensor<3xf32>
|
||||||
|
// CHECK-NEXT: %1 = "onnx.Mul"(%arg0, %0) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
|
||||||
|
// CHECK-NEXT: return %1 : tensor<3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// scaler no scale
|
||||||
|
// CHECK-LABEL: func @test_scaler_no_scale(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
|
||||||
|
func @test_scaler_no_scale(%arg0: tensor<3xf32>) -> tensor<3xf32> {
|
||||||
|
%0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32]} : (tensor<3xf32>) -> tensor<3xf32>
|
||||||
|
return %0 : tensor<3xf32>
|
||||||
|
|
||||||
|
// CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32>} : () -> tensor<3xf32>
|
||||||
|
// CHECK-NEXT: %1 = "onnx.Sub"(%arg0, %0) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
|
||||||
|
// CHECK-NEXT: return %1 : tensor<3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// normal scaler
|
||||||
|
// CHECK-LABEL: func @test_scaler_normal(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
|
||||||
|
func @test_scaler_normal(%arg0: tensor<3xf32>) -> tensor<3xf32> {
|
||||||
|
%0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32], scale = [3.125000e-02 : f32, 0.0909090936 : f32, 0.0333333351 : f32]} : (tensor<3xf32>) -> tensor<3xf32>
|
||||||
|
return %0 : tensor<3xf32>
|
||||||
|
|
||||||
|
// CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32>} : () -> tensor<3xf32>
|
||||||
|
// CHECK-NEXT: %1 = "onnx.Sub"(%arg0, %0) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
|
||||||
|
// CHECK-NEXT: %2 = "onnx.Constant"() {value = dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32>} : () -> tensor<3xf32>
|
||||||
|
// CHECK-NEXT: %3 = "onnx.Mul"(%1, %2) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
|
||||||
|
// CHECK-NEXT: return %3 : tensor<3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// normal scaler with constant offset and scale
|
||||||
|
// CHECK-LABEL: func @test_scaler_constant(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
|
||||||
|
func @test_scaler_constant(%arg0: tensor<3xf32>) -> tensor<3xf32> {
|
||||||
|
%0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32], scale = [3.125000e-02 : f32]} : (tensor<3xf32>) -> tensor<3xf32>
|
||||||
|
return %0 : tensor<3xf32>
|
||||||
|
|
||||||
|
// CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<1986.99939> : tensor<1xf32>} : () -> tensor<1xf32>
|
||||||
|
// CHECK-NEXT: %1 = "onnx.Sub"(%arg0, %0) : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32>
|
||||||
|
// CHECK-NEXT: %2 = "onnx.Constant"() {value = dense<3.125000e-02> : tensor<1xf32>} : () -> tensor<1xf32>
|
||||||
|
// CHECK-NEXT: %3 = "onnx.Mul"(%1, %2) : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32>
|
||||||
|
// CHECK-NEXT: return %3 : tensor<3xf32>
|
||||||
|
}
|
||||||
|
|
|
@ -1355,3 +1355,15 @@ func @test_slice_all_constant_negative_steps(%arg0 : tensor<2x4xf32>) -> tensor<
|
||||||
// CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x2xf32>
|
// CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x2xf32>
|
||||||
// CHECK: return [[RES]] : tensor<1x2xf32>
|
// CHECK: return [[RES]] : tensor<1x2xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
/// Test the shape inferencing for the scaler operation.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
func @test_scaler_no_scale_int(%arg0: tensor<3xi32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32]} : (tensor<3xi32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_scaler_no_scale_int
|
||||||
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32]} : (tensor<3xi32>) -> tensor<3xf32>
|
||||||
|
// CHECK: return [[RES_ATTR]] : tensor<3xf32>
|
||||||
|
}
|
||||||
|
|
|
@ -252,11 +252,11 @@ OpsWithShapeInference = [
|
||||||
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN',
|
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN',
|
||||||
'LSTM', 'GRU', 'Split', 'Pad', 'Cast', 'ConvTranspose', 'Flatten',
|
'LSTM', 'GRU', 'Split', 'Pad', 'Cast', 'ConvTranspose', 'Flatten',
|
||||||
'DynamicQuantizeLinear', 'QuantizeLinear', 'DequantizeLinear', 'ConvInteger',
|
'DynamicQuantizeLinear', 'QuantizeLinear', 'DequantizeLinear', 'ConvInteger',
|
||||||
'Squeeze', 'Shape', 'Tile', 'Gather', 'ConstantOfShape', 'Slice'
|
'Squeeze', 'Shape', 'Tile', 'Gather', 'ConstantOfShape', 'Slice', 'Scaler'
|
||||||
]
|
]
|
||||||
|
|
||||||
# Operations supporting canonicalization.
|
# Operations supporting canonicalization.
|
||||||
OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Scaler']
|
OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv']
|
||||||
|
|
||||||
# Operations who have operands that, if produced by constant operations, should
|
# Operations who have operands that, if produced by constant operations, should
|
||||||
# be promoted to become an attribute (via attribute promotion).
|
# be promoted to become an attribute (via attribute promotion).
|
||||||
|
|
Loading…
Reference in New Issue