PR #42508: [MLIR] Erase dead lmhlo.constant ops

Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/42508

An lmhlo.constant op on an memref that is locally allocated and with
no users other than dealloc's can be deleted. Add a canonicalization
pattern for this.
Copybara import of the project:

--
8758c409a15f567e7cb8e1077faa020f5705c85a by Uday Bondhugula <uday@polymagelabs.com>:

[MLIR] Erase dead lmhlo.constant ops

An lmhlo.constant op on an memref that is locally allocated and with
no other users (other than dealloc's) can be deleted. Add a
canonicalization patter for this.

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/42508 from polymage-labs:lhlo_constant_erase 8758c409a15f567e7cb8e1077faa020f5705c85a
PiperOrigin-RevId: 328042416
This commit is contained in:
Uday Bondhugula 2020-08-23 12:27:48 -07:00 committed by TensorFlow MLIR Team
parent bfd629ecb0
commit 282dba6d38
3 changed files with 56 additions and 0 deletions

View File

@ -81,6 +81,8 @@ def LHLO_ConstOp : LHLO_Op<"constant", []>, BASE_HLO_ConstOp {
ElementsAttr:$value,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
let hasCanonicalizer = 1;
}
def LHLO_IotaOp : LHLO_Op<"iota", []>, BASE_HLO_IotaOp {

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
@ -56,6 +57,38 @@ LmhloDialect::LmhloDialect(MLIRContext *context)
>();
}
//===----------------------------------------------------------------------===//
// ConstOp.
//===----------------------------------------------------------------------===//
/// An lho.constant on an memref that is locally allocated and with no other
/// users (other than dealloc's) can be erased.
// TODO: This can be generalized to an arbitrary op by making use of memory
// effects (write memory effect).
struct EraseConstOp : public OpRewritePattern<ConstOp> {
using OpRewritePattern<ConstOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ConstOp op,
PatternRewriter& rewriter) const override {
Value memref = op.output();
if (!memref.getDefiningOp<AllocOp>()) {
return failure();
}
// Check that all uses of the memref are either DeallocOps or this op.
for (Operation* user : memref.getUsers())
if (user != op && !isa<DeallocOp>(user)) return failure();
rewriter.eraseOp(op);
return success();
}
};
void ConstOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
MLIRContext* context) {
results.insert<EraseConstOp>(context);
}
//===----------------------------------------------------------------------===//
// StaticMemRefCastOp
//===----------------------------------------------------------------------===//

View File

@ -597,3 +597,24 @@ func @unpack_repack_same_tuple_single_element(%arg0: tuple<tensor<i32>>) -> tupl
// CHECK: return [[ARG0]]
return %3 : tuple<tensor<i32>>
}
// CHECK-LABEL: func @erase_dead_lhlo_constant
func @erase_dead_lhlo_constant() {
%M = alloc() : memref<256x1024xf32>
// CHECK-NEXT: return
"lmhlo.constant"(%M) {value = dense<0.0> : tensor<f32>} : (memref<256x1024xf32>) -> ()
dealloc %M : memref<256x1024xf32>
return
}
// A negative test for dead lhlo constant op erasure.
// CHECK-LABEL: func @erase_dead_lhlo_constant_negative
func @erase_dead_lhlo_constant_negative(%M : memref<4xf32>) -> memref<256x1024xf32> {
// CHECK-NEXT: lmhlo.constant
"lmhlo.constant"(%M) {value = dense<0.0> : tensor<f32>} : (memref<4xf32>) -> ()
// CHECK-NEXT: alloc
// CHECK-NEXT: lmhlo.constant
%N = alloc() : memref<256x1024xf32>
"lmhlo.constant"(%N) {value = dense<0.0> : tensor<f32>} : (memref<256x1024xf32>) -> ()
return %N : memref<256x1024xf32>
}