[KERNEL_GEN] Switch the pipeline to Linalg-on-Tensors.

PiperOrigin-RevId: 347781190
This commit is contained in:
Alexander Belyaev 2020-12-16 01:50:12 -08:00 committed by TensorFlow MLIR Team
parent 5ab94a00a7
commit e6e8920921
3 changed files with 74 additions and 44 deletions

View File

@ -52,6 +52,11 @@ void PopulateGatherToTorchIndexSelectPatterns(
void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
MLIRContext *ctx);
// Collection of rewrite patterns for lowering of dynamic HLOs to LHLO dialect.
void populateDynamicHLOToLHLOConversionPattern(
MLIRContext *context, BufferizeTypeConverter *converter,
OwningRewritePatternList *patterns, bool insert_copy = true);
// Collection of rewrite patterns for lowering of HLO to LHLO dialect.
void populateHLOToLHLOConversionPattern(MLIRContext *context,
BufferizeTypeConverter *converter,

View File

@ -192,24 +192,56 @@ struct HloToLhloCustomCallOpConverter
}
};
struct HloToLhloDynamicBroadcastInDimOpConverter
// TODO(pifon): Consider inserting lhlo.copy as in
// HloToLhloDynamicBroadcastInDimOpConverter.
struct HloToLhloDynamicReshapeConverter
: public BaseOpConversion<mhlo::DynamicReshapeOp> {
public:
using BaseOpConversion<mhlo::DynamicReshapeOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
mhlo::DynamicReshapeOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
Type result_type;
if (auto ranked_type = op.getType().dyn_cast<RankedTensorType>()) {
result_type =
MemRefType::get(ranked_type.getShape(), ranked_type.getElementType());
} else if (auto unranked_type =
op.getType().dyn_cast<UnrankedTensorType>()) {
result_type = UnrankedMemRefType::get(unranked_type.getElementType(), 0);
} else {
return failure();
}
mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
rewriter.replaceOpWithNewOp<MemRefReshapeOp>(
op, result_type, adaptor.operand(), adaptor.output_shape());
return success();
}
};
class HloToLhloDynamicBroadcastInDimOpConverter
: public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> {
public:
using BaseOpConversion<mhlo::DynamicBroadcastInDimOp>::BaseOpConversion;
HloToLhloDynamicBroadcastInDimOpConverter(TypeConverter& converter,
MLIRContext* ctx,
bool insert_copy = true)
: BaseOpConversion<mhlo::DynamicBroadcastInDimOp>(converter, ctx),
insert_copy_(insert_copy) {}
LogicalResult matchAndRewrite(
mhlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
Value result = InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
if (insert_copy_) {
auto loc = op.getLoc();
Value resultBuffer = InsertDynamicAllocAndDealloc(
Value result_buffer = InsertDynamicAllocAndDealloc(
loc, op.getResult(), op.output_dimensions(), &rewriter);
Value transformed_operand =
InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
rewriter.create<lmhlo::CopyOp>(loc, transformed_operand, resultBuffer);
rewriter.replaceOp(op, {resultBuffer});
rewriter.create<lmhlo::CopyOp>(loc, result, result_buffer);
result = result_buffer;
}
rewriter.replaceOp(op, {result});
return success();
}
@ -307,31 +339,10 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
static_strides, llvm::None, sizes, strides);
return transformed_operand;
}
};
struct HloToLhloDynamicReshapeConverter
: public BaseOpConversion<mhlo::DynamicReshapeOp> {
public:
using BaseOpConversion<mhlo::DynamicReshapeOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
mhlo::DynamicReshapeOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
Type result_type;
if (auto ranked_type = op.getType().dyn_cast<RankedTensorType>()) {
result_type =
MemRefType::get(ranked_type.getShape(), ranked_type.getElementType());
} else if (auto unranked_type =
op.getType().dyn_cast<UnrankedTensorType>()) {
result_type = UnrankedMemRefType::get(unranked_type.getElementType(), 0);
} else {
return failure();
}
mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
rewriter.replaceOpWithNewOp<MemRefReshapeOp>(
op, result_type, adaptor.operand(), adaptor.output_shape());
return success();
}
// Keep the copy semantics and allocate a buffer for the result of the memref
// cast.
bool insert_copy_;
};
struct HloToLhloDotGeneralOpConverter
@ -593,15 +604,22 @@ struct HloLegalizeToLhlo
};
} // namespace
void populateDynamicHLOToLHLOConversionPattern(
MLIRContext* context, BufferizeTypeConverter* converter,
OwningRewritePatternList* patterns, bool insert_copy) {
patterns->insert<HloToLhloDynamicBroadcastInDimOpConverter>(
*converter, context, insert_copy);
patterns->insert<HloToLhloDynamicReshapeConverter>(*converter, context);
}
void populateHLOToLHLOConversionPattern(MLIRContext* context,
BufferizeTypeConverter* converter,
OwningRewritePatternList* patterns) {
populateDynamicHLOToLHLOConversionPattern(context, converter, patterns);
// clang-format off
patterns->insert<
HloToLhloCustomCallOpConverter,
HloToLhloDotGeneralOpConverter,
HloToLhloDynamicBroadcastInDimOpConverter,
HloToLhloDynamicReshapeConverter,
HloToLhloOpConverter<mhlo::AbsOp>,
HloToLhloOpConverter<mhlo::AddOp>,
HloToLhloOpConverter<mhlo::AndOp>,

View File

@ -170,24 +170,31 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) -> index {
return %rank : index
}
// CHECK: %[[SHAPE:.*]] = tensor_from_elements
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64>
// CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1]]] : tensor<3xi64>
// CHECK: %[[SIZE_1:.*]] = index_cast %[[EL1]] : i64 to index
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2]]] : tensor<3xi64>
// CHECK: %[[SIZE_2:.*]] = index_cast %[[EL2]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>
// CHECK: %[[OPER_DIM_1:.*]] = dim %[[OPERAND]], %[[C1]] : memref<?x?xf32>
// CHECK: %[[OP_STRIDE_0:.*]] = muli %[[C1]], %[[OPER_DIM_1]] : index
// CHECK: %[[OPER_DIM_0:.*]] = dim %[[OPERAND]], %[[C0]] : memref<?x?xf32>
// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64>
// CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index
// CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1]]] : tensor<3xi64>
// CHECK: %[[SIZE_1:.*]] = index_cast %[[EL1]] : i64 to index
// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPER_DIM_0]], %[[SIZE_1]] : index
// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0]], %[[OP_STRIDE_0]] : index
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2]]] : tensor<3xi64>
// CHECK: %[[SIZE_2:.*]] = index_cast %[[EL2]] : i64 to index
// CHECK: %[[EXPAND_2:.*]] = cmpi "slt", %[[OPER_DIM_1]], %[[SIZE_2]] : index
// CHECK: %[[STRIDE_2:.*]] = select %[[EXPAND_2]], %[[C0]], %[[C1]] : index
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]]: memref<?x?xf32> to memref<?x?x?xf32, #map>
// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>
// CHECK: "lmhlo.copy"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) : (memref<?x?x?xf32, #map>, memref<?x?x?xf32>) -> ()
// CHECK: dealloc %[[RESULT]] : memref<?x?x?xf32>