diff --git a/BUILD b/BUILD index cf623ab..64498fe 100644 --- a/BUILD +++ b/BUILD @@ -1099,6 +1099,24 @@ cc_library( ], ) +cc_library( + name = "legalize_tensor_load_op", + srcs = ["lib/Dialect/mhlo/transforms/legalize_tensor_load_op.cc"], + hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"], + deps = [ + ":lhlo", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Shape", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + ], +) + cc_library( name = "chlo_legalize_to_hlo", srcs = ["lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc"], @@ -1193,6 +1211,7 @@ cc_library( ":hlo_legalize_to_lhlo", ":legalize_control_flow", ":legalize_gather_to_torch_index_select", + ":legalize_tensor_load_op", ":legalize_to_linalg", ":legalize_to_standard", ":legalize_trigonometric_to_approximation", diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td b/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td index 17c0524..e1d840c 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td +++ b/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td @@ -51,3 +51,8 @@ def LhloLegalizeToParallelLoopsPass : Pass<"lhlo-legalize-to-parallel-loops", "F let constructor = "createLegalizeLhloToParallelLoopsPass()"; } +def LegalizeTensorLoadOpPass : Pass<"lhlo-legalize-tensor-load-op", "FuncOp"> { + let summary = "Legalize tensor load ops that are inserted during mhlo to lmhlo conversion."; + let constructor = "createLegalizeTensorLoadOpPass()"; +} + diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h index 9c74818..269424b 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h @@ -111,6 +111,9 @@ std::unique_ptr createLhloFuseLinalgPass( // Lowers from LHLO dialect to parallel loops. std::unique_ptr> createLegalizeLhloToParallelLoopsPass(); +// Legalizes tensor load ops that are inserted during mhlo to lmhlo conversion. +std::unique_ptr> createLegalizeTensorLoadOpPass(); + } // namespace lmhlo } // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/CMakeLists.txt b/lib/Dialect/mhlo/transforms/CMakeLists.txt index 8949ac3..0445037 100644 --- a/lib/Dialect/mhlo/transforms/CMakeLists.txt +++ b/lib/Dialect/mhlo/transforms/CMakeLists.txt @@ -136,6 +136,7 @@ add_mlir_library(MhloLhloToLinalg ) add_mlir_library(LmhloPasses + legalize_tensor_load_op.cc lhlo_fuse_linalg.cc lhlo_legalize_to_affine.cc lhlo_legalize_to_gpu.cc diff --git a/lib/Dialect/mhlo/transforms/legalize_tensor_load_op.cc b/lib/Dialect/mhlo/transforms/legalize_tensor_load_op.cc new file mode 100644 index 0000000..6d2733d --- /dev/null +++ b/lib/Dialect/mhlo/transforms/legalize_tensor_load_op.cc @@ -0,0 +1,97 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements logic for lowering memref.tensor_load ops that are +// inserted during `mhlo-legalize-to-lmhlo`. + +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace lmhlo { +namespace { +using shape::ShapeOfOp; +using tensor::ExtractOp; + +// Converting: +// memref (operand) -> memref.tensor_load -> tensor.extract +// to +// memref (operand) -> memref.load +struct ForwardExtractOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExtractOp extract, + PatternRewriter& rewriter) const override { + auto tensor_load = extract.tensor().getDefiningOp(); + if (!tensor_load) return failure(); + + rewriter.replaceOpWithNewOp( + extract, extract.getType(), tensor_load.memref(), extract.indices()); + return success(); + } +}; + +// Converting: +// memref (operand) -> memref.tensor_load -> shape.shape_of +// to +// memref (operand) -> shape.shape_of +struct ForwardShapeOfOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ShapeOfOp shape_of, + PatternRewriter& rewriter) const override { + auto tensor_load = shape_of.arg().getDefiningOp(); + if (!tensor_load) return failure(); + + rewriter.replaceOpWithNewOp(shape_of, shape_of.getType(), + tensor_load.memref()); + return success(); + } +}; + +struct LegalizeTensorLoadOpPass + : public mlir::PassWrapper { + // Perform the lowering to remove memref.tensor_load ops inserted during + // `mhlo-legalize-to-lmhlo`. + void runOnFunction() override { + auto func = getFunction(); + auto context = &getContext(); + OwningRewritePatternList patterns(context); + patterns.insert(context); + if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) { + func.emitError("applyPatternsAndFoldGreedily does not converge"); + signalPassFailure(); + } + } +}; + +} // namespace + +} // namespace lmhlo +} // namespace mlir + +std::unique_ptr> +mlir::lmhlo::createLegalizeTensorLoadOpPass() { + return std::make_unique(); +} diff --git a/tests/lhlo-legalize-tensor-load-op.mlir b/tests/lhlo-legalize-tensor-load-op.mlir new file mode 100644 index 0000000..ca2c7bb --- /dev/null +++ b/tests/lhlo-legalize-tensor-load-op.mlir @@ -0,0 +1,36 @@ +// RUN: mlir-hlo-opt -lhlo-legalize-tensor-load-op %s -o - | FileCheck %s + +// test: `memref -> memref.tensor_load -> tensor.extract` -> `memref -> memref.load` +// CHECK-LABEL: forward_extract_op +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref<3xindex>) +func @forward_extract_op(%arg0: memref, %arg1: memref<3xindex>) -> memref { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + // CHECK-NOT: memref.tensor_load + // CHECK-NOT: tensor.extract + // CHECK: %[[DIM0:.*]] = memref.load %[[ARG1]][%c0] + // CHECK: %[[DIM1:.*]] = memref.load %[[ARG1]][%c1] + // CHECK: %[[DIM2:.*]] = memref.load %[[ARG1]][%c2] + // CHECK: memref.alloc(%[[DIM0]], %[[DIM1]], %[[DIM2]]) + %0 = memref.tensor_load %arg1 : memref<3xindex> + %1 = tensor.extract %0[%c0] : tensor<3xindex> + %2 = tensor.extract %0[%c1] : tensor<3xindex> + %3 = tensor.extract %0[%c2] : tensor<3xindex> + %4 = memref.alloc(%1, %2, %3) : memref + "lmhlo.dynamic_broadcast_in_dim"(%arg0, %arg1, %4) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref, memref<3xindex>, memref) -> () + return %4 : memref +} + +// ----- + +// test: `memref -> memref.tensor_load -> shape.shape_of` -> `memref -> shape.shape_of` +// CHECK-LABEL: forward_shape_of_op +// CHECK-SAME: (%[[ARG:.*]]: memref) +func @forward_shape_of_op(%arg0: memref) -> tensor<2xindex> { + // CHECK-NOT: memref.tensor_load + // CHECK: shape.shape_of %[[ARG]] : memref -> tensor<2xindex> + %0 = memref.tensor_load %arg0 : memref + %1 = shape.shape_of %0 : tensor -> tensor<2xindex> + return %1 : tensor<2xindex> +}