Enhance lowering reshape op to Linalg.
Handle non-expansion and non-collapsion cases by rewriting it to two reshape ops. PiperOrigin-RevId: 327926863
This commit is contained in:
parent
d2c9d03f31
commit
bfd629ecb0
|
@ -15,6 +15,8 @@ limitations under the License.
|
|||
|
||||
// This file implements logic for lowering HLO/LHLO dialect to Linalg dialect.
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
|
||||
|
@ -598,6 +600,7 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
|
|||
unsigned currSrcDim = 0, currDstDim = 0;
|
||||
SmallVector<linalg::ReassociationExprs, 4> reassociationMap(
|
||||
dstShape.size());
|
||||
bool isExpandingOrCollapsing = true;
|
||||
while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
|
||||
int64_t dstSize = dstShape[currDstDim];
|
||||
int64_t srcSize = srcShape[currSrcDim];
|
||||
|
@ -619,11 +622,47 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
|
|||
}
|
||||
}
|
||||
} else {
|
||||
return failure();
|
||||
isExpandingOrCollapsing = false;
|
||||
break;
|
||||
}
|
||||
currDstDim++;
|
||||
}
|
||||
if (currSrcDim != srcShape.size()) return failure();
|
||||
if (currSrcDim != srcShape.size()) isExpandingOrCollapsing = false;
|
||||
|
||||
if (!isExpandingOrCollapsing) {
|
||||
auto getIdentityExprs = [&rewriter](int n) {
|
||||
SmallVector<AffineExpr, 4> exprs;
|
||||
for (int i = 0; i < n; ++i)
|
||||
exprs.push_back(rewriter.getAffineDimExpr(i));
|
||||
return exprs;
|
||||
};
|
||||
Location loc = reshapeOp.getLoc();
|
||||
int64_t totalElems = std::accumulate(srcShape.begin(), srcShape.end(), 1,
|
||||
std::multiplies<int64_t>());
|
||||
auto elemType = operandType.getElementType();
|
||||
SmallVector<linalg::ReassociationExprs, 4> collapsingMap = {
|
||||
getIdentityExprs(dstShape.size())};
|
||||
SmallVector<linalg::ReassociationExprs, 4> expandingMap = {
|
||||
getIdentityExprs(srcShape.size())};
|
||||
|
||||
if (isLHLO) {
|
||||
auto collapsedType = MemRefType::get({totalElems}, elemType);
|
||||
Value collapsedOp = rewriter.create<linalg::ReshapeOp>(
|
||||
loc, collapsedType, args[0], collapsingMap);
|
||||
Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>(
|
||||
loc, resultType, collapsedOp, expandingMap);
|
||||
rewriter.replaceOpWithNewOp<linalg::CopyOp>(
|
||||
reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr,
|
||||
/*outputPermutation =*/nullptr);
|
||||
} else {
|
||||
auto collapsedType = RankedTensorType::get({totalElems}, elemType);
|
||||
Value collapsedOp = rewriter.create<linalg::TensorReshapeOp>(
|
||||
loc, collapsedType, args[0], collapsingMap);
|
||||
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
|
||||
reshapeOp, resultType, collapsedOp, expandingMap);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
if (isLHLO) {
|
||||
Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>(
|
||||
|
|
|
@ -373,6 +373,18 @@ func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
// CHECK-LABEL: func @reshape_3D_4D
|
||||
func @reshape_3D_4D(%arg0: tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32> {
|
||||
%0 = "mhlo.reshape"(%arg0) : (tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32>
|
||||
return %0 : tensor<1x784x1x1xf32>
|
||||
}
|
||||
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]]]
|
||||
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP2]]]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @minf
|
||||
func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%0 = "mhlo.minimum"(%lhs, %rhs)
|
||||
|
|
|
@ -688,6 +688,20 @@ func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
// CHECK-LABEL: func @reshape_3D_4D
|
||||
func @reshape_3D_4D(%arg0: memref<1x49x16xf32>, %arg1: memref<1x784x1x1xf32>) {
|
||||
"lmhlo.reshape"(%arg0, %arg1)
|
||||
: (memref<1x49x16xf32>, memref<1x784x1x1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.reshape %{{.*}} [#[[RESHAPE_MAP1]]]
|
||||
// CHECK: linalg.reshape %{{.*}} [#[[RESHAPE_MAP2]]]
|
||||
// CHECK: linalg.copy
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)>
|
||||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-LABEL: func @reverse
|
||||
|
|
Loading…
Reference in New Issue