[MLIR] Extend unranked transformation to CHLO dialect
PiperOrigin-RevId: 332026604
This commit is contained in:
		
							parent
							
								
									2aa07b0091
								
							
						
					
					
						commit
						69b80d8deb
					
				|  | @ -344,14 +344,16 @@ def HLOClient_BroadcastComplexOp : HLOClient_BroadcastBinaryElementwiseOp< | |||
| //===----------------------------------------------------------------------===// | ||||
| 
 | ||||
| class HLOClient_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits, | ||||
|     Type TensorType> : HLOClient_Op<mnemonic, | ||||
|       !listconcat(traits, [InferFusibilityOpInterface])> { | ||||
|     Type TensorType> : HLOClient_Op<mnemonic, !listconcat(traits, [ | ||||
|     InferFusibilityOpInterface, NoSideEffect, SameOperandsAndResultType])> { | ||||
|   let arguments = (ins TensorType:$operand); | ||||
|   let results = (outs TensorType); | ||||
|   let results = (outs TensorType:$result); | ||||
| 
 | ||||
|   let assemblyFormat = "$operand attr-dict `:` type($operand)"; | ||||
| } | ||||
| 
 | ||||
| def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", | ||||
|     [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> { | ||||
| def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", [], | ||||
|     HLO_FpOrComplexTensor> { | ||||
|   let summary = "Acos operator"; | ||||
| 
 | ||||
|   let description = [{ | ||||
|  | @ -364,8 +366,8 @@ def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", | |||
|   }]; | ||||
| } | ||||
| 
 | ||||
| def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan", | ||||
|     [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> { | ||||
| def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan", [], | ||||
|     HLO_FpOrComplexTensor> { | ||||
|   let summary = "Tan operation"; | ||||
| 
 | ||||
|   let description = [{ | ||||
|  |  | |||
|  | @ -30,6 +30,9 @@ template <typename T> | |||
| class OperationPass; | ||||
| class Pass; | ||||
| 
 | ||||
| // Transforms unranked HLO operations to ranked ones where possible.
 | ||||
| std::unique_ptr<FunctionPass> createTransformUnrankedHloPass(); | ||||
| 
 | ||||
| namespace mhlo { | ||||
| 
 | ||||
| /// Lowers HLO control flow ops to the Standard dialect.
 | ||||
|  | @ -52,9 +55,6 @@ std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass( | |||
| // Lowers from HLO dialect to Linalg dialect.
 | ||||
| std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass(); | ||||
| 
 | ||||
| // Transforms unranked HLO operations to ranked ones where possible.
 | ||||
| std::unique_ptr<FunctionPass> createTransformUnrankedHloPass(); | ||||
| 
 | ||||
| // Sinks constants implicitly captured in control flow regions. This is
 | ||||
| // necessary to export to XLA.
 | ||||
| std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass(); | ||||
|  |  | |||
|  | @ -28,6 +28,12 @@ class LLVMTypeConverter; | |||
| class LowerToLLVMOptions; | ||||
| class OwningRewritePatternList; | ||||
| class BufferAssignmentPlacer; | ||||
| 
 | ||||
| // Populates a collection of rewrite patterns to realize element-wise operations
 | ||||
| // on ranked tensors where possible.
 | ||||
| void PopulateTransformUnrankedHloPatterns(MLIRContext *context, | ||||
|                                           OwningRewritePatternList *patterns); | ||||
| 
 | ||||
| namespace mhlo { | ||||
| 
 | ||||
| // Collection of rewrite patterns for lowering a general dot product.
 | ||||
|  |  | |||
|  | @ -14,6 +14,7 @@ limitations under the License. | |||
| 
 | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" | ||||
| #include "mlir/Dialect/Shape/IR/Shape.h" | ||||
|  | @ -27,7 +28,6 @@ limitations under the License. | |||
| #include "mlir/Transforms/DialectConversion.h" | ||||
| 
 | ||||
| namespace mlir { | ||||
| namespace mhlo { | ||||
| namespace { | ||||
| 
 | ||||
| // TODO(herhut): Generate these out of op definitions.
 | ||||
|  | @ -46,6 +46,9 @@ namespace { | |||
|           sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp)              \ | ||||
|               sep fn(ShiftRightLogicalOp) sep fn(SubOp) | ||||
| 
 | ||||
| // TODO(herhut): Generate these out of op definitions.
 | ||||
| #define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) fn(TanOp) sep fn(AcosOp) | ||||
| 
 | ||||
| template <typename OpTy> | ||||
| inline void AddLegalOpOnRankedTensor(ConversionTarget *target) { | ||||
|   target->addDynamicallyLegalOp<OpTy>([](OpTy op) { | ||||
|  | @ -101,8 +104,8 @@ struct ElementwiseOpConversion : public OpRewritePattern<OpTy> { | |||
|           operand.getType().template cast<ShapedType>().getElementType(); | ||||
|       Type flatTy = | ||||
|           RankedTensorType::get({ShapedType::kDynamicSize}, operandElementTy); | ||||
|       Value flat = | ||||
|           rewriter.create<DynamicReshapeOp>(loc, flatTy, operand, flatShape); | ||||
|       Value flat = rewriter.create<mhlo::DynamicReshapeOp>(loc, flatTy, operand, | ||||
|                                                            flatShape); | ||||
|       flatOperands.push_back(flat); | ||||
|     } | ||||
| 
 | ||||
|  | @ -115,8 +118,8 @@ struct ElementwiseOpConversion : public OpRewritePattern<OpTy> { | |||
|         rewriter.create<OpTy>(loc, flatResultTy, flatOperands, op.getAttrs()); | ||||
| 
 | ||||
|     // Restore original shape.
 | ||||
|     rewriter.replaceOpWithNewOp<DynamicReshapeOp>(op, op.getType(), flatResult, | ||||
|                                                   shape); | ||||
|     rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, op.getType(), | ||||
|                                                         flatResult, shape); | ||||
| 
 | ||||
|     return success(); | ||||
|   } | ||||
|  | @ -132,13 +135,16 @@ struct TransformUnrankedHloPass | |||
|     // Setup conversion target.
 | ||||
|     MLIRContext &ctx = getContext(); | ||||
|     ConversionTarget target(ctx); | ||||
|     target.addLegalDialect<MhloDialect, StandardOpsDialect, | ||||
|     target.addLegalDialect<mhlo::MhloDialect, StandardOpsDialect, | ||||
|                            shape::ShapeDialect>(); | ||||
|     target.addLegalOp<FuncOp>(); | ||||
| #define ADD_LEGAL(op) AddLegalOpOnRankedTensor<op>(&target) | ||||
|     MAP_XLA_OPERATION_CWISE_UNARY(ADD_LEGAL, ;); | ||||
|     MAP_XLA_OPERATION_CWISE_BINARY(ADD_LEGAL, ;); | ||||
| #undef ADD_LEGAL | ||||
| #define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor<mhlo::op>(&target) | ||||
| #define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor<chlo::op>(&target) | ||||
|     MAP_XLA_OPERATION_CWISE_UNARY(ADD_LEGAL_MHLO, ;); | ||||
|     MAP_XLA_OPERATION_CWISE_BINARY(ADD_LEGAL_MHLO, ;); | ||||
|     MAP_CHLO_OPERATION_CWISE_UNARY(ADD_LEGAL_CHLO, ;); | ||||
| #undef ADD_LEGAL_MHLO | ||||
| #undef ADD_LEGAL_CHLO | ||||
| 
 | ||||
|     // Populate rewrite patterns.
 | ||||
|     OwningRewritePatternList patterns; | ||||
|  | @ -154,16 +160,19 @@ struct TransformUnrankedHloPass | |||
| 
 | ||||
| void PopulateTransformUnrankedHloPatterns(MLIRContext *context, | ||||
|                                           OwningRewritePatternList *patterns) { | ||||
| #define MAP_UNARY(op) ElementwiseOpConversion<op> | ||||
| #define MAP_BINARY(op) ElementwiseOpConversion<op> | ||||
| #define MAP_UNARY(op) ElementwiseOpConversion<mhlo::op> | ||||
| #define MAP_BINARY(op) ElementwiseOpConversion<mhlo::op> | ||||
| #define MAP_CHLO_UNARY(op) ElementwiseOpConversion<chlo::op> | ||||
| #define COMMA , | ||||
|   // clang-format off
 | ||||
|   patterns->insert< | ||||
|       MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA), | ||||
|       MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA)>(context); | ||||
|       MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA), | ||||
|       MAP_CHLO_OPERATION_CWISE_UNARY(MAP_CHLO_UNARY, COMMA)>(context); | ||||
|   // clang-format on
 | ||||
| #undef MAP_UNARY | ||||
| #undef MAP_BINARY | ||||
| #undef MAP_CHLO_UNARY | ||||
| #undef COMMA | ||||
| } | ||||
| 
 | ||||
|  | @ -171,5 +180,4 @@ std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() { | |||
|   return std::make_unique<TransformUnrankedHloPass>(); | ||||
| } | ||||
| 
 | ||||
| }  // namespace mhlo
 | ||||
| }  // namespace mlir
 | ||||
|  |  | |||
|  | @ -1,4 +1,4 @@ | |||
| // RUN: mlir-hlo-opt -transform-unranked-hlo -split-input-file %s | FileCheck %s | ||||
| // RUN: mlir-hlo-opt --transform-unranked-hlo --split-input-file %s | FileCheck %s | ||||
| 
 | ||||
| // Check the validity of expected IR. | ||||
| // CHECK-LABEL: @sqr_transform_result | ||||
|  | @ -80,3 +80,19 @@ func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> { | |||
|   %result = mhlo.add %a, %b : tensor<*xf32> | ||||
|   return %result : tensor<*xf32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // CHECK-LABEL: @tan | ||||
| // CHECK-SAME: (%[[A:.*]]: tensor<*xf32>) -> tensor<*xf32> | ||||
| func @tan(%a : tensor<*xf32>) -> tensor<*xf32> { | ||||
|   // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32> -> tensor<?xindex> | ||||
|   // CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] | ||||
|   // CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> | ||||
|   // CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> | ||||
|   // CHECK: %[[FLAT_B:.*]] = chlo.tan %[[FLAT_A]] : tensor<?xf32> | ||||
|   // CHECK: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32> | ||||
|   // CHECK: return %[[B]] : tensor<*xf32> | ||||
|   %result = chlo.tan %a : tensor<*xf32> | ||||
|   return %result : tensor<*xf32> | ||||
| } | ||||
		Loading…
	
		Reference in New Issue