Use the correct shape when converting mhlo.reshape
If mhlo.reshape is not purely collapsing some consecutive operand dimensions into result dimensions, we will generate two linalg reshape op for it: the first one collapses all operand dimensions into one dimension, and the second one expands it to all result dimensions. For this case, the number of collapsed/expanded dimensions should be coming strictly from the operand/result. It is different from the case where we can generate one linalg reshape. For that case, the reassociation map should have rank equal to the largest among operand/result shape. PiperOrigin-RevId: 354293826
This commit is contained in:
parent
e0a7be7fb1
commit
39589add22
|
@ -689,8 +689,9 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
|
||||||
ConversionPatternRewriter& rewriter) const final {
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
if (!VerifyHloOpBufferOrTensorSemantics<isLHLO>(reshape_op))
|
if (!VerifyHloOpBufferOrTensorSemantics<isLHLO>(reshape_op))
|
||||||
return failure();
|
return failure();
|
||||||
|
typename OpTy::Adaptor operands(args);
|
||||||
ShapedType operand_type =
|
ShapedType operand_type =
|
||||||
reshape_op.operand().getType().template cast<ShapedType>();
|
operands.operand().getType().template cast<ShapedType>();
|
||||||
ShapedType result_type = GetHloOpResultType<isLHLO>(reshape_op);
|
ShapedType result_type = GetHloOpResultType<isLHLO>(reshape_op);
|
||||||
|
|
||||||
if (!operand_type.hasStaticShape() || !result_type.hasStaticShape())
|
if (!operand_type.hasStaticShape() || !result_type.hasStaticShape())
|
||||||
|
@ -708,7 +709,11 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
|
||||||
unsigned curr_src_dim = 0, curr_dst_dim = 0;
|
unsigned curr_src_dim = 0, curr_dst_dim = 0;
|
||||||
SmallVector<linalg::ReassociationExprs, 4> reassociation_map(
|
SmallVector<linalg::ReassociationExprs, 4> reassociation_map(
|
||||||
dst_shape.size());
|
dst_shape.size());
|
||||||
bool is_expanding_or_collapsing = true;
|
|
||||||
|
// First scan all dimensions in the source shapes to see whether we have a
|
||||||
|
// perfect case where consecutive dimensions in source are collapsed. For
|
||||||
|
// such case we can just generate one single linalg.reshape.
|
||||||
|
bool is_collapsing_source = true;
|
||||||
while (curr_src_dim < src_shape.size() && curr_dst_dim < dst_shape.size()) {
|
while (curr_src_dim < src_shape.size() && curr_dst_dim < dst_shape.size()) {
|
||||||
int64_t dst_size = dst_shape[curr_dst_dim];
|
int64_t dst_size = dst_shape[curr_dst_dim];
|
||||||
int64_t src_size = src_shape[curr_src_dim];
|
int64_t src_size = src_shape[curr_src_dim];
|
||||||
|
@ -731,15 +736,17 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
is_expanding_or_collapsing = false;
|
is_collapsing_source = false;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
curr_dst_dim++;
|
curr_dst_dim++;
|
||||||
}
|
}
|
||||||
if (curr_src_dim != src_shape.size() || curr_dst_dim != dst_shape.size())
|
if (curr_src_dim != src_shape.size() || curr_dst_dim != dst_shape.size())
|
||||||
is_expanding_or_collapsing = false;
|
is_collapsing_source = false;
|
||||||
|
|
||||||
if (!is_expanding_or_collapsing) {
|
// Otherwise, we need to first reduce all source dimensions into one and
|
||||||
|
// then expand to the destination dimensions.
|
||||||
|
if (!is_collapsing_source) {
|
||||||
auto get_identity_exprs = [&rewriter](int n) {
|
auto get_identity_exprs = [&rewriter](int n) {
|
||||||
SmallVector<AffineExpr, 4> exprs;
|
SmallVector<AffineExpr, 4> exprs;
|
||||||
for (int i = 0; i < n; ++i)
|
for (int i = 0; i < n; ++i)
|
||||||
|
@ -751,9 +758,13 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
|
||||||
1, std::multiplies<int64_t>());
|
1, std::multiplies<int64_t>());
|
||||||
auto elem_type = operand_type.getElementType();
|
auto elem_type = operand_type.getElementType();
|
||||||
SmallVector<linalg::ReassociationExprs, 4> collapsing_map = {
|
SmallVector<linalg::ReassociationExprs, 4> collapsing_map = {
|
||||||
get_identity_exprs(dst_shape.size())};
|
// Use operand_type here because we need to collapse all operands
|
||||||
|
// dimensions.
|
||||||
|
get_identity_exprs(operand_type.getShape().size())};
|
||||||
SmallVector<linalg::ReassociationExprs, 4> expanding_map = {
|
SmallVector<linalg::ReassociationExprs, 4> expanding_map = {
|
||||||
get_identity_exprs(src_shape.size())};
|
// Use result_type here because we need to expand to all result
|
||||||
|
// dimensions.
|
||||||
|
get_identity_exprs(result_type.getShape().size())};
|
||||||
|
|
||||||
if (isLHLO) {
|
if (isLHLO) {
|
||||||
auto collapsed_type = MemRefType::get({total_elems}, elem_type);
|
auto collapsed_type = MemRefType::get({total_elems}, elem_type);
|
||||||
|
|
|
@ -508,6 +508,18 @@ func @reshape_3D_4D(%arg0: tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||||
|
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||||
|
// CHECK-LABEL: func @reshape_4D_3D
|
||||||
|
func @reshape_4D_3D(%arg0: tensor<1x8x10x3xf32>) -> tensor<1x240x1xf32> {
|
||||||
|
%0 = "mhlo.reshape"(%arg0) : (tensor<1x8x10x3xf32>) -> tensor<1x240x1xf32>
|
||||||
|
return %0 : tensor<1x240x1xf32>
|
||||||
|
}
|
||||||
|
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]]]
|
||||||
|
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP2]]]
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||||
// CHECK-LABEL: func @reshape1_4D_4D
|
// CHECK-LABEL: func @reshape1_4D_4D
|
||||||
func @reshape1_4D_4D(%arg0: tensor<4x512x1x1xi32>) -> tensor<1x4x1x512xi32> {
|
func @reshape1_4D_4D(%arg0: tensor<4x512x1x1xi32>) -> tensor<1x4x1x512xi32> {
|
||||||
|
|
|
@ -811,6 +811,20 @@ func @reshape_3D_4D(%arg0: memref<1x49x16xf32>, %arg1: memref<1x784x1x1xf32>) {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||||
|
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||||
|
// CHECK-LABEL: func @reshape_4D_3D
|
||||||
|
func @reshape_4D_3D(%arg0: memref<1x8x10x3xf32>, %arg1: memref<1x240x1xf32>) {
|
||||||
|
"lmhlo.reshape"(%arg0, %arg1)
|
||||||
|
: (memref<1x8x10x3xf32>, memref<1x240x1xf32>) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: linalg.reshape %{{.*}} [#[[RESHAPE_MAP1]]]
|
||||||
|
// CHECK: linalg.reshape %{{.*}} [#[[RESHAPE_MAP2]]]
|
||||||
|
// CHECK: linalg.copy
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||||
// CHECK-LABEL: func @reshape1_4D_4D
|
// CHECK-LABEL: func @reshape1_4D_4D
|
||||||
func @reshape1_4D_4D(%arg0: memref<4x512x1x1xi32>,
|
func @reshape1_4D_4D(%arg0: memref<4x512x1x1xi32>,
|
||||||
|
|
Loading…
Reference in New Issue