Scalar / Trivial folding for mhlo.select
This covers the case where the predicate is a splat or the on_true/on_false values are the same. PiperOrigin-RevId: 329622785
This commit is contained in:
parent
158dbba4e5
commit
7c93352a40
|
@ -1153,6 +1153,8 @@ def HLO_SelectOp: HLO_Op<"select", [NoSideEffect, DeclareOpInterfaceMethods<Infe
|
|||
);
|
||||
|
||||
let results = (outs HLO_Tensor);
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def HLO_SelectAndScatterOp: HLO_Op<"select_and_scatter",
|
||||
|
|
|
@ -1410,6 +1410,29 @@ static LogicalResult Verify(SelectOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (on_true() == on_false()) {
|
||||
return on_true();
|
||||
}
|
||||
|
||||
auto predicate = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
|
||||
if (!predicate) {
|
||||
return {};
|
||||
}
|
||||
|
||||
auto predicateTy = predicate.getType().cast<ShapedType>();
|
||||
if (!predicateTy.getElementType().isInteger(1)) {
|
||||
return {};
|
||||
}
|
||||
|
||||
if (predicate.isSplat()) {
|
||||
return predicate.getSplatValue<APInt>().getBoolValue() ? on_true()
|
||||
: on_false();
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
// Makes it such that a SelectOp that is a non-root operation in a DRR infers
|
||||
// the return type based on operand type.
|
||||
LogicalResult SelectOp::inferReturnTypes(
|
||||
|
|
|
@ -633,3 +633,35 @@ func @fold_get_dimension_size(%I : tensor<1x128x512xf32>) -> tensor<i32> {
|
|||
// CHECK-NEXT: %[[C:.*]] = mhlo.constant dense<512> : tensor<i32>
|
||||
// CHECK-NEXT: return %[[C]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @fold_select_same
|
||||
func @fold_select_same(%arg0 : tensor<f32>, %arg1 : tensor<i1>) -> tensor<f32> {
|
||||
%1 = "mhlo.select"(%arg1, %arg0, %arg0) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: return %arg0
|
||||
return %1 : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @fold_select_first
|
||||
func @fold_select_first(%arg0 : tensor<f32>, %arg1 : tensor<f32>) -> tensor<f32> {
|
||||
%0 = mhlo.constant dense<1> : tensor<i1>
|
||||
%1 = "mhlo.select"(%0, %arg0, %arg1) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: return %arg0
|
||||
return %1 : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @fold_select_second
|
||||
func @fold_select_second(%arg0 : tensor<f32>, %arg1 : tensor<f32>) -> tensor<f32> {
|
||||
%0 = mhlo.constant dense<0> : tensor<i1>
|
||||
%1 = "mhlo.select"(%0, %arg0, %arg1) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: return %arg1
|
||||
return %1 : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @fold_select_vector
|
||||
func @fold_select_vector(%arg0 : tensor<4xf32>, %arg1 : tensor<4xf32>) -> tensor<4xf32> {
|
||||
%0 = mhlo.constant dense<1> : tensor<4xi1>
|
||||
%1 = "mhlo.select"(%0, %arg0, %arg1) : (tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK: return %arg0
|
||||
return %1 : tensor<4xf32>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue