[MLIR][LHLO] Replace lhlo-copy-removal pass with mlir-copy-removal pass
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/43137 This PR removes lhlo-copy-removal pass entirely and replace its usages with ```mlir::createCopyRemovalPass()```. -- 7ce1a06f507c8db46c6d7b43c7870cf56002e18e by Ehsan Toosi <ehsan.nadjaran_toosi@dfki.de>: [mlir][lhlo] Replace lhlo-copy-removal pass with mlir-copy-removal pass COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/43137 from dfki-ehna:using_mlir_copy_removal 7ce1a06f507c8db46c6d7b43c7870cf56002e18e PiperOrigin-RevId: 331498501
This commit is contained in:
parent
7cfcc2c79d
commit
ce1c8a1ebc
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||||
#include "mlir/IR/Operation.h"
|
#include "mlir/IR/Operation.h"
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
#include "mlir/IR/Types.h"
|
#include "mlir/IR/Types.h"
|
||||||
|
#include "mlir/Interfaces/CopyOpInterface.h"
|
||||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||||
#include "mlir/Interfaces/ViewLikeInterface.h"
|
#include "mlir/Interfaces/ViewLikeInterface.h"
|
||||||
|
|
||||||
|
|
|
@ -34,6 +34,7 @@ limitations under the License.
|
||||||
#define LHLO_OPS
|
#define LHLO_OPS
|
||||||
|
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
|
include "mlir/Interfaces/CopyOpInterface.td"
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
include "mlir/Interfaces/ViewLikeInterface.td"
|
include "mlir/Interfaces/ViewLikeInterface.td"
|
||||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
||||||
|
@ -616,11 +617,16 @@ def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
def LHLO_CopyOp: LHLO_Op<"copy", []>, BASE_HLO_CopyOp {
|
def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]>, BASE_HLO_CopyOp {
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
Value getSource() { return operand();}
|
||||||
|
Value getTarget() { return output(); }
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def LHLO_DotOp: LHLO_Op<"dot", []>, BASE_HLO_DotOp {
|
def LHLO_DotOp: LHLO_Op<"dot", []>, BASE_HLO_DotOp {
|
||||||
|
|
|
@ -15,12 +15,6 @@ limitations under the License.
|
||||||
|
|
||||||
include "mlir/Pass/PassBase.td"
|
include "mlir/Pass/PassBase.td"
|
||||||
|
|
||||||
def LhloCopyRemovalPass : Pass<"lhlo-copy-removal", "FuncOp"> {
|
|
||||||
let summary = "Removes redundant LHLO copy operations.";
|
|
||||||
let constructor = "createLhloCopyRemovalPass()";
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def LhloLegalizeToLinalgPass : Pass<"lhlo-legalize-to-linalg", "FuncOp"> {
|
def LhloLegalizeToLinalgPass : Pass<"lhlo-legalize-to-linalg", "FuncOp"> {
|
||||||
let summary = "Legalize from LHLO dialect to Linalg dialect.";
|
let summary = "Legalize from LHLO dialect to Linalg dialect.";
|
||||||
let constructor = "createLegalizeLhloToLinalgPass()";
|
let constructor = "createLegalizeLhloToLinalgPass()";
|
||||||
|
|
|
@ -95,12 +95,6 @@ std::unique_ptr<FunctionPass> createLegalizeToGpuPass();
|
||||||
std::unique_ptr<FunctionPass> createLhloFuseLinalgPass(
|
std::unique_ptr<FunctionPass> createLhloFuseLinalgPass(
|
||||||
bool use_parallel_loops = false, llvm::ArrayRef<unsigned> tile_sizes = {});
|
bool use_parallel_loops = false, llvm::ArrayRef<unsigned> tile_sizes = {});
|
||||||
|
|
||||||
// Removes unnecessary LHLO copies which copy from the allocated buffers to the
|
|
||||||
// block arguments. The block arguments are used instead of all uses of these
|
|
||||||
// buffers. The buffers are freed. This pass only works in regions that contain
|
|
||||||
// a single block.
|
|
||||||
std::unique_ptr<Pass> createLhloCopyRemovalPass();
|
|
||||||
|
|
||||||
// Lowers from LHLO dialect to parallel loops.
|
// Lowers from LHLO dialect to parallel loops.
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass();
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass();
|
||||||
|
|
||||||
|
|
|
@ -125,7 +125,6 @@ add_mlir_library(MhloLhloToLinalg
|
||||||
)
|
)
|
||||||
|
|
||||||
add_mlir_library(LmhloPasses
|
add_mlir_library(LmhloPasses
|
||||||
lhlo_copy_removal.cc
|
|
||||||
lhlo_fuse_linalg.cc
|
lhlo_fuse_linalg.cc
|
||||||
lhlo_legalize_to_affine.cc
|
lhlo_legalize_to_affine.cc
|
||||||
lhlo_legalize_to_gpu.cc
|
lhlo_legalize_to_gpu.cc
|
||||||
|
|
|
@ -1,102 +0,0 @@
|
||||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
// This file implements a pass to remove redundant LHLO copy operations.
|
|
||||||
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
||||||
#include "mlir/IR/Operation.h"
|
|
||||||
#include "mlir/Pass/Pass.h"
|
|
||||||
|
|
||||||
namespace mlir {
|
|
||||||
namespace lmhlo {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
// Removes LHLO copy operations that copy from allocated buffers to block
|
|
||||||
// arguments. All uses of each buffer are replaced with the corresponding block
|
|
||||||
// argument and the buffer is freed. Note that this pass only works in regions
|
|
||||||
// with a single block.
|
|
||||||
struct LhloCopyRemovalPass
|
|
||||||
: mlir::PassWrapper<LhloCopyRemovalPass, OperationPass<>> {
|
|
||||||
void runOnOperation() override {
|
|
||||||
llvm::SmallVector<mlir::Operation*, 2> eraseList;
|
|
||||||
auto operation = getOperation();
|
|
||||||
operation->walk([&](mlir::lmhlo::CopyOp copyOp) {
|
|
||||||
// If this region contains more than one block, then ignore this copy
|
|
||||||
// operation.
|
|
||||||
if (copyOp.getParentRegion()->getBlocks().size() > 1) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
mlir::Value fromOperand = copyOp.operand();
|
|
||||||
mlir::Value toOperand = copyOp.output();
|
|
||||||
|
|
||||||
// If the fromOperand value is a block argument or the toOperand
|
|
||||||
// value is not a block argument, then ignore this copy operation.
|
|
||||||
if (!fromOperand.getDefiningOp() || toOperand.getDefiningOp()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// The copy operation removal is illegal if there is at least a single use
|
|
||||||
// of toOperand value that lies between the first use of fromOperand value
|
|
||||||
// and the copy operation.
|
|
||||||
auto fromOperandUsers = fromOperand.getUsers();
|
|
||||||
auto firstUser = *fromOperandUsers.begin();
|
|
||||||
for (auto op : fromOperandUsers) {
|
|
||||||
if (op->isBeforeInBlock(firstUser)) firstUser = op;
|
|
||||||
}
|
|
||||||
for (auto op : toOperand.getUsers()) {
|
|
||||||
if (op->isBeforeInBlock(copyOp) && firstUser->isBeforeInBlock(op)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(DFKI): Use live variable analysis to solve aliasing issues among
|
|
||||||
// block arguments.
|
|
||||||
|
|
||||||
// Remove the associated alloc operation.
|
|
||||||
auto allocOp = fromOperand.getDefiningOp();
|
|
||||||
eraseList.push_back(allocOp);
|
|
||||||
|
|
||||||
// Iterate over all uses of the fromOperand to find the associated
|
|
||||||
// deallocOp (if any).
|
|
||||||
for (auto op : fromOperandUsers) {
|
|
||||||
if (isa<mlir::DeallocOp>(op)) {
|
|
||||||
eraseList.push_back(op);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Replace all uses of the fromOperand with the toOperand. This rewires
|
|
||||||
// all references pointing to the original alloc operation to the new
|
|
||||||
// target operation in order to safely remove the copy op.
|
|
||||||
fromOperand.replaceAllUsesWith(toOperand);
|
|
||||||
copyOp.erase();
|
|
||||||
});
|
|
||||||
for (auto op : eraseList) {
|
|
||||||
op->erase();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
std::unique_ptr<Pass> createLhloCopyRemovalPass() {
|
|
||||||
return std::make_unique<LhloCopyRemovalPass>();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace lmhlo
|
|
||||||
} // namespace mlir
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: mlir-hlo-opt %s -mhlo-test-chlo-legalize-to-hlo -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -lhlo-copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -canonicalize -cse -convert-linalg-to-llvm -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
|
// RUN: mlir-hlo-opt %s -mhlo-test-chlo-legalize-to-hlo -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -canonicalize -cse -convert-linalg-to-llvm -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
|
||||||
|
|
||||||
func @main() -> () {
|
func @main() -> () {
|
||||||
call @trivial_broadcast_wrapper() : () -> ()
|
call @trivial_broadcast_wrapper() : () -> ()
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: mlir-hlo-opt %s -mhlo-test-chlo-legalize-to-hlo -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -lhlo-copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -convert-scf-to-std -canonicalize -cse -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
|
// RUN: mlir-hlo-opt %s -mhlo-test-chlo-legalize-to-hlo -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -convert-scf-to-std -canonicalize -cse -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
|
||||||
|
|
||||||
func @main() -> () {
|
func @main() -> () {
|
||||||
call @reshape_with_static_shape_size_matrix_to_1D() : () -> ()
|
call @reshape_with_static_shape_size_matrix_to_1D() : () -> ()
|
||||||
|
|
|
@ -1,115 +0,0 @@
|
||||||
// RUN: mlir-hlo-opt -lhlo-copy-removal %s -o - | FileCheck %s
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @remove_simple
|
|
||||||
func @remove_simple(%arg0: memref<2x2xf32>) {
|
|
||||||
%0 = alloc() {temp = true} : memref<2x2xf32>
|
|
||||||
"lmhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
|
||||||
dealloc %0 : memref<2x2xf32>
|
|
||||||
// CHECK-NEXT: "lmhlo.terminator"() : () -> ()
|
|
||||||
"lmhlo.terminator"() : () -> ()
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @remove_without_dealloc
|
|
||||||
func @remove_without_dealloc(%arg0: memref<2x2xf32>) {
|
|
||||||
%0 = alloc() {temp = true} : memref<2x2xf32>
|
|
||||||
"lmhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
|
||||||
// CHECK-NEXT: "lmhlo.terminator"() : () -> ()
|
|
||||||
"lmhlo.terminator"() : () -> ()
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @replace_dependency
|
|
||||||
func @replace_dependency(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) {
|
|
||||||
%0 = alloc() {temp = true} : memref<2x2xf32>
|
|
||||||
"lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
|
||||||
// CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
|
||||||
"lmhlo.copy"(%0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
|
||||||
dealloc %0 : memref<2x2xf32>
|
|
||||||
// CHECK-NEXT: "lmhlo.terminator"() : () -> ()
|
|
||||||
"lmhlo.terminator"() : () -> ()
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @keep_copies
|
|
||||||
func @keep_copies(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) {
|
|
||||||
// CHECK-NEXT: "lmhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
|
||||||
"lmhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
|
||||||
// CHECK-NEXT: "lmhlo.terminator"() : () -> ()
|
|
||||||
"lmhlo.terminator"() : () -> ()
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @must_not_be_removed
|
|
||||||
func @must_not_be_removed(%arg0: memref<2x2xf32>,
|
|
||||||
%arg1: memref<2x2xf32>,
|
|
||||||
%arg2: memref<2x2xf32>) {
|
|
||||||
// CHECK-NEXT: %[[ALLOC:.*]] = alloc() {temp = true} : memref<2x2xf32>
|
|
||||||
%0 = alloc() {temp = true} : memref<2x2xf32>
|
|
||||||
// CHECK-NEXT: "lmhlo.exponential"(%arg0, %[[ALLOC]]) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
|
||||||
"lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
|
||||||
// CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
|
||||||
"lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
|
||||||
// CHECK-NEXT: "lmhlo.copy"(%[[ALLOC]], %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
|
||||||
"lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
|
||||||
dealloc %0 : memref<2x2xf32>
|
|
||||||
"lmhlo.terminator"() : () -> ()
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @must_be_removed_first
|
|
||||||
func @must_be_removed_first(%arg0: memref<2x2xf32>,
|
|
||||||
%arg1: memref<2x2xf32>,
|
|
||||||
%arg2: memref<2x2xf32>) {
|
|
||||||
%0 = alloc() {temp = true} : memref<2x2xf32>
|
|
||||||
// CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
|
||||||
"lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
|
||||||
// CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
|
||||||
"lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
|
||||||
"lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
|
||||||
dealloc %0 : memref<2x2xf32>
|
|
||||||
"lmhlo.terminator"() : () -> ()
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @must_be_removed_second
|
|
||||||
func @must_be_removed_second(%arg0: memref<2x2xf32>,
|
|
||||||
%arg1: memref<2x2xf32>,
|
|
||||||
%arg2: memref<2x2xf32>) {
|
|
||||||
%0 = alloc() {temp = true} : memref<2x2xf32>
|
|
||||||
// CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
|
||||||
"lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
|
||||||
"lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
|
||||||
// CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
|
||||||
"lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
|
||||||
dealloc %0 : memref<2x2xf32>
|
|
||||||
"lmhlo.terminator"() : () -> ()
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @reduce
|
|
||||||
func @reduce(%arg0: memref<1x8xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) {
|
|
||||||
%0 = alloc() : memref<1xf32>
|
|
||||||
"lmhlo.reduce"(%arg0, %arg1, %0) ( {
|
|
||||||
// CHECK: ^bb0(%[[ARG0:.*]]: memref<f32>, %[[ARG1:.*]]: memref<f32>,
|
|
||||||
// CHECK-SAME: %[[ARG2:.*]]: memref<f32>)
|
|
||||||
^bb0(%arg3: memref<f32>, %arg4: memref<f32>, %arg5: memref<f32>):
|
|
||||||
%1 = alloc() : memref<f32>
|
|
||||||
// CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]])
|
|
||||||
"lmhlo.add"(%arg3, %arg4, %1)
|
|
||||||
: (memref<f32>, memref<f32>, memref<f32>) -> ()
|
|
||||||
// CHECK-NOT; lmhlo.copy
|
|
||||||
"lmhlo.copy"(%1, %arg5) : (memref<f32>, memref<f32>) -> ()
|
|
||||||
"lmhlo.terminator"() : () -> ()
|
|
||||||
}) {dimensions = dense<1> : tensor<1xi64>}
|
|
||||||
: (memref<1x8xf32>, memref<f32>, memref<1xf32>) -> ()
|
|
||||||
"lmhlo.copy"(%0, %arg2) : (memref<1xf32>, memref<1xf32>) -> ()
|
|
||||||
return
|
|
||||||
}
|
|
Loading…
Reference in New Issue