Use the correct shape when converting mhlo.reshape
If mhlo.reshape is not purely collapsing some consecutive operand dimensions into result dimensions, we will generate two linalg reshape op for it: the first one collapses all operand dimensions into one dimension, and the second one expands it to all result dimensions. For this case, the number of collapsed/expanded dimensions should be coming strictly from the operand/result. It is different from the case where we can generate one linalg reshape. For that case, the reassociation map should have rank equal to the largest among operand/result shape. PiperOrigin-RevId: 354293826
This commit is contained in:
parent
e0a7be7fb1
commit
39589add22
|
@ -689,8 +689,9 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
|
|||
ConversionPatternRewriter& rewriter) const final {
|
||||
if (!VerifyHloOpBufferOrTensorSemantics<isLHLO>(reshape_op))
|
||||
return failure();
|
||||
typename OpTy::Adaptor operands(args);
|
||||
ShapedType operand_type =
|
||||
reshape_op.operand().getType().template cast<ShapedType>();
|
||||
operands.operand().getType().template cast<ShapedType>();
|
||||
ShapedType result_type = GetHloOpResultType<isLHLO>(reshape_op);
|
||||
|
||||
if (!operand_type.hasStaticShape() || !result_type.hasStaticShape())
|
||||
|
@ -708,7 +709,11 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
|
|||
unsigned curr_src_dim = 0, curr_dst_dim = 0;
|
||||
SmallVector<linalg::ReassociationExprs, 4> reassociation_map(
|
||||
dst_shape.size());
|
||||
bool is_expanding_or_collapsing = true;
|
||||
|
||||
// First scan all dimensions in the source shapes to see whether we have a
|
||||
// perfect case where consecutive dimensions in source are collapsed. For
|
||||
// such case we can just generate one single linalg.reshape.
|
||||
bool is_collapsing_source = true;
|
||||
while (curr_src_dim < src_shape.size() && curr_dst_dim < dst_shape.size()) {
|
||||
int64_t dst_size = dst_shape[curr_dst_dim];
|
||||
int64_t src_size = src_shape[curr_src_dim];
|
||||
|
@ -731,15 +736,17 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
|
|||
}
|
||||
}
|
||||
} else {
|
||||
is_expanding_or_collapsing = false;
|
||||
is_collapsing_source = false;
|
||||
break;
|
||||
}
|
||||
curr_dst_dim++;
|
||||
}
|
||||
if (curr_src_dim != src_shape.size() || curr_dst_dim != dst_shape.size())
|
||||
is_expanding_or_collapsing = false;
|
||||
is_collapsing_source = false;
|
||||
|
||||
if (!is_expanding_or_collapsing) {
|
||||
// Otherwise, we need to first reduce all source dimensions into one and
|
||||
// then expand to the destination dimensions.
|
||||
if (!is_collapsing_source) {
|
||||
auto get_identity_exprs = [&rewriter](int n) {
|
||||
SmallVector<AffineExpr, 4> exprs;
|
||||
for (int i = 0; i < n; ++i)
|
||||
|
@ -751,9 +758,13 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
|
|||
1, std::multiplies<int64_t>());
|
||||
auto elem_type = operand_type.getElementType();
|
||||
SmallVector<linalg::ReassociationExprs, 4> collapsing_map = {
|
||||
get_identity_exprs(dst_shape.size())};
|
||||
// Use operand_type here because we need to collapse all operands
|
||||
// dimensions.
|
||||
get_identity_exprs(operand_type.getShape().size())};
|
||||
SmallVector<linalg::ReassociationExprs, 4> expanding_map = {
|
||||
get_identity_exprs(src_shape.size())};
|
||||
// Use result_type here because we need to expand to all result
|
||||
// dimensions.
|
||||
get_identity_exprs(result_type.getShape().size())};
|
||||
|
||||
if (isLHLO) {
|
||||
auto collapsed_type = MemRefType::get({total_elems}, elem_type);
|
||||
|
|
|
@ -508,6 +508,18 @@ func @reshape_3D_4D(%arg0: tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-LABEL: func @reshape_4D_3D
|
||||
func @reshape_4D_3D(%arg0: tensor<1x8x10x3xf32>) -> tensor<1x240x1xf32> {
|
||||
%0 = "mhlo.reshape"(%arg0) : (tensor<1x8x10x3xf32>) -> tensor<1x240x1xf32>
|
||||
return %0 : tensor<1x240x1xf32>
|
||||
}
|
||||
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]]]
|
||||
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP2]]]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
// CHECK-LABEL: func @reshape1_4D_4D
|
||||
func @reshape1_4D_4D(%arg0: tensor<4x512x1x1xi32>) -> tensor<1x4x1x512xi32> {
|
||||
|
|
|
@ -811,6 +811,20 @@ func @reshape_3D_4D(%arg0: memref<1x49x16xf32>, %arg1: memref<1x784x1x1xf32>) {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-LABEL: func @reshape_4D_3D
|
||||
func @reshape_4D_3D(%arg0: memref<1x8x10x3xf32>, %arg1: memref<1x240x1xf32>) {
|
||||
"lmhlo.reshape"(%arg0, %arg1)
|
||||
: (memref<1x8x10x3xf32>, memref<1x240x1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.reshape %{{.*}} [#[[RESHAPE_MAP1]]]
|
||||
// CHECK: linalg.reshape %{{.*}} [#[[RESHAPE_MAP2]]]
|
||||
// CHECK: linalg.copy
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
// CHECK-LABEL: func @reshape1_4D_4D
|
||||
func @reshape1_4D_4D(%arg0: memref<4x512x1x1xi32>,
|
||||
|
|
Loading…
Reference in New Issue