2019-11-13 02:37:46 +08:00
|
|
|
//=- ONNXCombine.td - Pattern Match Optimizations for ONNX -*- tablegen -*-===//
|
|
|
|
//
|
|
|
|
// Copyright 2019 The IBM Research Authors.
|
|
|
|
//
|
|
|
|
// =============================================================================
|
|
|
|
//
|
|
|
|
// Defines language-specific pattern match optimizations for ONNX using
|
|
|
|
// Declarative Rewrite Rules (DRR) specified using TableGen records.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#ifndef ONNX_COMBINE
|
|
|
|
#define ONNX_COMBINE
|
|
|
|
|
|
|
|
#ifndef OP_BASE
|
|
|
|
include "dialect/onnx/onnx.td"
|
|
|
|
#endif // OP_BASE
|
|
|
|
|
|
|
|
/// Note: The DRR definition used for defining patterns is shown below:
|
|
|
|
///
|
|
|
|
/// class Pattern<
|
|
|
|
/// dag sourcePattern, list<dag> resultPatterns,
|
|
|
|
/// list<dag> additionalConstraints = [],
|
|
|
|
/// dag benefitsAdded = (addBenefit 0)
|
|
|
|
/// >;
|
|
|
|
|
2020-01-14 01:21:29 +08:00
|
|
|
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
|
2020-02-08 02:51:44 +08:00
|
|
|
class HasRankOf<int rank> : Constraint<CPred<"$0.getType().isa<ShapedType>() && $0.getType().cast<ShapedType>().getRank() == " # rank>>;
|
2019-11-19 10:23:46 +08:00
|
|
|
|
2019-11-13 02:37:46 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Pattern-Match and Rewrite
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-01-27 23:09:14 +08:00
|
|
|
def GemmAlpha : NativeCodeCall<"$_builder.getF32FloatAttr(1.0)">;
|
|
|
|
def GemmBeta : NativeCodeCall<"$_builder.getF32FloatAttr(1.0)">;
|
|
|
|
def GemmTransA : NativeCodeCall<"$_builder.getI64IntegerAttr(0)">;
|
|
|
|
def GemmTransB : NativeCodeCall<"$_builder.getI64IntegerAttr(0)">;
|
|
|
|
|
2020-01-16 03:27:21 +08:00
|
|
|
// onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.Gemm(%X, %Y, %Z)
|
2019-11-19 10:23:46 +08:00
|
|
|
def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3),
|
2020-01-27 23:09:14 +08:00
|
|
|
(ONNXGemmOp $m1, $m2, $m3, (GemmAlpha), (GemmBeta), (GemmTransA), (GemmTransB)),
|
2020-02-15 05:06:38 +08:00
|
|
|
[(HasOneUse $res), (HasRankOf<2> $m1), (HasRankOf<2> $m2)]>;
|
2019-11-13 02:37:46 +08:00
|
|
|
|
2019-11-21 10:57:13 +08:00
|
|
|
// ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X)
|
|
|
|
def IdentityEliminationPattern : Pat<(ONNXIdentityOp $arg),
|
2020-02-15 05:06:38 +08:00
|
|
|
(replaceWithValue $arg)>;
|
2019-11-21 10:57:13 +08:00
|
|
|
|
2020-02-12 04:32:01 +08:00
|
|
|
def ConstantPadPattern : Pat<(ONNXPadConstantValueOp $m1, (ONNXConstantOp:$res $v1, $v2), $m2, $m3),
|
|
|
|
(ONNXPadConstantValuePadOp $m1, $v2, $m2, $m3),
|
2020-02-15 05:06:38 +08:00
|
|
|
[(HasOneUse $res)]>;
|
2020-02-12 04:32:01 +08:00
|
|
|
|
2019-11-13 02:37:46 +08:00
|
|
|
#endif // ONNX_COMBINE
|