[mlir][hlo] Make select ready for dynamic shapes (ranked only for now)

Move tf.SelectV2 broadcast lowering to a chlo.broadcast_select op, and lower it
to broadcasts on mhlo from there.

PiperOrigin-RevId: 358179975
This commit is contained in:
Benjamin Kramer 2021-02-18 08:07:36 -08:00 committed by TensorFlow MLIR Team
parent a6f03aecfb
commit ca4034b56e
4 changed files with 219 additions and 1 deletions

View File

@ -678,4 +678,34 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp<
"StringAttr":$comparison_direction, CArg<"StringAttr", "{}">:$compare_type)>];
}
//===----------------------------------------------------------------------===//
// Broadcasting select op
//===----------------------------------------------------------------------===//
def HLOClient_BroadcastSelectOp : HLOClient_Op<
"broadcast_select",
[NoSideEffect, DeclareOpInterfaceMethods<InferShapedTypeOpInterface>]> {
string summary = "Select operator (with optional numpy-style broadcasting)";
string description = [{
Constructs an output array from elements of two input arrays, based on the
values of a predicate array.
See https://www.tensorflow.org/xla/operation_semantics#select
}];
let arguments = (ins
HLO_PredTensor:$pred,
HLO_Tensor:$on_true,
HLO_Tensor:$on_false
);
let results = (outs HLO_Tensor);
let assemblyFormat = [{
$pred `,` $on_true `,` $on_false attr-dict `:`
`(` type($pred) `,` type($on_true) `,` type($on_false) `)` `->` type(results)
}];
}
#endif // CHLO_OPS

View File

@ -373,6 +373,31 @@ void ConstantLikeOp::getCanonicalizationPatterns(
results.insert<ConstantLikeToConstant>(context);
}
LogicalResult BroadcastSelectOp::inferReturnTypeComponents(
MLIRContext*, Optional<Location> location, ValueRange operands,
DictionaryAttr, RegionRange,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
BroadcastSelectOp::Adaptor op(operands);
auto pred_type = op.pred().getType().dyn_cast<ShapedType>();
auto on_true_type = op.on_true().getType().dyn_cast<ShapedType>();
auto on_false_type = op.on_false().getType().dyn_cast<ShapedType>();
if (!pred_type || !on_true_type || !on_false_type ||
on_true_type.getElementType() != on_false_type.getElementType()) {
return emitOptionalError(location, "mismatched operand types");
}
Type element_type = on_true_type.getElementType();
// Compute the result shape as two binary broadcasts.
Type other =
GetBroadcastType(on_true_type, on_false_type, element_type, nullptr);
Type output = GetBroadcastType(other, pred_type, element_type, nullptr);
inferredReturnShapes.push_back(output);
return success();
}
} // namespace chlo
} // namespace mlir

View File

@ -993,6 +993,90 @@ struct ConvertZetaOp : public OpConversionPattern<ZetaOp> {
}
};
struct ConvertSelectOp : public OpConversionPattern<BroadcastSelectOp> {
using OpConversionPattern<BroadcastSelectOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
BroadcastSelectOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// Only support ranked operands.
typename BroadcastSelectOp::Adaptor transformed(operands);
Value pred = transformed.pred();
Value on_true = transformed.on_true();
Value on_false = transformed.on_false();
auto pred_type = pred.getType().dyn_cast<RankedTensorType>();
auto on_true_type = on_true.getType().dyn_cast<RankedTensorType>();
auto on_false_type = on_false.getType().dyn_cast<RankedTensorType>();
auto result_type = op.getResult().getType().dyn_cast<RankedTensorType>();
if (!pred_type || !on_true_type || !on_false_type || !result_type) {
return failure();
}
auto loc = op.getLoc();
Value pred_shape = rewriter.createOrFold<shape::ShapeOfOp>(loc, pred);
Value on_true_shape = rewriter.createOrFold<shape::ShapeOfOp>(loc, on_true);
Value on_false_shape =
rewriter.createOrFold<shape::ShapeOfOp>(loc, on_false);
int64_t result_rank = std::max(
{pred_type.getRank(), on_true_type.getRank(), on_false_type.getRank()});
Value broadcastable_cstr =
rewriter.createOrFold<shape::CstrBroadcastableOp>(
loc, ValueRange{pred_shape, on_true_shape, on_false_shape});
auto assuming_op = rewriter.create<shape::AssumingOp>(
loc, ArrayRef<Type>{result_type}, broadcastable_cstr);
OpBuilder::InsertionGuard guard(rewriter);
rewriter.createBlock(&assuming_op.doRegion());
Value result_extents = rewriter.createOrFold<shape::BroadcastOp>(
loc, shape::getExtentTensorType(op.getContext()),
ValueRange{pred_shape, on_true_shape, on_false_shape},
/*error=*/nullptr);
auto shape_type =
RankedTensorType::get({result_rank}, rewriter.getIndexType());
result_extents =
rewriter.createOrFold<tensor::CastOp>(loc, shape_type, result_extents);
Value broadcasted_pred = pred;
// Pred has an implicit broadcast for scalars, so use that when convenient.
if (pred_type.getRank() > 0) {
auto pred_broadcast_dimensions = llvm::to_vector<4>(
llvm::seq<int64_t>(result_rank - pred_type.getRank(), result_rank));
broadcasted_pred = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
loc,
RankedTensorType::get(result_type.getShape(),
pred_type.getElementType()),
pred, result_extents,
rewriter.getI64TensorAttr(pred_broadcast_dimensions));
}
auto on_true_broadcast_dimensions = llvm::to_vector<4>(
llvm::seq<int64_t>(result_rank - on_true_type.getRank(), result_rank));
Value broadcasted_on_true = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
loc,
RankedTensorType::get(result_type.getShape(),
on_true_type.getElementType()),
on_true, result_extents,
rewriter.getI64TensorAttr(on_true_broadcast_dimensions));
auto on_false_broadcast_dimensions = llvm::to_vector<4>(
llvm::seq<int64_t>(result_rank - on_false_type.getRank(), result_rank));
Value broadcasted_on_false = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
loc,
RankedTensorType::get(result_type.getShape(),
on_false_type.getElementType()),
on_false, result_extents,
rewriter.getI64TensorAttr(on_false_broadcast_dimensions));
// And generate the final non-broadcasted ternary op.
Value final_result = rewriter.create<mhlo::SelectOp>(
loc, result_type, broadcasted_pred, broadcasted_on_true,
broadcasted_on_false);
rewriter.create<shape::AssumingYieldOp>(loc, final_result);
rewriter.replaceOp(op, {assuming_op.getResult(0)});
return success();
}
};
// Converts binary ops that statically are determined to not broadcast directly
// to the corresponding mhlo non-broadcasting op.
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
@ -1140,6 +1224,7 @@ void PopulateChloBroadcastingPatterns(MLIRContext *context,
context, patterns, 10);
PopulateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>(
context, patterns, 5);
patterns->insert<ConvertSelectOp>(context);
}
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt -chlo-legalize-to-hlo="broadcast-only=true" -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s
// RUN: mlir-hlo-opt -chlo-legalize-to-hlo="broadcast-only=true" -cse -canonicalize -split-input-file -verify-diagnostics %s -o - | FileCheck %s
// Check the non-broadcast case for each registered op, then just check a
// representative op for detailed broadcast semantics.
@ -72,6 +72,84 @@ func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
return %0 : tensor<?x?xi1>
}
// -----
// CHECK-LABEL: func @selectv2
func @selectv2(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
// CHECK-NEXT: "mhlo.select"(%arg0, %arg1, %arg2)
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
return %0: tensor<2xi32>
}
// CHECK-LABEL: func @selectv2_pred_scalar
func @selectv2_pred_scalar(%arg0: tensor<i1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
// CHECK-NEXT: "mhlo.select"(%arg0, %arg1, %arg2)
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
return %0: tensor<2xi32>
}
// CHECK-LABEL: func @selectv2_broadcast_then
func @selectv2_broadcast_then(%arg0: tensor<i1>, %arg1: tensor<8x1xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> {
// CHECK-NEXT: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<8x1xi32>) -> tensor<2x8x8xi32>
// CHECK-NEXT: "mhlo.select"(%arg0, %[[BROADCAST]], %arg2)
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<8x1xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32>
return %0: tensor<2x8x8xi32>
}
// CHECK-LABEL: func @selectv2_broadcast_else
func @selectv2_broadcast_else(%arg0: tensor<i1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<8x1xi32>) -> tensor<2x8x8xi32> {
// CHECK-NEXT: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg2) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<8x1xi32>) -> tensor<2x8x8xi32>
// CHECK-NEXT: "mhlo.select"(%arg0, %arg1, %[[BROADCAST]])
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2x8x8xi32>, tensor<8x1xi32>) -> tensor<2x8x8xi32>
return %0: tensor<2x8x8xi32>
}
// CHECK-LABEL: func @selectv2_broadcast_pred
func @selectv2_broadcast_pred(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> {
// CHECK-NEXT: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1xi1>) -> tensor<2x8x8xi1>
// CHECK-NEXT: "mhlo.select"(%[[BROADCAST]], %arg1, %arg2)
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32>
return %0: tensor<2x8x8xi32>
}
// CHECK-LABEL: func @selectv2_broadcast_tensor_pred
func @selectv2_broadcast_tensor_pred(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> {
// CHECK-NEXT: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi1>) -> tensor<2x3xi1>
// CHECK-NEXT: "mhlo.select"(%[[BROADCAST]], %arg1, %arg2)
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16>
return %0: tensor<2x3xf16>
}
// CHECK-LABEL: func @selectv2_broadcast_all
func @selectv2_broadcast_all(%arg0: tensor<8x1x1xi1>, %arg1: tensor<1x8x1xi32>, %arg2: tensor<1x1x8xi32>) -> tensor<8x8x8xi32> {
// CHECK-DAG: %[[BROADCAST_0:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x1x1xi1>) -> tensor<8x8x8xi1>
// CHECK-DAG: %[[BROADCAST_1:.*]] = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x8x1xi32>) -> tensor<8x8x8xi32>
// CHECK-DAG: %[[BROADCAST_2:.*]] = "mhlo.broadcast_in_dim"(%arg2) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x1x8xi32>) -> tensor<8x8x8xi32>
// CHECK: "mhlo.select"(%[[BROADCAST_0]], %[[BROADCAST_1]], %[[BROADCAST_2]])
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<8x1x1xi1>, tensor<1x8x1xi32>, tensor<1x1x8xi32>) -> tensor<8x8x8xi32>
return %0: tensor<8x8x8xi32>
}
// CHECK-LABEL: func @selectv2_dynamic_ranked
func @selectv2_dynamic_ranked(%arg0: tensor<1xi1>, %arg1: tensor<2x?x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x?x8xi32> {
// CHECK-NEXT: %[[SHAPE0:.*]] = shape.const_shape [1] : tensor<?xindex>
// CHECK-NEXT: %[[SHAPE2:.*]] = shape.const_shape [2, 8, 8] : tensor<?xindex>
// CHECK-NEXT: %[[SHAPE1:.*]] = shape.shape_of %arg1 : tensor<2x?x8xi32> -> tensor<?xindex>
// CHECK-NEXT: %[[CSTR:.*]] = shape.cstr_broadcastable %[[SHAPE1]], %[[SHAPE0]], %[[SHAPE2]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex>
// CHECK-NEXT: %[[ASSUME:.*]] = shape.assuming %[[CSTR]] -> (tensor<2x?x8xi32>) {
// CHECK-NEXT: %[[BCST_V:.*]] = shape.broadcast %[[SHAPE1]], %[[SHAPE0]], %[[SHAPE2]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[BCST:.*]] = tensor.cast %[[BCST_V]] : tensor<?xindex> to tensor<3xindex>
// CHECK-NEXT: %[[BCST0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[BCST]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1xi1>, tensor<3xindex>) -> tensor<2x?x8xi1>
// CHECK-NEXT: %[[BCST1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[BCST]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x?x8xi32>, tensor<3xindex>) -> tensor<2x?x8xi32>
// CHECK-NEXT: %[[BCST2:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg2, %[[BCST]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x8x8xi32>, tensor<3xindex>) -> tensor<2x?x8xi32>
// CHECK-NEXT: %[[SELECT:.*]] = "mhlo.select"(%[[BCST0]], %[[BCST1]], %[[BCST2]]) : (tensor<2x?x8xi1>, tensor<2x?x8xi32>, tensor<2x?x8xi32>) -> tensor<2x?x8xi32>
// CHECK-NEXT: shape.assuming_yield %[[SELECT]] : tensor<2x?x8xi32>
// CHECK-NEXT: }
// CHECK-NEXT: return %[[ASSUME]] : tensor<2x?x8xi32>
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x?x8xi32>, tensor<2x8x8xi32>) -> tensor<2x?x8xi32>
return %0: tensor<2x?x8xi32>
}
// -----
// Verifies that broadcast_dimensions validity checks are valid.
// CHECK-LABEL: @dynamicNonScalarBroadcastDimensions