[KERNEL_GEN] Switch the pipeline to Linalg-on-Tensors.
PiperOrigin-RevId: 347781190
This commit is contained in:
parent
5ab94a00a7
commit
e6e8920921
|
@ -52,6 +52,11 @@ void PopulateGatherToTorchIndexSelectPatterns(
|
||||||
void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
|
void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
|
||||||
MLIRContext *ctx);
|
MLIRContext *ctx);
|
||||||
|
|
||||||
|
// Collection of rewrite patterns for lowering of dynamic HLOs to LHLO dialect.
|
||||||
|
void populateDynamicHLOToLHLOConversionPattern(
|
||||||
|
MLIRContext *context, BufferizeTypeConverter *converter,
|
||||||
|
OwningRewritePatternList *patterns, bool insert_copy = true);
|
||||||
|
|
||||||
// Collection of rewrite patterns for lowering of HLO to LHLO dialect.
|
// Collection of rewrite patterns for lowering of HLO to LHLO dialect.
|
||||||
void populateHLOToLHLOConversionPattern(MLIRContext *context,
|
void populateHLOToLHLOConversionPattern(MLIRContext *context,
|
||||||
BufferizeTypeConverter *converter,
|
BufferizeTypeConverter *converter,
|
||||||
|
|
|
@ -192,24 +192,56 @@ struct HloToLhloCustomCallOpConverter
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct HloToLhloDynamicBroadcastInDimOpConverter
|
// TODO(pifon): Consider inserting lhlo.copy as in
|
||||||
|
// HloToLhloDynamicBroadcastInDimOpConverter.
|
||||||
|
struct HloToLhloDynamicReshapeConverter
|
||||||
|
: public BaseOpConversion<mhlo::DynamicReshapeOp> {
|
||||||
|
public:
|
||||||
|
using BaseOpConversion<mhlo::DynamicReshapeOp>::BaseOpConversion;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
mhlo::DynamicReshapeOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
Type result_type;
|
||||||
|
if (auto ranked_type = op.getType().dyn_cast<RankedTensorType>()) {
|
||||||
|
result_type =
|
||||||
|
MemRefType::get(ranked_type.getShape(), ranked_type.getElementType());
|
||||||
|
} else if (auto unranked_type =
|
||||||
|
op.getType().dyn_cast<UnrankedTensorType>()) {
|
||||||
|
result_type = UnrankedMemRefType::get(unranked_type.getElementType(), 0);
|
||||||
|
} else {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
|
||||||
|
rewriter.replaceOpWithNewOp<MemRefReshapeOp>(
|
||||||
|
op, result_type, adaptor.operand(), adaptor.output_shape());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class HloToLhloDynamicBroadcastInDimOpConverter
|
||||||
: public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> {
|
: public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> {
|
||||||
public:
|
public:
|
||||||
using BaseOpConversion<mhlo::DynamicBroadcastInDimOp>::BaseOpConversion;
|
HloToLhloDynamicBroadcastInDimOpConverter(TypeConverter& converter,
|
||||||
|
MLIRContext* ctx,
|
||||||
|
bool insert_copy = true)
|
||||||
|
: BaseOpConversion<mhlo::DynamicBroadcastInDimOp>(converter, ctx),
|
||||||
|
insert_copy_(insert_copy) {}
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(
|
LogicalResult matchAndRewrite(
|
||||||
mhlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
|
mhlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter& rewriter) const final {
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
auto loc = op.getLoc();
|
Value result = InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
|
||||||
Value resultBuffer = InsertDynamicAllocAndDealloc(
|
|
||||||
loc, op.getResult(), op.output_dimensions(), &rewriter);
|
|
||||||
|
|
||||||
Value transformed_operand =
|
if (insert_copy_) {
|
||||||
InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
|
auto loc = op.getLoc();
|
||||||
rewriter.create<lmhlo::CopyOp>(loc, transformed_operand, resultBuffer);
|
Value result_buffer = InsertDynamicAllocAndDealloc(
|
||||||
|
loc, op.getResult(), op.output_dimensions(), &rewriter);
|
||||||
rewriter.replaceOp(op, {resultBuffer});
|
|
||||||
|
|
||||||
|
rewriter.create<lmhlo::CopyOp>(loc, result, result_buffer);
|
||||||
|
result = result_buffer;
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(op, {result});
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -307,31 +339,10 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
|
||||||
static_strides, llvm::None, sizes, strides);
|
static_strides, llvm::None, sizes, strides);
|
||||||
return transformed_operand;
|
return transformed_operand;
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
struct HloToLhloDynamicReshapeConverter
|
// Keep the copy semantics and allocate a buffer for the result of the memref
|
||||||
: public BaseOpConversion<mhlo::DynamicReshapeOp> {
|
// cast.
|
||||||
public:
|
bool insert_copy_;
|
||||||
using BaseOpConversion<mhlo::DynamicReshapeOp>::BaseOpConversion;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(
|
|
||||||
mhlo::DynamicReshapeOp op, ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter& rewriter) const final {
|
|
||||||
Type result_type;
|
|
||||||
if (auto ranked_type = op.getType().dyn_cast<RankedTensorType>()) {
|
|
||||||
result_type =
|
|
||||||
MemRefType::get(ranked_type.getShape(), ranked_type.getElementType());
|
|
||||||
} else if (auto unranked_type =
|
|
||||||
op.getType().dyn_cast<UnrankedTensorType>()) {
|
|
||||||
result_type = UnrankedMemRefType::get(unranked_type.getElementType(), 0);
|
|
||||||
} else {
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
|
|
||||||
rewriter.replaceOpWithNewOp<MemRefReshapeOp>(
|
|
||||||
op, result_type, adaptor.operand(), adaptor.output_shape());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct HloToLhloDotGeneralOpConverter
|
struct HloToLhloDotGeneralOpConverter
|
||||||
|
@ -593,15 +604,22 @@ struct HloLegalizeToLhlo
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
void populateDynamicHLOToLHLOConversionPattern(
|
||||||
|
MLIRContext* context, BufferizeTypeConverter* converter,
|
||||||
|
OwningRewritePatternList* patterns, bool insert_copy) {
|
||||||
|
patterns->insert<HloToLhloDynamicBroadcastInDimOpConverter>(
|
||||||
|
*converter, context, insert_copy);
|
||||||
|
patterns->insert<HloToLhloDynamicReshapeConverter>(*converter, context);
|
||||||
|
}
|
||||||
|
|
||||||
void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
||||||
BufferizeTypeConverter* converter,
|
BufferizeTypeConverter* converter,
|
||||||
OwningRewritePatternList* patterns) {
|
OwningRewritePatternList* patterns) {
|
||||||
|
populateDynamicHLOToLHLOConversionPattern(context, converter, patterns);
|
||||||
// clang-format off
|
// clang-format off
|
||||||
patterns->insert<
|
patterns->insert<
|
||||||
HloToLhloCustomCallOpConverter,
|
HloToLhloCustomCallOpConverter,
|
||||||
HloToLhloDotGeneralOpConverter,
|
HloToLhloDotGeneralOpConverter,
|
||||||
HloToLhloDynamicBroadcastInDimOpConverter,
|
|
||||||
HloToLhloDynamicReshapeConverter,
|
|
||||||
HloToLhloOpConverter<mhlo::AbsOp>,
|
HloToLhloOpConverter<mhlo::AbsOp>,
|
||||||
HloToLhloOpConverter<mhlo::AddOp>,
|
HloToLhloOpConverter<mhlo::AddOp>,
|
||||||
HloToLhloOpConverter<mhlo::AndOp>,
|
HloToLhloOpConverter<mhlo::AndOp>,
|
||||||
|
|
|
@ -170,24 +170,31 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) -> index {
|
||||||
return %rank : index
|
return %rank : index
|
||||||
}
|
}
|
||||||
// CHECK: %[[SHAPE:.*]] = tensor_from_elements
|
// CHECK: %[[SHAPE:.*]] = tensor_from_elements
|
||||||
|
|
||||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64>
|
|
||||||
// CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index
|
|
||||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||||
// CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1]]] : tensor<3xi64>
|
|
||||||
// CHECK: %[[SIZE_1:.*]] = index_cast %[[EL1]] : i64 to index
|
|
||||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
|
||||||
// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2]]] : tensor<3xi64>
|
|
||||||
// CHECK: %[[SIZE_2:.*]] = index_cast %[[EL2]] : i64 to index
|
|
||||||
// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>
|
|
||||||
// CHECK: %[[OPER_DIM_1:.*]] = dim %[[OPERAND]], %[[C1]] : memref<?x?xf32>
|
// CHECK: %[[OPER_DIM_1:.*]] = dim %[[OPERAND]], %[[C1]] : memref<?x?xf32>
|
||||||
// CHECK: %[[OP_STRIDE_0:.*]] = muli %[[C1]], %[[OPER_DIM_1]] : index
|
// CHECK: %[[OP_STRIDE_0:.*]] = muli %[[C1]], %[[OPER_DIM_1]] : index
|
||||||
// CHECK: %[[OPER_DIM_0:.*]] = dim %[[OPERAND]], %[[C0]] : memref<?x?xf32>
|
// CHECK: %[[OPER_DIM_0:.*]] = dim %[[OPERAND]], %[[C0]] : memref<?x?xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64>
|
||||||
|
// CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index
|
||||||
|
// CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1]]] : tensor<3xi64>
|
||||||
|
|
||||||
|
// CHECK: %[[SIZE_1:.*]] = index_cast %[[EL1]] : i64 to index
|
||||||
// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPER_DIM_0]], %[[SIZE_1]] : index
|
// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPER_DIM_0]], %[[SIZE_1]] : index
|
||||||
// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0]], %[[OP_STRIDE_0]] : index
|
// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0]], %[[OP_STRIDE_0]] : index
|
||||||
|
|
||||||
|
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||||
|
// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2]]] : tensor<3xi64>
|
||||||
|
// CHECK: %[[SIZE_2:.*]] = index_cast %[[EL2]] : i64 to index
|
||||||
// CHECK: %[[EXPAND_2:.*]] = cmpi "slt", %[[OPER_DIM_1]], %[[SIZE_2]] : index
|
// CHECK: %[[EXPAND_2:.*]] = cmpi "slt", %[[OPER_DIM_1]], %[[SIZE_2]] : index
|
||||||
// CHECK: %[[STRIDE_2:.*]] = select %[[EXPAND_2]], %[[C0]], %[[C1]] : index
|
// CHECK: %[[STRIDE_2:.*]] = select %[[EXPAND_2]], %[[C0]], %[[C1]] : index
|
||||||
|
|
||||||
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]]: memref<?x?xf32> to memref<?x?x?xf32, #map>
|
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]]: memref<?x?xf32> to memref<?x?x?xf32, #map>
|
||||||
|
|
||||||
|
// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>
|
||||||
|
|
||||||
// CHECK: "lmhlo.copy"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) : (memref<?x?x?xf32, #map>, memref<?x?x?xf32>) -> ()
|
// CHECK: "lmhlo.copy"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) : (memref<?x?x?xf32, #map>, memref<?x?x?xf32>) -> ()
|
||||||
// CHECK: dealloc %[[RESULT]] : memref<?x?x?xf32>
|
// CHECK: dealloc %[[RESULT]] : memref<?x?x?xf32>
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue