From 282dba6d381c3272377e006b54ae941b739229ed Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Sun, 23 Aug 2020 12:27:48 -0700 Subject: [PATCH] PR #42508: [MLIR] Erase dead lmhlo.constant ops Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/42508 An lmhlo.constant op on an memref that is locally allocated and with no users other than dealloc's can be deleted. Add a canonicalization pattern for this. Copybara import of the project: -- 8758c409a15f567e7cb8e1077faa020f5705c85a by Uday Bondhugula : [MLIR] Erase dead lmhlo.constant ops An lmhlo.constant op on an memref that is locally allocated and with no other users (other than dealloc's) can be deleted. Add a canonicalization patter for this. COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/42508 from polymage-labs:lhlo_constant_erase 8758c409a15f567e7cb8e1077faa020f5705c85a PiperOrigin-RevId: 328042416 --- include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td | 2 ++ lib/Dialect/mhlo/IR/lhlo_ops.cc | 33 ++++++++++++++++++++ tests/canonicalize.mlir | 21 +++++++++++++ 3 files changed, 56 insertions(+) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index 3fa4658..750cce6 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -81,6 +81,8 @@ def LHLO_ConstOp : LHLO_Op<"constant", []>, BASE_HLO_ConstOp { ElementsAttr:$value, Arg:$output ); + + let hasCanonicalizer = 1; } def LHLO_IotaOp : LHLO_Op<"iota", []>, BASE_HLO_IotaOp { diff --git a/lib/Dialect/mhlo/IR/lhlo_ops.cc b/lib/Dialect/mhlo/IR/lhlo_ops.cc index f61a663..81407c8 100644 --- a/lib/Dialect/mhlo/IR/lhlo_ops.cc +++ b/lib/Dialect/mhlo/IR/lhlo_ops.cc @@ -29,6 +29,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" @@ -56,6 +57,38 @@ LmhloDialect::LmhloDialect(MLIRContext *context) >(); } +//===----------------------------------------------------------------------===// +// ConstOp. +//===----------------------------------------------------------------------===// + +/// An lho.constant on an memref that is locally allocated and with no other +/// users (other than dealloc's) can be erased. +// TODO: This can be generalized to an arbitrary op by making use of memory +// effects (write memory effect). +struct EraseConstOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ConstOp op, + PatternRewriter& rewriter) const override { + Value memref = op.output(); + if (!memref.getDefiningOp()) { + return failure(); + } + + // Check that all uses of the memref are either DeallocOps or this op. + for (Operation* user : memref.getUsers()) + if (user != op && !isa(user)) return failure(); + + rewriter.eraseOp(op); + return success(); + } +}; + +void ConstOp::getCanonicalizationPatterns(OwningRewritePatternList& results, + MLIRContext* context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // StaticMemRefCastOp //===----------------------------------------------------------------------===// diff --git a/tests/canonicalize.mlir b/tests/canonicalize.mlir index 15b1a15..0d20c3f 100644 --- a/tests/canonicalize.mlir +++ b/tests/canonicalize.mlir @@ -597,3 +597,24 @@ func @unpack_repack_same_tuple_single_element(%arg0: tuple>) -> tupl // CHECK: return [[ARG0]] return %3 : tuple> } + +// CHECK-LABEL: func @erase_dead_lhlo_constant +func @erase_dead_lhlo_constant() { + %M = alloc() : memref<256x1024xf32> + // CHECK-NEXT: return + "lmhlo.constant"(%M) {value = dense<0.0> : tensor} : (memref<256x1024xf32>) -> () + dealloc %M : memref<256x1024xf32> + return +} + +// A negative test for dead lhlo constant op erasure. +// CHECK-LABEL: func @erase_dead_lhlo_constant_negative +func @erase_dead_lhlo_constant_negative(%M : memref<4xf32>) -> memref<256x1024xf32> { + // CHECK-NEXT: lmhlo.constant + "lmhlo.constant"(%M) {value = dense<0.0> : tensor} : (memref<4xf32>) -> () + // CHECK-NEXT: alloc + // CHECK-NEXT: lmhlo.constant + %N = alloc() : memref<256x1024xf32> + "lmhlo.constant"(%N) {value = dense<0.0> : tensor} : (memref<256x1024xf32>) -> () + return %N : memref<256x1024xf32> +}