[KERNEL_GEN] Switch the pipeline to Linalg-on-Tensors.

PiperOrigin-RevId: 347368063
This commit is contained in:
Alexander Belyaev 2020-12-14 05:46:09 -08:00 committed by TensorFlow MLIR Team
parent 6b439f7eee
commit 8b35a75d4a
3 changed files with 74 additions and 44 deletions

View File

@ -52,6 +52,11 @@ void PopulateGatherToTorchIndexSelectPatterns(
void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns, void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
MLIRContext *ctx); MLIRContext *ctx);
// Collection of rewrite patterns for lowering of dynamic HLOs to LHLO dialect.
void populateDynamicHLOToLHLOConversionPattern(
MLIRContext *context, BufferizeTypeConverter *converter,
OwningRewritePatternList *patterns, bool insert_copy = true);
// Collection of rewrite patterns for lowering of HLO to LHLO dialect. // Collection of rewrite patterns for lowering of HLO to LHLO dialect.
void populateHLOToLHLOConversionPattern(MLIRContext *context, void populateHLOToLHLOConversionPattern(MLIRContext *context,
BufferizeTypeConverter *converter, BufferizeTypeConverter *converter,

View File

@ -192,24 +192,56 @@ struct HloToLhloCustomCallOpConverter
} }
}; };
struct HloToLhloDynamicBroadcastInDimOpConverter // TODO(pifon): Consider inserting lhlo.copy as in
// HloToLhloDynamicBroadcastInDimOpConverter.
struct HloToLhloDynamicReshapeConverter
: public BaseOpConversion<mhlo::DynamicReshapeOp> {
public:
using BaseOpConversion<mhlo::DynamicReshapeOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
mhlo::DynamicReshapeOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
Type result_type;
if (auto ranked_type = op.getType().dyn_cast<RankedTensorType>()) {
result_type =
MemRefType::get(ranked_type.getShape(), ranked_type.getElementType());
} else if (auto unranked_type =
op.getType().dyn_cast<UnrankedTensorType>()) {
result_type = UnrankedMemRefType::get(unranked_type.getElementType(), 0);
} else {
return failure();
}
mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
rewriter.replaceOpWithNewOp<MemRefReshapeOp>(
op, result_type, adaptor.operand(), adaptor.output_shape());
return success();
}
};
class HloToLhloDynamicBroadcastInDimOpConverter
: public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> { : public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> {
public: public:
using BaseOpConversion<mhlo::DynamicBroadcastInDimOp>::BaseOpConversion; HloToLhloDynamicBroadcastInDimOpConverter(TypeConverter& converter,
MLIRContext* ctx,
bool insert_copy = true)
: BaseOpConversion<mhlo::DynamicBroadcastInDimOp>(converter, ctx),
insert_copy_(insert_copy) {}
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
mhlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands, mhlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
Value result = InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
if (insert_copy_) {
auto loc = op.getLoc(); auto loc = op.getLoc();
Value resultBuffer = InsertDynamicAllocAndDealloc( Value result_buffer = InsertDynamicAllocAndDealloc(
loc, op.getResult(), op.output_dimensions(), &rewriter); loc, op.getResult(), op.output_dimensions(), &rewriter);
Value transformed_operand = rewriter.create<lmhlo::CopyOp>(loc, result, result_buffer);
InsertDynamicMemrefCastOp(op, operands.front(), &rewriter); result = result_buffer;
rewriter.create<lmhlo::CopyOp>(loc, transformed_operand, resultBuffer); }
rewriter.replaceOp(op, {result});
rewriter.replaceOp(op, {resultBuffer});
return success(); return success();
} }
@ -307,31 +339,10 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
static_strides, llvm::None, sizes, strides); static_strides, llvm::None, sizes, strides);
return transformed_operand; return transformed_operand;
} }
};
struct HloToLhloDynamicReshapeConverter // Keep the copy semantics and allocate a buffer for the result of the memref
: public BaseOpConversion<mhlo::DynamicReshapeOp> { // cast.
public: bool insert_copy_;
using BaseOpConversion<mhlo::DynamicReshapeOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
mhlo::DynamicReshapeOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
Type result_type;
if (auto ranked_type = op.getType().dyn_cast<RankedTensorType>()) {
result_type =
MemRefType::get(ranked_type.getShape(), ranked_type.getElementType());
} else if (auto unranked_type =
op.getType().dyn_cast<UnrankedTensorType>()) {
result_type = UnrankedMemRefType::get(unranked_type.getElementType(), 0);
} else {
return failure();
}
mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
rewriter.replaceOpWithNewOp<MemRefReshapeOp>(
op, result_type, adaptor.operand(), adaptor.output_shape());
return success();
}
}; };
struct HloToLhloDotGeneralOpConverter struct HloToLhloDotGeneralOpConverter
@ -593,15 +604,22 @@ struct HloLegalizeToLhlo
}; };
} // namespace } // namespace
void populateDynamicHLOToLHLOConversionPattern(
MLIRContext* context, BufferizeTypeConverter* converter,
OwningRewritePatternList* patterns, bool insert_copy) {
patterns->insert<HloToLhloDynamicBroadcastInDimOpConverter>(
*converter, context, insert_copy);
patterns->insert<HloToLhloDynamicReshapeConverter>(*converter, context);
}
void populateHLOToLHLOConversionPattern(MLIRContext* context, void populateHLOToLHLOConversionPattern(MLIRContext* context,
BufferizeTypeConverter* converter, BufferizeTypeConverter* converter,
OwningRewritePatternList* patterns) { OwningRewritePatternList* patterns) {
populateDynamicHLOToLHLOConversionPattern(context, converter, patterns);
// clang-format off // clang-format off
patterns->insert< patterns->insert<
HloToLhloCustomCallOpConverter, HloToLhloCustomCallOpConverter,
HloToLhloDotGeneralOpConverter, HloToLhloDotGeneralOpConverter,
HloToLhloDynamicBroadcastInDimOpConverter,
HloToLhloDynamicReshapeConverter,
HloToLhloOpConverter<mhlo::AbsOp>, HloToLhloOpConverter<mhlo::AbsOp>,
HloToLhloOpConverter<mhlo::AddOp>, HloToLhloOpConverter<mhlo::AddOp>,
HloToLhloOpConverter<mhlo::AndOp>, HloToLhloOpConverter<mhlo::AndOp>,

View File

@ -170,24 +170,31 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) -> index {
return %rank : index return %rank : index
} }
// CHECK: %[[SHAPE:.*]] = tensor_from_elements // CHECK: %[[SHAPE:.*]] = tensor_from_elements
// CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64>
// CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index
// CHECK: %[[C1:.*]] = constant 1 : index // CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1]]] : tensor<3xi64>
// CHECK: %[[SIZE_1:.*]] = index_cast %[[EL1]] : i64 to index
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2]]] : tensor<3xi64>
// CHECK: %[[SIZE_2:.*]] = index_cast %[[EL2]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>
// CHECK: %[[OPER_DIM_1:.*]] = dim %[[OPERAND]], %[[C1]] : memref<?x?xf32> // CHECK: %[[OPER_DIM_1:.*]] = dim %[[OPERAND]], %[[C1]] : memref<?x?xf32>
// CHECK: %[[OP_STRIDE_0:.*]] = muli %[[C1]], %[[OPER_DIM_1]] : index // CHECK: %[[OP_STRIDE_0:.*]] = muli %[[C1]], %[[OPER_DIM_1]] : index
// CHECK: %[[OPER_DIM_0:.*]] = dim %[[OPERAND]], %[[C0]] : memref<?x?xf32> // CHECK: %[[OPER_DIM_0:.*]] = dim %[[OPERAND]], %[[C0]] : memref<?x?xf32>
// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64>
// CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index
// CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1]]] : tensor<3xi64>
// CHECK: %[[SIZE_1:.*]] = index_cast %[[EL1]] : i64 to index
// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPER_DIM_0]], %[[SIZE_1]] : index // CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPER_DIM_0]], %[[SIZE_1]] : index
// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0]], %[[OP_STRIDE_0]] : index // CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0]], %[[OP_STRIDE_0]] : index
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2]]] : tensor<3xi64>
// CHECK: %[[SIZE_2:.*]] = index_cast %[[EL2]] : i64 to index
// CHECK: %[[EXPAND_2:.*]] = cmpi "slt", %[[OPER_DIM_1]], %[[SIZE_2]] : index // CHECK: %[[EXPAND_2:.*]] = cmpi "slt", %[[OPER_DIM_1]], %[[SIZE_2]] : index
// CHECK: %[[STRIDE_2:.*]] = select %[[EXPAND_2]], %[[C0]], %[[C1]] : index // CHECK: %[[STRIDE_2:.*]] = select %[[EXPAND_2]], %[[C0]], %[[C1]] : index
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]]: memref<?x?xf32> to memref<?x?x?xf32, #map> // CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]]: memref<?x?xf32> to memref<?x?x?xf32, #map>
// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>
// CHECK: "lmhlo.copy"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) : (memref<?x?x?xf32, #map>, memref<?x?x?xf32>) -> () // CHECK: "lmhlo.copy"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) : (memref<?x?x?xf32, #map>, memref<?x?x?xf32>) -> ()
// CHECK: dealloc %[[RESULT]] : memref<?x?x?xf32> // CHECK: dealloc %[[RESULT]] : memref<?x?x?xf32>