parent
							
								
									eaa21130e8
								
							
						
					
					
						commit
						5235eceea0
					
				| 
						 | 
					@ -71,6 +71,9 @@ MAP_HLO_TO_LHLO(ReshapeOp);
 | 
				
			||||||
MAP_HLO_TO_LHLO(RemOp);
 | 
					MAP_HLO_TO_LHLO(RemOp);
 | 
				
			||||||
MAP_HLO_TO_LHLO(RsqrtOp);
 | 
					MAP_HLO_TO_LHLO(RsqrtOp);
 | 
				
			||||||
MAP_HLO_TO_LHLO(SelectOp);
 | 
					MAP_HLO_TO_LHLO(SelectOp);
 | 
				
			||||||
 | 
					MAP_HLO_TO_LHLO(ShiftLeftOp);
 | 
				
			||||||
 | 
					MAP_HLO_TO_LHLO(ShiftRightArithmeticOp);
 | 
				
			||||||
 | 
					MAP_HLO_TO_LHLO(ShiftRightLogicalOp);
 | 
				
			||||||
MAP_HLO_TO_LHLO(SignOp);
 | 
					MAP_HLO_TO_LHLO(SignOp);
 | 
				
			||||||
MAP_HLO_TO_LHLO(SinOp);
 | 
					MAP_HLO_TO_LHLO(SinOp);
 | 
				
			||||||
MAP_HLO_TO_LHLO(SliceOp);
 | 
					MAP_HLO_TO_LHLO(SliceOp);
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -498,6 +498,30 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SelectOp>(
 | 
				
			||||||
                                                        b);
 | 
					                                                        b);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftLeftOp>(
 | 
				
			||||||
 | 
					    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
				
			||||||
 | 
					    OpBuilder* b) {
 | 
				
			||||||
 | 
					  return MapLhloOpToStdScalarOpImpl<IntegerType, mlir::ShiftLeftOp>{}(
 | 
				
			||||||
 | 
					      loc, result_types, args, b);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftRightArithmeticOp>(
 | 
				
			||||||
 | 
					    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
				
			||||||
 | 
					    OpBuilder* b) {
 | 
				
			||||||
 | 
					  return MapLhloOpToStdScalarOpImpl<IntegerType, mlir::SignedShiftRightOp>{}(
 | 
				
			||||||
 | 
					      loc, result_types, args, b);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftRightLogicalOp>(
 | 
				
			||||||
 | 
					    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
				
			||||||
 | 
					    OpBuilder* b) {
 | 
				
			||||||
 | 
					  return MapLhloOpToStdScalarOpImpl<IntegerType, mlir::UnsignedShiftRightOp>{}(
 | 
				
			||||||
 | 
					      loc, result_types, args, b);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <>
 | 
					template <>
 | 
				
			||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
 | 
					inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
 | 
				
			||||||
                                                   ArrayRef<Type> result_types,
 | 
					                                                   ArrayRef<Type> result_types,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -634,6 +634,9 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
 | 
				
			||||||
      HloToLhloOpConverter<mhlo::RsqrtOp>,
 | 
					      HloToLhloOpConverter<mhlo::RsqrtOp>,
 | 
				
			||||||
      HloToLhloOpConverter<mhlo::ReshapeOp>,
 | 
					      HloToLhloOpConverter<mhlo::ReshapeOp>,
 | 
				
			||||||
      HloToLhloOpConverter<mhlo::SelectOp>,
 | 
					      HloToLhloOpConverter<mhlo::SelectOp>,
 | 
				
			||||||
 | 
					      HloToLhloOpConverter<mhlo::ShiftLeftOp>,
 | 
				
			||||||
 | 
					      HloToLhloOpConverter<mhlo::ShiftRightArithmeticOp>,
 | 
				
			||||||
 | 
					      HloToLhloOpConverter<mhlo::ShiftRightLogicalOp>,
 | 
				
			||||||
      HloToLhloOpConverter<mhlo::SignOp>,
 | 
					      HloToLhloOpConverter<mhlo::SignOp>,
 | 
				
			||||||
      HloToLhloOpConverter<mhlo::SinOp>,
 | 
					      HloToLhloOpConverter<mhlo::SinOp>,
 | 
				
			||||||
      HloToLhloOpConverter<mhlo::SliceOp>,
 | 
					      HloToLhloOpConverter<mhlo::SliceOp>,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -937,6 +937,9 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
 | 
				
			||||||
                   PointwiseToLinalgConverter<lmhlo::RemOp>,
 | 
					                   PointwiseToLinalgConverter<lmhlo::RemOp>,
 | 
				
			||||||
                   PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
 | 
					                   PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
 | 
				
			||||||
                   PointwiseToLinalgConverter<lmhlo::SelectOp>,
 | 
					                   PointwiseToLinalgConverter<lmhlo::SelectOp>,
 | 
				
			||||||
 | 
					                   PointwiseToLinalgConverter<lmhlo::ShiftLeftOp>,
 | 
				
			||||||
 | 
					                   PointwiseToLinalgConverter<lmhlo::ShiftRightArithmeticOp>,
 | 
				
			||||||
 | 
					                   PointwiseToLinalgConverter<lmhlo::ShiftRightLogicalOp>,
 | 
				
			||||||
                   PointwiseToLinalgConverter<lmhlo::SignOp>,
 | 
					                   PointwiseToLinalgConverter<lmhlo::SignOp>,
 | 
				
			||||||
                   PointwiseToLinalgConverter<lmhlo::SinOp>,
 | 
					                   PointwiseToLinalgConverter<lmhlo::SinOp>,
 | 
				
			||||||
                   PointwiseToLinalgConverter<lmhlo::SqrtOp>,
 | 
					                   PointwiseToLinalgConverter<lmhlo::SqrtOp>,
 | 
				
			||||||
| 
						 | 
					@ -1049,6 +1052,9 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
 | 
				
			||||||
               PointwiseToLinalgConverter<mhlo::RemOp, false>,
 | 
					               PointwiseToLinalgConverter<mhlo::RemOp, false>,
 | 
				
			||||||
               PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
 | 
					               PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
 | 
				
			||||||
               PointwiseToLinalgConverter<mhlo::SelectOp, false>,
 | 
					               PointwiseToLinalgConverter<mhlo::SelectOp, false>,
 | 
				
			||||||
 | 
					               PointwiseToLinalgConverter<mhlo::ShiftLeftOp, false>,
 | 
				
			||||||
 | 
					               PointwiseToLinalgConverter<mhlo::ShiftRightArithmeticOp, false>,
 | 
				
			||||||
 | 
					               PointwiseToLinalgConverter<mhlo::ShiftRightLogicalOp, false>,
 | 
				
			||||||
               PointwiseToLinalgConverter<mhlo::SinOp, false>,
 | 
					               PointwiseToLinalgConverter<mhlo::SinOp, false>,
 | 
				
			||||||
               PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
 | 
					               PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
 | 
				
			||||||
               PointwiseToLinalgConverter<mhlo::SubOp, false>,
 | 
					               PointwiseToLinalgConverter<mhlo::SubOp, false>,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -425,6 +425,48 @@ func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// -----
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-LABEL: func @shift_left
 | 
				
			||||||
 | 
					func @shift_left(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
 | 
				
			||||||
 | 
					                 %result: memref<2x2xi32>) {
 | 
				
			||||||
 | 
					  %tensor_lhs = tensor_load %lhs : memref<2x2xi32>
 | 
				
			||||||
 | 
					  %tensor_rhs = tensor_load %rhs : memref<2x2xi32>
 | 
				
			||||||
 | 
					  %tensor_result = "mhlo.shift_left"(%tensor_lhs, %tensor_rhs)
 | 
				
			||||||
 | 
					      : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
 | 
				
			||||||
 | 
					  // CHECK: "lmhlo.shift_left"(%{{.*}}, %{{.*}})
 | 
				
			||||||
 | 
					  tensor_store %tensor_result, %result : memref<2x2xi32>
 | 
				
			||||||
 | 
					  return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-LABEL: func @shift_right_arithmetic
 | 
				
			||||||
 | 
					func @shift_right_arithmetic(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
 | 
				
			||||||
 | 
					                             %result: memref<2x2xi32>) {
 | 
				
			||||||
 | 
					  %tensor_lhs = tensor_load %lhs : memref<2x2xi32>
 | 
				
			||||||
 | 
					  %tensor_rhs = tensor_load %rhs : memref<2x2xi32>
 | 
				
			||||||
 | 
					  %tensor_result = "mhlo.shift_right_arithmetic"(%tensor_lhs, %tensor_rhs)
 | 
				
			||||||
 | 
					      : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
 | 
				
			||||||
 | 
					  // CHECK: "lmhlo.shift_right_arithmetic"(%{{.*}}, %{{.*}})
 | 
				
			||||||
 | 
					  tensor_store %tensor_result, %result : memref<2x2xi32>
 | 
				
			||||||
 | 
					  return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-LABEL: func @shift_right_logical
 | 
				
			||||||
 | 
					func @shift_right_logical(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
 | 
				
			||||||
 | 
					                          %result: memref<2x2xi32>) {
 | 
				
			||||||
 | 
					  %tensor_lhs = tensor_load %lhs : memref<2x2xi32>
 | 
				
			||||||
 | 
					  %tensor_rhs = tensor_load %rhs : memref<2x2xi32>
 | 
				
			||||||
 | 
					  %tensor_result = "mhlo.shift_right_logical"(%tensor_lhs, %tensor_rhs)
 | 
				
			||||||
 | 
					      : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
 | 
				
			||||||
 | 
					  // CHECK: "lmhlo.shift_right_logical"(%{{.*}}, %{{.*}})
 | 
				
			||||||
 | 
					  tensor_store %tensor_result, %result : memref<2x2xi32>
 | 
				
			||||||
 | 
					  return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// CHECK-LABEL: func @tanh
 | 
					// CHECK-LABEL: func @tanh
 | 
				
			||||||
func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
 | 
					func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
 | 
				
			||||||
  %tensor_operand = tensor_load %operand : memref<2x2xf32>
 | 
					  %tensor_operand = tensor_load %operand : memref<2x2xf32>
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -630,3 +630,45 @@ func @iota() -> tensor<7x10xf32> {
 | 
				
			||||||
// CHECK-NEXT:   %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32
 | 
					// CHECK-NEXT:   %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32
 | 
				
			||||||
// CHECK-NEXT:   %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32
 | 
					// CHECK-NEXT:   %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32
 | 
				
			||||||
// CHECK-NEXT:   linalg.yield %[[FLOAT_CAST]] : f32
 | 
					// CHECK-NEXT:   linalg.yield %[[FLOAT_CAST]] : f32
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @shift_left(%lhs: tensor<2x2xi32>,
 | 
				
			||||||
 | 
					                 %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
 | 
				
			||||||
 | 
					  %result = "mhlo.shift_left"(%lhs, %rhs)
 | 
				
			||||||
 | 
					      : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
 | 
				
			||||||
 | 
					  return %result : tensor<2x2xi32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					// CHECK-LABEL: func @shift_left
 | 
				
			||||||
 | 
					// CHECK: linalg.generic
 | 
				
			||||||
 | 
					// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32):
 | 
				
			||||||
 | 
					// CHECK-NEXT:   %[[RESULT:.*]] = shift_left %[[LHS]], %[[RHS]] : i32
 | 
				
			||||||
 | 
					// CHECK-NEXT:   linalg.yield %[[RESULT]] : i32
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @shift_right_arithmetic(%lhs: tensor<2x2xi32>,
 | 
				
			||||||
 | 
					                             %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
 | 
				
			||||||
 | 
					  %result = "mhlo.shift_right_arithmetic"(%lhs, %rhs)
 | 
				
			||||||
 | 
					      : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
 | 
				
			||||||
 | 
					  return %result : tensor<2x2xi32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					// CHECK-LABEL: func @shift_right_arithmetic
 | 
				
			||||||
 | 
					// CHECK: linalg.generic
 | 
				
			||||||
 | 
					// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32):
 | 
				
			||||||
 | 
					// CHECK-NEXT:   %[[RESULT:.*]] = shift_right_signed %[[LHS]], %[[RHS]] : i32
 | 
				
			||||||
 | 
					// CHECK-NEXT:   linalg.yield %[[RESULT]] : i32
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @shift_right_logical(%lhs: tensor<2x2xi32>,
 | 
				
			||||||
 | 
					                          %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
 | 
				
			||||||
 | 
					  %result = "mhlo.shift_right_logical"(%lhs, %rhs)
 | 
				
			||||||
 | 
					      : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
 | 
				
			||||||
 | 
					  return %result : tensor<2x2xi32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					// CHECK-LABEL: func @shift_right_logical
 | 
				
			||||||
 | 
					// CHECK: linalg.generic
 | 
				
			||||||
 | 
					// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32):
 | 
				
			||||||
 | 
					// CHECK-NEXT:   %[[RESULT:.*]] = shift_right_unsigned %[[LHS]], %[[RHS]] : i32
 | 
				
			||||||
 | 
					// CHECK-NEXT:   linalg.yield %[[RESULT]] : i32
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue