PR #49598: [MLIR][DISC] legalize tensor_load inserted during hlo-to-lhlo conversion

Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/49598

This PR implements logic for lowering memref.tensor_load ops that are
inserted during `mhlo-legalize-to-lmhlo`
Copybara import of the project:

--
80eb377af4e02182e1aecc943a41ca5d7d1c2100 by Wenyi Zhao <reyizero@gmail.com>:

[MLIR][DISC] legalize tensor_load inserted during hlo-to-lhlo conversion

This PR implements logic for lowering memref.tensor_load ops that are
inserted during `mhlo-legalize-to-lmhlo`.

--
ac452fe3dcd591211cd5c59be9189fe2f7153b41 by Wenyi Zhao <reyizero@gmail.com>:

minor fix

--
6b36017f8632a06adbc3e05a62975fa641d0260f by Wenyi Zhao <reyizero@gmail.com>:

minor refine

--
846005cc76d0033112e47825c2e9a97790b6925f by Wenyi Zhao <reyizero@gmail.com>:

minor fix

--
f6a4becaa287d5ca323b2d152a4d0ae053730fd9 by Wenyi Zhao <reyizero@gmail.com>:

fix

--
5555749f60f7fce8f57962860ef65efccf0362ba by Wenyi Zhao <reyizero@gmail.com>:

fix

--
8873b9b6d9315c1199ca9f7c133ecf377ecd2fa6 by Wenyi Zhao <reyizero@gmail.com>:

fix

PiperOrigin-RevId: 376942547
This commit is contained in:
wyzhao 2021-06-01 16:27:07 -07:00 committed by TensorFlow MLIR Team
parent 5baf6e7709
commit 968d4b8709
6 changed files with 161 additions and 0 deletions

19
BUILD
View File

@ -1099,6 +1099,24 @@ cc_library(
], ],
) )
cc_library(
name = "legalize_tensor_load_op",
srcs = ["lib/Dialect/mhlo/transforms/legalize_tensor_load_op.cc"],
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"],
deps = [
":lhlo",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
],
)
cc_library( cc_library(
name = "chlo_legalize_to_hlo", name = "chlo_legalize_to_hlo",
srcs = ["lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc"], srcs = ["lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc"],
@ -1193,6 +1211,7 @@ cc_library(
":hlo_legalize_to_lhlo", ":hlo_legalize_to_lhlo",
":legalize_control_flow", ":legalize_control_flow",
":legalize_gather_to_torch_index_select", ":legalize_gather_to_torch_index_select",
":legalize_tensor_load_op",
":legalize_to_linalg", ":legalize_to_linalg",
":legalize_to_standard", ":legalize_to_standard",
":legalize_trigonometric_to_approximation", ":legalize_trigonometric_to_approximation",

View File

@ -51,3 +51,8 @@ def LhloLegalizeToParallelLoopsPass : Pass<"lhlo-legalize-to-parallel-loops", "F
let constructor = "createLegalizeLhloToParallelLoopsPass()"; let constructor = "createLegalizeLhloToParallelLoopsPass()";
} }
def LegalizeTensorLoadOpPass : Pass<"lhlo-legalize-tensor-load-op", "FuncOp"> {
let summary = "Legalize tensor load ops that are inserted during mhlo to lmhlo conversion.";
let constructor = "createLegalizeTensorLoadOpPass()";
}

View File

@ -111,6 +111,9 @@ std::unique_ptr<FunctionPass> createLhloFuseLinalgPass(
// Lowers from LHLO dialect to parallel loops. // Lowers from LHLO dialect to parallel loops.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass(); std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass();
// Legalizes tensor load ops that are inserted during mhlo to lmhlo conversion.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTensorLoadOpPass();
} // namespace lmhlo } // namespace lmhlo
} // namespace mlir } // namespace mlir

View File

@ -136,6 +136,7 @@ add_mlir_library(MhloLhloToLinalg
) )
add_mlir_library(LmhloPasses add_mlir_library(LmhloPasses
legalize_tensor_load_op.cc
lhlo_fuse_linalg.cc lhlo_fuse_linalg.cc
lhlo_legalize_to_affine.cc lhlo_legalize_to_affine.cc
lhlo_legalize_to_gpu.cc lhlo_legalize_to_gpu.cc

View File

@ -0,0 +1,97 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements logic for lowering memref.tensor_load ops that are
// inserted during `mhlo-legalize-to-lmhlo`.
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
namespace lmhlo {
namespace {
using shape::ShapeOfOp;
using tensor::ExtractOp;
// Converting:
// memref (operand) -> memref.tensor_load -> tensor.extract
// to
// memref (operand) -> memref.load
struct ForwardExtractOp : public OpRewritePattern<ExtractOp> {
using OpRewritePattern<ExtractOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ExtractOp extract,
PatternRewriter& rewriter) const override {
auto tensor_load = extract.tensor().getDefiningOp<memref::TensorLoadOp>();
if (!tensor_load) return failure();
rewriter.replaceOpWithNewOp<memref::LoadOp>(
extract, extract.getType(), tensor_load.memref(), extract.indices());
return success();
}
};
// Converting:
// memref (operand) -> memref.tensor_load -> shape.shape_of
// to
// memref (operand) -> shape.shape_of
struct ForwardShapeOfOp : public OpRewritePattern<ShapeOfOp> {
using OpRewritePattern<ShapeOfOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ShapeOfOp shape_of,
PatternRewriter& rewriter) const override {
auto tensor_load = shape_of.arg().getDefiningOp<memref::TensorLoadOp>();
if (!tensor_load) return failure();
rewriter.replaceOpWithNewOp<ShapeOfOp>(shape_of, shape_of.getType(),
tensor_load.memref());
return success();
}
};
struct LegalizeTensorLoadOpPass
: public mlir::PassWrapper<LegalizeTensorLoadOpPass, FunctionPass> {
// Perform the lowering to remove memref.tensor_load ops inserted during
// `mhlo-legalize-to-lmhlo`.
void runOnFunction() override {
auto func = getFunction();
auto context = &getContext();
OwningRewritePatternList patterns(context);
patterns.insert<ForwardShapeOfOp, ForwardExtractOp>(context);
if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
func.emitError("applyPatternsAndFoldGreedily does not converge");
signalPassFailure();
}
}
};
} // namespace
} // namespace lmhlo
} // namespace mlir
std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
mlir::lmhlo::createLegalizeTensorLoadOpPass() {
return std::make_unique<LegalizeTensorLoadOpPass>();
}

View File

@ -0,0 +1,36 @@
// RUN: mlir-hlo-opt -lhlo-legalize-tensor-load-op %s -o - | FileCheck %s
// test: `memref -> memref.tensor_load -> tensor.extract` -> `memref -> memref.load`
// CHECK-LABEL: forward_extract_op
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<3xindex>)
func @forward_extract_op(%arg0: memref<?x?xf32>, %arg1: memref<3xindex>) -> memref<?x?x?xf32> {
%c0 = constant 0 : index
%c1 = constant 1 : index
%c2 = constant 2 : index
// CHECK-NOT: memref.tensor_load
// CHECK-NOT: tensor.extract
// CHECK: %[[DIM0:.*]] = memref.load %[[ARG1]][%c0]
// CHECK: %[[DIM1:.*]] = memref.load %[[ARG1]][%c1]
// CHECK: %[[DIM2:.*]] = memref.load %[[ARG1]][%c2]
// CHECK: memref.alloc(%[[DIM0]], %[[DIM1]], %[[DIM2]])
%0 = memref.tensor_load %arg1 : memref<3xindex>
%1 = tensor.extract %0[%c0] : tensor<3xindex>
%2 = tensor.extract %0[%c1] : tensor<3xindex>
%3 = tensor.extract %0[%c2] : tensor<3xindex>
%4 = memref.alloc(%1, %2, %3) : memref<?x?x?xf32>
"lmhlo.dynamic_broadcast_in_dim"(%arg0, %arg1, %4) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<?x?xf32>, memref<3xindex>, memref<?x?x?xf32>) -> ()
return %4 : memref<?x?x?xf32>
}
// -----
// test: `memref -> memref.tensor_load -> shape.shape_of` -> `memref -> shape.shape_of`
// CHECK-LABEL: forward_shape_of_op
// CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32>)
func @forward_shape_of_op(%arg0: memref<?x?xf32>) -> tensor<2xindex> {
// CHECK-NOT: memref.tensor_load
// CHECK: shape.shape_of %[[ARG]] : memref<?x?xf32> -> tensor<2xindex>
%0 = memref.tensor_load %arg0 : memref<?x?xf32>
%1 = shape.shape_of %0 : tensor<?x?xf32> -> tensor<2xindex>
return %1 : tensor<2xindex>
}