[mlir][hlo] Add basic rank-specialization for select

This just blows up everything to ranked (up to 6) and is probably quite slow.
This is sufficient to make kernelgen compile SelectV2.

PiperOrigin-RevId: 358777728
This commit is contained in:
Benjamin Kramer 2021-02-22 02:39:56 -08:00 committed by TensorFlow MLIR Team
parent 909574e393
commit a9cc1dcfa0
2 changed files with 70 additions and 0 deletions

View File

@ -475,6 +475,25 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
} }
}; };
// Rank-specialize chlo.broadcast_select ops.
struct ConvertUnrankedDynamicBroadcastSelectOp
: public OpConversionPattern<chlo::BroadcastSelectOp> {
using OpConversionPattern<chlo::BroadcastSelectOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
chlo::BroadcastSelectOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// For now only do the bare minimum and specialize for every rank. There is
// more potential for optimization here. This also is missing the
// specialization for rank 0.
rewriter.replaceOp(
op, {ConvertUnrankedDynamicBroadcastOpHelper<
chlo::BroadcastSelectOp,
mhlo::SelectOp>::HandleBroadcastAndOp(rewriter, op, operands)});
return success();
}
};
struct TransformUnrankedHloPass struct TransformUnrankedHloPass
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> { : public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override { void getDependentDialects(DialectRegistry &registry) const override {
@ -539,6 +558,7 @@ void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
ConvertUnrankedDynamicBroadcastBinaryOp>(context, patterns); ConvertUnrankedDynamicBroadcastBinaryOp>(context, patterns);
chlo::PopulateForBroadcastingBinaryOp< chlo::PopulateForBroadcastingBinaryOp<
ConvertUnrankedScalarDynamicBroadcastBinaryOp>(context, patterns); ConvertUnrankedScalarDynamicBroadcastBinaryOp>(context, patterns);
patterns->insert<ConvertUnrankedDynamicBroadcastSelectOp>(context);
} }
std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() { std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() {

View File

@ -308,3 +308,53 @@ func @addUnrankedUnranked(
// CHECK-NEXT: } // CHECK-NEXT: }
// CHECK-NEXT: return %[[VAL_72:.*]] : tensor<*xf32> // CHECK-NEXT: return %[[VAL_72:.*]] : tensor<*xf32>
// CHECK-NEXT: } // CHECK-NEXT: }
// -----
func @selectUnrankedUnrankedUnranked(
%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>)
-> tensor<*xf32> {
%0 = chlo.broadcast_select %arg0, %arg1, %arg2
: (tensor<*xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @selectUnrankedUnrankedUnranked(
// CHECK-SAME: %[[PRED:.*]]: tensor<*xi1>,
// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>,
// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-NEXT: %[[PRED_SHAPE:.*]] = shape.shape_of %[[PRED]] : tensor<*xi1> -> tensor<?xindex>
// CHECK-NEXT: %[[PRED_RANK:.*]] = shape.rank %[[PRED_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[GREATER_RANK_CMP:.*]] = cmpi sgt, %[[PRED_RANK]], %[[LHS_RANK]] : index
// CHECK-NEXT: %[[GREATER_RANK:.*]] = select %[[GREATER_RANK_CMP]], %[[PRED_RANK]], %[[LHS_RANK]] : index
// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[GREATEST_RANK_CMP:.*]] = cmpi sgt, %[[GREATER_RANK]], %[[RHS_RANK]] : index
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[GREATEST_RANK_CMP]], %[[GREATER_RANK]], %[[RHS_RANK]] : index
// CHECK-NEXT: %c1 = constant 1 : index
// CHECK-NEXT: %[[GREATEST_RANK_IS_1:.*]] = cmpi eq, %[[GREATEST_RANK]], %c1 : index
// Handle rank 1 specialization
// CHECK-NEXT: scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1] : tensor<1xindex>
// CHECK-NEXT: %[[BROADCASTED_PRED:.*]] = shape.broadcast %[[PRED_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_PRED:.*]] = tensor.cast %[[BROADCASTED_PRED]] : tensor<?xindex> to tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[CASTED_PRED]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor<?xi1>
// CHECK-NEXT: %[[BROADCASTED_LHS:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS:.*]] = tensor.cast %[[BROADCASTED_LHS]] : tensor<?xindex> to tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS:.*]] = tensor.cast %[[BROADCASTED_RHS]] : tensor<?xindex> to tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_select %[[RESHAPED_PRED]], %[[RESHAPED_LHS]], %[[RESHAPED_RHS]] : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1:.*]] : tensor<?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32>
// CHECK-NEXT: }
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?xi1>, tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?xi1>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?xi1>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?xi1>, tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>