From ddda2699fb50246197cb95f899c05f7d0cc8e5f3 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Tue, 15 Dec 2020 06:31:28 -0800 Subject: [PATCH] [KERNEL_GEN] Switch the pipeline to Linalg-on-Tensors. PiperOrigin-RevId: 347600145 --- .../Dialect/mhlo/transforms/rewriters.h | 5 -- .../mhlo/transforms/hlo_legalize_to_lhlo.cc | 90 ++++++++----------- tests/hlo-legalize-to-lhlo.mlir | 21 ++--- 3 files changed, 43 insertions(+), 73 deletions(-) diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h index 1fb9ba6..a2066df 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h @@ -52,11 +52,6 @@ void PopulateGatherToTorchIndexSelectPatterns( void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns, MLIRContext *ctx); -// Collection of rewrite patterns for lowering of dynamic HLOs to LHLO dialect. -void populateDynamicHLOToLHLOConversionPattern( - MLIRContext *context, BufferizeTypeConverter *converter, - OwningRewritePatternList *patterns, bool insert_copy = true); - // Collection of rewrite patterns for lowering of HLO to LHLO dialect. void populateHLOToLHLOConversionPattern(MLIRContext *context, BufferizeTypeConverter *converter, diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index 11f9159..0460a31 100644 --- a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -192,56 +192,24 @@ struct HloToLhloCustomCallOpConverter } }; -// TODO(pifon): Consider inserting lhlo.copy as in -// HloToLhloDynamicBroadcastInDimOpConverter. -struct HloToLhloDynamicReshapeConverter - : public BaseOpConversion { - public: - using BaseOpConversion::BaseOpConversion; - - LogicalResult matchAndRewrite( - mhlo::DynamicReshapeOp op, ArrayRef operands, - ConversionPatternRewriter& rewriter) const final { - Type result_type; - if (auto ranked_type = op.getType().dyn_cast()) { - result_type = - MemRefType::get(ranked_type.getShape(), ranked_type.getElementType()); - } else if (auto unranked_type = - op.getType().dyn_cast()) { - result_type = UnrankedMemRefType::get(unranked_type.getElementType(), 0); - } else { - return failure(); - } - mhlo::DynamicReshapeOp::Adaptor adaptor(operands); - rewriter.replaceOpWithNewOp( - op, result_type, adaptor.operand(), adaptor.output_shape()); - return success(); - } -}; - -class HloToLhloDynamicBroadcastInDimOpConverter +struct HloToLhloDynamicBroadcastInDimOpConverter : public BaseOpConversion { public: - HloToLhloDynamicBroadcastInDimOpConverter(TypeConverter& converter, - MLIRContext* ctx, - bool insert_copy = true) - : BaseOpConversion(converter, ctx), - insert_copy_(insert_copy) {} + using BaseOpConversion::BaseOpConversion; LogicalResult matchAndRewrite( mhlo::DynamicBroadcastInDimOp op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { - Value result = InsertDynamicMemrefCastOp(op, operands.front(), &rewriter); + auto loc = op.getLoc(); + Value resultBuffer = InsertDynamicAllocAndDealloc( + loc, op.getResult(), op.output_dimensions(), &rewriter); - if (insert_copy_) { - auto loc = op.getLoc(); - Value result_buffer = InsertDynamicAllocAndDealloc( - loc, op.getResult(), op.output_dimensions(), &rewriter); + Value transformed_operand = + InsertDynamicMemrefCastOp(op, operands.front(), &rewriter); + rewriter.create(loc, transformed_operand, resultBuffer); + + rewriter.replaceOp(op, {resultBuffer}); - rewriter.create(loc, result, result_buffer); - result = result_buffer; - } - rewriter.replaceOp(op, {result}); return success(); } @@ -339,10 +307,31 @@ class HloToLhloDynamicBroadcastInDimOpConverter static_strides, llvm::None, sizes, strides); return transformed_operand; } +}; - // Keep the copy semantics and allocate a buffer for the result of the memref - // cast. - bool insert_copy_; +struct HloToLhloDynamicReshapeConverter + : public BaseOpConversion { + public: + using BaseOpConversion::BaseOpConversion; + + LogicalResult matchAndRewrite( + mhlo::DynamicReshapeOp op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + Type result_type; + if (auto ranked_type = op.getType().dyn_cast()) { + result_type = + MemRefType::get(ranked_type.getShape(), ranked_type.getElementType()); + } else if (auto unranked_type = + op.getType().dyn_cast()) { + result_type = UnrankedMemRefType::get(unranked_type.getElementType(), 0); + } else { + return failure(); + } + mhlo::DynamicReshapeOp::Adaptor adaptor(operands); + rewriter.replaceOpWithNewOp( + op, result_type, adaptor.operand(), adaptor.output_shape()); + return success(); + } }; struct HloToLhloDotGeneralOpConverter @@ -604,22 +593,15 @@ struct HloLegalizeToLhlo }; } // namespace -void populateDynamicHLOToLHLOConversionPattern( - MLIRContext* context, BufferizeTypeConverter* converter, - OwningRewritePatternList* patterns, bool insert_copy) { - patterns->insert( - *converter, context, insert_copy); - patterns->insert(*converter, context); -} - void populateHLOToLHLOConversionPattern(MLIRContext* context, BufferizeTypeConverter* converter, OwningRewritePatternList* patterns) { - populateDynamicHLOToLHLOConversionPattern(context, converter, patterns); // clang-format off patterns->insert< HloToLhloCustomCallOpConverter, HloToLhloDotGeneralOpConverter, + HloToLhloDynamicBroadcastInDimOpConverter, + HloToLhloDynamicReshapeConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, diff --git a/tests/hlo-legalize-to-lhlo.mlir b/tests/hlo-legalize-to-lhlo.mlir index 5c05d5e..0c1ee24 100644 --- a/tests/hlo-legalize-to-lhlo.mlir +++ b/tests/hlo-legalize-to-lhlo.mlir @@ -170,31 +170,24 @@ func @dyn_broadcast(%operand: memref) -> index { return %rank : index } // CHECK: %[[SHAPE:.*]] = tensor_from_elements - // CHECK: %[[C0:.*]] = constant 0 : index -// CHECK: %[[C1:.*]] = constant 1 : index -// CHECK: %[[OPER_DIM_1:.*]] = dim %[[OPERAND]], %[[C1]] : memref -// CHECK: %[[OP_STRIDE_0:.*]] = muli %[[C1]], %[[OPER_DIM_1]] : index -// CHECK: %[[OPER_DIM_0:.*]] = dim %[[OPERAND]], %[[C0]] : memref - // CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64> // CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index +// CHECK: %[[C1:.*]] = constant 1 : index // CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1]]] : tensor<3xi64> - // CHECK: %[[SIZE_1:.*]] = index_cast %[[EL1]] : i64 to index -// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPER_DIM_0]], %[[SIZE_1]] : index -// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0]], %[[OP_STRIDE_0]] : index - // CHECK: %[[C2:.*]] = constant 2 : index // CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2]]] : tensor<3xi64> // CHECK: %[[SIZE_2:.*]] = index_cast %[[EL2]] : i64 to index +// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref +// CHECK: %[[OPER_DIM_1:.*]] = dim %[[OPERAND]], %[[C1]] : memref +// CHECK: %[[OP_STRIDE_0:.*]] = muli %[[C1]], %[[OPER_DIM_1]] : index +// CHECK: %[[OPER_DIM_0:.*]] = dim %[[OPERAND]], %[[C0]] : memref +// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPER_DIM_0]], %[[SIZE_1]] : index +// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0]], %[[OP_STRIDE_0]] : index // CHECK: %[[EXPAND_2:.*]] = cmpi "slt", %[[OPER_DIM_1]], %[[SIZE_2]] : index // CHECK: %[[STRIDE_2:.*]] = select %[[EXPAND_2]], %[[C0]], %[[C1]] : index - // CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]]: memref to memref - -// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref - // CHECK: "lmhlo.copy"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) : (memref, memref) -> () // CHECK: dealloc %[[RESULT]] : memref