Add ONNXScalerOp pattern (#220)

* add ONNXScalerOp pattern

* move ScalerOp rewrite rule to Rewrite.cpp .td

* attempt to fix format issue

* fixing format issue

* fixing format issue2

* add ONNXScalerOp pattern

* move ScalerOp rewrite rule to Rewrite.cpp .td

* attempt to fix format issue

* fixing format issue

* fixing format issue2
This commit is contained in:
Anh Leu 2020-07-17 10:01:30 -05:00 committed by GitHub
parent 13b8591af8
commit 4b33c312d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 164 additions and 1 deletions

View File

@ -6090,6 +6090,7 @@ def ONNXSVMRegressorOp:ONNX_Op<"SVMRegressor",
def ONNXScalerOp:ONNX_Op<"Scaler", def ONNXScalerOp:ONNX_Op<"Scaler",
[NoSideEffect]> { [NoSideEffect]> {
let hasCanonicalizer = 1;
let summary = "ONNX Scaler operation"; let summary = "ONNX Scaler operation";
let description = [{ let description = [{
"Rescale input data, for example to standardize features by removing the mean and scaling to unit variance." "Rescale input data, for example to standardize features by removing the mean and scaling to unit variance."

View File

@ -18,6 +18,23 @@ using namespace mlir;
namespace { namespace {
// Create an DenseElementsAttr of ArrayAttr.
// This function is used to get Value Type for Scaler function.
DenseElementsAttr createDenseArrayAttr(
PatternRewriter &rewriter, ArrayAttr origAttrs) {
mlir::Type elementType = rewriter.getF32Type();
int nElements = origAttrs.getValue().size();
SmallVector<float, 4> wrapper(nElements, 0);
if (origAttrs) {
for (int i = 0; i < nElements; ++i) {
wrapper[i] = origAttrs.getValue()[i].cast<FloatAttr>().getValueAsDouble();
}
}
return DenseElementsAttr::get(
RankedTensorType::get(wrapper.size(), elementType),
llvm::makeArrayRef(wrapper));
}
// Check whether an ArrayAttr contains non-zero values or not. // Check whether an ArrayAttr contains non-zero values or not.
bool hasNonZeroInArrayAttr(ArrayAttr attrs) { bool hasNonZeroInArrayAttr(ArrayAttr attrs) {
bool allZeros = true; bool allZeros = true;
@ -92,3 +109,13 @@ void ONNXConvOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) { OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ConvOpPaddingPattern>(context); results.insert<ConvOpPaddingPattern>(context);
} }
/// on the ONNXScalerOp.
void ONNXScalerOp::getCanonicalizationPatterns(
OwningRewritePatternList &result, MLIRContext *context) {
result.insert<ScalerNullPattern>(context);
result.insert<ScalerNullPattern2>(context);
result.insert<ScalerNoScalePattern>(context);
result.insert<ScalerNoOffsetPattern>(context);
result.insert<ScalerPattern>(context);
}

View File

@ -100,4 +100,56 @@ def ConvOpPaddingPattern: Pat<
[(HasNonZeroInArrayAttr:$pads), (IsNotStringAttrOfValue<"VALID"> $auto_pad)] [(HasNonZeroInArrayAttr:$pads), (IsNotStringAttrOfValue<"VALID"> $auto_pad)]
>; >;
//===----------------------------------------------------------------------===//
// ONNXScalerOp %X, %Offest, %Scale
// x input, a offset, b scale
//===----------------------------------------------------------------------===//
// Useful test definitions.
def AttributeIsNull :
Constraint<CPred<"! ($_self)">,
"Attribute is null">;
def HasFloatType : Constraint<CPred<"(($_self).getType().dyn_cast<ShapedType>().getElementType().isF32())">>;
// Create a DenseElementsAttr from an ArrayAttr.
def createDenseArrayAttr:
NativeCodeCall<"createDenseArrayAttr($_builder, $0)">;
def ScalerT : NativeCodeCall<"$_builder.getI64IntegerAttr(0)">;
// No attribute
def ScalerNullPattern : Pat<
(ONNXScalerOp $x, $a, $b),
(replaceWithValue $x),
[(HasFloatType:$x),(AttributeIsNull:$a), (AttributeIsNull:$b)]>;
// No attribute, input x not float type
def ScalerNullPattern2 : Pat<
(ONNXScalerOp $x, $a, $b),
(ONNXCastOp $x, (ScalerT)),
[(AttributeIsNull:$a), (AttributeIsNull:$b)]>;
// No scale
def ScalerNoScalePattern : Pat<
(ONNXScalerOp $x, $a, $b),
(ONNXSubOp $x,
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a))),
[(AttributeIsNull:$b)]>;
// No offset
def ScalerNoOffsetPattern : Pat<
(ONNXScalerOp $x, $a, $b),
(ONNXMulOp $x,
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b))),
[(AttributeIsNull:$a)]>;
// Normal ONNXScalerOp
def ScalerPattern : Pat<
(ONNXScalerOp $x, $a, $b),
(ONNXMulOp
(ONNXSubOp $x,
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a))),
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b)))>;
#endif // ONNX_REWRITE #endif // ONNX_REWRITE

View File

@ -96,3 +96,86 @@ func @test_gemm_add_fusion_rank3(%arg0: tensor<128x128x256xf32>, %arg1: tensor<1
// CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : i64, transB = 0 : i64} : (tensor<128x128x256xf32>, tensor<128x128x256xf32>, tensor<256xf32>) -> tensor<*xf32> // CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : i64, transB = 0 : i64} : (tensor<128x128x256xf32>, tensor<128x128x256xf32>, tensor<256xf32>) -> tensor<*xf32>
// return [[GEMM]] : tensor<*xf32> // return [[GEMM]] : tensor<*xf32>
} }
// -----
// Scaler Pattern test
// -----
// null
// CHECK-LABEL: func @test_scaler_null_float(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
func @test_scaler_null_float(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "onnx.Scaler"(%arg0) : (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-NEXT: return %arg0 : tensor<3xf32>
}
// -----
// null not float
// CHECK-LABEL: func @test_scaler_null(%{{.*}}: tensor<3xi32>) -> tensor<3xf32> {
func @test_scaler_null(%arg0: tensor<3xi32>) -> tensor<3xf32> {
%0 = "onnx.Scaler"(%arg0) : (tensor<3xi32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-NEXT: %0 = "onnx.Cast"(%arg0) {to = 0 : i64} : (tensor<3xi32>) -> tensor<3xf32>
// CHECK-NEXT: return %0 : tensor<3xf32>
}
// -----
// scaler no offset
// CHECK-LABEL: func @test_scaler_no_offset(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
func @test_scaler_no_offset(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "onnx.Scaler"(%arg0) {scale = [3.125000e-02 : f32, 0.0909090936 : f32, 0.0333333351 : f32]} : (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32>} : () -> tensor<3xf32>
// CHECK-NEXT: %1 = "onnx.Mul"(%arg0, %0) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK-NEXT: return %1 : tensor<3xf32>
}
// -----
// scaler no scale
// CHECK-LABEL: func @test_scaler_no_scale(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
func @test_scaler_no_scale(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32]} : (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32>} : () -> tensor<3xf32>
// CHECK-NEXT: %1 = "onnx.Sub"(%arg0, %0) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK-NEXT: return %1 : tensor<3xf32>
}
// -----
// normal scaler
// CHECK-LABEL: func @test_scaler_normal(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
func @test_scaler_normal(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32], scale = [3.125000e-02 : f32, 0.0909090936 : f32, 0.0333333351 : f32]} : (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32>} : () -> tensor<3xf32>
// CHECK-NEXT: %1 = "onnx.Sub"(%arg0, %0) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK-NEXT: %2 = "onnx.Constant"() {value = dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32>} : () -> tensor<3xf32>
// CHECK-NEXT: %3 = "onnx.Mul"(%1, %2) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK-NEXT: return %3 : tensor<3xf32>
}
// -----
// normal scaler with constant offset and scale
// CHECK-LABEL: func @test_scaler_constant(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
func @test_scaler_constant(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32], scale = [3.125000e-02 : f32]} : (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<1986.99939> : tensor<1xf32>} : () -> tensor<1xf32>
// CHECK-NEXT: %1 = "onnx.Sub"(%arg0, %0) : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32>
// CHECK-NEXT: %2 = "onnx.Constant"() {value = dense<3.125000e-02> : tensor<1xf32>} : () -> tensor<1xf32>
// CHECK-NEXT: %3 = "onnx.Mul"(%1, %2) : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32>
// CHECK-NEXT: return %3 : tensor<3xf32>
}

View File

@ -256,7 +256,7 @@ OpsWithShapeInference = [
] ]
# Operations supporting canonicalization. # Operations supporting canonicalization.
OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv'] OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Scaler']
# Operations who have operands that, if produced by constant operations, should # Operations who have operands that, if produced by constant operations, should
# be promoted to become an attribute (via attribute promotion). # be promoted to become an attribute (via attribute promotion).