Permit vector types in lmhlo to std lowering.
PiperOrigin-RevId: 337523303
This commit is contained in:
parent
2e30b59ddc
commit
706718b4fb
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace lmhlo {
|
||||
|
@ -96,7 +97,7 @@ template <typename SupportedType, typename StdScalarOp, typename... Args>
|
|||
struct MapLhloOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> {
|
||||
Value operator()(Location loc, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
Type element_type = args.front().getType();
|
||||
Type element_type = getElementTypeOrSelf(args.front().getType());
|
||||
if (element_type.isa<SupportedType>()) {
|
||||
return b->template create<StdScalarOp>(loc, result_types, args,
|
||||
mlir::None);
|
||||
|
@ -120,7 +121,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AbsOp>(Location loc,
|
|||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
Type element_type = args.front().getType();
|
||||
Type element_type = getElementTypeOrSelf(args.front().getType());
|
||||
if (element_type.isa<FloatType>()) {
|
||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::AbsFOp>{}(
|
||||
loc, result_types, args, b);
|
||||
|
@ -130,8 +131,11 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AbsOp>(Location loc,
|
|||
Value lhs = args[0];
|
||||
auto integer_type = element_type.dyn_cast<IntegerType>();
|
||||
|
||||
auto zero_intval =
|
||||
Value zero_intval =
|
||||
b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
|
||||
if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
|
||||
zero_intval = b->create<::mlir::SplatOp>(loc, vec_type, zero_intval);
|
||||
}
|
||||
auto lhs_gt_zero = b->create<ScalarIOp<CompareOp>>(loc, CmpIPredicate::sge,
|
||||
lhs, zero_intval);
|
||||
auto neg_val = b->create<ScalarIOp<lmhlo::SubOp>>(loc, zero_intval, lhs);
|
||||
|
@ -196,7 +200,7 @@ inline Value MapCompareOpToStdScalarOp(Location loc,
|
|||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
const auto& lhs = args[0];
|
||||
const auto& rhs = args[1];
|
||||
Type element_type = lhs.getType();
|
||||
Type element_type = getElementTypeOrSelf(lhs.getType());
|
||||
if (element_type.isSignlessInteger()) {
|
||||
Optional<CmpIPredicate> predicate =
|
||||
getCmpPredicate<CmpIPredicate>(comparison_direction);
|
||||
|
@ -268,8 +272,8 @@ template <>
|
|||
inline Value MapLhloOpToStdScalarOp<lmhlo::ConvertOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
Type sourceType = args.front().getType();
|
||||
Type targetType = result_types.front();
|
||||
Type sourceType = getElementTypeOrSelf(args.front().getType());
|
||||
Type targetType = getElementTypeOrSelf(result_types.front());
|
||||
|
||||
if (mlir::SIToFPOp::areCastCompatible(sourceType, targetType)) {
|
||||
return b->create<mlir::SIToFPOp>(loc, result_types, args, mlir::None);
|
||||
|
@ -390,7 +394,7 @@ struct CompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate,
|
|||
static Value map(Location loc, StringRef comparison_direction,
|
||||
ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
Type element_type = args.front().getType();
|
||||
Type element_type = getElementTypeOrSelf(args.front().getType());
|
||||
if (element_type.isa<SupportedType>()) {
|
||||
auto predicate = getCmpPredicate<Predicate>(comparison_direction);
|
||||
assert(predicate.hasValue() && "expected valid comparison direction");
|
||||
|
@ -439,7 +443,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
|
|||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
Type element_type = args.front().getType();
|
||||
Type element_type = getElementTypeOrSelf(args.front().getType());
|
||||
if (element_type.isa<FloatType>()) {
|
||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::NegFOp>{}(
|
||||
loc, result_types, args, b);
|
||||
|
@ -449,8 +453,11 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
|
|||
Value lhs = args[0];
|
||||
auto integer_type = element_type.dyn_cast<IntegerType>();
|
||||
|
||||
auto zero_intval =
|
||||
Value zero_intval =
|
||||
b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
|
||||
if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
|
||||
zero_intval = b->create<::mlir::SplatOp>(loc, vec_type, zero_intval);
|
||||
}
|
||||
return b->create<ScalarIOp<lmhlo::SubOp>>(loc, zero_intval, lhs);
|
||||
}
|
||||
return nullptr;
|
||||
|
@ -461,11 +468,14 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NotOp>(Location loc,
|
|||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
Type element_type = args.front().getType();
|
||||
Type element_type = getElementTypeOrSelf(args.front().getType());
|
||||
if (auto integer_type = element_type.dyn_cast<IntegerType>()) {
|
||||
// lmhlo.not(x) -> x ^ -1
|
||||
auto all_ones =
|
||||
Value all_ones =
|
||||
b->create<::mlir::ConstantIntOp>(loc, -1, integer_type.getWidth());
|
||||
if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
|
||||
all_ones = b->create<::mlir::SplatOp>(loc, vec_type, all_ones);
|
||||
}
|
||||
return b->create<::mlir::XOrOp>(loc, all_ones, args[0]);
|
||||
}
|
||||
return nullptr;
|
||||
|
@ -493,26 +503,35 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
|
|||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
Type element_type = args.front().getType();
|
||||
Type element_type = getElementTypeOrSelf(args.front().getType());
|
||||
if (auto float_type = element_type.dyn_cast<FloatType>()) {
|
||||
bool ignored;
|
||||
APFloat one_apfloat(1.0f);
|
||||
one_apfloat.convert(float_type.getFloatSemantics(),
|
||||
APFloat::rmNearestTiesToEven, &ignored);
|
||||
Value one = b->create<mlir::ConstantFloatOp>(loc, one_apfloat, float_type);
|
||||
if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
|
||||
one = b->create<::mlir::SplatOp>(loc, vec_type, one);
|
||||
}
|
||||
return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]);
|
||||
} else if (auto integer_type = element_type.dyn_cast<IntegerType>()) {
|
||||
// sign(x) = x == 0 ? 0 : ((x s>> 31) | 1)
|
||||
Value zero =
|
||||
b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
|
||||
Value cmp =
|
||||
b->create<::mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0], zero);
|
||||
Value bitwidth_minus_one = b->create<::mlir::ConstantIntOp>(
|
||||
loc, integer_type.getWidth() - 1, integer_type.getWidth());
|
||||
Value ashr =
|
||||
b->create<::mlir::SignedShiftRightOp>(loc, args[0], bitwidth_minus_one);
|
||||
Value one =
|
||||
b->create<::mlir::ConstantIntOp>(loc, 1, integer_type.getWidth());
|
||||
if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
|
||||
zero = b->create<::mlir::SplatOp>(loc, vec_type, zero);
|
||||
bitwidth_minus_one =
|
||||
b->create<::mlir::SplatOp>(loc, vec_type, bitwidth_minus_one);
|
||||
one = b->create<::mlir::SplatOp>(loc, vec_type, one);
|
||||
}
|
||||
Value cmp =
|
||||
b->create<::mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0], zero);
|
||||
Value ashr =
|
||||
b->create<::mlir::SignedShiftRightOp>(loc, args[0], bitwidth_minus_one);
|
||||
Value or_op = b->create<::mlir::OrOp>(loc, ashr, one);
|
||||
return b->create<::mlir::SelectOp>(loc, cmp, zero, or_op);
|
||||
}
|
||||
|
@ -583,6 +602,27 @@ struct HloOpToStdScalarOp {
|
|||
return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>(
|
||||
op.getLoc(), comparison_direction, result_types, args, b);
|
||||
}
|
||||
|
||||
// Implementation for LHLO ops except lmhlo::CompareOp.
|
||||
template <typename LhloOpTy,
|
||||
typename = std::enable_if_t<
|
||||
!std::is_same<LhloOpTy, lmhlo::CompareOp>::value &&
|
||||
std::is_same<typename mhlo::HloToLhloOp<LhloOpTy>,
|
||||
std::false_type>::value>>
|
||||
static Value map(Location loc, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) {
|
||||
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(loc, result_types, args, b);
|
||||
}
|
||||
|
||||
// Implementation for lmhlo::CompareOp.
|
||||
template <typename LhloOpTy, typename = std::enable_if_t<std::is_same<
|
||||
LhloOpTy, lmhlo::CompareOp>::value>>
|
||||
static Value map(Location loc, StringRef comparison_direction,
|
||||
ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>(
|
||||
loc, comparison_direction, result_types, args, b);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace lmhlo
|
||||
|
|
|
@ -621,10 +621,10 @@ func @sign_i16(%input: memref<2x2xi16>, %result: memref<2x2xi16>) {
|
|||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i16, %[[RESULT_OUT:.*]]):
|
||||
// CHECK-NEXT: %[[C0:.*]] = constant 0 : i16
|
||||
// CHECK-NEXT: %[[CMP:.*]] = cmpi "eq", %[[OPERAND_IN]], %[[C0]] : i16
|
||||
// CHECK-NEXT: %[[C15:.*]] = constant 15 : i16
|
||||
// CHECK-NEXT: %[[ASHR:.*]] = shift_right_signed %[[OPERAND_IN]], %[[C15]] : i16
|
||||
// CHECK-NEXT: %[[C1:.*]] = constant 1 : i16
|
||||
// CHECK-NEXT: %[[CMP:.*]] = cmpi "eq", %[[OPERAND_IN]], %[[C0]] : i16
|
||||
// CHECK-NEXT: %[[ASHR:.*]] = shift_right_signed %[[OPERAND_IN]], %[[C15]] : i16
|
||||
// CHECK-NEXT: %[[OR:.*]] = or %[[ASHR]], %[[C1]] : i16
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[C0]], %[[OR]] : i16
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : i16
|
||||
|
|
Loading…
Reference in New Issue