Enhance lowering reshape op to Linalg.
Handle non-expansion and non-collapsion cases by rewriting it to two reshape ops. PiperOrigin-RevId: 327926863
This commit is contained in:
		
							parent
							
								
									d2c9d03f31
								
							
						
					
					
						commit
						bfd629ecb0
					
				| 
						 | 
					@ -15,6 +15,8 @@ limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// This file implements logic for lowering HLO/LHLO dialect to Linalg dialect.
 | 
					// This file implements logic for lowering HLO/LHLO dialect to Linalg dialect.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <numeric>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 | 
					#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 | 
				
			||||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 | 
					#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 | 
				
			||||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
 | 
					#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
 | 
				
			||||||
| 
						 | 
					@ -598,6 +600,7 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
 | 
				
			||||||
    unsigned currSrcDim = 0, currDstDim = 0;
 | 
					    unsigned currSrcDim = 0, currDstDim = 0;
 | 
				
			||||||
    SmallVector<linalg::ReassociationExprs, 4> reassociationMap(
 | 
					    SmallVector<linalg::ReassociationExprs, 4> reassociationMap(
 | 
				
			||||||
        dstShape.size());
 | 
					        dstShape.size());
 | 
				
			||||||
 | 
					    bool isExpandingOrCollapsing = true;
 | 
				
			||||||
    while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
 | 
					    while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
 | 
				
			||||||
      int64_t dstSize = dstShape[currDstDim];
 | 
					      int64_t dstSize = dstShape[currDstDim];
 | 
				
			||||||
      int64_t srcSize = srcShape[currSrcDim];
 | 
					      int64_t srcSize = srcShape[currSrcDim];
 | 
				
			||||||
| 
						 | 
					@ -619,11 +622,47 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
 | 
				
			||||||
          }
 | 
					          }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
      } else {
 | 
					      } else {
 | 
				
			||||||
        return failure();
 | 
					        isExpandingOrCollapsing = false;
 | 
				
			||||||
 | 
					        break;
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      currDstDim++;
 | 
					      currDstDim++;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    if (currSrcDim != srcShape.size()) return failure();
 | 
					    if (currSrcDim != srcShape.size()) isExpandingOrCollapsing = false;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (!isExpandingOrCollapsing) {
 | 
				
			||||||
 | 
					      auto getIdentityExprs = [&rewriter](int n) {
 | 
				
			||||||
 | 
					        SmallVector<AffineExpr, 4> exprs;
 | 
				
			||||||
 | 
					        for (int i = 0; i < n; ++i)
 | 
				
			||||||
 | 
					          exprs.push_back(rewriter.getAffineDimExpr(i));
 | 
				
			||||||
 | 
					        return exprs;
 | 
				
			||||||
 | 
					      };
 | 
				
			||||||
 | 
					      Location loc = reshapeOp.getLoc();
 | 
				
			||||||
 | 
					      int64_t totalElems = std::accumulate(srcShape.begin(), srcShape.end(), 1,
 | 
				
			||||||
 | 
					                                           std::multiplies<int64_t>());
 | 
				
			||||||
 | 
					      auto elemType = operandType.getElementType();
 | 
				
			||||||
 | 
					      SmallVector<linalg::ReassociationExprs, 4> collapsingMap = {
 | 
				
			||||||
 | 
					          getIdentityExprs(dstShape.size())};
 | 
				
			||||||
 | 
					      SmallVector<linalg::ReassociationExprs, 4> expandingMap = {
 | 
				
			||||||
 | 
					          getIdentityExprs(srcShape.size())};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      if (isLHLO) {
 | 
				
			||||||
 | 
					        auto collapsedType = MemRefType::get({totalElems}, elemType);
 | 
				
			||||||
 | 
					        Value collapsedOp = rewriter.create<linalg::ReshapeOp>(
 | 
				
			||||||
 | 
					            loc, collapsedType, args[0], collapsingMap);
 | 
				
			||||||
 | 
					        Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>(
 | 
				
			||||||
 | 
					            loc, resultType, collapsedOp, expandingMap);
 | 
				
			||||||
 | 
					        rewriter.replaceOpWithNewOp<linalg::CopyOp>(
 | 
				
			||||||
 | 
					            reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr,
 | 
				
			||||||
 | 
					            /*outputPermutation =*/nullptr);
 | 
				
			||||||
 | 
					      } else {
 | 
				
			||||||
 | 
					        auto collapsedType = RankedTensorType::get({totalElems}, elemType);
 | 
				
			||||||
 | 
					        Value collapsedOp = rewriter.create<linalg::TensorReshapeOp>(
 | 
				
			||||||
 | 
					            loc, collapsedType, args[0], collapsingMap);
 | 
				
			||||||
 | 
					        rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
 | 
				
			||||||
 | 
					            reshapeOp, resultType, collapsedOp, expandingMap);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      return success();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (isLHLO) {
 | 
					    if (isLHLO) {
 | 
				
			||||||
      Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>(
 | 
					      Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -373,6 +373,18 @@ func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// -----
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 | 
				
			||||||
 | 
					// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 | 
				
			||||||
 | 
					// CHECK-LABEL: func @reshape_3D_4D
 | 
				
			||||||
 | 
					func @reshape_3D_4D(%arg0: tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32> {
 | 
				
			||||||
 | 
					  %0 = "mhlo.reshape"(%arg0) : (tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32>
 | 
				
			||||||
 | 
					  return %0 : tensor<1x784x1x1xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]]]
 | 
				
			||||||
 | 
					// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP2]]]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// CHECK-LABEL: func @minf
 | 
					// CHECK-LABEL: func @minf
 | 
				
			||||||
func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
 | 
					func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
 | 
				
			||||||
  %0 = "mhlo.minimum"(%lhs, %rhs)
 | 
					  %0 = "mhlo.minimum"(%lhs, %rhs)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -688,6 +688,20 @@ func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// -----
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 | 
				
			||||||
 | 
					// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 | 
				
			||||||
 | 
					// CHECK-LABEL: func @reshape_3D_4D
 | 
				
			||||||
 | 
					func @reshape_3D_4D(%arg0: memref<1x49x16xf32>, %arg1: memref<1x784x1x1xf32>) {
 | 
				
			||||||
 | 
					  "lmhlo.reshape"(%arg0, %arg1)
 | 
				
			||||||
 | 
					   : (memref<1x49x16xf32>, memref<1x784x1x1xf32>) -> ()
 | 
				
			||||||
 | 
					  return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					// CHECK: linalg.reshape %{{.*}} [#[[RESHAPE_MAP1]]]
 | 
				
			||||||
 | 
					// CHECK: linalg.reshape %{{.*}} [#[[RESHAPE_MAP2]]]
 | 
				
			||||||
 | 
					// CHECK: linalg.copy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)>
 | 
					// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)>
 | 
				
			||||||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 | 
					// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 | 
				
			||||||
// CHECK-LABEL: func @reverse
 | 
					// CHECK-LABEL: func @reverse
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue