Use the correct shape when converting mhlo.reshape

If mhlo.reshape is not purely collapsing some consecutive operand
dimensions into result dimensions, we will generate two linalg
reshape op for it: the first one collapses all operand dimensions
into one dimension, and the second one expands it to all result
dimensions. For this case, the number of collapsed/expanded dimensions
should be coming strictly from the operand/result. It is different
from the case where we can generate one linalg reshape. For that case,
the reassociation map should have rank equal to the largest among
operand/result shape.

PiperOrigin-RevId: 354293826
This commit is contained in:
Lei Zhang 2021-01-28 05:36:37 -08:00 committed by TensorFlow MLIR Team
parent e0a7be7fb1
commit 39589add22
3 changed files with 44 additions and 7 deletions

View File

@ -689,8 +689,9 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
if (!VerifyHloOpBufferOrTensorSemantics<isLHLO>(reshape_op)) if (!VerifyHloOpBufferOrTensorSemantics<isLHLO>(reshape_op))
return failure(); return failure();
typename OpTy::Adaptor operands(args);
ShapedType operand_type = ShapedType operand_type =
reshape_op.operand().getType().template cast<ShapedType>(); operands.operand().getType().template cast<ShapedType>();
ShapedType result_type = GetHloOpResultType<isLHLO>(reshape_op); ShapedType result_type = GetHloOpResultType<isLHLO>(reshape_op);
if (!operand_type.hasStaticShape() || !result_type.hasStaticShape()) if (!operand_type.hasStaticShape() || !result_type.hasStaticShape())
@ -708,7 +709,11 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
unsigned curr_src_dim = 0, curr_dst_dim = 0; unsigned curr_src_dim = 0, curr_dst_dim = 0;
SmallVector<linalg::ReassociationExprs, 4> reassociation_map( SmallVector<linalg::ReassociationExprs, 4> reassociation_map(
dst_shape.size()); dst_shape.size());
bool is_expanding_or_collapsing = true;
// First scan all dimensions in the source shapes to see whether we have a
// perfect case where consecutive dimensions in source are collapsed. For
// such case we can just generate one single linalg.reshape.
bool is_collapsing_source = true;
while (curr_src_dim < src_shape.size() && curr_dst_dim < dst_shape.size()) { while (curr_src_dim < src_shape.size() && curr_dst_dim < dst_shape.size()) {
int64_t dst_size = dst_shape[curr_dst_dim]; int64_t dst_size = dst_shape[curr_dst_dim];
int64_t src_size = src_shape[curr_src_dim]; int64_t src_size = src_shape[curr_src_dim];
@ -731,15 +736,17 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
} }
} }
} else { } else {
is_expanding_or_collapsing = false; is_collapsing_source = false;
break; break;
} }
curr_dst_dim++; curr_dst_dim++;
} }
if (curr_src_dim != src_shape.size() || curr_dst_dim != dst_shape.size()) if (curr_src_dim != src_shape.size() || curr_dst_dim != dst_shape.size())
is_expanding_or_collapsing = false; is_collapsing_source = false;
if (!is_expanding_or_collapsing) { // Otherwise, we need to first reduce all source dimensions into one and
// then expand to the destination dimensions.
if (!is_collapsing_source) {
auto get_identity_exprs = [&rewriter](int n) { auto get_identity_exprs = [&rewriter](int n) {
SmallVector<AffineExpr, 4> exprs; SmallVector<AffineExpr, 4> exprs;
for (int i = 0; i < n; ++i) for (int i = 0; i < n; ++i)
@ -751,9 +758,13 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
1, std::multiplies<int64_t>()); 1, std::multiplies<int64_t>());
auto elem_type = operand_type.getElementType(); auto elem_type = operand_type.getElementType();
SmallVector<linalg::ReassociationExprs, 4> collapsing_map = { SmallVector<linalg::ReassociationExprs, 4> collapsing_map = {
get_identity_exprs(dst_shape.size())}; // Use operand_type here because we need to collapse all operands
// dimensions.
get_identity_exprs(operand_type.getShape().size())};
SmallVector<linalg::ReassociationExprs, 4> expanding_map = { SmallVector<linalg::ReassociationExprs, 4> expanding_map = {
get_identity_exprs(src_shape.size())}; // Use result_type here because we need to expand to all result
// dimensions.
get_identity_exprs(result_type.getShape().size())};
if (isLHLO) { if (isLHLO) {
auto collapsed_type = MemRefType::get({total_elems}, elem_type); auto collapsed_type = MemRefType::get({total_elems}, elem_type);

View File

@ -508,6 +508,18 @@ func @reshape_3D_4D(%arg0: tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32> {
// ----- // -----
// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @reshape_4D_3D
func @reshape_4D_3D(%arg0: tensor<1x8x10x3xf32>) -> tensor<1x240x1xf32> {
%0 = "mhlo.reshape"(%arg0) : (tensor<1x8x10x3xf32>) -> tensor<1x240x1xf32>
return %0 : tensor<1x240x1xf32>
}
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]]]
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP2]]]
// -----
// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-LABEL: func @reshape1_4D_4D // CHECK-LABEL: func @reshape1_4D_4D
func @reshape1_4D_4D(%arg0: tensor<4x512x1x1xi32>) -> tensor<1x4x1x512xi32> { func @reshape1_4D_4D(%arg0: tensor<4x512x1x1xi32>) -> tensor<1x4x1x512xi32> {

View File

@ -811,6 +811,20 @@ func @reshape_3D_4D(%arg0: memref<1x49x16xf32>, %arg1: memref<1x784x1x1xf32>) {
// ----- // -----
// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @reshape_4D_3D
func @reshape_4D_3D(%arg0: memref<1x8x10x3xf32>, %arg1: memref<1x240x1xf32>) {
"lmhlo.reshape"(%arg0, %arg1)
: (memref<1x8x10x3xf32>, memref<1x240x1xf32>) -> ()
return
}
// CHECK: linalg.reshape %{{.*}} [#[[RESHAPE_MAP1]]]
// CHECK: linalg.reshape %{{.*}} [#[[RESHAPE_MAP2]]]
// CHECK: linalg.copy
// -----
// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-LABEL: func @reshape1_4D_4D // CHECK-LABEL: func @reshape1_4D_4D
func @reshape1_4D_4D(%arg0: memref<4x512x1x1xi32>, func @reshape1_4D_4D(%arg0: memref<4x512x1x1xi32>,