Folders for mhlo.compare
Constant evaluation of compare for the case where inputs are either the same variable or the values are constant. PiperOrigin-RevId: 333342328
This commit is contained in:
parent
0259f982bf
commit
233f1a8a1a
|
@ -689,6 +689,8 @@ def HLO_CompareOp: HLO_Op<"compare", [NoSideEffect, SameTypeOperands,
|
||||||
);
|
);
|
||||||
let results = (outs HLO_PredTensor);
|
let results = (outs HLO_PredTensor);
|
||||||
|
|
||||||
|
let hasFolder = 1;
|
||||||
|
|
||||||
let builders = [OpBuilder<
|
let builders = [OpBuilder<
|
||||||
"OpBuilder &builder, OperationState &result, Value lhs, Value rhs, "
|
"OpBuilder &builder, OperationState &result, Value lhs, Value rhs, "
|
||||||
"StringAttr comparison_direction"
|
"StringAttr comparison_direction"
|
||||||
|
|
|
@ -2501,8 +2501,108 @@ LogicalResult CompareOp::reifyReturnTypeShapes(
|
||||||
&reifiedReturnShapes);
|
&reifiedReturnShapes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct less : std::less<T> {};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct less<APInt> {
|
||||||
|
bool operator()(const APInt& a, const APInt& b) const { return a.slt(b); }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct less_equal : std::less_equal<T> {};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct less_equal<APInt> {
|
||||||
|
bool operator()(const APInt& a, const APInt& b) const { return a.sle(b); }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct greater : std::greater<T> {};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct greater<APInt> {
|
||||||
|
bool operator()(const APInt& a, const APInt& b) const { return a.sgt(b); }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct greater_equal : std::greater_equal<T> {};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct greater_equal<APInt> {
|
||||||
|
bool operator()(const APInt& a, const APInt& b) const { return a.sge(b); }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Op, typename ElementType, typename SrcType, typename Convert>
|
||||||
|
static Attribute CompareFolder(CompareOp op, ArrayRef<Attribute> attrs) {
|
||||||
|
if (!attrs[0] || !attrs[1]) return {};
|
||||||
|
|
||||||
|
DenseElementsAttr lhs = attrs[0].dyn_cast<DenseElementsAttr>();
|
||||||
|
DenseElementsAttr rhs = attrs[1].dyn_cast<DenseElementsAttr>();
|
||||||
|
if (!lhs || !rhs) return {};
|
||||||
|
|
||||||
|
ShapedType operand_type =
|
||||||
|
op.getOperand(0).getType().template cast<ShapedType>();
|
||||||
|
if (!operand_type.hasStaticShape()) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!operand_type.getElementType().isa<ElementType>()) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<bool, 6> values;
|
||||||
|
values.reserve(lhs.getNumElements());
|
||||||
|
for (const auto zip :
|
||||||
|
llvm::zip(lhs.getValues<SrcType>(), rhs.getValues<SrcType>())) {
|
||||||
|
values.push_back(Convert()(std::get<0>(zip), std::get<1>(zip)));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto result_ty = op.getType().cast<ShapedType>();
|
||||||
|
return DenseElementsAttr::get(result_ty, values);
|
||||||
|
}
|
||||||
|
|
||||||
|
OpFoldResult CompareOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
auto result_ty = getType().cast<ShapedType>();
|
||||||
|
if (!result_ty.hasStaticShape()) return {};
|
||||||
|
|
||||||
|
auto direction = comparison_direction();
|
||||||
|
if (lhs() == rhs()) {
|
||||||
|
if (direction == "LE" || direction == "EQ" || direction == "GE") {
|
||||||
|
return DenseIntElementsAttr::get(result_ty, {true});
|
||||||
|
}
|
||||||
|
|
||||||
|
return DenseIntElementsAttr::get(result_ty, {false});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!operands[0] || !operands[1]) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
#define COMPARE_FOLDER(Op, comparison, Func) \
|
||||||
|
if (direction == comparison) { \
|
||||||
|
if (auto folded = CompareFolder<Op, FloatType, APFloat, Func<APFloat>>( \
|
||||||
|
*this, operands)) \
|
||||||
|
return folded; \
|
||||||
|
if (auto folded = CompareFolder<Op, IntegerType, APInt, Func<APInt>>( \
|
||||||
|
*this, operands)) \
|
||||||
|
return folded; \
|
||||||
|
}
|
||||||
|
|
||||||
|
COMPARE_FOLDER(CompareOp, "EQ", std::equal_to);
|
||||||
|
COMPARE_FOLDER(CompareOp, "NE", std::not_equal_to);
|
||||||
|
COMPARE_FOLDER(CompareOp, "LT", less);
|
||||||
|
COMPARE_FOLDER(CompareOp, "LE", less_equal);
|
||||||
|
COMPARE_FOLDER(CompareOp, "GT", greater);
|
||||||
|
COMPARE_FOLDER(CompareOp, "GE", greater_equal);
|
||||||
|
#undef COMPARE_FOLDER
|
||||||
|
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mhlo
|
} // namespace mhlo
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"
|
||||||
|
|
||||||
|
|
|
@ -583,6 +583,262 @@ func @dce_while_without_side_effect(%arg0: tensor<i64>) -> tensor<i64> {
|
||||||
return %arg0 : tensor<i64>
|
return %arg0 : tensor<i64>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_compare_same_eq
|
||||||
|
func @fold_compare_same_eq(%arg0: tensor<i64>) -> tensor<i1> {
|
||||||
|
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1>
|
||||||
|
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||||
|
return %0 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_compare_same_le
|
||||||
|
func @fold_compare_same_le(%arg0: tensor<i64>) -> tensor<i1> {
|
||||||
|
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1>
|
||||||
|
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||||
|
return %0 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_compare_same_ge
|
||||||
|
func @fold_compare_same_ge(%arg0: tensor<i64>) -> tensor<i1> {
|
||||||
|
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1>
|
||||||
|
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||||
|
return %0 : tensor<i1>
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: fold_compare_same_ne
|
||||||
|
func @fold_compare_same_ne(%arg0: tensor<i64>) -> tensor<i1> {
|
||||||
|
// CHECK: %0 = mhlo.constant dense<false> : tensor<i1>
|
||||||
|
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||||
|
return %0 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_compare_same_lt
|
||||||
|
func @fold_compare_same_lt(%arg0: tensor<i64>) -> tensor<i1> {
|
||||||
|
// CHECK: %0 = mhlo.constant dense<false> : tensor<i1>
|
||||||
|
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||||
|
return %0 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_compare_same_gt
|
||||||
|
func @fold_compare_same_gt(%arg0: tensor<i64>) -> tensor<i1> {
|
||||||
|
// CHECK: %0 = mhlo.constant dense<false> : tensor<i1>
|
||||||
|
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||||
|
return %0 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_compare_false_eq
|
||||||
|
func @fold_compare_false_eq() -> tensor<i1> {
|
||||||
|
%0 = mhlo.constant dense<0> : tensor<i32>
|
||||||
|
%1 = mhlo.constant dense<1> : tensor<i32>
|
||||||
|
// CHECK: %0 = mhlo.constant dense<false> : tensor<i1>
|
||||||
|
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "EQ"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||||
|
return %2 : tensor<i1>
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: fold_compare_true_eq
|
||||||
|
func @fold_compare_true_eq() -> tensor<i1> {
|
||||||
|
%0 = mhlo.constant dense<1> : tensor<i32>
|
||||||
|
%1 = mhlo.constant dense<1> : tensor<i32>
|
||||||
|
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1>
|
||||||
|
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "EQ"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||||
|
return %2 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_compare_false_eq_float
|
||||||
|
func @fold_compare_false_eq_float() -> tensor<i1> {
|
||||||
|
%0 = mhlo.constant dense<0.> : tensor<f32>
|
||||||
|
%1 = mhlo.constant dense<1.> : tensor<f32>
|
||||||
|
// CHECK: %0 = mhlo.constant dense<false> : tensor<i1>
|
||||||
|
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "EQ"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
return %2 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_compare_true_eq_float
|
||||||
|
func @fold_compare_true_eq_float() -> tensor<i1> {
|
||||||
|
%0 = mhlo.constant dense<1.> : tensor<f32>
|
||||||
|
%1 = mhlo.constant dense<1.> : tensor<f32>
|
||||||
|
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1>
|
||||||
|
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "EQ"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
return %2 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_compare_false_ne
|
||||||
|
func @fold_compare_false_ne() -> tensor<i1> {
|
||||||
|
%0 = mhlo.constant dense<1> : tensor<i32>
|
||||||
|
%1 = mhlo.constant dense<1> : tensor<i32>
|
||||||
|
// CHECK: %0 = mhlo.constant dense<false> : tensor<i1>
|
||||||
|
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "NE"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||||
|
return %2 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_compare_true_ne
|
||||||
|
func @fold_compare_true_ne() -> tensor<i1> {
|
||||||
|
%0 = mhlo.constant dense<1> : tensor<i32>
|
||||||
|
%1 = mhlo.constant dense<0> : tensor<i32>
|
||||||
|
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1>
|
||||||
|
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "NE"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||||
|
return %2 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_compare_false_ne_float
|
||||||
|
func @fold_compare_false_ne_float() -> tensor<i1> {
|
||||||
|
%0 = mhlo.constant dense<1.> : tensor<f32>
|
||||||
|
%1 = mhlo.constant dense<1.> : tensor<f32>
|
||||||
|
// CHECK: %0 = mhlo.constant dense<false> : tensor<i1>
|
||||||
|
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "NE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
return %2 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_compare_true_ne_float
|
||||||
|
func @fold_compare_true_ne_float() -> tensor<i1> {
|
||||||
|
%0 = mhlo.constant dense<0.> : tensor<f32>
|
||||||
|
%1 = mhlo.constant dense<1.> : tensor<f32>
|
||||||
|
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1>
|
||||||
|
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "NE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
return %2 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_compare_false_lt
|
||||||
|
func @fold_compare_false_lt() -> tensor<i1> {
|
||||||
|
%0 = mhlo.constant dense<1> : tensor<i32>
|
||||||
|
%1 = mhlo.constant dense<1> : tensor<i32>
|
||||||
|
// CHECK: %0 = mhlo.constant dense<false> : tensor<i1>
|
||||||
|
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||||
|
return %2 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_compare_true_lt
|
||||||
|
func @fold_compare_true_lt() -> tensor<i1> {
|
||||||
|
%0 = mhlo.constant dense<0> : tensor<i32>
|
||||||
|
%1 = mhlo.constant dense<1> : tensor<i32>
|
||||||
|
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1>
|
||||||
|
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||||
|
return %2 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_compare_false_lt_float
|
||||||
|
func @fold_compare_false_lt_float() -> tensor<i1> {
|
||||||
|
%0 = mhlo.constant dense<1.> : tensor<f32>
|
||||||
|
%1 = mhlo.constant dense<1.> : tensor<f32>
|
||||||
|
// CHECK: %0 = mhlo.constant dense<false> : tensor<i1>
|
||||||
|
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
return %2 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_compare_true_lt_float
|
||||||
|
func @fold_compare_true_lt_float() -> tensor<i1> {
|
||||||
|
%0 = mhlo.constant dense<0.> : tensor<f32>
|
||||||
|
%1 = mhlo.constant dense<1.> : tensor<f32>
|
||||||
|
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1>
|
||||||
|
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
return %2 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_compare_false_le
|
||||||
|
func @fold_compare_false_le() -> tensor<i1> {
|
||||||
|
%0 = mhlo.constant dense<1> : tensor<i32>
|
||||||
|
%1 = mhlo.constant dense<0> : tensor<i32>
|
||||||
|
// CHECK: %0 = mhlo.constant dense<false> : tensor<i1>
|
||||||
|
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "LE"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||||
|
return %2 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_compare_true_le
|
||||||
|
func @fold_compare_true_le() -> tensor<i1> {
|
||||||
|
%0 = mhlo.constant dense<1> : tensor<i32>
|
||||||
|
%1 = mhlo.constant dense<1> : tensor<i32>
|
||||||
|
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1>
|
||||||
|
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "LE"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||||
|
return %2 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_compare_false_le_float
|
||||||
|
func @fold_compare_false_le_float() -> tensor<i1> {
|
||||||
|
%0 = mhlo.constant dense<1.> : tensor<f32>
|
||||||
|
%1 = mhlo.constant dense<0.> : tensor<f32>
|
||||||
|
// CHECK: %0 = mhlo.constant dense<false> : tensor<i1>
|
||||||
|
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "LE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
return %2 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_compare_true_le_float
|
||||||
|
func @fold_compare_true_le_float() -> tensor<i1> {
|
||||||
|
%0 = mhlo.constant dense<1.> : tensor<f32>
|
||||||
|
%1 = mhlo.constant dense<1.> : tensor<f32>
|
||||||
|
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1>
|
||||||
|
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "LE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
return %2 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_compare_false_gt
|
||||||
|
func @fold_compare_false_gt() -> tensor<i1> {
|
||||||
|
%0 = mhlo.constant dense<0> : tensor<i32>
|
||||||
|
%1 = mhlo.constant dense<0> : tensor<i32>
|
||||||
|
// CHECK: %0 = mhlo.constant dense<false> : tensor<i1>
|
||||||
|
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||||
|
return %2 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_compare_true_gt
|
||||||
|
func @fold_compare_true_gt() -> tensor<i1> {
|
||||||
|
%0 = mhlo.constant dense<1> : tensor<i32>
|
||||||
|
%1 = mhlo.constant dense<0> : tensor<i32>
|
||||||
|
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1>
|
||||||
|
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||||
|
return %2 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_compare_false_gt_float
|
||||||
|
func @fold_compare_false_gt_float() -> tensor<i1> {
|
||||||
|
%0 = mhlo.constant dense<0.> : tensor<f32>
|
||||||
|
%1 = mhlo.constant dense<0.> : tensor<f32>
|
||||||
|
// CHECK: %0 = mhlo.constant dense<false> : tensor<i1>
|
||||||
|
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
return %2 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_compare_true_gt_float
|
||||||
|
func @fold_compare_true_gt_float() -> tensor<i1> {
|
||||||
|
%0 = mhlo.constant dense<1.> : tensor<f32>
|
||||||
|
%1 = mhlo.constant dense<0.> : tensor<f32>
|
||||||
|
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1>
|
||||||
|
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
return %2 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_compare_false_ge
|
||||||
|
func @fold_compare_false_ge() -> tensor<i1> {
|
||||||
|
%0 = mhlo.constant dense<0> : tensor<i32>
|
||||||
|
%1 = mhlo.constant dense<1> : tensor<i32>
|
||||||
|
// CHECK: %0 = mhlo.constant dense<false> : tensor<i1>
|
||||||
|
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "GE"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||||
|
return %2 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_compare_true_ge
|
||||||
|
func @fold_compare_true_ge() -> tensor<i1> {
|
||||||
|
%0 = mhlo.constant dense<0> : tensor<i32>
|
||||||
|
%1 = mhlo.constant dense<0> : tensor<i32>
|
||||||
|
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1>
|
||||||
|
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "GE"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||||
|
return %2 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_compare_false_ge_float
|
||||||
|
func @fold_compare_false_ge_float() -> tensor<i1> {
|
||||||
|
%0 = mhlo.constant dense<0.> : tensor<f32>
|
||||||
|
%1 = mhlo.constant dense<1.> : tensor<f32>
|
||||||
|
// CHECK: %0 = mhlo.constant dense<false> : tensor<i1>
|
||||||
|
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "GE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
return %2 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_compare_true_ge_float
|
||||||
|
func @fold_compare_true_ge_float() -> tensor<i1> {
|
||||||
|
%0 = mhlo.constant dense<0.> : tensor<f32>
|
||||||
|
%1 = mhlo.constant dense<0.> : tensor<f32>
|
||||||
|
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1>
|
||||||
|
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "GE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
return %2 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: unpack_repack_same_tuple
|
// CHECK-LABEL: unpack_repack_same_tuple
|
||||||
// CHECK-SAME: ([[ARG0:%.*]]: tuple<tensor<i32>, !mhlo.token, tensor<f32>>)
|
// CHECK-SAME: ([[ARG0:%.*]]: tuple<tensor<i32>, !mhlo.token, tensor<f32>>)
|
||||||
func @unpack_repack_same_tuple(%arg0: tuple<tensor<i32>, !mhlo.token, tensor<f32>>) -> tuple<tensor<i32>, !mhlo.token, tensor<f32>> {
|
func @unpack_repack_same_tuple(%arg0: tuple<tensor<i32>, !mhlo.token, tensor<f32>>) -> tuple<tensor<i32>, !mhlo.token, tensor<f32>> {
|
||||||
|
|
|
@ -51,38 +51,38 @@ func @unary_ops_float(%arg0: tensor<4xf32>) -> tensor<4xf32> {
|
||||||
return %0 : tensor<4xf32>
|
return %0 : tensor<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
|
// CHECK-LABEL: func @compare_int
|
||||||
func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) {
|
func @compare_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) {
|
||||||
// CHECK-NEXT: %0 = cmpi "eq", %arg0, %arg0 : tensor<4xi32>
|
// CHECK-NEXT: %0 = cmpi "eq", %arg0, %arg1 : tensor<4xi32>
|
||||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
%0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||||
// CHECK-NEXT: %1 = cmpi "ne", %arg0, %arg0 : tensor<4xi32>
|
// CHECK-NEXT: %1 = cmpi "ne", %arg0, %arg1 : tensor<4xi32>
|
||||||
%1 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
%1 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||||
// CHECK-NEXT: %2 = cmpi "slt", %arg0, %arg0 : tensor<4xi32>
|
// CHECK-NEXT: %2 = cmpi "slt", %arg0, %arg1 : tensor<4xi32>
|
||||||
%2 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
%2 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||||
// CHECK-NEXT: %3 = cmpi "sle", %arg0, %arg0 : tensor<4xi32>
|
// CHECK-NEXT: %3 = cmpi "sle", %arg0, %arg1 : tensor<4xi32>
|
||||||
%3 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
%3 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||||
// CHECK-NEXT: %4 = cmpi "sgt", %arg0, %arg0 : tensor<4xi32>
|
// CHECK-NEXT: %4 = cmpi "sgt", %arg0, %arg1 : tensor<4xi32>
|
||||||
%4 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
%4 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||||
// CHECK-NEXT: %5 = cmpi "sge", %arg0, %arg0 : tensor<4xi32>
|
// CHECK-NEXT: %5 = cmpi "sge", %arg0, %arg1 : tensor<4xi32>
|
||||||
%5 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
%5 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||||
// CHECK-NEXT: return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
|
// CHECK-NEXT: return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
|
||||||
return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
|
return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @compare_float
|
// CHECK-LABEL: func @compare_float
|
||||||
func @compare_float(%arg0: tensor<4xf32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) {
|
func @compare_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) {
|
||||||
// CHECK-NEXT: %0 = cmpf "oeq", %arg0, %arg0 : tensor<4xf32>
|
// CHECK-NEXT: %0 = cmpf "oeq", %arg0, %arg1 : tensor<4xf32>
|
||||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
%0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||||
// CHECK-NEXT: %1 = cmpf "une", %arg0, %arg0 : tensor<4xf32>
|
// CHECK-NEXT: %1 = cmpf "une", %arg0, %arg1 : tensor<4xf32>
|
||||||
%1 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
%1 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||||
// CHECK-NEXT: %2 = cmpf "olt", %arg0, %arg0 : tensor<4xf32>
|
// CHECK-NEXT: %2 = cmpf "olt", %arg0, %arg1 : tensor<4xf32>
|
||||||
%2 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
%2 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||||
// CHECK-NEXT: %3 = cmpf "ole", %arg0, %arg0 : tensor<4xf32>
|
// CHECK-NEXT: %3 = cmpf "ole", %arg0, %arg1 : tensor<4xf32>
|
||||||
%3 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
%3 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||||
// CHECK-NEXT: %4 = cmpf "ogt", %arg0, %arg0 : tensor<4xf32>
|
// CHECK-NEXT: %4 = cmpf "ogt", %arg0, %arg1 : tensor<4xf32>
|
||||||
%4 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
%4 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||||
// CHECK-NEXT: %5 = cmpf "oge", %arg0, %arg0 : tensor<4xf32>
|
// CHECK-NEXT: %5 = cmpf "oge", %arg0, %arg1 : tensor<4xf32>
|
||||||
%5 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
%5 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||||
return %0, %1, %2, %3, %4, %5: tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
|
return %0, %1, %2, %3, %4, %5: tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue