Bitwise and/or/xor folders
Includes both and/or/xor on same inputs, constant all ones/zeros single arg folder, and constant input folders. PiperOrigin-RevId: 330610858
This commit is contained in:
parent
9d4273b5a7
commit
81d51d810b
|
@ -379,6 +379,8 @@ class HLO_BinaryLogicalElementwiseOp<string mnemonic> :
|
||||||
HLO_PredOrIntTensor:$lhs,
|
HLO_PredOrIntTensor:$lhs,
|
||||||
HLO_PredOrIntTensor:$rhs
|
HLO_PredOrIntTensor:$rhs
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLO_AndOp: HLO_BinaryLogicalElementwiseOp<"and">, BASE_HLO_AndOp;
|
def HLO_AndOp: HLO_BinaryLogicalElementwiseOp<"and">, BASE_HLO_AndOp;
|
||||||
|
|
|
@ -1267,6 +1267,130 @@ static LogicalResult Verify(InfeedOp op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Logical Ops
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
if (lhs() == rhs()) return lhs();
|
||||||
|
|
||||||
|
auto rType = getType().cast<ShapedType>();
|
||||||
|
auto lhsVal = operands[0].dyn_cast_or_null<DenseElementsAttr>();
|
||||||
|
auto rhsVal = operands[1].dyn_cast_or_null<DenseElementsAttr>();
|
||||||
|
|
||||||
|
if (lhsVal && lhsVal.isSplat()) {
|
||||||
|
if (lhsVal.getSplatValue()
|
||||||
|
.cast<IntegerAttr>()
|
||||||
|
.getValue()
|
||||||
|
.isAllOnesValue()) {
|
||||||
|
return rhs();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
|
||||||
|
return lhsVal;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (rhsVal && rhsVal.isSplat()) {
|
||||||
|
if (rhsVal.getSplatValue()
|
||||||
|
.cast<IntegerAttr>()
|
||||||
|
.getValue()
|
||||||
|
.isAllOnesValue()) {
|
||||||
|
return lhs();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (rhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
|
||||||
|
return rhsVal;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!rhsVal || !lhsVal) return {};
|
||||||
|
|
||||||
|
llvm::SmallVector<APInt, 4> values;
|
||||||
|
values.reserve(rhsVal.getNumElements());
|
||||||
|
for (auto it : llvm::zip(rhsVal.getIntValues(), lhsVal.getIntValues())) {
|
||||||
|
values.push_back(std::get<0>(it) & std::get<1>(it));
|
||||||
|
}
|
||||||
|
|
||||||
|
return DenseIntElementsAttr::get(rType, values);
|
||||||
|
}
|
||||||
|
|
||||||
|
OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
if (lhs() == rhs()) return lhs();
|
||||||
|
|
||||||
|
auto rType = getType().cast<ShapedType>();
|
||||||
|
auto lhsVal = operands[0].dyn_cast_or_null<DenseElementsAttr>();
|
||||||
|
auto rhsVal = operands[1].dyn_cast_or_null<DenseElementsAttr>();
|
||||||
|
|
||||||
|
if (lhsVal && lhsVal.isSplat()) {
|
||||||
|
if (lhsVal.getSplatValue()
|
||||||
|
.cast<IntegerAttr>()
|
||||||
|
.getValue()
|
||||||
|
.isAllOnesValue()) {
|
||||||
|
return lhsVal;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
|
||||||
|
return rhs();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (rhsVal && rhsVal.isSplat()) {
|
||||||
|
if (rhsVal.getSplatValue()
|
||||||
|
.cast<IntegerAttr>()
|
||||||
|
.getValue()
|
||||||
|
.isAllOnesValue()) {
|
||||||
|
return rhsVal;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (rhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
|
||||||
|
return lhs();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!rhsVal || !lhsVal) return {};
|
||||||
|
|
||||||
|
llvm::SmallVector<APInt, 4> values;
|
||||||
|
values.reserve(rhsVal.getNumElements());
|
||||||
|
for (auto it : llvm::zip(rhsVal.getIntValues(), lhsVal.getIntValues())) {
|
||||||
|
values.push_back(std::get<0>(it) | std::get<1>(it));
|
||||||
|
}
|
||||||
|
|
||||||
|
return DenseIntElementsAttr::get(rType, values);
|
||||||
|
}
|
||||||
|
|
||||||
|
OpFoldResult XorOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
auto rType = getType().cast<ShapedType>();
|
||||||
|
if (lhs() == rhs()) {
|
||||||
|
return DenseIntElementsAttr::get(rType, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto lhsVal = operands[0].dyn_cast_or_null<DenseElementsAttr>();
|
||||||
|
auto rhsVal = operands[1].dyn_cast_or_null<DenseElementsAttr>();
|
||||||
|
|
||||||
|
if (lhsVal && lhsVal.isSplat()) {
|
||||||
|
if (lhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
|
||||||
|
return rhs();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (rhsVal && rhsVal.isSplat()) {
|
||||||
|
if (rhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
|
||||||
|
return lhs();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!rhsVal || !lhsVal) return {};
|
||||||
|
|
||||||
|
llvm::SmallVector<APInt, 4> values;
|
||||||
|
values.reserve(rhsVal.getNumElements());
|
||||||
|
for (auto it : llvm::zip(rhsVal.getIntValues(), lhsVal.getIntValues())) {
|
||||||
|
values.push_back(std::get<0>(it) ^ std::get<1>(it));
|
||||||
|
}
|
||||||
|
|
||||||
|
return DenseIntElementsAttr::get(rType, values);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// MapOp
|
// MapOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -694,3 +694,143 @@ func @gather_scalar_index_to_slice(%arg0: tensor<5x6x7xf32>) -> tensor<5x6x4xf32
|
||||||
// CHECK: %[[RET:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<[5, 6, 5]> : tensor<3xi64>, start_indices = dense<[0, 0, 1]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<5x6x7xf32>) -> tensor<5x6x4xf32>
|
// CHECK: %[[RET:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<[5, 6, 5]> : tensor<3xi64>, start_indices = dense<[0, 0, 1]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<5x6x7xf32>) -> tensor<5x6x4xf32>
|
||||||
// CHECK: return %[[RET]] : tensor<5x6x4xf32>
|
// CHECK: return %[[RET]] : tensor<5x6x4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @fold_and_same
|
||||||
|
func @fold_and_same(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
|
||||||
|
%0 = "mhlo.and"(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||||
|
// CHECK: return %arg0
|
||||||
|
return %0 : tensor<4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @fold_and_ones
|
||||||
|
func @fold_and_ones(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
|
||||||
|
%0 = mhlo.constant dense<-1> : tensor<4xi32>
|
||||||
|
%1 = "mhlo.and"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||||
|
// CHECK: return %arg0
|
||||||
|
return %1 : tensor<4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @fold_and_zeros
|
||||||
|
func @fold_and_zeros(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
|
||||||
|
%0 = mhlo.constant dense<0> : tensor<4xi32>
|
||||||
|
%1 = "mhlo.and"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||||
|
// CHECK: return %0
|
||||||
|
return %1 : tensor<4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @fold_and_constant
|
||||||
|
func @fold_and_constant(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
|
||||||
|
%0 = mhlo.constant dense<7> : tensor<4xi32>
|
||||||
|
// CHECK: mhlo.and
|
||||||
|
%1 = "mhlo.and"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||||
|
return %1 : tensor<4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @fold_and_constants
|
||||||
|
func @fold_and_constants() -> tensor<4xi32> {
|
||||||
|
%0 = mhlo.constant dense<[0, 1, 6, 3]> : tensor<4xi32>
|
||||||
|
%1 = mhlo.constant dense<[7, 3, 7, 2]> : tensor<4xi32>
|
||||||
|
%2 = "mhlo.and"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||||
|
// CHECK: %0 = mhlo.constant dense<[0, 1, 6, 2]> : tensor<4xi32>
|
||||||
|
// CHECK: return %0
|
||||||
|
return %2 : tensor<4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @fold_or_same
|
||||||
|
func @fold_or_same(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
|
||||||
|
%0 = "mhlo.or"(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||||
|
// CHECK: return %arg0
|
||||||
|
return %0 : tensor<4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @fold_or_ones
|
||||||
|
func @fold_or_ones(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
|
||||||
|
%0 = mhlo.constant dense<-1> : tensor<4xi32>
|
||||||
|
%1 = "mhlo.or"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||||
|
// CHECK: return %0
|
||||||
|
return %1 : tensor<4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @fold_or_zeros
|
||||||
|
func @fold_or_zeros(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
|
||||||
|
%0 = mhlo.constant dense<0> : tensor<4xi32>
|
||||||
|
%1 = "mhlo.or"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||||
|
// CHECK: return %arg0
|
||||||
|
return %1 : tensor<4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @fold_or_constant
|
||||||
|
func @fold_or_constant(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
|
||||||
|
%0 = mhlo.constant dense<7> : tensor<4xi32>
|
||||||
|
// CHECK: mhlo.or
|
||||||
|
%1 = "mhlo.or"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||||
|
return %1 : tensor<4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @fold_or_zeros_right
|
||||||
|
func @fold_or_zeros_right(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
|
||||||
|
%0 = mhlo.constant dense<0> : tensor<4xi32>
|
||||||
|
%1 = "mhlo.or"(%arg0, %0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||||
|
// CHECK: return %arg0
|
||||||
|
return %1 : tensor<4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @fold_or_zeros_constants
|
||||||
|
func @fold_or_zeros_constants() -> tensor<4xi32> {
|
||||||
|
%0 = mhlo.constant dense<[0, 1, 6, 3]> : tensor<4xi32>
|
||||||
|
%1 = mhlo.constant dense<[7, 3, 7, 2]> : tensor<4xi32>
|
||||||
|
%2 = "mhlo.or"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||||
|
// CHECK: %0 = mhlo.constant dense<[7, 3, 7, 3]> : tensor<4xi32>
|
||||||
|
// CHECK: return %0
|
||||||
|
return %2 : tensor<4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @fold_xor_same
|
||||||
|
func @fold_xor_same(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
|
||||||
|
%0 = "mhlo.xor"(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||||
|
// CHECK: %0 = mhlo.constant dense<0> : tensor<4xi32>
|
||||||
|
// CHECK: return %0
|
||||||
|
return %0 : tensor<4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @fold_xor_ones_left
|
||||||
|
func @fold_xor_ones_left(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
|
||||||
|
%0 = mhlo.constant dense<-1> : tensor<4xi32>
|
||||||
|
// CHECK: mhlo.xor
|
||||||
|
%1 = "mhlo.xor"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||||
|
return %1 : tensor<4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @fold_xor_ones_right
|
||||||
|
func @fold_xor_ones_right(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
|
||||||
|
%0 = mhlo.constant dense<-1> : tensor<4xi32>
|
||||||
|
// CHECK: mhlo.xor
|
||||||
|
%1 = "mhlo.xor"(%arg0, %0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||||
|
return %1 : tensor<4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @fold_xor_zeros_left
|
||||||
|
func @fold_xor_zeros_left(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
|
||||||
|
%0 = mhlo.constant dense<0> : tensor<4xi32>
|
||||||
|
%1 = "mhlo.xor"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||||
|
// CHECK: return %arg0
|
||||||
|
return %1 : tensor<4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @fold_xor_zeros_right
|
||||||
|
func @fold_xor_zeros_right(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
|
||||||
|
%0 = mhlo.constant dense<0> : tensor<4xi32>
|
||||||
|
%1 = "mhlo.xor"(%arg0, %0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||||
|
// CHECK: return %arg0
|
||||||
|
return %1 : tensor<4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @fold_xor_zeros_constants
|
||||||
|
func @fold_xor_zeros_constants() -> tensor<4xi32> {
|
||||||
|
%0 = mhlo.constant dense<[0, 1, 6, 3]> : tensor<4xi32>
|
||||||
|
%1 = mhlo.constant dense<[7, 3, 7, 2]> : tensor<4xi32>
|
||||||
|
%2 = "mhlo.xor"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||||
|
// CHECK: %0 = mhlo.constant dense<[7, 2, 1, 1]> : tensor<4xi32>
|
||||||
|
// CHECK: return %0
|
||||||
|
return %2 : tensor<4xi32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue