Updates LLVM usage to match
[da3ed58b97c1](https://github.com/llvm/llvm-project/commit/da3ed58b97c1)

PiperOrigin-RevId: 377432380
This commit is contained in:
A. Unique TensorFlower 2021-06-03 20:44:08 -07:00 committed by TensorFlow MLIR Team
parent aba16adfa5
commit db05388a3c
7 changed files with 73 additions and 56 deletions

View File

@ -15,9 +15,9 @@
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
LLVM_COMMIT = "c89dff5855bb32d47751cce087537c2b12a90f1b"
LLVM_COMMIT = "da3ed58b97c1cc1356b7732d5dcbb6e4de3057da"
LLVM_SHA256 = "900067ffc67a11fd1f650d8852e7706c7642d86f5cc81bbd6cd67996fae58116"
LLVM_SHA256 = "d0766a8638c50daf167d699a71982bcbec3a0b41bc86054bcc642a40755dca32"
LLVM_BAZEL_TAG = "llvm-project-{commit}".format(commit = LLVM_COMMIT)

View File

@ -1,2 +1,2 @@
c89dff5855bb32d47751cce087537c2b12a90f1b
da3ed58b97c1cc1356b7732d5dcbb6e4de3057da

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <vector>
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/Optional.h"
namespace mlir {

View File

@ -701,16 +701,16 @@ class LhloBroadcastInDimConverter
collapsed_dims_list.back().push_back(dims);
}
// `linalg.reshape` is inserted only if necessary, i.e. when the rank can be
// reduced.
// `linalg.collapse_shape` is inserted only if necessary, i.e. when the rank
// can be reduced.
if (new_shape.size() < operand_shape.size()) {
auto new_memref_type = MemRefType::get(
new_shape, operand_type.getElementType(),
makeStridedLinearLayoutMap(new_strides, operand_offset,
rewriter.getContext()));
operand = rewriter.create<linalg::ReshapeOp>(op.getLoc(), new_memref_type,
operand_adaptor.operand(),
collapsed_dims_list);
operand = rewriter.create<linalg::CollapseShapeOp>(
op.getLoc(), new_memref_type, operand_adaptor.operand(),
collapsed_dims_list);
}
return std::make_pair(operand, broadcast_dims);
}
@ -868,30 +868,45 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
if (isLHLO) {
auto collapsed_type = MemRefType::get({total_elems}, elem_type);
Value collapsed_op = rewriter.create<linalg::ReshapeOp>(
Value collapsed_op = rewriter.create<linalg::CollapseShapeOp>(
loc, collapsed_type, args[0], collapsing_map);
Value reshape_buffer = rewriter.create<linalg::ReshapeOp>(
Value reshape_buffer = rewriter.create<linalg::ExpandShapeOp>(
loc, result_type, collapsed_op, expanding_map);
rewriter.replaceOpWithNewOp<linalg::CopyOp>(reshape_op, reshape_buffer,
args[1]);
} else {
auto collapsed_type = RankedTensorType::get({total_elems}, elem_type);
Value collapsed_op = rewriter.create<linalg::TensorReshapeOp>(
Value collapsed_op = rewriter.create<linalg::TensorCollapseShapeOp>(
loc, collapsed_type, args[0], collapsing_map);
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
reshape_op, result_type, collapsed_op, expanding_map);
}
return success();
}
bool isCollapsing =
result_type.getRank() < args[0].getType().cast<ShapedType>().getRank();
if (isLHLO) {
Value reshape_buffer = rewriter.create<linalg::ReshapeOp>(
reshape_op.getLoc(), result_type, args[0], reassociation_map);
Value reshape_buffer = isCollapsing ? rewriter
.create<linalg::CollapseShapeOp>(
reshape_op.getLoc(), result_type,
args[0], reassociation_map)
.getResult()
: rewriter
.create<linalg::ExpandShapeOp>(
reshape_op.getLoc(), result_type,
args[0], reassociation_map)
.getResult();
rewriter.replaceOpWithNewOp<linalg::CopyOp>(reshape_op, reshape_buffer,
args[1]);
} else {
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
reshape_op, result_type, args[0], reassociation_map);
if (isCollapsing) {
rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
reshape_op, result_type, args[0], reassociation_map);
} else {
rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
reshape_op, result_type, args[0], reassociation_map);
}
}
return success();
}
@ -1910,7 +1925,7 @@ struct DepthwiseConvOpOnTensorsConversion
SmallVector<linalg::ReassociationIndices, 4> collapsed_dim_list = {
get_indices_vector(0, 1), get_indices_vector(1, 2),
get_indices_vector(2, 3), get_indices_vector(3, 5)};
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
op, result_type, conv.getResult(0), collapsed_dim_list);
} else {
// For cases where channel multiplier == 1
@ -1936,7 +1951,7 @@ struct DepthwiseConvOpOnTensorsConversion
get_indices_vector(0, 1), get_indices_vector(1, 2),
get_indices_vector(2, 4)};
Value reshaped_filter = rewriter.create<linalg::TensorReshapeOp>(
Value reshaped_filter = rewriter.create<linalg::TensorCollapseShapeOp>(
loc, filter_shape, filter, collapsed_dim_list);
rewriter.replaceOpWithNewOp<linalg::DepthwiseConvInputNHWCFilterHWCOp>(

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <algorithm>
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SmallVector.h"
namespace mlir {

View File

@ -543,7 +543,7 @@ func @reshape_0D_1D(%arg0: tensor<i32>) -> tensor<1xi32> {
%0 = "mhlo.reshape"(%arg0) : (tensor<i32>) -> tensor<1xi32>
return %0 : tensor<1xi32>
}
// CHECK: linalg.tensor_reshape %{{.*}} [] : tensor<i32> into tensor<1xi32>
// CHECK: linalg.tensor_expand_shape %{{.*}} [] : tensor<i32> into tensor<1xi32>
// -----
@ -554,7 +554,7 @@ func @reshape_0D_1D_unsigned(%arg0: tensor<ui32>) -> tensor<1xui32> {
// CHECK-LABEL: func @reshape_0D_1D_unsigned
// CHECK-SAME: %[[ARG_UNSIGNED:[a-zA-Z0-9_]*]]
// CHECK: %[[ARG_SIGNLESS:.*]] = unrealized_conversion_cast %[[ARG_UNSIGNED]] : tensor<ui32> to tensor<i32>
// CHECK: %[[RET_SIGNLESS:.*]] = linalg.tensor_reshape %[[ARG_SIGNLESS]] [] : tensor<i32> into tensor<1xi32>
// CHECK: %[[RET_SIGNLESS:.*]] = linalg.tensor_expand_shape %[[ARG_SIGNLESS]] [] : tensor<i32> into tensor<1xi32>
// CHECK: %[[RET_UNSIGNED:.*]] = unrealized_conversion_cast %[[RET_SIGNLESS]] : tensor<1xi32> to tensor<1xui32>
// CHECK: return %[[RET_UNSIGNED]] : tensor<1xui32>
@ -565,7 +565,7 @@ func @reshape_1D_0D(%arg0: tensor<1xi32>) -> tensor<i32> {
%0 = "mhlo.reshape"(%arg0) : (tensor<1xi32>) -> tensor<i32>
return %0 : tensor<i32>
}
// CHECK: linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor<i32>
// CHECK: linalg.tensor_collapse_shape %{{.*}} [] : tensor<1xi32> into tensor<i32>
// -----
@ -576,7 +576,7 @@ func @reshape_1D_0D_unsigned(%arg0: tensor<1xui32>) -> tensor<ui32> {
// CHECK-LABEL: func @reshape_1D_0D_unsigned
// CHECK-SAME: %[[ARG_UNSIGNED:[a-zA-Z0-9_]*]]
// CHECK: %[[ARG_SIGNLESS:.*]] = unrealized_conversion_cast %[[ARG_UNSIGNED]] : tensor<1xui32> to tensor<1xi32>
// CHECK: %[[RET_SIGNLESS:.*]] = linalg.tensor_reshape %[[ARG_SIGNLESS]] [] : tensor<1xi32> into tensor<i32>
// CHECK: %[[RET_SIGNLESS:.*]] = linalg.tensor_collapse_shape %[[ARG_SIGNLESS]] [] : tensor<1xi32> into tensor<i32>
// CHECK: %[[RET_UNSIGNED:.*]] = unrealized_conversion_cast %[[RET_SIGNLESS]] : tensor<i32> to tensor<ui32>
// CHECK: return %[[RET_UNSIGNED]] : tensor<ui32>
@ -587,7 +587,7 @@ func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> {
%0 = "mhlo.reshape"(%arg0) : (tensor<12x1x42xi32>) -> tensor<12x42xi32>
return %0 : tensor<12x42xi32>
}
// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1], [2]]
// CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1], [2]]
// -----
@ -596,7 +596,7 @@ func @reshape_4D_2D(%arg0: tensor<12x42x1x1xi32>) -> tensor<12x42xi32> {
%0 = "mhlo.reshape"(%arg0) : (tensor<12x42x1x1xi32>) -> tensor<12x42xi32>
return %0 : tensor<12x42xi32>
}
// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0], [1, 2, 3]]
// CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0], [1, 2, 3]]
// -----
@ -605,7 +605,7 @@ func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> {
%0 = "mhlo.reshape"(%arg0) : (tensor<12x42xi32>) -> tensor<12x1x42x1xi32>
return %0 : tensor<12x1x42x1xi32>
}
// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1], [2, 3]]
// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1], [2, 3]]
// -----
@ -614,8 +614,8 @@ func @reshape_3D_4D(%arg0: tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32> {
%0 = "mhlo.reshape"(%arg0) : (tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32>
return %0 : tensor<1x784x1x1xf32>
}
// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1, 2]]
// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1, 2, 3]]
// CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1, 2]]
// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1, 2, 3]]
// -----
@ -624,8 +624,8 @@ func @reshape_4D_3D(%arg0: tensor<1x8x10x3xf32>) -> tensor<1x240x1xf32> {
%0 = "mhlo.reshape"(%arg0) : (tensor<1x8x10x3xf32>) -> tensor<1x240x1xf32>
return %0 : tensor<1x240x1xf32>
}
// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1, 2, 3]
// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1, 2]
// CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3]
// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1, 2]
// -----
@ -634,8 +634,8 @@ func @reshape1_4D_4D(%arg0: tensor<4x512x1x1xi32>) -> tensor<1x4x1x512xi32> {
%0 = "mhlo.reshape"(%arg0) : (tensor<4x512x1x1xi32>) -> tensor<1x4x1x512xi32>
return %0 : tensor<1x4x1x512xi32>
}
// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1, 2, 3]
// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1, 2, 3]
// CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3]
// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1, 2, 3]
// -----
@ -644,8 +644,8 @@ func @reshape2_4D_4D(%arg0: tensor<4x1x1x1024xi32>) -> tensor<4x1024x1x1xi32> {
%0 = "mhlo.reshape"(%arg0) : (tensor<4x1x1x1024xi32>) -> tensor<4x1024x1x1xi32>
return %0 : tensor<4x1024x1x1xi32>
}
// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1, 2, 3]
// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1, 2, 3]
// CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3]
// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1, 2, 3]
// -----
@ -718,7 +718,7 @@ func @reshape_collapse_single_dim
return %0 : tensor<1x784xf32>
}
// CHECK-LABEL: func @reshape_collapse_single_dim
// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0], [1, 2, 3]]
// CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0], [1, 2, 3]]
// -----
@ -727,7 +727,7 @@ func @reshape_collapse(%arg0: tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32> {
return %0 : tensor<2x4x3xf32>
}
// CHECK-LABEL: func @reshape_collapse
// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0], [1, 2], [3]]
// CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0], [1, 2], [3]]
// -----
@ -736,7 +736,7 @@ func @reshape_expand(%arg0: tensor<2x8xf32>) -> tensor<2x4x2xf32> {
return %0 : tensor<2x4x2xf32>
}
// CHECK-LABEL: func @reshape_expand
// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0], [1, 2]]
// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0], [1, 2]]
// -----
@ -745,7 +745,7 @@ func @reshape_single_expand(%arg0 : tensor<8xf32>) -> tensor<1x4x2xf32> {
return %0 : tensor<1x4x2xf32>
}
// CHECK-LABEL: func @reshape_single_expand
// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1, 2]]
// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1, 2]]
// -----
@ -755,7 +755,7 @@ func @reshape_multiple_collapse
return %0 : tensor<1x4x5x6xf32>
}
// CHECK-LABEL: func @reshape_multiple_collapse
// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0], [1, 2], [3], [4, 5]]
// CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0], [1, 2], [3], [4, 5]]
// -----
@ -2062,7 +2062,7 @@ func @depthwise_conv(%arg0: tensor<2x4x5x2xf32>,
// CHECK-SAME: {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
// CHECK-SAME: ins(%[[IN]], %[[FILTER]] : tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<2x3x4x2x3xf32>) -> tensor<2x3x4x2x3xf32>
// CHECK: %{{.+}} = linalg.tensor_reshape %[[OUT]]
// CHECK: %{{.+}} = linalg.tensor_collapse_shape %[[OUT]]
// CHECK-SAME: [0], [1], [2], [3, 4]
// CHECK-SAME: : tensor<2x3x4x2x3xf32> into tensor<2x3x4x6xf32>
@ -2095,7 +2095,7 @@ func @depthwise_conv_multiplier_1(%arg0: tensor<1x113x113x96xf32>,
// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 56, 56, 96] : tensor<1x56x56x96xf32>
// CHECK: %[[CST:.+]] = constant 0.000000e+00 : f32
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[CST]]) : tensor<1x56x56x96xf32>, f32 -> tensor<1x56x56x96xf32>
// CHECK: %[[RESHAPED_FILTER:.+]] = linalg.tensor_reshape %[[FILTER]]
// CHECK: %[[RESHAPED_FILTER:.+]] = linalg.tensor_collapse_shape %[[FILTER]]
// CHECK-SAME: [0], [1], [2, 3]
// CHECK-SAME: : tensor<3x3x1x96xf32> into tensor<3x3x96xf32>
// CHECK: %{{.+}} = linalg.depthwise_conv_2d_input_nhwc_filter_hwc

View File

@ -346,7 +346,7 @@ func @static_broadcast_in_dim_no_expansion(%operand: memref<5xf32>,
} : (memref<5xf32>, memref<5x10xf32>) -> ()
return
}
// CHECK-NOT: linalg.reshape
// CHECK-NOT: linalg.{{.*}}shape
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
@ -363,7 +363,7 @@ func @static_broadcast_in_dim_expansion(%operand: memref<1x5xf32>,
} : (memref<1x5xf32>, memref<5x10x100xf32>) -> ()
return
}
// CHECK: %[[RESHAPED_ARG:.*]] = linalg.reshape %{{.*}} {{\[}}[0, 1]]
// CHECK: %[[RESHAPED_ARG:.*]] = linalg.collapse_shape %{{.*}} {{\[}}[0, 1]]
// CHECK-SAME: memref<1x5xf32> into memref<5xf32>
// CHECK: linalg.generic {{{.*}}indexing_maps =
// CHECK-SAME: [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
@ -383,7 +383,7 @@ func @static_broadcast_in_dim_scalar(%operand: memref<f32>,
} : (memref<f32>, memref<5x10xf32>) -> ()
return
}
// CHECK-NOT: linalg.reshape
// CHECK-NOT: linalg.{{.*}}shape
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[RESULT_MAP_0]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[CONST:.*]]: f32, %[[RESULT:.*]]: f32):
// CHECK-NEXT: linalg.yield %[[CONST]] : f32
@ -400,7 +400,7 @@ func @static_broadcast_in_dim_with_one_to_one(%operand: memref<1xf32>,
} : (memref<1xf32>, memref<1x5xf32>) -> ()
return
}
// CHECK-NOT: linalg.reshape
// CHECK-NOT: linalg.{{.*}}shape
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.+]]: f32, %{{.+}}: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
@ -416,7 +416,7 @@ func @static_broadcast_in_dim_with_one_to_many(%operand: memref<1xf32>,
} : (memref<1xf32>, memref<5x5xf32>) -> ()
return
}
// CHECK-NOT: linalg.reshape
// CHECK-NOT: linalg.{{.*}}shape
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[VALUE:.*]] = memref.load %{{.*}}[[C0]]
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[RESULT_MAP]]]
@ -881,7 +881,7 @@ func @reshape_3D_2D(%arg0: memref<12x1x42xi32>, %arg1 : memref<12x42xi32>) {
: (memref<12x1x42xi32>, memref<12x42xi32>) -> ()
return
}
// CHECK: linalg.reshape %{{.*}} {{\[}}[0, 1], [2]]
// CHECK: linalg.collapse_shape %{{.*}} {{\[}}[0, 1], [2]]
// CHECK-NEXT: linalg.copy
// -----
@ -892,7 +892,7 @@ func @reshape_4D_2D(%arg0: memref<12x42x1x1xi32>, %arg1 : memref<12x42xi32>) {
: (memref<12x42x1x1xi32>, memref<12x42xi32>) -> ()
return
}
// CHECK: linalg.reshape %{{.*}} {{\[}}[0], [1, 2, 3]]
// CHECK: linalg.collapse_shape %{{.*}} {{\[}}[0], [1, 2, 3]]
// CHECK-NEXT: linalg.copy
// -----
@ -903,7 +903,7 @@ func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) {
: (memref<12x42xi32>, memref<12x1x42x1xi32>) -> ()
return
}
// CHECK: linalg.reshape %{{.*}} {{\[}}[0, 1], [2, 3]]
// CHECK: linalg.expand_shape %{{.*}} {{\[}}[0, 1], [2, 3]]
// CHECK-NEXT: linalg.copy
// -----
@ -914,8 +914,8 @@ func @reshape_3D_4D(%arg0: memref<1x49x16xf32>, %arg1: memref<1x784x1x1xf32>) {
: (memref<1x49x16xf32>, memref<1x784x1x1xf32>) -> ()
return
}
// CHECK: linalg.reshape %{{.*}} {{\[}}[0, 1, 2]]
// CHECK: linalg.reshape %{{.*}} {{\[}}[0, 1, 2, 3]]
// CHECK: linalg.collapse_shape %{{.*}} {{\[}}[0, 1, 2]]
// CHECK: linalg.expand_shape %{{.*}} {{\[}}[0, 1, 2, 3]]
// CHECK: linalg.copy
// -----
@ -926,8 +926,8 @@ func @reshape_4D_3D(%arg0: memref<1x8x10x3xf32>, %arg1: memref<1x240x1xf32>) {
: (memref<1x8x10x3xf32>, memref<1x240x1xf32>) -> ()
return
}
// CHECK: linalg.reshape %{{.*}} {{\[}}[0, 1, 2, 3]]
// CHECK: linalg.reshape %{{.*}} {{\[}}[0, 1, 2]]
// CHECK: linalg.collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3]]
// CHECK: linalg.expand_shape %{{.*}} {{\[}}[0, 1, 2]]
// CHECK: linalg.copy
// -----
@ -939,8 +939,8 @@ func @reshape1_4D_4D(%arg0: memref<4x512x1x1xi32>,
: (memref<4x512x1x1xi32>, memref<1x4x1x512xi32>) -> ()
return
}
// CHECK: linalg.reshape %{{.*}} {{\[}}[0, 1, 2, 3]]
// CHECK: linalg.reshape %{{.*}} {{\[}}[0, 1, 2, 3]]
// CHECK: linalg.collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3]]
// CHECK: linalg.expand_shape %{{.*}} {{\[}}[0, 1, 2, 3]]
// -----
@ -951,8 +951,8 @@ func @reshape2_4D_4D(%arg0: memref<4x1x1x1024xi32>,
: (memref<4x1x1x1024xi32>, memref<4x1024x1x1xi32>) -> ()
return
}
// CHECK: linalg.reshape %{{.*}} {{\[}}[0, 1, 2, 3]]
// CHECK: linalg.reshape %{{.*}} {{\[}}[0, 1, 2, 3]]
// CHECK: linalg.collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3]]
// CHECK: linalg.expand_shape %{{.*}} {{\[}}[0, 1, 2, 3]]
// -----