[mhlo:linalg] Add support for lowering mhlo.concatenate to Linalg ops.
This uses a indexed linalg.generic, which is rather awkward standalone but allows fusing into the output of the concatenate and avoid to ever materialize it in memory. I think this is the only way to get that with the current linalg stack, fusion across a concatenate would require more infrastructure. PiperOrigin-RevId: 369677652
This commit is contained in:
parent
c5302511f0
commit
4d435a817e
|
@ -941,6 +941,104 @@ class IotaConverter : public OpConversionPattern<OpTy> {
|
|||
}
|
||||
};
|
||||
|
||||
/// Converts mhlo.concatenate operation to a linalg.generic op.
|
||||
struct ConcatenateConverter : public OpConversionPattern<mhlo::ConcatenateOp> {
|
||||
using OpConversionPattern<mhlo::ConcatenateOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
mhlo::ConcatenateOp op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const override {
|
||||
// Shortcut the one-operand case, simplifies code below.
|
||||
if (args.size() == 1) {
|
||||
rewriter.replaceOp(op, args[0]);
|
||||
return success();
|
||||
}
|
||||
|
||||
auto result_type = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
if (!result_type) return failure();
|
||||
|
||||
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
||||
uint64_t dim = op.dimension();
|
||||
int64_t rank = result_type.getRank();
|
||||
Value zero = b.create<ConstantIndexOp>(0);
|
||||
SmallVector<Value, 3> sizes;
|
||||
for (int64_t i = 0; i < rank; ++i) {
|
||||
sizes.push_back(i == dim ? Value() : b.create<memref::DimOp>(args[0], i));
|
||||
}
|
||||
|
||||
// Calculate the size of the concatenated dimension.
|
||||
Value result_dim_size;
|
||||
for (auto arg : args) {
|
||||
Value size = b.create<memref::DimOp>(arg, dim);
|
||||
result_dim_size =
|
||||
result_dim_size ? b.create<AddIOp>(result_dim_size, size) : size;
|
||||
}
|
||||
sizes[dim] = result_dim_size;
|
||||
|
||||
// Allocate the output tensor with init_tensor.
|
||||
SmallVector<Value, 3> dyn_sizes;
|
||||
for (int64_t i = 0; i < rank; ++i) {
|
||||
if (result_type.isDynamicDim(i)) dyn_sizes.push_back(sizes[i]);
|
||||
}
|
||||
Value result = b.create<linalg::InitTensorOp>(
|
||||
dyn_sizes, result_type.getShape(), result_type.getElementType());
|
||||
|
||||
// Generate a generic op to gather the elements of the concatenate. This is
|
||||
// awkward standalone but allows fusion with other generic ops.
|
||||
unsigned nloops = result_type.getRank();
|
||||
auto linalg_op = b.create<linalg::IndexedGenericOp>(
|
||||
/*resultTensorTypes=*/result_type,
|
||||
/*inputs=*/ValueRange{}, /*outputBuffers=*/result,
|
||||
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
|
||||
GetNParallelLoopsAttrs(nloops),
|
||||
[&](OpBuilder& nested_builder, Location loc, ValueRange ivs,
|
||||
ValueRange) {
|
||||
OpBuilder b = nested_builder;
|
||||
Value concat_dim_size = zero;
|
||||
Value result;
|
||||
auto extract_indices = llvm::to_vector<4>(ivs);
|
||||
for (const Value& arg : args) {
|
||||
Value new_concat_dim_size;
|
||||
scf::IfOp if_op;
|
||||
if (&arg != &args.back()) {
|
||||
// Calculate how far along we have iterated along the concatenate
|
||||
// dimension. That way we can tell which input to select.
|
||||
new_concat_dim_size = b.create<AddIOp>(
|
||||
loc, concat_dim_size, b.create<memref::DimOp>(loc, arg, dim));
|
||||
Value cmp = b.create<CmpIOp>(loc, rewriter.getI1Type(),
|
||||
CmpIPredicate::ult, ivs[dim],
|
||||
new_concat_dim_size);
|
||||
if_op = b.create<scf::IfOp>(loc, result_type.getElementType(),
|
||||
cmp, true);
|
||||
if (result) {
|
||||
b.create<scf::YieldOp>(loc, if_op->getResults()[0]);
|
||||
} else {
|
||||
result = if_op->getResults()[0];
|
||||
}
|
||||
|
||||
b = if_op.getThenBodyBuilder(b.getListener());
|
||||
}
|
||||
|
||||
// Now adjust the index for the concatenated dimension to fit into
|
||||
// the selected tensor and do an extract at that position.
|
||||
extract_indices[dim] =
|
||||
b.create<SubIOp>(loc, ivs[dim], concat_dim_size);
|
||||
Value extract =
|
||||
b.create<tensor::ExtractOp>(loc, arg, extract_indices);
|
||||
b.create<scf::YieldOp>(loc, extract);
|
||||
|
||||
if (if_op) {
|
||||
b = if_op.getElseBodyBuilder(b.getListener());
|
||||
concat_dim_size = new_concat_dim_size;
|
||||
}
|
||||
}
|
||||
nested_builder.create<linalg::YieldOp>(loc, result);
|
||||
});
|
||||
rewriter.replaceOp(op, linalg_op.result_tensors());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
class ConstConverter : public OpConversionPattern<OpTy> {
|
||||
public:
|
||||
|
@ -2107,7 +2205,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
|||
OwningRewritePatternList* patterns) {
|
||||
// clang-format off
|
||||
patterns->insert<
|
||||
BroadcastConverter<mhlo::BroadcastOp, false>,
|
||||
BroadcastConverter<mhlo::BroadcastOp, false>, ConcatenateConverter,
|
||||
ConstConverter<mhlo::ConstOp>, HloDynamicBroadcastInDimConverter,
|
||||
HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>,
|
||||
IotaConverter<mhlo::DynamicIotaOp, false>,
|
||||
|
|
|
@ -2098,3 +2098,58 @@ func @torch_index_select_dynamic(%input: tensor<?x?x?x?xf32>,
|
|||
// CHECK: %[[POS:.+]] = index_cast %[[ARG4]]
|
||||
// CHECK: %[[YIELD:.+]] = tensor.extract %[[INPUT]][%[[ARG0]], %[[ARG1]], %[[POS]], %[[ARG3]]]
|
||||
// CHECK: linalg.yield %[[YIELD]]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @concatenate(
|
||||
// CHECK-SAME: %[[VAL_0:[a-zA-Z0-9_]*]]
|
||||
// CHECK-SAME: %[[VAL_1:[a-zA-Z0-9_]*]]
|
||||
// CHECK-SAME: %[[VAL_2:[a-zA-Z0-9_]*]]
|
||||
// CHECK: %[[VAL_3:.*]] = constant 0 : index
|
||||
// CHECK: %[[VAL_4:.*]] = constant 0 : index
|
||||
// CHECK: %[[VAL_5:.*]] = memref.dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_6:.*]] = constant 1 : index
|
||||
// CHECK: %[[VAL_7:.*]] = memref.dim %[[VAL_0]], %[[VAL_6]] : tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_8:.*]] = constant 1 : index
|
||||
// CHECK: %[[VAL_9:.*]] = memref.dim %[[VAL_1]], %[[VAL_8]] : tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_10:.*]] = addi %[[VAL_7]], %[[VAL_9]] : index
|
||||
// CHECK: %[[VAL_11:.*]] = constant 1 : index
|
||||
// CHECK: %[[VAL_12:.*]] = memref.dim %[[VAL_2]], %[[VAL_11]] : tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_13:.*]] = addi %[[VAL_10]], %[[VAL_12]] : index
|
||||
// CHECK: %[[VAL_14:.*]] = linalg.init_tensor [%[[VAL_5]], %[[VAL_13]]] : tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_15:.*]] = linalg.indexed_generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[VAL_14]] : tensor<?x?xi32>) {
|
||||
// CHECK: ^bb0(%[[VAL_16:.*]]: index, %[[VAL_17:.*]]: index, %[[VAL_18:.*]]: i32):
|
||||
// CHECK: %[[VAL_19:.*]] = constant 1 : index
|
||||
// CHECK: %[[VAL_20:.*]] = memref.dim %[[VAL_0]], %[[VAL_19]] : tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_21:.*]] = addi %[[VAL_3]], %[[VAL_20]] : index
|
||||
// CHECK: %[[VAL_22:.*]] = cmpi ult, %[[VAL_17]], %[[VAL_21]] : index
|
||||
// CHECK: %[[VAL_23:.*]] = scf.if %[[VAL_22]] -> (i32) {
|
||||
// CHECK: %[[VAL_24:.*]] = subi %[[VAL_17]], %[[VAL_3]] : index
|
||||
// CHECK: %[[VAL_25:.*]] = tensor.extract %[[VAL_0]][%[[VAL_16]], %[[VAL_24]]] : tensor<?x?xi32>
|
||||
// CHECK: scf.yield %[[VAL_25]] : i32
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[VAL_26:.*]] = constant 1 : index
|
||||
// CHECK: %[[VAL_27:.*]] = memref.dim %[[VAL_1]], %[[VAL_26]] : tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_28:.*]] = addi %[[VAL_21]], %[[VAL_27]] : index
|
||||
// CHECK: %[[VAL_29:.*]] = cmpi ult, %[[VAL_17]], %[[VAL_28]] : index
|
||||
// CHECK: %[[VAL_30:.*]] = scf.if %[[VAL_29]] -> (i32) {
|
||||
// CHECK: %[[VAL_31:.*]] = subi %[[VAL_17]], %[[VAL_21]] : index
|
||||
// CHECK: %[[VAL_32:.*]] = tensor.extract %[[VAL_1]][%[[VAL_16]], %[[VAL_31]]] : tensor<?x?xi32>
|
||||
// CHECK: scf.yield %[[VAL_32]] : i32
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[VAL_33:.*]] = subi %[[VAL_17]], %[[VAL_28]] : index
|
||||
// CHECK: %[[VAL_34:.*]] = tensor.extract %[[VAL_2]][%[[VAL_16]], %[[VAL_33]]] : tensor<?x?xi32>
|
||||
// CHECK: scf.yield %[[VAL_34]] : i32
|
||||
// CHECK: }
|
||||
// CHECK: scf.yield %[[VAL_35:.*]] : i32
|
||||
// CHECK: }
|
||||
// CHECK: linalg.yield %[[VAL_36:.*]] : i32
|
||||
// CHECK: } -> tensor<?x?xi32>
|
||||
// CHECK: return %[[VAL_37:.*]] : tensor<?x?xi32>
|
||||
// CHECK: }
|
||||
func @concatenate(%a: tensor<?x?xi32>, %b: tensor<?x?xi32>, %c: tensor<?x?xi32>) -> tensor<?x?xi32> {
|
||||
%concat = "mhlo.concatenate"(%a, %b, %c) {
|
||||
dimension = 1
|
||||
} : (tensor<?x?xi32>, tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||
return %concat : tensor<?x?xi32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue