Updates LLVM usage to match
[c3acda0798f9](https://github.com/llvm/llvm-project/commit/c3acda0798f9)

PiperOrigin-RevId: 348896724
This commit is contained in:
A. Unique TensorFlower 2020-12-23 23:53:08 -08:00 committed by TensorFlow MLIR Team
parent e3754d7b5c
commit b0bf2ef45b
15 changed files with 159 additions and 83 deletions

5
BUILD
View File

@ -464,6 +464,7 @@ cc_library(
"@llvm-project//mlir:SideEffects",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
],
@ -688,6 +689,7 @@ cc_library(
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
],
alwayslink = 1,
@ -727,6 +729,7 @@ cc_library(
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:ViewLikeInterface",
],
@ -972,6 +975,7 @@ cc_library(
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
],
)
@ -1038,6 +1042,7 @@ cc_library(
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
],
alwayslink = 1,

View File

@ -15,9 +15,9 @@
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
LLVM_COMMIT = "1b97cdf885d6455841280b8da858835e641ee941"
LLVM_COMMIT = "c3acda0798f9b10ac3187ad941bbd8af82fb84a1"
LLVM_SHA256 = "80d5036ba734fcb700a5699e2f99e5a0de5808dde01a1df3c4fae04510bc8e23"
LLVM_SHA256 = "bd707c585368c86a4d9de1f262d39adb230f7dac889aa786b2721bf67b447a8c"
LLVM_BAZEL_TAG = "llvm-project-{commit}".format(commit = LLVM_COMMIT)

View File

@ -1,2 +1,2 @@
1b97cdf885d6455841280b8da858835e641ee941
c3acda0798f9b10ac3187ad941bbd8af82fb84a1

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
@ -66,7 +67,8 @@ struct ConvertConstantLikeOp : public OpConversionPattern<ConstantLikeOp> {
loc, extent_tensor_type, transformed.operand());
Type shape_ty =
RankedTensorType::get({result_ty.getRank()}, rewriter.getIndexType());
Value shape = rewriter.create<TensorCastOp>(loc, shape_ty, uncasted_shape);
Value shape =
rewriter.create<tensor::CastOp>(loc, shape_ty, uncasted_shape);
rewriter.replaceOpWithNewOp<mhlo::DynamicBroadcastInDimOp>(
op, result_ty, constant, shape, rewriter.getI64TensorAttr({}));
return success();

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Pass/Pass.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h"
namespace mlir {
namespace mhlo {
@ -43,6 +44,7 @@ struct ChloLegalizeToHloPass
// The conversion uses helpers from the standard dialect.
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
conversionTarget.addLegalDialect<mlir::tensor::TensorDialect>();
conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>();
conversionTarget.addLegalDialect<mlir::scf::SCFDialect>();

View File

@ -70,6 +70,34 @@ bool VerifyHloOpBufferOrTensorSemantics(Operation* op) {
: llvm::all_of(op->getResults(), verify_type);
}
// TODO(pifon): Migrate to InitTensorOp when available.
template <bool isLHLO>
Value GetInitTensor(OpBuilder& b, Location loc, ShapedType type,
SmallVectorImpl<Value>& dyn_sizes) {
if (isLHLO) return nullptr;
return b.create<linalg::InitTensorOp>(loc, dyn_sizes, type.getShape(),
type.getElementType());
}
template <bool isLHLO>
Value GetInitTensor(OpBuilder& b, Location loc, ShapedType type) {
SmallVector<Value, 0> empty;
return GetInitTensor<isLHLO>(b, loc, type, empty);
}
// TODO(pifon): This logic is used everywhere, the code should be shared.
SmallVector<Value, 2> ExtractDynamicSizes(OpBuilder& b, Location loc,
Value tensor) {
auto tensor_type = tensor.getType().dyn_cast<RankedTensorType>();
if (!tensor_type) return {};
SmallVector<Value, 2> dyn_sizes;
for (auto& en : llvm::enumerate(tensor_type.getShape())) {
if (en.value() != ShapedType::kDynamicSize) continue;
dyn_sizes.push_back(b.create<DimOp>(loc, tensor, en.index()));
}
return dyn_sizes;
}
template <typename OpTy, bool isLHLO = true>
class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
public:
@ -113,18 +141,19 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
for (Value in : inputs)
body_arg_types.emplace_back(getElementTypeOrSelf(in.getType()));
ValueRange output_buffers(args.take_back(args.size() - num_inputs));
for (Value out : output_buffers)
body_result_types.emplace_back(getElementTypeOrSelf(out.getType()));
if (!isLHLO) {
// HLO operations have return as tensor types.
assert(body_result_types.empty() &&
"When lowering HLO ops result can't be part of arguments");
SmallVector<Value, 4> output_buffers;
if (isLHLO) {
output_buffers.append(args.begin() + num_inputs, args.end());
} else {
Value result = op.getOperation()->getResult(0);
body_result_types.push_back(getElementTypeOrSelf(result));
ShapedType result_type = result.getType().template cast<ShapedType>();
auto dyn_sizes = ExtractDynamicSizes(rewriter, loc, args[0]);
output_buffers.push_back(
GetInitTensor<isLHLO>(rewriter, loc, result_type, dyn_sizes));
op_result_types.push_back(result.getType());
}
body_result_types = llvm::to_vector<4>(llvm::map_range(
output_buffers, [](Value v) { return getElementTypeOrSelf(v); }));
AffineMap common_indexing_map =
nloops ? rewriter.getMultiDimIdentityMap(nloops)
@ -134,8 +163,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
bool failed = false;
auto linalg_op = rewriter.create<linalg::GenericOp>(
loc, op_result_types, inputs, output_buffers,
/*initTensors=*/ValueRange{}, indexing_maps,
loc, op_result_types, inputs, output_buffers, indexing_maps,
GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
// TODO(ravishankarm) : For now use the method in lmhlo namespace.
@ -309,13 +337,19 @@ class DataMovementOpConverter : public OpConversionPattern<OpTy> {
auto nloops = result_type.getRank();
auto loc = op.getLoc();
// TODO(pifon): technically, the op itself could have size operands (e.g.
// broadcast into a dynamic dimension).Handle this case.
auto dyn_sizes = isLHLO ? SmallVector<Value, 2>()
: ExtractDynamicSizes(rewriter, loc, args[0]);
auto linalg_op = rewriter.create<linalg::GenericOp>(
loc,
/*resultTensorTypes=*/isLHLO ? ArrayRef<Type>{} : result_type,
/*inputs=*/args.front(),
/*outputBuffers=*/isLHLO ? ValueRange{args.back()} : ValueRange{},
/*initTensor=*/ValueRange{}, indexing_maps,
GetNParallelLoopsAttrs(nloops),
/*outputBuffers=*/
isLHLO ? ValueRange{args.back()}
: ValueRange{GetInitTensor<isLHLO>(rewriter, loc, result_type,
dyn_sizes)},
indexing_maps, GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
nested_builder.create<linalg::YieldOp>(loc, *args.begin());
});
@ -712,13 +746,16 @@ class IotaConverter : public OpConversionPattern<OpTy> {
// Construct the indexing maps needed for linalg.generic ops.
unsigned nloops = result_shaped_type.getRank();
Location loc = iota_op.getLoc();
auto linalg_op = rewriter.create<linalg::IndexedGenericOp>(
iota_op.getLoc(),
loc,
/*resultTensorTypes=*/
isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{result_shaped_type},
/*inputs=*/ValueRange{},
/*outputBuffers=*/isLHLO ? ValueRange{args} : ValueRange{},
/*initTensors=*/ValueRange{},
/*outputBuffers=*/
isLHLO ? ValueRange{args}
: ValueRange{GetInitTensor<isLHLO>(rewriter, loc,
result_shaped_type)},
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange ivs,
@ -818,8 +855,8 @@ class ReduceConverter : public OpConversionPattern<lmhlo::ReduceOp> {
auto linalg_op = rewriter.create<linalg::GenericOp>(
loc, /*resultTensorTypes=*/ArrayRef<Type>{},
/*inputs=*/adaptor.operands(), /*outputBuffers=*/adaptor.out(),
/*initTensors=*/ValueRange{}, maps, types);
/*inputs=*/adaptor.operands(), /*outputBuffers=*/adaptor.out(), maps,
types);
rewriter.inlineRegionBefore(reduce_op.body(), linalg_op.region(),
linalg_op.region().end());
{

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@ -110,7 +111,7 @@ class LhloFuseLinalgPass
continue;
}
if (auto tensor_cast = dyn_cast<TensorCastOp>(definingOp)) {
if (auto tensor_cast = dyn_cast<tensor::CastOp>(definingOp)) {
auto alias = tensor_cast.source();
if (result_buffers.insert(alias).second) {
worklist.push_back(alias);

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
@ -228,7 +229,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
loc, result_type, IsScalarTensor(rewriter, op, lhs), true);
OpBuilder if_lhs_scalar_builder =
if_op.getThenBodyBuilder(rewriter.getListener());
Value reshaped_lhs = if_lhs_scalar_builder.create<TensorCastOp>(
Value reshaped_lhs = if_lhs_scalar_builder.create<tensor::CastOp>(
loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs);
Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs},
@ -247,7 +248,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
if_rhs_scalar_op.getResult(0));
OpBuilder if_rhs_scalar_builder =
if_rhs_scalar_op.getThenBodyBuilder(rewriter.getListener());
Value reshaped_rhs = if_rhs_scalar_builder.create<TensorCastOp>(
Value reshaped_rhs = if_rhs_scalar_builder.create<tensor::CastOp>(
loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs);
Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs},
@ -345,12 +346,12 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
Value extended_lhs = if_builder.create<shape::BroadcastOp>(
loc, unknown_rank_extent_tensor_type, lhs_shape, ranked_shape_val,
nullptr);
Value extended_lhs_casted = if_builder.create<TensorCastOp>(
Value extended_lhs_casted = if_builder.create<tensor::CastOp>(
loc, known_rank_extent_tensor_type, extended_lhs);
Value extended_rhs = if_builder.create<shape::BroadcastOp>(
loc, unknown_rank_extent_tensor_type, rhs_shape, ranked_shape_val,
nullptr);
Value extended_rhs_casted = if_builder.create<TensorCastOp>(
Value extended_rhs_casted = if_builder.create<tensor::CastOp>(
loc, known_rank_extent_tensor_type, extended_rhs);
// 1. Reshape operands to the given rank (with the same number of elements)
@ -372,7 +373,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
Value result = if_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{result_type},
ArrayRef<Value>{reshaped_lhs, reshaped_rhs}, op.getAttrs());
Value reshaped_result = if_builder.create<TensorCastOp>(
Value reshaped_result = if_builder.create<tensor::CastOp>(
loc, UnrankedTensorType::get(result_element_type), result);
if_builder.create<scf::YieldOp>(loc, reshaped_result);
}
@ -446,7 +447,8 @@ struct TransformUnrankedHloPass
MLIRContext &ctx = getContext();
ConversionTarget target(ctx);
target.addLegalDialect<mhlo::MhloDialect, StandardOpsDialect,
shape::ShapeDialect, scf::SCFDialect>();
shape::ShapeDialect, scf::SCFDialect,
tensor::TensorDialect>();
target.addLegalOp<FuncOp>();
#define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor<mhlo::op>(&target)
#define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor<chlo::op>(&target)

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
@ -66,7 +67,7 @@ Value ComputeBinaryElementwiseBroadcastingResultExtents(
Value result_shape_v = builder.createOrFold<shape::BroadcastOp>(
loc, shape::getExtentTensorType(builder.getContext()), lhs_shape_v,
rhs_shape_v, nullptr /* error */);
return builder.createOrFold<TensorCastOp>(
return builder.createOrFold<tensor::CastOp>(
loc, RankedTensorType::get({result_rank}, builder.getIndexType()),
result_shape_v);
}

View File

@ -19,7 +19,7 @@ func @dynamicBroadcast(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?
// CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
// CHECK-DAG: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
// CHECK: %[[RESULT_EXTENTS:.+]] = tensor_cast %[[RESULT_S]] : tensor<?xindex> to tensor<2xindex>
// CHECK: %[[RESULT_EXTENTS:.+]] = tensor.cast %[[RESULT_S]] : tensor<?xindex> to tensor<2xindex>
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
// CHECK-NEXT: %[[RESULT:.+]] = mhlo.add %[[ARG0_B]], %[[ARG1_B]]
@ -40,7 +40,7 @@ func @dynamicBroadcastComplex(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
// CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
// CHECK-NEXT: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
// CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = tensor_cast %[[RESULT_S]] : tensor<?xindex> to tensor<2xindex>
// CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = tensor.cast %[[RESULT_S]] : tensor<?xindex> to tensor<2xindex>
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-NEXT: %[[RESULT:.+]] = "mhlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
@ -61,7 +61,7 @@ func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
// CHECK: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
// CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
// CHECK: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
// CHECK: %[[RESULT_EXTENTS:.+]] = tensor_cast %[[RESULT_S]] : tensor<?xindex> to tensor<2xindex>
// CHECK: %[[RESULT_EXTENTS:.+]] = tensor.cast %[[RESULT_S]] : tensor<?xindex> to tensor<2xindex>
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.+]] = "mhlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>

View File

@ -16,7 +16,7 @@ func @constant_like_static_shape(%arg : tensor<1x2xi64>) -> tensor<1x2xf32> {
func @constant_like_dynamic_shape(%arg : tensor<?x?xi64>) -> tensor<?x?xf32> {
// CHECK: %[[CONSTANT:.*]] = mhlo.constant dense<3.200000e+00> : tensor<f32>
// CHECK: %[[UNCASTED_SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<?x?xi64> -> tensor<?xindex>
// CHECK: %[[SHAPE:.*]] = tensor_cast %[[UNCASTED_SHAPE]] : tensor<?xindex> to tensor<2xindex>
// CHECK: %[[SHAPE:.*]] = tensor.cast %[[UNCASTED_SHAPE]] : tensor<?xindex> to tensor<2xindex>
// CHECK: %[[BROADCASTED_CONSTANT:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[CONSTANT]], %[[SHAPE]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK: return %[[BROADCASTED_CONSTANT]] : tensor<?x?xf32>
%result = "chlo.constant_like"(%arg) { value = 3.2 : f32 }

View File

@ -628,7 +628,7 @@ func @shape_assuming_tensor(%arg0: tensor<?xf16>) -> tensor<?xf16> {
// CHECK: shape.assuming %{{.*}} -> (memref<?xf16>)
%2 = shape.assuming %1 -> (tensor<?xf16>) {
%3 = shape.shape_of %arg0 : tensor<?xf16> -> tensor<?xindex>
%4 = tensor_cast %3 : tensor<?xindex> to tensor<1xindex>
%4 = tensor.cast %3 : tensor<?xindex> to tensor<1xindex>
%5 = "mhlo.dynamic_broadcast_in_dim"(%0, %4) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f16>, tensor<1xindex>) -> tensor<?xf16>
%6 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %4) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf16>, tensor<1xindex>) -> tensor<?xf16>
// CHECK: "lmhlo.maximum"(%{{.*}}, %{{.*}}, %{{.*}}) : (memref<?xf16>, memref<?xf16>, memref<?xf16>) -> ()
@ -638,3 +638,5 @@ func @shape_assuming_tensor(%arg0: tensor<?xf16>) -> tensor<?xf16> {
}
return %2 : tensor<?xf16>
}

View File

@ -249,8 +249,9 @@ func @float_cmp(%lhs: tensor<2x2xf32>,
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
return %0 : tensor<2x2xi1>
}
// CHECK: linalg.init_tensor [2, 2] : tensor<2x2xi1>
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32):
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %{{.*}}: i1):
// CHECK-NEXT: %[[RESULT:.*]] = cmpf "oeq", %[[LHS_IN]], %[[RHS_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
@ -263,8 +264,9 @@ func @float_cmp_ne(%lhs: tensor<2x2xf32>,
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
return %0 : tensor<2x2xi1>
}
// CHECK: linalg.init_tensor [2, 2] : tensor<2x2xi1>
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32):
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %{{.*}}: i1):
// CHECK-NEXT: %[[RESULT:.*]] = cmpf "une", %[[LHS_IN]], %[[RHS_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
@ -277,8 +279,9 @@ func @int_cmp(%lhs: tensor<2x2xi32>,
: (tensor<2x2xi32>, tensor<2x2xi32>) -> (tensor<2x2xi1>)
return %0 : tensor<2x2xi1>
}
// CHECK: linalg.init_tensor [2, 2] : tensor<2x2xi1>
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32, %{{.*}}: i1):
// CHECK-NEXT: %[[RESULT:.*]] = cmpi "slt", %[[LHS_IN]], %[[RHS_IN]] : i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
@ -335,8 +338,9 @@ func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>,
: (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
return %0 : tensor<2x2xf32>
}
// CHECK: linalg.init_tensor [2, 2] : tensor<2x2xf32>
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[PRED_IN:.*]]: i1, %[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32):
// CHECK-NEXT: ^bb0(%[[PRED_IN:.*]]: i1, %[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %{{.*}}: f32):
// CHECK-NEXT: %[[RESULT:.*]] = select %[[PRED_IN]], %[[LHS_IN]], %[[RHS_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
@ -349,8 +353,9 @@ func @broadcast_scalar(%arg: tensor<f32>) -> tensor<4x2x1xf32> {
%0 = "mhlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<f32>) -> tensor<4x2x1xf32>
return %0: tensor<4x2x1xf32>
}
// CHECK: linalg.init_tensor [4, 2, 1] : tensor<4x2x1xf32>
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32):
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
// -----
@ -362,8 +367,11 @@ func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> {
%0 = "mhlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32>
return %0: tensor<4x2x1x4x?x16xf32>
}
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[D1:.*]] = dim %{{.*}}, %[[C1]] : tensor<4x?x16xf32>
// CHECK: linalg.init_tensor [4, 2, 1, 4, %[[D1]], 16] : tensor<4x2x1x4x?x16xf32>
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32):
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
// -----
@ -377,8 +385,9 @@ func @broadcast_in_dim(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> {
: (tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32>
return %0 : tensor<7x10x6x4x5xf32>
}
// CHECK: linalg.init_tensor [7, 10, 6, 4, 5] : tensor<7x10x6x4x5xf32>
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32):
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
// -----
@ -393,8 +402,9 @@ func @broadcast_in_dim_with_one_to_one(
: (tensor<1xf32>) -> tensor<1x5xf32>
return %0 : tensor<1x5xf32>
}
// CHECK: linalg.init_tensor [1, 5] : tensor<1x5xf32>
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32):
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
// -----
@ -408,8 +418,9 @@ func @broadcast_scalar(%operand: tensor<f32>) -> tensor<7x10x6xf32> {
: (tensor<f32>) -> tensor<7x10x6xf32>
return %0 : tensor<7x10x6xf32>
}
// CHECK: linalg.init_tensor [7, 10, 6] : tensor<7x10x6xf32>
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32):
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
// -----
@ -499,8 +510,9 @@ func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
// CHECK: linalg.init_tensor [2, 2] : tensor<2x2xf32>
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32):
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %{{.*}}: f32):
// CHECK-NEXT: %[[CMP:.*]] = cmpf "olt", %[[LHS_IN]], %[[RHS_IN]] : f32
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
@ -513,8 +525,9 @@ func @maxi(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// CHECK: linalg.init_tensor [2, 2] : tensor<2x2xi32>
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32, %{{.*}}: i32):
// CHECK-NEXT: %[[CMP:.*]] = cmpi "sgt", %[[LHS_IN]], %[[RHS_IN]] : i32
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
@ -527,9 +540,10 @@ func @add_scalar(%lhs: tensor<f32>, %rhs: tensor<f32>) -> tensor<f32> {
%0 = "mhlo.add"(%lhs, %rhs) : (tensor<f32>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
// CHECK: linalg.init_tensor
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32, %{{.*}}: f32):
// CHECK: %[[RESULT:.*]] = addf %[[LHS]], %[[RHS]]
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
@ -599,8 +613,9 @@ func @convert_i32_to_f32(%input: tensor<2x2xi32>) -> tensor<2x2xf32> {
%result = "mhlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xf32>
return %result : tensor<2x2xf32>
}
// CHECK: linalg.init_tensor
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32):
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %{{.*}}: f32):
// CHECK-NEXT: %[[RESULT:.*]] = sitofp %[[OPERAND_IN]] : i32 to f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
@ -611,8 +626,9 @@ func @convert_i16_to_i32(%input: tensor<2x2xi16>) -> tensor<2x2xi32> {
%result = "mhlo.convert"(%input) : (tensor<2x2xi16>) -> tensor<2x2xi32>
return %result : tensor<2x2xi32>
}
// CHECK: linalg.init_tensor
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i16):
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i16, %{{.*}}: i32):
// CHECK-NEXT: %[[RESULT:.*]] = zexti %[[OPERAND_IN]] : i16 to i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
@ -623,8 +639,9 @@ func @convert_i32_to_i16(%input: tensor<2x2xi32>) -> tensor<2x2xi16> {
%result = "mhlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xi16>
return %result : tensor<2x2xi16>
}
// CHECK: linalg.init_tensor
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32):
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %{{.*}}: i16):
// CHECK-NEXT: %[[RESULT:.*]] = trunci %[[OPERAND_IN]] : i32 to i16
// CHECK-NEXT: linalg.yield %[[RESULT]] : i16
@ -635,8 +652,9 @@ func @convert_f32_to_f64(%input: tensor<2x2xf32>) -> tensor<2x2xf64> {
%result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xf64>
return %result : tensor<2x2xf64>
}
// CHECK: linalg.init_tensor
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32):
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %{{.*}}: f64):
// CHECK-NEXT: %[[RESULT:.*]] = fpext %[[OPERAND_IN]] : f32 to f64
// CHECK-NEXT: linalg.yield %[[RESULT]] : f64
@ -647,8 +665,9 @@ func @convert_f64_to_f32(%input: tensor<2x2xf64>) -> tensor<2x2xf32> {
%result = "mhlo.convert"(%input) : (tensor<2x2xf64>) -> tensor<2x2xf32>
return %result : tensor<2x2xf32>
}
// CHECK: linalg.init_tensor
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f64):
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f64, %{{.*}}: f32):
// CHECK-NEXT: %[[RESULT:.*]] = fptrunc %[[OPERAND_IN]] : f64 to f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
@ -659,8 +678,9 @@ func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> {
%result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi32>
return %result : tensor<2x2xi32>
}
// CHECK: linalg.init_tensor
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32):
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %{{.*}}: i32):
// CHECK-NEXT: %[[RESULT:.*]] = fptosi %[[OPERAND_IN]] : f32 to i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
@ -686,9 +706,10 @@ func @iota() -> tensor<7x10xf32> {
%result = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<7x10xf32>)
return %result : tensor<7x10xf32>
}
// CHECK: linalg.init_tensor
// CHECK: linalg.indexed_generic
// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index):
// CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index, %{{.*}}: f32):
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32
// CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32
// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32
@ -702,8 +723,9 @@ func @shift_left(%lhs: tensor<2x2xi32>,
return %result : tensor<2x2xi32>
}
// CHECK-LABEL: func @shift_left
// CHECK: linalg.init_tensor
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32):
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32, %{{.*}}: i32):
// CHECK-NEXT: %[[RESULT:.*]] = shift_left %[[LHS]], %[[RHS]] : i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
@ -716,8 +738,9 @@ func @shift_right_arithmetic(%lhs: tensor<2x2xi32>,
return %result : tensor<2x2xi32>
}
// CHECK-LABEL: func @shift_right_arithmetic
// CHECK: linalg.init_tensor
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32):
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32, %{{.*}}: i32):
// CHECK-NEXT: %[[RESULT:.*]] = shift_right_signed %[[LHS]], %[[RHS]] : i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
@ -730,8 +753,9 @@ func @shift_right_logical(%lhs: tensor<2x2xi32>,
return %result : tensor<2x2xi32>
}
// CHECK-LABEL: func @shift_right_logical
// CHECK: linalg.init_tensor
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32):
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32, %{{.*}}: i32):
// CHECK-NEXT: %[[RESULT:.*]] = shift_right_unsigned %[[LHS]], %[[RHS]] : i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32

View File

@ -163,7 +163,7 @@ func @addUnrankedUnranked(
// CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[LHS_RANK]], %[[C0]] : index
// Handle scalar LHS case
// CHECK-NEXT: %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[SCALAR_LHS:.*]] = tensor_cast %[[LHS]] : tensor<*xf32> to tensor<f32>
// CHECK-NEXT: %[[SCALAR_LHS:.*]] = tensor.cast %[[LHS]] : tensor<*xf32> to tensor<f32>
// CHECK-NEXT: %[[RHS_SHAPE_1:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
// CHECK-NEXT: %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE_1]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[NUM_TENS_RHS:.*]] = tensor_from_elements %[[NUM_RHS]] : tensor<1xindex>
@ -177,7 +177,7 @@ func @addUnrankedUnranked(
// CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RHS_RANK]], %[[C0]] : index
// Handle scalar RHS case
// CHECK-NEXT: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[SCALAR_RHS:.*]] = tensor_cast %[[RHS]] : tensor<*xf32> to tensor<f32>
// CHECK-NEXT: %[[SCALAR_RHS:.*]] = tensor.cast %[[RHS]] : tensor<*xf32> to tensor<f32>
// CHECK-NEXT: %[[NUM_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[NUM_TENS_LHS:.*]] = tensor_from_elements %[[NUM_LHS]] : tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[NUM_TENS_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
@ -205,13 +205,13 @@ func @addUnrankedUnranked(
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1]
// CHECK-NEXT: %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_1:.*]] = tensor_cast %[[BROADCASTED_LHS_1]] : tensor<?xindex> to tensor<1xindex>
// CHECK-NEXT: %[[CASTED_LHS_1:.*]] = tensor.cast %[[BROADCASTED_LHS_1]] : tensor<?xindex> to tensor<1xindex>
// CHECK-NEXT: %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_1:.*]] = tensor_cast %[[BROADCASTED_RHS_1]] : tensor<?xindex> to tensor<1xindex>
// CHECK-NEXT: %[[CASTED_RHS_1:.*]] = tensor.cast %[[BROADCASTED_RHS_1]] : tensor<?xindex> to tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_1]], %[[RESHAPED_RHS_1]] : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESULT_1:.*]] = tensor_cast %[[RESULT_RANK_1]] : tensor<?xf32> to tensor<*xf32>
// CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1]] : tensor<?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[C2:.*]] = constant 2 : index
@ -220,13 +220,13 @@ func @addUnrankedUnranked(
// CHECK-NEXT: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1]
// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_2:.*]] = tensor_cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex>
// CHECK-NEXT: %[[CASTED_LHS_2:.*]] = tensor.cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex>
// CHECK-NEXT: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_2:.*]] = tensor_cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex>
// CHECK-NEXT: %[[CASTED_RHS_2:.*]] = tensor.cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-NEXT: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-NEXT: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NEXT: %[[RESULT_2:.*]] = tensor_cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32>
// CHECK-NEXT: %[[RESULT_2:.*]] = tensor.cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_2]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[C3:.*]] = constant 3 : index
@ -235,13 +235,13 @@ func @addUnrankedUnranked(
// CHECK-NEXT: %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
// CHECK-NEXT: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_3:.*]] = tensor_cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex>
// CHECK-NEXT: %[[CASTED_LHS_3:.*]] = tensor.cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex>
// CHECK-NEXT: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_3:.*]] = tensor_cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex>
// CHECK-NEXT: %[[CASTED_RHS_3:.*]] = tensor.cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
// CHECK-NEXT: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
// CHECK-NEXT: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK-NEXT: %[[RESULT_3:.*]] = tensor_cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32>
// CHECK-NEXT: %[[RESULT_3:.*]] = tensor.cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_3]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[C4:.*]] = constant 4 : index
@ -250,13 +250,13 @@ func @addUnrankedUnranked(
// CHECK-NEXT: %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
// CHECK-NEXT: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_4:.*]] = tensor_cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex>
// CHECK-NEXT: %[[CASTED_LHS_4:.*]] = tensor.cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex>
// CHECK-NEXT: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_4:.*]] = tensor_cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex>
// CHECK-NEXT: %[[CASTED_RHS_4:.*]] = tensor.cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-NEXT: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_4:.*]] = tensor_cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32>
// CHECK-NEXT: %[[RESULT_4:.*]] = tensor.cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_4]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[C5:.*]] = constant 5 : index
@ -265,13 +265,13 @@ func @addUnrankedUnranked(
// CHECK-NEXT: %[[VAL_50:.*]] = scf.if %[[GREATEST_RANK_IS_5]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
// CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor_cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor.cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
// CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor_cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor.cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_5:.*]] = tensor_cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
// CHECK-NEXT: %[[RESULT_5:.*]] = tensor.cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_5]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[C6:.*]] = constant 6 : index
@ -280,13 +280,13 @@ func @addUnrankedUnranked(
// Handle rank 6 specialization
// CHECK-NEXT: %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
// CHECK-NEXT: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_6:.*]] = tensor_cast %[[BROADCASTED_LHS_6]] : tensor<?xindex> to tensor<6xindex>
// CHECK-NEXT: %[[CASTED_LHS_6:.*]] = tensor.cast %[[BROADCASTED_LHS_6]] : tensor<?xindex> to tensor<6xindex>
// CHECK-NEXT: %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_6:.*]] = tensor_cast %[[BROADCASTED_RHS_6]] : tensor<?xindex> to tensor<6xindex>
// CHECK-NEXT: %[[CASTED_RHS_6:.*]] = tensor.cast %[[BROADCASTED_RHS_6]] : tensor<?xindex> to tensor<6xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_6:.*]] = tensor_cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32>
// CHECK-NEXT: %[[RESULT_6:.*]] = tensor.cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_6]] : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[VAL_65:.*]] : tensor<*xf32>

View File

@ -375,7 +375,7 @@ func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: inde
// -----
// Confirm that tiling information is passed through tensor_load, tensor_cast
// Confirm that tiling information is passed through tensor_load, tensor.cast
// and memref_to_tensor operations.
func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
-> memref<?xf32> {
@ -390,7 +390,7 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
linalg.yield %13 : f32
}
%2 = tensor_load %1 : memref<32xf32>
%3 = tensor_cast %2 : tensor<32xf32> to tensor<?xf32>
%3 = tensor.cast %2 : tensor<32xf32> to tensor<?xf32>
%4 = tensor_to_memref %3 : memref<?xf32>
return %4 : memref<?xf32>
}
@ -403,7 +403,7 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
// CHECK: linalg.generic
// CHECK: absf
// CHECK: tensor_load
// CHECK: tensor_cast
// CHECK: tensor.cast
// CHECK: tensor_to_memref
// TILED-LABEL: func @tensor_ops
@ -414,7 +414,7 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
// TILED: linalg.generic
// TILED: absf
// TILED: tensor_load
// TILED: tensor_cast
// TILED: tensor.cast
// TILED: tensor_to_memref
@ -425,5 +425,5 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
// PLOOP: linalg.generic
// PLOOP: absf
// PLOOP: tensor_load
// PLOOP: tensor_cast
// PLOOP: tensor.cast
// PLOOP: tensor_to_memref