diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h index a2066df..1fb9ba6 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h @@ -52,6 +52,11 @@ void PopulateGatherToTorchIndexSelectPatterns( void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns, MLIRContext *ctx); +// Collection of rewrite patterns for lowering of dynamic HLOs to LHLO dialect. +void populateDynamicHLOToLHLOConversionPattern( + MLIRContext *context, BufferizeTypeConverter *converter, + OwningRewritePatternList *patterns, bool insert_copy = true); + // Collection of rewrite patterns for lowering of HLO to LHLO dialect. void populateHLOToLHLOConversionPattern(MLIRContext *context, BufferizeTypeConverter *converter, diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index 0d5e52c..822fa56 100644 --- a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -192,24 +192,56 @@ struct HloToLhloCustomCallOpConverter } }; -struct HloToLhloDynamicBroadcastInDimOpConverter +// TODO(pifon): Consider inserting lhlo.copy as in +// HloToLhloDynamicBroadcastInDimOpConverter. +struct HloToLhloDynamicReshapeConverter + : public BaseOpConversion { + public: + using BaseOpConversion::BaseOpConversion; + + LogicalResult matchAndRewrite( + mhlo::DynamicReshapeOp op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + Type result_type; + if (auto ranked_type = op.getType().dyn_cast()) { + result_type = + MemRefType::get(ranked_type.getShape(), ranked_type.getElementType()); + } else if (auto unranked_type = + op.getType().dyn_cast()) { + result_type = UnrankedMemRefType::get(unranked_type.getElementType(), 0); + } else { + return failure(); + } + mhlo::DynamicReshapeOp::Adaptor adaptor(operands); + rewriter.replaceOpWithNewOp( + op, result_type, adaptor.operand(), adaptor.output_shape()); + return success(); + } +}; + +class HloToLhloDynamicBroadcastInDimOpConverter : public BaseOpConversion { public: - using BaseOpConversion::BaseOpConversion; + HloToLhloDynamicBroadcastInDimOpConverter(TypeConverter& converter, + MLIRContext* ctx, + bool insert_copy = true) + : BaseOpConversion(converter, ctx), + insert_copy_(insert_copy) {} LogicalResult matchAndRewrite( mhlo::DynamicBroadcastInDimOp op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { - auto loc = op.getLoc(); - Value resultBuffer = InsertDynamicAllocAndDealloc( - loc, op.getResult(), op.output_dimensions(), &rewriter); + Value result = InsertDynamicMemrefCastOp(op, operands.front(), &rewriter); - Value transformed_operand = - InsertDynamicMemrefCastOp(op, operands.front(), &rewriter); - rewriter.create(loc, transformed_operand, resultBuffer); - - rewriter.replaceOp(op, {resultBuffer}); + if (insert_copy_) { + auto loc = op.getLoc(); + Value result_buffer = InsertDynamicAllocAndDealloc( + loc, op.getResult(), op.output_dimensions(), &rewriter); + rewriter.create(loc, result, result_buffer); + result = result_buffer; + } + rewriter.replaceOp(op, {result}); return success(); } @@ -307,31 +339,10 @@ struct HloToLhloDynamicBroadcastInDimOpConverter static_strides, llvm::None, sizes, strides); return transformed_operand; } -}; -struct HloToLhloDynamicReshapeConverter - : public BaseOpConversion { - public: - using BaseOpConversion::BaseOpConversion; - - LogicalResult matchAndRewrite( - mhlo::DynamicReshapeOp op, ArrayRef operands, - ConversionPatternRewriter& rewriter) const final { - Type result_type; - if (auto ranked_type = op.getType().dyn_cast()) { - result_type = - MemRefType::get(ranked_type.getShape(), ranked_type.getElementType()); - } else if (auto unranked_type = - op.getType().dyn_cast()) { - result_type = UnrankedMemRefType::get(unranked_type.getElementType(), 0); - } else { - return failure(); - } - mhlo::DynamicReshapeOp::Adaptor adaptor(operands); - rewriter.replaceOpWithNewOp( - op, result_type, adaptor.operand(), adaptor.output_shape()); - return success(); - } + // Keep the copy semantics and allocate a buffer for the result of the memref + // cast. + bool insert_copy_; }; struct HloToLhloDotGeneralOpConverter @@ -593,15 +604,22 @@ struct HloLegalizeToLhlo }; } // namespace +void populateDynamicHLOToLHLOConversionPattern( + MLIRContext* context, BufferizeTypeConverter* converter, + OwningRewritePatternList* patterns, bool insert_copy) { + patterns->insert( + *converter, context, insert_copy); + patterns->insert(*converter, context); +} + void populateHLOToLHLOConversionPattern(MLIRContext* context, BufferizeTypeConverter* converter, OwningRewritePatternList* patterns) { + populateDynamicHLOToLHLOConversionPattern(context, converter, patterns); // clang-format off patterns->insert< HloToLhloCustomCallOpConverter, HloToLhloDotGeneralOpConverter, - HloToLhloDynamicBroadcastInDimOpConverter, - HloToLhloDynamicReshapeConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, diff --git a/tests/hlo-legalize-to-lhlo.mlir b/tests/hlo-legalize-to-lhlo.mlir index 0c1ee24..5c05d5e 100644 --- a/tests/hlo-legalize-to-lhlo.mlir +++ b/tests/hlo-legalize-to-lhlo.mlir @@ -170,24 +170,31 @@ func @dyn_broadcast(%operand: memref) -> index { return %rank : index } // CHECK: %[[SHAPE:.*]] = tensor_from_elements + // CHECK: %[[C0:.*]] = constant 0 : index -// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64> -// CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index // CHECK: %[[C1:.*]] = constant 1 : index -// CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1]]] : tensor<3xi64> -// CHECK: %[[SIZE_1:.*]] = index_cast %[[EL1]] : i64 to index -// CHECK: %[[C2:.*]] = constant 2 : index -// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2]]] : tensor<3xi64> -// CHECK: %[[SIZE_2:.*]] = index_cast %[[EL2]] : i64 to index -// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref // CHECK: %[[OPER_DIM_1:.*]] = dim %[[OPERAND]], %[[C1]] : memref // CHECK: %[[OP_STRIDE_0:.*]] = muli %[[C1]], %[[OPER_DIM_1]] : index // CHECK: %[[OPER_DIM_0:.*]] = dim %[[OPERAND]], %[[C0]] : memref + +// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64> +// CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index +// CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1]]] : tensor<3xi64> + +// CHECK: %[[SIZE_1:.*]] = index_cast %[[EL1]] : i64 to index // CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPER_DIM_0]], %[[SIZE_1]] : index // CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0]], %[[OP_STRIDE_0]] : index + +// CHECK: %[[C2:.*]] = constant 2 : index +// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2]]] : tensor<3xi64> +// CHECK: %[[SIZE_2:.*]] = index_cast %[[EL2]] : i64 to index // CHECK: %[[EXPAND_2:.*]] = cmpi "slt", %[[OPER_DIM_1]], %[[SIZE_2]] : index // CHECK: %[[STRIDE_2:.*]] = select %[[EXPAND_2]], %[[C0]], %[[C1]] : index + // CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]]: memref to memref + +// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref + // CHECK: "lmhlo.copy"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) : (memref, memref) -> () // CHECK: dealloc %[[RESULT]] : memref