PR #49228: [MLIR][DISC] porting dynamic shape related OPs to mhlo and lmhlo dialect

Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/49228

We are porting our MLIR-based dynamic shape compiler to tf community (From OP def, Patttern, to Optimization pass, etc).
This is the first PR, which including some dynamic shape OPs def in mhlo and lmhlo dialect.
For mhlo dialect, we add:
- HLO_RealDynamicSliceOp
- HLO_DynamicPadOp
- HLO_DynamicGatherOp
- HLO_DynamicConvOp

For lmhlo dialect, we add:
- LHLO_RealDynamicSliceOp
- LHLO_DynamicBroadcastInDimOp
- LHLO_DynamicGatherOp
- LHLO_DynamicPadOp
- LHLO_DynamicBitcastOp
- LHLO_DynamicConvOp
- LHLO_DynamicIotaOp
- LHLO_DynamicReshapeOp
- LHLO_DotGeneralOp
- LHLO_BitcastOp

Rest Ops to add:
* We will send a separate PR containing LHLO_DynamicWhileOp and LHLO_DynamicCaseOp for control flow.
* We will add a separate dedicated dialect like mhlo_ral, which including D2HOp/H2DOp/DebugPrintOp/TopKOp, etc.

Previous discussions:[RFC](https://groups.google.com/a/tensorflow.org/g/mlir/c/_X48poNcbDI/m/jCC8BWIICQAJ), [discussion_1](https://llvm.discourse.group/t/updates-on-mlir-based-dynamic-shape-compiler/2384), [Recording of meeting](https://drive.google.com/file/d/1_uEISlV5MUWdG9faKAdKlCWnPtGjRC-D/view?usp=sharing).
Copybara import of the project:

--
e22d9e61106e00a1a1c6f368cc4a03e3bd1f414c by azazhu <azazhu@gmail.com>:

[DISC]fea: porting mhlo and lmhlo OPs

--
9ec3e76290da07cbd53d7da5fa86ff67179441a1 by azazhu <azazhu@gmail.com>:

[DISC][MLIR] 1. add summary and description for dynamic OPs in mhlo and lmhlo; 2. rm InferOutputTypes; 3. add verify for RealDynamicSliceOp and DynamicPadOp

--
0d68cd135555fd935991c12456b21329e628f23f by azazhu <azazhu@gmail.com>:

[DISC][MLIR] 1.remove D2H,H2D and DebugPrint Ops from mhlo/lmhlo dialect; 2. add type constraint to DynamicPadOp and RealDynamicSliceOp; 3.refine lmhlo type constraint; 4.rename RealDynamicSliceOp as name conflict.

--
698762a77d60f6a844cb1ab3f32740d4ef3c5843 by azazhu <azazhu@gmail.com>:

[DISC][MLIR] 1. replace dyn_cast to cast 2. refine code

PiperOrigin-RevId: 375022260
This commit is contained in:
Feiwen 2021-05-20 23:15:58 -07:00 committed by TensorFlow MLIR Team
parent cd8f585cf7
commit a7884196f5
5 changed files with 349 additions and 0 deletions

View File

@ -2173,4 +2173,88 @@ def HLO_ReducePrecisionOp :
let results = (outs HLO_FpTensor:$output); let results = (outs HLO_FpTensor:$output);
} }
def HLO_RealDynamicSliceOp: HLO_Op<
"real_dynamic_slice",
[NoSideEffect, AllElementTypesMatch<["operand", "result"]>,
AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> {
let summary = "Real Dynamic Slice operator";
let description = [{
The dynamic shape version of SliceOp. Extracts a sub-array from the input
array according to start_indices, limit_indices and strides. Expect
start_indices/limit_indices/strides to be statically shaped and matching
the rank of the input.
}];
let arguments = (ins
HLO_Tensor:$operand,
HLO_DimensionTensor:$start_indices,
HLO_DimensionTensor:$limit_indices,
HLO_DimensionTensor:$strides
);
let results = (outs HLO_Tensor:$result);
let hasCustomHLOConverter = 1;
}
def HLO_DynamicPadOp: HLO_Op<"dynamic_pad",
[NoSideEffect, AllElementTypesMatch<["operand", "padding_value", "result"]>,
AllTypesMatch<["edge_padding_low", "edge_padding_high", "interior_padding"]>]> {
let summary = "Dynamic Pad operator";
let description = [{
The dynamic shape version of PadOp. Pads the edges of `operand` with the
`padding_value` and according to the passed configuration. Expect
edge_padding_low/edge_padding_high/interior_padding to be statically shaped
and matching the rank of the input.
See
https://www.tensorflow.org/xla/operation_semantics#pad
}];
let arguments = (ins
HLO_Tensor:$operand,
HLO_Tensor:$padding_value,
HLO_DimensionTensor:$edge_padding_low,
HLO_DimensionTensor:$edge_padding_high,
HLO_DimensionTensor:$interior_padding
);
let results = (outs HLO_Tensor:$result);
let description = [{
Dynamically Pads the `operand`, with amount of padding added at
low-end/high-end/interior is passed through input tensors.
}];
let hasCanonicalizer = 1;
let hasCustomHLOConverter = 1;
}
def HLO_DynamicGatherOp: HLO_Op<"dynamic_gather", [NoSideEffect]> {
string summary = "Dynamic Gather operator";
string description = [{
The dynamic shape version of GatherOp. Stitches together several slices of an input
array. slice_sizes is a compile-time variable.
}];
let arguments = (ins
HLO_Tensor:$operand,
HLO_IntTensor:$start_indices,
HLO_IntTensor:$slice_sizes,
GatherDimensionNumbers:$dimension_numbers,
DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted
);
let results = (outs HLO_Tensor);
let hasCustomHLOConverter = 1;
}
def HLO_DynamicConvOp : HLO_Op<"dynamic_conv", [NoSideEffect]>, BASE_HLO_ConvOp {
let summary = "Dynamic Convolution operator";
let description = [{
The dynamic shape version of ConvOp. Computes a convolution with dynamic padding.
}];
let arguments = !con(
(ins
HLO_Tensor:$lhs,
HLO_Tensor:$rhs,
HLO_Tensor:$d_padding),
ConvolutionAttributes.attributes);
let results = (outs HLO_Tensor);
let hasCustomHLOConverter = 1;
}
#endif // HLO_OPS #endif // HLO_OPS

View File

@ -1386,4 +1386,158 @@ def TerminatorOp :
[{ build($_builder, $_state, llvm::None, operands, llvm::None); }]>]; [{ build($_builder, $_state, llvm::None, operands, llvm::None); }]>];
} }
def LHLO_RealDynamicSliceOp: LHLO_Op<
"real_dynamic_slice",
[AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> {
let summary = "LHLO Real Dynamic Slice operator";
let description = [{
The dynamic shape version of DynamicSliceOp. Extracts a sub-array from the
input array according to dynamic start_indices, limit_indices and strides.
}];
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_DimensionBuffer, "", [MemRead]>:$start_indices,
Arg<LHLO_DimensionBuffer, "", [MemRead]>:$limit_indices,
Arg<LHLO_DimensionBuffer, "", [MemRead]>:$strides,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}
def LHLO_DynamicBroadcastInDimOp : LHLO_Op<"dynamic_broadcast_in_dim",
[]> {
let summary = "Broadcast a tensor into the given dynamic shape by adding dimensions.";
let description = [{
The dynamic shape version of BroadcastInDimOp. This is a generalization of the
BroadcastInDimOp which accepts its output dimensions as an argument. It should
eventually supercede the statically shaped original, but is being phased as a
separate op in order to support compatibility with lowerings and translations that
precede dynamic shapes.
}];
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_DimensionBuffer, "", [MemRead]>:$output_dimensions,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
BroadcastDimAttr:$broadcast_dimensions
);
}
def LHLO_DotGeneralOp: LHLO_Op<"dot_general", []> {
let summary = "LHLO General Dot operator";
let description = [{
Performs general dot products between vectors, vector/matrix and
matrix/matrix multiplication.
See https://www.tensorflow.org/xla/operation_semantics#dotgeneral.
}];
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
DotDimensionNumbers:$dot_dimension_numbers,
HLO_PrecisionConfigAttr:$precision_config,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}
def LHLO_DynamicGatherOp: LHLO_Op<"dynamic_gather", []> {
string summary = "LHLO Dynamic Gather operator";
string description = [{
The dynamic shape version of GatherOp. Stitches together several slices of an input
array. slice_sizes is not a const.
}];
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_IntBuffer, "", [MemRead]>:$start_indices,
Arg<LHLO_IntBuffer, "", [MemRead]>:$slice_sizes,
GatherDimensionNumbers:$dimension_numbers,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}
def LHLO_DynamicPadOp: LHLO_Op<
"dynamic_pad",
[AllTypesMatch<["edge_padding_low", "edge_padding_high", "interior_padding"]>]> {
let summary = "LHLO Dynamic Pad operator";
let description = [{
The dynamic shape version of PadOp. Pads the edges of `operand` with the `padding_value` and according to
the passed configuration. Passed configuration are dynamic shape.
See
https://www.tensorflow.org/xla/operation_semantics#pad
}];
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$padding_value,
Arg<LHLO_DimensionBuffer, "", [MemRead]>:$edge_padding_low,
Arg<LHLO_DimensionBuffer, "", [MemRead]>:$edge_padding_high,
Arg<LHLO_DimensionBuffer, "", [MemRead]>:$interior_padding,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}
def LHLO_BitcastOp: LHLO_Op<"bitcast", []> {
let summary = "LHLO Bitcast operator";
let description = [{
This op changes the shape of the input in the way that the physical
arrangement of elements are unchanged.
However, the op needs layout information to make sense of "physical
arrangement of elements". Layout support in MHLO is currently under
exploration.
}];
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}
def LHLO_DynamicBitcastOp: LHLO_Op<"dynamic_bitcast", []> {
let summary = "LHLO Dynamic Bitcast operator";
let description = [{
The dynamic shape version of BitcastOp. This op changes the shape of the
input in the way that the physical arrangement of elements are unchanged.
However, the op needs layout information to make sense of "physical
arrangement of elements". Layout support in MHLO is currently under
exploration.
}];
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_IntBuffer, "", [MemRead]>:$shape,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}
def LHLO_DynamicIotaOp : LHLO_Op<"dynamic_iota", []> {
let summary = "Create linear increasing values from 0 to length -1.";
let description = [{
The dynamic shape version of IotaOp. Produces an output of the specified shape,
with an incremental set of values along the specified dimension starting at 0.
See
https://www.tensorflow.org/xla/operation_semantics#iota
}];
let arguments = (ins Arg<LHLO_IntBuffer, "", [MemRead]>:$shape,
I64Attr:$iota_dimension,
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
}
def LHLO_DynamicConvOp : LHLO_Op<"dynamic_conv", []> {
let arguments = !con(
(ins Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
Arg<LHLO_Buffer, "", [MemRead]>:$d_padding,
Arg<LHLO_Buffer, "", [MemWrite]>:$output),
ConvolutionAttributes.attributes);
}
def LHLO_DynamicReshapeOp: LHLO_Op<"dynamic_reshape", []> {
let summary = "Reshape a tensor to a given, possibly dynamic, shape.";
let description = [{
The dynamic shape version of ReshapeOp. Reshapes `operand` to `output`.
}];
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_IntBuffer, "", [MemRead]>:$shape,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}
#endif // LHLO_OPS #endif // LHLO_OPS

View File

@ -43,4 +43,9 @@ def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>;
def LHLO_Buffer : MemRefOf<[AnyFloat, AnyInteger, AnyComplex]>; def LHLO_Buffer : MemRefOf<[AnyFloat, AnyInteger, AnyComplex]>;
def LHLO_DimensionValue : AnyTypeOf<[Index, HLO_Pred, HLO_Int]>;
// Dynamic representation of a shape vector
def LHLO_DimensionBuffer : MemRefRankOf<[LHLO_DimensionValue], [1]>;
#endif // LHLO_OPS_BASE #endif // LHLO_OPS_BASE

View File

@ -1547,6 +1547,41 @@ static LogicalResult Verify(DynamicSliceOp op) {
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// RealDynamicSliceOp
//===----------------------------------------------------------------------===//
// Verifies that operand rank matches start_indices/limit_indices/strides size
static LogicalResult Verify(RealDynamicSliceOp op) {
auto input_type = op.operand().getType().dyn_cast<RankedTensorType>();
// If operand is unranked, there is very little to verify statically.
if (!input_type) return success();
int input_rank = input_type.getRank();
auto start_type = op.start_indices().getType().cast<RankedTensorType>();
auto limit_type = op.limit_indices().getType().cast<RankedTensorType>();
auto strides_type = op.strides().getType().cast<RankedTensorType>();
if (input_rank != start_type.getNumElements()) {
return op.emitOpError() << "has mismatched number of operand rank ("
<< input_rank << ") and start_indices size ("
<< start_type.getNumElements() << ")";
}
if (input_rank != limit_type.getNumElements()) {
return op.emitOpError() << "has mismatched number of operand rank ("
<< input_rank << ") and limit_indices size ("
<< limit_type.getNumElements() << ")";
}
if (input_rank != strides_type.getNumElements()) {
return op.emitOpError()
<< "has mismatched number of operand rank (" << input_rank
<< ") and strides size (" << strides_type.getNumElements() << ")";
}
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// InfeedOp // InfeedOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -2254,6 +2289,63 @@ OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
return DenseElementsAttr::get(return_type, result); return DenseElementsAttr::get(return_type, result);
} }
//===----------------------------------------------------------------------===//
// DynamicPadOp
//===----------------------------------------------------------------------===//
void DynamicPadOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
results.insert<DPadToPad>(context);
}
static LogicalResult Verify(DynamicPadOp op) {
auto input_type = op.operand().getType().dyn_cast<RankedTensorType>();
// If operand is unranked, there is very little to verify statically.
if (!input_type) return success();
int input_rank = input_type.getRank();
auto pad_type = op.padding_value().getType().cast<RankedTensorType>();
if (pad_type.getRank() != 0) {
return op.emitOpError() << "padding value type should be a rank-0";
}
auto padding_low_type =
op.edge_padding_low().getType().cast<RankedTensorType>();
if (padding_low_type.getNumElements() != input_rank) {
return op.emitOpError()
<< "edge_padding_low length(" << padding_low_type.getNumElements()
<< ") must match operand rank(" << input_rank << ").";
}
auto padding_high_type =
op.edge_padding_high().getType().cast<RankedTensorType>();
if (padding_high_type.getNumElements() != input_rank) {
return op.emitOpError()
<< "edge_padding_high length(" << padding_high_type.getNumElements()
<< ") must match operand rank(" << input_rank << ").";
}
auto interior_padding_type =
op.interior_padding().getType().cast<RankedTensorType>();
if (interior_padding_type.getNumElements() != input_rank) {
return op.emitOpError()
<< "edge_padding_interior length("
<< interior_padding_type.getNumElements()
<< ") must match operand rank(" << input_rank << ").";
}
auto output_type = op.getResult().getType().dyn_cast<RankedTensorType>();
// If result is unranked, there is very little to verify statically.
if (!output_type) return success();
int output_rank = output_type.getRank();
if (input_rank != output_rank) {
return op.emitOpError() << "operand rank(" << input_rank
<< ") must match result(" << output_rank << ").";
}
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ReshapeOp // ReshapeOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -41,3 +41,17 @@ def RemoveRedundantDynamicBroadcast : Pat<
(HLO_DynamicReshapeOp $operand, $shape), (HLO_DynamicReshapeOp $operand, $shape),
$shape, IdentityBroadcastDims:$dims), $shape, IdentityBroadcastDims:$dims),
(HLO_DynamicReshapeOp $operand, $shape)>; (HLO_DynamicReshapeOp $operand, $shape)>;
// Convert DPad to Pad if edge_padding_low, edge_padding_high and
// interior_paddin are HLO_ConstOp
def DPadToPad: Pat<
(HLO_DynamicPadOp HLO_Tensor:$input,
HLO_Tensor:$padding_value,
(HLO_ConstOp I64ElementsAttr:$edge_padding_low),
(HLO_ConstOp I64ElementsAttr:$edge_padding_high),
(HLO_ConstOp I64ElementsAttr:$interior_paddin)),
(HLO_PadOp $input, $padding_value,
(CastIntElementsAttr $edge_padding_low),
(CastIntElementsAttr $edge_padding_high),
(CastIntElementsAttr $interior_paddin))>;