From 236e7db5c081c7c54e9672f7dea3ef61d1c9ca4a Mon Sep 17 00:00:00 2001 From: Prashant Kumar Date: Wed, 14 Apr 2021 11:05:28 -0700 Subject: [PATCH] PR #47315: [MLIR] Add concatenateOp lowering from lmhlo to Affine. Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/47315 Lowering of `concatenateOp` is added from lmhlo to Affine. The lowering has been added as a part of `lhlo-legalize-to-affine` pass. Signed-off-by: Prashant Kumar Copybara import of the project: -- 15314e4579f7a6901cf3475eff25962a34772eaf by Prashant Kumar : [MLIR] Add concatenateOp lowering from lmhlo to Affine. Lowering of `concatenateOp` is added from lmhlo to Affine. The lowering has been added as a part of `lhlo-legalize-to-affine` pass. Signed-off-by: Prashant Kumar PiperOrigin-RevId: 368465992 --- .../transforms/lhlo_legalize_to_affine.cc | 79 +++++++++++++++++++ tests/lhlo-legalize-to-affine.mlir | 36 +++++++++ 2 files changed, 115 insertions(+) diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc index 0235d2d..2e7d2a4 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc @@ -98,6 +98,84 @@ struct DotOpConverter : public OpRewritePattern { } }; +/// Concat Operation Example (2D): +/// Given inpA[2][1], inpB[2][2], concat_dimension = 1. +/// Compute output[x1][x2]. +/// Implementation Pseudocode: +/// s = 0 +/// for a in range(0, 2): +/// for b in range(0, 1): +/// output[a][b] = inpA[a][b - s] +/// s = 1 +/// for a in range(0, 2): +/// for b in range(1, 3): +/// output[a][b] = inpB[a][b - s] +/// +/// Concatenate composes an array from multiple array operands. The array is of +/// the same rank as each of the input array operands (which must be of the same +/// rank as each other) and contains the arguments in the order that they were +/// specified. +struct ConcatOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ConcatenateOp op, + PatternRewriter& rewriter) const override { + Location loc = op.getLoc(); + Value output = op.output(); + MemRefType outputType = output.getType().cast(); + unsigned outputRank = outputType.getRank(); + ArrayRef outputShape = outputType.getShape(); + + ValueRange operands = op.val(); + uint64_t concatDim = op.dimension(); + int64_t prevBound = 0; + + for (Value operand : operands) { + MemRefType operandType = operand.getType().cast(); + ArrayRef operandShape = operandType.getShape(); + + // TODO(pashu123): Extend support for dynamic dimensions. + if (!operandType.hasStaticShape()) return failure(); + + // Only for the concatenation dimension, the value is dimension - + // prevBound. + SmallVector expr; + for (unsigned i = 0; i < outputRank; i++) { + AffineExpr d0 = (i == concatDim) + ? rewriter.getAffineDimExpr(concatDim) - prevBound + : rewriter.getAffineDimExpr(i); + + expr.push_back(d0); + } + AffineMap map = + AffineMap::get(outputRank, 0, expr, rewriter.getContext()); + + // Create multiple for loop nests iterating along the concatenation + // dimension. + OpBuilder::InsertionGuard guard(rewriter); + SmallVector indices; + AffineForOp forOp; + for (unsigned i = 0; i < outputRank; i++) { + if (i == concatDim) { + forOp = rewriter.create(loc, prevBound, + prevBound + operandShape[i]); + prevBound += operandShape[i]; + indices.push_back(forOp.getInductionVar()); + } else { + forOp = rewriter.create(loc, 0, outputShape[i]); + indices.push_back(forOp.getInductionVar()); + } + rewriter.setInsertionPointToStart(forOp.getBody()); + } + Value storeVal = + rewriter.create(loc, operand, map, indices); + rewriter.create(loc, storeVal, output, indices); + } + rewriter.eraseOp(op); + return success(); + } +}; + template struct BinaryOpConverter : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -145,6 +223,7 @@ void populateLHLOToAffineConversionPattern(MLIRContext* context, BinaryOpConverter, BinaryOpConverter, BinaryOpConverter, + ConcatOpConverter, DotOpConverter>(context); // clang-format on } diff --git a/tests/lhlo-legalize-to-affine.mlir b/tests/lhlo-legalize-to-affine.mlir index 35c0bef..9177f43 100644 --- a/tests/lhlo-legalize-to-affine.mlir +++ b/tests/lhlo-legalize-to-affine.mlir @@ -202,3 +202,39 @@ func @int_dot_op(%lhs: memref<7x3xi32>, %rhs: (memref<7x3xi32>, memref<3x4xi32>, memref<7x4xi32>) -> () return } + +// CHECK-LABEL: func @concatenate +func @concatenate(%arg0: memref<1x1xf32>, %arg1: memref<1x100xf32>, %arg2: memref<1x200xf32>, %arg3: memref<1x301xf32>) { + // CHECK-NEXT: %[[RESULT:.*]] = memref.alloc() : memref<1x301xf32> + // CHECK-NEXT: affine.for %[[X:.*]] = 0 to 1 { + // CHECK-NEXT: affine.for %[[Y:.*]] = 0 to 1 { + // CHECK-NEXT: %[[LOAD:.*]] = affine.load %arg0[%[[X]], %[[Y]]] : memref<1x1xf32> + // CHECK-NEXT: affine.store %[[LOAD]], %[[RESULT]][%[[X]], %[[Y]]] : memref<1x301xf32> + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: affine.for %[[X:.*]] = 0 to 1 { + // CHECK-NEXT: affine.for %[[Y:.*]] = 1 to 101 { + // CHECK-NEXT: %[[LOAD:.*]] = affine.load %arg1[%[[X]], %[[Y]] - 1] : memref<1x100xf32> + // CHECK-NEXT: affine.store %[[LOAD]], %[[RESULT]][%[[X]], %[[Y]]] : memref<1x301xf32> + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: affine.for %[[X:.*]] = 0 to 1 { + // CHECK-NEXT: affine.for %[[Y:.*]] = 101 to 301 { + // CHECK-NEXT: %[[LOAD:.*]] = affine.load %arg2[%[[X]], %[[Y]] - 101] : memref<1x200xf32> + // CHECK-NEXT: affine.store %[[LOAD]], %[[RESULT]][%[[X]], %[[Y]]] : memref<1x301xf32> + %0 = memref.alloc() : memref<1x301xf32> + "lmhlo.concatenate"(%arg0, %arg1, %arg2, %0) {dimension = 1 : i64} : (memref<1x1xf32>, memref<1x100xf32>, memref<1x200xf32>, memref<1x301xf32>) -> () + "lmhlo.copy"(%0, %arg3) : (memref<1x301xf32>, memref<1x301xf32>) -> () + "lmhlo.terminator"() : () -> () +} + +// TODO(pashu123): Extend Support for dynamic dimensions. +// CHECK-LABEL: func @concatenate_dynamic +func @concatenate_dynamic(%arg0: memref<1x?xf32>, %arg1: memref<1x?xf32>, %arg2: memref<1x?xf32>) { + // CHECK: "lmhlo.concatenate" + %cst_1 = constant 1 : index + %0 = memref.alloc(%cst_1) : memref<1x?xf32> + "lmhlo.concatenate"(%arg0, %arg1, %0) {dimension = 1 : i64} : (memref<1x?xf32>, memref<1x?xf32>, memref<1x?xf32>) -> () + "lmhlo.copy"(%0, %arg2) : (memref<1x?xf32>, memref<1x?xf32>) -> () + "lmhlo.terminator"() : () -> () +}