diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 306813c..6530f7c 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -941,6 +941,104 @@ class IotaConverter : public OpConversionPattern { } }; +/// Converts mhlo.concatenate operation to a linalg.generic op. +struct ConcatenateConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::ConcatenateOp op, ArrayRef args, + ConversionPatternRewriter& rewriter) const override { + // Shortcut the one-operand case, simplifies code below. + if (args.size() == 1) { + rewriter.replaceOp(op, args[0]); + return success(); + } + + auto result_type = op.getResult().getType().dyn_cast(); + if (!result_type) return failure(); + + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + uint64_t dim = op.dimension(); + int64_t rank = result_type.getRank(); + Value zero = b.create(0); + SmallVector sizes; + for (int64_t i = 0; i < rank; ++i) { + sizes.push_back(i == dim ? Value() : b.create(args[0], i)); + } + + // Calculate the size of the concatenated dimension. + Value result_dim_size; + for (auto arg : args) { + Value size = b.create(arg, dim); + result_dim_size = + result_dim_size ? b.create(result_dim_size, size) : size; + } + sizes[dim] = result_dim_size; + + // Allocate the output tensor with init_tensor. + SmallVector dyn_sizes; + for (int64_t i = 0; i < rank; ++i) { + if (result_type.isDynamicDim(i)) dyn_sizes.push_back(sizes[i]); + } + Value result = b.create( + dyn_sizes, result_type.getShape(), result_type.getElementType()); + + // Generate a generic op to gather the elements of the concatenate. This is + // awkward standalone but allows fusion with other generic ops. + unsigned nloops = result_type.getRank(); + auto linalg_op = b.create( + /*resultTensorTypes=*/result_type, + /*inputs=*/ValueRange{}, /*outputBuffers=*/result, + llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), + GetNParallelLoopsAttrs(nloops), + [&](OpBuilder& nested_builder, Location loc, ValueRange ivs, + ValueRange) { + OpBuilder b = nested_builder; + Value concat_dim_size = zero; + Value result; + auto extract_indices = llvm::to_vector<4>(ivs); + for (const Value& arg : args) { + Value new_concat_dim_size; + scf::IfOp if_op; + if (&arg != &args.back()) { + // Calculate how far along we have iterated along the concatenate + // dimension. That way we can tell which input to select. + new_concat_dim_size = b.create( + loc, concat_dim_size, b.create(loc, arg, dim)); + Value cmp = b.create(loc, rewriter.getI1Type(), + CmpIPredicate::ult, ivs[dim], + new_concat_dim_size); + if_op = b.create(loc, result_type.getElementType(), + cmp, true); + if (result) { + b.create(loc, if_op->getResults()[0]); + } else { + result = if_op->getResults()[0]; + } + + b = if_op.getThenBodyBuilder(b.getListener()); + } + + // Now adjust the index for the concatenated dimension to fit into + // the selected tensor and do an extract at that position. + extract_indices[dim] = + b.create(loc, ivs[dim], concat_dim_size); + Value extract = + b.create(loc, arg, extract_indices); + b.create(loc, extract); + + if (if_op) { + b = if_op.getElseBodyBuilder(b.getListener()); + concat_dim_size = new_concat_dim_size; + } + } + nested_builder.create(loc, result); + }); + rewriter.replaceOp(op, linalg_op.result_tensors()); + return success(); + } +}; + template class ConstConverter : public OpConversionPattern { public: @@ -2107,7 +2205,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { // clang-format off patterns->insert< - BroadcastConverter, + BroadcastConverter, ConcatenateConverter, ConstConverter, HloDynamicBroadcastInDimConverter, HloBroadcastInDimConverter, IotaConverter, IotaConverter, diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 76a1629..b853fc1 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -2098,3 +2098,58 @@ func @torch_index_select_dynamic(%input: tensor, // CHECK: %[[POS:.+]] = index_cast %[[ARG4]] // CHECK: %[[YIELD:.+]] = tensor.extract %[[INPUT]][%[[ARG0]], %[[ARG1]], %[[POS]], %[[ARG3]]] // CHECK: linalg.yield %[[YIELD]] + +// ----- + +// CHECK-LABEL: func @concatenate( +// CHECK-SAME: %[[VAL_0:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[VAL_1:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[VAL_2:[a-zA-Z0-9_]*]] +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = constant 0 : index +// CHECK: %[[VAL_5:.*]] = memref.dim %[[VAL_0]], %[[VAL_4]] : tensor +// CHECK: %[[VAL_6:.*]] = constant 1 : index +// CHECK: %[[VAL_7:.*]] = memref.dim %[[VAL_0]], %[[VAL_6]] : tensor +// CHECK: %[[VAL_8:.*]] = constant 1 : index +// CHECK: %[[VAL_9:.*]] = memref.dim %[[VAL_1]], %[[VAL_8]] : tensor +// CHECK: %[[VAL_10:.*]] = addi %[[VAL_7]], %[[VAL_9]] : index +// CHECK: %[[VAL_11:.*]] = constant 1 : index +// CHECK: %[[VAL_12:.*]] = memref.dim %[[VAL_2]], %[[VAL_11]] : tensor +// CHECK: %[[VAL_13:.*]] = addi %[[VAL_10]], %[[VAL_12]] : index +// CHECK: %[[VAL_14:.*]] = linalg.init_tensor [%[[VAL_5]], %[[VAL_13]]] : tensor +// CHECK: %[[VAL_15:.*]] = linalg.indexed_generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[VAL_14]] : tensor) { +// CHECK: ^bb0(%[[VAL_16:.*]]: index, %[[VAL_17:.*]]: index, %[[VAL_18:.*]]: i32): +// CHECK: %[[VAL_19:.*]] = constant 1 : index +// CHECK: %[[VAL_20:.*]] = memref.dim %[[VAL_0]], %[[VAL_19]] : tensor +// CHECK: %[[VAL_21:.*]] = addi %[[VAL_3]], %[[VAL_20]] : index +// CHECK: %[[VAL_22:.*]] = cmpi ult, %[[VAL_17]], %[[VAL_21]] : index +// CHECK: %[[VAL_23:.*]] = scf.if %[[VAL_22]] -> (i32) { +// CHECK: %[[VAL_24:.*]] = subi %[[VAL_17]], %[[VAL_3]] : index +// CHECK: %[[VAL_25:.*]] = tensor.extract %[[VAL_0]][%[[VAL_16]], %[[VAL_24]]] : tensor +// CHECK: scf.yield %[[VAL_25]] : i32 +// CHECK: } else { +// CHECK: %[[VAL_26:.*]] = constant 1 : index +// CHECK: %[[VAL_27:.*]] = memref.dim %[[VAL_1]], %[[VAL_26]] : tensor +// CHECK: %[[VAL_28:.*]] = addi %[[VAL_21]], %[[VAL_27]] : index +// CHECK: %[[VAL_29:.*]] = cmpi ult, %[[VAL_17]], %[[VAL_28]] : index +// CHECK: %[[VAL_30:.*]] = scf.if %[[VAL_29]] -> (i32) { +// CHECK: %[[VAL_31:.*]] = subi %[[VAL_17]], %[[VAL_21]] : index +// CHECK: %[[VAL_32:.*]] = tensor.extract %[[VAL_1]][%[[VAL_16]], %[[VAL_31]]] : tensor +// CHECK: scf.yield %[[VAL_32]] : i32 +// CHECK: } else { +// CHECK: %[[VAL_33:.*]] = subi %[[VAL_17]], %[[VAL_28]] : index +// CHECK: %[[VAL_34:.*]] = tensor.extract %[[VAL_2]][%[[VAL_16]], %[[VAL_33]]] : tensor +// CHECK: scf.yield %[[VAL_34]] : i32 +// CHECK: } +// CHECK: scf.yield %[[VAL_35:.*]] : i32 +// CHECK: } +// CHECK: linalg.yield %[[VAL_36:.*]] : i32 +// CHECK: } -> tensor +// CHECK: return %[[VAL_37:.*]] : tensor +// CHECK: } +func @concatenate(%a: tensor, %b: tensor, %c: tensor) -> tensor { + %concat = "mhlo.concatenate"(%a, %b, %c) { + dimension = 1 + } : (tensor, tensor, tensor) -> tensor + return %concat : tensor +}