PR #47315: [MLIR] Add concatenateOp lowering from lmhlo to Affine.
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/47315 Lowering of `concatenateOp` is added from lmhlo to Affine. The lowering has been added as a part of `lhlo-legalize-to-affine` pass. Signed-off-by: Prashant Kumar <prashantk@polymagelabs.com> Copybara import of the project: -- 15314e4579f7a6901cf3475eff25962a34772eaf by Prashant Kumar <prashantk@polymagelabs.com>: [MLIR] Add concatenateOp lowering from lmhlo to Affine. Lowering of `concatenateOp` is added from lmhlo to Affine. The lowering has been added as a part of `lhlo-legalize-to-affine` pass. Signed-off-by: Prashant Kumar <prashantk@polymagelabs.com> PiperOrigin-RevId: 368465992
This commit is contained in:
parent
c10167d4a8
commit
236e7db5c0
|
@ -98,6 +98,84 @@ struct DotOpConverter : public OpRewritePattern<DotOp> {
|
|||
}
|
||||
};
|
||||
|
||||
/// Concat Operation Example (2D):
|
||||
/// Given inpA[2][1], inpB[2][2], concat_dimension = 1.
|
||||
/// Compute output[x1][x2].
|
||||
/// Implementation Pseudocode:
|
||||
/// s = 0
|
||||
/// for a in range(0, 2):
|
||||
/// for b in range(0, 1):
|
||||
/// output[a][b] = inpA[a][b - s]
|
||||
/// s = 1
|
||||
/// for a in range(0, 2):
|
||||
/// for b in range(1, 3):
|
||||
/// output[a][b] = inpB[a][b - s]
|
||||
///
|
||||
/// Concatenate composes an array from multiple array operands. The array is of
|
||||
/// the same rank as each of the input array operands (which must be of the same
|
||||
/// rank as each other) and contains the arguments in the order that they were
|
||||
/// specified.
|
||||
struct ConcatOpConverter : public OpRewritePattern<ConcatenateOp> {
|
||||
using OpRewritePattern<ConcatenateOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ConcatenateOp op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value output = op.output();
|
||||
MemRefType outputType = output.getType().cast<MemRefType>();
|
||||
unsigned outputRank = outputType.getRank();
|
||||
ArrayRef<int64_t> outputShape = outputType.getShape();
|
||||
|
||||
ValueRange operands = op.val();
|
||||
uint64_t concatDim = op.dimension();
|
||||
int64_t prevBound = 0;
|
||||
|
||||
for (Value operand : operands) {
|
||||
MemRefType operandType = operand.getType().cast<MemRefType>();
|
||||
ArrayRef<int64_t> operandShape = operandType.getShape();
|
||||
|
||||
// TODO(pashu123): Extend support for dynamic dimensions.
|
||||
if (!operandType.hasStaticShape()) return failure();
|
||||
|
||||
// Only for the concatenation dimension, the value is dimension -
|
||||
// prevBound.
|
||||
SmallVector<AffineExpr, 4> expr;
|
||||
for (unsigned i = 0; i < outputRank; i++) {
|
||||
AffineExpr d0 = (i == concatDim)
|
||||
? rewriter.getAffineDimExpr(concatDim) - prevBound
|
||||
: rewriter.getAffineDimExpr(i);
|
||||
|
||||
expr.push_back(d0);
|
||||
}
|
||||
AffineMap map =
|
||||
AffineMap::get(outputRank, 0, expr, rewriter.getContext());
|
||||
|
||||
// Create multiple for loop nests iterating along the concatenation
|
||||
// dimension.
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
SmallVector<Value, 3> indices;
|
||||
AffineForOp forOp;
|
||||
for (unsigned i = 0; i < outputRank; i++) {
|
||||
if (i == concatDim) {
|
||||
forOp = rewriter.create<AffineForOp>(loc, prevBound,
|
||||
prevBound + operandShape[i]);
|
||||
prevBound += operandShape[i];
|
||||
indices.push_back(forOp.getInductionVar());
|
||||
} else {
|
||||
forOp = rewriter.create<AffineForOp>(loc, 0, outputShape[i]);
|
||||
indices.push_back(forOp.getInductionVar());
|
||||
}
|
||||
rewriter.setInsertionPointToStart(forOp.getBody());
|
||||
}
|
||||
Value storeVal =
|
||||
rewriter.create<AffineLoadOp>(loc, operand, map, indices);
|
||||
rewriter.create<AffineStoreOp>(loc, storeVal, output, indices);
|
||||
}
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LhloOpTy>
|
||||
struct BinaryOpConverter : public OpRewritePattern<LhloOpTy> {
|
||||
using OpRewritePattern<LhloOpTy>::OpRewritePattern;
|
||||
|
@ -145,6 +223,7 @@ void populateLHLOToAffineConversionPattern(MLIRContext* context,
|
|||
BinaryOpConverter<lmhlo::MinOp>,
|
||||
BinaryOpConverter<lmhlo::MulOp>,
|
||||
BinaryOpConverter<lmhlo::SubOp>,
|
||||
ConcatOpConverter,
|
||||
DotOpConverter>(context);
|
||||
// clang-format on
|
||||
}
|
||||
|
|
|
@ -202,3 +202,39 @@ func @int_dot_op(%lhs: memref<7x3xi32>, %rhs:
|
|||
(memref<7x3xi32>, memref<3x4xi32>, memref<7x4xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @concatenate
|
||||
func @concatenate(%arg0: memref<1x1xf32>, %arg1: memref<1x100xf32>, %arg2: memref<1x200xf32>, %arg3: memref<1x301xf32>) {
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = memref.alloc() : memref<1x301xf32>
|
||||
// CHECK-NEXT: affine.for %[[X:.*]] = 0 to 1 {
|
||||
// CHECK-NEXT: affine.for %[[Y:.*]] = 0 to 1 {
|
||||
// CHECK-NEXT: %[[LOAD:.*]] = affine.load %arg0[%[[X]], %[[Y]]] : memref<1x1xf32>
|
||||
// CHECK-NEXT: affine.store %[[LOAD]], %[[RESULT]][%[[X]], %[[Y]]] : memref<1x301xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: affine.for %[[X:.*]] = 0 to 1 {
|
||||
// CHECK-NEXT: affine.for %[[Y:.*]] = 1 to 101 {
|
||||
// CHECK-NEXT: %[[LOAD:.*]] = affine.load %arg1[%[[X]], %[[Y]] - 1] : memref<1x100xf32>
|
||||
// CHECK-NEXT: affine.store %[[LOAD]], %[[RESULT]][%[[X]], %[[Y]]] : memref<1x301xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: affine.for %[[X:.*]] = 0 to 1 {
|
||||
// CHECK-NEXT: affine.for %[[Y:.*]] = 101 to 301 {
|
||||
// CHECK-NEXT: %[[LOAD:.*]] = affine.load %arg2[%[[X]], %[[Y]] - 101] : memref<1x200xf32>
|
||||
// CHECK-NEXT: affine.store %[[LOAD]], %[[RESULT]][%[[X]], %[[Y]]] : memref<1x301xf32>
|
||||
%0 = memref.alloc() : memref<1x301xf32>
|
||||
"lmhlo.concatenate"(%arg0, %arg1, %arg2, %0) {dimension = 1 : i64} : (memref<1x1xf32>, memref<1x100xf32>, memref<1x200xf32>, memref<1x301xf32>) -> ()
|
||||
"lmhlo.copy"(%0, %arg3) : (memref<1x301xf32>, memref<1x301xf32>) -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
||||
// TODO(pashu123): Extend Support for dynamic dimensions.
|
||||
// CHECK-LABEL: func @concatenate_dynamic
|
||||
func @concatenate_dynamic(%arg0: memref<1x?xf32>, %arg1: memref<1x?xf32>, %arg2: memref<1x?xf32>) {
|
||||
// CHECK: "lmhlo.concatenate"
|
||||
%cst_1 = constant 1 : index
|
||||
%0 = memref.alloc(%cst_1) : memref<1x?xf32>
|
||||
"lmhlo.concatenate"(%arg0, %arg1, %0) {dimension = 1 : i64} : (memref<1x?xf32>, memref<1x?xf32>, memref<1x?xf32>) -> ()
|
||||
"lmhlo.copy"(%0, %arg2) : (memref<1x?xf32>, memref<1x?xf32>) -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue