PR #42508: [MLIR] Erase dead lmhlo.constant ops
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/42508 An lmhlo.constant op on an memref that is locally allocated and with no users other than dealloc's can be deleted. Add a canonicalization pattern for this. Copybara import of the project: -- 8758c409a15f567e7cb8e1077faa020f5705c85a by Uday Bondhugula <uday@polymagelabs.com>: [MLIR] Erase dead lmhlo.constant ops An lmhlo.constant op on an memref that is locally allocated and with no other users (other than dealloc's) can be deleted. Add a canonicalization patter for this. COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/42508 from polymage-labs:lhlo_constant_erase 8758c409a15f567e7cb8e1077faa020f5705c85a PiperOrigin-RevId: 328042416
This commit is contained in:
parent
bfd629ecb0
commit
282dba6d38
|
@ -81,6 +81,8 @@ def LHLO_ConstOp : LHLO_Op<"constant", []>, BASE_HLO_ConstOp {
|
|||
ElementsAttr:$value,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
||||
);
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def LHLO_IotaOp : LHLO_Op<"iota", []>, BASE_HLO_IotaOp {
|
||||
|
|
|
@ -29,6 +29,7 @@ limitations under the License.
|
|||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
@ -56,6 +57,38 @@ LmhloDialect::LmhloDialect(MLIRContext *context)
|
|||
>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConstOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// An lho.constant on an memref that is locally allocated and with no other
|
||||
/// users (other than dealloc's) can be erased.
|
||||
// TODO: This can be generalized to an arbitrary op by making use of memory
|
||||
// effects (write memory effect).
|
||||
struct EraseConstOp : public OpRewritePattern<ConstOp> {
|
||||
using OpRewritePattern<ConstOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ConstOp op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
Value memref = op.output();
|
||||
if (!memref.getDefiningOp<AllocOp>()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Check that all uses of the memref are either DeallocOps or this op.
|
||||
for (Operation* user : memref.getUsers())
|
||||
if (user != op && !isa<DeallocOp>(user)) return failure();
|
||||
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void ConstOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
|
||||
MLIRContext* context) {
|
||||
results.insert<EraseConstOp>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// StaticMemRefCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -597,3 +597,24 @@ func @unpack_repack_same_tuple_single_element(%arg0: tuple<tensor<i32>>) -> tupl
|
|||
// CHECK: return [[ARG0]]
|
||||
return %3 : tuple<tensor<i32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @erase_dead_lhlo_constant
|
||||
func @erase_dead_lhlo_constant() {
|
||||
%M = alloc() : memref<256x1024xf32>
|
||||
// CHECK-NEXT: return
|
||||
"lmhlo.constant"(%M) {value = dense<0.0> : tensor<f32>} : (memref<256x1024xf32>) -> ()
|
||||
dealloc %M : memref<256x1024xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// A negative test for dead lhlo constant op erasure.
|
||||
// CHECK-LABEL: func @erase_dead_lhlo_constant_negative
|
||||
func @erase_dead_lhlo_constant_negative(%M : memref<4xf32>) -> memref<256x1024xf32> {
|
||||
// CHECK-NEXT: lmhlo.constant
|
||||
"lmhlo.constant"(%M) {value = dense<0.0> : tensor<f32>} : (memref<4xf32>) -> ()
|
||||
// CHECK-NEXT: alloc
|
||||
// CHECK-NEXT: lmhlo.constant
|
||||
%N = alloc() : memref<256x1024xf32>
|
||||
"lmhlo.constant"(%N) {value = dense<0.0> : tensor<f32>} : (memref<256x1024xf32>) -> ()
|
||||
return %N : memref<256x1024xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue