From f4414fcd666b59d3cff6737bf9587791f99e6c58 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Thu, 29 Apr 2021 02:26:25 -0700 Subject: [PATCH] [MHLO:Linalg] Add support for lowering unsigned ops This strips away the signedness with a type converter, using unrealized conversion casts. The rest is mostly mechanically pushing the original op down the pipeline so lowerings can see the original types. Signed types stay signless for now. This can be changed in the HLO bridge later. I did a pass over all ops and added unsigned lowerings where they were missing. There may be more. Currently the lowering will die at a later stage because it doesn't understand the unrealized casts. PiperOrigin-RevId: 371077494 --- .../mhlo/transforms/map_lmhlo_to_scalar_op.h | 353 +++++++++++------- .../Dialect/mhlo/transforms/rewriters.h | 4 + .../mhlo/transforms/legalize_to_linalg.cc | 109 +++++- tests/hlo-legalize-to-linalg.mlir | 55 +++ tests/lhlo-legalize-to-linalg.mlir | 15 + 5 files changed, 381 insertions(+), 155 deletions(-) diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index d6bbc17..91ed8f0 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -44,41 +44,50 @@ template <> struct LhloToScalarOp { using FOp = ::mlir::AddFOp; using IOp = ::mlir::AddIOp; + using UOp = ::mlir::AddIOp; using COp = ::mlir::complex::AddOp; }; template <> struct LhloToScalarOp { using FOp = ::mlir::CmpFOp; using IOp = ::mlir::CmpIOp; + using UOp = ::mlir::CmpIOp; }; template <> struct LhloToScalarOp { using FOp = ::mlir::DivFOp; using IOp = ::mlir::SignedDivIOp; + using UOp = ::mlir::UnsignedDivIOp; }; template <> struct LhloToScalarOp { using FOp = ::mlir::MulFOp; using IOp = ::mlir::MulIOp; + using UOp = ::mlir::MulIOp; }; template <> struct LhloToScalarOp { using FOp = ::mlir::RemFOp; using IOp = ::mlir::SignedRemIOp; + using UOp = ::mlir::UnsignedRemIOp; }; template <> struct LhloToScalarOp { using FOp = ::mlir::SubFOp; using IOp = ::mlir::SubIOp; + using UOp = ::mlir::SubIOp; using COp = ::mlir::complex::SubOp; }; // Alias for the map from LHLO binary op type to STD floating-point op type. template using ScalarFOp = typename LhloToScalarOp::FOp; -// Alias for the map from LHLO binary op type to STD integer op type. +// Alias for the map from LHLO binary op type to STD signed integer op type. template using ScalarIOp = typename LhloToScalarOp::IOp; +// Alias for the map from LHLO binary op type to STD unsigned integer op type. +template +using ScalarUOp = typename LhloToScalarOp::UOp; // Alias for the map from LHLO binary op type to STD complex op type. template using ScalarCOp = typename LhloToScalarOp::COp; @@ -86,7 +95,8 @@ using ScalarCOp = typename LhloToScalarOp::COp; template struct MapLhloOpToScalarOpImpl { Value operator()(Location loc, ArrayRef result_types, - ArrayRef args, OpBuilder* b) { + ArrayRef arg_types, ArrayRef args, + OpBuilder* b) { return nullptr; } }; @@ -94,7 +104,8 @@ struct MapLhloOpToScalarOpImpl { template struct MapLhloOpToScalarOpImpl { Value operator()(Location loc, ArrayRef result_types, - ArrayRef args, OpBuilder* b) { + ArrayRef arg_types, ArrayRef args, + OpBuilder* b) { return b->template create(loc, result_types, args, mlir::None); } }; @@ -102,41 +113,69 @@ struct MapLhloOpToScalarOpImpl { template struct MapLhloOpToScalarOpImpl { Value operator()(Location loc, ArrayRef result_types, - ArrayRef args, OpBuilder* b) { - Type element_type = getElementTypeOrSelf(args.front().getType()); - if (element_type.isa()) { + ArrayRef arg_types, ArrayRef args, + OpBuilder* b) { + Type element_type = getElementTypeOrSelf(arg_types.front()); + if (SupportedType{}(element_type)) { return b->template create(loc, result_types, args, mlir::None); } - return MapLhloOpToScalarOpImpl{}(loc, result_types, args, b); + return MapLhloOpToScalarOpImpl{}(loc, result_types, arg_types, + args, b); } }; +struct isAnyIntegerType { + bool operator()(Type t) { return t.isa(); } +}; + +struct isSignedIntegerType { + bool operator()(Type t) { + // Pretend that signless is signed. This will change eventually. + return t.isa() && !t.isUnsignedInteger(); + } +}; + +struct isUnsignedIntegerType { + bool operator()(Type t) { return t.isUnsignedInteger(); } +}; + +struct isFloatType { + bool operator()(Type t) { return t.isa(); } +}; + +struct isComplexType { + bool operator()(Type t) { return t.isa(); } +}; + // Inserts the computation that corresponds to the body of the loop for lowered // LHLO unary/binary op. Returns the value for the result. template inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, + ArrayRef arg_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToScalarOpImpl, FloatType, - ScalarFOp>{}(loc, result_types, args, - b); + return MapLhloOpToScalarOpImpl, + isUnsignedIntegerType, ScalarUOp, + isFloatType, ScalarFOp>{}( + loc, result_types, arg_types, args, b); } template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, + ArrayRef arg_types, ArrayRef args, OpBuilder* b) { - Type element_type = getElementTypeOrSelf(args.front().getType()); + Type element_type = getElementTypeOrSelf(arg_types.front()); if (element_type.isa()) { - return MapLhloOpToScalarOpImpl{}( - loc, result_types, args, b); + return MapLhloOpToScalarOpImpl{}( + loc, result_types, arg_types, args, b); } if (element_type.isa()) { - return MapLhloOpToScalarOpImpl{}( - loc, result_types, args, b); + return MapLhloOpToScalarOpImpl{}( + loc, result_types, arg_types, args, b); } - if (element_type.isa()) { + if (element_type.isSignlessInteger() || element_type.isSignedInteger()) { // lmhlo.abs(x, result) -> result = select((x > 0), x, sub(0, x)) Value lhs = args[0]; auto integer_type = element_type.dyn_cast(); @@ -156,40 +195,44 @@ inline Value MapLhloOpToStdScalarOp(Location loc, template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, + ArrayRef arg_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToScalarOpImpl, - FloatType, ScalarFOp, - ComplexType, ScalarCOp>{}( - loc, result_types, args, b); + return MapLhloOpToScalarOpImpl, + isFloatType, ScalarFOp, + isComplexType, ScalarCOp>{}( + loc, result_types, arg_types, args, b); } template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, + ArrayRef arg_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToScalarOpImpl{}( - loc, result_types, args, b); + return MapLhloOpToScalarOpImpl{}( + loc, result_types, arg_types, args, b); } template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, + ArrayRef arg_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToScalarOpImpl{}( - loc, result_types, args, b); + return MapLhloOpToScalarOpImpl{}( + loc, result_types, arg_types, args, b); } template -inline Optional getCmpPredicate(StringRef comparison_direction) { +inline Optional getCmpPredicate(StringRef, bool) { return llvm::None; } template <> inline Optional getCmpPredicate( - StringRef comparison_direction) { + StringRef comparison_direction, bool is_signed) { + assert(is_signed && "cannot have an unsigned float!"); return llvm::StringSwitch>(comparison_direction) .Case("EQ", CmpFPredicate::OEQ) .Case("NE", CmpFPredicate::UNE) @@ -202,14 +245,14 @@ inline Optional getCmpPredicate( template <> inline Optional getCmpPredicate( - StringRef comparison_direction) { + StringRef comparison_direction, bool is_signed) { return llvm::StringSwitch>(comparison_direction) .Case("EQ", CmpIPredicate::eq) .Case("NE", CmpIPredicate::ne) - .Case("GE", CmpIPredicate::sge) - .Case("GT", CmpIPredicate::sgt) - .Case("LE", CmpIPredicate::sle) - .Case("LT", CmpIPredicate::slt) + .Case("GE", is_signed ? CmpIPredicate::sge : CmpIPredicate::uge) + .Case("GT", is_signed ? CmpIPredicate::sgt : CmpIPredicate::ugt) + .Case("LE", is_signed ? CmpIPredicate::sle : CmpIPredicate::ule) + .Case("LT", is_signed ? CmpIPredicate::slt : CmpIPredicate::ult) .Default(llvm::None); } @@ -217,20 +260,21 @@ template inline Value MapCompareOpToStdScalarOp(Location loc, StringRef comparison_direction, ArrayRef result_types, + ArrayRef arg_types, ArrayRef args, OpBuilder* b) { const auto& lhs = args[0]; const auto& rhs = args[1]; - Type element_type = getElementTypeOrSelf(lhs.getType()); - if (element_type.isSignlessInteger()) { - Optional predicate = - getCmpPredicate(comparison_direction); + Type element_type = getElementTypeOrSelf(arg_types.front()); + if (element_type.isa()) { + Optional predicate = getCmpPredicate( + comparison_direction, !element_type.isUnsignedInteger()); assert(predicate.hasValue() && "expected valid comparison direction"); return b->create>(loc, predicate.getValue(), lhs, rhs); } if (element_type.isa()) { - Optional predicate = - getCmpPredicate(comparison_direction); + Optional predicate = getCmpPredicate( + comparison_direction, /*is_signed=*/true); assert(predicate.hasValue() && "expected valid comparison direction"); return b->create>(loc, predicate.getValue(), lhs, rhs); @@ -241,6 +285,7 @@ inline Value MapCompareOpToStdScalarOp(Location loc, template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, + ArrayRef arg_types, ArrayRef args, OpBuilder* b) { return args.front(); @@ -249,59 +294,66 @@ inline Value MapLhloOpToStdScalarOp(Location loc, template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, + ArrayRef arg_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToScalarOpImpl{}( - loc, result_types, args, b); + return MapLhloOpToScalarOpImpl{}( + loc, result_types, arg_types, args, b); } template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, + ArrayRef arg_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToScalarOpImpl{}( - loc, result_types, args, b); + return MapLhloOpToScalarOpImpl{}( + loc, result_types, arg_types, args, b); } template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, + ArrayRef arg_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToScalarOpImpl{}( - loc, result_types, args, b); + return MapLhloOpToScalarOpImpl{}( + loc, result_types, arg_types, args, b); } template <> inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { - return MapLhloOpToScalarOpImpl{}(loc, result_types, args, - b); + Location loc, ArrayRef result_types, ArrayRef arg_types, + ArrayRef args, OpBuilder* b) { + return MapLhloOpToScalarOpImpl{}(loc, result_types, + arg_types, args, b); } template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, + ArrayRef arg_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToScalarOpImpl{}(loc, result_types, args, b); + return MapLhloOpToScalarOpImpl{}(loc, result_types, arg_types, + args, b); } template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, + ArrayRef arg_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToScalarOpImpl{}(loc, result_types, args, b); + return MapLhloOpToScalarOpImpl{}(loc, result_types, arg_types, + args, b); } template <> inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { - Type sourceType = getElementTypeOrSelf(args.front().getType()); + Location loc, ArrayRef result_types, ArrayRef arg_types, + ArrayRef args, OpBuilder* b) { + Type sourceType = getElementTypeOrSelf(arg_types.front()); Type targetType = getElementTypeOrSelf(result_types.front()); // A boolean value is considered to be unsigned when converting to @@ -342,7 +394,7 @@ inline Value MapLhloOpToStdScalarOp( zero); } } - if (sourceType.isSignlessInteger() && targetType.isSignlessInteger()) { + if (sourceType.isa() && targetType.isa()) { IntegerType src = sourceType.cast(); IntegerType res = targetType.cast(); if (src.getWidth() > res.getWidth()) { @@ -352,6 +404,10 @@ inline Value MapLhloOpToStdScalarOp( return b->create(loc, result_types, args, mlir::None); } else if (src.getWidth() < res.getWidth()) { + if (src.isUnsignedInteger()) { + return b->create(loc, result_types, args, + mlir::None); + } return b->create(loc, result_types, args, mlir::None); } @@ -367,6 +423,7 @@ inline Value MapLhloOpToStdScalarOp( template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, + ArrayRef arg_types, ArrayRef args, OpBuilder* b) { // Dot Op converter from lhlo to affine only accepts float and integer types. @@ -375,16 +432,16 @@ inline Value MapLhloOpToStdScalarOp(Location loc, const auto& result = args[2]; Type element_type = lhs.getType(); if (element_type.isa()) { - Value float_mul = MapLhloOpToScalarOpImpl{}( - loc, result_types, {lhs, rhs}, b); - return MapLhloOpToScalarOpImpl{}( - loc, result_types, {float_mul, result}, b); + Value float_mul = MapLhloOpToScalarOpImpl{}( + loc, result_types, arg_types, {lhs, rhs}, b); + return MapLhloOpToScalarOpImpl{}( + loc, result_types, arg_types, {float_mul, result}, b); } if (element_type.isa()) { - Value int_mul = MapLhloOpToScalarOpImpl{}( - loc, result_types, {lhs, rhs}, b); - return MapLhloOpToScalarOpImpl{}( - loc, result_types, {int_mul, result}, b); + Value int_mul = MapLhloOpToScalarOpImpl{}( + loc, result_types, arg_types, {lhs, rhs}, b); + return MapLhloOpToScalarOpImpl{}( + loc, result_types, arg_types, {int_mul, result}, b); } return nullptr; } @@ -392,34 +449,37 @@ inline Value MapLhloOpToStdScalarOp(Location loc, template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, + ArrayRef arg_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToScalarOpImpl{}( - loc, result_types, args, b); + return MapLhloOpToScalarOpImpl{}( + loc, result_types, arg_types, args, b); } template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, + ArrayRef arg_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToScalarOpImpl{}( - loc, result_types, args, b); + return MapLhloOpToScalarOpImpl{}( + loc, result_types, arg_types, args, b); } template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, + ArrayRef arg_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToScalarOpImpl{}( - loc, result_types, args, b); + return MapLhloOpToScalarOpImpl{}( + loc, result_types, arg_types, args, b); } template <> inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { + Location loc, ArrayRef result_types, ArrayRef arg_types, + ArrayRef args, OpBuilder* b) { if (args[0].getType().isa()) { auto pos_inf = APFloat::getInf( args[0].getType().cast().getFloatSemantics()); @@ -437,8 +497,8 @@ inline Value MapLhloOpToStdScalarOp( template struct CompareSelectOpToStdScalarOp { static Value map(Location loc, StringRef comparison_direction, - ArrayRef result_types, ArrayRef args, - OpBuilder* b) { + ArrayRef result_types, ArrayRef arg_types, + ArrayRef args, OpBuilder* b) { return nullptr; } }; @@ -450,28 +510,30 @@ template { static Value map(Location loc, StringRef comparison_direction, - ArrayRef result_types, ArrayRef args, - OpBuilder* b) { - Type element_type = getElementTypeOrSelf(args.front().getType()); + ArrayRef result_types, ArrayRef arg_types, + ArrayRef args, OpBuilder* b) { + Type element_type = getElementTypeOrSelf(arg_types.front()); if (element_type.isa()) { - auto predicate = getCmpPredicate(comparison_direction); + auto predicate = getCmpPredicate( + comparison_direction, !element_type.isUnsignedInteger()); assert(predicate.hasValue() && "expected valid comparison direction"); auto cmp = b->template create(loc, predicate.getValue(), args[0], args[1]); return b->create<::mlir::SelectOp>(loc, cmp, args[0], args[1]); } - return CompareSelectOpToStdScalarOp::map(loc, comparison_direction, - result_types, args, b); + return CompareSelectOpToStdScalarOp::map( + loc, comparison_direction, result_types, arg_types, args, b); } }; template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, + ArrayRef arg_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToScalarOpImpl{}( - loc, result_types, args, b); + return MapLhloOpToScalarOpImpl{}( + loc, result_types, arg_types, args, b); } inline Value LhloAlwaysPropagateNaN(Value v, ArrayRef args, Location loc, @@ -493,8 +555,8 @@ inline Value LhloAlwaysPropagateNaN(Value v, ArrayRef args, Location loc, template <> inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { + Location loc, ArrayRef result_types, ArrayRef arg_types, + ArrayRef args, OpBuilder* b) { auto ty = result_types.front().cast(); Value one = b->create(loc, b->getFloatAttr(ty, 1.0)); Value x = args.front(); @@ -507,43 +569,47 @@ inline Value MapLhloOpToStdScalarOp( template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, + ArrayRef arg_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToScalarOpImpl{}( - loc, result_types, args, b); + return MapLhloOpToScalarOpImpl{}( + loc, result_types, arg_types, args, b); } template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, + ArrayRef arg_types, ArrayRef args, OpBuilder* b) { return LhloAlwaysPropagateNaN( CompareSelectOpToStdScalarOp< IntegerType, ScalarIOp, CmpIPredicate, FloatType, ScalarFOp, CmpFPredicate>::map(loc, "GT", - result_types, args, - b), + result_types, + arg_types, args, b), args, loc, b); } template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, + ArrayRef arg_types, ArrayRef args, OpBuilder* b) { return LhloAlwaysPropagateNaN( CompareSelectOpToStdScalarOp< IntegerType, ScalarIOp, CmpIPredicate, FloatType, ScalarFOp, CmpFPredicate>::map(loc, "LT", - result_types, args, - b), + result_types, + arg_types, args, b), args, loc, b); } template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, + ArrayRef arg_types, ArrayRef args, OpBuilder* b) { assert(args.size() == 3 && "expected 3 arguments"); @@ -552,21 +618,22 @@ inline Value MapLhloOpToStdScalarOp(Location loc, Value ub = args[2]; // clamp(lb, x, ub) = max(min(x, ub), lb) - Value min_x_ub = - MapLhloOpToStdScalarOp(loc, result_types, {x, ub}, b); - return MapLhloOpToStdScalarOp(loc, result_types, {min_x_ub, lb}, - b); + Value min_x_ub = MapLhloOpToStdScalarOp(loc, result_types, + arg_types, {x, ub}, b); + return MapLhloOpToStdScalarOp(loc, result_types, arg_types, + {min_x_ub, lb}, b); } template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, + ArrayRef arg_types, ArrayRef args, OpBuilder* b) { Type element_type = getElementTypeOrSelf(args.front().getType()); if (element_type.isa()) { - return MapLhloOpToScalarOpImpl{}( - loc, result_types, args, b); + return MapLhloOpToScalarOpImpl{}( + loc, result_types, arg_types, args, b); } if (element_type.isa()) { // lmhlo.neg(x, result) -> result = sub(0, x) @@ -586,6 +653,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, + ArrayRef arg_types, ArrayRef args, OpBuilder* b) { Type element_type = getElementTypeOrSelf(args.front().getType()); @@ -604,24 +672,27 @@ inline Value MapLhloOpToStdScalarOp(Location loc, template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, + ArrayRef arg_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToScalarOpImpl{}(loc, result_types, - args, b); + return MapLhloOpToScalarOpImpl{}( + loc, result_types, arg_types, args, b); } template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, + ArrayRef arg_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToScalarOpImpl{}( - loc, result_types, args, b); + return MapLhloOpToScalarOpImpl{}( + loc, result_types, arg_types, args, b); } template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, + ArrayRef arg_types, ArrayRef args, OpBuilder* b) { lmhlo::PowOp::Adaptor adaptor(args); @@ -630,7 +701,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, auto result_type = result_types.front(); if (result_type.isa<::mlir::FloatType>()) return MapLhloOpToScalarOpImpl<::mlir::math::PowFOp>{}(loc, result_types, - args, b); + arg_types, args, b); assert(result_type.isa<::mlir::IntegerType>() && "only float and integer `pow` is supported right now"); @@ -699,39 +770,41 @@ inline Value MapLhloOpToStdScalarOp(Location loc, template <> inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { - return MapLhloOpToScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, args, - b); + Location loc, ArrayRef result_types, ArrayRef arg_types, + ArrayRef args, OpBuilder* b) { + return MapLhloOpToScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, + arg_types, args, b); } template <> inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { - return MapLhloOpToScalarOpImpl{}( - loc, result_types, args, b); + Location loc, ArrayRef result_types, ArrayRef arg_types, + ArrayRef args, OpBuilder* b) { + return MapLhloOpToScalarOpImpl{}( + loc, result_types, arg_types, args, b); } template <> inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { - return MapLhloOpToScalarOpImpl{}( - loc, result_types, args, b); + Location loc, ArrayRef result_types, ArrayRef arg_types, + ArrayRef args, OpBuilder* b) { + return MapLhloOpToScalarOpImpl{}( + loc, result_types, arg_types, args, b); } template <> inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { - return MapLhloOpToScalarOpImpl{}( - loc, result_types, args, b); + Location loc, ArrayRef result_types, ArrayRef arg_types, + ArrayRef args, OpBuilder* b) { + return MapLhloOpToScalarOpImpl{}( + loc, result_types, arg_types, args, b); } template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, + ArrayRef arg_types, ArrayRef args, OpBuilder* b) { Type element_type = getElementTypeOrSelf(args.front().getType()); @@ -780,39 +853,43 @@ inline Value MapLhloOpToStdScalarOp(Location loc, template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, + ArrayRef arg_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToScalarOpImpl{}( - loc, result_types, args, b); + return MapLhloOpToScalarOpImpl{}( + loc, result_types, arg_types, args, b); } template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, + ArrayRef arg_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToScalarOpImpl, - FloatType, ScalarFOp, - ComplexType, ScalarCOp>{}( - loc, result_types, args, b); + return MapLhloOpToScalarOpImpl, + isFloatType, ScalarFOp, + isComplexType, ScalarCOp>{}( + loc, result_types, arg_types, args, b); } template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, + ArrayRef arg_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToScalarOpImpl{}( - loc, result_types, args, b); + return MapLhloOpToScalarOpImpl{}( + loc, result_types, arg_types, args, b); } template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, + ArrayRef arg_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToScalarOpImpl{}( - loc, result_types, args, b); + return MapLhloOpToScalarOpImpl{}( + loc, result_types, arg_types, args, b); } } // namespace impl @@ -826,8 +903,9 @@ struct HloOpToStdScalarOp { std::false_type>::value>> static Value map(HloOpTy op, ArrayRef result_types, ArrayRef args, OpBuilder* b, unsigned i = 0) { - return impl::MapLhloOpToStdScalarOp(op.getLoc(), result_types, - args, b); + return impl::MapLhloOpToStdScalarOp( + op.getLoc(), result_types, llvm::to_vector<4>(op->getOperandTypes()), + args, b); } // Implementation for HLO ops except mhlo::CompareOp. @@ -837,8 +915,9 @@ struct HloOpToStdScalarOp { !std::is_same::value>> static Value map(HloOpTy op, ArrayRef result_types, ArrayRef args, OpBuilder* b, int i = 0) { - return impl::MapLhloOpToStdScalarOp(op.getLoc(), result_types, - args, b); + return impl::MapLhloOpToStdScalarOp( + op.getLoc(), result_types, llvm::to_vector<4>(op->getOperandTypes()), + args, b); } // Implementation for lmhlo::CompareOp. @@ -848,7 +927,8 @@ struct HloOpToStdScalarOp { ArrayRef args, OpBuilder* b) { auto comparison_direction = op.comparison_direction(); return impl::MapCompareOpToStdScalarOp( - op.getLoc(), comparison_direction, result_types, args, b); + op.getLoc(), comparison_direction, result_types, + llvm::to_vector<4>(op->getOperandTypes()), args, b); } // Implementation for mhlo::CompareOp. @@ -859,7 +939,8 @@ struct HloOpToStdScalarOp { ArrayRef args, OpBuilder* b) { auto comparison_direction = op.comparison_direction(); return impl::MapCompareOpToStdScalarOp( - op.getLoc(), comparison_direction, result_types, args, b); + op.getLoc(), comparison_direction, result_types, + llvm::to_vector<4>(op->getOperandTypes()), args, b); } // Implementation for LHLO ops except lmhlo::CompareOp. @@ -869,18 +950,20 @@ struct HloOpToStdScalarOp { std::is_same, std::false_type>::value>> static Value map(Location loc, ArrayRef result_types, - ArrayRef args, OpBuilder* b, unsigned i = 0) { - return impl::MapLhloOpToStdScalarOp(loc, result_types, args, b); + ArrayRef arg_types, ArrayRef args, OpBuilder* b, + unsigned i = 0) { + return impl::MapLhloOpToStdScalarOp(loc, result_types, arg_types, + args, b); } // Implementation for lmhlo::CompareOp. template ::value>> static Value map(Location loc, StringRef comparison_direction, - ArrayRef result_types, ArrayRef args, - OpBuilder* b) { + ArrayRef result_types, ArrayRef arg_types, + ArrayRef args, OpBuilder* b) { return impl::MapCompareOpToStdScalarOp( - loc, comparison_direction, result_types, args, b); + loc, comparison_direction, result_types, arg_types, args, b); } }; diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h index 672711d..bd12379 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h @@ -57,8 +57,12 @@ void populateHLOToLHLOConversionPattern(MLIRContext *context, // Collection of rewrite patterns for lowering of HLO to Linalg dialect. void populateHLOToLinalgConversionPattern(MLIRContext *context, + TypeConverter &typeConverter, OwningRewritePatternList *patterns); +// Converter to signless intergers to be used with linalg conversion patterns. +std::unique_ptr createHloToLinalgSignedIntegerConverter(); + // Sets up legality definitions for materializing broadcasts. void SetupMaterializeBroadcastsLegality(MLIRContext *context, ConversionTarget *conversionTarget); diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 6ce42ed..ec57da4 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -241,16 +241,15 @@ class PointwiseToLinalgConverter : public OpConversionPattern { !(t.getElementType().isSignlessIntOrFloat() || t.getElementType().isa()); }; - if (llvm::any_of(args, - [&](Value v) { - return fail(v.getType().dyn_cast()); - }) || - llvm::any_of(op.getOperation()->getResultTypes(), - [&](Type t) { return fail(t.dyn_cast()); })) + if (llvm::any_of(op.getOperation()->getResultTypes(), [&](Type t) { + return fail(this->typeConverter->convertType(t) + .template dyn_cast()); + })) { return emitError(loc, - "lhlo to linalg conversion expects ranked args of " + "hlo to linalg conversion expects ranked args of " "signless int, float or complex element type with ") << nloops << " parallel iterators: " << *(op.getOperation()); + } // Construct the indexing maps needed for linalg.generic ops. SmallVector body_arg_types, body_result_types, op_result_types; @@ -270,12 +269,12 @@ class PointwiseToLinalgConverter : public OpConversionPattern { if (isLHLO) { output_buffers.append(args.begin() + num_inputs, args.end()); } else { - Value result = op.getOperation()->getResult(0); - ShapedType result_type = result.getType().template cast(); + Type result_type = this->typeConverter->convertType( + op.getOperation()->getResult(0).getType()); auto dyn_sizes = ExtractDynamicSizes(rewriter, loc, args[0]); - output_buffers.push_back( - GetInitTensor(rewriter, loc, result_type, dyn_sizes)); - op_result_types.push_back(result.getType()); + output_buffers.push_back(GetInitTensor( + rewriter, loc, result_type.cast(), dyn_sizes)); + op_result_types.push_back(result_type); } body_result_types = llvm::to_vector<4>(llvm::map_range( output_buffers, [](Value v) { return getElementTypeOrSelf(v); })); @@ -1279,8 +1278,10 @@ class DynamicSliceConverter : public OpConversionPattern { // map_lmhlo_to_scalar_op.h requires to pass a mhlo op. It will convert it // to an lmhlo op and call the lmhlo implementation. start_index = lmhlo::HloOpToStdScalarOp::map( - loc, start_index.getType(), ArrayRef{zero, start_index, ub}, - &rewriter); + loc, start_index.getType(), + ArrayRef{start_index.getType(), start_index.getType(), + start_index.getType()}, + ArrayRef{zero, start_index, ub}, &rewriter); start_indices.push_back( rewriter.create(loc, index_type, start_index) .getResult()); @@ -2073,6 +2074,7 @@ struct TorchIndexSelectOpOnTensorsConversion }; void populateLHLOToLinalgConversionPattern(MLIRContext* context, + TypeConverter& typeConverter, OwningRewritePatternList* patterns) { // clang-format off patterns->insert, @@ -2128,10 +2130,65 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, ScalarPointwiseToStandardConverter, SliceConverter, TransposeConverter - >(context); + >(typeConverter, context); // clang-format on } +// Converter that turns signed/unsigned integers types into signless types. +class RemoveSignTypeConverter : public TypeConverter { + public: + RemoveSignTypeConverter() { + addConversion([](Type type) { return type; }); + + addConversion(convertInteger); + addConversion(convertShapedType); + + addArgumentMaterialization(materializeCastFromIllegal); + addSourceMaterialization(materializeCastToIllegal); + addTargetMaterialization(materializeCastFromIllegal); + } + + private: + static Type convertInteger(IntegerType int_type) { + return IntegerType::get(int_type.getContext(), + int_type.getIntOrFloatBitWidth()); + } + + static Type convertShapedType(ShapedType shaped_type) { + if (auto int_type = shaped_type.getElementType().dyn_cast()) + return shaped_type.clone(convertInteger(int_type)); + return shaped_type; + } + + static llvm::Optional materializeCastFromIllegal(OpBuilder& builder, + Type type, + ValueRange inputs, + Location loc) { + Type from_type = getElementTypeOrSelf(inputs[0].getType()); + Type to_type = getElementTypeOrSelf(type); + if ((!from_type.isSignedInteger() && !from_type.isUnsignedInteger()) || + !to_type.isSignlessInteger()) + return llvm::None; + // Use unrealized conversion casts to do signful->signless conversions. + return builder.create(loc, type, inputs[0]) + ->getResult(0); + } + + static llvm::Optional materializeCastToIllegal(OpBuilder& builder, + Type type, + ValueRange inputs, + Location loc) { + Type from_type = getElementTypeOrSelf(inputs[0].getType()); + Type to_type = getElementTypeOrSelf(type); + if (!from_type.isSignlessInteger() || + (!to_type.isSignedInteger() && !to_type.isUnsignedInteger())) + return llvm::None; + // Use unrealized conversion casts to do signless->signful conversions. + return builder.create(loc, type, inputs[0]) + ->getResult(0); + } +}; + // Converts LHLO ops to Linalg generic. // Sample result for lmhlo::AddOp. // @@ -2163,9 +2220,12 @@ struct LhloLegalizeToLinalgPass target.addLegalDialect(); + target.addLegalOp(); + RemoveSignTypeConverter type_converter; auto func = getFunction(); - populateLHLOToLinalgConversionPattern(func.getContext(), &patterns); + populateLHLOToLinalgConversionPattern(func.getContext(), type_converter, + &patterns); if (failed(applyPartialConversion(func, target, std::move(patterns)))) { signalPassFailure(); } @@ -2181,17 +2241,20 @@ struct HloLegalizeToLinalgPass } void runOnFunction() override { - OwningRewritePatternList patterns(&getContext()); - ConversionTarget target(getContext()); + MLIRContext& ctx = getContext(); + OwningRewritePatternList patterns(&ctx); + ConversionTarget target(ctx); target.addLegalDialect(); // TODO: DimOp shouldn't be in MemRefDialect target.addLegalOp(); + target.addLegalOp(); + RemoveSignTypeConverter type_converter; auto func = getFunction(); - mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns); + mhlo::populateHLOToLinalgConversionPattern(&ctx, type_converter, &patterns); if (failed(applyPartialConversion(func, target, std::move(patterns)))) { signalPassFailure(); } @@ -2209,6 +2272,7 @@ std::unique_ptr> createLegalizeLhloToLinalgPass() { namespace mhlo { void populateHLOToLinalgConversionPattern(MLIRContext* context, + TypeConverter& type_converter, OwningRewritePatternList* patterns) { // clang-format off patterns->insert< @@ -2272,7 +2336,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, ReduceOnTensorsConversion, ReduceWindowOpOnTensorsConversion, TorchIndexSelectOpOnTensorsConversion, - PadOpOnTensorsConversion>(context); + PadOpOnTensorsConversion>(type_converter, context); // clang-format on patterns->insert, ReduceRegionXLAOpConversion, @@ -2287,5 +2351,10 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, std::unique_ptr> createLegalizeHloToLinalgPass() { return std::make_unique(); } + +std::unique_ptr createHloToLinalgSignedIntegerConverter() { + return std::make_unique(); +} + } // namespace mhlo } // namespace mlir diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index d4ab359..e7dc9c5 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -603,6 +603,21 @@ func @maxi(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { // ----- +// CHECK-LABEL: func @maxu +func @maxu(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor<2x2xui32> { + %0 = "mhlo.maximum"(%lhs, %rhs) + : (tensor<2x2xui32>, tensor<2x2xui32>) -> tensor<2x2xui32> + return %0 : tensor<2x2xui32> +} +// CHECK: linalg.init_tensor [2, 2] : tensor<2x2xi32> +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32, %{{.*}}: i32): +// CHECK-NEXT: %[[CMP:.*]] = cmpi ugt, %[[LHS_IN]], %[[RHS_IN]] : i32 +// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 + +// ----- + // CHECK-DAG: #[[MAP:.*]] = affine_map<() -> ()> // CHECK-LABEL: func @add_scalar func @add_scalar(%lhs: tensor, %rhs: tensor) -> tensor { @@ -2196,3 +2211,43 @@ func @concatenate(%a: tensor, %b: tensor, %c: tensor) } : (tensor, tensor, tensor) -> tensor return %concat : tensor } + +// ----- + +// CHECK-LABEL: unsigned_divide +func @unsigned_divide(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor<2x2xui32> { + // CHECK: linalg.generic + // CHECK: divi_unsigned + %0 = "mhlo.divide"(%lhs, %rhs) : (tensor<2x2xui32>, tensor<2x2xui32>) -> tensor<2x2xui32> + return %0 : tensor<2x2xui32> +} + +// ----- + +// CHECK-LABEL: unsigned_remainder +func @unsigned_remainder(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor<2x2xui32> { + // CHECK: linalg.generic + // CHECK: remi_unsigned + %0 = "mhlo.remainder"(%lhs, %rhs) : (tensor<2x2xui32>, tensor<2x2xui32>) -> tensor<2x2xui32> + return %0 : tensor<2x2xui32> +} + +// ----- + +// CHECK-LABEL: unsigned_convert +func @unsigned_convert(%in: tensor<2x2xui32>) -> tensor<2x2xui64> { + // CHECK: linalg.generic + // CHECK: zexti + %0 = "mhlo.convert"(%in) : (tensor<2x2xui32>) -> tensor<2x2xui64> + return %0 : tensor<2x2xui64> +} + +// ----- + +// CHECK-LABEL: unsigned_compare +func @unsigned_compare(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor<2x2xi1> { + // CHECK: linalg.generic + // CHECK: cmpi ugt + %0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "GT"} : (tensor<2x2xui32>, tensor<2x2xui32>) -> tensor<2x2xi1> + return %0 : tensor<2x2xi1> +} diff --git a/tests/lhlo-legalize-to-linalg.mlir b/tests/lhlo-legalize-to-linalg.mlir index 1e0f3f4..246a7a5 100644 --- a/tests/lhlo-legalize-to-linalg.mlir +++ b/tests/lhlo-legalize-to-linalg.mlir @@ -93,6 +93,21 @@ func @maxi(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, // ----- +// CHECK-LABEL: func @maxu +func @maxu(%lhs: memref<2x2xui32>, %rhs: memref<2x2xui32>, + %result: memref<2x2xui32>) { + "lmhlo.maximum"(%lhs, %rhs, %result) + : (memref<2x2xui32>, memref<2x2xui32>, memref<2x2xui32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32, %[[RESULT_OUT:.*]]: i32): +// CHECK-NEXT: %[[CMP:.*]] = cmpi ugt, %[[LHS_IN]], %[[RHS_IN]] : i32 +// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 + +// ----- + // CHECK-LABEL: func @and func @and(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, %result: memref<2x2xi32>) {