From 7f94bd923bd6c46c30876fe00ff422e9a4d1a922 Mon Sep 17 00:00:00 2001 From: Wenyi Zhao <951425797@qq.com> Date: Mon, 14 Jun 2021 12:35:47 -0700 Subject: [PATCH] PR #50236: [MLIR][DISC] Bufferize TransposeOp and ConcatenateOp Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/50236 support hlo-to-lhlo conversion for TransposeOp and ConcatenateOp Copybara import of the project: -- 62860e717f2a14fbd3ddfb634aa6ff132d245a72 by Wenyi Zhao : [MLIR][DISC] Bufferize TransposeOp and ConcatenateOp -- ce2ff57c1edee1172cd2f36346cc0b34ec1c7467 by Wenyi Zhao : fix PiperOrigin-RevId: 379330954 --- include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td | 4 +- .../mhlo/transforms/map_hlo_to_lhlo_op.h | 1 + lib/Dialect/mhlo/IR/hlo_ops.cc | 85 +++++++++++++++++++ .../mhlo/transforms/hlo_legalize_to_lhlo.cc | 1 + tests/hlo-legalize-to-lhlo-only-dynamic.mlir | 34 ++++++++ 5 files changed, 123 insertions(+), 2 deletions(-) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 5274a96..ea599c6 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -1382,7 +1382,7 @@ def HLO_ClampOp : HLO_Op<"clamp", let results = (outs HLO_Tensor); } -def HLO_ConcatenateOp : HLO_Op<"concatenate", +def HLO_ConcatenateOp : HLO_ShapedInterfaceOp<"concatenate", [NoSideEffect, SameOperandsAndResultElementType, DeclareOpInterfaceMethods]> { let summary = "XLA's concatenate op"; @@ -1901,7 +1901,7 @@ def HLO_TraceOp: HLO_Op<"trace", []> { let hasCustomHLOConverter = 1; } -def HLO_TransposeOp: HLO_Op<"transpose", +def HLO_TransposeOp: HLO_ShapedInterfaceOp<"transpose", [NoSideEffect, SameOperandsAndResultElementType]> { let summary = "Transpose operator"; let description = [{ diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h index af42eb9..cb9b360 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h @@ -47,6 +47,7 @@ MAP_HLO_TO_LHLO(ClampOp); MAP_HLO_TO_LHLO(ConstOp); MAP_HLO_TO_LHLO(CompareOp); MAP_HLO_TO_LHLO(ComplexOp); +MAP_HLO_TO_LHLO(ConcatenateOp); MAP_HLO_TO_LHLO(ConvOp); MAP_HLO_TO_LHLO(ConvertOp); MAP_HLO_TO_LHLO(CopyOp); diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 3c76042..b66235d 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -1417,6 +1417,57 @@ static LogicalResult Verify(ConcatenateOp op) { return success(); } +LogicalResult ConcatenateOp::reifyReturnTypeShapes( + OpBuilder& builder, ValueRange operands, + SmallVectorImpl& reifiedReturnShapes) { + ConcatenateOp::Adaptor adaptor(operands); + auto inputs = adaptor.val(); + + auto operand_type = inputs[0].getType().dyn_cast(); + // Not support unranked type a.t.m. + if (!operand_type) return failure(); + + Location loc = this->getLoc(); + Type shape_scalar_type = builder.getIndexType(); + auto to_shape_scalar_type = [&](Value v) { + return MaybeCastTo(builder, loc, v, shape_scalar_type); + }; + + SmallVector, 4> all_shape_values; + for (size_t input_id = 0; input_id < inputs.size(); ++input_id) { + Value operand = inputs[input_id]; + auto operand_type = operand.getType().dyn_cast(); + if (!operand_type) return failure(); + + SmallVector shape_vals; + for (const auto& element : llvm::enumerate(operand_type.getShape())) { + Value value_dim = to_shape_scalar_type( + builder.create(loc, operand, element.index())); + shape_vals.push_back(value_dim); + } + all_shape_values.emplace_back(std::move(shape_vals)); + } + + int axis = this->dimension(); + auto& shape_values = all_shape_values[0]; + for (size_t vec_id = 1; vec_id < all_shape_values.size(); ++vec_id) { + auto& other_shape_values = all_shape_values[vec_id]; + if (other_shape_values.size() != shape_values.size()) { + this->emitOpError() + << "Concatenate expects all operands must be of the same rank"; + return failure(); + } + shape_values[axis] = builder.create(loc, shape_values[axis], + other_shape_values[axis]); + } + + Value output_shape = builder.create( + loc, shape_scalar_type, shape_values); + reifiedReturnShapes.push_back(output_shape); + + return success(); +} + //===----------------------------------------------------------------------===// // DynamicReshapeOp //===----------------------------------------------------------------------===// @@ -3363,6 +3414,40 @@ static LogicalResult Verify(TransposeOp op) { return success(); } +LogicalResult TransposeOp::reifyReturnTypeShapes( + OpBuilder& builder, ValueRange operands, + SmallVectorImpl& reifiedReturnShapes) { + TransposeOp::Adaptor adaptor(operands); + Value operand = adaptor.operand(); + + auto operand_type = operand.getType().dyn_cast(); + // Not support unranked type a.t.m. + if (!operand_type) return failure(); + + Location loc = this->getLoc(); + SmallVector permutation(this->permutation().getValues()); + SmallVector shape_values(permutation.size()); + + Type shape_scalar_type = builder.getIndexType(); + auto to_shape_scalar_type = [&](Value v) { + return MaybeCastTo(builder, loc, v, shape_scalar_type); + }; + + for (const auto& element : llvm::enumerate(operand_type.getShape())) { + int64_t idx = element.index(); + auto it = std::find(permutation.begin(), permutation.end(), idx); + Value value_dim = to_shape_scalar_type( + builder.create(loc, operand, element.index())); + shape_values[std::distance(permutation.begin(), it)] = value_dim; + } + + Value output_shape = builder.create( + loc, shape_scalar_type, shape_values); + reifiedReturnShapes.push_back(output_shape); + + return success(); +} + //===----------------------------------------------------------------------===// // TriangularSolveOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index a5fd06c..0fa8af8 100644 --- a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -696,6 +696,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, diff --git a/tests/hlo-legalize-to-lhlo-only-dynamic.mlir b/tests/hlo-legalize-to-lhlo-only-dynamic.mlir index 095646d..4f3b5ee 100644 --- a/tests/hlo-legalize-to-lhlo-only-dynamic.mlir +++ b/tests/hlo-legalize-to-lhlo-only-dynamic.mlir @@ -150,3 +150,37 @@ func @column_reduce(%arg0: tensor, %arg1: tensor) -> tensor : (tensor, tensor) -> tensor return %0: tensor } + +// ----- + +// CHECK-LABEL: func @transpose +// CHECK-SAME: (%[[ARG:.*]]: memref) -> memref +func @transpose(%arg0: tensor) -> tensor { + // CHECK-NOT: tensor_load + // CHECK: %[[DIM0:.*]] = memref.dim %[[ARG]], %c0 : memref + // CHECK: %[[DIM1:.*]] = memref.dim %[[ARG]], %c1 : memref + // CHECK: %[[OUT:.*]] = memref.alloc(%[[DIM1]], %[[DIM0]]) : memref + // CHECK: "lmhlo.transpose"(%[[ARG]], %[[OUT]]) + %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1,0]> : tensor<2xi64>} : (tensor) -> tensor + return %0: tensor +} + +// ----- + +// CHECK-LABEL: func @concatenate +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref) -> memref +func @concatenate(%a: tensor, %b: tensor, %c: tensor) -> tensor { + // CHECK-NOT: tensor_load + // CHECK: %[[ARG0_DIM0:.*]] = memref.dim %[[ARG0]], %c0 : memref + // CHECK: %[[ARG0_DIM1:.*]] = memref.dim %[[ARG0]], %c1 : memref + // CHECK: %[[ARG1_DIM1:.*]] = memref.dim %[[ARG1]], %c1 : memref + // CHECK: %[[ARG2_DIM1:.*]] = memref.dim %[[ARG2]], %c1 : memref + // CHECK: %[[TMP:.*]] = addi %[[ARG0_DIM1]], %[[ARG1_DIM1]] : index + // CHECK: %[[OUT_DIM1:.*]] = addi %[[TMP]], %[[ARG2_DIM1]] : index + // CHECK: %[[OUT:.*]] = memref.alloc(%[[ARG0_DIM0]], %[[OUT_DIM1]]) : memref + // CHECK: "lmhlo.concatenate"(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[OUT]]) + %concat = "mhlo.concatenate"(%a, %b, %c) { + dimension = 1 + } : (tensor, tensor, tensor) -> tensor + return %concat : tensor +}