[KERNEL_GEN] Switch the pipeline to Linalg-on-Tensors.
PiperOrigin-RevId: 347368063
This commit is contained in:
parent
6b439f7eee
commit
8b35a75d4a
|
@ -52,6 +52,11 @@ void PopulateGatherToTorchIndexSelectPatterns(
|
|||
void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
|
||||
MLIRContext *ctx);
|
||||
|
||||
// Collection of rewrite patterns for lowering of dynamic HLOs to LHLO dialect.
|
||||
void populateDynamicHLOToLHLOConversionPattern(
|
||||
MLIRContext *context, BufferizeTypeConverter *converter,
|
||||
OwningRewritePatternList *patterns, bool insert_copy = true);
|
||||
|
||||
// Collection of rewrite patterns for lowering of HLO to LHLO dialect.
|
||||
void populateHLOToLHLOConversionPattern(MLIRContext *context,
|
||||
BufferizeTypeConverter *converter,
|
||||
|
|
|
@ -192,24 +192,56 @@ struct HloToLhloCustomCallOpConverter
|
|||
}
|
||||
};
|
||||
|
||||
struct HloToLhloDynamicBroadcastInDimOpConverter
|
||||
// TODO(pifon): Consider inserting lhlo.copy as in
|
||||
// HloToLhloDynamicBroadcastInDimOpConverter.
|
||||
struct HloToLhloDynamicReshapeConverter
|
||||
: public BaseOpConversion<mhlo::DynamicReshapeOp> {
|
||||
public:
|
||||
using BaseOpConversion<mhlo::DynamicReshapeOp>::BaseOpConversion;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
mhlo::DynamicReshapeOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
Type result_type;
|
||||
if (auto ranked_type = op.getType().dyn_cast<RankedTensorType>()) {
|
||||
result_type =
|
||||
MemRefType::get(ranked_type.getShape(), ranked_type.getElementType());
|
||||
} else if (auto unranked_type =
|
||||
op.getType().dyn_cast<UnrankedTensorType>()) {
|
||||
result_type = UnrankedMemRefType::get(unranked_type.getElementType(), 0);
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
|
||||
rewriter.replaceOpWithNewOp<MemRefReshapeOp>(
|
||||
op, result_type, adaptor.operand(), adaptor.output_shape());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class HloToLhloDynamicBroadcastInDimOpConverter
|
||||
: public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> {
|
||||
public:
|
||||
using BaseOpConversion<mhlo::DynamicBroadcastInDimOp>::BaseOpConversion;
|
||||
HloToLhloDynamicBroadcastInDimOpConverter(TypeConverter& converter,
|
||||
MLIRContext* ctx,
|
||||
bool insert_copy = true)
|
||||
: BaseOpConversion<mhlo::DynamicBroadcastInDimOp>(converter, ctx),
|
||||
insert_copy_(insert_copy) {}
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
mhlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto loc = op.getLoc();
|
||||
Value resultBuffer = InsertDynamicAllocAndDealloc(
|
||||
loc, op.getResult(), op.output_dimensions(), &rewriter);
|
||||
Value result = InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
|
||||
|
||||
Value transformed_operand =
|
||||
InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
|
||||
rewriter.create<lmhlo::CopyOp>(loc, transformed_operand, resultBuffer);
|
||||
|
||||
rewriter.replaceOp(op, {resultBuffer});
|
||||
if (insert_copy_) {
|
||||
auto loc = op.getLoc();
|
||||
Value result_buffer = InsertDynamicAllocAndDealloc(
|
||||
loc, op.getResult(), op.output_dimensions(), &rewriter);
|
||||
|
||||
rewriter.create<lmhlo::CopyOp>(loc, result, result_buffer);
|
||||
result = result_buffer;
|
||||
}
|
||||
rewriter.replaceOp(op, {result});
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -307,31 +339,10 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
|
|||
static_strides, llvm::None, sizes, strides);
|
||||
return transformed_operand;
|
||||
}
|
||||
};
|
||||
|
||||
struct HloToLhloDynamicReshapeConverter
|
||||
: public BaseOpConversion<mhlo::DynamicReshapeOp> {
|
||||
public:
|
||||
using BaseOpConversion<mhlo::DynamicReshapeOp>::BaseOpConversion;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
mhlo::DynamicReshapeOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
Type result_type;
|
||||
if (auto ranked_type = op.getType().dyn_cast<RankedTensorType>()) {
|
||||
result_type =
|
||||
MemRefType::get(ranked_type.getShape(), ranked_type.getElementType());
|
||||
} else if (auto unranked_type =
|
||||
op.getType().dyn_cast<UnrankedTensorType>()) {
|
||||
result_type = UnrankedMemRefType::get(unranked_type.getElementType(), 0);
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
|
||||
rewriter.replaceOpWithNewOp<MemRefReshapeOp>(
|
||||
op, result_type, adaptor.operand(), adaptor.output_shape());
|
||||
return success();
|
||||
}
|
||||
// Keep the copy semantics and allocate a buffer for the result of the memref
|
||||
// cast.
|
||||
bool insert_copy_;
|
||||
};
|
||||
|
||||
struct HloToLhloDotGeneralOpConverter
|
||||
|
@ -593,15 +604,22 @@ struct HloLegalizeToLhlo
|
|||
};
|
||||
} // namespace
|
||||
|
||||
void populateDynamicHLOToLHLOConversionPattern(
|
||||
MLIRContext* context, BufferizeTypeConverter* converter,
|
||||
OwningRewritePatternList* patterns, bool insert_copy) {
|
||||
patterns->insert<HloToLhloDynamicBroadcastInDimOpConverter>(
|
||||
*converter, context, insert_copy);
|
||||
patterns->insert<HloToLhloDynamicReshapeConverter>(*converter, context);
|
||||
}
|
||||
|
||||
void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
||||
BufferizeTypeConverter* converter,
|
||||
OwningRewritePatternList* patterns) {
|
||||
populateDynamicHLOToLHLOConversionPattern(context, converter, patterns);
|
||||
// clang-format off
|
||||
patterns->insert<
|
||||
HloToLhloCustomCallOpConverter,
|
||||
HloToLhloDotGeneralOpConverter,
|
||||
HloToLhloDynamicBroadcastInDimOpConverter,
|
||||
HloToLhloDynamicReshapeConverter,
|
||||
HloToLhloOpConverter<mhlo::AbsOp>,
|
||||
HloToLhloOpConverter<mhlo::AddOp>,
|
||||
HloToLhloOpConverter<mhlo::AndOp>,
|
||||
|
|
|
@ -170,24 +170,31 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) -> index {
|
|||
return %rank : index
|
||||
}
|
||||
// CHECK: %[[SHAPE:.*]] = tensor_from_elements
|
||||
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64>
|
||||
// CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1]]] : tensor<3xi64>
|
||||
// CHECK: %[[SIZE_1:.*]] = index_cast %[[EL1]] : i64 to index
|
||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2]]] : tensor<3xi64>
|
||||
// CHECK: %[[SIZE_2:.*]] = index_cast %[[EL2]] : i64 to index
|
||||
// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>
|
||||
// CHECK: %[[OPER_DIM_1:.*]] = dim %[[OPERAND]], %[[C1]] : memref<?x?xf32>
|
||||
// CHECK: %[[OP_STRIDE_0:.*]] = muli %[[C1]], %[[OPER_DIM_1]] : index
|
||||
// CHECK: %[[OPER_DIM_0:.*]] = dim %[[OPERAND]], %[[C0]] : memref<?x?xf32>
|
||||
|
||||
// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64>
|
||||
// CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index
|
||||
// CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1]]] : tensor<3xi64>
|
||||
|
||||
// CHECK: %[[SIZE_1:.*]] = index_cast %[[EL1]] : i64 to index
|
||||
// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPER_DIM_0]], %[[SIZE_1]] : index
|
||||
// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0]], %[[OP_STRIDE_0]] : index
|
||||
|
||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2]]] : tensor<3xi64>
|
||||
// CHECK: %[[SIZE_2:.*]] = index_cast %[[EL2]] : i64 to index
|
||||
// CHECK: %[[EXPAND_2:.*]] = cmpi "slt", %[[OPER_DIM_1]], %[[SIZE_2]] : index
|
||||
// CHECK: %[[STRIDE_2:.*]] = select %[[EXPAND_2]], %[[C0]], %[[C1]] : index
|
||||
|
||||
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]]: memref<?x?xf32> to memref<?x?x?xf32, #map>
|
||||
|
||||
// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>
|
||||
|
||||
// CHECK: "lmhlo.copy"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) : (memref<?x?x?xf32, #map>, memref<?x?x?xf32>) -> ()
|
||||
// CHECK: dealloc %[[RESULT]] : memref<?x?x?xf32>
|
||||
|
||||
|
|
Loading…
Reference in New Issue