[mhlo:linalg] Add support for lowering mhlo.concatenate to Linalg ops.
This uses a indexed linalg.generic, which is rather awkward standalone but allows fusing into the output of the concatenate and avoid to ever materialize it in memory. I think this is the only way to get that with the current linalg stack, fusion across a concatenate would require more infrastructure. PiperOrigin-RevId: 369677652
This commit is contained in:
parent
c5302511f0
commit
4d435a817e
|
@ -941,6 +941,104 @@ class IotaConverter : public OpConversionPattern<OpTy> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Converts mhlo.concatenate operation to a linalg.generic op.
|
||||||
|
struct ConcatenateConverter : public OpConversionPattern<mhlo::ConcatenateOp> {
|
||||||
|
using OpConversionPattern<mhlo::ConcatenateOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
mhlo::ConcatenateOp op, ArrayRef<Value> args,
|
||||||
|
ConversionPatternRewriter& rewriter) const override {
|
||||||
|
// Shortcut the one-operand case, simplifies code below.
|
||||||
|
if (args.size() == 1) {
|
||||||
|
rewriter.replaceOp(op, args[0]);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto result_type = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!result_type) return failure();
|
||||||
|
|
||||||
|
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
||||||
|
uint64_t dim = op.dimension();
|
||||||
|
int64_t rank = result_type.getRank();
|
||||||
|
Value zero = b.create<ConstantIndexOp>(0);
|
||||||
|
SmallVector<Value, 3> sizes;
|
||||||
|
for (int64_t i = 0; i < rank; ++i) {
|
||||||
|
sizes.push_back(i == dim ? Value() : b.create<memref::DimOp>(args[0], i));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate the size of the concatenated dimension.
|
||||||
|
Value result_dim_size;
|
||||||
|
for (auto arg : args) {
|
||||||
|
Value size = b.create<memref::DimOp>(arg, dim);
|
||||||
|
result_dim_size =
|
||||||
|
result_dim_size ? b.create<AddIOp>(result_dim_size, size) : size;
|
||||||
|
}
|
||||||
|
sizes[dim] = result_dim_size;
|
||||||
|
|
||||||
|
// Allocate the output tensor with init_tensor.
|
||||||
|
SmallVector<Value, 3> dyn_sizes;
|
||||||
|
for (int64_t i = 0; i < rank; ++i) {
|
||||||
|
if (result_type.isDynamicDim(i)) dyn_sizes.push_back(sizes[i]);
|
||||||
|
}
|
||||||
|
Value result = b.create<linalg::InitTensorOp>(
|
||||||
|
dyn_sizes, result_type.getShape(), result_type.getElementType());
|
||||||
|
|
||||||
|
// Generate a generic op to gather the elements of the concatenate. This is
|
||||||
|
// awkward standalone but allows fusion with other generic ops.
|
||||||
|
unsigned nloops = result_type.getRank();
|
||||||
|
auto linalg_op = b.create<linalg::IndexedGenericOp>(
|
||||||
|
/*resultTensorTypes=*/result_type,
|
||||||
|
/*inputs=*/ValueRange{}, /*outputBuffers=*/result,
|
||||||
|
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
|
||||||
|
GetNParallelLoopsAttrs(nloops),
|
||||||
|
[&](OpBuilder& nested_builder, Location loc, ValueRange ivs,
|
||||||
|
ValueRange) {
|
||||||
|
OpBuilder b = nested_builder;
|
||||||
|
Value concat_dim_size = zero;
|
||||||
|
Value result;
|
||||||
|
auto extract_indices = llvm::to_vector<4>(ivs);
|
||||||
|
for (const Value& arg : args) {
|
||||||
|
Value new_concat_dim_size;
|
||||||
|
scf::IfOp if_op;
|
||||||
|
if (&arg != &args.back()) {
|
||||||
|
// Calculate how far along we have iterated along the concatenate
|
||||||
|
// dimension. That way we can tell which input to select.
|
||||||
|
new_concat_dim_size = b.create<AddIOp>(
|
||||||
|
loc, concat_dim_size, b.create<memref::DimOp>(loc, arg, dim));
|
||||||
|
Value cmp = b.create<CmpIOp>(loc, rewriter.getI1Type(),
|
||||||
|
CmpIPredicate::ult, ivs[dim],
|
||||||
|
new_concat_dim_size);
|
||||||
|
if_op = b.create<scf::IfOp>(loc, result_type.getElementType(),
|
||||||
|
cmp, true);
|
||||||
|
if (result) {
|
||||||
|
b.create<scf::YieldOp>(loc, if_op->getResults()[0]);
|
||||||
|
} else {
|
||||||
|
result = if_op->getResults()[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
b = if_op.getThenBodyBuilder(b.getListener());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now adjust the index for the concatenated dimension to fit into
|
||||||
|
// the selected tensor and do an extract at that position.
|
||||||
|
extract_indices[dim] =
|
||||||
|
b.create<SubIOp>(loc, ivs[dim], concat_dim_size);
|
||||||
|
Value extract =
|
||||||
|
b.create<tensor::ExtractOp>(loc, arg, extract_indices);
|
||||||
|
b.create<scf::YieldOp>(loc, extract);
|
||||||
|
|
||||||
|
if (if_op) {
|
||||||
|
b = if_op.getElseBodyBuilder(b.getListener());
|
||||||
|
concat_dim_size = new_concat_dim_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nested_builder.create<linalg::YieldOp>(loc, result);
|
||||||
|
});
|
||||||
|
rewriter.replaceOp(op, linalg_op.result_tensors());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
class ConstConverter : public OpConversionPattern<OpTy> {
|
class ConstConverter : public OpConversionPattern<OpTy> {
|
||||||
public:
|
public:
|
||||||
|
@ -2107,7 +2205,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
OwningRewritePatternList* patterns) {
|
OwningRewritePatternList* patterns) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
patterns->insert<
|
patterns->insert<
|
||||||
BroadcastConverter<mhlo::BroadcastOp, false>,
|
BroadcastConverter<mhlo::BroadcastOp, false>, ConcatenateConverter,
|
||||||
ConstConverter<mhlo::ConstOp>, HloDynamicBroadcastInDimConverter,
|
ConstConverter<mhlo::ConstOp>, HloDynamicBroadcastInDimConverter,
|
||||||
HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>,
|
HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>,
|
||||||
IotaConverter<mhlo::DynamicIotaOp, false>,
|
IotaConverter<mhlo::DynamicIotaOp, false>,
|
||||||
|
|
|
@ -2098,3 +2098,58 @@ func @torch_index_select_dynamic(%input: tensor<?x?x?x?xf32>,
|
||||||
// CHECK: %[[POS:.+]] = index_cast %[[ARG4]]
|
// CHECK: %[[POS:.+]] = index_cast %[[ARG4]]
|
||||||
// CHECK: %[[YIELD:.+]] = tensor.extract %[[INPUT]][%[[ARG0]], %[[ARG1]], %[[POS]], %[[ARG3]]]
|
// CHECK: %[[YIELD:.+]] = tensor.extract %[[INPUT]][%[[ARG0]], %[[ARG1]], %[[POS]], %[[ARG3]]]
|
||||||
// CHECK: linalg.yield %[[YIELD]]
|
// CHECK: linalg.yield %[[YIELD]]
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @concatenate(
|
||||||
|
// CHECK-SAME: %[[VAL_0:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK-SAME: %[[VAL_1:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK-SAME: %[[VAL_2:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK: %[[VAL_3:.*]] = constant 0 : index
|
||||||
|
// CHECK: %[[VAL_4:.*]] = constant 0 : index
|
||||||
|
// CHECK: %[[VAL_5:.*]] = memref.dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = constant 1 : index
|
||||||
|
// CHECK: %[[VAL_7:.*]] = memref.dim %[[VAL_0]], %[[VAL_6]] : tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = constant 1 : index
|
||||||
|
// CHECK: %[[VAL_9:.*]] = memref.dim %[[VAL_1]], %[[VAL_8]] : tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_10:.*]] = addi %[[VAL_7]], %[[VAL_9]] : index
|
||||||
|
// CHECK: %[[VAL_11:.*]] = constant 1 : index
|
||||||
|
// CHECK: %[[VAL_12:.*]] = memref.dim %[[VAL_2]], %[[VAL_11]] : tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_13:.*]] = addi %[[VAL_10]], %[[VAL_12]] : index
|
||||||
|
// CHECK: %[[VAL_14:.*]] = linalg.init_tensor [%[[VAL_5]], %[[VAL_13]]] : tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_15:.*]] = linalg.indexed_generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[VAL_14]] : tensor<?x?xi32>) {
|
||||||
|
// CHECK: ^bb0(%[[VAL_16:.*]]: index, %[[VAL_17:.*]]: index, %[[VAL_18:.*]]: i32):
|
||||||
|
// CHECK: %[[VAL_19:.*]] = constant 1 : index
|
||||||
|
// CHECK: %[[VAL_20:.*]] = memref.dim %[[VAL_0]], %[[VAL_19]] : tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_21:.*]] = addi %[[VAL_3]], %[[VAL_20]] : index
|
||||||
|
// CHECK: %[[VAL_22:.*]] = cmpi ult, %[[VAL_17]], %[[VAL_21]] : index
|
||||||
|
// CHECK: %[[VAL_23:.*]] = scf.if %[[VAL_22]] -> (i32) {
|
||||||
|
// CHECK: %[[VAL_24:.*]] = subi %[[VAL_17]], %[[VAL_3]] : index
|
||||||
|
// CHECK: %[[VAL_25:.*]] = tensor.extract %[[VAL_0]][%[[VAL_16]], %[[VAL_24]]] : tensor<?x?xi32>
|
||||||
|
// CHECK: scf.yield %[[VAL_25]] : i32
|
||||||
|
// CHECK: } else {
|
||||||
|
// CHECK: %[[VAL_26:.*]] = constant 1 : index
|
||||||
|
// CHECK: %[[VAL_27:.*]] = memref.dim %[[VAL_1]], %[[VAL_26]] : tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_28:.*]] = addi %[[VAL_21]], %[[VAL_27]] : index
|
||||||
|
// CHECK: %[[VAL_29:.*]] = cmpi ult, %[[VAL_17]], %[[VAL_28]] : index
|
||||||
|
// CHECK: %[[VAL_30:.*]] = scf.if %[[VAL_29]] -> (i32) {
|
||||||
|
// CHECK: %[[VAL_31:.*]] = subi %[[VAL_17]], %[[VAL_21]] : index
|
||||||
|
// CHECK: %[[VAL_32:.*]] = tensor.extract %[[VAL_1]][%[[VAL_16]], %[[VAL_31]]] : tensor<?x?xi32>
|
||||||
|
// CHECK: scf.yield %[[VAL_32]] : i32
|
||||||
|
// CHECK: } else {
|
||||||
|
// CHECK: %[[VAL_33:.*]] = subi %[[VAL_17]], %[[VAL_28]] : index
|
||||||
|
// CHECK: %[[VAL_34:.*]] = tensor.extract %[[VAL_2]][%[[VAL_16]], %[[VAL_33]]] : tensor<?x?xi32>
|
||||||
|
// CHECK: scf.yield %[[VAL_34]] : i32
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: scf.yield %[[VAL_35:.*]] : i32
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: linalg.yield %[[VAL_36:.*]] : i32
|
||||||
|
// CHECK: } -> tensor<?x?xi32>
|
||||||
|
// CHECK: return %[[VAL_37:.*]] : tensor<?x?xi32>
|
||||||
|
// CHECK: }
|
||||||
|
func @concatenate(%a: tensor<?x?xi32>, %b: tensor<?x?xi32>, %c: tensor<?x?xi32>) -> tensor<?x?xi32> {
|
||||||
|
%concat = "mhlo.concatenate"(%a, %b, %c) {
|
||||||
|
dimension = 1
|
||||||
|
} : (tensor<?x?xi32>, tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||||
|
return %concat : tensor<?x?xi32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue