PR #47315: [MLIR] Add concatenateOp lowering from lmhlo to Affine.

Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/47315

Lowering of `concatenateOp` is added from lmhlo to Affine. The lowering
has been added as a part of `lhlo-legalize-to-affine` pass.

Signed-off-by: Prashant Kumar <prashantk@polymagelabs.com>
Copybara import of the project:

--
15314e4579f7a6901cf3475eff25962a34772eaf by Prashant Kumar <prashantk@polymagelabs.com>:

[MLIR] Add concatenateOp lowering from lmhlo to Affine.

Lowering of `concatenateOp` is added from lmhlo to Affine. The lowering
has been added as a part of `lhlo-legalize-to-affine` pass.

Signed-off-by: Prashant Kumar <prashantk@polymagelabs.com>
PiperOrigin-RevId: 368465992
This commit is contained in:
Prashant Kumar 2021-04-14 11:05:28 -07:00 committed by TensorFlow MLIR Team
parent c10167d4a8
commit 236e7db5c0
2 changed files with 115 additions and 0 deletions

View File

@ -98,6 +98,84 @@ struct DotOpConverter : public OpRewritePattern<DotOp> {
} }
}; };
/// Concat Operation Example (2D):
/// Given inpA[2][1], inpB[2][2], concat_dimension = 1.
/// Compute output[x1][x2].
/// Implementation Pseudocode:
/// s = 0
/// for a in range(0, 2):
/// for b in range(0, 1):
/// output[a][b] = inpA[a][b - s]
/// s = 1
/// for a in range(0, 2):
/// for b in range(1, 3):
/// output[a][b] = inpB[a][b - s]
///
/// Concatenate composes an array from multiple array operands. The array is of
/// the same rank as each of the input array operands (which must be of the same
/// rank as each other) and contains the arguments in the order that they were
/// specified.
struct ConcatOpConverter : public OpRewritePattern<ConcatenateOp> {
using OpRewritePattern<ConcatenateOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ConcatenateOp op,
PatternRewriter& rewriter) const override {
Location loc = op.getLoc();
Value output = op.output();
MemRefType outputType = output.getType().cast<MemRefType>();
unsigned outputRank = outputType.getRank();
ArrayRef<int64_t> outputShape = outputType.getShape();
ValueRange operands = op.val();
uint64_t concatDim = op.dimension();
int64_t prevBound = 0;
for (Value operand : operands) {
MemRefType operandType = operand.getType().cast<MemRefType>();
ArrayRef<int64_t> operandShape = operandType.getShape();
// TODO(pashu123): Extend support for dynamic dimensions.
if (!operandType.hasStaticShape()) return failure();
// Only for the concatenation dimension, the value is dimension -
// prevBound.
SmallVector<AffineExpr, 4> expr;
for (unsigned i = 0; i < outputRank; i++) {
AffineExpr d0 = (i == concatDim)
? rewriter.getAffineDimExpr(concatDim) - prevBound
: rewriter.getAffineDimExpr(i);
expr.push_back(d0);
}
AffineMap map =
AffineMap::get(outputRank, 0, expr, rewriter.getContext());
// Create multiple for loop nests iterating along the concatenation
// dimension.
OpBuilder::InsertionGuard guard(rewriter);
SmallVector<Value, 3> indices;
AffineForOp forOp;
for (unsigned i = 0; i < outputRank; i++) {
if (i == concatDim) {
forOp = rewriter.create<AffineForOp>(loc, prevBound,
prevBound + operandShape[i]);
prevBound += operandShape[i];
indices.push_back(forOp.getInductionVar());
} else {
forOp = rewriter.create<AffineForOp>(loc, 0, outputShape[i]);
indices.push_back(forOp.getInductionVar());
}
rewriter.setInsertionPointToStart(forOp.getBody());
}
Value storeVal =
rewriter.create<AffineLoadOp>(loc, operand, map, indices);
rewriter.create<AffineStoreOp>(loc, storeVal, output, indices);
}
rewriter.eraseOp(op);
return success();
}
};
template <typename LhloOpTy> template <typename LhloOpTy>
struct BinaryOpConverter : public OpRewritePattern<LhloOpTy> { struct BinaryOpConverter : public OpRewritePattern<LhloOpTy> {
using OpRewritePattern<LhloOpTy>::OpRewritePattern; using OpRewritePattern<LhloOpTy>::OpRewritePattern;
@ -145,6 +223,7 @@ void populateLHLOToAffineConversionPattern(MLIRContext* context,
BinaryOpConverter<lmhlo::MinOp>, BinaryOpConverter<lmhlo::MinOp>,
BinaryOpConverter<lmhlo::MulOp>, BinaryOpConverter<lmhlo::MulOp>,
BinaryOpConverter<lmhlo::SubOp>, BinaryOpConverter<lmhlo::SubOp>,
ConcatOpConverter,
DotOpConverter>(context); DotOpConverter>(context);
// clang-format on // clang-format on
} }

View File

@ -202,3 +202,39 @@ func @int_dot_op(%lhs: memref<7x3xi32>, %rhs:
(memref<7x3xi32>, memref<3x4xi32>, memref<7x4xi32>) -> () (memref<7x3xi32>, memref<3x4xi32>, memref<7x4xi32>) -> ()
return return
} }
// CHECK-LABEL: func @concatenate
func @concatenate(%arg0: memref<1x1xf32>, %arg1: memref<1x100xf32>, %arg2: memref<1x200xf32>, %arg3: memref<1x301xf32>) {
// CHECK-NEXT: %[[RESULT:.*]] = memref.alloc() : memref<1x301xf32>
// CHECK-NEXT: affine.for %[[X:.*]] = 0 to 1 {
// CHECK-NEXT: affine.for %[[Y:.*]] = 0 to 1 {
// CHECK-NEXT: %[[LOAD:.*]] = affine.load %arg0[%[[X]], %[[Y]]] : memref<1x1xf32>
// CHECK-NEXT: affine.store %[[LOAD]], %[[RESULT]][%[[X]], %[[Y]]] : memref<1x301xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: affine.for %[[X:.*]] = 0 to 1 {
// CHECK-NEXT: affine.for %[[Y:.*]] = 1 to 101 {
// CHECK-NEXT: %[[LOAD:.*]] = affine.load %arg1[%[[X]], %[[Y]] - 1] : memref<1x100xf32>
// CHECK-NEXT: affine.store %[[LOAD]], %[[RESULT]][%[[X]], %[[Y]]] : memref<1x301xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: affine.for %[[X:.*]] = 0 to 1 {
// CHECK-NEXT: affine.for %[[Y:.*]] = 101 to 301 {
// CHECK-NEXT: %[[LOAD:.*]] = affine.load %arg2[%[[X]], %[[Y]] - 101] : memref<1x200xf32>
// CHECK-NEXT: affine.store %[[LOAD]], %[[RESULT]][%[[X]], %[[Y]]] : memref<1x301xf32>
%0 = memref.alloc() : memref<1x301xf32>
"lmhlo.concatenate"(%arg0, %arg1, %arg2, %0) {dimension = 1 : i64} : (memref<1x1xf32>, memref<1x100xf32>, memref<1x200xf32>, memref<1x301xf32>) -> ()
"lmhlo.copy"(%0, %arg3) : (memref<1x301xf32>, memref<1x301xf32>) -> ()
"lmhlo.terminator"() : () -> ()
}
// TODO(pashu123): Extend Support for dynamic dimensions.
// CHECK-LABEL: func @concatenate_dynamic
func @concatenate_dynamic(%arg0: memref<1x?xf32>, %arg1: memref<1x?xf32>, %arg2: memref<1x?xf32>) {
// CHECK: "lmhlo.concatenate"
%cst_1 = constant 1 : index
%0 = memref.alloc(%cst_1) : memref<1x?xf32>
"lmhlo.concatenate"(%arg0, %arg1, %0) {dimension = 1 : i64} : (memref<1x?xf32>, memref<1x?xf32>, memref<1x?xf32>) -> ()
"lmhlo.copy"(%0, %arg2) : (memref<1x?xf32>, memref<1x?xf32>) -> ()
"lmhlo.terminator"() : () -> ()
}