[MLIR][LHLO] Convert mhlo.dynamic_reshape -> lhlo.reshape_memref_cast.
PiperOrigin-RevId: 320149593
This commit is contained in:
parent
8692fde3f9
commit
e8cfdee592
|
@ -470,14 +470,6 @@ def ReshapeMemRefCastOp: Op<LHLO_Dialect, "reshape_memref_cast", [
|
||||||
);
|
);
|
||||||
let results = (outs AnyRankedOrUnrankedMemRef:$result);
|
let results = (outs AnyRankedOrUnrankedMemRef:$result);
|
||||||
|
|
||||||
let builders = [OpBuilder<
|
|
||||||
"OpBuilder &builder, OperationState &result, MemRefType resultType, " #
|
|
||||||
"Value operand, Value shape", [{
|
|
||||||
result.addOperands(operand);
|
|
||||||
result.addOperands(shape);
|
|
||||||
result.types.push_back(resultType);
|
|
||||||
}]>];
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
|
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
|
||||||
}];
|
}];
|
||||||
|
|
|
@ -220,6 +220,31 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct HloToLhloDynamicReshapeConverter
|
||||||
|
: public BaseOpConversion<mhlo::DynamicReshapeOp> {
|
||||||
|
public:
|
||||||
|
using BaseOpConversion<mhlo::DynamicReshapeOp>::BaseOpConversion;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
mhlo::DynamicReshapeOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
Type result_type;
|
||||||
|
if (auto ranked_type = op.getType().dyn_cast<RankedTensorType>()) {
|
||||||
|
result_type =
|
||||||
|
MemRefType::get(ranked_type.getShape(), ranked_type.getElementType());
|
||||||
|
} else if (auto unranked_type =
|
||||||
|
op.getType().dyn_cast<UnrankedTensorType>()) {
|
||||||
|
result_type = UnrankedMemRefType::get(unranked_type.getElementType(), 0);
|
||||||
|
} else {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
|
||||||
|
rewriter.replaceOpWithNewOp<xla_lhlo::ReshapeMemRefCastOp>(
|
||||||
|
op, result_type, adaptor.operand(), adaptor.output_shape());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
|
struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
|
||||||
public:
|
public:
|
||||||
using BaseOpConversion<mhlo::ReduceOp>::BaseOpConversion;
|
using BaseOpConversion<mhlo::ReduceOp>::BaseOpConversion;
|
||||||
|
@ -441,6 +466,7 @@ void populateHLOToLHLOConversionPattern(
|
||||||
// clang-format off
|
// clang-format off
|
||||||
patterns->insert<
|
patterns->insert<
|
||||||
HloToLhloDynamicBroadcastInDimOpConverter,
|
HloToLhloDynamicBroadcastInDimOpConverter,
|
||||||
|
HloToLhloDynamicReshapeConverter,
|
||||||
HloToLhloOpConverter<mhlo::AbsOp>,
|
HloToLhloOpConverter<mhlo::AbsOp>,
|
||||||
HloToLhloOpConverter<mhlo::AddOp>,
|
HloToLhloOpConverter<mhlo::AddOp>,
|
||||||
HloToLhloOpConverter<mhlo::AndOp>,
|
HloToLhloOpConverter<mhlo::AndOp>,
|
||||||
|
|
|
@ -6,3 +6,29 @@ func @func_op_unranked_arg_result(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
}
|
}
|
||||||
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>) -> memref<*xf32>
|
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>) -> memref<*xf32>
|
||||||
// CHECK-NEXT: return [[ARG]] : memref<*xf32>
|
// CHECK-NEXT: return [[ARG]] : memref<*xf32>
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @dynamic_reshape_from_unranked
|
||||||
|
func @dynamic_reshape_from_unranked(
|
||||||
|
%operand: tensor<*xf32>, %shape: tensor<1xi32>) -> tensor<?xf32> {
|
||||||
|
%reshaped = "mhlo.dynamic_reshape"(%operand, %shape)
|
||||||
|
: (tensor<*xf32>, tensor<1xi32>) -> tensor<?xf32>
|
||||||
|
return %reshaped : tensor<?xf32>
|
||||||
|
}
|
||||||
|
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>, [[SHAPE:%.*]]: memref<1xi32>)
|
||||||
|
// CHECK-NEXT: reshape_memref_cast [[ARG]]([[SHAPE]])
|
||||||
|
// CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @dynamic_reshape_to_unranked
|
||||||
|
func @dynamic_reshape_to_unranked(
|
||||||
|
%operand: tensor<?xf32>, %shape: tensor<?xi32>) -> tensor<*xf32> {
|
||||||
|
%reshaped = "mhlo.dynamic_reshape"(%operand, %shape)
|
||||||
|
: (tensor<?xf32>, tensor<?xi32>) -> tensor<*xf32>
|
||||||
|
return %reshaped : tensor<*xf32>
|
||||||
|
}
|
||||||
|
// CHECK-SAME: ([[ARG:%.*]]: memref<?xf32>, [[SHAPE:%.*]]: memref<?xi32>)
|
||||||
|
// CHECK-NEXT: reshape_memref_cast [[ARG]]([[SHAPE]])
|
||||||
|
// CHECK-SAME: : (memref<?xf32>, memref<?xi32>) -> memref<*xf32>
|
||||||
|
|
Loading…
Reference in New Issue