Scalar / Trivial folding for mhlo.select

This covers the case where the predicate is a splat or the on_true/on_false
values are the same.

PiperOrigin-RevId: 329622785
This commit is contained in:
Robert Suderman 2020-09-01 18:33:08 -07:00 committed by TensorFlow MLIR Team
parent 158dbba4e5
commit 7c93352a40
3 changed files with 57 additions and 0 deletions

View File

@ -1153,6 +1153,8 @@ def HLO_SelectOp: HLO_Op<"select", [NoSideEffect, DeclareOpInterfaceMethods<Infe
);
let results = (outs HLO_Tensor);
let hasFolder = 1;
}
def HLO_SelectAndScatterOp: HLO_Op<"select_and_scatter",

View File

@ -1410,6 +1410,29 @@ static LogicalResult Verify(SelectOp op) {
return success();
}
OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
if (on_true() == on_false()) {
return on_true();
}
auto predicate = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
if (!predicate) {
return {};
}
auto predicateTy = predicate.getType().cast<ShapedType>();
if (!predicateTy.getElementType().isInteger(1)) {
return {};
}
if (predicate.isSplat()) {
return predicate.getSplatValue<APInt>().getBoolValue() ? on_true()
: on_false();
}
return {};
}
// Makes it such that a SelectOp that is a non-root operation in a DRR infers
// the return type based on operand type.
LogicalResult SelectOp::inferReturnTypes(

View File

@ -633,3 +633,35 @@ func @fold_get_dimension_size(%I : tensor<1x128x512xf32>) -> tensor<i32> {
// CHECK-NEXT: %[[C:.*]] = mhlo.constant dense<512> : tensor<i32>
// CHECK-NEXT: return %[[C]]
}
// CHECK-LABEL: func @fold_select_same
func @fold_select_same(%arg0 : tensor<f32>, %arg1 : tensor<i1>) -> tensor<f32> {
%1 = "mhlo.select"(%arg1, %arg0, %arg0) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: return %arg0
return %1 : tensor<f32>
}
// CHECK-LABEL: func @fold_select_first
func @fold_select_first(%arg0 : tensor<f32>, %arg1 : tensor<f32>) -> tensor<f32> {
%0 = mhlo.constant dense<1> : tensor<i1>
%1 = "mhlo.select"(%0, %arg0, %arg1) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: return %arg0
return %1 : tensor<f32>
}
// CHECK-LABEL: func @fold_select_second
func @fold_select_second(%arg0 : tensor<f32>, %arg1 : tensor<f32>) -> tensor<f32> {
%0 = mhlo.constant dense<0> : tensor<i1>
%1 = "mhlo.select"(%0, %arg0, %arg1) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: return %arg1
return %1 : tensor<f32>
}
// CHECK-LABEL: func @fold_select_vector
func @fold_select_vector(%arg0 : tensor<4xf32>, %arg1 : tensor<4xf32>) -> tensor<4xf32> {
%0 = mhlo.constant dense<1> : tensor<4xi1>
%1 = "mhlo.select"(%0, %arg0, %arg1) : (tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK: return %arg0
return %1 : tensor<4xf32>
}