Bitwise and/or/xor folders

Includes both and/or/xor on same inputs, constant all ones/zeros single arg
folder, and constant input folders.

PiperOrigin-RevId: 330610858
This commit is contained in:
Robert Suderman 2020-09-08 16:26:39 -07:00 committed by TensorFlow MLIR Team
parent 9d4273b5a7
commit 81d51d810b
3 changed files with 266 additions and 0 deletions

View File

@ -379,6 +379,8 @@ class HLO_BinaryLogicalElementwiseOp<string mnemonic> :
HLO_PredOrIntTensor:$lhs,
HLO_PredOrIntTensor:$rhs
);
let hasFolder = 1;
}
def HLO_AndOp: HLO_BinaryLogicalElementwiseOp<"and">, BASE_HLO_AndOp;

View File

@ -1267,6 +1267,130 @@ static LogicalResult Verify(InfeedOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// Logical Ops
//===----------------------------------------------------------------------===//
OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
if (lhs() == rhs()) return lhs();
auto rType = getType().cast<ShapedType>();
auto lhsVal = operands[0].dyn_cast_or_null<DenseElementsAttr>();
auto rhsVal = operands[1].dyn_cast_or_null<DenseElementsAttr>();
if (lhsVal && lhsVal.isSplat()) {
if (lhsVal.getSplatValue()
.cast<IntegerAttr>()
.getValue()
.isAllOnesValue()) {
return rhs();
}
if (lhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
return lhsVal;
}
}
if (rhsVal && rhsVal.isSplat()) {
if (rhsVal.getSplatValue()
.cast<IntegerAttr>()
.getValue()
.isAllOnesValue()) {
return lhs();
}
if (rhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
return rhsVal;
}
}
if (!rhsVal || !lhsVal) return {};
llvm::SmallVector<APInt, 4> values;
values.reserve(rhsVal.getNumElements());
for (auto it : llvm::zip(rhsVal.getIntValues(), lhsVal.getIntValues())) {
values.push_back(std::get<0>(it) & std::get<1>(it));
}
return DenseIntElementsAttr::get(rType, values);
}
OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) {
if (lhs() == rhs()) return lhs();
auto rType = getType().cast<ShapedType>();
auto lhsVal = operands[0].dyn_cast_or_null<DenseElementsAttr>();
auto rhsVal = operands[1].dyn_cast_or_null<DenseElementsAttr>();
if (lhsVal && lhsVal.isSplat()) {
if (lhsVal.getSplatValue()
.cast<IntegerAttr>()
.getValue()
.isAllOnesValue()) {
return lhsVal;
}
if (lhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
return rhs();
}
}
if (rhsVal && rhsVal.isSplat()) {
if (rhsVal.getSplatValue()
.cast<IntegerAttr>()
.getValue()
.isAllOnesValue()) {
return rhsVal;
}
if (rhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
return lhs();
}
}
if (!rhsVal || !lhsVal) return {};
llvm::SmallVector<APInt, 4> values;
values.reserve(rhsVal.getNumElements());
for (auto it : llvm::zip(rhsVal.getIntValues(), lhsVal.getIntValues())) {
values.push_back(std::get<0>(it) | std::get<1>(it));
}
return DenseIntElementsAttr::get(rType, values);
}
OpFoldResult XorOp::fold(ArrayRef<Attribute> operands) {
auto rType = getType().cast<ShapedType>();
if (lhs() == rhs()) {
return DenseIntElementsAttr::get(rType, 0);
}
auto lhsVal = operands[0].dyn_cast_or_null<DenseElementsAttr>();
auto rhsVal = operands[1].dyn_cast_or_null<DenseElementsAttr>();
if (lhsVal && lhsVal.isSplat()) {
if (lhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
return rhs();
}
}
if (rhsVal && rhsVal.isSplat()) {
if (rhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
return lhs();
}
}
if (!rhsVal || !lhsVal) return {};
llvm::SmallVector<APInt, 4> values;
values.reserve(rhsVal.getNumElements());
for (auto it : llvm::zip(rhsVal.getIntValues(), lhsVal.getIntValues())) {
values.push_back(std::get<0>(it) ^ std::get<1>(it));
}
return DenseIntElementsAttr::get(rType, values);
}
//===----------------------------------------------------------------------===//
// MapOp
//===----------------------------------------------------------------------===//

View File

@ -694,3 +694,143 @@ func @gather_scalar_index_to_slice(%arg0: tensor<5x6x7xf32>) -> tensor<5x6x4xf32
// CHECK: %[[RET:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<[5, 6, 5]> : tensor<3xi64>, start_indices = dense<[0, 0, 1]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<5x6x7xf32>) -> tensor<5x6x4xf32>
// CHECK: return %[[RET]] : tensor<5x6x4xf32>
}
// CHECK-LABEL: func @fold_and_same
func @fold_and_same(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
%0 = "mhlo.and"(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK: return %arg0
return %0 : tensor<4xi32>
}
// CHECK-LABEL: func @fold_and_ones
func @fold_and_ones(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
%0 = mhlo.constant dense<-1> : tensor<4xi32>
%1 = "mhlo.and"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK: return %arg0
return %1 : tensor<4xi32>
}
// CHECK-LABEL: func @fold_and_zeros
func @fold_and_zeros(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
%0 = mhlo.constant dense<0> : tensor<4xi32>
%1 = "mhlo.and"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK: return %0
return %1 : tensor<4xi32>
}
// CHECK-LABEL: func @fold_and_constant
func @fold_and_constant(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
%0 = mhlo.constant dense<7> : tensor<4xi32>
// CHECK: mhlo.and
%1 = "mhlo.and"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %1 : tensor<4xi32>
}
// CHECK-LABEL: func @fold_and_constants
func @fold_and_constants() -> tensor<4xi32> {
%0 = mhlo.constant dense<[0, 1, 6, 3]> : tensor<4xi32>
%1 = mhlo.constant dense<[7, 3, 7, 2]> : tensor<4xi32>
%2 = "mhlo.and"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK: %0 = mhlo.constant dense<[0, 1, 6, 2]> : tensor<4xi32>
// CHECK: return %0
return %2 : tensor<4xi32>
}
// CHECK-LABEL: func @fold_or_same
func @fold_or_same(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
%0 = "mhlo.or"(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK: return %arg0
return %0 : tensor<4xi32>
}
// CHECK-LABEL: func @fold_or_ones
func @fold_or_ones(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
%0 = mhlo.constant dense<-1> : tensor<4xi32>
%1 = "mhlo.or"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK: return %0
return %1 : tensor<4xi32>
}
// CHECK-LABEL: func @fold_or_zeros
func @fold_or_zeros(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
%0 = mhlo.constant dense<0> : tensor<4xi32>
%1 = "mhlo.or"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK: return %arg0
return %1 : tensor<4xi32>
}
// CHECK-LABEL: func @fold_or_constant
func @fold_or_constant(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
%0 = mhlo.constant dense<7> : tensor<4xi32>
// CHECK: mhlo.or
%1 = "mhlo.or"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %1 : tensor<4xi32>
}
// CHECK-LABEL: func @fold_or_zeros_right
func @fold_or_zeros_right(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
%0 = mhlo.constant dense<0> : tensor<4xi32>
%1 = "mhlo.or"(%arg0, %0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK: return %arg0
return %1 : tensor<4xi32>
}
// CHECK-LABEL: func @fold_or_zeros_constants
func @fold_or_zeros_constants() -> tensor<4xi32> {
%0 = mhlo.constant dense<[0, 1, 6, 3]> : tensor<4xi32>
%1 = mhlo.constant dense<[7, 3, 7, 2]> : tensor<4xi32>
%2 = "mhlo.or"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK: %0 = mhlo.constant dense<[7, 3, 7, 3]> : tensor<4xi32>
// CHECK: return %0
return %2 : tensor<4xi32>
}
// CHECK-LABEL: func @fold_xor_same
func @fold_xor_same(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
%0 = "mhlo.xor"(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK: %0 = mhlo.constant dense<0> : tensor<4xi32>
// CHECK: return %0
return %0 : tensor<4xi32>
}
// CHECK-LABEL: func @fold_xor_ones_left
func @fold_xor_ones_left(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
%0 = mhlo.constant dense<-1> : tensor<4xi32>
// CHECK: mhlo.xor
%1 = "mhlo.xor"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %1 : tensor<4xi32>
}
// CHECK-LABEL: func @fold_xor_ones_right
func @fold_xor_ones_right(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
%0 = mhlo.constant dense<-1> : tensor<4xi32>
// CHECK: mhlo.xor
%1 = "mhlo.xor"(%arg0, %0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %1 : tensor<4xi32>
}
// CHECK-LABEL: func @fold_xor_zeros_left
func @fold_xor_zeros_left(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
%0 = mhlo.constant dense<0> : tensor<4xi32>
%1 = "mhlo.xor"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK: return %arg0
return %1 : tensor<4xi32>
}
// CHECK-LABEL: func @fold_xor_zeros_right
func @fold_xor_zeros_right(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
%0 = mhlo.constant dense<0> : tensor<4xi32>
%1 = "mhlo.xor"(%arg0, %0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK: return %arg0
return %1 : tensor<4xi32>
}
// CHECK-LABEL: func @fold_xor_zeros_constants
func @fold_xor_zeros_constants() -> tensor<4xi32> {
%0 = mhlo.constant dense<[0, 1, 6, 3]> : tensor<4xi32>
%1 = mhlo.constant dense<[7, 3, 7, 2]> : tensor<4xi32>
%2 = "mhlo.xor"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK: %0 = mhlo.constant dense<[7, 2, 1, 1]> : tensor<4xi32>
// CHECK: return %0
return %2 : tensor<4xi32>
}