[KERNEL_GEN] Switch the pipeline to Linalg-on-Tensors.
PiperOrigin-RevId: 347600145
This commit is contained in:
		
							parent
							
								
									79fa36bcbc
								
							
						
					
					
						commit
						ddda2699fb
					
				| 
						 | 
					@ -52,11 +52,6 @@ void PopulateGatherToTorchIndexSelectPatterns(
 | 
				
			||||||
void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
 | 
					void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
 | 
				
			||||||
                               MLIRContext *ctx);
 | 
					                               MLIRContext *ctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Collection of rewrite patterns for lowering of dynamic HLOs to LHLO dialect.
 | 
					 | 
				
			||||||
void populateDynamicHLOToLHLOConversionPattern(
 | 
					 | 
				
			||||||
    MLIRContext *context, BufferizeTypeConverter *converter,
 | 
					 | 
				
			||||||
    OwningRewritePatternList *patterns, bool insert_copy = true);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Collection of rewrite patterns for lowering of HLO to LHLO dialect.
 | 
					// Collection of rewrite patterns for lowering of HLO to LHLO dialect.
 | 
				
			||||||
void populateHLOToLHLOConversionPattern(MLIRContext *context,
 | 
					void populateHLOToLHLOConversionPattern(MLIRContext *context,
 | 
				
			||||||
                                        BufferizeTypeConverter *converter,
 | 
					                                        BufferizeTypeConverter *converter,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -192,56 +192,24 @@ struct HloToLhloCustomCallOpConverter
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// TODO(pifon): Consider inserting lhlo.copy as in
 | 
					struct HloToLhloDynamicBroadcastInDimOpConverter
 | 
				
			||||||
// HloToLhloDynamicBroadcastInDimOpConverter.
 | 
					 | 
				
			||||||
struct HloToLhloDynamicReshapeConverter
 | 
					 | 
				
			||||||
    : public BaseOpConversion<mhlo::DynamicReshapeOp> {
 | 
					 | 
				
			||||||
 public:
 | 
					 | 
				
			||||||
  using BaseOpConversion<mhlo::DynamicReshapeOp>::BaseOpConversion;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  LogicalResult matchAndRewrite(
 | 
					 | 
				
			||||||
      mhlo::DynamicReshapeOp op, ArrayRef<Value> operands,
 | 
					 | 
				
			||||||
      ConversionPatternRewriter& rewriter) const final {
 | 
					 | 
				
			||||||
    Type result_type;
 | 
					 | 
				
			||||||
    if (auto ranked_type = op.getType().dyn_cast<RankedTensorType>()) {
 | 
					 | 
				
			||||||
      result_type =
 | 
					 | 
				
			||||||
          MemRefType::get(ranked_type.getShape(), ranked_type.getElementType());
 | 
					 | 
				
			||||||
    } else if (auto unranked_type =
 | 
					 | 
				
			||||||
                   op.getType().dyn_cast<UnrankedTensorType>()) {
 | 
					 | 
				
			||||||
      result_type = UnrankedMemRefType::get(unranked_type.getElementType(), 0);
 | 
					 | 
				
			||||||
    } else {
 | 
					 | 
				
			||||||
      return failure();
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
 | 
					 | 
				
			||||||
    rewriter.replaceOpWithNewOp<MemRefReshapeOp>(
 | 
					 | 
				
			||||||
        op, result_type, adaptor.operand(), adaptor.output_shape());
 | 
					 | 
				
			||||||
    return success();
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
};
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class HloToLhloDynamicBroadcastInDimOpConverter
 | 
					 | 
				
			||||||
    : public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> {
 | 
					    : public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> {
 | 
				
			||||||
 public:
 | 
					 public:
 | 
				
			||||||
  HloToLhloDynamicBroadcastInDimOpConverter(TypeConverter& converter,
 | 
					  using BaseOpConversion<mhlo::DynamicBroadcastInDimOp>::BaseOpConversion;
 | 
				
			||||||
                                            MLIRContext* ctx,
 | 
					 | 
				
			||||||
                                            bool insert_copy = true)
 | 
					 | 
				
			||||||
      : BaseOpConversion<mhlo::DynamicBroadcastInDimOp>(converter, ctx),
 | 
					 | 
				
			||||||
        insert_copy_(insert_copy) {}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  LogicalResult matchAndRewrite(
 | 
					  LogicalResult matchAndRewrite(
 | 
				
			||||||
      mhlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
 | 
					      mhlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
 | 
				
			||||||
      ConversionPatternRewriter& rewriter) const final {
 | 
					      ConversionPatternRewriter& rewriter) const final {
 | 
				
			||||||
    Value result = InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if (insert_copy_) {
 | 
					 | 
				
			||||||
    auto loc = op.getLoc();
 | 
					    auto loc = op.getLoc();
 | 
				
			||||||
      Value result_buffer = InsertDynamicAllocAndDealloc(
 | 
					    Value resultBuffer = InsertDynamicAllocAndDealloc(
 | 
				
			||||||
        loc, op.getResult(), op.output_dimensions(), &rewriter);
 | 
					        loc, op.getResult(), op.output_dimensions(), &rewriter);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      rewriter.create<lmhlo::CopyOp>(loc, result, result_buffer);
 | 
					    Value transformed_operand =
 | 
				
			||||||
      result = result_buffer;
 | 
					        InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
 | 
				
			||||||
    }
 | 
					    rewriter.create<lmhlo::CopyOp>(loc, transformed_operand, resultBuffer);
 | 
				
			||||||
    rewriter.replaceOp(op, {result});
 | 
					
 | 
				
			||||||
 | 
					    rewriter.replaceOp(op, {resultBuffer});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return success();
 | 
					    return success();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -339,10 +307,31 @@ class HloToLhloDynamicBroadcastInDimOpConverter
 | 
				
			||||||
        static_strides, llvm::None, sizes, strides);
 | 
					        static_strides, llvm::None, sizes, strides);
 | 
				
			||||||
    return transformed_operand;
 | 
					    return transformed_operand;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Keep the copy semantics and allocate a buffer for the result of the memref
 | 
					struct HloToLhloDynamicReshapeConverter
 | 
				
			||||||
  // cast.
 | 
					    : public BaseOpConversion<mhlo::DynamicReshapeOp> {
 | 
				
			||||||
  bool insert_copy_;
 | 
					 public:
 | 
				
			||||||
 | 
					  using BaseOpConversion<mhlo::DynamicReshapeOp>::BaseOpConversion;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  LogicalResult matchAndRewrite(
 | 
				
			||||||
 | 
					      mhlo::DynamicReshapeOp op, ArrayRef<Value> operands,
 | 
				
			||||||
 | 
					      ConversionPatternRewriter& rewriter) const final {
 | 
				
			||||||
 | 
					    Type result_type;
 | 
				
			||||||
 | 
					    if (auto ranked_type = op.getType().dyn_cast<RankedTensorType>()) {
 | 
				
			||||||
 | 
					      result_type =
 | 
				
			||||||
 | 
					          MemRefType::get(ranked_type.getShape(), ranked_type.getElementType());
 | 
				
			||||||
 | 
					    } else if (auto unranked_type =
 | 
				
			||||||
 | 
					                   op.getType().dyn_cast<UnrankedTensorType>()) {
 | 
				
			||||||
 | 
					      result_type = UnrankedMemRefType::get(unranked_type.getElementType(), 0);
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      return failure();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
 | 
				
			||||||
 | 
					    rewriter.replaceOpWithNewOp<MemRefReshapeOp>(
 | 
				
			||||||
 | 
					        op, result_type, adaptor.operand(), adaptor.output_shape());
 | 
				
			||||||
 | 
					    return success();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct HloToLhloDotGeneralOpConverter
 | 
					struct HloToLhloDotGeneralOpConverter
 | 
				
			||||||
| 
						 | 
					@ -604,22 +593,15 @@ struct HloLegalizeToLhlo
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
}  // namespace
 | 
					}  // namespace
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void populateDynamicHLOToLHLOConversionPattern(
 | 
					 | 
				
			||||||
    MLIRContext* context, BufferizeTypeConverter* converter,
 | 
					 | 
				
			||||||
    OwningRewritePatternList* patterns, bool insert_copy) {
 | 
					 | 
				
			||||||
  patterns->insert<HloToLhloDynamicBroadcastInDimOpConverter>(
 | 
					 | 
				
			||||||
      *converter, context, insert_copy);
 | 
					 | 
				
			||||||
  patterns->insert<HloToLhloDynamicReshapeConverter>(*converter, context);
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
void populateHLOToLHLOConversionPattern(MLIRContext* context,
 | 
					void populateHLOToLHLOConversionPattern(MLIRContext* context,
 | 
				
			||||||
                                        BufferizeTypeConverter* converter,
 | 
					                                        BufferizeTypeConverter* converter,
 | 
				
			||||||
                                        OwningRewritePatternList* patterns) {
 | 
					                                        OwningRewritePatternList* patterns) {
 | 
				
			||||||
  populateDynamicHLOToLHLOConversionPattern(context, converter, patterns);
 | 
					 | 
				
			||||||
  // clang-format off
 | 
					  // clang-format off
 | 
				
			||||||
  patterns->insert<
 | 
					  patterns->insert<
 | 
				
			||||||
      HloToLhloCustomCallOpConverter,
 | 
					      HloToLhloCustomCallOpConverter,
 | 
				
			||||||
      HloToLhloDotGeneralOpConverter,
 | 
					      HloToLhloDotGeneralOpConverter,
 | 
				
			||||||
 | 
					      HloToLhloDynamicBroadcastInDimOpConverter,
 | 
				
			||||||
 | 
					      HloToLhloDynamicReshapeConverter,
 | 
				
			||||||
      HloToLhloOpConverter<mhlo::AbsOp>,
 | 
					      HloToLhloOpConverter<mhlo::AbsOp>,
 | 
				
			||||||
      HloToLhloOpConverter<mhlo::AddOp>,
 | 
					      HloToLhloOpConverter<mhlo::AddOp>,
 | 
				
			||||||
      HloToLhloOpConverter<mhlo::AndOp>,
 | 
					      HloToLhloOpConverter<mhlo::AndOp>,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -170,31 +170,24 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) -> index {
 | 
				
			||||||
  return %rank : index
 | 
					  return %rank : index
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
// CHECK: %[[SHAPE:.*]] = tensor_from_elements
 | 
					// CHECK: %[[SHAPE:.*]] = tensor_from_elements
 | 
				
			||||||
 | 
					 | 
				
			||||||
// CHECK: %[[C0:.*]] = constant 0 : index
 | 
					// CHECK: %[[C0:.*]] = constant 0 : index
 | 
				
			||||||
// CHECK: %[[C1:.*]] = constant 1 : index
 | 
					 | 
				
			||||||
// CHECK: %[[OPER_DIM_1:.*]] = dim %[[OPERAND]], %[[C1]] : memref<?x?xf32>
 | 
					 | 
				
			||||||
// CHECK: %[[OP_STRIDE_0:.*]] = muli %[[C1]], %[[OPER_DIM_1]] : index
 | 
					 | 
				
			||||||
// CHECK: %[[OPER_DIM_0:.*]] = dim %[[OPERAND]], %[[C0]] : memref<?x?xf32>
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64>
 | 
					// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64>
 | 
				
			||||||
// CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index
 | 
					// CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index
 | 
				
			||||||
 | 
					// CHECK: %[[C1:.*]] = constant 1 : index
 | 
				
			||||||
// CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1]]] : tensor<3xi64>
 | 
					// CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1]]] : tensor<3xi64>
 | 
				
			||||||
 | 
					 | 
				
			||||||
// CHECK: %[[SIZE_1:.*]] = index_cast %[[EL1]] : i64 to index
 | 
					// CHECK: %[[SIZE_1:.*]] = index_cast %[[EL1]] : i64 to index
 | 
				
			||||||
// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPER_DIM_0]], %[[SIZE_1]] : index
 | 
					 | 
				
			||||||
// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0]], %[[OP_STRIDE_0]] : index
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// CHECK: %[[C2:.*]] = constant 2 : index
 | 
					// CHECK: %[[C2:.*]] = constant 2 : index
 | 
				
			||||||
// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2]]] : tensor<3xi64>
 | 
					// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2]]] : tensor<3xi64>
 | 
				
			||||||
// CHECK: %[[SIZE_2:.*]] = index_cast %[[EL2]] : i64 to index
 | 
					// CHECK: %[[SIZE_2:.*]] = index_cast %[[EL2]] : i64 to index
 | 
				
			||||||
 | 
					// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>
 | 
				
			||||||
 | 
					// CHECK: %[[OPER_DIM_1:.*]] = dim %[[OPERAND]], %[[C1]] : memref<?x?xf32>
 | 
				
			||||||
 | 
					// CHECK: %[[OP_STRIDE_0:.*]] = muli %[[C1]], %[[OPER_DIM_1]] : index
 | 
				
			||||||
 | 
					// CHECK: %[[OPER_DIM_0:.*]] = dim %[[OPERAND]], %[[C0]] : memref<?x?xf32>
 | 
				
			||||||
 | 
					// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPER_DIM_0]], %[[SIZE_1]] : index
 | 
				
			||||||
 | 
					// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0]], %[[OP_STRIDE_0]] : index
 | 
				
			||||||
// CHECK: %[[EXPAND_2:.*]] = cmpi "slt", %[[OPER_DIM_1]], %[[SIZE_2]] : index
 | 
					// CHECK: %[[EXPAND_2:.*]] = cmpi "slt", %[[OPER_DIM_1]], %[[SIZE_2]] : index
 | 
				
			||||||
// CHECK: %[[STRIDE_2:.*]] = select %[[EXPAND_2]], %[[C0]], %[[C1]] : index
 | 
					// CHECK: %[[STRIDE_2:.*]] = select %[[EXPAND_2]], %[[C0]], %[[C1]] : index
 | 
				
			||||||
 | 
					 | 
				
			||||||
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]]: memref<?x?xf32> to memref<?x?x?xf32, #map>
 | 
					// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]]: memref<?x?xf32> to memref<?x?x?xf32, #map>
 | 
				
			||||||
 | 
					 | 
				
			||||||
// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// CHECK: "lmhlo.copy"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) : (memref<?x?x?xf32, #map>, memref<?x?x?xf32>) -> ()
 | 
					// CHECK: "lmhlo.copy"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) : (memref<?x?x?xf32, #map>, memref<?x?x?xf32>) -> ()
 | 
				
			||||||
// CHECK: dealloc %[[RESULT]] : memref<?x?x?xf32>
 | 
					// CHECK: dealloc %[[RESULT]] : memref<?x?x?xf32>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue