Enhance lowering reshape op to Linalg.
Handle non-expansion and non-collapsion cases by rewriting it to two reshape ops. PiperOrigin-RevId: 327926863
This commit is contained in:
parent
d2c9d03f31
commit
bfd629ecb0
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||||
|
|
||||||
// This file implements logic for lowering HLO/LHLO dialect to Linalg dialect.
|
// This file implements logic for lowering HLO/LHLO dialect to Linalg dialect.
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
|
||||||
|
@ -598,6 +600,7 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
|
||||||
unsigned currSrcDim = 0, currDstDim = 0;
|
unsigned currSrcDim = 0, currDstDim = 0;
|
||||||
SmallVector<linalg::ReassociationExprs, 4> reassociationMap(
|
SmallVector<linalg::ReassociationExprs, 4> reassociationMap(
|
||||||
dstShape.size());
|
dstShape.size());
|
||||||
|
bool isExpandingOrCollapsing = true;
|
||||||
while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
|
while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
|
||||||
int64_t dstSize = dstShape[currDstDim];
|
int64_t dstSize = dstShape[currDstDim];
|
||||||
int64_t srcSize = srcShape[currSrcDim];
|
int64_t srcSize = srcShape[currSrcDim];
|
||||||
|
@ -619,11 +622,47 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return failure();
|
isExpandingOrCollapsing = false;
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
currDstDim++;
|
currDstDim++;
|
||||||
}
|
}
|
||||||
if (currSrcDim != srcShape.size()) return failure();
|
if (currSrcDim != srcShape.size()) isExpandingOrCollapsing = false;
|
||||||
|
|
||||||
|
if (!isExpandingOrCollapsing) {
|
||||||
|
auto getIdentityExprs = [&rewriter](int n) {
|
||||||
|
SmallVector<AffineExpr, 4> exprs;
|
||||||
|
for (int i = 0; i < n; ++i)
|
||||||
|
exprs.push_back(rewriter.getAffineDimExpr(i));
|
||||||
|
return exprs;
|
||||||
|
};
|
||||||
|
Location loc = reshapeOp.getLoc();
|
||||||
|
int64_t totalElems = std::accumulate(srcShape.begin(), srcShape.end(), 1,
|
||||||
|
std::multiplies<int64_t>());
|
||||||
|
auto elemType = operandType.getElementType();
|
||||||
|
SmallVector<linalg::ReassociationExprs, 4> collapsingMap = {
|
||||||
|
getIdentityExprs(dstShape.size())};
|
||||||
|
SmallVector<linalg::ReassociationExprs, 4> expandingMap = {
|
||||||
|
getIdentityExprs(srcShape.size())};
|
||||||
|
|
||||||
|
if (isLHLO) {
|
||||||
|
auto collapsedType = MemRefType::get({totalElems}, elemType);
|
||||||
|
Value collapsedOp = rewriter.create<linalg::ReshapeOp>(
|
||||||
|
loc, collapsedType, args[0], collapsingMap);
|
||||||
|
Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>(
|
||||||
|
loc, resultType, collapsedOp, expandingMap);
|
||||||
|
rewriter.replaceOpWithNewOp<linalg::CopyOp>(
|
||||||
|
reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr,
|
||||||
|
/*outputPermutation =*/nullptr);
|
||||||
|
} else {
|
||||||
|
auto collapsedType = RankedTensorType::get({totalElems}, elemType);
|
||||||
|
Value collapsedOp = rewriter.create<linalg::TensorReshapeOp>(
|
||||||
|
loc, collapsedType, args[0], collapsingMap);
|
||||||
|
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
|
||||||
|
reshapeOp, resultType, collapsedOp, expandingMap);
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
if (isLHLO) {
|
if (isLHLO) {
|
||||||
Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>(
|
Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>(
|
||||||
|
|
|
@ -373,6 +373,18 @@ func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||||
|
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||||
|
// CHECK-LABEL: func @reshape_3D_4D
|
||||||
|
func @reshape_3D_4D(%arg0: tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32> {
|
||||||
|
%0 = "mhlo.reshape"(%arg0) : (tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32>
|
||||||
|
return %0 : tensor<1x784x1x1xf32>
|
||||||
|
}
|
||||||
|
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]]]
|
||||||
|
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP2]]]
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @minf
|
// CHECK-LABEL: func @minf
|
||||||
func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||||
%0 = "mhlo.minimum"(%lhs, %rhs)
|
%0 = "mhlo.minimum"(%lhs, %rhs)
|
||||||
|
|
|
@ -688,6 +688,20 @@ func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||||
|
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||||
|
// CHECK-LABEL: func @reshape_3D_4D
|
||||||
|
func @reshape_3D_4D(%arg0: memref<1x49x16xf32>, %arg1: memref<1x784x1x1xf32>) {
|
||||||
|
"lmhlo.reshape"(%arg0, %arg1)
|
||||||
|
: (memref<1x49x16xf32>, memref<1x784x1x1xf32>) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: linalg.reshape %{{.*}} [#[[RESHAPE_MAP1]]]
|
||||||
|
// CHECK: linalg.reshape %{{.*}} [#[[RESHAPE_MAP2]]]
|
||||||
|
// CHECK: linalg.copy
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)>
|
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)>
|
||||||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
// CHECK-LABEL: func @reverse
|
// CHECK-LABEL: func @reverse
|
||||||
|
|
Loading…
Reference in New Issue