From 5235eceea055e37a33b93888bdb5e6907612fa2a Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Mon, 7 Dec 2020 13:01:25 -0800 Subject: [PATCH] Lower mhlo shifts to linalg PiperOrigin-RevId: 346161253 --- .../mhlo/transforms/map_hlo_to_lhlo_op.h | 3 ++ .../mhlo/transforms/map_lmhlo_to_scalar_op.h | 24 +++++++++++ .../mhlo/transforms/hlo_legalize_to_lhlo.cc | 3 ++ .../mhlo/transforms/legalize_to_linalg.cc | 6 +++ tests/hlo-legalize-to-lhlo.mlir | 42 +++++++++++++++++++ tests/hlo-legalize-to-linalg.mlir | 42 +++++++++++++++++++ 6 files changed, 120 insertions(+) diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h index ac67619..26f8afd 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h @@ -71,6 +71,9 @@ MAP_HLO_TO_LHLO(ReshapeOp); MAP_HLO_TO_LHLO(RemOp); MAP_HLO_TO_LHLO(RsqrtOp); MAP_HLO_TO_LHLO(SelectOp); +MAP_HLO_TO_LHLO(ShiftLeftOp); +MAP_HLO_TO_LHLO(ShiftRightArithmeticOp); +MAP_HLO_TO_LHLO(ShiftRightLogicalOp); MAP_HLO_TO_LHLO(SignOp); MAP_HLO_TO_LHLO(SinOp); MAP_HLO_TO_LHLO(SliceOp); diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index d59dfd4..cf00805 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -498,6 +498,30 @@ inline Value MapLhloOpToStdScalarOp( b); } +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); +} + template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index 76b91f7..ec3a192 100644 --- a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -634,6 +634,9 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index c9eeefd..4b15de3 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -937,6 +937,9 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, @@ -1049,6 +1052,9 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, diff --git a/tests/hlo-legalize-to-lhlo.mlir b/tests/hlo-legalize-to-lhlo.mlir index cb88019..d7adf6c 100644 --- a/tests/hlo-legalize-to-lhlo.mlir +++ b/tests/hlo-legalize-to-lhlo.mlir @@ -425,6 +425,48 @@ func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { // ----- +// CHECK-LABEL: func @shift_left +func @shift_left(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, + %result: memref<2x2xi32>) { + %tensor_lhs = tensor_load %lhs : memref<2x2xi32> + %tensor_rhs = tensor_load %rhs : memref<2x2xi32> + %tensor_result = "mhlo.shift_left"(%tensor_lhs, %tensor_rhs) + : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + // CHECK: "lmhlo.shift_left"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref<2x2xi32> + return +} + +// ----- + +// CHECK-LABEL: func @shift_right_arithmetic +func @shift_right_arithmetic(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, + %result: memref<2x2xi32>) { + %tensor_lhs = tensor_load %lhs : memref<2x2xi32> + %tensor_rhs = tensor_load %rhs : memref<2x2xi32> + %tensor_result = "mhlo.shift_right_arithmetic"(%tensor_lhs, %tensor_rhs) + : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + // CHECK: "lmhlo.shift_right_arithmetic"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref<2x2xi32> + return +} + +// ----- + +// CHECK-LABEL: func @shift_right_logical +func @shift_right_logical(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, + %result: memref<2x2xi32>) { + %tensor_lhs = tensor_load %lhs : memref<2x2xi32> + %tensor_rhs = tensor_load %rhs : memref<2x2xi32> + %tensor_result = "mhlo.shift_right_logical"(%tensor_lhs, %tensor_rhs) + : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + // CHECK: "lmhlo.shift_right_logical"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref<2x2xi32> + return +} + +// ----- + // CHECK-LABEL: func @tanh func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index c4413ed..0749345 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -630,3 +630,45 @@ func @iota() -> tensor<7x10xf32> { // CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32 // CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32 // CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32 + +// ----- + +func @shift_left(%lhs: tensor<2x2xi32>, + %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { + %result = "mhlo.shift_left"(%lhs, %rhs) + : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + return %result : tensor<2x2xi32> +} +// CHECK-LABEL: func @shift_left +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32): +// CHECK-NEXT: %[[RESULT:.*]] = shift_left %[[LHS]], %[[RHS]] : i32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 + +// ----- + +func @shift_right_arithmetic(%lhs: tensor<2x2xi32>, + %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { + %result = "mhlo.shift_right_arithmetic"(%lhs, %rhs) + : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + return %result : tensor<2x2xi32> +} +// CHECK-LABEL: func @shift_right_arithmetic +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32): +// CHECK-NEXT: %[[RESULT:.*]] = shift_right_signed %[[LHS]], %[[RHS]] : i32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 + +// ----- + +func @shift_right_logical(%lhs: tensor<2x2xi32>, + %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { + %result = "mhlo.shift_right_logical"(%lhs, %rhs) + : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + return %result : tensor<2x2xi32> +} +// CHECK-LABEL: func @shift_right_logical +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32): +// CHECK-NEXT: %[[RESULT:.*]] = shift_right_unsigned %[[LHS]], %[[RHS]] : i32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i32