PR #49598: [MLIR][DISC] legalize tensor_load inserted during hlo-to-lhlo conversion
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/49598 This PR implements logic for lowering memref.tensor_load ops that are inserted during `mhlo-legalize-to-lmhlo` Copybara import of the project: -- 80eb377af4e02182e1aecc943a41ca5d7d1c2100 by Wenyi Zhao <reyizero@gmail.com>: [MLIR][DISC] legalize tensor_load inserted during hlo-to-lhlo conversion This PR implements logic for lowering memref.tensor_load ops that are inserted during `mhlo-legalize-to-lmhlo`. -- ac452fe3dcd591211cd5c59be9189fe2f7153b41 by Wenyi Zhao <reyizero@gmail.com>: minor fix -- 6b36017f8632a06adbc3e05a62975fa641d0260f by Wenyi Zhao <reyizero@gmail.com>: minor refine -- 846005cc76d0033112e47825c2e9a97790b6925f by Wenyi Zhao <reyizero@gmail.com>: minor fix -- f6a4becaa287d5ca323b2d152a4d0ae053730fd9 by Wenyi Zhao <reyizero@gmail.com>: fix -- 5555749f60f7fce8f57962860ef65efccf0362ba by Wenyi Zhao <reyizero@gmail.com>: fix -- 8873b9b6d9315c1199ca9f7c133ecf377ecd2fa6 by Wenyi Zhao <reyizero@gmail.com>: fix PiperOrigin-RevId: 376942547
This commit is contained in:
parent
5baf6e7709
commit
968d4b8709
19
BUILD
19
BUILD
|
@ -1099,6 +1099,24 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "legalize_tensor_load_op",
|
||||||
|
srcs = ["lib/Dialect/mhlo/transforms/legalize_tensor_load_op.cc"],
|
||||||
|
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"],
|
||||||
|
deps = [
|
||||||
|
":lhlo",
|
||||||
|
"@llvm-project//llvm:Support",
|
||||||
|
"@llvm-project//mlir:IR",
|
||||||
|
"@llvm-project//mlir:MemRefDialect",
|
||||||
|
"@llvm-project//mlir:Pass",
|
||||||
|
"@llvm-project//mlir:Shape",
|
||||||
|
"@llvm-project//mlir:StandardOps",
|
||||||
|
"@llvm-project//mlir:Support",
|
||||||
|
"@llvm-project//mlir:TensorDialect",
|
||||||
|
"@llvm-project//mlir:Transforms",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "chlo_legalize_to_hlo",
|
name = "chlo_legalize_to_hlo",
|
||||||
srcs = ["lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc"],
|
srcs = ["lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc"],
|
||||||
|
@ -1193,6 +1211,7 @@ cc_library(
|
||||||
":hlo_legalize_to_lhlo",
|
":hlo_legalize_to_lhlo",
|
||||||
":legalize_control_flow",
|
":legalize_control_flow",
|
||||||
":legalize_gather_to_torch_index_select",
|
":legalize_gather_to_torch_index_select",
|
||||||
|
":legalize_tensor_load_op",
|
||||||
":legalize_to_linalg",
|
":legalize_to_linalg",
|
||||||
":legalize_to_standard",
|
":legalize_to_standard",
|
||||||
":legalize_trigonometric_to_approximation",
|
":legalize_trigonometric_to_approximation",
|
||||||
|
|
|
@ -51,3 +51,8 @@ def LhloLegalizeToParallelLoopsPass : Pass<"lhlo-legalize-to-parallel-loops", "F
|
||||||
let constructor = "createLegalizeLhloToParallelLoopsPass()";
|
let constructor = "createLegalizeLhloToParallelLoopsPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def LegalizeTensorLoadOpPass : Pass<"lhlo-legalize-tensor-load-op", "FuncOp"> {
|
||||||
|
let summary = "Legalize tensor load ops that are inserted during mhlo to lmhlo conversion.";
|
||||||
|
let constructor = "createLegalizeTensorLoadOpPass()";
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -111,6 +111,9 @@ std::unique_ptr<FunctionPass> createLhloFuseLinalgPass(
|
||||||
// Lowers from LHLO dialect to parallel loops.
|
// Lowers from LHLO dialect to parallel loops.
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass();
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass();
|
||||||
|
|
||||||
|
// Legalizes tensor load ops that are inserted during mhlo to lmhlo conversion.
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTensorLoadOpPass();
|
||||||
|
|
||||||
} // namespace lmhlo
|
} // namespace lmhlo
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -136,6 +136,7 @@ add_mlir_library(MhloLhloToLinalg
|
||||||
)
|
)
|
||||||
|
|
||||||
add_mlir_library(LmhloPasses
|
add_mlir_library(LmhloPasses
|
||||||
|
legalize_tensor_load_op.cc
|
||||||
lhlo_fuse_linalg.cc
|
lhlo_fuse_linalg.cc
|
||||||
lhlo_legalize_to_affine.cc
|
lhlo_legalize_to_affine.cc
|
||||||
lhlo_legalize_to_gpu.cc
|
lhlo_legalize_to_gpu.cc
|
||||||
|
|
|
@ -0,0 +1,97 @@
|
||||||
|
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This file implements logic for lowering memref.tensor_load ops that are
|
||||||
|
// inserted during `mhlo-legalize-to-lmhlo`.
|
||||||
|
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||||
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Pass/PassRegistry.h"
|
||||||
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace lmhlo {
|
||||||
|
namespace {
|
||||||
|
using shape::ShapeOfOp;
|
||||||
|
using tensor::ExtractOp;
|
||||||
|
|
||||||
|
// Converting:
|
||||||
|
// memref (operand) -> memref.tensor_load -> tensor.extract
|
||||||
|
// to
|
||||||
|
// memref (operand) -> memref.load
|
||||||
|
struct ForwardExtractOp : public OpRewritePattern<ExtractOp> {
|
||||||
|
using OpRewritePattern<ExtractOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(ExtractOp extract,
|
||||||
|
PatternRewriter& rewriter) const override {
|
||||||
|
auto tensor_load = extract.tensor().getDefiningOp<memref::TensorLoadOp>();
|
||||||
|
if (!tensor_load) return failure();
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<memref::LoadOp>(
|
||||||
|
extract, extract.getType(), tensor_load.memref(), extract.indices());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Converting:
|
||||||
|
// memref (operand) -> memref.tensor_load -> shape.shape_of
|
||||||
|
// to
|
||||||
|
// memref (operand) -> shape.shape_of
|
||||||
|
struct ForwardShapeOfOp : public OpRewritePattern<ShapeOfOp> {
|
||||||
|
using OpRewritePattern<ShapeOfOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(ShapeOfOp shape_of,
|
||||||
|
PatternRewriter& rewriter) const override {
|
||||||
|
auto tensor_load = shape_of.arg().getDefiningOp<memref::TensorLoadOp>();
|
||||||
|
if (!tensor_load) return failure();
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<ShapeOfOp>(shape_of, shape_of.getType(),
|
||||||
|
tensor_load.memref());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LegalizeTensorLoadOpPass
|
||||||
|
: public mlir::PassWrapper<LegalizeTensorLoadOpPass, FunctionPass> {
|
||||||
|
// Perform the lowering to remove memref.tensor_load ops inserted during
|
||||||
|
// `mhlo-legalize-to-lmhlo`.
|
||||||
|
void runOnFunction() override {
|
||||||
|
auto func = getFunction();
|
||||||
|
auto context = &getContext();
|
||||||
|
OwningRewritePatternList patterns(context);
|
||||||
|
patterns.insert<ForwardShapeOfOp, ForwardExtractOp>(context);
|
||||||
|
if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
|
||||||
|
func.emitError("applyPatternsAndFoldGreedily does not converge");
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
} // namespace lmhlo
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
|
||||||
|
mlir::lmhlo::createLegalizeTensorLoadOpPass() {
|
||||||
|
return std::make_unique<LegalizeTensorLoadOpPass>();
|
||||||
|
}
|
|
@ -0,0 +1,36 @@
|
||||||
|
// RUN: mlir-hlo-opt -lhlo-legalize-tensor-load-op %s -o - | FileCheck %s
|
||||||
|
|
||||||
|
// test: `memref -> memref.tensor_load -> tensor.extract` -> `memref -> memref.load`
|
||||||
|
// CHECK-LABEL: forward_extract_op
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<3xindex>)
|
||||||
|
func @forward_extract_op(%arg0: memref<?x?xf32>, %arg1: memref<3xindex>) -> memref<?x?x?xf32> {
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%c1 = constant 1 : index
|
||||||
|
%c2 = constant 2 : index
|
||||||
|
// CHECK-NOT: memref.tensor_load
|
||||||
|
// CHECK-NOT: tensor.extract
|
||||||
|
// CHECK: %[[DIM0:.*]] = memref.load %[[ARG1]][%c0]
|
||||||
|
// CHECK: %[[DIM1:.*]] = memref.load %[[ARG1]][%c1]
|
||||||
|
// CHECK: %[[DIM2:.*]] = memref.load %[[ARG1]][%c2]
|
||||||
|
// CHECK: memref.alloc(%[[DIM0]], %[[DIM1]], %[[DIM2]])
|
||||||
|
%0 = memref.tensor_load %arg1 : memref<3xindex>
|
||||||
|
%1 = tensor.extract %0[%c0] : tensor<3xindex>
|
||||||
|
%2 = tensor.extract %0[%c1] : tensor<3xindex>
|
||||||
|
%3 = tensor.extract %0[%c2] : tensor<3xindex>
|
||||||
|
%4 = memref.alloc(%1, %2, %3) : memref<?x?x?xf32>
|
||||||
|
"lmhlo.dynamic_broadcast_in_dim"(%arg0, %arg1, %4) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<?x?xf32>, memref<3xindex>, memref<?x?x?xf32>) -> ()
|
||||||
|
return %4 : memref<?x?x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// test: `memref -> memref.tensor_load -> shape.shape_of` -> `memref -> shape.shape_of`
|
||||||
|
// CHECK-LABEL: forward_shape_of_op
|
||||||
|
// CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32>)
|
||||||
|
func @forward_shape_of_op(%arg0: memref<?x?xf32>) -> tensor<2xindex> {
|
||||||
|
// CHECK-NOT: memref.tensor_load
|
||||||
|
// CHECK: shape.shape_of %[[ARG]] : memref<?x?xf32> -> tensor<2xindex>
|
||||||
|
%0 = memref.tensor_load %arg0 : memref<?x?xf32>
|
||||||
|
%1 = shape.shape_of %0 : tensor<?x?xf32> -> tensor<2xindex>
|
||||||
|
return %1 : tensor<2xindex>
|
||||||
|
}
|
Loading…
Reference in New Issue