From 8388303fd2a53b494283b10700215667038b484c Mon Sep 17 00:00:00 2001 From: Wenyi Zhao <951425797@qq.com> Date: Fri, 11 Jun 2021 16:31:53 -0700 Subject: [PATCH] PR #50211: [MLIR][DISC] Bufferize RealDynamicSliceOp and ReduceOp Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/50211 support hlo-to-lhlo conversion for RealDynamicSliceOp and ReduceOp Copybara import of the project: -- c417b336670a1fc256f7026dfe8080e46d13d79a by Wenyi Zhao : [MLIR][DISC] Bufferize RealDynamicSliceOp and ReduceOp PiperOrigin-RevId: 378972113 --- include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td | 4 +- .../mhlo/transforms/map_hlo_to_lhlo_op.h | 1 + lib/Dialect/mhlo/IR/hlo_ops.cc | 86 +++++++++++++++++++ .../mhlo/transforms/hlo_legalize_to_lhlo.cc | 8 +- tests/hlo-legalize-to-lhlo-only-dynamic.mlir | 69 +++++++++++++++ 5 files changed, 161 insertions(+), 7 deletions(-) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 5ac612a..2632876 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -967,7 +967,7 @@ def HLO_AllToAllOp : HLO_Op<"all_to_all", let results = (outs HLO_Tensor); } -def HLO_ReduceOp: HLO_Op<"reduce", [ +def HLO_ReduceOp: HLO_ShapedInterfaceOp<"reduce", [ RecursiveSideEffects, SameVariadicOperandSize, SingleBlockImplicitTerminator<"ReturnOp">, @@ -2199,7 +2199,7 @@ def HLO_ReducePrecisionOp : let results = (outs HLO_FpTensor:$output); } -def HLO_RealDynamicSliceOp: HLO_Op< +def HLO_RealDynamicSliceOp: HLO_ShapedInterfaceOp< "real_dynamic_slice", [NoSideEffect, AllElementTypesMatch<["operand", "result"]>, AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> { diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h index 59ef375..af42eb9 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h @@ -75,6 +75,7 @@ MAP_HLO_TO_LHLO(NegOp); MAP_HLO_TO_LHLO(NotOp); MAP_HLO_TO_LHLO(OrOp); MAP_HLO_TO_LHLO(PowOp); +MAP_HLO_TO_LHLO(RealDynamicSliceOp); MAP_HLO_TO_LHLO(RealOp); MAP_HLO_TO_LHLO(ReduceOp); MAP_HLO_TO_LHLO(ReshapeOp); diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index c1bbe52..75a2d41 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -1626,6 +1626,52 @@ static LogicalResult Verify(RealDynamicSliceOp op) { return success(); } +LogicalResult RealDynamicSliceOp::reifyReturnTypeShapes( + OpBuilder& builder, ValueRange operands, + SmallVectorImpl& reifiedReturnShapes) { + RealDynamicSliceOp::Adaptor adaptor(operands); + Value operand = adaptor.operand(); + Value start_indices = adaptor.start_indices(); + Value limit_indices = adaptor.limit_indices(); + Value strides = adaptor.strides(); + + auto operand_type = operand.getType().dyn_cast(); + // Not support unranked type a.t.m. + if (!operand_type) return failure(); + + Location loc = this->getLoc(); + SmallVector shape_values; + shape_values.reserve(operand_type.getRank()); + Type shape_scalar_type = + start_indices.getType().cast().getElementType(); + Value one = builder.create(loc, 1); + one = MaybeCastTo(builder, loc, one, shape_scalar_type); + for (const auto& element : llvm::enumerate(operand_type.getShape())) { + Value offset = builder.create(loc, element.index()); + Value value_start = + builder.create(loc, start_indices, offset); + Value value_limit = + builder.create(loc, limit_indices, offset); + Value value_stride = + builder.create(loc, strides, offset); + // size = (limit - start + stride - 1) / stride + shape_values.push_back(builder.create( + loc, + builder.create( + loc, + builder.create( + loc, value_stride, + builder.create(loc, value_limit, value_start)), + one), + value_stride)); + } + + reifiedReturnShapes.push_back(builder.create( + loc, shape_scalar_type, shape_values)); + + return success(); +} + //===----------------------------------------------------------------------===// // InfeedOp //===----------------------------------------------------------------------===// @@ -2072,6 +2118,46 @@ void ReduceOp::getCanonicalizationPatterns(OwningRewritePatternList& results, MLIRContext* context) { results.insert(context); } + +LogicalResult ReduceOp::reifyReturnTypeShapes( + OpBuilder& builder, ValueRange operands, + SmallVectorImpl& reifiedReturnShapes) { + ReduceOp::Adaptor adaptor(operands); + auto inputs = adaptor.inputs(); + + auto operand_type = inputs[0].getType().dyn_cast(); + // Not support unranked type a.t.m. + if (!operand_type) return failure(); + + Location loc = this->getLoc(); + SmallVector shape_values; + SmallVector dimensions(this->dimensions().getValues()); + shape_values.reserve(operand_type.getRank()); + Type shape_scalar_type = builder.getIndexType(); + auto to_shape_scalar_type = [&](Value v) { + return MaybeCastTo(builder, loc, v, shape_scalar_type); + }; + + for (const auto& element : llvm::enumerate(operand_type.getShape())) { + int64_t idx = element.index(); + auto it = std::find(dimensions.begin(), dimensions.end(), idx); + if (it != dimensions.end()) { + continue; + } + Value value_dim = to_shape_scalar_type( + builder.create(loc, inputs[0], element.index())); + shape_values.push_back(value_dim); + } + + Value output_shape = builder.create( + loc, shape_scalar_type, shape_values); + for (size_t i = 0; i < inputs.size(); ++i) { + reifiedReturnShapes.push_back(output_shape); + } + + return success(); +} + //===----------------------------------------------------------------------===// // RngNormalOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index 355cd02..a5fd06c 100644 --- a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -434,11 +434,8 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion { << "tensor to buffer conversion expects a single block " "in the region containing the operation"; } - const auto& original_results = op.getResults(); SmallVector buffer_args(operands.begin(), operands.end()); - for (auto result : original_results) { - buffer_args.push_back(InsertAlloc(loc, result, &rewriter)); - } + if (failed(ConvertResults(op, buffer_args, rewriter))) return failure(); auto new_op = rewriter.create(loc, llvm::None, buffer_args, op->getAttrs()); @@ -671,7 +668,8 @@ void populateDynamicHLOToLHLOOnlyConversionPattern( patterns->insert, HloToLhloOpConverter, HloToLhloOpConverter, - HloToLhloOpConverter + HloToLhloOpConverter, + HloToLhloOpConverter >(*converter, context); // clang-format on } diff --git a/tests/hlo-legalize-to-lhlo-only-dynamic.mlir b/tests/hlo-legalize-to-lhlo-only-dynamic.mlir index 89f078d..095646d 100644 --- a/tests/hlo-legalize-to-lhlo-only-dynamic.mlir +++ b/tests/hlo-legalize-to-lhlo-only-dynamic.mlir @@ -81,3 +81,72 @@ func @dynamic_pad(%arg0: tensor, %arg1: tensor, %arg2: tensor<2xin %0 = "mhlo.dynamic_pad"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor, tensor, tensor<2xindex>, tensor<2xindex>, tensor<2xindex>) -> tensor return %0: tensor } + +// ----- + +// CHECK-LABEL: func @real_dynamic_slice +// CHECK-SAME: (%[[ARG:.*]]: memref, +// CHECK-SAME: %[[START:.*]]: memref<2xi32>, %[[LIMIT:.*]]: memref<2xi32>, %[[STRIDE:.*]]: memref<2xi32>) -> memref +func @real_dynamic_slice(%arg0: tensor, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>, %arg3: tensor<2xi32>) -> tensor { + // CHECK-NOT: tensor_load + // CHECK: %[[T0:.*]] = memref.load %[[START]][%c0] : memref<2xi32> + // CHECK: %[[T1:.*]] = memref.load %[[LIMIT]][%c0] : memref<2xi32> + // CHECK: %[[T2:.*]] = memref.load %[[STRIDE]][%c0] : memref<2xi32> + // CHECK: %[[T3:.*]] = subi %[[T1]], %[[T0]] : i32 + // CHECK: %[[T4:.*]] = addi %[[T2]], %[[T3]] : i32 + // CHECK: %[[T5:.*]] = subi %[[T4]], %c1_i32 : i32 + // CHECK: %[[T6:.*]] = divi_signed %[[T5]], %[[T2]] : i32 + // CHECK: %[[T7:.*]] = memref.load %[[START]][%c1] : memref<2xi32> + // CHECK: %[[T8:.*]] = memref.load %[[LIMIT]][%c1] : memref<2xi32> + // CHECK: %[[T9:.*]] = memref.load %[[STRIDE]][%c1] : memref<2xi32> + // CHECK: %[[T10:.*]] = subi %[[T8]], %[[T7]] : i32 + // CHECK: %[[T11:.*]] = addi %[[T9]], %[[T10]] : i32 + // CHECK: %[[T12:.*]] = subi %[[T11]], %c1_i32 : i32 + // CHECK: %[[T13:.*]] = divi_signed %[[T12]], %[[T9]] : i32 + // CHECK: %[[T14:.*]] = index_cast %[[T6]] : i32 to index + // CHECK: %[[T15:.*]] = index_cast %[[T13]] : i32 to index + // CHECK: %[[T16:.*]] = memref.alloc(%[[T14]], %[[T15]]) : memref + // CHECK: "lmhlo.real_dynamic_slice"(%[[ARG]], %[[START]], %[[LIMIT]], %[[STRIDE]], %[[T16]]) + %0 = "mhlo.real_dynamic_slice"(%arg0, %arg1, %arg2, %arg3) : (tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor + return %0: tensor +} + +// ----- + +// CHECK-LABEL: func @row_reduce +// CHECK-SAME: (%[[ARG:.*]]: memref, %[[VAL:.*]]: memref) -> memref +func @row_reduce(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK-NOT: tensor_load + // CHECK: %[[DIM0:.*]] = memref.dim %[[ARG]], %c0 : memref + // CHECK: %[[OUT:.*]] = memref.alloc(%[[DIM0]]) : memref + // CHECK: lmhlo.reduce + // CHECK-SAME: %[[ARG]], %[[VAL]], %[[OUT]] + // CHECK: return %[[OUT]] : memref + %0 = "mhlo.reduce"(%arg0, %arg1) ( { + ^bb0(%arg2: tensor, %arg3: tensor): // no predecessors + %1 = mhlo.add %arg2, %arg3 : tensor + "mhlo.return"(%1) : (tensor) -> () + }) {dimensions = dense<1> : tensor<1xi64>} + : (tensor, tensor) -> tensor + return %0: tensor +} + +// ----- + +// CHECK-LABEL: func @column_reduce +// CHECK-SAME: (%[[ARG:.*]]: memref, %[[VAL:.*]]: memref) -> memref +func @column_reduce(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK-NOT: tensor_load + // CHECK: %[[DIM1:.*]] = memref.dim %[[ARG]], %c1 : memref + // CHECK: %[[OUT:.*]] = memref.alloc(%[[DIM1]]) : memref + // CHECK: lmhlo.reduce + // CHECK-SAME: %[[ARG]], %[[VAL]], %[[OUT]] + // CHECK: return %[[OUT]] : memref + %0 = "mhlo.reduce"(%arg0, %arg1) ( { + ^bb0(%arg2: tensor, %arg3: tensor): // no predecessors + %1 = mhlo.add %arg2, %arg3 : tensor + "mhlo.return"(%1) : (tensor) -> () + }) {dimensions = dense<0> : tensor<1xi64>} + : (tensor, tensor) -> tensor + return %0: tensor +}