[mlir][hlo] Add basic rank-specialization for select
This just blows up everything to ranked (up to 6) and is probably quite slow. This is sufficient to make kernelgen compile SelectV2. PiperOrigin-RevId: 358777728
This commit is contained in:
parent
909574e393
commit
a9cc1dcfa0
|
@ -475,6 +475,25 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
|||
}
|
||||
};
|
||||
|
||||
// Rank-specialize chlo.broadcast_select ops.
|
||||
struct ConvertUnrankedDynamicBroadcastSelectOp
|
||||
: public OpConversionPattern<chlo::BroadcastSelectOp> {
|
||||
using OpConversionPattern<chlo::BroadcastSelectOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
chlo::BroadcastSelectOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// For now only do the bare minimum and specialize for every rank. There is
|
||||
// more potential for optimization here. This also is missing the
|
||||
// specialization for rank 0.
|
||||
rewriter.replaceOp(
|
||||
op, {ConvertUnrankedDynamicBroadcastOpHelper<
|
||||
chlo::BroadcastSelectOp,
|
||||
mhlo::SelectOp>::HandleBroadcastAndOp(rewriter, op, operands)});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TransformUnrankedHloPass
|
||||
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
|
@ -539,6 +558,7 @@ void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
|
|||
ConvertUnrankedDynamicBroadcastBinaryOp>(context, patterns);
|
||||
chlo::PopulateForBroadcastingBinaryOp<
|
||||
ConvertUnrankedScalarDynamicBroadcastBinaryOp>(context, patterns);
|
||||
patterns->insert<ConvertUnrankedDynamicBroadcastSelectOp>(context);
|
||||
}
|
||||
|
||||
std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() {
|
||||
|
|
|
@ -308,3 +308,53 @@ func @addUnrankedUnranked(
|
|||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[VAL_72:.*]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
func @selectUnrankedUnrankedUnranked(
|
||||
%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>)
|
||||
-> tensor<*xf32> {
|
||||
%0 = chlo.broadcast_select %arg0, %arg1, %arg2
|
||||
: (tensor<*xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @selectUnrankedUnrankedUnranked(
|
||||
// CHECK-SAME: %[[PRED:.*]]: tensor<*xi1>,
|
||||
// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK-NEXT: %[[PRED_SHAPE:.*]] = shape.shape_of %[[PRED]] : tensor<*xi1> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[PRED_RANK:.*]] = shape.rank %[[PRED_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[GREATER_RANK_CMP:.*]] = cmpi sgt, %[[PRED_RANK]], %[[LHS_RANK]] : index
|
||||
// CHECK-NEXT: %[[GREATER_RANK:.*]] = select %[[GREATER_RANK_CMP]], %[[PRED_RANK]], %[[LHS_RANK]] : index
|
||||
// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[GREATEST_RANK_CMP:.*]] = cmpi sgt, %[[GREATER_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[GREATEST_RANK_CMP]], %[[GREATER_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK-NEXT: %c1 = constant 1 : index
|
||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_1:.*]] = cmpi eq, %[[GREATEST_RANK]], %c1 : index
|
||||
// Handle rank 1 specialization
|
||||
// CHECK-NEXT: scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1] : tensor<1xindex>
|
||||
// CHECK-NEXT: %[[BROADCASTED_PRED:.*]] = shape.broadcast %[[PRED_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_PRED:.*]] = tensor.cast %[[BROADCASTED_PRED]] : tensor<?xindex> to tensor<1xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[CASTED_PRED]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor<?xi1>
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS:.*]] = tensor.cast %[[BROADCASTED_LHS]] : tensor<?xindex> to tensor<1xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS:.*]] = tensor.cast %[[BROADCASTED_RHS]] : tensor<?xindex> to tensor<1xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_select %[[RESHAPED_PRED]], %[[RESHAPED_LHS]], %[[RESHAPED_RHS]] : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1:.*]] : tensor<?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?xi1>, tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?xi1>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?xi1>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?xi1>, tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
|
||||
|
|
Loading…
Reference in New Issue