Enhance lowering reshape op to Linalg.
Handle non-expansion and non-collapsion cases by rewriting it to two reshape ops. PiperOrigin-RevId: 327926863
This commit is contained in:
		
							parent
							
								
									d2c9d03f31
								
							
						
					
					
						commit
						bfd629ecb0
					
				| 
						 | 
				
			
			@ -15,6 +15,8 @@ limitations under the License.
 | 
			
		|||
 | 
			
		||||
// This file implements logic for lowering HLO/LHLO dialect to Linalg dialect.
 | 
			
		||||
 | 
			
		||||
#include <numeric>
 | 
			
		||||
 | 
			
		||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 | 
			
		||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 | 
			
		||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -598,6 +600,7 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
 | 
			
		|||
    unsigned currSrcDim = 0, currDstDim = 0;
 | 
			
		||||
    SmallVector<linalg::ReassociationExprs, 4> reassociationMap(
 | 
			
		||||
        dstShape.size());
 | 
			
		||||
    bool isExpandingOrCollapsing = true;
 | 
			
		||||
    while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
 | 
			
		||||
      int64_t dstSize = dstShape[currDstDim];
 | 
			
		||||
      int64_t srcSize = srcShape[currSrcDim];
 | 
			
		||||
| 
						 | 
				
			
			@ -619,11 +622,47 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
 | 
			
		|||
          }
 | 
			
		||||
        }
 | 
			
		||||
      } else {
 | 
			
		||||
        return failure();
 | 
			
		||||
        isExpandingOrCollapsing = false;
 | 
			
		||||
        break;
 | 
			
		||||
      }
 | 
			
		||||
      currDstDim++;
 | 
			
		||||
    }
 | 
			
		||||
    if (currSrcDim != srcShape.size()) return failure();
 | 
			
		||||
    if (currSrcDim != srcShape.size()) isExpandingOrCollapsing = false;
 | 
			
		||||
 | 
			
		||||
    if (!isExpandingOrCollapsing) {
 | 
			
		||||
      auto getIdentityExprs = [&rewriter](int n) {
 | 
			
		||||
        SmallVector<AffineExpr, 4> exprs;
 | 
			
		||||
        for (int i = 0; i < n; ++i)
 | 
			
		||||
          exprs.push_back(rewriter.getAffineDimExpr(i));
 | 
			
		||||
        return exprs;
 | 
			
		||||
      };
 | 
			
		||||
      Location loc = reshapeOp.getLoc();
 | 
			
		||||
      int64_t totalElems = std::accumulate(srcShape.begin(), srcShape.end(), 1,
 | 
			
		||||
                                           std::multiplies<int64_t>());
 | 
			
		||||
      auto elemType = operandType.getElementType();
 | 
			
		||||
      SmallVector<linalg::ReassociationExprs, 4> collapsingMap = {
 | 
			
		||||
          getIdentityExprs(dstShape.size())};
 | 
			
		||||
      SmallVector<linalg::ReassociationExprs, 4> expandingMap = {
 | 
			
		||||
          getIdentityExprs(srcShape.size())};
 | 
			
		||||
 | 
			
		||||
      if (isLHLO) {
 | 
			
		||||
        auto collapsedType = MemRefType::get({totalElems}, elemType);
 | 
			
		||||
        Value collapsedOp = rewriter.create<linalg::ReshapeOp>(
 | 
			
		||||
            loc, collapsedType, args[0], collapsingMap);
 | 
			
		||||
        Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>(
 | 
			
		||||
            loc, resultType, collapsedOp, expandingMap);
 | 
			
		||||
        rewriter.replaceOpWithNewOp<linalg::CopyOp>(
 | 
			
		||||
            reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr,
 | 
			
		||||
            /*outputPermutation =*/nullptr);
 | 
			
		||||
      } else {
 | 
			
		||||
        auto collapsedType = RankedTensorType::get({totalElems}, elemType);
 | 
			
		||||
        Value collapsedOp = rewriter.create<linalg::TensorReshapeOp>(
 | 
			
		||||
            loc, collapsedType, args[0], collapsingMap);
 | 
			
		||||
        rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
 | 
			
		||||
            reshapeOp, resultType, collapsedOp, expandingMap);
 | 
			
		||||
      }
 | 
			
		||||
      return success();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (isLHLO) {
 | 
			
		||||
      Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -373,6 +373,18 @@ func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> {
 | 
			
		|||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 | 
			
		||||
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 | 
			
		||||
// CHECK-LABEL: func @reshape_3D_4D
 | 
			
		||||
func @reshape_3D_4D(%arg0: tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32> {
 | 
			
		||||
  %0 = "mhlo.reshape"(%arg0) : (tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32>
 | 
			
		||||
  return %0 : tensor<1x784x1x1xf32>
 | 
			
		||||
}
 | 
			
		||||
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]]]
 | 
			
		||||
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP2]]]
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @minf
 | 
			
		||||
func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
 | 
			
		||||
  %0 = "mhlo.minimum"(%lhs, %rhs)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -688,6 +688,20 @@ func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) {
 | 
			
		|||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 | 
			
		||||
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 | 
			
		||||
// CHECK-LABEL: func @reshape_3D_4D
 | 
			
		||||
func @reshape_3D_4D(%arg0: memref<1x49x16xf32>, %arg1: memref<1x784x1x1xf32>) {
 | 
			
		||||
  "lmhlo.reshape"(%arg0, %arg1)
 | 
			
		||||
   : (memref<1x49x16xf32>, memref<1x784x1x1xf32>) -> ()
 | 
			
		||||
  return
 | 
			
		||||
}
 | 
			
		||||
// CHECK: linalg.reshape %{{.*}} [#[[RESHAPE_MAP1]]]
 | 
			
		||||
// CHECK: linalg.reshape %{{.*}} [#[[RESHAPE_MAP2]]]
 | 
			
		||||
// CHECK: linalg.copy
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)>
 | 
			
		||||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 | 
			
		||||
// CHECK-LABEL: func @reverse
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue