[MLIR][LHLO] Convert mhlo.dynamic_reshape -> lhlo.reshape_memref_cast.
PiperOrigin-RevId: 320149593
This commit is contained in:
parent
8692fde3f9
commit
e8cfdee592
|
@ -470,14 +470,6 @@ def ReshapeMemRefCastOp: Op<LHLO_Dialect, "reshape_memref_cast", [
|
|||
);
|
||||
let results = (outs AnyRankedOrUnrankedMemRef:$result);
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"OpBuilder &builder, OperationState &result, MemRefType resultType, " #
|
||||
"Value operand, Value shape", [{
|
||||
result.addOperands(operand);
|
||||
result.addOperands(shape);
|
||||
result.types.push_back(resultType);
|
||||
}]>];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
|
||||
}];
|
||||
|
|
|
@ -220,6 +220,31 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
|
|||
}
|
||||
};
|
||||
|
||||
struct HloToLhloDynamicReshapeConverter
|
||||
: public BaseOpConversion<mhlo::DynamicReshapeOp> {
|
||||
public:
|
||||
using BaseOpConversion<mhlo::DynamicReshapeOp>::BaseOpConversion;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
mhlo::DynamicReshapeOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
Type result_type;
|
||||
if (auto ranked_type = op.getType().dyn_cast<RankedTensorType>()) {
|
||||
result_type =
|
||||
MemRefType::get(ranked_type.getShape(), ranked_type.getElementType());
|
||||
} else if (auto unranked_type =
|
||||
op.getType().dyn_cast<UnrankedTensorType>()) {
|
||||
result_type = UnrankedMemRefType::get(unranked_type.getElementType(), 0);
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
|
||||
rewriter.replaceOpWithNewOp<xla_lhlo::ReshapeMemRefCastOp>(
|
||||
op, result_type, adaptor.operand(), adaptor.output_shape());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
|
||||
public:
|
||||
using BaseOpConversion<mhlo::ReduceOp>::BaseOpConversion;
|
||||
|
@ -441,6 +466,7 @@ void populateHLOToLHLOConversionPattern(
|
|||
// clang-format off
|
||||
patterns->insert<
|
||||
HloToLhloDynamicBroadcastInDimOpConverter,
|
||||
HloToLhloDynamicReshapeConverter,
|
||||
HloToLhloOpConverter<mhlo::AbsOp>,
|
||||
HloToLhloOpConverter<mhlo::AddOp>,
|
||||
HloToLhloOpConverter<mhlo::AndOp>,
|
||||
|
|
|
@ -6,3 +6,29 @@ func @func_op_unranked_arg_result(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
|||
}
|
||||
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>) -> memref<*xf32>
|
||||
// CHECK-NEXT: return [[ARG]] : memref<*xf32>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @dynamic_reshape_from_unranked
|
||||
func @dynamic_reshape_from_unranked(
|
||||
%operand: tensor<*xf32>, %shape: tensor<1xi32>) -> tensor<?xf32> {
|
||||
%reshaped = "mhlo.dynamic_reshape"(%operand, %shape)
|
||||
: (tensor<*xf32>, tensor<1xi32>) -> tensor<?xf32>
|
||||
return %reshaped : tensor<?xf32>
|
||||
}
|
||||
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>, [[SHAPE:%.*]]: memref<1xi32>)
|
||||
// CHECK-NEXT: reshape_memref_cast [[ARG]]([[SHAPE]])
|
||||
// CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @dynamic_reshape_to_unranked
|
||||
func @dynamic_reshape_to_unranked(
|
||||
%operand: tensor<?xf32>, %shape: tensor<?xi32>) -> tensor<*xf32> {
|
||||
%reshaped = "mhlo.dynamic_reshape"(%operand, %shape)
|
||||
: (tensor<?xf32>, tensor<?xi32>) -> tensor<*xf32>
|
||||
return %reshaped : tensor<*xf32>
|
||||
}
|
||||
// CHECK-SAME: ([[ARG:%.*]]: memref<?xf32>, [[SHAPE:%.*]]: memref<?xi32>)
|
||||
// CHECK-NEXT: reshape_memref_cast [[ARG]]([[SHAPE]])
|
||||
// CHECK-SAME: : (memref<?xf32>, memref<?xi32>) -> memref<*xf32>
|
||||
|
|
Loading…
Reference in New Issue