parent
eaa21130e8
commit
5235eceea0
|
@ -71,6 +71,9 @@ MAP_HLO_TO_LHLO(ReshapeOp);
|
||||||
MAP_HLO_TO_LHLO(RemOp);
|
MAP_HLO_TO_LHLO(RemOp);
|
||||||
MAP_HLO_TO_LHLO(RsqrtOp);
|
MAP_HLO_TO_LHLO(RsqrtOp);
|
||||||
MAP_HLO_TO_LHLO(SelectOp);
|
MAP_HLO_TO_LHLO(SelectOp);
|
||||||
|
MAP_HLO_TO_LHLO(ShiftLeftOp);
|
||||||
|
MAP_HLO_TO_LHLO(ShiftRightArithmeticOp);
|
||||||
|
MAP_HLO_TO_LHLO(ShiftRightLogicalOp);
|
||||||
MAP_HLO_TO_LHLO(SignOp);
|
MAP_HLO_TO_LHLO(SignOp);
|
||||||
MAP_HLO_TO_LHLO(SinOp);
|
MAP_HLO_TO_LHLO(SinOp);
|
||||||
MAP_HLO_TO_LHLO(SliceOp);
|
MAP_HLO_TO_LHLO(SliceOp);
|
||||||
|
|
|
@ -498,6 +498,30 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SelectOp>(
|
||||||
b);
|
b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftLeftOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return MapLhloOpToStdScalarOpImpl<IntegerType, mlir::ShiftLeftOp>{}(
|
||||||
|
loc, result_types, args, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftRightArithmeticOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return MapLhloOpToStdScalarOpImpl<IntegerType, mlir::SignedShiftRightOp>{}(
|
||||||
|
loc, result_types, args, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftRightLogicalOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return MapLhloOpToStdScalarOpImpl<IntegerType, mlir::UnsignedShiftRightOp>{}(
|
||||||
|
loc, result_types, args, b);
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
|
|
@ -634,6 +634,9 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
||||||
HloToLhloOpConverter<mhlo::RsqrtOp>,
|
HloToLhloOpConverter<mhlo::RsqrtOp>,
|
||||||
HloToLhloOpConverter<mhlo::ReshapeOp>,
|
HloToLhloOpConverter<mhlo::ReshapeOp>,
|
||||||
HloToLhloOpConverter<mhlo::SelectOp>,
|
HloToLhloOpConverter<mhlo::SelectOp>,
|
||||||
|
HloToLhloOpConverter<mhlo::ShiftLeftOp>,
|
||||||
|
HloToLhloOpConverter<mhlo::ShiftRightArithmeticOp>,
|
||||||
|
HloToLhloOpConverter<mhlo::ShiftRightLogicalOp>,
|
||||||
HloToLhloOpConverter<mhlo::SignOp>,
|
HloToLhloOpConverter<mhlo::SignOp>,
|
||||||
HloToLhloOpConverter<mhlo::SinOp>,
|
HloToLhloOpConverter<mhlo::SinOp>,
|
||||||
HloToLhloOpConverter<mhlo::SliceOp>,
|
HloToLhloOpConverter<mhlo::SliceOp>,
|
||||||
|
|
|
@ -937,6 +937,9 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
PointwiseToLinalgConverter<lmhlo::RemOp>,
|
PointwiseToLinalgConverter<lmhlo::RemOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
|
PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::SelectOp>,
|
PointwiseToLinalgConverter<lmhlo::SelectOp>,
|
||||||
|
PointwiseToLinalgConverter<lmhlo::ShiftLeftOp>,
|
||||||
|
PointwiseToLinalgConverter<lmhlo::ShiftRightArithmeticOp>,
|
||||||
|
PointwiseToLinalgConverter<lmhlo::ShiftRightLogicalOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::SignOp>,
|
PointwiseToLinalgConverter<lmhlo::SignOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::SinOp>,
|
PointwiseToLinalgConverter<lmhlo::SinOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::SqrtOp>,
|
PointwiseToLinalgConverter<lmhlo::SqrtOp>,
|
||||||
|
@ -1049,6 +1052,9 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
PointwiseToLinalgConverter<mhlo::RemOp, false>,
|
PointwiseToLinalgConverter<mhlo::RemOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
|
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::SelectOp, false>,
|
PointwiseToLinalgConverter<mhlo::SelectOp, false>,
|
||||||
|
PointwiseToLinalgConverter<mhlo::ShiftLeftOp, false>,
|
||||||
|
PointwiseToLinalgConverter<mhlo::ShiftRightArithmeticOp, false>,
|
||||||
|
PointwiseToLinalgConverter<mhlo::ShiftRightLogicalOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::SinOp, false>,
|
PointwiseToLinalgConverter<mhlo::SinOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
|
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::SubOp, false>,
|
PointwiseToLinalgConverter<mhlo::SubOp, false>,
|
||||||
|
|
|
@ -425,6 +425,48 @@ func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @shift_left
|
||||||
|
func @shift_left(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
|
||||||
|
%result: memref<2x2xi32>) {
|
||||||
|
%tensor_lhs = tensor_load %lhs : memref<2x2xi32>
|
||||||
|
%tensor_rhs = tensor_load %rhs : memref<2x2xi32>
|
||||||
|
%tensor_result = "mhlo.shift_left"(%tensor_lhs, %tensor_rhs)
|
||||||
|
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||||
|
// CHECK: "lmhlo.shift_left"(%{{.*}}, %{{.*}})
|
||||||
|
tensor_store %tensor_result, %result : memref<2x2xi32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @shift_right_arithmetic
|
||||||
|
func @shift_right_arithmetic(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
|
||||||
|
%result: memref<2x2xi32>) {
|
||||||
|
%tensor_lhs = tensor_load %lhs : memref<2x2xi32>
|
||||||
|
%tensor_rhs = tensor_load %rhs : memref<2x2xi32>
|
||||||
|
%tensor_result = "mhlo.shift_right_arithmetic"(%tensor_lhs, %tensor_rhs)
|
||||||
|
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||||
|
// CHECK: "lmhlo.shift_right_arithmetic"(%{{.*}}, %{{.*}})
|
||||||
|
tensor_store %tensor_result, %result : memref<2x2xi32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @shift_right_logical
|
||||||
|
func @shift_right_logical(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
|
||||||
|
%result: memref<2x2xi32>) {
|
||||||
|
%tensor_lhs = tensor_load %lhs : memref<2x2xi32>
|
||||||
|
%tensor_rhs = tensor_load %rhs : memref<2x2xi32>
|
||||||
|
%tensor_result = "mhlo.shift_right_logical"(%tensor_lhs, %tensor_rhs)
|
||||||
|
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||||
|
// CHECK: "lmhlo.shift_right_logical"(%{{.*}}, %{{.*}})
|
||||||
|
tensor_store %tensor_result, %result : memref<2x2xi32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @tanh
|
// CHECK-LABEL: func @tanh
|
||||||
func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||||
|
|
|
@ -630,3 +630,45 @@ func @iota() -> tensor<7x10xf32> {
|
||||||
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32
|
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32
|
||||||
// CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32
|
// CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32
|
||||||
// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32
|
// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @shift_left(%lhs: tensor<2x2xi32>,
|
||||||
|
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||||
|
%result = "mhlo.shift_left"(%lhs, %rhs)
|
||||||
|
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||||
|
return %result : tensor<2x2xi32>
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @shift_left
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32):
|
||||||
|
// CHECK-NEXT: %[[RESULT:.*]] = shift_left %[[LHS]], %[[RHS]] : i32
|
||||||
|
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @shift_right_arithmetic(%lhs: tensor<2x2xi32>,
|
||||||
|
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||||
|
%result = "mhlo.shift_right_arithmetic"(%lhs, %rhs)
|
||||||
|
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||||
|
return %result : tensor<2x2xi32>
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @shift_right_arithmetic
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32):
|
||||||
|
// CHECK-NEXT: %[[RESULT:.*]] = shift_right_signed %[[LHS]], %[[RHS]] : i32
|
||||||
|
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @shift_right_logical(%lhs: tensor<2x2xi32>,
|
||||||
|
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||||
|
%result = "mhlo.shift_right_logical"(%lhs, %rhs)
|
||||||
|
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||||
|
return %result : tensor<2x2xi32>
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @shift_right_logical
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32):
|
||||||
|
// CHECK-NEXT: %[[RESULT:.*]] = shift_right_unsigned %[[LHS]], %[[RHS]] : i32
|
||||||
|
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
|
||||||
|
|
Loading…
Reference in New Issue