[MLIR][LHLO] Replace lhlo-copy-removal pass with mlir-copy-removal pass

Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/43137

This PR removes lhlo-copy-removal pass entirely and replace its usages with ```mlir::createCopyRemovalPass()```.

--
7ce1a06f507c8db46c6d7b43c7870cf56002e18e by Ehsan Toosi <ehsan.nadjaran_toosi@dfki.de>:

[mlir][lhlo] Replace lhlo-copy-removal pass with mlir-copy-removal pass

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/43137 from dfki-ehna:using_mlir_copy_removal 7ce1a06f507c8db46c6d7b43c7870cf56002e18e
PiperOrigin-RevId: 331498501
This commit is contained in:
Ehsan Toosi 2020-09-14 01:21:28 -07:00 committed by TensorFlow MLIR Team
parent 7cfcc2c79d
commit ce1c8a1ebc
9 changed files with 10 additions and 233 deletions

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "mlir/IR/Operation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/CopyOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"

View File

@ -34,6 +34,7 @@ limitations under the License.
#define LHLO_OPS
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/CopyOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
@ -616,11 +617,16 @@ def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp {
);
}
def LHLO_CopyOp: LHLO_Op<"copy", []>, BASE_HLO_CopyOp {
def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]>, BASE_HLO_CopyOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
let extraClassDeclaration = [{
Value getSource() { return operand();}
Value getTarget() { return output(); }
}];
}
def LHLO_DotOp: LHLO_Op<"dot", []>, BASE_HLO_DotOp {

View File

@ -15,12 +15,6 @@ limitations under the License.
include "mlir/Pass/PassBase.td"
def LhloCopyRemovalPass : Pass<"lhlo-copy-removal", "FuncOp"> {
let summary = "Removes redundant LHLO copy operations.";
let constructor = "createLhloCopyRemovalPass()";
}
def LhloLegalizeToLinalgPass : Pass<"lhlo-legalize-to-linalg", "FuncOp"> {
let summary = "Legalize from LHLO dialect to Linalg dialect.";
let constructor = "createLegalizeLhloToLinalgPass()";

View File

@ -95,12 +95,6 @@ std::unique_ptr<FunctionPass> createLegalizeToGpuPass();
std::unique_ptr<FunctionPass> createLhloFuseLinalgPass(
bool use_parallel_loops = false, llvm::ArrayRef<unsigned> tile_sizes = {});
// Removes unnecessary LHLO copies which copy from the allocated buffers to the
// block arguments. The block arguments are used instead of all uses of these
// buffers. The buffers are freed. This pass only works in regions that contain
// a single block.
std::unique_ptr<Pass> createLhloCopyRemovalPass();
// Lowers from LHLO dialect to parallel loops.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass();

View File

@ -125,7 +125,6 @@ add_mlir_library(MhloLhloToLinalg
)
add_mlir_library(LmhloPasses
lhlo_copy_removal.cc
lhlo_fuse_linalg.cc
lhlo_legalize_to_affine.cc
lhlo_legalize_to_gpu.cc

View File

@ -1,102 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements a pass to remove redundant LHLO copy operations.
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Operation.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace lmhlo {
namespace {
// Removes LHLO copy operations that copy from allocated buffers to block
// arguments. All uses of each buffer are replaced with the corresponding block
// argument and the buffer is freed. Note that this pass only works in regions
// with a single block.
struct LhloCopyRemovalPass
: mlir::PassWrapper<LhloCopyRemovalPass, OperationPass<>> {
void runOnOperation() override {
llvm::SmallVector<mlir::Operation*, 2> eraseList;
auto operation = getOperation();
operation->walk([&](mlir::lmhlo::CopyOp copyOp) {
// If this region contains more than one block, then ignore this copy
// operation.
if (copyOp.getParentRegion()->getBlocks().size() > 1) {
return;
}
mlir::Value fromOperand = copyOp.operand();
mlir::Value toOperand = copyOp.output();
// If the fromOperand value is a block argument or the toOperand
// value is not a block argument, then ignore this copy operation.
if (!fromOperand.getDefiningOp() || toOperand.getDefiningOp()) {
return;
}
// The copy operation removal is illegal if there is at least a single use
// of toOperand value that lies between the first use of fromOperand value
// and the copy operation.
auto fromOperandUsers = fromOperand.getUsers();
auto firstUser = *fromOperandUsers.begin();
for (auto op : fromOperandUsers) {
if (op->isBeforeInBlock(firstUser)) firstUser = op;
}
for (auto op : toOperand.getUsers()) {
if (op->isBeforeInBlock(copyOp) && firstUser->isBeforeInBlock(op)) {
return;
}
}
// TODO(DFKI): Use live variable analysis to solve aliasing issues among
// block arguments.
// Remove the associated alloc operation.
auto allocOp = fromOperand.getDefiningOp();
eraseList.push_back(allocOp);
// Iterate over all uses of the fromOperand to find the associated
// deallocOp (if any).
for (auto op : fromOperandUsers) {
if (isa<mlir::DeallocOp>(op)) {
eraseList.push_back(op);
break;
}
}
// Replace all uses of the fromOperand with the toOperand. This rewires
// all references pointing to the original alloc operation to the new
// target operation in order to safely remove the copy op.
fromOperand.replaceAllUsesWith(toOperand);
copyOp.erase();
});
for (auto op : eraseList) {
op->erase();
}
};
};
} // namespace
std::unique_ptr<Pass> createLhloCopyRemovalPass() {
return std::make_unique<LhloCopyRemovalPass>();
}
} // namespace lmhlo
} // namespace mlir

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt %s -mhlo-test-chlo-legalize-to-hlo -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -lhlo-copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -canonicalize -cse -convert-linalg-to-llvm -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
// RUN: mlir-hlo-opt %s -mhlo-test-chlo-legalize-to-hlo -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -canonicalize -cse -convert-linalg-to-llvm -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
func @main() -> () {
call @trivial_broadcast_wrapper() : () -> ()

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt %s -mhlo-test-chlo-legalize-to-hlo -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -lhlo-copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -convert-scf-to-std -canonicalize -cse -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
// RUN: mlir-hlo-opt %s -mhlo-test-chlo-legalize-to-hlo -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -convert-scf-to-std -canonicalize -cse -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
func @main() -> () {
call @reshape_with_static_shape_size_matrix_to_1D() : () -> ()

View File

@ -1,115 +0,0 @@
// RUN: mlir-hlo-opt -lhlo-copy-removal %s -o - | FileCheck %s
// CHECK-LABEL: func @remove_simple
func @remove_simple(%arg0: memref<2x2xf32>) {
%0 = alloc() {temp = true} : memref<2x2xf32>
"lmhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
dealloc %0 : memref<2x2xf32>
// CHECK-NEXT: "lmhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
}
// -----
// CHECK-LABEL: func @remove_without_dealloc
func @remove_without_dealloc(%arg0: memref<2x2xf32>) {
%0 = alloc() {temp = true} : memref<2x2xf32>
"lmhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "lmhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
}
// -----
// CHECK-LABEL: func @replace_dependency
func @replace_dependency(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) {
%0 = alloc() {temp = true} : memref<2x2xf32>
"lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.copy"(%0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
dealloc %0 : memref<2x2xf32>
// CHECK-NEXT: "lmhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
}
// -----
// CHECK-LABEL: func @keep_copies
func @keep_copies(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) {
// CHECK-NEXT: "lmhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "lmhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
}
// -----
// CHECK-LABEL: func @must_not_be_removed
func @must_not_be_removed(%arg0: memref<2x2xf32>,
%arg1: memref<2x2xf32>,
%arg2: memref<2x2xf32>) {
// CHECK-NEXT: %[[ALLOC:.*]] = alloc() {temp = true} : memref<2x2xf32>
%0 = alloc() {temp = true} : memref<2x2xf32>
// CHECK-NEXT: "lmhlo.exponential"(%arg0, %[[ALLOC]]) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "lmhlo.copy"(%[[ALLOC]], %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
dealloc %0 : memref<2x2xf32>
"lmhlo.terminator"() : () -> ()
}
// -----
// CHECK-LABEL: func @must_be_removed_first
func @must_be_removed_first(%arg0: memref<2x2xf32>,
%arg1: memref<2x2xf32>,
%arg2: memref<2x2xf32>) {
%0 = alloc() {temp = true} : memref<2x2xf32>
// CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
dealloc %0 : memref<2x2xf32>
"lmhlo.terminator"() : () -> ()
}
// -----
// CHECK-LABEL: func @must_be_removed_second
func @must_be_removed_second(%arg0: memref<2x2xf32>,
%arg1: memref<2x2xf32>,
%arg2: memref<2x2xf32>) {
%0 = alloc() {temp = true} : memref<2x2xf32>
// CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
dealloc %0 : memref<2x2xf32>
"lmhlo.terminator"() : () -> ()
}
// -----
// CHECK-LABEL: func @reduce
func @reduce(%arg0: memref<1x8xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) {
%0 = alloc() : memref<1xf32>
"lmhlo.reduce"(%arg0, %arg1, %0) ( {
// CHECK: ^bb0(%[[ARG0:.*]]: memref<f32>, %[[ARG1:.*]]: memref<f32>,
// CHECK-SAME: %[[ARG2:.*]]: memref<f32>)
^bb0(%arg3: memref<f32>, %arg4: memref<f32>, %arg5: memref<f32>):
%1 = alloc() : memref<f32>
// CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]])
"lmhlo.add"(%arg3, %arg4, %1)
: (memref<f32>, memref<f32>, memref<f32>) -> ()
// CHECK-NOT; lmhlo.copy
"lmhlo.copy"(%1, %arg5) : (memref<f32>, memref<f32>) -> ()
"lmhlo.terminator"() : () -> ()
}) {dimensions = dense<1> : tensor<1xi64>}
: (memref<1x8xf32>, memref<f32>, memref<1xf32>) -> ()
"lmhlo.copy"(%0, %arg2) : (memref<1xf32>, memref<1xf32>) -> ()
return
}