//===---- ONNXRewrite.td - Pattern Match Rewriting for ONNX --*- tablegen -===// // // Copyright 2019 The IBM Research Authors. // // ============================================================================= // // Defines language-specific pattern match optimizations for ONNX using // Declarative Rewrite Rules (DRR) specified using TableGen records. // //===----------------------------------------------------------------------===// #ifndef ONNX_REWRITE #define ONNX_REWRITE #ifndef OP_BASE include "src/Dialect/ONNX/ONNXOps.td" #endif // OP_BASE /// Note: The DRR definition used for defining patterns is shown below: /// /// class Pattern< /// dag sourcePattern, list resultPatterns, /// list additionalConstraints = [], /// dag benefitsAdded = (addBenefit 0) /// >; def GetNullAttr : NativeCodeCall<"Attribute()">; // Create a StringAttr from a string. class StringAttrOfValue: NativeCodeCall<"$_builder.getStringAttr(\"" # val # "\")">; // Create a DenseElementsAttr from an interger value. // It seems Table-gen does not support `float` type, so we can not pass a float value. class FloatAttrOfValue: NativeCodeCall<"createDenseFloatAttrOfValue($_builder, $0, " # val # ")">; // Create an ArrayAttr of IntergerAttr(s) of zero values. // This function is used for padding attribute in Conv. def createArrayAttrOfZerosFrom: NativeCodeCall<"createArrayAttrOfZeros($_builder, $0)">; // Pad a ArrayAttr with zeros. // // pads = [B1, B2, ... Bk, E1, E2, ..., Ek] // // becomes: // // pads = [0,... 0, B1, B2, ... Bk, 0,... 0, E1, E2, ..., Ek] // |_____| |_____| // nZeros nZeros // // This function is used for padding attribute in Conv. class insertZerosForNonPaddedDims: NativeCodeCall<"insertZerosForNonPaddedDims($_builder, $0," # extensionLength # ")">; // Check whether an ArrayAttr contains non-zero values or not. def HasNonZeroInArrayAttr: Constraint, "has non-zero elements">; // Check that a StrAttr does not contain a specific value. class IsNotStringAttrOfValue: Constraint().getValue() != \"" # val # "\"">>; //===----------------------------------------------------------------------===// // Rewrite: // %0 = onnx.ConvOp(%D : tensor, %K) // {pads = [b0, b1, ... bK, e0, e1, ..., eK]} -> // tensor // // as: // %0 = onnx.PadConstantValuePadOp(%D) // {pads = [0, 0, b0, b1, ... bK, 0, 0, e0, e1, ..., eK]} -> // tensor // %1 = onnx.Conv(%0 : tensor, %K) {pads = [0, ..., 0]} -> // tensor //===----------------------------------------------------------------------===// def ConvOpPaddingPattern: Pat< (ONNXConvOp:$res $x, $w, $b, $auto_pad, $dilation, $group, $kernel_shape, $pads, $strides), (ONNXConvOp (ONNXPadOp $x, (ONNXConstantOp (GetNullAttr), (insertZerosForNonPaddedDims<2> $pads)), (ONNXConstantOp (GetNullAttr), (FloatAttrOfValue<0> $res)), (StringAttrOfValue<"constant">)), $w, $b, $auto_pad, $dilation, $group, $kernel_shape, (createArrayAttrOfZerosFrom $pads), $strides), [(HasNonZeroInArrayAttr:$pads), (IsNotStringAttrOfValue<"VALID"> $auto_pad)] >; #endif // ONNX_REWRITE