PR #47315: [MLIR] Add concatenateOp lowering from lmhlo to Affine.
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/47315 Lowering of `concatenateOp` is added from lmhlo to Affine. The lowering has been added as a part of `lhlo-legalize-to-affine` pass. Signed-off-by: Prashant Kumar <prashantk@polymagelabs.com> Copybara import of the project: -- 15314e4579f7a6901cf3475eff25962a34772eaf by Prashant Kumar <prashantk@polymagelabs.com>: [MLIR] Add concatenateOp lowering from lmhlo to Affine. Lowering of `concatenateOp` is added from lmhlo to Affine. The lowering has been added as a part of `lhlo-legalize-to-affine` pass. Signed-off-by: Prashant Kumar <prashantk@polymagelabs.com> PiperOrigin-RevId: 368465992
This commit is contained in:
parent
c10167d4a8
commit
236e7db5c0
|
@ -98,6 +98,84 @@ struct DotOpConverter : public OpRewritePattern<DotOp> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Concat Operation Example (2D):
|
||||||
|
/// Given inpA[2][1], inpB[2][2], concat_dimension = 1.
|
||||||
|
/// Compute output[x1][x2].
|
||||||
|
/// Implementation Pseudocode:
|
||||||
|
/// s = 0
|
||||||
|
/// for a in range(0, 2):
|
||||||
|
/// for b in range(0, 1):
|
||||||
|
/// output[a][b] = inpA[a][b - s]
|
||||||
|
/// s = 1
|
||||||
|
/// for a in range(0, 2):
|
||||||
|
/// for b in range(1, 3):
|
||||||
|
/// output[a][b] = inpB[a][b - s]
|
||||||
|
///
|
||||||
|
/// Concatenate composes an array from multiple array operands. The array is of
|
||||||
|
/// the same rank as each of the input array operands (which must be of the same
|
||||||
|
/// rank as each other) and contains the arguments in the order that they were
|
||||||
|
/// specified.
|
||||||
|
struct ConcatOpConverter : public OpRewritePattern<ConcatenateOp> {
|
||||||
|
using OpRewritePattern<ConcatenateOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(ConcatenateOp op,
|
||||||
|
PatternRewriter& rewriter) const override {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
Value output = op.output();
|
||||||
|
MemRefType outputType = output.getType().cast<MemRefType>();
|
||||||
|
unsigned outputRank = outputType.getRank();
|
||||||
|
ArrayRef<int64_t> outputShape = outputType.getShape();
|
||||||
|
|
||||||
|
ValueRange operands = op.val();
|
||||||
|
uint64_t concatDim = op.dimension();
|
||||||
|
int64_t prevBound = 0;
|
||||||
|
|
||||||
|
for (Value operand : operands) {
|
||||||
|
MemRefType operandType = operand.getType().cast<MemRefType>();
|
||||||
|
ArrayRef<int64_t> operandShape = operandType.getShape();
|
||||||
|
|
||||||
|
// TODO(pashu123): Extend support for dynamic dimensions.
|
||||||
|
if (!operandType.hasStaticShape()) return failure();
|
||||||
|
|
||||||
|
// Only for the concatenation dimension, the value is dimension -
|
||||||
|
// prevBound.
|
||||||
|
SmallVector<AffineExpr, 4> expr;
|
||||||
|
for (unsigned i = 0; i < outputRank; i++) {
|
||||||
|
AffineExpr d0 = (i == concatDim)
|
||||||
|
? rewriter.getAffineDimExpr(concatDim) - prevBound
|
||||||
|
: rewriter.getAffineDimExpr(i);
|
||||||
|
|
||||||
|
expr.push_back(d0);
|
||||||
|
}
|
||||||
|
AffineMap map =
|
||||||
|
AffineMap::get(outputRank, 0, expr, rewriter.getContext());
|
||||||
|
|
||||||
|
// Create multiple for loop nests iterating along the concatenation
|
||||||
|
// dimension.
|
||||||
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
|
SmallVector<Value, 3> indices;
|
||||||
|
AffineForOp forOp;
|
||||||
|
for (unsigned i = 0; i < outputRank; i++) {
|
||||||
|
if (i == concatDim) {
|
||||||
|
forOp = rewriter.create<AffineForOp>(loc, prevBound,
|
||||||
|
prevBound + operandShape[i]);
|
||||||
|
prevBound += operandShape[i];
|
||||||
|
indices.push_back(forOp.getInductionVar());
|
||||||
|
} else {
|
||||||
|
forOp = rewriter.create<AffineForOp>(loc, 0, outputShape[i]);
|
||||||
|
indices.push_back(forOp.getInductionVar());
|
||||||
|
}
|
||||||
|
rewriter.setInsertionPointToStart(forOp.getBody());
|
||||||
|
}
|
||||||
|
Value storeVal =
|
||||||
|
rewriter.create<AffineLoadOp>(loc, operand, map, indices);
|
||||||
|
rewriter.create<AffineStoreOp>(loc, storeVal, output, indices);
|
||||||
|
}
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename LhloOpTy>
|
template <typename LhloOpTy>
|
||||||
struct BinaryOpConverter : public OpRewritePattern<LhloOpTy> {
|
struct BinaryOpConverter : public OpRewritePattern<LhloOpTy> {
|
||||||
using OpRewritePattern<LhloOpTy>::OpRewritePattern;
|
using OpRewritePattern<LhloOpTy>::OpRewritePattern;
|
||||||
|
@ -145,6 +223,7 @@ void populateLHLOToAffineConversionPattern(MLIRContext* context,
|
||||||
BinaryOpConverter<lmhlo::MinOp>,
|
BinaryOpConverter<lmhlo::MinOp>,
|
||||||
BinaryOpConverter<lmhlo::MulOp>,
|
BinaryOpConverter<lmhlo::MulOp>,
|
||||||
BinaryOpConverter<lmhlo::SubOp>,
|
BinaryOpConverter<lmhlo::SubOp>,
|
||||||
|
ConcatOpConverter,
|
||||||
DotOpConverter>(context);
|
DotOpConverter>(context);
|
||||||
// clang-format on
|
// clang-format on
|
||||||
}
|
}
|
||||||
|
|
|
@ -202,3 +202,39 @@ func @int_dot_op(%lhs: memref<7x3xi32>, %rhs:
|
||||||
(memref<7x3xi32>, memref<3x4xi32>, memref<7x4xi32>) -> ()
|
(memref<7x3xi32>, memref<3x4xi32>, memref<7x4xi32>) -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @concatenate
|
||||||
|
func @concatenate(%arg0: memref<1x1xf32>, %arg1: memref<1x100xf32>, %arg2: memref<1x200xf32>, %arg3: memref<1x301xf32>) {
|
||||||
|
// CHECK-NEXT: %[[RESULT:.*]] = memref.alloc() : memref<1x301xf32>
|
||||||
|
// CHECK-NEXT: affine.for %[[X:.*]] = 0 to 1 {
|
||||||
|
// CHECK-NEXT: affine.for %[[Y:.*]] = 0 to 1 {
|
||||||
|
// CHECK-NEXT: %[[LOAD:.*]] = affine.load %arg0[%[[X]], %[[Y]]] : memref<1x1xf32>
|
||||||
|
// CHECK-NEXT: affine.store %[[LOAD]], %[[RESULT]][%[[X]], %[[Y]]] : memref<1x301xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: affine.for %[[X:.*]] = 0 to 1 {
|
||||||
|
// CHECK-NEXT: affine.for %[[Y:.*]] = 1 to 101 {
|
||||||
|
// CHECK-NEXT: %[[LOAD:.*]] = affine.load %arg1[%[[X]], %[[Y]] - 1] : memref<1x100xf32>
|
||||||
|
// CHECK-NEXT: affine.store %[[LOAD]], %[[RESULT]][%[[X]], %[[Y]]] : memref<1x301xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: affine.for %[[X:.*]] = 0 to 1 {
|
||||||
|
// CHECK-NEXT: affine.for %[[Y:.*]] = 101 to 301 {
|
||||||
|
// CHECK-NEXT: %[[LOAD:.*]] = affine.load %arg2[%[[X]], %[[Y]] - 101] : memref<1x200xf32>
|
||||||
|
// CHECK-NEXT: affine.store %[[LOAD]], %[[RESULT]][%[[X]], %[[Y]]] : memref<1x301xf32>
|
||||||
|
%0 = memref.alloc() : memref<1x301xf32>
|
||||||
|
"lmhlo.concatenate"(%arg0, %arg1, %arg2, %0) {dimension = 1 : i64} : (memref<1x1xf32>, memref<1x100xf32>, memref<1x200xf32>, memref<1x301xf32>) -> ()
|
||||||
|
"lmhlo.copy"(%0, %arg3) : (memref<1x301xf32>, memref<1x301xf32>) -> ()
|
||||||
|
"lmhlo.terminator"() : () -> ()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(pashu123): Extend Support for dynamic dimensions.
|
||||||
|
// CHECK-LABEL: func @concatenate_dynamic
|
||||||
|
func @concatenate_dynamic(%arg0: memref<1x?xf32>, %arg1: memref<1x?xf32>, %arg2: memref<1x?xf32>) {
|
||||||
|
// CHECK: "lmhlo.concatenate"
|
||||||
|
%cst_1 = constant 1 : index
|
||||||
|
%0 = memref.alloc(%cst_1) : memref<1x?xf32>
|
||||||
|
"lmhlo.concatenate"(%arg0, %arg1, %0) {dimension = 1 : i64} : (memref<1x?xf32>, memref<1x?xf32>, memref<1x?xf32>) -> ()
|
||||||
|
"lmhlo.copy"(%0, %arg2) : (memref<1x?xf32>, memref<1x?xf32>) -> ()
|
||||||
|
"lmhlo.terminator"() : () -> ()
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue