Scalar / Trivial folding for mhlo.select
This covers the case where the predicate is a splat or the on_true/on_false values are the same. PiperOrigin-RevId: 329622785
This commit is contained in:
parent
158dbba4e5
commit
7c93352a40
|
@ -1153,6 +1153,8 @@ def HLO_SelectOp: HLO_Op<"select", [NoSideEffect, DeclareOpInterfaceMethods<Infe
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs HLO_Tensor);
|
let results = (outs HLO_Tensor);
|
||||||
|
|
||||||
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLO_SelectAndScatterOp: HLO_Op<"select_and_scatter",
|
def HLO_SelectAndScatterOp: HLO_Op<"select_and_scatter",
|
||||||
|
|
|
@ -1410,6 +1410,29 @@ static LogicalResult Verify(SelectOp op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
if (on_true() == on_false()) {
|
||||||
|
return on_true();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto predicate = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
|
||||||
|
if (!predicate) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
auto predicateTy = predicate.getType().cast<ShapedType>();
|
||||||
|
if (!predicateTy.getElementType().isInteger(1)) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (predicate.isSplat()) {
|
||||||
|
return predicate.getSplatValue<APInt>().getBoolValue() ? on_true()
|
||||||
|
: on_false();
|
||||||
|
}
|
||||||
|
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
// Makes it such that a SelectOp that is a non-root operation in a DRR infers
|
// Makes it such that a SelectOp that is a non-root operation in a DRR infers
|
||||||
// the return type based on operand type.
|
// the return type based on operand type.
|
||||||
LogicalResult SelectOp::inferReturnTypes(
|
LogicalResult SelectOp::inferReturnTypes(
|
||||||
|
|
|
@ -633,3 +633,35 @@ func @fold_get_dimension_size(%I : tensor<1x128x512xf32>) -> tensor<i32> {
|
||||||
// CHECK-NEXT: %[[C:.*]] = mhlo.constant dense<512> : tensor<i32>
|
// CHECK-NEXT: %[[C:.*]] = mhlo.constant dense<512> : tensor<i32>
|
||||||
// CHECK-NEXT: return %[[C]]
|
// CHECK-NEXT: return %[[C]]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @fold_select_same
|
||||||
|
func @fold_select_same(%arg0 : tensor<f32>, %arg1 : tensor<i1>) -> tensor<f32> {
|
||||||
|
%1 = "mhlo.select"(%arg1, %arg0, %arg0) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: return %arg0
|
||||||
|
return %1 : tensor<f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @fold_select_first
|
||||||
|
func @fold_select_first(%arg0 : tensor<f32>, %arg1 : tensor<f32>) -> tensor<f32> {
|
||||||
|
%0 = mhlo.constant dense<1> : tensor<i1>
|
||||||
|
%1 = "mhlo.select"(%0, %arg0, %arg1) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: return %arg0
|
||||||
|
return %1 : tensor<f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @fold_select_second
|
||||||
|
func @fold_select_second(%arg0 : tensor<f32>, %arg1 : tensor<f32>) -> tensor<f32> {
|
||||||
|
%0 = mhlo.constant dense<0> : tensor<i1>
|
||||||
|
%1 = "mhlo.select"(%0, %arg0, %arg1) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: return %arg1
|
||||||
|
return %1 : tensor<f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @fold_select_vector
|
||||||
|
func @fold_select_vector(%arg0 : tensor<4xf32>, %arg1 : tensor<4xf32>) -> tensor<4xf32> {
|
||||||
|
%0 = mhlo.constant dense<1> : tensor<4xi1>
|
||||||
|
%1 = "mhlo.select"(%0, %arg0, %arg1) : (tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||||
|
// CHECK: return %arg0
|
||||||
|
return %1 : tensor<4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue