[MLIR][LHLO] Replace lhlo-copy-removal pass with mlir-copy-removal pass
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/43137 This PR removes lhlo-copy-removal pass entirely and replace its usages with ```mlir::createCopyRemovalPass()```. -- 7ce1a06f507c8db46c6d7b43c7870cf56002e18e by Ehsan Toosi <ehsan.nadjaran_toosi@dfki.de>: [mlir][lhlo] Replace lhlo-copy-removal pass with mlir-copy-removal pass COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/43137 from dfki-ehna:using_mlir_copy_removal 7ce1a06f507c8db46c6d7b43c7870cf56002e18e PiperOrigin-RevId: 331498501
This commit is contained in:
parent
7cfcc2c79d
commit
ce1c8a1ebc
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Interfaces/CopyOpInterface.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "mlir/Interfaces/ViewLikeInterface.h"
|
||||
|
||||
|
|
|
@ -34,6 +34,7 @@ limitations under the License.
|
|||
#define LHLO_OPS
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/CopyOpInterface.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/ViewLikeInterface.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
||||
|
@ -616,11 +617,16 @@ def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp {
|
|||
);
|
||||
}
|
||||
|
||||
def LHLO_CopyOp: LHLO_Op<"copy", []>, BASE_HLO_CopyOp {
|
||||
def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]>, BASE_HLO_CopyOp {
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
Value getSource() { return operand();}
|
||||
Value getTarget() { return output(); }
|
||||
}];
|
||||
}
|
||||
|
||||
def LHLO_DotOp: LHLO_Op<"dot", []>, BASE_HLO_DotOp {
|
||||
|
|
|
@ -15,12 +15,6 @@ limitations under the License.
|
|||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def LhloCopyRemovalPass : Pass<"lhlo-copy-removal", "FuncOp"> {
|
||||
let summary = "Removes redundant LHLO copy operations.";
|
||||
let constructor = "createLhloCopyRemovalPass()";
|
||||
}
|
||||
|
||||
|
||||
def LhloLegalizeToLinalgPass : Pass<"lhlo-legalize-to-linalg", "FuncOp"> {
|
||||
let summary = "Legalize from LHLO dialect to Linalg dialect.";
|
||||
let constructor = "createLegalizeLhloToLinalgPass()";
|
||||
|
|
|
@ -95,12 +95,6 @@ std::unique_ptr<FunctionPass> createLegalizeToGpuPass();
|
|||
std::unique_ptr<FunctionPass> createLhloFuseLinalgPass(
|
||||
bool use_parallel_loops = false, llvm::ArrayRef<unsigned> tile_sizes = {});
|
||||
|
||||
// Removes unnecessary LHLO copies which copy from the allocated buffers to the
|
||||
// block arguments. The block arguments are used instead of all uses of these
|
||||
// buffers. The buffers are freed. This pass only works in regions that contain
|
||||
// a single block.
|
||||
std::unique_ptr<Pass> createLhloCopyRemovalPass();
|
||||
|
||||
// Lowers from LHLO dialect to parallel loops.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass();
|
||||
|
||||
|
|
|
@ -125,7 +125,6 @@ add_mlir_library(MhloLhloToLinalg
|
|||
)
|
||||
|
||||
add_mlir_library(LmhloPasses
|
||||
lhlo_copy_removal.cc
|
||||
lhlo_fuse_linalg.cc
|
||||
lhlo_legalize_to_affine.cc
|
||||
lhlo_legalize_to_gpu.cc
|
||||
|
|
|
@ -1,102 +0,0 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file implements a pass to remove redundant LHLO copy operations.
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace lmhlo {
|
||||
namespace {
|
||||
|
||||
// Removes LHLO copy operations that copy from allocated buffers to block
|
||||
// arguments. All uses of each buffer are replaced with the corresponding block
|
||||
// argument and the buffer is freed. Note that this pass only works in regions
|
||||
// with a single block.
|
||||
struct LhloCopyRemovalPass
|
||||
: mlir::PassWrapper<LhloCopyRemovalPass, OperationPass<>> {
|
||||
void runOnOperation() override {
|
||||
llvm::SmallVector<mlir::Operation*, 2> eraseList;
|
||||
auto operation = getOperation();
|
||||
operation->walk([&](mlir::lmhlo::CopyOp copyOp) {
|
||||
// If this region contains more than one block, then ignore this copy
|
||||
// operation.
|
||||
if (copyOp.getParentRegion()->getBlocks().size() > 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
mlir::Value fromOperand = copyOp.operand();
|
||||
mlir::Value toOperand = copyOp.output();
|
||||
|
||||
// If the fromOperand value is a block argument or the toOperand
|
||||
// value is not a block argument, then ignore this copy operation.
|
||||
if (!fromOperand.getDefiningOp() || toOperand.getDefiningOp()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// The copy operation removal is illegal if there is at least a single use
|
||||
// of toOperand value that lies between the first use of fromOperand value
|
||||
// and the copy operation.
|
||||
auto fromOperandUsers = fromOperand.getUsers();
|
||||
auto firstUser = *fromOperandUsers.begin();
|
||||
for (auto op : fromOperandUsers) {
|
||||
if (op->isBeforeInBlock(firstUser)) firstUser = op;
|
||||
}
|
||||
for (auto op : toOperand.getUsers()) {
|
||||
if (op->isBeforeInBlock(copyOp) && firstUser->isBeforeInBlock(op)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(DFKI): Use live variable analysis to solve aliasing issues among
|
||||
// block arguments.
|
||||
|
||||
// Remove the associated alloc operation.
|
||||
auto allocOp = fromOperand.getDefiningOp();
|
||||
eraseList.push_back(allocOp);
|
||||
|
||||
// Iterate over all uses of the fromOperand to find the associated
|
||||
// deallocOp (if any).
|
||||
for (auto op : fromOperandUsers) {
|
||||
if (isa<mlir::DeallocOp>(op)) {
|
||||
eraseList.push_back(op);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Replace all uses of the fromOperand with the toOperand. This rewires
|
||||
// all references pointing to the original alloc operation to the new
|
||||
// target operation in order to safely remove the copy op.
|
||||
fromOperand.replaceAllUsesWith(toOperand);
|
||||
copyOp.erase();
|
||||
});
|
||||
for (auto op : eraseList) {
|
||||
op->erase();
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createLhloCopyRemovalPass() {
|
||||
return std::make_unique<LhloCopyRemovalPass>();
|
||||
}
|
||||
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-hlo-opt %s -mhlo-test-chlo-legalize-to-hlo -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -lhlo-copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -canonicalize -cse -convert-linalg-to-llvm -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
|
||||
// RUN: mlir-hlo-opt %s -mhlo-test-chlo-legalize-to-hlo -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -canonicalize -cse -convert-linalg-to-llvm -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
|
||||
|
||||
func @main() -> () {
|
||||
call @trivial_broadcast_wrapper() : () -> ()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-hlo-opt %s -mhlo-test-chlo-legalize-to-hlo -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -lhlo-copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -convert-scf-to-std -canonicalize -cse -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
|
||||
// RUN: mlir-hlo-opt %s -mhlo-test-chlo-legalize-to-hlo -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -convert-scf-to-std -canonicalize -cse -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
|
||||
|
||||
func @main() -> () {
|
||||
call @reshape_with_static_shape_size_matrix_to_1D() : () -> ()
|
||||
|
|
|
@ -1,115 +0,0 @@
|
|||
// RUN: mlir-hlo-opt -lhlo-copy-removal %s -o - | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @remove_simple
|
||||
func @remove_simple(%arg0: memref<2x2xf32>) {
|
||||
%0 = alloc() {temp = true} : memref<2x2xf32>
|
||||
"lmhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
dealloc %0 : memref<2x2xf32>
|
||||
// CHECK-NEXT: "lmhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @remove_without_dealloc
|
||||
func @remove_without_dealloc(%arg0: memref<2x2xf32>) {
|
||||
%0 = alloc() {temp = true} : memref<2x2xf32>
|
||||
"lmhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "lmhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @replace_dependency
|
||||
func @replace_dependency(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) {
|
||||
%0 = alloc() {temp = true} : memref<2x2xf32>
|
||||
"lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.copy"(%0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
dealloc %0 : memref<2x2xf32>
|
||||
// CHECK-NEXT: "lmhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @keep_copies
|
||||
func @keep_copies(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) {
|
||||
// CHECK-NEXT: "lmhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "lmhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @must_not_be_removed
|
||||
func @must_not_be_removed(%arg0: memref<2x2xf32>,
|
||||
%arg1: memref<2x2xf32>,
|
||||
%arg2: memref<2x2xf32>) {
|
||||
// CHECK-NEXT: %[[ALLOC:.*]] = alloc() {temp = true} : memref<2x2xf32>
|
||||
%0 = alloc() {temp = true} : memref<2x2xf32>
|
||||
// CHECK-NEXT: "lmhlo.exponential"(%arg0, %[[ALLOC]]) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "lmhlo.copy"(%[[ALLOC]], %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
dealloc %0 : memref<2x2xf32>
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @must_be_removed_first
|
||||
func @must_be_removed_first(%arg0: memref<2x2xf32>,
|
||||
%arg1: memref<2x2xf32>,
|
||||
%arg2: memref<2x2xf32>) {
|
||||
%0 = alloc() {temp = true} : memref<2x2xf32>
|
||||
// CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
dealloc %0 : memref<2x2xf32>
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @must_be_removed_second
|
||||
func @must_be_removed_second(%arg0: memref<2x2xf32>,
|
||||
%arg1: memref<2x2xf32>,
|
||||
%arg2: memref<2x2xf32>) {
|
||||
%0 = alloc() {temp = true} : memref<2x2xf32>
|
||||
// CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
dealloc %0 : memref<2x2xf32>
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @reduce
|
||||
func @reduce(%arg0: memref<1x8xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) {
|
||||
%0 = alloc() : memref<1xf32>
|
||||
"lmhlo.reduce"(%arg0, %arg1, %0) ( {
|
||||
// CHECK: ^bb0(%[[ARG0:.*]]: memref<f32>, %[[ARG1:.*]]: memref<f32>,
|
||||
// CHECK-SAME: %[[ARG2:.*]]: memref<f32>)
|
||||
^bb0(%arg3: memref<f32>, %arg4: memref<f32>, %arg5: memref<f32>):
|
||||
%1 = alloc() : memref<f32>
|
||||
// CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]])
|
||||
"lmhlo.add"(%arg3, %arg4, %1)
|
||||
: (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
// CHECK-NOT; lmhlo.copy
|
||||
"lmhlo.copy"(%1, %arg5) : (memref<f32>, memref<f32>) -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}) {dimensions = dense<1> : tensor<1xi64>}
|
||||
: (memref<1x8xf32>, memref<f32>, memref<1xf32>) -> ()
|
||||
"lmhlo.copy"(%0, %arg2) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
Loading…
Reference in New Issue