PR #50100: [MLIR][DISC] Bufferize DynamicIotaOp and DynamicPadOp
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/50100 support hlo-to-lhlo conversion for DynamicIotaOp and DynamicPadOp Copybara import of the project: -- c3aae94954e35d3f8ad265f619ef9765665a5115 by Wenyi Zhao <reyizero@gmail.com>: [MLIR][DISC] Bufferize DynamicIotaOp and DynamicPadOp -- adc6996d70b804d61310d56a33fac975d70c8636 by Wenyi Zhao <reyizero@gmail.com>: minor PiperOrigin-RevId: 378733284
This commit is contained in:
parent
642ca86a3f
commit
6660234d80
|
@ -94,7 +94,7 @@ def HLO_IotaOp : HLO_Op<"iota", [NoSideEffect]> {
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def HLO_DynamicIotaOp: HLO_Op<"dynamic_iota", [NoSideEffect]> {
|
||||
def HLO_DynamicIotaOp: HLO_ShapedInterfaceOp<"dynamic_iota", [NoSideEffect]> {
|
||||
let summary = "Create linear increasing values from 0 to length -1.";
|
||||
let description = [{
|
||||
Produces an HLO Tensor of the specified shape, with an incremental set of
|
||||
|
@ -2220,7 +2220,7 @@ def HLO_RealDynamicSliceOp: HLO_Op<
|
|||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
||||
def HLO_DynamicPadOp: HLO_Op<"dynamic_pad",
|
||||
def HLO_DynamicPadOp: HLO_ShapedInterfaceOp<"dynamic_pad",
|
||||
[NoSideEffect, AllElementTypesMatch<["operand", "padding_value", "result"]>,
|
||||
AllTypesMatch<["edge_padding_low", "edge_padding_high", "interior_padding"]>]> {
|
||||
let summary = "Dynamic Pad operator";
|
||||
|
|
|
@ -1514,7 +1514,7 @@ def LHLO_DynamicIotaOp : LHLO_Op<"dynamic_iota", []> {
|
|||
See
|
||||
https://www.tensorflow.org/xla/operation_semantics#iota
|
||||
}];
|
||||
let arguments = (ins Arg<LHLO_IntBuffer, "", [MemRead]>:$shape,
|
||||
let arguments = (ins Arg<LHLO_DimensionBuffer, "", [MemRead]>:$shape,
|
||||
I64Attr:$iota_dimension,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
|
||||
}
|
||||
|
|
|
@ -55,6 +55,8 @@ MAP_HLO_TO_LHLO(CustomCallOp);
|
|||
MAP_HLO_TO_LHLO(DivOp);
|
||||
MAP_HLO_TO_LHLO(DotOp);
|
||||
MAP_HLO_TO_LHLO(DynamicBroadcastInDimOp);
|
||||
MAP_HLO_TO_LHLO(DynamicIotaOp);
|
||||
MAP_HLO_TO_LHLO(DynamicPadOp);
|
||||
MAP_HLO_TO_LHLO(DynamicReshapeOp);
|
||||
MAP_HLO_TO_LHLO(ExpOp);
|
||||
MAP_HLO_TO_LHLO(Expm1Op);
|
||||
|
|
|
@ -186,6 +186,14 @@ static LogicalResult rngInferReturnTypeComponents(
|
|||
return success();
|
||||
}
|
||||
|
||||
// Returns a new scalar integer value having type `type`. Here `type` must be
|
||||
// an integer or index type.
|
||||
Value MaybeCastTo(OpBuilder& b, Location loc, Value value, Type type) {
|
||||
if (type == value.getType()) return value;
|
||||
assert(type.isIndex() || value.getType().isIndex());
|
||||
return b.create<IndexCastOp>(loc, value, type);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -480,6 +488,14 @@ void DynamicIotaOp::getCanonicalizationPatterns(
|
|||
results.insert<DynamicIotaBroadcast>(context);
|
||||
}
|
||||
|
||||
LogicalResult DynamicIotaOp::reifyReturnTypeShapes(
|
||||
OpBuilder&, ValueRange operands,
|
||||
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||
DynamicIotaOp::Adaptor adaptor(operands);
|
||||
reifiedReturnShapes.push_back(adaptor.output_shape());
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DynamicUpdateSliceOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2374,6 +2390,65 @@ static LogicalResult Verify(DynamicPadOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult DynamicPadOp::reifyReturnTypeShapes(
|
||||
OpBuilder& builder, ValueRange operands,
|
||||
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||
DynamicPadOp::Adaptor adaptor(operands);
|
||||
Value operand = adaptor.operand();
|
||||
Value edge_padding_low = adaptor.edge_padding_low();
|
||||
Value edge_padding_high = adaptor.edge_padding_high();
|
||||
Value interior_padding = adaptor.interior_padding();
|
||||
|
||||
auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
|
||||
// Not support unranked pad a.t.m.
|
||||
if (!operand_type) return failure();
|
||||
|
||||
auto loc = this->getLoc();
|
||||
SmallVector<Value, 4> shape_values;
|
||||
shape_values.reserve(operand_type.getRank());
|
||||
Type shape_scalar_type =
|
||||
edge_padding_low.getType().cast<ShapedType>().getElementType();
|
||||
|
||||
auto to_shape_scalar_type = [&](Value v) {
|
||||
return MaybeCastTo(builder, loc, v, shape_scalar_type);
|
||||
};
|
||||
|
||||
Value zero = to_shape_scalar_type(builder.create<ConstantIndexOp>(loc, 0));
|
||||
Value one = to_shape_scalar_type(builder.create<ConstantIndexOp>(loc, 1));
|
||||
|
||||
for (int idx : llvm::seq<int>(0, operand_type.getShape().size())) {
|
||||
Value value_dim =
|
||||
to_shape_scalar_type(builder.create<memref::DimOp>(loc, operand, idx));
|
||||
Value offset = builder.create<ConstantIndexOp>(loc, idx);
|
||||
Value value_low =
|
||||
builder.create<tensor::ExtractOp>(loc, edge_padding_low, offset);
|
||||
Value value_high =
|
||||
builder.create<tensor::ExtractOp>(loc, edge_padding_high, offset);
|
||||
Value value_interior =
|
||||
builder.create<tensor::ExtractOp>(loc, interior_padding, offset);
|
||||
// output_size = input_size + padding_low + padding_high + interior *
|
||||
// max(input_size - 1, 0)
|
||||
Value value_dim_less_than_one =
|
||||
builder.create<CmpIOp>(loc, CmpIPredicate::slt, value_dim, one);
|
||||
Value interior_size = builder.create<MulIOp>(
|
||||
loc, value_interior,
|
||||
builder.create<mlir::SelectOp>(
|
||||
loc, value_dim_less_than_one, zero,
|
||||
builder.create<SubIOp>(loc, value_dim, one)));
|
||||
shape_values.push_back(builder.create<AddIOp>(
|
||||
loc,
|
||||
builder.create<AddIOp>(
|
||||
loc, builder.create<AddIOp>(loc, interior_size, value_dim),
|
||||
value_low),
|
||||
value_high));
|
||||
}
|
||||
|
||||
reifiedReturnShapes.push_back(builder.create<tensor::FromElementsOp>(
|
||||
loc, shape_scalar_type, shape_values));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReshapeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -669,6 +669,8 @@ void populateDynamicHLOToLHLOOnlyConversionPattern(
|
|||
OwningRewritePatternList* patterns) {
|
||||
// clang-format off
|
||||
patterns->insert<HloToLhloOpConverter<mhlo::DynamicBroadcastInDimOp>,
|
||||
HloToLhloOpConverter<mhlo::DynamicIotaOp>,
|
||||
HloToLhloOpConverter<mhlo::DynamicPadOp>,
|
||||
HloToLhloOpConverter<mhlo::DynamicReshapeOp>
|
||||
>(*converter, context);
|
||||
// clang-format on
|
||||
|
|
|
@ -33,3 +33,51 @@ func @dynamic_broadcast_in_dim(%operand: tensor<?x?xf32>, %shape: tensor<3xindex
|
|||
} : (tensor<?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||
return %result : tensor<?x?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @dynamic_iota
|
||||
// CHECK-SAME: (%[[SHAPE:.*]]: memref<2xindex>) -> memref<5x?xi32>
|
||||
func @dynamic_iota(%arg0 : tensor<2xindex>) -> tensor<5x?xi32> {
|
||||
// CHECK-NOT: tensor_load
|
||||
// CHECK: %[[DIM0:.*]] = memref.load %[[SHAPE]][%c1]
|
||||
// CHECK: %[[OUTPUT:.*]] = memref.alloc(%[[DIM0]])
|
||||
// CHECK: "lmhlo.dynamic_iota"(%[[SHAPE]], %[[OUTPUT]])
|
||||
%0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 1 : i64} : (tensor<2xindex>) -> tensor<5x?xi32>
|
||||
return %0 : tensor<5x?xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @dynamic_pad
|
||||
// CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32>, %[[VAL:.*]]: memref<f32>,
|
||||
// CHECK-SAME: %[[LOW:.*]]: memref<2xindex>, %[[HIGH:.*]]: memref<2xindex>, %[[INTER:.*]]: memref<2xindex>) -> memref<?x?xf32>
|
||||
func @dynamic_pad(%arg0: tensor<?x?xf32>, %arg1: tensor<f32>, %arg2: tensor<2xindex>, %arg3: tensor<2xindex>, %arg4: tensor<2xindex>) -> tensor<?x?xf32> {
|
||||
// CHECK-NOT: tensor_load
|
||||
// CHECK: %[[DIM0:.*]] = memref.dim %[[ARG]], %c0 : memref<?x?xf32>
|
||||
// CHECK: %[[TMP1:.*]] = memref.load %[[LOW]][%c0] : memref<2xindex>
|
||||
// CHECK: %[[TMP2:.*]] = memref.load %[[HIGH]][%c0] : memref<2xindex>
|
||||
// CHECK: %[[TMP3:.*]] = memref.load %[[INTER]][%c0] : memref<2xindex>
|
||||
// CHECK: %[[TMP4:.*]] = cmpi slt, %[[DIM0]], %c1 : index
|
||||
// CHECK: %[[TMP5:.*]] = subi %[[DIM0]], %c1 : index
|
||||
// CHECK: %[[TMP6:.*]] = select %[[TMP4]], %c0, %[[TMP5]] : index
|
||||
// CHECK: %[[TMP7:.*]] = muli %[[TMP3]], %[[TMP6]] : index
|
||||
// CHECK: %[[TMP8:.*]] = addi %[[TMP7]], %[[DIM0]] : index
|
||||
// CHECK: %[[TMP9:.*]] = addi %[[TMP8]], %[[TMP1]] : index
|
||||
// CHECK: %[[TMP10:.*]] = addi %[[TMP9]], %[[TMP2]] : index
|
||||
// CHECK: %[[TMP11:.*]] = memref.dim %[[ARG]], %c1 : memref<?x?xf32>
|
||||
// CHECK: %[[TMP12:.*]] = memref.load %[[LOW]][%c1] : memref<2xindex>
|
||||
// CHECK: %[[TMP13:.*]] = memref.load %[[HIGH]][%c1] : memref<2xindex>
|
||||
// CHECK: %[[TMP14:.*]] = memref.load %[[INTER]][%c1] : memref<2xindex>
|
||||
// CHECK: %[[TMP15:.*]] = cmpi slt, %[[TMP11]], %c1 : index
|
||||
// CHECK: %[[TMP16:.*]] = subi %[[TMP11]], %c1 : index
|
||||
// CHECK: %[[TMP17:.*]] = select %[[TMP15]], %c0, %[[TMP16]] : index
|
||||
// CHECK: %[[TMP18:.*]] = muli %[[TMP14]], %[[TMP17]] : index
|
||||
// CHECK: %[[TMP19:.*]] = addi %[[TMP18]], %[[TMP11]] : index
|
||||
// CHECK: %[[TMP20:.*]] = addi %[[TMP19]], %[[TMP12]] : index
|
||||
// CHECK: %[[TMP21:.*]] = addi %[[TMP20]], %[[TMP13]] : index
|
||||
// CHECK: %[[OUT:.*]] = memref.alloc(%[[TMP10]], %[[TMP21]]) : memref<?x?xf32>
|
||||
// CHECK: "lmhlo.dynamic_pad"(%[[ARG]], %[[VAL]], %[[LOW]], %[[HIGH]], %[[INTER]], %[[OUT]])
|
||||
%0 = "mhlo.dynamic_pad"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<?x?xf32>, tensor<f32>, tensor<2xindex>, tensor<2xindex>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
return %0: tensor<?x?xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue