diff --git a/BUILD b/BUILD index 44e369e..605d232 100644 --- a/BUILD +++ b/BUILD @@ -464,6 +464,7 @@ cc_library( "@llvm-project//mlir:SideEffects", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], @@ -688,6 +689,7 @@ cc_library( "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Shape", "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:Transforms", ], alwayslink = 1, @@ -727,6 +729,7 @@ cc_library( "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:ViewLikeInterface", ], @@ -972,6 +975,7 @@ cc_library( "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Shape", "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:Transforms", ], ) @@ -1038,6 +1042,7 @@ cc_library( "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Shape", "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:Transforms", ], alwayslink = 1, diff --git a/WORKSPACE b/WORKSPACE index 424bc54..65663d8 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -15,9 +15,9 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -LLVM_COMMIT = "1b97cdf885d6455841280b8da858835e641ee941" +LLVM_COMMIT = "c3acda0798f9b10ac3187ad941bbd8af82fb84a1" -LLVM_SHA256 = "80d5036ba734fcb700a5699e2f99e5a0de5808dde01a1df3c4fae04510bc8e23" +LLVM_SHA256 = "bd707c585368c86a4d9de1f262d39adb230f7dac889aa786b2721bf67b447a8c" LLVM_BAZEL_TAG = "llvm-project-{commit}".format(commit = LLVM_COMMIT) diff --git a/build_tools/llvm_version.txt b/build_tools/llvm_version.txt index 42543b5..df102d5 100644 --- a/build_tools/llvm_version.txt +++ b/build_tools/llvm_version.txt @@ -1,2 +1,2 @@ -1b97cdf885d6455841280b8da858835e641ee941 +c3acda0798f9b10ac3187ad941bbd8af82fb84a1 diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc index a9102cb..e954917 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" @@ -66,7 +67,8 @@ struct ConvertConstantLikeOp : public OpConversionPattern { loc, extent_tensor_type, transformed.operand()); Type shape_ty = RankedTensorType::get({result_ty.getRank()}, rewriter.getIndexType()); - Value shape = rewriter.create(loc, shape_ty, uncasted_shape); + Value shape = + rewriter.create(loc, shape_ty, uncasted_shape); rewriter.replaceOpWithNewOp( op, result_ty, constant, shape, rewriter.getI64TensorAttr({})); return success(); diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc index ce63f6b..d1f3913 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Pass/Pass.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" namespace mlir { namespace mhlo { @@ -43,6 +44,7 @@ struct ChloLegalizeToHloPass // The conversion uses helpers from the standard dialect. conversionTarget.addLegalDialect(); + conversionTarget.addLegalDialect(); conversionTarget.addLegalDialect(); conversionTarget.addLegalDialect(); diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 1a153dd..9ea80fd 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -70,6 +70,34 @@ bool VerifyHloOpBufferOrTensorSemantics(Operation* op) { : llvm::all_of(op->getResults(), verify_type); } +// TODO(pifon): Migrate to InitTensorOp when available. +template +Value GetInitTensor(OpBuilder& b, Location loc, ShapedType type, + SmallVectorImpl& dyn_sizes) { + if (isLHLO) return nullptr; + return b.create(loc, dyn_sizes, type.getShape(), + type.getElementType()); +} + +template +Value GetInitTensor(OpBuilder& b, Location loc, ShapedType type) { + SmallVector empty; + return GetInitTensor(b, loc, type, empty); +} + +// TODO(pifon): This logic is used everywhere, the code should be shared. +SmallVector ExtractDynamicSizes(OpBuilder& b, Location loc, + Value tensor) { + auto tensor_type = tensor.getType().dyn_cast(); + if (!tensor_type) return {}; + SmallVector dyn_sizes; + for (auto& en : llvm::enumerate(tensor_type.getShape())) { + if (en.value() != ShapedType::kDynamicSize) continue; + dyn_sizes.push_back(b.create(loc, tensor, en.index())); + } + return dyn_sizes; +} + template class PointwiseToLinalgConverter : public OpConversionPattern { public: @@ -113,18 +141,19 @@ class PointwiseToLinalgConverter : public OpConversionPattern { for (Value in : inputs) body_arg_types.emplace_back(getElementTypeOrSelf(in.getType())); - ValueRange output_buffers(args.take_back(args.size() - num_inputs)); - for (Value out : output_buffers) - body_result_types.emplace_back(getElementTypeOrSelf(out.getType())); - - if (!isLHLO) { - // HLO operations have return as tensor types. - assert(body_result_types.empty() && - "When lowering HLO ops result can't be part of arguments"); + SmallVector output_buffers; + if (isLHLO) { + output_buffers.append(args.begin() + num_inputs, args.end()); + } else { Value result = op.getOperation()->getResult(0); - body_result_types.push_back(getElementTypeOrSelf(result)); + ShapedType result_type = result.getType().template cast(); + auto dyn_sizes = ExtractDynamicSizes(rewriter, loc, args[0]); + output_buffers.push_back( + GetInitTensor(rewriter, loc, result_type, dyn_sizes)); op_result_types.push_back(result.getType()); } + body_result_types = llvm::to_vector<4>(llvm::map_range( + output_buffers, [](Value v) { return getElementTypeOrSelf(v); })); AffineMap common_indexing_map = nloops ? rewriter.getMultiDimIdentityMap(nloops) @@ -134,8 +163,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern { bool failed = false; auto linalg_op = rewriter.create( - loc, op_result_types, inputs, output_buffers, - /*initTensors=*/ValueRange{}, indexing_maps, + loc, op_result_types, inputs, output_buffers, indexing_maps, GetNParallelLoopsAttrs(nloops), [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) { // TODO(ravishankarm) : For now use the method in lmhlo namespace. @@ -309,13 +337,19 @@ class DataMovementOpConverter : public OpConversionPattern { auto nloops = result_type.getRank(); auto loc = op.getLoc(); + // TODO(pifon): technically, the op itself could have size operands (e.g. + // broadcast into a dynamic dimension).Handle this case. + auto dyn_sizes = isLHLO ? SmallVector() + : ExtractDynamicSizes(rewriter, loc, args[0]); auto linalg_op = rewriter.create( loc, /*resultTensorTypes=*/isLHLO ? ArrayRef{} : result_type, /*inputs=*/args.front(), - /*outputBuffers=*/isLHLO ? ValueRange{args.back()} : ValueRange{}, - /*initTensor=*/ValueRange{}, indexing_maps, - GetNParallelLoopsAttrs(nloops), + /*outputBuffers=*/ + isLHLO ? ValueRange{args.back()} + : ValueRange{GetInitTensor(rewriter, loc, result_type, + dyn_sizes)}, + indexing_maps, GetNParallelLoopsAttrs(nloops), [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) { nested_builder.create(loc, *args.begin()); }); @@ -712,13 +746,16 @@ class IotaConverter : public OpConversionPattern { // Construct the indexing maps needed for linalg.generic ops. unsigned nloops = result_shaped_type.getRank(); + Location loc = iota_op.getLoc(); auto linalg_op = rewriter.create( - iota_op.getLoc(), + loc, /*resultTensorTypes=*/ isLHLO ? ArrayRef{} : ArrayRef{result_shaped_type}, /*inputs=*/ValueRange{}, - /*outputBuffers=*/isLHLO ? ValueRange{args} : ValueRange{}, - /*initTensors=*/ValueRange{}, + /*outputBuffers=*/ + isLHLO ? ValueRange{args} + : ValueRange{GetInitTensor(rewriter, loc, + result_shaped_type)}, llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), GetNParallelLoopsAttrs(nloops), [&](OpBuilder& nested_builder, Location nested_loc, ValueRange ivs, @@ -818,8 +855,8 @@ class ReduceConverter : public OpConversionPattern { auto linalg_op = rewriter.create( loc, /*resultTensorTypes=*/ArrayRef{}, - /*inputs=*/adaptor.operands(), /*outputBuffers=*/adaptor.out(), - /*initTensors=*/ValueRange{}, maps, types); + /*inputs=*/adaptor.operands(), /*outputBuffers=*/adaptor.out(), maps, + types); rewriter.inlineRegionBefore(reduce_op.body(), linalg_op.region(), linalg_op.region().end()); { diff --git a/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc b/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc index d01f7d0..98a6ab5 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Interfaces/ViewLikeInterface.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -110,7 +111,7 @@ class LhloFuseLinalgPass continue; } - if (auto tensor_cast = dyn_cast(definingOp)) { + if (auto tensor_cast = dyn_cast(definingOp)) { auto alias = tensor_cast.source(); if (result_buffers.insert(alias).second) { worklist.push_back(alias); diff --git a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index b2fef91..ecfc6ae 100644 --- a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" @@ -228,7 +229,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp loc, result_type, IsScalarTensor(rewriter, op, lhs), true); OpBuilder if_lhs_scalar_builder = if_op.getThenBodyBuilder(rewriter.getListener()); - Value reshaped_lhs = if_lhs_scalar_builder.create( + Value reshaped_lhs = if_lhs_scalar_builder.create( loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs); Value if_lhs_scalar_result = if_lhs_scalar_builder.create( loc, ArrayRef{result_type}, ArrayRef{reshaped_lhs, rhs}, @@ -247,7 +248,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp if_rhs_scalar_op.getResult(0)); OpBuilder if_rhs_scalar_builder = if_rhs_scalar_op.getThenBodyBuilder(rewriter.getListener()); - Value reshaped_rhs = if_rhs_scalar_builder.create( + Value reshaped_rhs = if_rhs_scalar_builder.create( loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs); Value if_rhs_scalar_result = if_rhs_scalar_builder.create( loc, ArrayRef{result_type}, ArrayRef{lhs, reshaped_rhs}, @@ -345,12 +346,12 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp Value extended_lhs = if_builder.create( loc, unknown_rank_extent_tensor_type, lhs_shape, ranked_shape_val, nullptr); - Value extended_lhs_casted = if_builder.create( + Value extended_lhs_casted = if_builder.create( loc, known_rank_extent_tensor_type, extended_lhs); Value extended_rhs = if_builder.create( loc, unknown_rank_extent_tensor_type, rhs_shape, ranked_shape_val, nullptr); - Value extended_rhs_casted = if_builder.create( + Value extended_rhs_casted = if_builder.create( loc, known_rank_extent_tensor_type, extended_rhs); // 1. Reshape operands to the given rank (with the same number of elements) @@ -372,7 +373,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp Value result = if_builder.create( loc, ArrayRef{result_type}, ArrayRef{reshaped_lhs, reshaped_rhs}, op.getAttrs()); - Value reshaped_result = if_builder.create( + Value reshaped_result = if_builder.create( loc, UnrankedTensorType::get(result_element_type), result); if_builder.create(loc, reshaped_result); } @@ -446,7 +447,8 @@ struct TransformUnrankedHloPass MLIRContext &ctx = getContext(); ConversionTarget target(ctx); target.addLegalDialect(); + shape::ShapeDialect, scf::SCFDialect, + tensor::TensorDialect>(); target.addLegalOp(); #define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor(&target) #define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor(&target) diff --git a/lib/utils/broadcast_utils.cc b/lib/utils/broadcast_utils.cc index bdd66a1..2810d71 100644 --- a/lib/utils/broadcast_utils.cc +++ b/lib/utils/broadcast_utils.cc @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" @@ -66,7 +67,7 @@ Value ComputeBinaryElementwiseBroadcastingResultExtents( Value result_shape_v = builder.createOrFold( loc, shape::getExtentTensorType(builder.getContext()), lhs_shape_v, rhs_shape_v, nullptr /* error */); - return builder.createOrFold( + return builder.createOrFold( loc, RankedTensorType::get({result_rank}, builder.getIndexType()), result_shape_v); } diff --git a/tests/chlo_legalize_to_hlo_broadcasts.mlir b/tests/chlo_legalize_to_hlo_broadcasts.mlir index a83a29f..c05344f 100644 --- a/tests/chlo_legalize_to_hlo_broadcasts.mlir +++ b/tests/chlo_legalize_to_hlo_broadcasts.mlir @@ -19,7 +19,7 @@ func @dynamicBroadcast(%arg0: tensor, %arg1: tensor) -> tensor to tensor<2xindex> + // CHECK: %[[RESULT_EXTENTS:.+]] = tensor.cast %[[RESULT_S]] : tensor to tensor<2xindex> // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} // CHECK-NEXT: %[[RESULT:.+]] = mhlo.add %[[ARG0_B]], %[[ARG1_B]] @@ -40,7 +40,7 @@ func @dynamicBroadcastComplex(%arg0: tensor, %arg1: tensor) -> t // CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]] // CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] // CHECK-NEXT: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]] - // CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = tensor_cast %[[RESULT_S]] : tensor to tensor<2xindex> + // CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = tensor.cast %[[RESULT_S]] : tensor to tensor<2xindex> // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor // CHECK-NEXT: %[[RESULT:.+]] = "mhlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor, tensor) -> tensor> @@ -61,7 +61,7 @@ func @dynamicBroadcastCompare(%arg0: tensor, %arg1: tensor) -> t // CHECK: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]] // CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] // CHECK: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]] - // CHECK: %[[RESULT_EXTENTS:.+]] = tensor_cast %[[RESULT_S]] : tensor to tensor<2xindex> + // CHECK: %[[RESULT_EXTENTS:.+]] = tensor.cast %[[RESULT_S]] : tensor to tensor<2xindex> // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor // CHECK: %[[RESULT:.+]] = "mhlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor, tensor) -> tensor diff --git a/tests/chlo_legalize_to_mhlo.mlir b/tests/chlo_legalize_to_mhlo.mlir index 5a5197e..42a1154 100644 --- a/tests/chlo_legalize_to_mhlo.mlir +++ b/tests/chlo_legalize_to_mhlo.mlir @@ -16,7 +16,7 @@ func @constant_like_static_shape(%arg : tensor<1x2xi64>) -> tensor<1x2xf32> { func @constant_like_dynamic_shape(%arg : tensor) -> tensor { // CHECK: %[[CONSTANT:.*]] = mhlo.constant dense<3.200000e+00> : tensor // CHECK: %[[UNCASTED_SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor -> tensor - // CHECK: %[[SHAPE:.*]] = tensor_cast %[[UNCASTED_SHAPE]] : tensor to tensor<2xindex> + // CHECK: %[[SHAPE:.*]] = tensor.cast %[[UNCASTED_SHAPE]] : tensor to tensor<2xindex> // CHECK: %[[BROADCASTED_CONSTANT:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[CONSTANT]], %[[SHAPE]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<2xindex>) -> tensor // CHECK: return %[[BROADCASTED_CONSTANT]] : tensor %result = "chlo.constant_like"(%arg) { value = 3.2 : f32 } diff --git a/tests/hlo-legalize-to-lhlo.mlir b/tests/hlo-legalize-to-lhlo.mlir index 80e958b..2834861 100644 --- a/tests/hlo-legalize-to-lhlo.mlir +++ b/tests/hlo-legalize-to-lhlo.mlir @@ -628,7 +628,7 @@ func @shape_assuming_tensor(%arg0: tensor) -> tensor { // CHECK: shape.assuming %{{.*}} -> (memref) %2 = shape.assuming %1 -> (tensor) { %3 = shape.shape_of %arg0 : tensor -> tensor - %4 = tensor_cast %3 : tensor to tensor<1xindex> + %4 = tensor.cast %3 : tensor to tensor<1xindex> %5 = "mhlo.dynamic_broadcast_in_dim"(%0, %4) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor %6 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %4) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor, tensor<1xindex>) -> tensor // CHECK: "lmhlo.maximum"(%{{.*}}, %{{.*}}, %{{.*}}) : (memref, memref, memref) -> () @@ -638,3 +638,5 @@ func @shape_assuming_tensor(%arg0: tensor) -> tensor { } return %2 : tensor } + + diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 71a8b79..63abc02 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -249,8 +249,9 @@ func @float_cmp(%lhs: tensor<2x2xf32>, : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1> return %0 : tensor<2x2xi1> } +// CHECK: linalg.init_tensor [2, 2] : tensor<2x2xi1> // CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32): +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %{{.*}}: i1): // CHECK-NEXT: %[[RESULT:.*]] = cmpf "oeq", %[[LHS_IN]], %[[RHS_IN]] : f32 // CHECK-NEXT: linalg.yield %[[RESULT]] : i1 @@ -263,8 +264,9 @@ func @float_cmp_ne(%lhs: tensor<2x2xf32>, : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1> return %0 : tensor<2x2xi1> } +// CHECK: linalg.init_tensor [2, 2] : tensor<2x2xi1> // CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32): +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %{{.*}}: i1): // CHECK-NEXT: %[[RESULT:.*]] = cmpf "une", %[[LHS_IN]], %[[RHS_IN]] : f32 // CHECK-NEXT: linalg.yield %[[RESULT]] : i1 @@ -277,8 +279,9 @@ func @int_cmp(%lhs: tensor<2x2xi32>, : (tensor<2x2xi32>, tensor<2x2xi32>) -> (tensor<2x2xi1>) return %0 : tensor<2x2xi1> } +// CHECK: linalg.init_tensor [2, 2] : tensor<2x2xi1> // CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32): +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32, %{{.*}}: i1): // CHECK-NEXT: %[[RESULT:.*]] = cmpi "slt", %[[LHS_IN]], %[[RHS_IN]] : i32 // CHECK-NEXT: linalg.yield %[[RESULT]] : i1 @@ -335,8 +338,9 @@ func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>, : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>) return %0 : tensor<2x2xf32> } +// CHECK: linalg.init_tensor [2, 2] : tensor<2x2xf32> // CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[PRED_IN:.*]]: i1, %[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32): +// CHECK-NEXT: ^bb0(%[[PRED_IN:.*]]: i1, %[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %{{.*}}: f32): // CHECK-NEXT: %[[RESULT:.*]] = select %[[PRED_IN]], %[[LHS_IN]], %[[RHS_IN]] : f32 // CHECK-NEXT: linalg.yield %[[RESULT]] : f32 @@ -349,8 +353,9 @@ func @broadcast_scalar(%arg: tensor) -> tensor<4x2x1xf32> { %0 = "mhlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor) -> tensor<4x2x1xf32> return %0: tensor<4x2x1xf32> } +// CHECK: linalg.init_tensor [4, 2, 1] : tensor<4x2x1xf32> // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32): +// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32): // CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 // ----- @@ -362,8 +367,11 @@ func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> { %0 = "mhlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> return %0: tensor<4x2x1x4x?x16xf32> } +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[D1:.*]] = dim %{{.*}}, %[[C1]] : tensor<4x?x16xf32> +// CHECK: linalg.init_tensor [4, 2, 1, 4, %[[D1]], 16] : tensor<4x2x1x4x?x16xf32> // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32): +// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32): // CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 // ----- @@ -377,8 +385,9 @@ func @broadcast_in_dim(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> { : (tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> return %0 : tensor<7x10x6x4x5xf32> } +// CHECK: linalg.init_tensor [7, 10, 6, 4, 5] : tensor<7x10x6x4x5xf32> // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32): +// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32): // CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 // ----- @@ -393,8 +402,9 @@ func @broadcast_in_dim_with_one_to_one( : (tensor<1xf32>) -> tensor<1x5xf32> return %0 : tensor<1x5xf32> } +// CHECK: linalg.init_tensor [1, 5] : tensor<1x5xf32> // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32): +// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32): // CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 // ----- @@ -408,8 +418,9 @@ func @broadcast_scalar(%operand: tensor) -> tensor<7x10x6xf32> { : (tensor) -> tensor<7x10x6xf32> return %0 : tensor<7x10x6xf32> } +// CHECK: linalg.init_tensor [7, 10, 6] : tensor<7x10x6xf32> // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32): +// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32): // CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 // ----- @@ -499,8 +510,9 @@ func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } +// CHECK: linalg.init_tensor [2, 2] : tensor<2x2xf32> // CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32): +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %{{.*}}: f32): // CHECK-NEXT: %[[CMP:.*]] = cmpf "olt", %[[LHS_IN]], %[[RHS_IN]] : f32 // CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : f32 // CHECK-NEXT: linalg.yield %[[RESULT]] : f32 @@ -513,8 +525,9 @@ func @maxi(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> } +// CHECK: linalg.init_tensor [2, 2] : tensor<2x2xi32> // CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32): +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32, %{{.*}}: i32): // CHECK-NEXT: %[[CMP:.*]] = cmpi "sgt", %[[LHS_IN]], %[[RHS_IN]] : i32 // CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32 // CHECK-NEXT: linalg.yield %[[RESULT]] : i32 @@ -527,9 +540,10 @@ func @add_scalar(%lhs: tensor, %rhs: tensor) -> tensor { %0 = "mhlo.add"(%lhs, %rhs) : (tensor, tensor) -> tensor return %0 : tensor } +// CHECK: linalg.init_tensor // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]] -// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): +// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32, %{{.*}}: f32): // CHECK: %[[RESULT:.*]] = addf %[[LHS]], %[[RHS]] // CHECK-NEXT: linalg.yield %[[RESULT]] : f32 @@ -599,8 +613,9 @@ func @convert_i32_to_f32(%input: tensor<2x2xi32>) -> tensor<2x2xf32> { %result = "mhlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xf32> return %result : tensor<2x2xf32> } +// CHECK: linalg.init_tensor // CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32): +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %{{.*}}: f32): // CHECK-NEXT: %[[RESULT:.*]] = sitofp %[[OPERAND_IN]] : i32 to f32 // CHECK-NEXT: linalg.yield %[[RESULT]] : f32 @@ -611,8 +626,9 @@ func @convert_i16_to_i32(%input: tensor<2x2xi16>) -> tensor<2x2xi32> { %result = "mhlo.convert"(%input) : (tensor<2x2xi16>) -> tensor<2x2xi32> return %result : tensor<2x2xi32> } +// CHECK: linalg.init_tensor // CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i16): +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i16, %{{.*}}: i32): // CHECK-NEXT: %[[RESULT:.*]] = zexti %[[OPERAND_IN]] : i16 to i32 // CHECK-NEXT: linalg.yield %[[RESULT]] : i32 @@ -623,8 +639,9 @@ func @convert_i32_to_i16(%input: tensor<2x2xi32>) -> tensor<2x2xi16> { %result = "mhlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xi16> return %result : tensor<2x2xi16> } +// CHECK: linalg.init_tensor // CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32): +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %{{.*}}: i16): // CHECK-NEXT: %[[RESULT:.*]] = trunci %[[OPERAND_IN]] : i32 to i16 // CHECK-NEXT: linalg.yield %[[RESULT]] : i16 @@ -635,8 +652,9 @@ func @convert_f32_to_f64(%input: tensor<2x2xf32>) -> tensor<2x2xf64> { %result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xf64> return %result : tensor<2x2xf64> } +// CHECK: linalg.init_tensor // CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32): +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %{{.*}}: f64): // CHECK-NEXT: %[[RESULT:.*]] = fpext %[[OPERAND_IN]] : f32 to f64 // CHECK-NEXT: linalg.yield %[[RESULT]] : f64 @@ -647,8 +665,9 @@ func @convert_f64_to_f32(%input: tensor<2x2xf64>) -> tensor<2x2xf32> { %result = "mhlo.convert"(%input) : (tensor<2x2xf64>) -> tensor<2x2xf32> return %result : tensor<2x2xf32> } +// CHECK: linalg.init_tensor // CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f64): +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f64, %{{.*}}: f32): // CHECK-NEXT: %[[RESULT:.*]] = fptrunc %[[OPERAND_IN]] : f64 to f32 // CHECK-NEXT: linalg.yield %[[RESULT]] : f32 @@ -659,8 +678,9 @@ func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> { %result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi32> return %result : tensor<2x2xi32> } +// CHECK: linalg.init_tensor // CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32): +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %{{.*}}: i32): // CHECK-NEXT: %[[RESULT:.*]] = fptosi %[[OPERAND_IN]] : f32 to i32 // CHECK-NEXT: linalg.yield %[[RESULT]] : i32 @@ -686,9 +706,10 @@ func @iota() -> tensor<7x10xf32> { %result = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<7x10xf32>) return %result : tensor<7x10xf32> } +// CHECK: linalg.init_tensor // CHECK: linalg.indexed_generic // CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index): +// CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index, %{{.*}}: f32): // CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32 // CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32 // CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32 @@ -702,8 +723,9 @@ func @shift_left(%lhs: tensor<2x2xi32>, return %result : tensor<2x2xi32> } // CHECK-LABEL: func @shift_left +// CHECK: linalg.init_tensor // CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32): +// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32, %{{.*}}: i32): // CHECK-NEXT: %[[RESULT:.*]] = shift_left %[[LHS]], %[[RHS]] : i32 // CHECK-NEXT: linalg.yield %[[RESULT]] : i32 @@ -716,8 +738,9 @@ func @shift_right_arithmetic(%lhs: tensor<2x2xi32>, return %result : tensor<2x2xi32> } // CHECK-LABEL: func @shift_right_arithmetic +// CHECK: linalg.init_tensor // CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32): +// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32, %{{.*}}: i32): // CHECK-NEXT: %[[RESULT:.*]] = shift_right_signed %[[LHS]], %[[RHS]] : i32 // CHECK-NEXT: linalg.yield %[[RESULT]] : i32 @@ -730,8 +753,9 @@ func @shift_right_logical(%lhs: tensor<2x2xi32>, return %result : tensor<2x2xi32> } // CHECK-LABEL: func @shift_right_logical +// CHECK: linalg.init_tensor // CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32): +// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32, %{{.*}}: i32): // CHECK-NEXT: %[[RESULT:.*]] = shift_right_unsigned %[[LHS]], %[[RHS]] : i32 // CHECK-NEXT: linalg.yield %[[RESULT]] : i32 diff --git a/tests/hlo-transform-unranked.mlir b/tests/hlo-transform-unranked.mlir index af83b4a..7c96e1e 100644 --- a/tests/hlo-transform-unranked.mlir +++ b/tests/hlo-transform-unranked.mlir @@ -163,7 +163,7 @@ func @addUnrankedUnranked( // CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[LHS_RANK]], %[[C0]] : index // Handle scalar LHS case // CHECK-NEXT: %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) { -// CHECK-NEXT: %[[SCALAR_LHS:.*]] = tensor_cast %[[LHS]] : tensor<*xf32> to tensor +// CHECK-NEXT: %[[SCALAR_LHS:.*]] = tensor.cast %[[LHS]] : tensor<*xf32> to tensor // CHECK-NEXT: %[[RHS_SHAPE_1:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor // CHECK-NEXT: %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE_1]] : tensor -> index // CHECK-NEXT: %[[NUM_TENS_RHS:.*]] = tensor_from_elements %[[NUM_RHS]] : tensor<1xindex> @@ -177,7 +177,7 @@ func @addUnrankedUnranked( // CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RHS_RANK]], %[[C0]] : index // Handle scalar RHS case // CHECK-NEXT: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) { -// CHECK-NEXT: %[[SCALAR_RHS:.*]] = tensor_cast %[[RHS]] : tensor<*xf32> to tensor +// CHECK-NEXT: %[[SCALAR_RHS:.*]] = tensor.cast %[[RHS]] : tensor<*xf32> to tensor // CHECK-NEXT: %[[NUM_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor -> index // CHECK-NEXT: %[[NUM_TENS_LHS:.*]] = tensor_from_elements %[[NUM_LHS]] : tensor<1xindex> // CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[NUM_TENS_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor @@ -205,13 +205,13 @@ func @addUnrankedUnranked( // CHECK-NEXT: %[[RESULT_RANK_1:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) { // CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1] // CHECK-NEXT: %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor -// CHECK-NEXT: %[[CASTED_LHS_1:.*]] = tensor_cast %[[BROADCASTED_LHS_1]] : tensor to tensor<1xindex> +// CHECK-NEXT: %[[CASTED_LHS_1:.*]] = tensor.cast %[[BROADCASTED_LHS_1]] : tensor to tensor<1xindex> // CHECK-NEXT: %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor -// CHECK-NEXT: %[[CASTED_RHS_1:.*]] = tensor_cast %[[BROADCASTED_RHS_1]] : tensor to tensor<1xindex> +// CHECK-NEXT: %[[CASTED_RHS_1:.*]] = tensor.cast %[[BROADCASTED_RHS_1]] : tensor to tensor<1xindex> // CHECK-NEXT: %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK-NEXT: %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_1]], %[[RESHAPED_RHS_1]] : (tensor, tensor) -> tensor -// CHECK-NEXT: %[[RESULT_1:.*]] = tensor_cast %[[RESULT_RANK_1]] : tensor to tensor<*xf32> +// CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1]] : tensor to tensor<*xf32> // CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32> // CHECK-NEXT: } else { // CHECK-NEXT: %[[C2:.*]] = constant 2 : index @@ -220,13 +220,13 @@ func @addUnrankedUnranked( // CHECK-NEXT: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) { // CHECK-NEXT: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1] // CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor, tensor<2xindex> -> tensor -// CHECK-NEXT: %[[CASTED_LHS_2:.*]] = tensor_cast %[[BROADCASTED_LHS_2]] : tensor to tensor<2xindex> +// CHECK-NEXT: %[[CASTED_LHS_2:.*]] = tensor.cast %[[BROADCASTED_LHS_2]] : tensor to tensor<2xindex> // CHECK-NEXT: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor, tensor<2xindex> -> tensor -// CHECK-NEXT: %[[CASTED_RHS_2:.*]] = tensor_cast %[[BROADCASTED_RHS_2]] : tensor to tensor<2xindex> +// CHECK-NEXT: %[[CASTED_RHS_2:.*]] = tensor.cast %[[BROADCASTED_RHS_2]] : tensor to tensor<2xindex> // CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor // CHECK-NEXT: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor // CHECK-NEXT: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor, tensor) -> tensor -// CHECK-NEXT: %[[RESULT_2:.*]] = tensor_cast %[[RESULT_RANK_2]] : tensor to tensor<*xf32> +// CHECK-NEXT: %[[RESULT_2:.*]] = tensor.cast %[[RESULT_RANK_2]] : tensor to tensor<*xf32> // CHECK-NEXT: scf.yield %[[RESULT_2]] : tensor<*xf32> // CHECK-NEXT: } else { // CHECK-NEXT: %[[C3:.*]] = constant 3 : index @@ -235,13 +235,13 @@ func @addUnrankedUnranked( // CHECK-NEXT: %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) { // CHECK-NEXT: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1] // CHECK-NEXT: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor, tensor<3xindex> -> tensor -// CHECK-NEXT: %[[CASTED_LHS_3:.*]] = tensor_cast %[[BROADCASTED_LHS_3]] : tensor to tensor<3xindex> +// CHECK-NEXT: %[[CASTED_LHS_3:.*]] = tensor.cast %[[BROADCASTED_LHS_3]] : tensor to tensor<3xindex> // CHECK-NEXT: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor, tensor<3xindex> -> tensor -// CHECK-NEXT: %[[CASTED_RHS_3:.*]] = tensor_cast %[[BROADCASTED_RHS_3]] : tensor to tensor<3xindex> +// CHECK-NEXT: %[[CASTED_RHS_3:.*]] = tensor.cast %[[BROADCASTED_RHS_3]] : tensor to tensor<3xindex> // CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor // CHECK-NEXT: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor // CHECK-NEXT: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor, tensor) -> tensor -// CHECK-NEXT: %[[RESULT_3:.*]] = tensor_cast %[[RESULT_RANK_3]] : tensor to tensor<*xf32> +// CHECK-NEXT: %[[RESULT_3:.*]] = tensor.cast %[[RESULT_RANK_3]] : tensor to tensor<*xf32> // CHECK-NEXT: scf.yield %[[RESULT_3]] : tensor<*xf32> // CHECK-NEXT: } else { // CHECK-NEXT: %[[C4:.*]] = constant 4 : index @@ -250,13 +250,13 @@ func @addUnrankedUnranked( // CHECK-NEXT: %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) { // CHECK-NEXT: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1] // CHECK-NEXT: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor, tensor<4xindex> -> tensor -// CHECK-NEXT: %[[CASTED_LHS_4:.*]] = tensor_cast %[[BROADCASTED_LHS_4]] : tensor to tensor<4xindex> +// CHECK-NEXT: %[[CASTED_LHS_4:.*]] = tensor.cast %[[BROADCASTED_LHS_4]] : tensor to tensor<4xindex> // CHECK-NEXT: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor, tensor<4xindex> -> tensor -// CHECK-NEXT: %[[CASTED_RHS_4:.*]] = tensor_cast %[[BROADCASTED_RHS_4]] : tensor to tensor<4xindex> +// CHECK-NEXT: %[[CASTED_RHS_4:.*]] = tensor.cast %[[BROADCASTED_RHS_4]] : tensor to tensor<4xindex> // CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor // CHECK-NEXT: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor // CHECK-NEXT: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor, tensor) -> tensor -// CHECK-NEXT: %[[RESULT_4:.*]] = tensor_cast %[[RESULT_RANK_4]] : tensor to tensor<*xf32> +// CHECK-NEXT: %[[RESULT_4:.*]] = tensor.cast %[[RESULT_RANK_4]] : tensor to tensor<*xf32> // CHECK-NEXT: scf.yield %[[RESULT_4]] : tensor<*xf32> // CHECK-NEXT: } else { // CHECK-NEXT: %[[C5:.*]] = constant 5 : index @@ -265,13 +265,13 @@ func @addUnrankedUnranked( // CHECK-NEXT: %[[VAL_50:.*]] = scf.if %[[GREATEST_RANK_IS_5]] -> (tensor<*xf32>) { // CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1] // CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor, tensor<5xindex> -> tensor -// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor_cast %[[BROADCASTED_LHS_5]] : tensor to tensor<5xindex> +// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor.cast %[[BROADCASTED_LHS_5]] : tensor to tensor<5xindex> // CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor, tensor<5xindex> -> tensor -// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor_cast %[[BROADCASTED_RHS_5]] : tensor to tensor<5xindex> +// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor.cast %[[BROADCASTED_RHS_5]] : tensor to tensor<5xindex> // CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor // CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor // CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor, tensor) -> tensor -// CHECK-NEXT: %[[RESULT_5:.*]] = tensor_cast %[[RESULT_RANK_5]] : tensor to tensor<*xf32> +// CHECK-NEXT: %[[RESULT_5:.*]] = tensor.cast %[[RESULT_RANK_5]] : tensor to tensor<*xf32> // CHECK-NEXT: scf.yield %[[RESULT_5]] : tensor<*xf32> // CHECK-NEXT: } else { // CHECK-NEXT: %[[C6:.*]] = constant 6 : index @@ -280,13 +280,13 @@ func @addUnrankedUnranked( // Handle rank 6 specialization // CHECK-NEXT: %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1] // CHECK-NEXT: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor, tensor<6xindex> -> tensor -// CHECK-NEXT: %[[CASTED_LHS_6:.*]] = tensor_cast %[[BROADCASTED_LHS_6]] : tensor to tensor<6xindex> +// CHECK-NEXT: %[[CASTED_LHS_6:.*]] = tensor.cast %[[BROADCASTED_LHS_6]] : tensor to tensor<6xindex> // CHECK-NEXT: %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor, tensor<6xindex> -> tensor -// CHECK-NEXT: %[[CASTED_RHS_6:.*]] = tensor_cast %[[BROADCASTED_RHS_6]] : tensor to tensor<6xindex> +// CHECK-NEXT: %[[CASTED_RHS_6:.*]] = tensor.cast %[[BROADCASTED_RHS_6]] : tensor to tensor<6xindex> // CHECK-NEXT: %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor // CHECK-NEXT: %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor // CHECK-NEXT: %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor, tensor) -> tensor -// CHECK-NEXT: %[[RESULT_6:.*]] = tensor_cast %[[RESULT_RANK_6]] : tensor to tensor<*xf32> +// CHECK-NEXT: %[[RESULT_6:.*]] = tensor.cast %[[RESULT_RANK_6]] : tensor to tensor<*xf32> // CHECK-NEXT: scf.yield %[[RESULT_6]] : tensor<*xf32> // CHECK-NEXT: } // CHECK-NEXT: scf.yield %[[VAL_65:.*]] : tensor<*xf32> diff --git a/tests/lhlo-fuse-linalg.mlir b/tests/lhlo-fuse-linalg.mlir index 8b4a712..54aceaf 100644 --- a/tests/lhlo-fuse-linalg.mlir +++ b/tests/lhlo-fuse-linalg.mlir @@ -375,7 +375,7 @@ func @branching_result(%arg0: memref, %arg1: memref, %arg2: inde // ----- -// Confirm that tiling information is passed through tensor_load, tensor_cast +// Confirm that tiling information is passed through tensor_load, tensor.cast // and memref_to_tensor operations. func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>) -> memref { @@ -390,7 +390,7 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>) linalg.yield %13 : f32 } %2 = tensor_load %1 : memref<32xf32> - %3 = tensor_cast %2 : tensor<32xf32> to tensor + %3 = tensor.cast %2 : tensor<32xf32> to tensor %4 = tensor_to_memref %3 : memref return %4 : memref } @@ -403,7 +403,7 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>) // CHECK: linalg.generic // CHECK: absf // CHECK: tensor_load -// CHECK: tensor_cast +// CHECK: tensor.cast // CHECK: tensor_to_memref // TILED-LABEL: func @tensor_ops @@ -414,7 +414,7 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>) // TILED: linalg.generic // TILED: absf // TILED: tensor_load -// TILED: tensor_cast +// TILED: tensor.cast // TILED: tensor_to_memref @@ -425,5 +425,5 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>) // PLOOP: linalg.generic // PLOOP: absf // PLOOP: tensor_load -// PLOOP: tensor_cast +// PLOOP: tensor.cast // PLOOP: tensor_to_memref