Enhance lowering reshape op to Linalg.

Handle non-expansion and non-collapsion cases by rewriting it to two reshape
ops.

PiperOrigin-RevId: 327926863
This commit is contained in:
Hanhan Wang 2020-08-21 23:26:35 -07:00 committed by TensorFlow MLIR Team
parent d2c9d03f31
commit bfd629ecb0
3 changed files with 67 additions and 2 deletions

View File

@ -15,6 +15,8 @@ limitations under the License.
// This file implements logic for lowering HLO/LHLO dialect to Linalg dialect.
#include <numeric>
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
@ -598,6 +600,7 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
unsigned currSrcDim = 0, currDstDim = 0;
SmallVector<linalg::ReassociationExprs, 4> reassociationMap(
dstShape.size());
bool isExpandingOrCollapsing = true;
while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
int64_t dstSize = dstShape[currDstDim];
int64_t srcSize = srcShape[currSrcDim];
@ -619,11 +622,47 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
}
}
} else {
return failure();
isExpandingOrCollapsing = false;
break;
}
currDstDim++;
}
if (currSrcDim != srcShape.size()) return failure();
if (currSrcDim != srcShape.size()) isExpandingOrCollapsing = false;
if (!isExpandingOrCollapsing) {
auto getIdentityExprs = [&rewriter](int n) {
SmallVector<AffineExpr, 4> exprs;
for (int i = 0; i < n; ++i)
exprs.push_back(rewriter.getAffineDimExpr(i));
return exprs;
};
Location loc = reshapeOp.getLoc();
int64_t totalElems = std::accumulate(srcShape.begin(), srcShape.end(), 1,
std::multiplies<int64_t>());
auto elemType = operandType.getElementType();
SmallVector<linalg::ReassociationExprs, 4> collapsingMap = {
getIdentityExprs(dstShape.size())};
SmallVector<linalg::ReassociationExprs, 4> expandingMap = {
getIdentityExprs(srcShape.size())};
if (isLHLO) {
auto collapsedType = MemRefType::get({totalElems}, elemType);
Value collapsedOp = rewriter.create<linalg::ReshapeOp>(
loc, collapsedType, args[0], collapsingMap);
Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>(
loc, resultType, collapsedOp, expandingMap);
rewriter.replaceOpWithNewOp<linalg::CopyOp>(
reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr,
/*outputPermutation =*/nullptr);
} else {
auto collapsedType = RankedTensorType::get({totalElems}, elemType);
Value collapsedOp = rewriter.create<linalg::TensorReshapeOp>(
loc, collapsedType, args[0], collapsingMap);
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
reshapeOp, resultType, collapsedOp, expandingMap);
}
return success();
}
if (isLHLO) {
Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>(

View File

@ -373,6 +373,18 @@ func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> {
// -----
// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-LABEL: func @reshape_3D_4D
func @reshape_3D_4D(%arg0: tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32> {
%0 = "mhlo.reshape"(%arg0) : (tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32>
return %0 : tensor<1x784x1x1xf32>
}
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]]]
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP2]]]
// -----
// CHECK-LABEL: func @minf
func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
%0 = "mhlo.minimum"(%lhs, %rhs)

View File

@ -688,6 +688,20 @@ func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) {
// -----
// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-LABEL: func @reshape_3D_4D
func @reshape_3D_4D(%arg0: memref<1x49x16xf32>, %arg1: memref<1x784x1x1xf32>) {
"lmhlo.reshape"(%arg0, %arg1)
: (memref<1x49x16xf32>, memref<1x784x1x1xf32>) -> ()
return
}
// CHECK: linalg.reshape %{{.*}} [#[[RESHAPE_MAP1]]]
// CHECK: linalg.reshape %{{.*}} [#[[RESHAPE_MAP2]]]
// CHECK: linalg.copy
// -----
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @reverse