Use the correct shape when converting mhlo.reshape
If mhlo.reshape is not purely collapsing some consecutive operand dimensions into result dimensions, we will generate two linalg reshape op for it: the first one collapses all operand dimensions into one dimension, and the second one expands it to all result dimensions. For this case, the number of collapsed/expanded dimensions should be coming strictly from the operand/result. It is different from the case where we can generate one linalg reshape. For that case, the reassociation map should have rank equal to the largest among operand/result shape. PiperOrigin-RevId: 354293826
This commit is contained in:
		
							parent
							
								
									e0a7be7fb1
								
							
						
					
					
						commit
						39589add22
					
				| 
						 | 
					@ -689,8 +689,9 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
 | 
				
			||||||
      ConversionPatternRewriter& rewriter) const final {
 | 
					      ConversionPatternRewriter& rewriter) const final {
 | 
				
			||||||
    if (!VerifyHloOpBufferOrTensorSemantics<isLHLO>(reshape_op))
 | 
					    if (!VerifyHloOpBufferOrTensorSemantics<isLHLO>(reshape_op))
 | 
				
			||||||
      return failure();
 | 
					      return failure();
 | 
				
			||||||
 | 
					    typename OpTy::Adaptor operands(args);
 | 
				
			||||||
    ShapedType operand_type =
 | 
					    ShapedType operand_type =
 | 
				
			||||||
        reshape_op.operand().getType().template cast<ShapedType>();
 | 
					        operands.operand().getType().template cast<ShapedType>();
 | 
				
			||||||
    ShapedType result_type = GetHloOpResultType<isLHLO>(reshape_op);
 | 
					    ShapedType result_type = GetHloOpResultType<isLHLO>(reshape_op);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (!operand_type.hasStaticShape() || !result_type.hasStaticShape())
 | 
					    if (!operand_type.hasStaticShape() || !result_type.hasStaticShape())
 | 
				
			||||||
| 
						 | 
					@ -708,7 +709,11 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
 | 
				
			||||||
    unsigned curr_src_dim = 0, curr_dst_dim = 0;
 | 
					    unsigned curr_src_dim = 0, curr_dst_dim = 0;
 | 
				
			||||||
    SmallVector<linalg::ReassociationExprs, 4> reassociation_map(
 | 
					    SmallVector<linalg::ReassociationExprs, 4> reassociation_map(
 | 
				
			||||||
        dst_shape.size());
 | 
					        dst_shape.size());
 | 
				
			||||||
    bool is_expanding_or_collapsing = true;
 | 
					
 | 
				
			||||||
 | 
					    // First scan all dimensions in the source shapes to see whether we have a
 | 
				
			||||||
 | 
					    // perfect case where consecutive dimensions in source are collapsed. For
 | 
				
			||||||
 | 
					    // such case we can just generate one single linalg.reshape.
 | 
				
			||||||
 | 
					    bool is_collapsing_source = true;
 | 
				
			||||||
    while (curr_src_dim < src_shape.size() && curr_dst_dim < dst_shape.size()) {
 | 
					    while (curr_src_dim < src_shape.size() && curr_dst_dim < dst_shape.size()) {
 | 
				
			||||||
      int64_t dst_size = dst_shape[curr_dst_dim];
 | 
					      int64_t dst_size = dst_shape[curr_dst_dim];
 | 
				
			||||||
      int64_t src_size = src_shape[curr_src_dim];
 | 
					      int64_t src_size = src_shape[curr_src_dim];
 | 
				
			||||||
| 
						 | 
					@ -731,15 +736,17 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
 | 
				
			||||||
          }
 | 
					          }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
      } else {
 | 
					      } else {
 | 
				
			||||||
        is_expanding_or_collapsing = false;
 | 
					        is_collapsing_source = false;
 | 
				
			||||||
        break;
 | 
					        break;
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      curr_dst_dim++;
 | 
					      curr_dst_dim++;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    if (curr_src_dim != src_shape.size() || curr_dst_dim != dst_shape.size())
 | 
					    if (curr_src_dim != src_shape.size() || curr_dst_dim != dst_shape.size())
 | 
				
			||||||
      is_expanding_or_collapsing = false;
 | 
					      is_collapsing_source = false;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (!is_expanding_or_collapsing) {
 | 
					    // Otherwise, we need to first reduce all source dimensions into one and
 | 
				
			||||||
 | 
					    // then expand to the destination dimensions.
 | 
				
			||||||
 | 
					    if (!is_collapsing_source) {
 | 
				
			||||||
      auto get_identity_exprs = [&rewriter](int n) {
 | 
					      auto get_identity_exprs = [&rewriter](int n) {
 | 
				
			||||||
        SmallVector<AffineExpr, 4> exprs;
 | 
					        SmallVector<AffineExpr, 4> exprs;
 | 
				
			||||||
        for (int i = 0; i < n; ++i)
 | 
					        for (int i = 0; i < n; ++i)
 | 
				
			||||||
| 
						 | 
					@ -751,9 +758,13 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
 | 
				
			||||||
                                            1, std::multiplies<int64_t>());
 | 
					                                            1, std::multiplies<int64_t>());
 | 
				
			||||||
      auto elem_type = operand_type.getElementType();
 | 
					      auto elem_type = operand_type.getElementType();
 | 
				
			||||||
      SmallVector<linalg::ReassociationExprs, 4> collapsing_map = {
 | 
					      SmallVector<linalg::ReassociationExprs, 4> collapsing_map = {
 | 
				
			||||||
          get_identity_exprs(dst_shape.size())};
 | 
					          // Use operand_type here because we need to collapse all operands
 | 
				
			||||||
 | 
					          // dimensions.
 | 
				
			||||||
 | 
					          get_identity_exprs(operand_type.getShape().size())};
 | 
				
			||||||
      SmallVector<linalg::ReassociationExprs, 4> expanding_map = {
 | 
					      SmallVector<linalg::ReassociationExprs, 4> expanding_map = {
 | 
				
			||||||
          get_identity_exprs(src_shape.size())};
 | 
					          // Use result_type here because we need to expand to all result
 | 
				
			||||||
 | 
					          // dimensions.
 | 
				
			||||||
 | 
					          get_identity_exprs(result_type.getShape().size())};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      if (isLHLO) {
 | 
					      if (isLHLO) {
 | 
				
			||||||
        auto collapsed_type = MemRefType::get({total_elems}, elem_type);
 | 
					        auto collapsed_type = MemRefType::get({total_elems}, elem_type);
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -508,6 +508,18 @@ func @reshape_3D_4D(%arg0: tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32> {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// -----
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 | 
				
			||||||
 | 
					// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 | 
				
			||||||
 | 
					// CHECK-LABEL: func @reshape_4D_3D
 | 
				
			||||||
 | 
					func @reshape_4D_3D(%arg0: tensor<1x8x10x3xf32>) -> tensor<1x240x1xf32> {
 | 
				
			||||||
 | 
					  %0 = "mhlo.reshape"(%arg0) : (tensor<1x8x10x3xf32>) -> tensor<1x240x1xf32>
 | 
				
			||||||
 | 
					  return %0 : tensor<1x240x1xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]]]
 | 
				
			||||||
 | 
					// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP2]]]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 | 
					// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 | 
				
			||||||
// CHECK-LABEL: func @reshape1_4D_4D
 | 
					// CHECK-LABEL: func @reshape1_4D_4D
 | 
				
			||||||
func @reshape1_4D_4D(%arg0: tensor<4x512x1x1xi32>) -> tensor<1x4x1x512xi32> {
 | 
					func @reshape1_4D_4D(%arg0: tensor<4x512x1x1xi32>) -> tensor<1x4x1x512xi32> {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -811,6 +811,20 @@ func @reshape_3D_4D(%arg0: memref<1x49x16xf32>, %arg1: memref<1x784x1x1xf32>) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// -----
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 | 
				
			||||||
 | 
					// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 | 
				
			||||||
 | 
					// CHECK-LABEL: func @reshape_4D_3D
 | 
				
			||||||
 | 
					func @reshape_4D_3D(%arg0: memref<1x8x10x3xf32>, %arg1: memref<1x240x1xf32>) {
 | 
				
			||||||
 | 
					  "lmhlo.reshape"(%arg0, %arg1)
 | 
				
			||||||
 | 
					   : (memref<1x8x10x3xf32>, memref<1x240x1xf32>) -> ()
 | 
				
			||||||
 | 
					  return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					// CHECK: linalg.reshape %{{.*}} [#[[RESHAPE_MAP1]]]
 | 
				
			||||||
 | 
					// CHECK: linalg.reshape %{{.*}} [#[[RESHAPE_MAP2]]]
 | 
				
			||||||
 | 
					// CHECK: linalg.copy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 | 
					// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 | 
				
			||||||
// CHECK-LABEL: func @reshape1_4D_4D
 | 
					// CHECK-LABEL: func @reshape1_4D_4D
 | 
				
			||||||
func @reshape1_4D_4D(%arg0: memref<4x512x1x1xi32>,
 | 
					func @reshape1_4D_4D(%arg0: memref<4x512x1x1xi32>,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue