PR #50100: [MLIR][DISC] Bufferize DynamicIotaOp and DynamicPadOp
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/50100 support hlo-to-lhlo conversion for DynamicIotaOp and DynamicPadOp Copybara import of the project: -- c3aae94954e35d3f8ad265f619ef9765665a5115 by Wenyi Zhao <reyizero@gmail.com>: [MLIR][DISC] Bufferize DynamicIotaOp and DynamicPadOp -- adc6996d70b804d61310d56a33fac975d70c8636 by Wenyi Zhao <reyizero@gmail.com>: minor PiperOrigin-RevId: 378733284
This commit is contained in:
parent
642ca86a3f
commit
6660234d80
|
@ -94,7 +94,7 @@ def HLO_IotaOp : HLO_Op<"iota", [NoSideEffect]> {
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLO_DynamicIotaOp: HLO_Op<"dynamic_iota", [NoSideEffect]> {
|
def HLO_DynamicIotaOp: HLO_ShapedInterfaceOp<"dynamic_iota", [NoSideEffect]> {
|
||||||
let summary = "Create linear increasing values from 0 to length -1.";
|
let summary = "Create linear increasing values from 0 to length -1.";
|
||||||
let description = [{
|
let description = [{
|
||||||
Produces an HLO Tensor of the specified shape, with an incremental set of
|
Produces an HLO Tensor of the specified shape, with an incremental set of
|
||||||
|
@ -2220,7 +2220,7 @@ def HLO_RealDynamicSliceOp: HLO_Op<
|
||||||
let hasCustomHLOConverter = 1;
|
let hasCustomHLOConverter = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLO_DynamicPadOp: HLO_Op<"dynamic_pad",
|
def HLO_DynamicPadOp: HLO_ShapedInterfaceOp<"dynamic_pad",
|
||||||
[NoSideEffect, AllElementTypesMatch<["operand", "padding_value", "result"]>,
|
[NoSideEffect, AllElementTypesMatch<["operand", "padding_value", "result"]>,
|
||||||
AllTypesMatch<["edge_padding_low", "edge_padding_high", "interior_padding"]>]> {
|
AllTypesMatch<["edge_padding_low", "edge_padding_high", "interior_padding"]>]> {
|
||||||
let summary = "Dynamic Pad operator";
|
let summary = "Dynamic Pad operator";
|
||||||
|
|
|
@ -1514,7 +1514,7 @@ def LHLO_DynamicIotaOp : LHLO_Op<"dynamic_iota", []> {
|
||||||
See
|
See
|
||||||
https://www.tensorflow.org/xla/operation_semantics#iota
|
https://www.tensorflow.org/xla/operation_semantics#iota
|
||||||
}];
|
}];
|
||||||
let arguments = (ins Arg<LHLO_IntBuffer, "", [MemRead]>:$shape,
|
let arguments = (ins Arg<LHLO_DimensionBuffer, "", [MemRead]>:$shape,
|
||||||
I64Attr:$iota_dimension,
|
I64Attr:$iota_dimension,
|
||||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
|
||||||
}
|
}
|
||||||
|
|
|
@ -55,6 +55,8 @@ MAP_HLO_TO_LHLO(CustomCallOp);
|
||||||
MAP_HLO_TO_LHLO(DivOp);
|
MAP_HLO_TO_LHLO(DivOp);
|
||||||
MAP_HLO_TO_LHLO(DotOp);
|
MAP_HLO_TO_LHLO(DotOp);
|
||||||
MAP_HLO_TO_LHLO(DynamicBroadcastInDimOp);
|
MAP_HLO_TO_LHLO(DynamicBroadcastInDimOp);
|
||||||
|
MAP_HLO_TO_LHLO(DynamicIotaOp);
|
||||||
|
MAP_HLO_TO_LHLO(DynamicPadOp);
|
||||||
MAP_HLO_TO_LHLO(DynamicReshapeOp);
|
MAP_HLO_TO_LHLO(DynamicReshapeOp);
|
||||||
MAP_HLO_TO_LHLO(ExpOp);
|
MAP_HLO_TO_LHLO(ExpOp);
|
||||||
MAP_HLO_TO_LHLO(Expm1Op);
|
MAP_HLO_TO_LHLO(Expm1Op);
|
||||||
|
|
|
@ -186,6 +186,14 @@ static LogicalResult rngInferReturnTypeComponents(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns a new scalar integer value having type `type`. Here `type` must be
|
||||||
|
// an integer or index type.
|
||||||
|
Value MaybeCastTo(OpBuilder& b, Location loc, Value value, Type type) {
|
||||||
|
if (type == value.getType()) return value;
|
||||||
|
assert(type.isIndex() || value.getType().isIndex());
|
||||||
|
return b.create<IndexCastOp>(loc, value, type);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -480,6 +488,14 @@ void DynamicIotaOp::getCanonicalizationPatterns(
|
||||||
results.insert<DynamicIotaBroadcast>(context);
|
results.insert<DynamicIotaBroadcast>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult DynamicIotaOp::reifyReturnTypeShapes(
|
||||||
|
OpBuilder&, ValueRange operands,
|
||||||
|
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||||
|
DynamicIotaOp::Adaptor adaptor(operands);
|
||||||
|
reifiedReturnShapes.push_back(adaptor.output_shape());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// DynamicUpdateSliceOp
|
// DynamicUpdateSliceOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -2374,6 +2390,65 @@ static LogicalResult Verify(DynamicPadOp op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult DynamicPadOp::reifyReturnTypeShapes(
|
||||||
|
OpBuilder& builder, ValueRange operands,
|
||||||
|
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||||
|
DynamicPadOp::Adaptor adaptor(operands);
|
||||||
|
Value operand = adaptor.operand();
|
||||||
|
Value edge_padding_low = adaptor.edge_padding_low();
|
||||||
|
Value edge_padding_high = adaptor.edge_padding_high();
|
||||||
|
Value interior_padding = adaptor.interior_padding();
|
||||||
|
|
||||||
|
auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
|
||||||
|
// Not support unranked pad a.t.m.
|
||||||
|
if (!operand_type) return failure();
|
||||||
|
|
||||||
|
auto loc = this->getLoc();
|
||||||
|
SmallVector<Value, 4> shape_values;
|
||||||
|
shape_values.reserve(operand_type.getRank());
|
||||||
|
Type shape_scalar_type =
|
||||||
|
edge_padding_low.getType().cast<ShapedType>().getElementType();
|
||||||
|
|
||||||
|
auto to_shape_scalar_type = [&](Value v) {
|
||||||
|
return MaybeCastTo(builder, loc, v, shape_scalar_type);
|
||||||
|
};
|
||||||
|
|
||||||
|
Value zero = to_shape_scalar_type(builder.create<ConstantIndexOp>(loc, 0));
|
||||||
|
Value one = to_shape_scalar_type(builder.create<ConstantIndexOp>(loc, 1));
|
||||||
|
|
||||||
|
for (int idx : llvm::seq<int>(0, operand_type.getShape().size())) {
|
||||||
|
Value value_dim =
|
||||||
|
to_shape_scalar_type(builder.create<memref::DimOp>(loc, operand, idx));
|
||||||
|
Value offset = builder.create<ConstantIndexOp>(loc, idx);
|
||||||
|
Value value_low =
|
||||||
|
builder.create<tensor::ExtractOp>(loc, edge_padding_low, offset);
|
||||||
|
Value value_high =
|
||||||
|
builder.create<tensor::ExtractOp>(loc, edge_padding_high, offset);
|
||||||
|
Value value_interior =
|
||||||
|
builder.create<tensor::ExtractOp>(loc, interior_padding, offset);
|
||||||
|
// output_size = input_size + padding_low + padding_high + interior *
|
||||||
|
// max(input_size - 1, 0)
|
||||||
|
Value value_dim_less_than_one =
|
||||||
|
builder.create<CmpIOp>(loc, CmpIPredicate::slt, value_dim, one);
|
||||||
|
Value interior_size = builder.create<MulIOp>(
|
||||||
|
loc, value_interior,
|
||||||
|
builder.create<mlir::SelectOp>(
|
||||||
|
loc, value_dim_less_than_one, zero,
|
||||||
|
builder.create<SubIOp>(loc, value_dim, one)));
|
||||||
|
shape_values.push_back(builder.create<AddIOp>(
|
||||||
|
loc,
|
||||||
|
builder.create<AddIOp>(
|
||||||
|
loc, builder.create<AddIOp>(loc, interior_size, value_dim),
|
||||||
|
value_low),
|
||||||
|
value_high));
|
||||||
|
}
|
||||||
|
|
||||||
|
reifiedReturnShapes.push_back(builder.create<tensor::FromElementsOp>(
|
||||||
|
loc, shape_scalar_type, shape_values));
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ReshapeOp
|
// ReshapeOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -669,6 +669,8 @@ void populateDynamicHLOToLHLOOnlyConversionPattern(
|
||||||
OwningRewritePatternList* patterns) {
|
OwningRewritePatternList* patterns) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
patterns->insert<HloToLhloOpConverter<mhlo::DynamicBroadcastInDimOp>,
|
patterns->insert<HloToLhloOpConverter<mhlo::DynamicBroadcastInDimOp>,
|
||||||
|
HloToLhloOpConverter<mhlo::DynamicIotaOp>,
|
||||||
|
HloToLhloOpConverter<mhlo::DynamicPadOp>,
|
||||||
HloToLhloOpConverter<mhlo::DynamicReshapeOp>
|
HloToLhloOpConverter<mhlo::DynamicReshapeOp>
|
||||||
>(*converter, context);
|
>(*converter, context);
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
|
@ -32,4 +32,52 @@ func @dynamic_broadcast_in_dim(%operand: tensor<?x?xf32>, %shape: tensor<3xindex
|
||||||
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
|
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||||
} : (tensor<?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
} : (tensor<?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||||
return %result : tensor<?x?x?xf32>
|
return %result : tensor<?x?x?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @dynamic_iota
|
||||||
|
// CHECK-SAME: (%[[SHAPE:.*]]: memref<2xindex>) -> memref<5x?xi32>
|
||||||
|
func @dynamic_iota(%arg0 : tensor<2xindex>) -> tensor<5x?xi32> {
|
||||||
|
// CHECK-NOT: tensor_load
|
||||||
|
// CHECK: %[[DIM0:.*]] = memref.load %[[SHAPE]][%c1]
|
||||||
|
// CHECK: %[[OUTPUT:.*]] = memref.alloc(%[[DIM0]])
|
||||||
|
// CHECK: "lmhlo.dynamic_iota"(%[[SHAPE]], %[[OUTPUT]])
|
||||||
|
%0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 1 : i64} : (tensor<2xindex>) -> tensor<5x?xi32>
|
||||||
|
return %0 : tensor<5x?xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @dynamic_pad
|
||||||
|
// CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32>, %[[VAL:.*]]: memref<f32>,
|
||||||
|
// CHECK-SAME: %[[LOW:.*]]: memref<2xindex>, %[[HIGH:.*]]: memref<2xindex>, %[[INTER:.*]]: memref<2xindex>) -> memref<?x?xf32>
|
||||||
|
func @dynamic_pad(%arg0: tensor<?x?xf32>, %arg1: tensor<f32>, %arg2: tensor<2xindex>, %arg3: tensor<2xindex>, %arg4: tensor<2xindex>) -> tensor<?x?xf32> {
|
||||||
|
// CHECK-NOT: tensor_load
|
||||||
|
// CHECK: %[[DIM0:.*]] = memref.dim %[[ARG]], %c0 : memref<?x?xf32>
|
||||||
|
// CHECK: %[[TMP1:.*]] = memref.load %[[LOW]][%c0] : memref<2xindex>
|
||||||
|
// CHECK: %[[TMP2:.*]] = memref.load %[[HIGH]][%c0] : memref<2xindex>
|
||||||
|
// CHECK: %[[TMP3:.*]] = memref.load %[[INTER]][%c0] : memref<2xindex>
|
||||||
|
// CHECK: %[[TMP4:.*]] = cmpi slt, %[[DIM0]], %c1 : index
|
||||||
|
// CHECK: %[[TMP5:.*]] = subi %[[DIM0]], %c1 : index
|
||||||
|
// CHECK: %[[TMP6:.*]] = select %[[TMP4]], %c0, %[[TMP5]] : index
|
||||||
|
// CHECK: %[[TMP7:.*]] = muli %[[TMP3]], %[[TMP6]] : index
|
||||||
|
// CHECK: %[[TMP8:.*]] = addi %[[TMP7]], %[[DIM0]] : index
|
||||||
|
// CHECK: %[[TMP9:.*]] = addi %[[TMP8]], %[[TMP1]] : index
|
||||||
|
// CHECK: %[[TMP10:.*]] = addi %[[TMP9]], %[[TMP2]] : index
|
||||||
|
// CHECK: %[[TMP11:.*]] = memref.dim %[[ARG]], %c1 : memref<?x?xf32>
|
||||||
|
// CHECK: %[[TMP12:.*]] = memref.load %[[LOW]][%c1] : memref<2xindex>
|
||||||
|
// CHECK: %[[TMP13:.*]] = memref.load %[[HIGH]][%c1] : memref<2xindex>
|
||||||
|
// CHECK: %[[TMP14:.*]] = memref.load %[[INTER]][%c1] : memref<2xindex>
|
||||||
|
// CHECK: %[[TMP15:.*]] = cmpi slt, %[[TMP11]], %c1 : index
|
||||||
|
// CHECK: %[[TMP16:.*]] = subi %[[TMP11]], %c1 : index
|
||||||
|
// CHECK: %[[TMP17:.*]] = select %[[TMP15]], %c0, %[[TMP16]] : index
|
||||||
|
// CHECK: %[[TMP18:.*]] = muli %[[TMP14]], %[[TMP17]] : index
|
||||||
|
// CHECK: %[[TMP19:.*]] = addi %[[TMP18]], %[[TMP11]] : index
|
||||||
|
// CHECK: %[[TMP20:.*]] = addi %[[TMP19]], %[[TMP12]] : index
|
||||||
|
// CHECK: %[[TMP21:.*]] = addi %[[TMP20]], %[[TMP13]] : index
|
||||||
|
// CHECK: %[[OUT:.*]] = memref.alloc(%[[TMP10]], %[[TMP21]]) : memref<?x?xf32>
|
||||||
|
// CHECK: "lmhlo.dynamic_pad"(%[[ARG]], %[[VAL]], %[[LOW]], %[[HIGH]], %[[INTER]], %[[OUT]])
|
||||||
|
%0 = "mhlo.dynamic_pad"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<?x?xf32>, tensor<f32>, tensor<2xindex>, tensor<2xindex>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||||
|
return %0: tensor<?x?xf32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue