Lower mhlo shifts to linalg

PiperOrigin-RevId: 346161253
This commit is contained in:
Benjamin Kramer 2020-12-07 13:01:25 -08:00 committed by TensorFlow MLIR Team
parent eaa21130e8
commit 5235eceea0
6 changed files with 120 additions and 0 deletions

View File

@ -71,6 +71,9 @@ MAP_HLO_TO_LHLO(ReshapeOp);
MAP_HLO_TO_LHLO(RemOp); MAP_HLO_TO_LHLO(RemOp);
MAP_HLO_TO_LHLO(RsqrtOp); MAP_HLO_TO_LHLO(RsqrtOp);
MAP_HLO_TO_LHLO(SelectOp); MAP_HLO_TO_LHLO(SelectOp);
MAP_HLO_TO_LHLO(ShiftLeftOp);
MAP_HLO_TO_LHLO(ShiftRightArithmeticOp);
MAP_HLO_TO_LHLO(ShiftRightLogicalOp);
MAP_HLO_TO_LHLO(SignOp); MAP_HLO_TO_LHLO(SignOp);
MAP_HLO_TO_LHLO(SinOp); MAP_HLO_TO_LHLO(SinOp);
MAP_HLO_TO_LHLO(SliceOp); MAP_HLO_TO_LHLO(SliceOp);

View File

@ -498,6 +498,30 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SelectOp>(
b); b);
} }
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftLeftOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, mlir::ShiftLeftOp>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftRightArithmeticOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, mlir::SignedShiftRightOp>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftRightLogicalOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, mlir::UnsignedShiftRightOp>{}(
loc, result_types, args, b);
}
template <> template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc, inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
ArrayRef<Type> result_types, ArrayRef<Type> result_types,

View File

@ -634,6 +634,9 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
HloToLhloOpConverter<mhlo::RsqrtOp>, HloToLhloOpConverter<mhlo::RsqrtOp>,
HloToLhloOpConverter<mhlo::ReshapeOp>, HloToLhloOpConverter<mhlo::ReshapeOp>,
HloToLhloOpConverter<mhlo::SelectOp>, HloToLhloOpConverter<mhlo::SelectOp>,
HloToLhloOpConverter<mhlo::ShiftLeftOp>,
HloToLhloOpConverter<mhlo::ShiftRightArithmeticOp>,
HloToLhloOpConverter<mhlo::ShiftRightLogicalOp>,
HloToLhloOpConverter<mhlo::SignOp>, HloToLhloOpConverter<mhlo::SignOp>,
HloToLhloOpConverter<mhlo::SinOp>, HloToLhloOpConverter<mhlo::SinOp>,
HloToLhloOpConverter<mhlo::SliceOp>, HloToLhloOpConverter<mhlo::SliceOp>,

View File

@ -937,6 +937,9 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<lmhlo::RemOp>, PointwiseToLinalgConverter<lmhlo::RemOp>,
PointwiseToLinalgConverter<lmhlo::RsqrtOp>, PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
PointwiseToLinalgConverter<lmhlo::SelectOp>, PointwiseToLinalgConverter<lmhlo::SelectOp>,
PointwiseToLinalgConverter<lmhlo::ShiftLeftOp>,
PointwiseToLinalgConverter<lmhlo::ShiftRightArithmeticOp>,
PointwiseToLinalgConverter<lmhlo::ShiftRightLogicalOp>,
PointwiseToLinalgConverter<lmhlo::SignOp>, PointwiseToLinalgConverter<lmhlo::SignOp>,
PointwiseToLinalgConverter<lmhlo::SinOp>, PointwiseToLinalgConverter<lmhlo::SinOp>,
PointwiseToLinalgConverter<lmhlo::SqrtOp>, PointwiseToLinalgConverter<lmhlo::SqrtOp>,
@ -1049,6 +1052,9 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<mhlo::RemOp, false>, PointwiseToLinalgConverter<mhlo::RemOp, false>,
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>, PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
PointwiseToLinalgConverter<mhlo::SelectOp, false>, PointwiseToLinalgConverter<mhlo::SelectOp, false>,
PointwiseToLinalgConverter<mhlo::ShiftLeftOp, false>,
PointwiseToLinalgConverter<mhlo::ShiftRightArithmeticOp, false>,
PointwiseToLinalgConverter<mhlo::ShiftRightLogicalOp, false>,
PointwiseToLinalgConverter<mhlo::SinOp, false>, PointwiseToLinalgConverter<mhlo::SinOp, false>,
PointwiseToLinalgConverter<mhlo::SqrtOp, false>, PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
PointwiseToLinalgConverter<mhlo::SubOp, false>, PointwiseToLinalgConverter<mhlo::SubOp, false>,

View File

@ -425,6 +425,48 @@ func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// ----- // -----
// CHECK-LABEL: func @shift_left
func @shift_left(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
%result: memref<2x2xi32>) {
%tensor_lhs = tensor_load %lhs : memref<2x2xi32>
%tensor_rhs = tensor_load %rhs : memref<2x2xi32>
%tensor_result = "mhlo.shift_left"(%tensor_lhs, %tensor_rhs)
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// CHECK: "lmhlo.shift_left"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xi32>
return
}
// -----
// CHECK-LABEL: func @shift_right_arithmetic
func @shift_right_arithmetic(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
%result: memref<2x2xi32>) {
%tensor_lhs = tensor_load %lhs : memref<2x2xi32>
%tensor_rhs = tensor_load %rhs : memref<2x2xi32>
%tensor_result = "mhlo.shift_right_arithmetic"(%tensor_lhs, %tensor_rhs)
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// CHECK: "lmhlo.shift_right_arithmetic"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xi32>
return
}
// -----
// CHECK-LABEL: func @shift_right_logical
func @shift_right_logical(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
%result: memref<2x2xi32>) {
%tensor_lhs = tensor_load %lhs : memref<2x2xi32>
%tensor_rhs = tensor_load %rhs : memref<2x2xi32>
%tensor_result = "mhlo.shift_right_logical"(%tensor_lhs, %tensor_rhs)
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// CHECK: "lmhlo.shift_right_logical"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xi32>
return
}
// -----
// CHECK-LABEL: func @tanh // CHECK-LABEL: func @tanh
func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>

View File

@ -630,3 +630,45 @@ func @iota() -> tensor<7x10xf32> {
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32 // CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32
// CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32 // CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32
// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32 // CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32
// -----
func @shift_left(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
%result = "mhlo.shift_left"(%lhs, %rhs)
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
return %result : tensor<2x2xi32>
}
// CHECK-LABEL: func @shift_left
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32):
// CHECK-NEXT: %[[RESULT:.*]] = shift_left %[[LHS]], %[[RHS]] : i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
// -----
func @shift_right_arithmetic(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
%result = "mhlo.shift_right_arithmetic"(%lhs, %rhs)
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
return %result : tensor<2x2xi32>
}
// CHECK-LABEL: func @shift_right_arithmetic
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32):
// CHECK-NEXT: %[[RESULT:.*]] = shift_right_signed %[[LHS]], %[[RHS]] : i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
// -----
func @shift_right_logical(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
%result = "mhlo.shift_right_logical"(%lhs, %rhs)
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
return %result : tensor<2x2xi32>
}
// CHECK-LABEL: func @shift_right_logical
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32):
// CHECK-NEXT: %[[RESULT:.*]] = shift_right_unsigned %[[LHS]], %[[RHS]] : i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32