[MLIR][LHLO] Convert mhlo.dynamic_reshape -> lhlo.reshape_memref_cast.

PiperOrigin-RevId: 320149593
This commit is contained in:
Alexander Belyaev 2020-07-08 09:11:30 +00:00 committed by Mehdi Amini
parent 8692fde3f9
commit e8cfdee592
3 changed files with 52 additions and 8 deletions

View File

@ -470,14 +470,6 @@ def ReshapeMemRefCastOp: Op<LHLO_Dialect, "reshape_memref_cast", [
);
let results = (outs AnyRankedOrUnrankedMemRef:$result);
let builders = [OpBuilder<
"OpBuilder &builder, OperationState &result, MemRefType resultType, " #
"Value operand, Value shape", [{
result.addOperands(operand);
result.addOperands(shape);
result.types.push_back(resultType);
}]>];
let extraClassDeclaration = [{
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
}];

View File

@ -220,6 +220,31 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
}
};
struct HloToLhloDynamicReshapeConverter
: public BaseOpConversion<mhlo::DynamicReshapeOp> {
public:
using BaseOpConversion<mhlo::DynamicReshapeOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
mhlo::DynamicReshapeOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
Type result_type;
if (auto ranked_type = op.getType().dyn_cast<RankedTensorType>()) {
result_type =
MemRefType::get(ranked_type.getShape(), ranked_type.getElementType());
} else if (auto unranked_type =
op.getType().dyn_cast<UnrankedTensorType>()) {
result_type = UnrankedMemRefType::get(unranked_type.getElementType(), 0);
} else {
return failure();
}
mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
rewriter.replaceOpWithNewOp<xla_lhlo::ReshapeMemRefCastOp>(
op, result_type, adaptor.operand(), adaptor.output_shape());
return success();
}
};
struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
public:
using BaseOpConversion<mhlo::ReduceOp>::BaseOpConversion;
@ -441,6 +466,7 @@ void populateHLOToLHLOConversionPattern(
// clang-format off
patterns->insert<
HloToLhloDynamicBroadcastInDimOpConverter,
HloToLhloDynamicReshapeConverter,
HloToLhloOpConverter<mhlo::AbsOp>,
HloToLhloOpConverter<mhlo::AddOp>,
HloToLhloOpConverter<mhlo::AndOp>,

View File

@ -6,3 +6,29 @@ func @func_op_unranked_arg_result(%arg0: tensor<*xf32>) -> tensor<*xf32> {
}
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>) -> memref<*xf32>
// CHECK-NEXT: return [[ARG]] : memref<*xf32>
// -----
// CHECK-LABEL: func @dynamic_reshape_from_unranked
func @dynamic_reshape_from_unranked(
%operand: tensor<*xf32>, %shape: tensor<1xi32>) -> tensor<?xf32> {
%reshaped = "mhlo.dynamic_reshape"(%operand, %shape)
: (tensor<*xf32>, tensor<1xi32>) -> tensor<?xf32>
return %reshaped : tensor<?xf32>
}
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>, [[SHAPE:%.*]]: memref<1xi32>)
// CHECK-NEXT: reshape_memref_cast [[ARG]]([[SHAPE]])
// CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
// -----
// CHECK-LABEL: func @dynamic_reshape_to_unranked
func @dynamic_reshape_to_unranked(
%operand: tensor<?xf32>, %shape: tensor<?xi32>) -> tensor<*xf32> {
%reshaped = "mhlo.dynamic_reshape"(%operand, %shape)
: (tensor<?xf32>, tensor<?xi32>) -> tensor<*xf32>
return %reshaped : tensor<*xf32>
}
// CHECK-SAME: ([[ARG:%.*]]: memref<?xf32>, [[SHAPE:%.*]]: memref<?xi32>)
// CHECK-NEXT: reshape_memref_cast [[ARG]]([[SHAPE]])
// CHECK-SAME: : (memref<?xf32>, memref<?xi32>) -> memref<*xf32>