PR #42508: [MLIR] Erase dead lmhlo.constant ops
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/42508 An lmhlo.constant op on an memref that is locally allocated and with no users other than dealloc's can be deleted. Add a canonicalization pattern for this. Copybara import of the project: -- 8758c409a15f567e7cb8e1077faa020f5705c85a by Uday Bondhugula <uday@polymagelabs.com>: [MLIR] Erase dead lmhlo.constant ops An lmhlo.constant op on an memref that is locally allocated and with no other users (other than dealloc's) can be deleted. Add a canonicalization patter for this. COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/42508 from polymage-labs:lhlo_constant_erase 8758c409a15f567e7cb8e1077faa020f5705c85a PiperOrigin-RevId: 328042416
This commit is contained in:
parent
bfd629ecb0
commit
282dba6d38
|
@ -81,6 +81,8 @@ def LHLO_ConstOp : LHLO_Op<"constant", []>, BASE_HLO_ConstOp {
|
||||||
ElementsAttr:$value,
|
ElementsAttr:$value,
|
||||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def LHLO_IotaOp : LHLO_Op<"iota", []>, BASE_HLO_IotaOp {
|
def LHLO_IotaOp : LHLO_Op<"iota", []>, BASE_HLO_IotaOp {
|
||||||
|
|
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
|
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
|
@ -56,6 +57,38 @@ LmhloDialect::LmhloDialect(MLIRContext *context)
|
||||||
>();
|
>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ConstOp.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// An lho.constant on an memref that is locally allocated and with no other
|
||||||
|
/// users (other than dealloc's) can be erased.
|
||||||
|
// TODO: This can be generalized to an arbitrary op by making use of memory
|
||||||
|
// effects (write memory effect).
|
||||||
|
struct EraseConstOp : public OpRewritePattern<ConstOp> {
|
||||||
|
using OpRewritePattern<ConstOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(ConstOp op,
|
||||||
|
PatternRewriter& rewriter) const override {
|
||||||
|
Value memref = op.output();
|
||||||
|
if (!memref.getDefiningOp<AllocOp>()) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that all uses of the memref are either DeallocOps or this op.
|
||||||
|
for (Operation* user : memref.getUsers())
|
||||||
|
if (user != op && !isa<DeallocOp>(user)) return failure();
|
||||||
|
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void ConstOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
|
||||||
|
MLIRContext* context) {
|
||||||
|
results.insert<EraseConstOp>(context);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// StaticMemRefCastOp
|
// StaticMemRefCastOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -597,3 +597,24 @@ func @unpack_repack_same_tuple_single_element(%arg0: tuple<tensor<i32>>) -> tupl
|
||||||
// CHECK: return [[ARG0]]
|
// CHECK: return [[ARG0]]
|
||||||
return %3 : tuple<tensor<i32>>
|
return %3 : tuple<tensor<i32>>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @erase_dead_lhlo_constant
|
||||||
|
func @erase_dead_lhlo_constant() {
|
||||||
|
%M = alloc() : memref<256x1024xf32>
|
||||||
|
// CHECK-NEXT: return
|
||||||
|
"lmhlo.constant"(%M) {value = dense<0.0> : tensor<f32>} : (memref<256x1024xf32>) -> ()
|
||||||
|
dealloc %M : memref<256x1024xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// A negative test for dead lhlo constant op erasure.
|
||||||
|
// CHECK-LABEL: func @erase_dead_lhlo_constant_negative
|
||||||
|
func @erase_dead_lhlo_constant_negative(%M : memref<4xf32>) -> memref<256x1024xf32> {
|
||||||
|
// CHECK-NEXT: lmhlo.constant
|
||||||
|
"lmhlo.constant"(%M) {value = dense<0.0> : tensor<f32>} : (memref<4xf32>) -> ()
|
||||||
|
// CHECK-NEXT: alloc
|
||||||
|
// CHECK-NEXT: lmhlo.constant
|
||||||
|
%N = alloc() : memref<256x1024xf32>
|
||||||
|
"lmhlo.constant"(%N) {value = dense<0.0> : tensor<f32>} : (memref<256x1024xf32>) -> ()
|
||||||
|
return %N : memref<256x1024xf32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue