[mlir][hlo] Make select ready for dynamic shapes (ranked only for now)
Move tf.SelectV2 broadcast lowering to a chlo.broadcast_select op, and lower it to broadcasts on mhlo from there. PiperOrigin-RevId: 358179975
This commit is contained in:
parent
a6f03aecfb
commit
ca4034b56e
|
@ -678,4 +678,34 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp<
|
||||||
"StringAttr":$comparison_direction, CArg<"StringAttr", "{}">:$compare_type)>];
|
"StringAttr":$comparison_direction, CArg<"StringAttr", "{}">:$compare_type)>];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Broadcasting select op
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def HLOClient_BroadcastSelectOp : HLOClient_Op<
|
||||||
|
"broadcast_select",
|
||||||
|
[NoSideEffect, DeclareOpInterfaceMethods<InferShapedTypeOpInterface>]> {
|
||||||
|
string summary = "Select operator (with optional numpy-style broadcasting)";
|
||||||
|
|
||||||
|
string description = [{
|
||||||
|
Constructs an output array from elements of two input arrays, based on the
|
||||||
|
values of a predicate array.
|
||||||
|
|
||||||
|
See https://www.tensorflow.org/xla/operation_semantics#select
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
HLO_PredTensor:$pred,
|
||||||
|
HLO_Tensor:$on_true,
|
||||||
|
HLO_Tensor:$on_false
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs HLO_Tensor);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$pred `,` $on_true `,` $on_false attr-dict `:`
|
||||||
|
`(` type($pred) `,` type($on_true) `,` type($on_false) `)` `->` type(results)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
#endif // CHLO_OPS
|
#endif // CHLO_OPS
|
||||||
|
|
|
@ -373,6 +373,31 @@ void ConstantLikeOp::getCanonicalizationPatterns(
|
||||||
results.insert<ConstantLikeToConstant>(context);
|
results.insert<ConstantLikeToConstant>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult BroadcastSelectOp::inferReturnTypeComponents(
|
||||||
|
MLIRContext*, Optional<Location> location, ValueRange operands,
|
||||||
|
DictionaryAttr, RegionRange,
|
||||||
|
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
|
||||||
|
BroadcastSelectOp::Adaptor op(operands);
|
||||||
|
auto pred_type = op.pred().getType().dyn_cast<ShapedType>();
|
||||||
|
auto on_true_type = op.on_true().getType().dyn_cast<ShapedType>();
|
||||||
|
auto on_false_type = op.on_false().getType().dyn_cast<ShapedType>();
|
||||||
|
|
||||||
|
if (!pred_type || !on_true_type || !on_false_type ||
|
||||||
|
on_true_type.getElementType() != on_false_type.getElementType()) {
|
||||||
|
return emitOptionalError(location, "mismatched operand types");
|
||||||
|
}
|
||||||
|
|
||||||
|
Type element_type = on_true_type.getElementType();
|
||||||
|
|
||||||
|
// Compute the result shape as two binary broadcasts.
|
||||||
|
Type other =
|
||||||
|
GetBroadcastType(on_true_type, on_false_type, element_type, nullptr);
|
||||||
|
Type output = GetBroadcastType(other, pred_type, element_type, nullptr);
|
||||||
|
|
||||||
|
inferredReturnShapes.push_back(output);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace chlo
|
} // namespace chlo
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -993,6 +993,90 @@ struct ConvertZetaOp : public OpConversionPattern<ZetaOp> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct ConvertSelectOp : public OpConversionPattern<BroadcastSelectOp> {
|
||||||
|
using OpConversionPattern<BroadcastSelectOp>::OpConversionPattern;
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
BroadcastSelectOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
// Only support ranked operands.
|
||||||
|
typename BroadcastSelectOp::Adaptor transformed(operands);
|
||||||
|
Value pred = transformed.pred();
|
||||||
|
Value on_true = transformed.on_true();
|
||||||
|
Value on_false = transformed.on_false();
|
||||||
|
auto pred_type = pred.getType().dyn_cast<RankedTensorType>();
|
||||||
|
auto on_true_type = on_true.getType().dyn_cast<RankedTensorType>();
|
||||||
|
auto on_false_type = on_false.getType().dyn_cast<RankedTensorType>();
|
||||||
|
auto result_type = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!pred_type || !on_true_type || !on_false_type || !result_type) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
|
||||||
|
Value pred_shape = rewriter.createOrFold<shape::ShapeOfOp>(loc, pred);
|
||||||
|
Value on_true_shape = rewriter.createOrFold<shape::ShapeOfOp>(loc, on_true);
|
||||||
|
Value on_false_shape =
|
||||||
|
rewriter.createOrFold<shape::ShapeOfOp>(loc, on_false);
|
||||||
|
int64_t result_rank = std::max(
|
||||||
|
{pred_type.getRank(), on_true_type.getRank(), on_false_type.getRank()});
|
||||||
|
|
||||||
|
Value broadcastable_cstr =
|
||||||
|
rewriter.createOrFold<shape::CstrBroadcastableOp>(
|
||||||
|
loc, ValueRange{pred_shape, on_true_shape, on_false_shape});
|
||||||
|
auto assuming_op = rewriter.create<shape::AssumingOp>(
|
||||||
|
loc, ArrayRef<Type>{result_type}, broadcastable_cstr);
|
||||||
|
|
||||||
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
|
rewriter.createBlock(&assuming_op.doRegion());
|
||||||
|
|
||||||
|
Value result_extents = rewriter.createOrFold<shape::BroadcastOp>(
|
||||||
|
loc, shape::getExtentTensorType(op.getContext()),
|
||||||
|
ValueRange{pred_shape, on_true_shape, on_false_shape},
|
||||||
|
/*error=*/nullptr);
|
||||||
|
auto shape_type =
|
||||||
|
RankedTensorType::get({result_rank}, rewriter.getIndexType());
|
||||||
|
result_extents =
|
||||||
|
rewriter.createOrFold<tensor::CastOp>(loc, shape_type, result_extents);
|
||||||
|
|
||||||
|
Value broadcasted_pred = pred;
|
||||||
|
// Pred has an implicit broadcast for scalars, so use that when convenient.
|
||||||
|
if (pred_type.getRank() > 0) {
|
||||||
|
auto pred_broadcast_dimensions = llvm::to_vector<4>(
|
||||||
|
llvm::seq<int64_t>(result_rank - pred_type.getRank(), result_rank));
|
||||||
|
broadcasted_pred = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
|
||||||
|
loc,
|
||||||
|
RankedTensorType::get(result_type.getShape(),
|
||||||
|
pred_type.getElementType()),
|
||||||
|
pred, result_extents,
|
||||||
|
rewriter.getI64TensorAttr(pred_broadcast_dimensions));
|
||||||
|
}
|
||||||
|
auto on_true_broadcast_dimensions = llvm::to_vector<4>(
|
||||||
|
llvm::seq<int64_t>(result_rank - on_true_type.getRank(), result_rank));
|
||||||
|
Value broadcasted_on_true = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
|
||||||
|
loc,
|
||||||
|
RankedTensorType::get(result_type.getShape(),
|
||||||
|
on_true_type.getElementType()),
|
||||||
|
on_true, result_extents,
|
||||||
|
rewriter.getI64TensorAttr(on_true_broadcast_dimensions));
|
||||||
|
auto on_false_broadcast_dimensions = llvm::to_vector<4>(
|
||||||
|
llvm::seq<int64_t>(result_rank - on_false_type.getRank(), result_rank));
|
||||||
|
Value broadcasted_on_false = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
|
||||||
|
loc,
|
||||||
|
RankedTensorType::get(result_type.getShape(),
|
||||||
|
on_false_type.getElementType()),
|
||||||
|
on_false, result_extents,
|
||||||
|
rewriter.getI64TensorAttr(on_false_broadcast_dimensions));
|
||||||
|
|
||||||
|
// And generate the final non-broadcasted ternary op.
|
||||||
|
Value final_result = rewriter.create<mhlo::SelectOp>(
|
||||||
|
loc, result_type, broadcasted_pred, broadcasted_on_true,
|
||||||
|
broadcasted_on_false);
|
||||||
|
rewriter.create<shape::AssumingYieldOp>(loc, final_result);
|
||||||
|
rewriter.replaceOp(op, {assuming_op.getResult(0)});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Converts binary ops that statically are determined to not broadcast directly
|
// Converts binary ops that statically are determined to not broadcast directly
|
||||||
// to the corresponding mhlo non-broadcasting op.
|
// to the corresponding mhlo non-broadcasting op.
|
||||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||||
|
@ -1140,6 +1224,7 @@ void PopulateChloBroadcastingPatterns(MLIRContext *context,
|
||||||
context, patterns, 10);
|
context, patterns, 10);
|
||||||
PopulateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>(
|
PopulateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>(
|
||||||
context, patterns, 5);
|
context, patterns, 5);
|
||||||
|
patterns->insert<ConvertSelectOp>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: mlir-hlo-opt -chlo-legalize-to-hlo="broadcast-only=true" -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s
|
// RUN: mlir-hlo-opt -chlo-legalize-to-hlo="broadcast-only=true" -cse -canonicalize -split-input-file -verify-diagnostics %s -o - | FileCheck %s
|
||||||
|
|
||||||
// Check the non-broadcast case for each registered op, then just check a
|
// Check the non-broadcast case for each registered op, then just check a
|
||||||
// representative op for detailed broadcast semantics.
|
// representative op for detailed broadcast semantics.
|
||||||
|
@ -72,6 +72,84 @@ func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
|
||||||
return %0 : tensor<?x?xi1>
|
return %0 : tensor<?x?xi1>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @selectv2
|
||||||
|
func @selectv2(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
|
||||||
|
// CHECK-NEXT: "mhlo.select"(%arg0, %arg1, %arg2)
|
||||||
|
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||||
|
return %0: tensor<2xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @selectv2_pred_scalar
|
||||||
|
func @selectv2_pred_scalar(%arg0: tensor<i1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
|
||||||
|
// CHECK-NEXT: "mhlo.select"(%arg0, %arg1, %arg2)
|
||||||
|
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||||
|
return %0: tensor<2xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @selectv2_broadcast_then
|
||||||
|
func @selectv2_broadcast_then(%arg0: tensor<i1>, %arg1: tensor<8x1xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> {
|
||||||
|
// CHECK-NEXT: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<8x1xi32>) -> tensor<2x8x8xi32>
|
||||||
|
// CHECK-NEXT: "mhlo.select"(%arg0, %[[BROADCAST]], %arg2)
|
||||||
|
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<8x1xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32>
|
||||||
|
return %0: tensor<2x8x8xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @selectv2_broadcast_else
|
||||||
|
func @selectv2_broadcast_else(%arg0: tensor<i1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<8x1xi32>) -> tensor<2x8x8xi32> {
|
||||||
|
// CHECK-NEXT: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg2) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<8x1xi32>) -> tensor<2x8x8xi32>
|
||||||
|
// CHECK-NEXT: "mhlo.select"(%arg0, %arg1, %[[BROADCAST]])
|
||||||
|
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2x8x8xi32>, tensor<8x1xi32>) -> tensor<2x8x8xi32>
|
||||||
|
return %0: tensor<2x8x8xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @selectv2_broadcast_pred
|
||||||
|
func @selectv2_broadcast_pred(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> {
|
||||||
|
// CHECK-NEXT: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1xi1>) -> tensor<2x8x8xi1>
|
||||||
|
// CHECK-NEXT: "mhlo.select"(%[[BROADCAST]], %arg1, %arg2)
|
||||||
|
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32>
|
||||||
|
return %0: tensor<2x8x8xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @selectv2_broadcast_tensor_pred
|
||||||
|
func @selectv2_broadcast_tensor_pred(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> {
|
||||||
|
// CHECK-NEXT: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi1>) -> tensor<2x3xi1>
|
||||||
|
// CHECK-NEXT: "mhlo.select"(%[[BROADCAST]], %arg1, %arg2)
|
||||||
|
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16>
|
||||||
|
return %0: tensor<2x3xf16>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @selectv2_broadcast_all
|
||||||
|
func @selectv2_broadcast_all(%arg0: tensor<8x1x1xi1>, %arg1: tensor<1x8x1xi32>, %arg2: tensor<1x1x8xi32>) -> tensor<8x8x8xi32> {
|
||||||
|
// CHECK-DAG: %[[BROADCAST_0:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x1x1xi1>) -> tensor<8x8x8xi1>
|
||||||
|
// CHECK-DAG: %[[BROADCAST_1:.*]] = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x8x1xi32>) -> tensor<8x8x8xi32>
|
||||||
|
// CHECK-DAG: %[[BROADCAST_2:.*]] = "mhlo.broadcast_in_dim"(%arg2) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x1x8xi32>) -> tensor<8x8x8xi32>
|
||||||
|
// CHECK: "mhlo.select"(%[[BROADCAST_0]], %[[BROADCAST_1]], %[[BROADCAST_2]])
|
||||||
|
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<8x1x1xi1>, tensor<1x8x1xi32>, tensor<1x1x8xi32>) -> tensor<8x8x8xi32>
|
||||||
|
return %0: tensor<8x8x8xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @selectv2_dynamic_ranked
|
||||||
|
func @selectv2_dynamic_ranked(%arg0: tensor<1xi1>, %arg1: tensor<2x?x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x?x8xi32> {
|
||||||
|
// CHECK-NEXT: %[[SHAPE0:.*]] = shape.const_shape [1] : tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[SHAPE2:.*]] = shape.const_shape [2, 8, 8] : tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[SHAPE1:.*]] = shape.shape_of %arg1 : tensor<2x?x8xi32> -> tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[CSTR:.*]] = shape.cstr_broadcastable %[[SHAPE1]], %[[SHAPE0]], %[[SHAPE2]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[ASSUME:.*]] = shape.assuming %[[CSTR]] -> (tensor<2x?x8xi32>) {
|
||||||
|
// CHECK-NEXT: %[[BCST_V:.*]] = shape.broadcast %[[SHAPE1]], %[[SHAPE0]], %[[SHAPE2]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[BCST:.*]] = tensor.cast %[[BCST_V]] : tensor<?xindex> to tensor<3xindex>
|
||||||
|
// CHECK-NEXT: %[[BCST0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[BCST]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1xi1>, tensor<3xindex>) -> tensor<2x?x8xi1>
|
||||||
|
// CHECK-NEXT: %[[BCST1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[BCST]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x?x8xi32>, tensor<3xindex>) -> tensor<2x?x8xi32>
|
||||||
|
// CHECK-NEXT: %[[BCST2:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg2, %[[BCST]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x8x8xi32>, tensor<3xindex>) -> tensor<2x?x8xi32>
|
||||||
|
// CHECK-NEXT: %[[SELECT:.*]] = "mhlo.select"(%[[BCST0]], %[[BCST1]], %[[BCST2]]) : (tensor<2x?x8xi1>, tensor<2x?x8xi32>, tensor<2x?x8xi32>) -> tensor<2x?x8xi32>
|
||||||
|
// CHECK-NEXT: shape.assuming_yield %[[SELECT]] : tensor<2x?x8xi32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: return %[[ASSUME]] : tensor<2x?x8xi32>
|
||||||
|
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x?x8xi32>, tensor<2x8x8xi32>) -> tensor<2x?x8xi32>
|
||||||
|
return %0: tensor<2x?x8xi32>
|
||||||
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
// Verifies that broadcast_dimensions validity checks are valid.
|
// Verifies that broadcast_dimensions validity checks are valid.
|
||||||
// CHECK-LABEL: @dynamicNonScalarBroadcastDimensions
|
// CHECK-LABEL: @dynamicNonScalarBroadcastDimensions
|
||||||
|
|
Loading…
Reference in New Issue