diff --git a/doc/gen_doc.py b/doc/gen_doc.py index df1337d..d42eb27 100644 --- a/doc/gen_doc.py +++ b/doc/gen_doc.py @@ -54,6 +54,15 @@ ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu', CanonicalList=['Add', 'Identity', 'ReduceL1', 'ReduceL2', 'ReduceLogSum', 'ReduceLogSumExp', 'ReduceSumSquare'] +#add an Op in this list if the Op needs result type deduction which is required +#when writing declarative rewriting rules. Deduced type is always +#an UnrankedTensorType whose element type is the same as the first operand's +#element type. +#currenlty, there are only two build methods generated: +# - one with operands and attributes having a separate parameter, and +# - one with operands and attributes having aggregated parameters. +custom_builder_ops_list = ['Abs', 'Mul', 'Exp', 'ReduceSum', 'ReduceSumSquare'] + manual_code_in_op_def = dict([ ('DummyExample', ' let extraClassDeclaration = [{ \n'+ ' static StringRef getPermAttrName() { return "perm"; }\n'+ @@ -345,38 +354,23 @@ def gen_schema(schema) : #input s+= '\n'+line_indent+'let arguments = (ins ' isfirst = True - if schema.inputs: - isfirst = False - for input in schema.inputs: - if input != schema.inputs[0] : - s+= ',\n ' - etypes=collect_types(schema, input) + # add operands + operand_ins = get_operand_ins(schema) + for operand_type, operand_name in operand_ins: + if not isfirst: + s+= ',\n ' + else: + isfirst = False + s+=operand_type+':$'+operand_name - if OpSchema.FormalParameterOption.Optional == input.option: - #TODO: handle optional - print("warning: optional input for"+schema.name+' '+input.name) - elif OpSchema.FormalParameterOption.Variadic == input.option: - if input.isHomogeneous: - s+= 'Variadic<' - else: - #TODO handle (variadic, heterogeneous)" - print("warning: (variadic, heterogeneous) for"+schema.name+' '+input.name) - if etypes == '': - s+= 'AnyTypeOf<[AnyMemRef, AnyTensor]>' - else: - s+= 'TensorOf<['+etypes+']>' - - if OpSchema.FormalParameterOption.Optional == input.option: - #TODO: handle optional - t='' - elif OpSchema.FormalParameterOption.Variadic == input.option: - if input.isHomogeneous: - s+= '>' - else: - #TODO handle (variadic, heterogeneous)" - t='' - s+=':$'+input.name - s += gen_attr_ins(schema, isfirst) + # add attributes + attr_ins = get_attr_ins(schema) + for attr_type, attr_name in attr_ins: + if not isfirst: + s += ',\n ' + else : + isfirst = False + s += attr_type+':$'+attr_name s+= ');' #output @@ -395,6 +389,71 @@ def gen_schema(schema) : s+= ');\n' #s+= 'let hasCanonicalizer = 1;' + + #TODO: any better way to do this. + def get_attr_type_for_builder(attr_type) : + if 'I64Attr' in attr_type : + mytype = 'IntegerAttr' + elif 'F32Attr' in attr_type : + mytype = 'FloatAttr' + elif 'I64ArrayAttr' in attr_type or 'F32ArrayAttr' in attr_type: + mytype = 'ArrayAttr' + elif 'StrAttr' in attr_type : + mytype = 'StringAttr' + elif 'strings' in attr_type : + mytype = 'ArrayAttr' + else : + mytype ='Attribute' + return mytype + + def get_op_type_for_builder(op_type): + if op_type.startswith('Variadic'): + mytype = 'ValueRange' + else: + mytype = 'Value' + return mytype + + # add custom builders + # use element type of the first operand to construct an UnrankedTensorType for the output. + if schema.name in custom_builder_ops_list: + if len(operand_ins) == 0: + print("warning: not generate custom build methods for " + schema.name + " since it does not have operands.") + else: + if get_op_type_for_builder(operand_ins[0][0]) == 'ValueRange': + first_operand = operand_ins[0][1]+'[0]' + else: + first_operand = operand_ins[0][1] + + s += line_indent+'let builders = [\n' + + # custom builders with operands and attributes having a seperate parameter. + # E.g. OpBuilder<"Builder *builder, OperationState &state, Value X, Value, Y, Attribute A", [{}]> + s += line_indent*2+'OpBuilder<"Builder *builder, OperationState &state' + for arg_type, arg_name in operand_ins: + s += ', '+get_op_type_for_builder(arg_type)+' '+arg_name + for attr_type, attr_name in attr_ins: + s += ', '+get_attr_type_for_builder(attr_type)+' '+attr_name + s += '", [{\n' + s += line_indent*3+'auto elementType = '+first_operand+'.getType().cast().getElementType();\n' + s += line_indent*3+'build(builder, state, UnrankedTensorType::get(elementType)' + for _, arg_name in operand_ins: + s += ', '+arg_name + for _, attr_name in attr_ins: + s += ', '+attr_name + s += ');\n' + s += line_indent*2+'}]>,\n' + + # custom builders with all operands and attributes having aggregate parameters. + # E.g. OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{}]>' + s += line_indent*2+'OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{\n' + s += line_indent*3+'auto elementType = '+first_operand+'.getType().cast().getElementType();\n' + s += line_indent*3+'std::vector outputTypes;\n' + s += line_indent*3+'outputTypes.emplace_back(UnrankedTensorType::get(elementType));\n' + s += line_indent*3+'build(builder, state, outputTypes, operands, attributes);\n' + s += line_indent*2+'}]>' + + s += '\n'+line_indent+'];\n' + #add special code if schema.name in manual_code_in_op_def : s += manual_code_in_op_def[schema.name] @@ -447,7 +506,41 @@ def gen_code(schema,fefile) : else: fefile.write(', '+variadicIn+', '+variadicOut+');\n') -def gen_attr_ins(schema, isfirst) : +def get_operand_ins(schema): + operand_type_and_name_list = [] # [(optype, opname)] + if schema.inputs: + for input in schema.inputs: + optype = "" + + etypes=collect_types(schema, input) + + if OpSchema.FormalParameterOption.Optional == input.option: + #TODO : handle optional + print("warning: optional input for"+schema.name+' '+input.name) + elif OpSchema.FormalParameterOption.Variadic == input.option: + if input.isHomogeneous: + optype += 'Variadic<' + else: + #TODO handle(variadic, heterogeneous) " + print("warning: (variadic, heterogeneous) for"+schema.name+' '+input.name) + if etypes == '': + optype += 'AnyTypeOf<[AnyMemRef, AnyTensor]>' + else: + optype += 'TensorOf<['+etypes+']>' + + if OpSchema.FormalParameterOption.Optional == input.option: + #TODO : handle optional + t='' + elif OpSchema.FormalParameterOption.Variadic == input.option: + if input.isHomogeneous: + optype += '>' + else: + #TODO handle(variadic, heterogeneous) " + t='' + operand_type_and_name_list.append((optype, input.name)) + return operand_type_and_name_list + +def get_attr_ins(schema) : def get_attr_type_basic(attr_type) : if attr_type == 'int' : @@ -479,24 +572,22 @@ def gen_attr_ins(schema, isfirst) : mytype += ', "'+attr_default+'">' return mytype + attr_type_and_name_list = [] # :: [(attrtype, attrname)] attr_line = '' if schema.attributes: for _, attr in sorted(schema.attributes.items()): #attr_line = line_indent+line_indent+line_indent+line_indent - if not isfirst: - attr_line += ',\n ' - else : - isfirst = False - + found = False + attr_type = "" if schema.name+' '+attr.name in special_attr_defaults: (attr_type_str, attr_default_str) = special_attr_defaults[schema.name+' '+attr.name] - attr_line += get_attr_type_with_default(attr_type_str, attr_default_str) - attr_line += ':$'+attr.name + attr_type = get_attr_type_with_default(attr_type_str, attr_default_str) + found = True elif attr.required: s = Text(attr.type) attr_type_str = s[s.rfind('.') + 1:].lower() - attr_line += get_attr_type_basic(attr_type_str) - attr_line += ':$'+attr.name + attr_type = get_attr_type_basic(attr_type_str) + found = True # option holds either required or default value elif attr.default_value.name: @@ -527,14 +618,15 @@ def gen_attr_ins(schema, isfirst) : else: default_value = format_value(default_value) attr_option_str = default_value - attr_line += get_attr_type_with_default(attr_type_str, attr_option_str) - attr_line += ':$'+attr.name + attr_type = get_attr_type_with_default(attr_type_str, attr_option_str) + found = True else: s = Text(attr.type) attr_type_str = s[s.rfind('.') + 1:].lower() - attr_line += get_attr_type_optional(attr_type_str) - attr_line += ':$'+attr.name - return attr_line + attr_type = get_attr_type_optional(attr_type_str) + if found: + attr_type_and_name_list.append((attr_type, attr.name)) + return attr_type_and_name_list def main(args): # type: (Type[Args]) -> None with io.open(args.changelog, 'w', newline='') as fout: diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 4a03cbd..d895be5 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -36,6 +36,11 @@ onnf_tablegen(onnx_combine.inc -gen-rewriters) add_public_tablegen_target(gen_onnx_combine) add_dependencies(compiler gen_onnx_combine) +set(LLVM_TARGET_DEFINITIONS pass/onnx_rewrite.td) +onnf_tablegen(onnx_rewrite.inc -gen-rewriters) +add_public_tablegen_target(gen_onnx_rewrite) +add_dependencies(compiler gen_onnx_rewrite) + set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td) onnf_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass") onnf_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass") diff --git a/src/dialect/onnx/onnxop.inc b/src/dialect/onnx/onnxop.inc index b293000..38d7075 100644 --- a/src/dialect/onnx/onnxop.inc +++ b/src/dialect/onnx/onnxop.inc @@ -14,6 +14,18 @@ def ONNXAbsOp:ONNX_Op<"Abs", }]; let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let builders = [ + OpBuilder<"Builder *builder, OperationState &state, Value X", [{ + auto elementType = X.getType().cast().getElementType(); + build(builder, state, UnrankedTensorType::get(elementType), X); + }]>, + OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{ + auto elementType = operands[0].getType().cast().getElementType(); + std::vector outputTypes; + outputTypes.emplace_back(UnrankedTensorType::get(elementType)); + build(builder, state, outputTypes, operands, attributes); + }]> + ]; } def ONNXAcosOp:ONNX_Op<"Acos", @@ -649,6 +661,18 @@ def ONNXExpOp:ONNX_Op<"Exp", }]; let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let builders = [ + OpBuilder<"Builder *builder, OperationState &state, Value input", [{ + auto elementType = input.getType().cast().getElementType(); + build(builder, state, UnrankedTensorType::get(elementType), input); + }]>, + OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{ + auto elementType = operands[0].getType().cast().getElementType(); + std::vector outputTypes; + outputTypes.emplace_back(UnrankedTensorType::get(elementType)); + build(builder, state, outputTypes, operands, attributes); + }]> + ]; } def ONNXExpandOp:ONNX_Op<"Expand", @@ -1763,6 +1787,18 @@ def ONNXMulOp:ONNX_Op<"Mul", let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$C); + let builders = [ + OpBuilder<"Builder *builder, OperationState &state, Value A, Value B", [{ + auto elementType = A.getType().cast().getElementType(); + build(builder, state, UnrankedTensorType::get(elementType), A, B); + }]>, + OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{ + auto elementType = operands[0].getType().cast().getElementType(); + std::vector outputTypes; + outputTypes.emplace_back(UnrankedTensorType::get(elementType)); + build(builder, state, outputTypes, operands, attributes); + }]> + ]; } def ONNXMultinomialOp:ONNX_Op<"Multinomial", @@ -2431,6 +2467,18 @@ def ONNXReduceSumOp:ONNX_Op<"ReduceSum", OptionalAttr:$axes, DefaultValuedAttr:$keepdims); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reduced); + let builders = [ + OpBuilder<"Builder *builder, OperationState &state, Value data, ArrayAttr axes, IntegerAttr keepdims", [{ + auto elementType = data.getType().cast().getElementType(); + build(builder, state, UnrankedTensorType::get(elementType), data, axes, keepdims); + }]>, + OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{ + auto elementType = operands[0].getType().cast().getElementType(); + std::vector outputTypes; + outputTypes.emplace_back(UnrankedTensorType::get(elementType)); + build(builder, state, outputTypes, operands, attributes); + }]> + ]; } def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare", @@ -2449,6 +2497,18 @@ def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare", OptionalAttr:$axes, DefaultValuedAttr:$keepdims); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reduced); + let builders = [ + OpBuilder<"Builder *builder, OperationState &state, Value data, ArrayAttr axes, IntegerAttr keepdims", [{ + auto elementType = data.getType().cast().getElementType(); + build(builder, state, UnrankedTensorType::get(elementType), data, axes, keepdims); + }]>, + OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{ + auto elementType = operands[0].getType().cast().getElementType(); + std::vector outputTypes; + outputTypes.emplace_back(UnrankedTensorType::get(elementType)); + build(builder, state, outputTypes, operands, attributes); + }]> + ]; } def ONNXReluOp:ONNX_Op<"Relu", diff --git a/src/pass/onnx_rewrite.cpp b/src/pass/onnx_rewrite.cpp index bf2527b..afe43c1 100644 --- a/src/pass/onnx_rewrite.cpp +++ b/src/pass/onnx_rewrite.cpp @@ -17,252 +17,8 @@ using namespace mlir; namespace { - -// There are two ways to write rewrite rules: -// - Declarative manner: specify rewrite rules in a TableGen record, and -// - Manual Manner: subclass the mlir::RewritePattern. -// -// We prefer to use the former way as much as possible. However, there is a -// limitation about operation definition specification (ODS) in TableGen that -// requires us to write custom builders, that is -// "all ODS-generated `build()` methods require specifying the result type(s), -// unless the op has known traits like `SameOperandsAndResultType` that we can -// use to auto-generate a `build()` method with result type deduction". -// -// More information about the limitation can be found here: -// https://github.com/llvm/llvm-project/blob/master/mlir/docs/DeclarativeRewrites.md#building-operations -// -// Currently, we use the latter way of writing rewrite rules. There are two -// reasons for this decision: -// - To insert custom builders for operations, it is better to change the script -// gen_doc.py to generate all possibles custom builders for a large class of -// operations. At the time of this patch created, the gen_doc.py was changing, -// so we decided to write manually to reduce conflicts. -// - In declarative rewriting, we should deal with optional attributes. E.g. for -// to handle optional attributes, but I haven't tried it yet. -// -// Once we have done the above issues, we will switch to use the declarative -// manner. - -//===----------------------------------------------------------------------===// -// ONNXReduceL1Op %X = ONNXReduceSumOp (ONNXAbsOp %X) -//===----------------------------------------------------------------------===// -struct ReduceL1OpPattern : public RewritePattern { - ReduceL1OpPattern(MLIRContext *context) - : RewritePattern(ONNXReduceL1Op::getOperationName(), - {ONNXAbsOp::getOperationName(), - ONNXReduceSumOp::getOperationName()}, - 1, context) {} - - PatternMatchResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - auto opInput = op->getOperands()[0]; // %X - auto opResults = op->getResults(); - auto opAttrs = op->getAttrs(); - - // Rewrite - ONNXAbsOp absOp; - { - auto elementType = opInput.getType().cast().getElementType(); - absOp = rewriter.create( - loc, UnrankedTensorType::get(elementType), opInput); - } - - ONNXReduceSumOp sumOp; - { - SmallVector types; - for (auto v : opResults) { - types.emplace_back(v.getType()); - } - - SmallVector values; - values.emplace_back(absOp.getResult()); - - SmallVector attrs; - for (auto attr : opAttrs) { - attrs.emplace_back(attr); - } - - sumOp = rewriter.create(loc, types, values, attrs); - } - - rewriter.replaceOp(op, sumOp.getResult()); - return matchSuccess(); - }; -}; - -//===----------------------------------------------------------------------===// -// ONNXReduceL2Op %X = ONNXSqrtOp (ONNXReduceSumSquareOp (%X)) -//===----------------------------------------------------------------------===// -struct ReduceL2OpPattern : public RewritePattern { - ReduceL2OpPattern(MLIRContext *context) - : RewritePattern(ONNXReduceL2Op::getOperationName(), - {ONNXSqrtOp::getOperationName(), - ONNXReduceSumSquareOp::getOperationName()}, - 1, context) {} - - PatternMatchResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - auto opInput = op->getOperands()[0]; // %X - auto opResults = op->getResults(); - auto opAttrs = op->getAttrs(); - - // Rewrite - ONNXReduceSumSquareOp sumSquareOp; - { - auto elementType = opInput.getType().cast().getElementType(); - sumSquareOp = rewriter.create( - loc, UnrankedTensorType::get(elementType), opInput, opAttrs); - } - - ONNXSqrtOp sqrtOp; - { - SmallVector types; - for (auto v : opResults) { - types.emplace_back(v.getType()); - } - sqrtOp = rewriter.create(loc, types, sumSquareOp.getResult()); - } - - rewriter.replaceOp(op, sqrtOp.getResult()); - return matchSuccess(); - }; -}; - -//===----------------------------------------------------------------------===// -// ONNXReduceLogSumOp %X = ONNXLogOp (ONNXReduceSumOp (%X)) -//===----------------------------------------------------------------------===// -struct ReduceLogSumOpPattern : public RewritePattern { - ReduceLogSumOpPattern(MLIRContext *context) - : RewritePattern(ONNXReduceLogSumOp::getOperationName(), - {ONNXReduceSumOp::getOperationName(), - ONNXLogOp::getOperationName()}, - 1, context) {} - - PatternMatchResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - auto opInput = op->getOperands()[0]; // %X - auto opResults = op->getResults(); - auto opAttrs = op->getAttrs(); - - // Rewrite - ONNXReduceSumOp sumOp; - { - auto elementType = opInput.getType().cast().getElementType(); - sumOp = rewriter.create( - loc, UnrankedTensorType::get(elementType), opInput, opAttrs); - } - - ONNXLogOp logOp; - { - SmallVector types; - for (auto v : opResults) { - types.emplace_back(v.getType()); - } - logOp = rewriter.create(loc, types, sumOp.getResult()); - } - - rewriter.replaceOp(op, logOp.getResult()); - return matchSuccess(); - }; -}; - -//===----------------------------------------------------------------------===// -// ONNXReduceLogSumExpOp %X = ONNXReduceLogSumOp (ONNXExpOp %X) -//===----------------------------------------------------------------------===// -struct ReduceLogSumExpOpPattern : public RewritePattern { - ReduceLogSumExpOpPattern(MLIRContext *context) - : RewritePattern(ONNXReduceLogSumExpOp::getOperationName(), - {ONNXExpOp::getOperationName(), - ONNXReduceLogSumOp::getOperationName()}, - 1, context) {} - - PatternMatchResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - auto opInput = op->getOperands()[0]; // %X - auto opResults = op->getResults(); - auto opAttrs = op->getAttrs(); - - // Rewrite - ONNXExpOp expOp; - { - auto elementType = opInput.getType().cast().getElementType(); - expOp = rewriter.create( - loc, UnrankedTensorType::get(elementType), opInput); - } - - ONNXReduceLogSumOp logSumOp; - { - SmallVector types; - for (auto v : opResults) { - types.emplace_back(v.getType()); - } - - SmallVector values; - values.emplace_back(expOp.getResult()); - - SmallVector attrs; - for (auto attr : opAttrs) { - attrs.emplace_back(attr); - } - logSumOp = rewriter.create(loc, types, values, attrs); - } - - rewriter.replaceOp(op, logSumOp.getResult()); - return matchSuccess(); - }; -}; - -//===----------------------------------------------------------------------===// -// ONNXReduceSumSquareOp %X = ONNXReduceSumOp (ONNXMulOp %X, %X) -//===----------------------------------------------------------------------===// -struct ReduceSumSquareOpPattern : public RewritePattern { - ReduceSumSquareOpPattern(MLIRContext *context) - : RewritePattern(ONNXReduceSumSquareOp::getOperationName(), - {ONNXMulOp::getOperationName(), - ONNXReduceSumOp::getOperationName()}, - 1, context) {} - - PatternMatchResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - auto opInput = op->getOperands()[0]; // %X - auto opResults = op->getResults(); - auto opAttrs = op->getAttrs(); - - // Rewrite - ONNXMulOp mulOp; - { - auto elementType = opInput.getType().cast().getElementType(); - mulOp = rewriter.create( - loc, UnrankedTensorType::get(elementType), opInput, opInput); - } - - ONNXReduceSumOp sumOp; - { - SmallVector types; - for (auto v : opResults) { - types.emplace_back(v.getType()); - } - - SmallVector values; - values.emplace_back(mulOp.getResult()); - - SmallVector attrs; - for (auto attr : opAttrs) { - attrs.emplace_back(attr); - } - sumOp = rewriter.create(loc, types, values, attrs); - } - - rewriter.replaceOp(op, sumOp.getResult()); - return matchSuccess(); - }; -}; +/// Include the patterns defined in the Declarative Rewrite framework. +#include "src/onnx_rewrite.inc" //===----------------------------------------------------------------------===// // Rewrite: diff --git a/src/pass/onnx_rewrite.td b/src/pass/onnx_rewrite.td new file mode 100644 index 0000000..43dc99c --- /dev/null +++ b/src/pass/onnx_rewrite.td @@ -0,0 +1,57 @@ +//===----------------------------------------------------------------------===// +//=- onnx_rewrite.td - Pattern Match Rewriting for ONNX -*- tablegen -*----===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// Defines language-specific pattern match optimizations for ONNX using +// Declarative Rewrite Rules (DRR) specified using TableGen records. +// + +#ifndef ONNX_REWRITE +#define ONNX_REWRITE + +#ifndef OP_BASE +include "dialect/onnx/onnx.td" +#endif // OP_BASE + +/// Note: The DRR definition used for defining patterns is shown below: +/// +/// class Pattern< +/// dag sourcePattern, list resultPatterns, +/// list additionalConstraints = [], +/// dag benefitsAdded = (addBenefit 0) +/// >; + +//===----------------------------------------------------------------------===// +// ONNXReduceL1Op %X = ONNXReduceSumOp (ONNXAbsOp %X) +//===----------------------------------------------------------------------===// +def ReduceL1OpPattern: Pat<(ONNXReduceL1Op $oprd, $axes, $keepdims), + (ONNXReduceSumOp (ONNXAbsOp $oprd), $axes, $keepdims)>; + +//===----------------------------------------------------------------------===// +// ONNXReduceL2Op %X = ONNXSqrtOp (ONNXReduceSumSquareOp (%X)) +//===----------------------------------------------------------------------===// +def ReduceL2OpPattern: Pat<(ONNXReduceL2Op $oprd, $axes, $keepdims), + (ONNXSqrtOp (ONNXReduceSumSquareOp $oprd, $axes, $keepdims))>; + +//===----------------------------------------------------------------------===// +// ONNXReduceLogSumOp %X = ONNXLogOp (ONNXReduceSumOp (%X)) +//===----------------------------------------------------------------------===// +def ReduceLogSumOpPattern: Pat<(ONNXReduceLogSumOp $oprd, $axes, $keepdims), + (ONNXLogOp (ONNXReduceSumOp $oprd, $axes, $keepdims))>; + +//===----------------------------------------------------------------------===// +// ONNXReduceLogSumExpOp %X = ONNXReduceLogSumOp (ONNXExpOp %X) +//===----------------------------------------------------------------------===// +def ReduceLogSumExpOpPattern: Pat<(ONNXReduceLogSumExpOp $oprd, $axes, $keepdims), + (ONNXReduceLogSumOp (ONNXExpOp $oprd), $axes, $keepdims)>; + +//===----------------------------------------------------------------------===// +// ONNXReduceSumSquareOp %X = ONNXReduceSumOp (ONNXMulOp %X, %X) +//===----------------------------------------------------------------------===// +def ReduceSumSquareOpPattern: Pat<(ONNXReduceSumSquareOp $oprd, $axes, $keepdims), + (ONNXReduceSumOp (ONNXMulOp $oprd, $oprd), $axes, $keepdims)>; + +#endif // ONNX_REWRITE