PR #49598: [MLIR][DISC] legalize tensor_load inserted during hlo-to-lhlo conversion
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/49598 This PR implements logic for lowering memref.tensor_load ops that are inserted during `mhlo-legalize-to-lmhlo` Copybara import of the project: -- 80eb377af4e02182e1aecc943a41ca5d7d1c2100 by Wenyi Zhao <reyizero@gmail.com>: [MLIR][DISC] legalize tensor_load inserted during hlo-to-lhlo conversion This PR implements logic for lowering memref.tensor_load ops that are inserted during `mhlo-legalize-to-lmhlo`. -- ac452fe3dcd591211cd5c59be9189fe2f7153b41 by Wenyi Zhao <reyizero@gmail.com>: minor fix -- 6b36017f8632a06adbc3e05a62975fa641d0260f by Wenyi Zhao <reyizero@gmail.com>: minor refine -- 846005cc76d0033112e47825c2e9a97790b6925f by Wenyi Zhao <reyizero@gmail.com>: minor fix -- f6a4becaa287d5ca323b2d152a4d0ae053730fd9 by Wenyi Zhao <reyizero@gmail.com>: fix -- 5555749f60f7fce8f57962860ef65efccf0362ba by Wenyi Zhao <reyizero@gmail.com>: fix -- 8873b9b6d9315c1199ca9f7c133ecf377ecd2fa6 by Wenyi Zhao <reyizero@gmail.com>: fix PiperOrigin-RevId: 376942547
This commit is contained in:
parent
5baf6e7709
commit
968d4b8709
19
BUILD
19
BUILD
|
@ -1099,6 +1099,24 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "legalize_tensor_load_op",
|
||||
srcs = ["lib/Dialect/mhlo/transforms/legalize_tensor_load_op.cc"],
|
||||
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"],
|
||||
deps = [
|
||||
":lhlo",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:MemRefDialect",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Shape",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TensorDialect",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "chlo_legalize_to_hlo",
|
||||
srcs = ["lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc"],
|
||||
|
@ -1193,6 +1211,7 @@ cc_library(
|
|||
":hlo_legalize_to_lhlo",
|
||||
":legalize_control_flow",
|
||||
":legalize_gather_to_torch_index_select",
|
||||
":legalize_tensor_load_op",
|
||||
":legalize_to_linalg",
|
||||
":legalize_to_standard",
|
||||
":legalize_trigonometric_to_approximation",
|
||||
|
|
|
@ -51,3 +51,8 @@ def LhloLegalizeToParallelLoopsPass : Pass<"lhlo-legalize-to-parallel-loops", "F
|
|||
let constructor = "createLegalizeLhloToParallelLoopsPass()";
|
||||
}
|
||||
|
||||
def LegalizeTensorLoadOpPass : Pass<"lhlo-legalize-tensor-load-op", "FuncOp"> {
|
||||
let summary = "Legalize tensor load ops that are inserted during mhlo to lmhlo conversion.";
|
||||
let constructor = "createLegalizeTensorLoadOpPass()";
|
||||
}
|
||||
|
||||
|
|
|
@ -111,6 +111,9 @@ std::unique_ptr<FunctionPass> createLhloFuseLinalgPass(
|
|||
// Lowers from LHLO dialect to parallel loops.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass();
|
||||
|
||||
// Legalizes tensor load ops that are inserted during mhlo to lmhlo conversion.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTensorLoadOpPass();
|
||||
|
||||
} // namespace lmhlo
|
||||
|
||||
} // namespace mlir
|
||||
|
|
|
@ -136,6 +136,7 @@ add_mlir_library(MhloLhloToLinalg
|
|||
)
|
||||
|
||||
add_mlir_library(LmhloPasses
|
||||
legalize_tensor_load_op.cc
|
||||
lhlo_fuse_linalg.cc
|
||||
lhlo_legalize_to_affine.cc
|
||||
lhlo_legalize_to_gpu.cc
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file implements logic for lowering memref.tensor_load ops that are
|
||||
// inserted during `mhlo-legalize-to-lmhlo`.
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace lmhlo {
|
||||
namespace {
|
||||
using shape::ShapeOfOp;
|
||||
using tensor::ExtractOp;
|
||||
|
||||
// Converting:
|
||||
// memref (operand) -> memref.tensor_load -> tensor.extract
|
||||
// to
|
||||
// memref (operand) -> memref.load
|
||||
struct ForwardExtractOp : public OpRewritePattern<ExtractOp> {
|
||||
using OpRewritePattern<ExtractOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ExtractOp extract,
|
||||
PatternRewriter& rewriter) const override {
|
||||
auto tensor_load = extract.tensor().getDefiningOp<memref::TensorLoadOp>();
|
||||
if (!tensor_load) return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<memref::LoadOp>(
|
||||
extract, extract.getType(), tensor_load.memref(), extract.indices());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Converting:
|
||||
// memref (operand) -> memref.tensor_load -> shape.shape_of
|
||||
// to
|
||||
// memref (operand) -> shape.shape_of
|
||||
struct ForwardShapeOfOp : public OpRewritePattern<ShapeOfOp> {
|
||||
using OpRewritePattern<ShapeOfOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ShapeOfOp shape_of,
|
||||
PatternRewriter& rewriter) const override {
|
||||
auto tensor_load = shape_of.arg().getDefiningOp<memref::TensorLoadOp>();
|
||||
if (!tensor_load) return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<ShapeOfOp>(shape_of, shape_of.getType(),
|
||||
tensor_load.memref());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct LegalizeTensorLoadOpPass
|
||||
: public mlir::PassWrapper<LegalizeTensorLoadOpPass, FunctionPass> {
|
||||
// Perform the lowering to remove memref.tensor_load ops inserted during
|
||||
// `mhlo-legalize-to-lmhlo`.
|
||||
void runOnFunction() override {
|
||||
auto func = getFunction();
|
||||
auto context = &getContext();
|
||||
OwningRewritePatternList patterns(context);
|
||||
patterns.insert<ForwardShapeOfOp, ForwardExtractOp>(context);
|
||||
if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
|
||||
func.emitError("applyPatternsAndFoldGreedily does not converge");
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
||||
|
||||
std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
|
||||
mlir::lmhlo::createLegalizeTensorLoadOpPass() {
|
||||
return std::make_unique<LegalizeTensorLoadOpPass>();
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
// RUN: mlir-hlo-opt -lhlo-legalize-tensor-load-op %s -o - | FileCheck %s
|
||||
|
||||
// test: `memref -> memref.tensor_load -> tensor.extract` -> `memref -> memref.load`
|
||||
// CHECK-LABEL: forward_extract_op
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<3xindex>)
|
||||
func @forward_extract_op(%arg0: memref<?x?xf32>, %arg1: memref<3xindex>) -> memref<?x?x?xf32> {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%c2 = constant 2 : index
|
||||
// CHECK-NOT: memref.tensor_load
|
||||
// CHECK-NOT: tensor.extract
|
||||
// CHECK: %[[DIM0:.*]] = memref.load %[[ARG1]][%c0]
|
||||
// CHECK: %[[DIM1:.*]] = memref.load %[[ARG1]][%c1]
|
||||
// CHECK: %[[DIM2:.*]] = memref.load %[[ARG1]][%c2]
|
||||
// CHECK: memref.alloc(%[[DIM0]], %[[DIM1]], %[[DIM2]])
|
||||
%0 = memref.tensor_load %arg1 : memref<3xindex>
|
||||
%1 = tensor.extract %0[%c0] : tensor<3xindex>
|
||||
%2 = tensor.extract %0[%c1] : tensor<3xindex>
|
||||
%3 = tensor.extract %0[%c2] : tensor<3xindex>
|
||||
%4 = memref.alloc(%1, %2, %3) : memref<?x?x?xf32>
|
||||
"lmhlo.dynamic_broadcast_in_dim"(%arg0, %arg1, %4) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<?x?xf32>, memref<3xindex>, memref<?x?x?xf32>) -> ()
|
||||
return %4 : memref<?x?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// test: `memref -> memref.tensor_load -> shape.shape_of` -> `memref -> shape.shape_of`
|
||||
// CHECK-LABEL: forward_shape_of_op
|
||||
// CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32>)
|
||||
func @forward_shape_of_op(%arg0: memref<?x?xf32>) -> tensor<2xindex> {
|
||||
// CHECK-NOT: memref.tensor_load
|
||||
// CHECK: shape.shape_of %[[ARG]] : memref<?x?xf32> -> tensor<2xindex>
|
||||
%0 = memref.tensor_load %arg0 : memref<?x?xf32>
|
||||
%1 = shape.shape_of %0 : tensor<?x?xf32> -> tensor<2xindex>
|
||||
return %1 : tensor<2xindex>
|
||||
}
|
Loading…
Reference in New Issue