diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index 13d0f08..c9db345 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -678,4 +678,34 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp< "StringAttr":$comparison_direction, CArg<"StringAttr", "{}">:$compare_type)>]; } +//===----------------------------------------------------------------------===// +// Broadcasting select op +//===----------------------------------------------------------------------===// + +def HLOClient_BroadcastSelectOp : HLOClient_Op< + "broadcast_select", + [NoSideEffect, DeclareOpInterfaceMethods]> { + string summary = "Select operator (with optional numpy-style broadcasting)"; + + string description = [{ + Constructs an output array from elements of two input arrays, based on the + values of a predicate array. + + See https://www.tensorflow.org/xla/operation_semantics#select + }]; + + let arguments = (ins + HLO_PredTensor:$pred, + HLO_Tensor:$on_true, + HLO_Tensor:$on_false + ); + + let results = (outs HLO_Tensor); + + let assemblyFormat = [{ + $pred `,` $on_true `,` $on_false attr-dict `:` + `(` type($pred) `,` type($on_true) `,` type($on_false) `)` `->` type(results) + }]; +} + #endif // CHLO_OPS diff --git a/lib/Dialect/mhlo/IR/chlo_ops.cc b/lib/Dialect/mhlo/IR/chlo_ops.cc index 31a1ee3..57ae271 100644 --- a/lib/Dialect/mhlo/IR/chlo_ops.cc +++ b/lib/Dialect/mhlo/IR/chlo_ops.cc @@ -373,6 +373,31 @@ void ConstantLikeOp::getCanonicalizationPatterns( results.insert(context); } +LogicalResult BroadcastSelectOp::inferReturnTypeComponents( + MLIRContext*, Optional location, ValueRange operands, + DictionaryAttr, RegionRange, + SmallVectorImpl& inferredReturnShapes) { + BroadcastSelectOp::Adaptor op(operands); + auto pred_type = op.pred().getType().dyn_cast(); + auto on_true_type = op.on_true().getType().dyn_cast(); + auto on_false_type = op.on_false().getType().dyn_cast(); + + if (!pred_type || !on_true_type || !on_false_type || + on_true_type.getElementType() != on_false_type.getElementType()) { + return emitOptionalError(location, "mismatched operand types"); + } + + Type element_type = on_true_type.getElementType(); + + // Compute the result shape as two binary broadcasts. + Type other = + GetBroadcastType(on_true_type, on_false_type, element_type, nullptr); + Type output = GetBroadcastType(other, pred_type, element_type, nullptr); + + inferredReturnShapes.push_back(output); + return success(); +} + } // namespace chlo } // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc index fabf8ef..d3679ba 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -993,6 +993,90 @@ struct ConvertZetaOp : public OpConversionPattern { } }; +struct ConvertSelectOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + BroadcastSelectOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + // Only support ranked operands. + typename BroadcastSelectOp::Adaptor transformed(operands); + Value pred = transformed.pred(); + Value on_true = transformed.on_true(); + Value on_false = transformed.on_false(); + auto pred_type = pred.getType().dyn_cast(); + auto on_true_type = on_true.getType().dyn_cast(); + auto on_false_type = on_false.getType().dyn_cast(); + auto result_type = op.getResult().getType().dyn_cast(); + if (!pred_type || !on_true_type || !on_false_type || !result_type) { + return failure(); + } + + auto loc = op.getLoc(); + + Value pred_shape = rewriter.createOrFold(loc, pred); + Value on_true_shape = rewriter.createOrFold(loc, on_true); + Value on_false_shape = + rewriter.createOrFold(loc, on_false); + int64_t result_rank = std::max( + {pred_type.getRank(), on_true_type.getRank(), on_false_type.getRank()}); + + Value broadcastable_cstr = + rewriter.createOrFold( + loc, ValueRange{pred_shape, on_true_shape, on_false_shape}); + auto assuming_op = rewriter.create( + loc, ArrayRef{result_type}, broadcastable_cstr); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.createBlock(&assuming_op.doRegion()); + + Value result_extents = rewriter.createOrFold( + loc, shape::getExtentTensorType(op.getContext()), + ValueRange{pred_shape, on_true_shape, on_false_shape}, + /*error=*/nullptr); + auto shape_type = + RankedTensorType::get({result_rank}, rewriter.getIndexType()); + result_extents = + rewriter.createOrFold(loc, shape_type, result_extents); + + Value broadcasted_pred = pred; + // Pred has an implicit broadcast for scalars, so use that when convenient. + if (pred_type.getRank() > 0) { + auto pred_broadcast_dimensions = llvm::to_vector<4>( + llvm::seq(result_rank - pred_type.getRank(), result_rank)); + broadcasted_pred = rewriter.create( + loc, + RankedTensorType::get(result_type.getShape(), + pred_type.getElementType()), + pred, result_extents, + rewriter.getI64TensorAttr(pred_broadcast_dimensions)); + } + auto on_true_broadcast_dimensions = llvm::to_vector<4>( + llvm::seq(result_rank - on_true_type.getRank(), result_rank)); + Value broadcasted_on_true = rewriter.create( + loc, + RankedTensorType::get(result_type.getShape(), + on_true_type.getElementType()), + on_true, result_extents, + rewriter.getI64TensorAttr(on_true_broadcast_dimensions)); + auto on_false_broadcast_dimensions = llvm::to_vector<4>( + llvm::seq(result_rank - on_false_type.getRank(), result_rank)); + Value broadcasted_on_false = rewriter.create( + loc, + RankedTensorType::get(result_type.getShape(), + on_false_type.getElementType()), + on_false, result_extents, + rewriter.getI64TensorAttr(on_false_broadcast_dimensions)); + + // And generate the final non-broadcasted ternary op. + Value final_result = rewriter.create( + loc, result_type, broadcasted_pred, broadcasted_on_true, + broadcasted_on_false); + rewriter.create(loc, final_result); + rewriter.replaceOp(op, {assuming_op.getResult(0)}); + return success(); + } +}; + // Converts binary ops that statically are determined to not broadcast directly // to the corresponding mhlo non-broadcasting op. template @@ -1140,6 +1224,7 @@ void PopulateChloBroadcastingPatterns(MLIRContext *context, context, patterns, 10); PopulateForBroadcastingBinaryOp( context, patterns, 5); + patterns->insert(context); } void PopulateLegalizeChloToHloPatterns(MLIRContext *context, diff --git a/tests/chlo_legalize_to_hlo_broadcasts.mlir b/tests/chlo_legalize_to_hlo_broadcasts.mlir index efe96cf..4e4762f 100644 --- a/tests/chlo_legalize_to_hlo_broadcasts.mlir +++ b/tests/chlo_legalize_to_hlo_broadcasts.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt -chlo-legalize-to-hlo="broadcast-only=true" -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s +// RUN: mlir-hlo-opt -chlo-legalize-to-hlo="broadcast-only=true" -cse -canonicalize -split-input-file -verify-diagnostics %s -o - | FileCheck %s // Check the non-broadcast case for each registered op, then just check a // representative op for detailed broadcast semantics. @@ -72,6 +72,84 @@ func @dynamicBroadcastCompare(%arg0: tensor, %arg1: tensor) -> t return %0 : tensor } +// ----- + +// CHECK-LABEL: func @selectv2 +func @selectv2(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { + // CHECK-NEXT: "mhlo.select"(%arg0, %arg1, %arg2) + %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + return %0: tensor<2xi32> +} + +// CHECK-LABEL: func @selectv2_pred_scalar +func @selectv2_pred_scalar(%arg0: tensor, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { + // CHECK-NEXT: "mhlo.select"(%arg0, %arg1, %arg2) + %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + return %0: tensor<2xi32> +} + +// CHECK-LABEL: func @selectv2_broadcast_then +func @selectv2_broadcast_then(%arg0: tensor, %arg1: tensor<8x1xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> { + // CHECK-NEXT: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<8x1xi32>) -> tensor<2x8x8xi32> + // CHECK-NEXT: "mhlo.select"(%arg0, %[[BROADCAST]], %arg2) + %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor, tensor<8x1xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32> + return %0: tensor<2x8x8xi32> +} + +// CHECK-LABEL: func @selectv2_broadcast_else +func @selectv2_broadcast_else(%arg0: tensor, %arg1: tensor<2x8x8xi32>, %arg2: tensor<8x1xi32>) -> tensor<2x8x8xi32> { + // CHECK-NEXT: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg2) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<8x1xi32>) -> tensor<2x8x8xi32> + // CHECK-NEXT: "mhlo.select"(%arg0, %arg1, %[[BROADCAST]]) + %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor, tensor<2x8x8xi32>, tensor<8x1xi32>) -> tensor<2x8x8xi32> + return %0: tensor<2x8x8xi32> +} + +// CHECK-LABEL: func @selectv2_broadcast_pred +func @selectv2_broadcast_pred(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> { + // CHECK-NEXT: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1xi1>) -> tensor<2x8x8xi1> + // CHECK-NEXT: "mhlo.select"(%[[BROADCAST]], %arg1, %arg2) + %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32> + return %0: tensor<2x8x8xi32> +} + +// CHECK-LABEL: func @selectv2_broadcast_tensor_pred +func @selectv2_broadcast_tensor_pred(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> { + // CHECK-NEXT: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi1>) -> tensor<2x3xi1> + // CHECK-NEXT: "mhlo.select"(%[[BROADCAST]], %arg1, %arg2) + %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16> + return %0: tensor<2x3xf16> +} + +// CHECK-LABEL: func @selectv2_broadcast_all +func @selectv2_broadcast_all(%arg0: tensor<8x1x1xi1>, %arg1: tensor<1x8x1xi32>, %arg2: tensor<1x1x8xi32>) -> tensor<8x8x8xi32> { + // CHECK-DAG: %[[BROADCAST_0:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x1x1xi1>) -> tensor<8x8x8xi1> + // CHECK-DAG: %[[BROADCAST_1:.*]] = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x8x1xi32>) -> tensor<8x8x8xi32> + // CHECK-DAG: %[[BROADCAST_2:.*]] = "mhlo.broadcast_in_dim"(%arg2) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x1x8xi32>) -> tensor<8x8x8xi32> + // CHECK: "mhlo.select"(%[[BROADCAST_0]], %[[BROADCAST_1]], %[[BROADCAST_2]]) + %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<8x1x1xi1>, tensor<1x8x1xi32>, tensor<1x1x8xi32>) -> tensor<8x8x8xi32> + return %0: tensor<8x8x8xi32> +} + +// CHECK-LABEL: func @selectv2_dynamic_ranked +func @selectv2_dynamic_ranked(%arg0: tensor<1xi1>, %arg1: tensor<2x?x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x?x8xi32> { + // CHECK-NEXT: %[[SHAPE0:.*]] = shape.const_shape [1] : tensor + // CHECK-NEXT: %[[SHAPE2:.*]] = shape.const_shape [2, 8, 8] : tensor + // CHECK-NEXT: %[[SHAPE1:.*]] = shape.shape_of %arg1 : tensor<2x?x8xi32> -> tensor + // CHECK-NEXT: %[[CSTR:.*]] = shape.cstr_broadcastable %[[SHAPE1]], %[[SHAPE0]], %[[SHAPE2]] : tensor, tensor, tensor + // CHECK-NEXT: %[[ASSUME:.*]] = shape.assuming %[[CSTR]] -> (tensor<2x?x8xi32>) { + // CHECK-NEXT: %[[BCST_V:.*]] = shape.broadcast %[[SHAPE1]], %[[SHAPE0]], %[[SHAPE2]] : tensor, tensor, tensor -> tensor + // CHECK-NEXT: %[[BCST:.*]] = tensor.cast %[[BCST_V]] : tensor to tensor<3xindex> + // CHECK-NEXT: %[[BCST0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[BCST]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1xi1>, tensor<3xindex>) -> tensor<2x?x8xi1> + // CHECK-NEXT: %[[BCST1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[BCST]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x?x8xi32>, tensor<3xindex>) -> tensor<2x?x8xi32> + // CHECK-NEXT: %[[BCST2:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg2, %[[BCST]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x8x8xi32>, tensor<3xindex>) -> tensor<2x?x8xi32> + // CHECK-NEXT: %[[SELECT:.*]] = "mhlo.select"(%[[BCST0]], %[[BCST1]], %[[BCST2]]) : (tensor<2x?x8xi1>, tensor<2x?x8xi32>, tensor<2x?x8xi32>) -> tensor<2x?x8xi32> + // CHECK-NEXT: shape.assuming_yield %[[SELECT]] : tensor<2x?x8xi32> + // CHECK-NEXT: } + // CHECK-NEXT: return %[[ASSUME]] : tensor<2x?x8xi32> + %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x?x8xi32>, tensor<2x8x8xi32>) -> tensor<2x?x8xi32> + return %0: tensor<2x?x8xi32> +} + // ----- // Verifies that broadcast_dimensions validity checks are valid. // CHECK-LABEL: @dynamicNonScalarBroadcastDimensions