diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index 9cf1b6c..d59dfd4 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -22,6 +22,7 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/TypeUtilities.h" namespace mlir { namespace lmhlo { @@ -96,7 +97,7 @@ template struct MapLhloOpToStdScalarOpImpl { Value operator()(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - Type element_type = args.front().getType(); + Type element_type = getElementTypeOrSelf(args.front().getType()); if (element_type.isa()) { return b->template create(loc, result_types, args, mlir::None); @@ -120,7 +121,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - Type element_type = args.front().getType(); + Type element_type = getElementTypeOrSelf(args.front().getType()); if (element_type.isa()) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); @@ -130,8 +131,11 @@ inline Value MapLhloOpToStdScalarOp(Location loc, Value lhs = args[0]; auto integer_type = element_type.dyn_cast(); - auto zero_intval = + Value zero_intval = b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); + if (VectorType vec_type = args.front().getType().dyn_cast()) { + zero_intval = b->create<::mlir::SplatOp>(loc, vec_type, zero_intval); + } auto lhs_gt_zero = b->create>(loc, CmpIPredicate::sge, lhs, zero_intval); auto neg_val = b->create>(loc, zero_intval, lhs); @@ -196,7 +200,7 @@ inline Value MapCompareOpToStdScalarOp(Location loc, ArrayRef args, OpBuilder* b) { const auto& lhs = args[0]; const auto& rhs = args[1]; - Type element_type = lhs.getType(); + Type element_type = getElementTypeOrSelf(lhs.getType()); if (element_type.isSignlessInteger()) { Optional predicate = getCmpPredicate(comparison_direction); @@ -268,8 +272,8 @@ template <> inline Value MapLhloOpToStdScalarOp( Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - Type sourceType = args.front().getType(); - Type targetType = result_types.front(); + Type sourceType = getElementTypeOrSelf(args.front().getType()); + Type targetType = getElementTypeOrSelf(result_types.front()); if (mlir::SIToFPOp::areCastCompatible(sourceType, targetType)) { return b->create(loc, result_types, args, mlir::None); @@ -390,7 +394,7 @@ struct CompareSelectOpToStdScalarOp result_types, ArrayRef args, OpBuilder* b) { - Type element_type = args.front().getType(); + Type element_type = getElementTypeOrSelf(args.front().getType()); if (element_type.isa()) { auto predicate = getCmpPredicate(comparison_direction); assert(predicate.hasValue() && "expected valid comparison direction"); @@ -439,7 +443,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - Type element_type = args.front().getType(); + Type element_type = getElementTypeOrSelf(args.front().getType()); if (element_type.isa()) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); @@ -449,8 +453,11 @@ inline Value MapLhloOpToStdScalarOp(Location loc, Value lhs = args[0]; auto integer_type = element_type.dyn_cast(); - auto zero_intval = + Value zero_intval = b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); + if (VectorType vec_type = args.front().getType().dyn_cast()) { + zero_intval = b->create<::mlir::SplatOp>(loc, vec_type, zero_intval); + } return b->create>(loc, zero_intval, lhs); } return nullptr; @@ -461,11 +468,14 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - Type element_type = args.front().getType(); + Type element_type = getElementTypeOrSelf(args.front().getType()); if (auto integer_type = element_type.dyn_cast()) { // lmhlo.not(x) -> x ^ -1 - auto all_ones = + Value all_ones = b->create<::mlir::ConstantIntOp>(loc, -1, integer_type.getWidth()); + if (VectorType vec_type = args.front().getType().dyn_cast()) { + all_ones = b->create<::mlir::SplatOp>(loc, vec_type, all_ones); + } return b->create<::mlir::XOrOp>(loc, all_ones, args[0]); } return nullptr; @@ -493,26 +503,35 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - Type element_type = args.front().getType(); + Type element_type = getElementTypeOrSelf(args.front().getType()); if (auto float_type = element_type.dyn_cast()) { bool ignored; APFloat one_apfloat(1.0f); one_apfloat.convert(float_type.getFloatSemantics(), APFloat::rmNearestTiesToEven, &ignored); Value one = b->create(loc, one_apfloat, float_type); + if (VectorType vec_type = args.front().getType().dyn_cast()) { + one = b->create<::mlir::SplatOp>(loc, vec_type, one); + } return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]); } else if (auto integer_type = element_type.dyn_cast()) { // sign(x) = x == 0 ? 0 : ((x s>> 31) | 1) Value zero = b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); - Value cmp = - b->create<::mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0], zero); Value bitwidth_minus_one = b->create<::mlir::ConstantIntOp>( loc, integer_type.getWidth() - 1, integer_type.getWidth()); - Value ashr = - b->create<::mlir::SignedShiftRightOp>(loc, args[0], bitwidth_minus_one); Value one = b->create<::mlir::ConstantIntOp>(loc, 1, integer_type.getWidth()); + if (VectorType vec_type = args.front().getType().dyn_cast()) { + zero = b->create<::mlir::SplatOp>(loc, vec_type, zero); + bitwidth_minus_one = + b->create<::mlir::SplatOp>(loc, vec_type, bitwidth_minus_one); + one = b->create<::mlir::SplatOp>(loc, vec_type, one); + } + Value cmp = + b->create<::mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0], zero); + Value ashr = + b->create<::mlir::SignedShiftRightOp>(loc, args[0], bitwidth_minus_one); Value or_op = b->create<::mlir::OrOp>(loc, ashr, one); return b->create<::mlir::SelectOp>(loc, cmp, zero, or_op); } @@ -583,6 +602,27 @@ struct HloOpToStdScalarOp { return impl::MapCompareOpToStdScalarOp( op.getLoc(), comparison_direction, result_types, args, b); } + + // Implementation for LHLO ops except lmhlo::CompareOp. + template ::value && + std::is_same, + std::false_type>::value>> + static Value map(Location loc, ArrayRef result_types, + ArrayRef args, OpBuilder* b, unsigned i = 0) { + return impl::MapLhloOpToStdScalarOp(loc, result_types, args, b); + } + + // Implementation for lmhlo::CompareOp. + template ::value>> + static Value map(Location loc, StringRef comparison_direction, + ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return impl::MapCompareOpToStdScalarOp( + loc, comparison_direction, result_types, args, b); + } }; } // namespace lmhlo diff --git a/tests/lhlo-legalize-to-linalg.mlir b/tests/lhlo-legalize-to-linalg.mlir index debb035..4715108 100644 --- a/tests/lhlo-legalize-to-linalg.mlir +++ b/tests/lhlo-legalize-to-linalg.mlir @@ -621,10 +621,10 @@ func @sign_i16(%input: memref<2x2xi16>, %result: memref<2x2xi16>) { // CHECK: linalg.generic // CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i16, %[[RESULT_OUT:.*]]): // CHECK-NEXT: %[[C0:.*]] = constant 0 : i16 -// CHECK-NEXT: %[[CMP:.*]] = cmpi "eq", %[[OPERAND_IN]], %[[C0]] : i16 // CHECK-NEXT: %[[C15:.*]] = constant 15 : i16 -// CHECK-NEXT: %[[ASHR:.*]] = shift_right_signed %[[OPERAND_IN]], %[[C15]] : i16 // CHECK-NEXT: %[[C1:.*]] = constant 1 : i16 +// CHECK-NEXT: %[[CMP:.*]] = cmpi "eq", %[[OPERAND_IN]], %[[C0]] : i16 +// CHECK-NEXT: %[[ASHR:.*]] = shift_right_signed %[[OPERAND_IN]], %[[C15]] : i16 // CHECK-NEXT: %[[OR:.*]] = or %[[ASHR]], %[[C1]] : i16 // CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[C0]], %[[OR]] : i16 // CHECK-NEXT: linalg.yield %[[RESULT]] : i16