[mlir][hlo] Make select ready for dynamic shapes (ranked only for now)
Move tf.SelectV2 broadcast lowering to a chlo.broadcast_select op, and lower it to broadcasts on mhlo from there. PiperOrigin-RevId: 358179975
This commit is contained in:
parent
a6f03aecfb
commit
ca4034b56e
|
@ -678,4 +678,34 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp<
|
|||
"StringAttr":$comparison_direction, CArg<"StringAttr", "{}">:$compare_type)>];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Broadcasting select op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def HLOClient_BroadcastSelectOp : HLOClient_Op<
|
||||
"broadcast_select",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<InferShapedTypeOpInterface>]> {
|
||||
string summary = "Select operator (with optional numpy-style broadcasting)";
|
||||
|
||||
string description = [{
|
||||
Constructs an output array from elements of two input arrays, based on the
|
||||
values of a predicate array.
|
||||
|
||||
See https://www.tensorflow.org/xla/operation_semantics#select
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
HLO_PredTensor:$pred,
|
||||
HLO_Tensor:$on_true,
|
||||
HLO_Tensor:$on_false
|
||||
);
|
||||
|
||||
let results = (outs HLO_Tensor);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$pred `,` $on_true `,` $on_false attr-dict `:`
|
||||
`(` type($pred) `,` type($on_true) `,` type($on_false) `)` `->` type(results)
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // CHLO_OPS
|
||||
|
|
|
@ -373,6 +373,31 @@ void ConstantLikeOp::getCanonicalizationPatterns(
|
|||
results.insert<ConstantLikeToConstant>(context);
|
||||
}
|
||||
|
||||
LogicalResult BroadcastSelectOp::inferReturnTypeComponents(
|
||||
MLIRContext*, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr, RegionRange,
|
||||
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
|
||||
BroadcastSelectOp::Adaptor op(operands);
|
||||
auto pred_type = op.pred().getType().dyn_cast<ShapedType>();
|
||||
auto on_true_type = op.on_true().getType().dyn_cast<ShapedType>();
|
||||
auto on_false_type = op.on_false().getType().dyn_cast<ShapedType>();
|
||||
|
||||
if (!pred_type || !on_true_type || !on_false_type ||
|
||||
on_true_type.getElementType() != on_false_type.getElementType()) {
|
||||
return emitOptionalError(location, "mismatched operand types");
|
||||
}
|
||||
|
||||
Type element_type = on_true_type.getElementType();
|
||||
|
||||
// Compute the result shape as two binary broadcasts.
|
||||
Type other =
|
||||
GetBroadcastType(on_true_type, on_false_type, element_type, nullptr);
|
||||
Type output = GetBroadcastType(other, pred_type, element_type, nullptr);
|
||||
|
||||
inferredReturnShapes.push_back(output);
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace chlo
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -993,6 +993,90 @@ struct ConvertZetaOp : public OpConversionPattern<ZetaOp> {
|
|||
}
|
||||
};
|
||||
|
||||
struct ConvertSelectOp : public OpConversionPattern<BroadcastSelectOp> {
|
||||
using OpConversionPattern<BroadcastSelectOp>::OpConversionPattern;
|
||||
LogicalResult matchAndRewrite(
|
||||
BroadcastSelectOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Only support ranked operands.
|
||||
typename BroadcastSelectOp::Adaptor transformed(operands);
|
||||
Value pred = transformed.pred();
|
||||
Value on_true = transformed.on_true();
|
||||
Value on_false = transformed.on_false();
|
||||
auto pred_type = pred.getType().dyn_cast<RankedTensorType>();
|
||||
auto on_true_type = on_true.getType().dyn_cast<RankedTensorType>();
|
||||
auto on_false_type = on_false.getType().dyn_cast<RankedTensorType>();
|
||||
auto result_type = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
if (!pred_type || !on_true_type || !on_false_type || !result_type) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto loc = op.getLoc();
|
||||
|
||||
Value pred_shape = rewriter.createOrFold<shape::ShapeOfOp>(loc, pred);
|
||||
Value on_true_shape = rewriter.createOrFold<shape::ShapeOfOp>(loc, on_true);
|
||||
Value on_false_shape =
|
||||
rewriter.createOrFold<shape::ShapeOfOp>(loc, on_false);
|
||||
int64_t result_rank = std::max(
|
||||
{pred_type.getRank(), on_true_type.getRank(), on_false_type.getRank()});
|
||||
|
||||
Value broadcastable_cstr =
|
||||
rewriter.createOrFold<shape::CstrBroadcastableOp>(
|
||||
loc, ValueRange{pred_shape, on_true_shape, on_false_shape});
|
||||
auto assuming_op = rewriter.create<shape::AssumingOp>(
|
||||
loc, ArrayRef<Type>{result_type}, broadcastable_cstr);
|
||||
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.createBlock(&assuming_op.doRegion());
|
||||
|
||||
Value result_extents = rewriter.createOrFold<shape::BroadcastOp>(
|
||||
loc, shape::getExtentTensorType(op.getContext()),
|
||||
ValueRange{pred_shape, on_true_shape, on_false_shape},
|
||||
/*error=*/nullptr);
|
||||
auto shape_type =
|
||||
RankedTensorType::get({result_rank}, rewriter.getIndexType());
|
||||
result_extents =
|
||||
rewriter.createOrFold<tensor::CastOp>(loc, shape_type, result_extents);
|
||||
|
||||
Value broadcasted_pred = pred;
|
||||
// Pred has an implicit broadcast for scalars, so use that when convenient.
|
||||
if (pred_type.getRank() > 0) {
|
||||
auto pred_broadcast_dimensions = llvm::to_vector<4>(
|
||||
llvm::seq<int64_t>(result_rank - pred_type.getRank(), result_rank));
|
||||
broadcasted_pred = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
|
||||
loc,
|
||||
RankedTensorType::get(result_type.getShape(),
|
||||
pred_type.getElementType()),
|
||||
pred, result_extents,
|
||||
rewriter.getI64TensorAttr(pred_broadcast_dimensions));
|
||||
}
|
||||
auto on_true_broadcast_dimensions = llvm::to_vector<4>(
|
||||
llvm::seq<int64_t>(result_rank - on_true_type.getRank(), result_rank));
|
||||
Value broadcasted_on_true = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
|
||||
loc,
|
||||
RankedTensorType::get(result_type.getShape(),
|
||||
on_true_type.getElementType()),
|
||||
on_true, result_extents,
|
||||
rewriter.getI64TensorAttr(on_true_broadcast_dimensions));
|
||||
auto on_false_broadcast_dimensions = llvm::to_vector<4>(
|
||||
llvm::seq<int64_t>(result_rank - on_false_type.getRank(), result_rank));
|
||||
Value broadcasted_on_false = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
|
||||
loc,
|
||||
RankedTensorType::get(result_type.getShape(),
|
||||
on_false_type.getElementType()),
|
||||
on_false, result_extents,
|
||||
rewriter.getI64TensorAttr(on_false_broadcast_dimensions));
|
||||
|
||||
// And generate the final non-broadcasted ternary op.
|
||||
Value final_result = rewriter.create<mhlo::SelectOp>(
|
||||
loc, result_type, broadcasted_pred, broadcasted_on_true,
|
||||
broadcasted_on_false);
|
||||
rewriter.create<shape::AssumingYieldOp>(loc, final_result);
|
||||
rewriter.replaceOp(op, {assuming_op.getResult(0)});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Converts binary ops that statically are determined to not broadcast directly
|
||||
// to the corresponding mhlo non-broadcasting op.
|
||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||
|
@ -1140,6 +1224,7 @@ void PopulateChloBroadcastingPatterns(MLIRContext *context,
|
|||
context, patterns, 10);
|
||||
PopulateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>(
|
||||
context, patterns, 5);
|
||||
patterns->insert<ConvertSelectOp>(context);
|
||||
}
|
||||
|
||||
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-hlo-opt -chlo-legalize-to-hlo="broadcast-only=true" -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s
|
||||
// RUN: mlir-hlo-opt -chlo-legalize-to-hlo="broadcast-only=true" -cse -canonicalize -split-input-file -verify-diagnostics %s -o - | FileCheck %s
|
||||
|
||||
// Check the non-broadcast case for each registered op, then just check a
|
||||
// representative op for detailed broadcast semantics.
|
||||
|
@ -72,6 +72,84 @@ func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
|
|||
return %0 : tensor<?x?xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @selectv2
|
||||
func @selectv2(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
|
||||
// CHECK-NEXT: "mhlo.select"(%arg0, %arg1, %arg2)
|
||||
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||
return %0: tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @selectv2_pred_scalar
|
||||
func @selectv2_pred_scalar(%arg0: tensor<i1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
|
||||
// CHECK-NEXT: "mhlo.select"(%arg0, %arg1, %arg2)
|
||||
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||
return %0: tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @selectv2_broadcast_then
|
||||
func @selectv2_broadcast_then(%arg0: tensor<i1>, %arg1: tensor<8x1xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> {
|
||||
// CHECK-NEXT: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<8x1xi32>) -> tensor<2x8x8xi32>
|
||||
// CHECK-NEXT: "mhlo.select"(%arg0, %[[BROADCAST]], %arg2)
|
||||
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<8x1xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32>
|
||||
return %0: tensor<2x8x8xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @selectv2_broadcast_else
|
||||
func @selectv2_broadcast_else(%arg0: tensor<i1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<8x1xi32>) -> tensor<2x8x8xi32> {
|
||||
// CHECK-NEXT: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg2) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<8x1xi32>) -> tensor<2x8x8xi32>
|
||||
// CHECK-NEXT: "mhlo.select"(%arg0, %arg1, %[[BROADCAST]])
|
||||
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2x8x8xi32>, tensor<8x1xi32>) -> tensor<2x8x8xi32>
|
||||
return %0: tensor<2x8x8xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @selectv2_broadcast_pred
|
||||
func @selectv2_broadcast_pred(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> {
|
||||
// CHECK-NEXT: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1xi1>) -> tensor<2x8x8xi1>
|
||||
// CHECK-NEXT: "mhlo.select"(%[[BROADCAST]], %arg1, %arg2)
|
||||
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32>
|
||||
return %0: tensor<2x8x8xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @selectv2_broadcast_tensor_pred
|
||||
func @selectv2_broadcast_tensor_pred(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> {
|
||||
// CHECK-NEXT: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi1>) -> tensor<2x3xi1>
|
||||
// CHECK-NEXT: "mhlo.select"(%[[BROADCAST]], %arg1, %arg2)
|
||||
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16>
|
||||
return %0: tensor<2x3xf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @selectv2_broadcast_all
|
||||
func @selectv2_broadcast_all(%arg0: tensor<8x1x1xi1>, %arg1: tensor<1x8x1xi32>, %arg2: tensor<1x1x8xi32>) -> tensor<8x8x8xi32> {
|
||||
// CHECK-DAG: %[[BROADCAST_0:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x1x1xi1>) -> tensor<8x8x8xi1>
|
||||
// CHECK-DAG: %[[BROADCAST_1:.*]] = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x8x1xi32>) -> tensor<8x8x8xi32>
|
||||
// CHECK-DAG: %[[BROADCAST_2:.*]] = "mhlo.broadcast_in_dim"(%arg2) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x1x8xi32>) -> tensor<8x8x8xi32>
|
||||
// CHECK: "mhlo.select"(%[[BROADCAST_0]], %[[BROADCAST_1]], %[[BROADCAST_2]])
|
||||
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<8x1x1xi1>, tensor<1x8x1xi32>, tensor<1x1x8xi32>) -> tensor<8x8x8xi32>
|
||||
return %0: tensor<8x8x8xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @selectv2_dynamic_ranked
|
||||
func @selectv2_dynamic_ranked(%arg0: tensor<1xi1>, %arg1: tensor<2x?x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x?x8xi32> {
|
||||
// CHECK-NEXT: %[[SHAPE0:.*]] = shape.const_shape [1] : tensor<?xindex>
|
||||
// CHECK-NEXT: %[[SHAPE2:.*]] = shape.const_shape [2, 8, 8] : tensor<?xindex>
|
||||
// CHECK-NEXT: %[[SHAPE1:.*]] = shape.shape_of %arg1 : tensor<2x?x8xi32> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CSTR:.*]] = shape.cstr_broadcastable %[[SHAPE1]], %[[SHAPE0]], %[[SHAPE2]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex>
|
||||
// CHECK-NEXT: %[[ASSUME:.*]] = shape.assuming %[[CSTR]] -> (tensor<2x?x8xi32>) {
|
||||
// CHECK-NEXT: %[[BCST_V:.*]] = shape.broadcast %[[SHAPE1]], %[[SHAPE0]], %[[SHAPE2]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[BCST:.*]] = tensor.cast %[[BCST_V]] : tensor<?xindex> to tensor<3xindex>
|
||||
// CHECK-NEXT: %[[BCST0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[BCST]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1xi1>, tensor<3xindex>) -> tensor<2x?x8xi1>
|
||||
// CHECK-NEXT: %[[BCST1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[BCST]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x?x8xi32>, tensor<3xindex>) -> tensor<2x?x8xi32>
|
||||
// CHECK-NEXT: %[[BCST2:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg2, %[[BCST]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x8x8xi32>, tensor<3xindex>) -> tensor<2x?x8xi32>
|
||||
// CHECK-NEXT: %[[SELECT:.*]] = "mhlo.select"(%[[BCST0]], %[[BCST1]], %[[BCST2]]) : (tensor<2x?x8xi1>, tensor<2x?x8xi32>, tensor<2x?x8xi32>) -> tensor<2x?x8xi32>
|
||||
// CHECK-NEXT: shape.assuming_yield %[[SELECT]] : tensor<2x?x8xi32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[ASSUME]] : tensor<2x?x8xi32>
|
||||
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x?x8xi32>, tensor<2x8x8xi32>) -> tensor<2x?x8xi32>
|
||||
return %0: tensor<2x?x8xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// Verifies that broadcast_dimensions validity checks are valid.
|
||||
// CHECK-LABEL: @dynamicNonScalarBroadcastDimensions
|
||||
|
|
Loading…
Reference in New Issue