Permit vector types in lmhlo to std lowering.

PiperOrigin-RevId: 337523303
This commit is contained in:
A. Unique TensorFlower 2020-10-16 09:46:12 -07:00 committed by TensorFlow MLIR Team
parent 2e30b59ddc
commit 706718b4fb
2 changed files with 58 additions and 18 deletions

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/TypeUtilities.h"
namespace mlir {
namespace lmhlo {
@ -96,7 +97,7 @@ template <typename SupportedType, typename StdScalarOp, typename... Args>
struct MapLhloOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> {
Value operator()(Location loc, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
Type element_type = args.front().getType();
Type element_type = getElementTypeOrSelf(args.front().getType());
if (element_type.isa<SupportedType>()) {
return b->template create<StdScalarOp>(loc, result_types, args,
mlir::None);
@ -120,7 +121,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AbsOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
Type element_type = args.front().getType();
Type element_type = getElementTypeOrSelf(args.front().getType());
if (element_type.isa<FloatType>()) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::AbsFOp>{}(
loc, result_types, args, b);
@ -130,8 +131,11 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AbsOp>(Location loc,
Value lhs = args[0];
auto integer_type = element_type.dyn_cast<IntegerType>();
auto zero_intval =
Value zero_intval =
b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
zero_intval = b->create<::mlir::SplatOp>(loc, vec_type, zero_intval);
}
auto lhs_gt_zero = b->create<ScalarIOp<CompareOp>>(loc, CmpIPredicate::sge,
lhs, zero_intval);
auto neg_val = b->create<ScalarIOp<lmhlo::SubOp>>(loc, zero_intval, lhs);
@ -196,7 +200,7 @@ inline Value MapCompareOpToStdScalarOp(Location loc,
ArrayRef<Value> args, OpBuilder* b) {
const auto& lhs = args[0];
const auto& rhs = args[1];
Type element_type = lhs.getType();
Type element_type = getElementTypeOrSelf(lhs.getType());
if (element_type.isSignlessInteger()) {
Optional<CmpIPredicate> predicate =
getCmpPredicate<CmpIPredicate>(comparison_direction);
@ -268,8 +272,8 @@ template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::ConvertOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
Type sourceType = args.front().getType();
Type targetType = result_types.front();
Type sourceType = getElementTypeOrSelf(args.front().getType());
Type targetType = getElementTypeOrSelf(result_types.front());
if (mlir::SIToFPOp::areCastCompatible(sourceType, targetType)) {
return b->create<mlir::SIToFPOp>(loc, result_types, args, mlir::None);
@ -390,7 +394,7 @@ struct CompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate,
static Value map(Location loc, StringRef comparison_direction,
ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
Type element_type = args.front().getType();
Type element_type = getElementTypeOrSelf(args.front().getType());
if (element_type.isa<SupportedType>()) {
auto predicate = getCmpPredicate<Predicate>(comparison_direction);
assert(predicate.hasValue() && "expected valid comparison direction");
@ -439,7 +443,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
Type element_type = args.front().getType();
Type element_type = getElementTypeOrSelf(args.front().getType());
if (element_type.isa<FloatType>()) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::NegFOp>{}(
loc, result_types, args, b);
@ -449,8 +453,11 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
Value lhs = args[0];
auto integer_type = element_type.dyn_cast<IntegerType>();
auto zero_intval =
Value zero_intval =
b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
zero_intval = b->create<::mlir::SplatOp>(loc, vec_type, zero_intval);
}
return b->create<ScalarIOp<lmhlo::SubOp>>(loc, zero_intval, lhs);
}
return nullptr;
@ -461,11 +468,14 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NotOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
Type element_type = args.front().getType();
Type element_type = getElementTypeOrSelf(args.front().getType());
if (auto integer_type = element_type.dyn_cast<IntegerType>()) {
// lmhlo.not(x) -> x ^ -1
auto all_ones =
Value all_ones =
b->create<::mlir::ConstantIntOp>(loc, -1, integer_type.getWidth());
if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
all_ones = b->create<::mlir::SplatOp>(loc, vec_type, all_ones);
}
return b->create<::mlir::XOrOp>(loc, all_ones, args[0]);
}
return nullptr;
@ -493,26 +503,35 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
Type element_type = args.front().getType();
Type element_type = getElementTypeOrSelf(args.front().getType());
if (auto float_type = element_type.dyn_cast<FloatType>()) {
bool ignored;
APFloat one_apfloat(1.0f);
one_apfloat.convert(float_type.getFloatSemantics(),
APFloat::rmNearestTiesToEven, &ignored);
Value one = b->create<mlir::ConstantFloatOp>(loc, one_apfloat, float_type);
if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
one = b->create<::mlir::SplatOp>(loc, vec_type, one);
}
return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]);
} else if (auto integer_type = element_type.dyn_cast<IntegerType>()) {
// sign(x) = x == 0 ? 0 : ((x s>> 31) | 1)
Value zero =
b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
Value cmp =
b->create<::mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0], zero);
Value bitwidth_minus_one = b->create<::mlir::ConstantIntOp>(
loc, integer_type.getWidth() - 1, integer_type.getWidth());
Value ashr =
b->create<::mlir::SignedShiftRightOp>(loc, args[0], bitwidth_minus_one);
Value one =
b->create<::mlir::ConstantIntOp>(loc, 1, integer_type.getWidth());
if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
zero = b->create<::mlir::SplatOp>(loc, vec_type, zero);
bitwidth_minus_one =
b->create<::mlir::SplatOp>(loc, vec_type, bitwidth_minus_one);
one = b->create<::mlir::SplatOp>(loc, vec_type, one);
}
Value cmp =
b->create<::mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0], zero);
Value ashr =
b->create<::mlir::SignedShiftRightOp>(loc, args[0], bitwidth_minus_one);
Value or_op = b->create<::mlir::OrOp>(loc, ashr, one);
return b->create<::mlir::SelectOp>(loc, cmp, zero, or_op);
}
@ -583,6 +602,27 @@ struct HloOpToStdScalarOp {
return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>(
op.getLoc(), comparison_direction, result_types, args, b);
}
// Implementation for LHLO ops except lmhlo::CompareOp.
template <typename LhloOpTy,
typename = std::enable_if_t<
!std::is_same<LhloOpTy, lmhlo::CompareOp>::value &&
std::is_same<typename mhlo::HloToLhloOp<LhloOpTy>,
std::false_type>::value>>
static Value map(Location loc, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) {
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(loc, result_types, args, b);
}
// Implementation for lmhlo::CompareOp.
template <typename LhloOpTy, typename = std::enable_if_t<std::is_same<
LhloOpTy, lmhlo::CompareOp>::value>>
static Value map(Location loc, StringRef comparison_direction,
ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>(
loc, comparison_direction, result_types, args, b);
}
};
} // namespace lmhlo

View File

@ -621,10 +621,10 @@ func @sign_i16(%input: memref<2x2xi16>, %result: memref<2x2xi16>) {
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i16, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[C0:.*]] = constant 0 : i16
// CHECK-NEXT: %[[CMP:.*]] = cmpi "eq", %[[OPERAND_IN]], %[[C0]] : i16
// CHECK-NEXT: %[[C15:.*]] = constant 15 : i16
// CHECK-NEXT: %[[ASHR:.*]] = shift_right_signed %[[OPERAND_IN]], %[[C15]] : i16
// CHECK-NEXT: %[[C1:.*]] = constant 1 : i16
// CHECK-NEXT: %[[CMP:.*]] = cmpi "eq", %[[OPERAND_IN]], %[[C0]] : i16
// CHECK-NEXT: %[[ASHR:.*]] = shift_right_signed %[[OPERAND_IN]], %[[C15]] : i16
// CHECK-NEXT: %[[OR:.*]] = or %[[ASHR]], %[[C1]] : i16
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[C0]], %[[OR]] : i16
// CHECK-NEXT: linalg.yield %[[RESULT]] : i16