[mhlo:linalg] Add support for lowering mhlo.concatenate to Linalg ops.

This uses a indexed linalg.generic, which is rather awkward standalone but
allows fusing into the output of the concatenate and avoid to ever materialize
it in memory. I think this is the only way to get that with the current linalg
stack, fusion across a concatenate would require more infrastructure.

PiperOrigin-RevId: 369677652
This commit is contained in:
Benjamin Kramer 2021-04-21 10:00:12 -07:00 committed by TensorFlow MLIR Team
parent c5302511f0
commit 4d435a817e
2 changed files with 154 additions and 1 deletions

View File

@ -941,6 +941,104 @@ class IotaConverter : public OpConversionPattern<OpTy> {
}
};
/// Converts mhlo.concatenate operation to a linalg.generic op.
struct ConcatenateConverter : public OpConversionPattern<mhlo::ConcatenateOp> {
using OpConversionPattern<mhlo::ConcatenateOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::ConcatenateOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const override {
// Shortcut the one-operand case, simplifies code below.
if (args.size() == 1) {
rewriter.replaceOp(op, args[0]);
return success();
}
auto result_type = op.getResult().getType().dyn_cast<RankedTensorType>();
if (!result_type) return failure();
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
uint64_t dim = op.dimension();
int64_t rank = result_type.getRank();
Value zero = b.create<ConstantIndexOp>(0);
SmallVector<Value, 3> sizes;
for (int64_t i = 0; i < rank; ++i) {
sizes.push_back(i == dim ? Value() : b.create<memref::DimOp>(args[0], i));
}
// Calculate the size of the concatenated dimension.
Value result_dim_size;
for (auto arg : args) {
Value size = b.create<memref::DimOp>(arg, dim);
result_dim_size =
result_dim_size ? b.create<AddIOp>(result_dim_size, size) : size;
}
sizes[dim] = result_dim_size;
// Allocate the output tensor with init_tensor.
SmallVector<Value, 3> dyn_sizes;
for (int64_t i = 0; i < rank; ++i) {
if (result_type.isDynamicDim(i)) dyn_sizes.push_back(sizes[i]);
}
Value result = b.create<linalg::InitTensorOp>(
dyn_sizes, result_type.getShape(), result_type.getElementType());
// Generate a generic op to gather the elements of the concatenate. This is
// awkward standalone but allows fusion with other generic ops.
unsigned nloops = result_type.getRank();
auto linalg_op = b.create<linalg::IndexedGenericOp>(
/*resultTensorTypes=*/result_type,
/*inputs=*/ValueRange{}, /*outputBuffers=*/result,
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nested_builder, Location loc, ValueRange ivs,
ValueRange) {
OpBuilder b = nested_builder;
Value concat_dim_size = zero;
Value result;
auto extract_indices = llvm::to_vector<4>(ivs);
for (const Value& arg : args) {
Value new_concat_dim_size;
scf::IfOp if_op;
if (&arg != &args.back()) {
// Calculate how far along we have iterated along the concatenate
// dimension. That way we can tell which input to select.
new_concat_dim_size = b.create<AddIOp>(
loc, concat_dim_size, b.create<memref::DimOp>(loc, arg, dim));
Value cmp = b.create<CmpIOp>(loc, rewriter.getI1Type(),
CmpIPredicate::ult, ivs[dim],
new_concat_dim_size);
if_op = b.create<scf::IfOp>(loc, result_type.getElementType(),
cmp, true);
if (result) {
b.create<scf::YieldOp>(loc, if_op->getResults()[0]);
} else {
result = if_op->getResults()[0];
}
b = if_op.getThenBodyBuilder(b.getListener());
}
// Now adjust the index for the concatenated dimension to fit into
// the selected tensor and do an extract at that position.
extract_indices[dim] =
b.create<SubIOp>(loc, ivs[dim], concat_dim_size);
Value extract =
b.create<tensor::ExtractOp>(loc, arg, extract_indices);
b.create<scf::YieldOp>(loc, extract);
if (if_op) {
b = if_op.getElseBodyBuilder(b.getListener());
concat_dim_size = new_concat_dim_size;
}
}
nested_builder.create<linalg::YieldOp>(loc, result);
});
rewriter.replaceOp(op, linalg_op.result_tensors());
return success();
}
};
template <typename OpTy>
class ConstConverter : public OpConversionPattern<OpTy> {
public:
@ -2107,7 +2205,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) {
// clang-format off
patterns->insert<
BroadcastConverter<mhlo::BroadcastOp, false>,
BroadcastConverter<mhlo::BroadcastOp, false>, ConcatenateConverter,
ConstConverter<mhlo::ConstOp>, HloDynamicBroadcastInDimConverter,
HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>,
IotaConverter<mhlo::DynamicIotaOp, false>,

View File

@ -2098,3 +2098,58 @@ func @torch_index_select_dynamic(%input: tensor<?x?x?x?xf32>,
// CHECK: %[[POS:.+]] = index_cast %[[ARG4]]
// CHECK: %[[YIELD:.+]] = tensor.extract %[[INPUT]][%[[ARG0]], %[[ARG1]], %[[POS]], %[[ARG3]]]
// CHECK: linalg.yield %[[YIELD]]
// -----
// CHECK-LABEL: func @concatenate(
// CHECK-SAME: %[[VAL_0:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[VAL_1:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[VAL_2:[a-zA-Z0-9_]*]]
// CHECK: %[[VAL_3:.*]] = constant 0 : index
// CHECK: %[[VAL_4:.*]] = constant 0 : index
// CHECK: %[[VAL_5:.*]] = memref.dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?xi32>
// CHECK: %[[VAL_6:.*]] = constant 1 : index
// CHECK: %[[VAL_7:.*]] = memref.dim %[[VAL_0]], %[[VAL_6]] : tensor<?x?xi32>
// CHECK: %[[VAL_8:.*]] = constant 1 : index
// CHECK: %[[VAL_9:.*]] = memref.dim %[[VAL_1]], %[[VAL_8]] : tensor<?x?xi32>
// CHECK: %[[VAL_10:.*]] = addi %[[VAL_7]], %[[VAL_9]] : index
// CHECK: %[[VAL_11:.*]] = constant 1 : index
// CHECK: %[[VAL_12:.*]] = memref.dim %[[VAL_2]], %[[VAL_11]] : tensor<?x?xi32>
// CHECK: %[[VAL_13:.*]] = addi %[[VAL_10]], %[[VAL_12]] : index
// CHECK: %[[VAL_14:.*]] = linalg.init_tensor [%[[VAL_5]], %[[VAL_13]]] : tensor<?x?xi32>
// CHECK: %[[VAL_15:.*]] = linalg.indexed_generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[VAL_14]] : tensor<?x?xi32>) {
// CHECK: ^bb0(%[[VAL_16:.*]]: index, %[[VAL_17:.*]]: index, %[[VAL_18:.*]]: i32):
// CHECK: %[[VAL_19:.*]] = constant 1 : index
// CHECK: %[[VAL_20:.*]] = memref.dim %[[VAL_0]], %[[VAL_19]] : tensor<?x?xi32>
// CHECK: %[[VAL_21:.*]] = addi %[[VAL_3]], %[[VAL_20]] : index
// CHECK: %[[VAL_22:.*]] = cmpi ult, %[[VAL_17]], %[[VAL_21]] : index
// CHECK: %[[VAL_23:.*]] = scf.if %[[VAL_22]] -> (i32) {
// CHECK: %[[VAL_24:.*]] = subi %[[VAL_17]], %[[VAL_3]] : index
// CHECK: %[[VAL_25:.*]] = tensor.extract %[[VAL_0]][%[[VAL_16]], %[[VAL_24]]] : tensor<?x?xi32>
// CHECK: scf.yield %[[VAL_25]] : i32
// CHECK: } else {
// CHECK: %[[VAL_26:.*]] = constant 1 : index
// CHECK: %[[VAL_27:.*]] = memref.dim %[[VAL_1]], %[[VAL_26]] : tensor<?x?xi32>
// CHECK: %[[VAL_28:.*]] = addi %[[VAL_21]], %[[VAL_27]] : index
// CHECK: %[[VAL_29:.*]] = cmpi ult, %[[VAL_17]], %[[VAL_28]] : index
// CHECK: %[[VAL_30:.*]] = scf.if %[[VAL_29]] -> (i32) {
// CHECK: %[[VAL_31:.*]] = subi %[[VAL_17]], %[[VAL_21]] : index
// CHECK: %[[VAL_32:.*]] = tensor.extract %[[VAL_1]][%[[VAL_16]], %[[VAL_31]]] : tensor<?x?xi32>
// CHECK: scf.yield %[[VAL_32]] : i32
// CHECK: } else {
// CHECK: %[[VAL_33:.*]] = subi %[[VAL_17]], %[[VAL_28]] : index
// CHECK: %[[VAL_34:.*]] = tensor.extract %[[VAL_2]][%[[VAL_16]], %[[VAL_33]]] : tensor<?x?xi32>
// CHECK: scf.yield %[[VAL_34]] : i32
// CHECK: }
// CHECK: scf.yield %[[VAL_35:.*]] : i32
// CHECK: }
// CHECK: linalg.yield %[[VAL_36:.*]] : i32
// CHECK: } -> tensor<?x?xi32>
// CHECK: return %[[VAL_37:.*]] : tensor<?x?xi32>
// CHECK: }
func @concatenate(%a: tensor<?x?xi32>, %b: tensor<?x?xi32>, %c: tensor<?x?xi32>) -> tensor<?x?xi32> {
%concat = "mhlo.concatenate"(%a, %b, %c) {
dimension = 1
} : (tensor<?x?xi32>, tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
return %concat : tensor<?x?xi32>
}