Integrate LLVM at llvm/llvm-project@c3acda0798
Updates LLVM usage to match [c3acda0798f9](https://github.com/llvm/llvm-project/commit/c3acda0798f9) PiperOrigin-RevId: 348896724
This commit is contained in:
		
							parent
							
								
									e3754d7b5c
								
							
						
					
					
						commit
						b0bf2ef45b
					
				
							
								
								
									
										5
									
								
								BUILD
								
								
								
								
							
							
						
						
									
										5
									
								
								BUILD
								
								
								
								
							| 
						 | 
				
			
			@ -464,6 +464,7 @@ cc_library(
 | 
			
		|||
        "@llvm-project//mlir:SideEffects",
 | 
			
		||||
        "@llvm-project//mlir:StandardOps",
 | 
			
		||||
        "@llvm-project//mlir:Support",
 | 
			
		||||
        "@llvm-project//mlir:TensorDialect",
 | 
			
		||||
        "@llvm-project//mlir:TransformUtils",
 | 
			
		||||
        "@llvm-project//mlir:Transforms",
 | 
			
		||||
    ],
 | 
			
		||||
| 
						 | 
				
			
			@ -688,6 +689,7 @@ cc_library(
 | 
			
		|||
        "@llvm-project//mlir:SCFDialect",
 | 
			
		||||
        "@llvm-project//mlir:Shape",
 | 
			
		||||
        "@llvm-project//mlir:StandardOps",
 | 
			
		||||
        "@llvm-project//mlir:TensorDialect",
 | 
			
		||||
        "@llvm-project//mlir:Transforms",
 | 
			
		||||
    ],
 | 
			
		||||
    alwayslink = 1,
 | 
			
		||||
| 
						 | 
				
			
			@ -727,6 +729,7 @@ cc_library(
 | 
			
		|||
        "@llvm-project//mlir:SCFDialect",
 | 
			
		||||
        "@llvm-project//mlir:StandardOps",
 | 
			
		||||
        "@llvm-project//mlir:Support",
 | 
			
		||||
        "@llvm-project//mlir:TensorDialect",
 | 
			
		||||
        "@llvm-project//mlir:TransformUtils",
 | 
			
		||||
        "@llvm-project//mlir:ViewLikeInterface",
 | 
			
		||||
    ],
 | 
			
		||||
| 
						 | 
				
			
			@ -972,6 +975,7 @@ cc_library(
 | 
			
		|||
        "@llvm-project//mlir:SCFDialect",
 | 
			
		||||
        "@llvm-project//mlir:Shape",
 | 
			
		||||
        "@llvm-project//mlir:StandardOps",
 | 
			
		||||
        "@llvm-project//mlir:TensorDialect",
 | 
			
		||||
        "@llvm-project//mlir:Transforms",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			@ -1038,6 +1042,7 @@ cc_library(
 | 
			
		|||
        "@llvm-project//mlir:SCFDialect",
 | 
			
		||||
        "@llvm-project//mlir:Shape",
 | 
			
		||||
        "@llvm-project//mlir:StandardOps",
 | 
			
		||||
        "@llvm-project//mlir:TensorDialect",
 | 
			
		||||
        "@llvm-project//mlir:Transforms",
 | 
			
		||||
    ],
 | 
			
		||||
    alwayslink = 1,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -15,9 +15,9 @@
 | 
			
		|||
 | 
			
		||||
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
 | 
			
		||||
 | 
			
		||||
LLVM_COMMIT = "1b97cdf885d6455841280b8da858835e641ee941"
 | 
			
		||||
LLVM_COMMIT = "c3acda0798f9b10ac3187ad941bbd8af82fb84a1"
 | 
			
		||||
 | 
			
		||||
LLVM_SHA256 = "80d5036ba734fcb700a5699e2f99e5a0de5808dde01a1df3c4fae04510bc8e23"
 | 
			
		||||
LLVM_SHA256 = "bd707c585368c86a4d9de1f262d39adb230f7dac889aa786b2721bf67b447a8c"
 | 
			
		||||
 | 
			
		||||
LLVM_BAZEL_TAG = "llvm-project-{commit}".format(commit = LLVM_COMMIT)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,2 +1,2 @@
 | 
			
		|||
1b97cdf885d6455841280b8da858835e641ee941
 | 
			
		||||
c3acda0798f9b10ac3187ad941bbd8af82fb84a1
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -29,6 +29,7 @@ limitations under the License.
 | 
			
		|||
#include "mlir/Dialect/SCF/SCF.h"
 | 
			
		||||
#include "mlir/Dialect/Shape/IR/Shape.h"
 | 
			
		||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
 | 
			
		||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
 | 
			
		||||
#include "mlir/IR/Attributes.h"
 | 
			
		||||
#include "mlir/IR/BuiltinTypes.h"
 | 
			
		||||
#include "mlir/IR/MLIRContext.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -66,7 +67,8 @@ struct ConvertConstantLikeOp : public OpConversionPattern<ConstantLikeOp> {
 | 
			
		|||
        loc, extent_tensor_type, transformed.operand());
 | 
			
		||||
    Type shape_ty =
 | 
			
		||||
        RankedTensorType::get({result_ty.getRank()}, rewriter.getIndexType());
 | 
			
		||||
    Value shape = rewriter.create<TensorCastOp>(loc, shape_ty, uncasted_shape);
 | 
			
		||||
    Value shape =
 | 
			
		||||
        rewriter.create<tensor::CastOp>(loc, shape_ty, uncasted_shape);
 | 
			
		||||
    rewriter.replaceOpWithNewOp<mhlo::DynamicBroadcastInDimOp>(
 | 
			
		||||
        op, result_ty, constant, shape, rewriter.getI64TensorAttr({}));
 | 
			
		||||
    return success();
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -21,6 +21,7 @@ limitations under the License.
 | 
			
		|||
#include "mlir/Dialect/Shape/IR/Shape.h"
 | 
			
		||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
 | 
			
		||||
#include "mlir/Pass/Pass.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace mhlo {
 | 
			
		||||
| 
						 | 
				
			
			@ -43,6 +44,7 @@ struct ChloLegalizeToHloPass
 | 
			
		|||
 | 
			
		||||
    // The conversion uses helpers from the standard dialect.
 | 
			
		||||
    conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
 | 
			
		||||
    conversionTarget.addLegalDialect<mlir::tensor::TensorDialect>();
 | 
			
		||||
    conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>();
 | 
			
		||||
    conversionTarget.addLegalDialect<mlir::scf::SCFDialect>();
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -70,6 +70,34 @@ bool VerifyHloOpBufferOrTensorSemantics(Operation* op) {
 | 
			
		|||
                : llvm::all_of(op->getResults(), verify_type);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO(pifon): Migrate to InitTensorOp when available.
 | 
			
		||||
template <bool isLHLO>
 | 
			
		||||
Value GetInitTensor(OpBuilder& b, Location loc, ShapedType type,
 | 
			
		||||
                    SmallVectorImpl<Value>& dyn_sizes) {
 | 
			
		||||
  if (isLHLO) return nullptr;
 | 
			
		||||
  return b.create<linalg::InitTensorOp>(loc, dyn_sizes, type.getShape(),
 | 
			
		||||
                                        type.getElementType());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <bool isLHLO>
 | 
			
		||||
Value GetInitTensor(OpBuilder& b, Location loc, ShapedType type) {
 | 
			
		||||
  SmallVector<Value, 0> empty;
 | 
			
		||||
  return GetInitTensor<isLHLO>(b, loc, type, empty);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO(pifon): This logic is used everywhere, the code should be shared.
 | 
			
		||||
SmallVector<Value, 2> ExtractDynamicSizes(OpBuilder& b, Location loc,
 | 
			
		||||
                                          Value tensor) {
 | 
			
		||||
  auto tensor_type = tensor.getType().dyn_cast<RankedTensorType>();
 | 
			
		||||
  if (!tensor_type) return {};
 | 
			
		||||
  SmallVector<Value, 2> dyn_sizes;
 | 
			
		||||
  for (auto& en : llvm::enumerate(tensor_type.getShape())) {
 | 
			
		||||
    if (en.value() != ShapedType::kDynamicSize) continue;
 | 
			
		||||
    dyn_sizes.push_back(b.create<DimOp>(loc, tensor, en.index()));
 | 
			
		||||
  }
 | 
			
		||||
  return dyn_sizes;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename OpTy, bool isLHLO = true>
 | 
			
		||||
class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
 | 
			
		||||
 public:
 | 
			
		||||
| 
						 | 
				
			
			@ -113,18 +141,19 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
 | 
			
		|||
    for (Value in : inputs)
 | 
			
		||||
      body_arg_types.emplace_back(getElementTypeOrSelf(in.getType()));
 | 
			
		||||
 | 
			
		||||
    ValueRange output_buffers(args.take_back(args.size() - num_inputs));
 | 
			
		||||
    for (Value out : output_buffers)
 | 
			
		||||
      body_result_types.emplace_back(getElementTypeOrSelf(out.getType()));
 | 
			
		||||
 | 
			
		||||
    if (!isLHLO) {
 | 
			
		||||
      // HLO operations have return as tensor types.
 | 
			
		||||
      assert(body_result_types.empty() &&
 | 
			
		||||
             "When lowering HLO ops result can't be part of arguments");
 | 
			
		||||
    SmallVector<Value, 4> output_buffers;
 | 
			
		||||
    if (isLHLO) {
 | 
			
		||||
      output_buffers.append(args.begin() + num_inputs, args.end());
 | 
			
		||||
    } else {
 | 
			
		||||
      Value result = op.getOperation()->getResult(0);
 | 
			
		||||
      body_result_types.push_back(getElementTypeOrSelf(result));
 | 
			
		||||
      ShapedType result_type = result.getType().template cast<ShapedType>();
 | 
			
		||||
      auto dyn_sizes = ExtractDynamicSizes(rewriter, loc, args[0]);
 | 
			
		||||
      output_buffers.push_back(
 | 
			
		||||
          GetInitTensor<isLHLO>(rewriter, loc, result_type, dyn_sizes));
 | 
			
		||||
      op_result_types.push_back(result.getType());
 | 
			
		||||
    }
 | 
			
		||||
    body_result_types = llvm::to_vector<4>(llvm::map_range(
 | 
			
		||||
        output_buffers, [](Value v) { return getElementTypeOrSelf(v); }));
 | 
			
		||||
 | 
			
		||||
    AffineMap common_indexing_map =
 | 
			
		||||
        nloops ? rewriter.getMultiDimIdentityMap(nloops)
 | 
			
		||||
| 
						 | 
				
			
			@ -134,8 +163,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
 | 
			
		|||
 | 
			
		||||
    bool failed = false;
 | 
			
		||||
    auto linalg_op = rewriter.create<linalg::GenericOp>(
 | 
			
		||||
        loc, op_result_types, inputs, output_buffers,
 | 
			
		||||
        /*initTensors=*/ValueRange{}, indexing_maps,
 | 
			
		||||
        loc, op_result_types, inputs, output_buffers, indexing_maps,
 | 
			
		||||
        GetNParallelLoopsAttrs(nloops),
 | 
			
		||||
        [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
 | 
			
		||||
          // TODO(ravishankarm) : For now use the method in lmhlo namespace.
 | 
			
		||||
| 
						 | 
				
			
			@ -309,13 +337,19 @@ class DataMovementOpConverter : public OpConversionPattern<OpTy> {
 | 
			
		|||
 | 
			
		||||
    auto nloops = result_type.getRank();
 | 
			
		||||
    auto loc = op.getLoc();
 | 
			
		||||
    // TODO(pifon): technically, the op itself could have size operands (e.g.
 | 
			
		||||
    // broadcast into a dynamic dimension).Handle this case.
 | 
			
		||||
    auto dyn_sizes = isLHLO ? SmallVector<Value, 2>()
 | 
			
		||||
                            : ExtractDynamicSizes(rewriter, loc, args[0]);
 | 
			
		||||
    auto linalg_op = rewriter.create<linalg::GenericOp>(
 | 
			
		||||
        loc,
 | 
			
		||||
        /*resultTensorTypes=*/isLHLO ? ArrayRef<Type>{} : result_type,
 | 
			
		||||
        /*inputs=*/args.front(),
 | 
			
		||||
        /*outputBuffers=*/isLHLO ? ValueRange{args.back()} : ValueRange{},
 | 
			
		||||
        /*initTensor=*/ValueRange{}, indexing_maps,
 | 
			
		||||
        GetNParallelLoopsAttrs(nloops),
 | 
			
		||||
        /*outputBuffers=*/
 | 
			
		||||
        isLHLO ? ValueRange{args.back()}
 | 
			
		||||
               : ValueRange{GetInitTensor<isLHLO>(rewriter, loc, result_type,
 | 
			
		||||
                                                  dyn_sizes)},
 | 
			
		||||
        indexing_maps, GetNParallelLoopsAttrs(nloops),
 | 
			
		||||
        [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
 | 
			
		||||
          nested_builder.create<linalg::YieldOp>(loc, *args.begin());
 | 
			
		||||
        });
 | 
			
		||||
| 
						 | 
				
			
			@ -712,13 +746,16 @@ class IotaConverter : public OpConversionPattern<OpTy> {
 | 
			
		|||
    // Construct the indexing maps needed for linalg.generic ops.
 | 
			
		||||
    unsigned nloops = result_shaped_type.getRank();
 | 
			
		||||
 | 
			
		||||
    Location loc = iota_op.getLoc();
 | 
			
		||||
    auto linalg_op = rewriter.create<linalg::IndexedGenericOp>(
 | 
			
		||||
        iota_op.getLoc(),
 | 
			
		||||
        loc,
 | 
			
		||||
        /*resultTensorTypes=*/
 | 
			
		||||
        isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{result_shaped_type},
 | 
			
		||||
        /*inputs=*/ValueRange{},
 | 
			
		||||
        /*outputBuffers=*/isLHLO ? ValueRange{args} : ValueRange{},
 | 
			
		||||
        /*initTensors=*/ValueRange{},
 | 
			
		||||
        /*outputBuffers=*/
 | 
			
		||||
        isLHLO ? ValueRange{args}
 | 
			
		||||
               : ValueRange{GetInitTensor<isLHLO>(rewriter, loc,
 | 
			
		||||
                                                  result_shaped_type)},
 | 
			
		||||
        llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
 | 
			
		||||
        GetNParallelLoopsAttrs(nloops),
 | 
			
		||||
        [&](OpBuilder& nested_builder, Location nested_loc, ValueRange ivs,
 | 
			
		||||
| 
						 | 
				
			
			@ -818,8 +855,8 @@ class ReduceConverter : public OpConversionPattern<lmhlo::ReduceOp> {
 | 
			
		|||
 | 
			
		||||
    auto linalg_op = rewriter.create<linalg::GenericOp>(
 | 
			
		||||
        loc, /*resultTensorTypes=*/ArrayRef<Type>{},
 | 
			
		||||
        /*inputs=*/adaptor.operands(), /*outputBuffers=*/adaptor.out(),
 | 
			
		||||
        /*initTensors=*/ValueRange{}, maps, types);
 | 
			
		||||
        /*inputs=*/adaptor.operands(), /*outputBuffers=*/adaptor.out(), maps,
 | 
			
		||||
        types);
 | 
			
		||||
    rewriter.inlineRegionBefore(reduce_op.body(), linalg_op.region(),
 | 
			
		||||
                                linalg_op.region().end());
 | 
			
		||||
    {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -24,6 +24,7 @@ limitations under the License.
 | 
			
		|||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 | 
			
		||||
#include "mlir/Dialect/SCF/SCF.h"
 | 
			
		||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
 | 
			
		||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
 | 
			
		||||
#include "mlir/Interfaces/ViewLikeInterface.h"
 | 
			
		||||
#include "mlir/Pass/Pass.h"
 | 
			
		||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -110,7 +111,7 @@ class LhloFuseLinalgPass
 | 
			
		|||
        continue;
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      if (auto tensor_cast = dyn_cast<TensorCastOp>(definingOp)) {
 | 
			
		||||
      if (auto tensor_cast = dyn_cast<tensor::CastOp>(definingOp)) {
 | 
			
		||||
        auto alias = tensor_cast.source();
 | 
			
		||||
        if (result_buffers.insert(alias).second) {
 | 
			
		||||
          worklist.push_back(alias);
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -21,6 +21,7 @@ limitations under the License.
 | 
			
		|||
#include "mlir/Dialect/SCF/SCF.h"
 | 
			
		||||
#include "mlir/Dialect/Shape/IR/Shape.h"
 | 
			
		||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
 | 
			
		||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
 | 
			
		||||
#include "mlir/IR/BuiltinOps.h"
 | 
			
		||||
#include "mlir/IR/BuiltinTypes.h"
 | 
			
		||||
#include "mlir/IR/MLIRContext.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -228,7 +229,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
 | 
			
		|||
        loc, result_type, IsScalarTensor(rewriter, op, lhs), true);
 | 
			
		||||
    OpBuilder if_lhs_scalar_builder =
 | 
			
		||||
        if_op.getThenBodyBuilder(rewriter.getListener());
 | 
			
		||||
    Value reshaped_lhs = if_lhs_scalar_builder.create<TensorCastOp>(
 | 
			
		||||
    Value reshaped_lhs = if_lhs_scalar_builder.create<tensor::CastOp>(
 | 
			
		||||
        loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs);
 | 
			
		||||
    Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>(
 | 
			
		||||
        loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs},
 | 
			
		||||
| 
						 | 
				
			
			@ -247,7 +248,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
 | 
			
		|||
                                                 if_rhs_scalar_op.getResult(0));
 | 
			
		||||
    OpBuilder if_rhs_scalar_builder =
 | 
			
		||||
        if_rhs_scalar_op.getThenBodyBuilder(rewriter.getListener());
 | 
			
		||||
    Value reshaped_rhs = if_rhs_scalar_builder.create<TensorCastOp>(
 | 
			
		||||
    Value reshaped_rhs = if_rhs_scalar_builder.create<tensor::CastOp>(
 | 
			
		||||
        loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs);
 | 
			
		||||
    Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>(
 | 
			
		||||
        loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs},
 | 
			
		||||
| 
						 | 
				
			
			@ -345,12 +346,12 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
 | 
			
		|||
    Value extended_lhs = if_builder.create<shape::BroadcastOp>(
 | 
			
		||||
        loc, unknown_rank_extent_tensor_type, lhs_shape, ranked_shape_val,
 | 
			
		||||
        nullptr);
 | 
			
		||||
    Value extended_lhs_casted = if_builder.create<TensorCastOp>(
 | 
			
		||||
    Value extended_lhs_casted = if_builder.create<tensor::CastOp>(
 | 
			
		||||
        loc, known_rank_extent_tensor_type, extended_lhs);
 | 
			
		||||
    Value extended_rhs = if_builder.create<shape::BroadcastOp>(
 | 
			
		||||
        loc, unknown_rank_extent_tensor_type, rhs_shape, ranked_shape_val,
 | 
			
		||||
        nullptr);
 | 
			
		||||
    Value extended_rhs_casted = if_builder.create<TensorCastOp>(
 | 
			
		||||
    Value extended_rhs_casted = if_builder.create<tensor::CastOp>(
 | 
			
		||||
        loc, known_rank_extent_tensor_type, extended_rhs);
 | 
			
		||||
 | 
			
		||||
    // 1. Reshape operands to the given rank (with the same number of elements)
 | 
			
		||||
| 
						 | 
				
			
			@ -372,7 +373,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
 | 
			
		|||
    Value result = if_builder.create<ChloOpTy>(
 | 
			
		||||
        loc, ArrayRef<Type>{result_type},
 | 
			
		||||
        ArrayRef<Value>{reshaped_lhs, reshaped_rhs}, op.getAttrs());
 | 
			
		||||
    Value reshaped_result = if_builder.create<TensorCastOp>(
 | 
			
		||||
    Value reshaped_result = if_builder.create<tensor::CastOp>(
 | 
			
		||||
        loc, UnrankedTensorType::get(result_element_type), result);
 | 
			
		||||
    if_builder.create<scf::YieldOp>(loc, reshaped_result);
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			@ -446,7 +447,8 @@ struct TransformUnrankedHloPass
 | 
			
		|||
    MLIRContext &ctx = getContext();
 | 
			
		||||
    ConversionTarget target(ctx);
 | 
			
		||||
    target.addLegalDialect<mhlo::MhloDialect, StandardOpsDialect,
 | 
			
		||||
                           shape::ShapeDialect, scf::SCFDialect>();
 | 
			
		||||
                           shape::ShapeDialect, scf::SCFDialect,
 | 
			
		||||
                           tensor::TensorDialect>();
 | 
			
		||||
    target.addLegalOp<FuncOp>();
 | 
			
		||||
#define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor<mhlo::op>(&target)
 | 
			
		||||
#define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor<chlo::op>(&target)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -21,6 +21,7 @@ limitations under the License.
 | 
			
		|||
#include "llvm/ADT/SmallVector.h"
 | 
			
		||||
#include "mlir/Dialect/Shape/IR/Shape.h"
 | 
			
		||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
 | 
			
		||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
 | 
			
		||||
#include "mlir/IR/BuiltinTypes.h"
 | 
			
		||||
#include "mlir/IR/Diagnostics.h"
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -66,7 +67,7 @@ Value ComputeBinaryElementwiseBroadcastingResultExtents(
 | 
			
		|||
    Value result_shape_v = builder.createOrFold<shape::BroadcastOp>(
 | 
			
		||||
        loc, shape::getExtentTensorType(builder.getContext()), lhs_shape_v,
 | 
			
		||||
        rhs_shape_v, nullptr /* error */);
 | 
			
		||||
    return builder.createOrFold<TensorCastOp>(
 | 
			
		||||
    return builder.createOrFold<tensor::CastOp>(
 | 
			
		||||
        loc, RankedTensorType::get({result_rank}, builder.getIndexType()),
 | 
			
		||||
        result_shape_v);
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,7 +19,7 @@ func @dynamicBroadcast(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?
 | 
			
		|||
  // CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
 | 
			
		||||
  // CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
 | 
			
		||||
  // CHECK-DAG:    %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
 | 
			
		||||
  // CHECK:        %[[RESULT_EXTENTS:.+]] = tensor_cast %[[RESULT_S]] : tensor<?xindex> to tensor<2xindex>
 | 
			
		||||
  // CHECK:        %[[RESULT_EXTENTS:.+]] = tensor.cast %[[RESULT_S]] : tensor<?xindex> to tensor<2xindex>
 | 
			
		||||
  // CHECK-DAG:    %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
 | 
			
		||||
  // CHECK-DAG:    %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
 | 
			
		||||
  // CHECK-NEXT:   %[[RESULT:.+]] = mhlo.add %[[ARG0_B]], %[[ARG1_B]]
 | 
			
		||||
| 
						 | 
				
			
			@ -40,7 +40,7 @@ func @dynamicBroadcastComplex(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
 | 
			
		|||
  // CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
 | 
			
		||||
  // CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
 | 
			
		||||
  // CHECK-NEXT:   %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
 | 
			
		||||
  // CHECK-NEXT:   %[[RESULT_EXTENTS:.+]] = tensor_cast %[[RESULT_S]] : tensor<?xindex> to tensor<2xindex>
 | 
			
		||||
  // CHECK-NEXT:   %[[RESULT_EXTENTS:.+]] = tensor.cast %[[RESULT_S]] : tensor<?xindex> to tensor<2xindex>
 | 
			
		||||
  // CHECK-DAG:    %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
 | 
			
		||||
  // CHECK-DAG:    %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
 | 
			
		||||
  // CHECK-NEXT:   %[[RESULT:.+]] = "mhlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
 | 
			
		||||
| 
						 | 
				
			
			@ -61,7 +61,7 @@ func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
 | 
			
		|||
  // CHECK: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
 | 
			
		||||
  // CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
 | 
			
		||||
  // CHECK: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
 | 
			
		||||
  // CHECK: %[[RESULT_EXTENTS:.+]] = tensor_cast %[[RESULT_S]] : tensor<?xindex> to tensor<2xindex>
 | 
			
		||||
  // CHECK: %[[RESULT_EXTENTS:.+]] = tensor.cast %[[RESULT_S]] : tensor<?xindex> to tensor<2xindex>
 | 
			
		||||
  // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
 | 
			
		||||
  // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
 | 
			
		||||
  // CHECK: %[[RESULT:.+]] = "mhlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -16,7 +16,7 @@ func @constant_like_static_shape(%arg : tensor<1x2xi64>) -> tensor<1x2xf32> {
 | 
			
		|||
func @constant_like_dynamic_shape(%arg : tensor<?x?xi64>) -> tensor<?x?xf32> {
 | 
			
		||||
  // CHECK: %[[CONSTANT:.*]] = mhlo.constant dense<3.200000e+00> : tensor<f32>
 | 
			
		||||
  // CHECK: %[[UNCASTED_SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<?x?xi64> -> tensor<?xindex>
 | 
			
		||||
  // CHECK: %[[SHAPE:.*]] = tensor_cast %[[UNCASTED_SHAPE]] : tensor<?xindex> to tensor<2xindex>
 | 
			
		||||
  // CHECK: %[[SHAPE:.*]] = tensor.cast %[[UNCASTED_SHAPE]] : tensor<?xindex> to tensor<2xindex>
 | 
			
		||||
  // CHECK: %[[BROADCASTED_CONSTANT:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[CONSTANT]], %[[SHAPE]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<2xindex>) -> tensor<?x?xf32>
 | 
			
		||||
  // CHECK: return %[[BROADCASTED_CONSTANT]] : tensor<?x?xf32>
 | 
			
		||||
  %result = "chlo.constant_like"(%arg) { value = 3.2 : f32 }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -628,7 +628,7 @@ func @shape_assuming_tensor(%arg0: tensor<?xf16>) -> tensor<?xf16> {
 | 
			
		|||
  // CHECK: shape.assuming %{{.*}} -> (memref<?xf16>)
 | 
			
		||||
  %2 = shape.assuming %1 -> (tensor<?xf16>) {
 | 
			
		||||
    %3 = shape.shape_of %arg0 : tensor<?xf16> -> tensor<?xindex>
 | 
			
		||||
    %4 = tensor_cast %3 : tensor<?xindex> to tensor<1xindex>
 | 
			
		||||
    %4 = tensor.cast %3 : tensor<?xindex> to tensor<1xindex>
 | 
			
		||||
    %5 = "mhlo.dynamic_broadcast_in_dim"(%0, %4) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f16>, tensor<1xindex>) -> tensor<?xf16>
 | 
			
		||||
    %6 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %4) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf16>, tensor<1xindex>) -> tensor<?xf16>
 | 
			
		||||
    // CHECK: "lmhlo.maximum"(%{{.*}}, %{{.*}}, %{{.*}}) : (memref<?xf16>, memref<?xf16>, memref<?xf16>) -> ()
 | 
			
		||||
| 
						 | 
				
			
			@ -638,3 +638,5 @@ func @shape_assuming_tensor(%arg0: tensor<?xf16>) -> tensor<?xf16> {
 | 
			
		|||
  }
 | 
			
		||||
  return %2 : tensor<?xf16>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -249,8 +249,9 @@ func @float_cmp(%lhs: tensor<2x2xf32>,
 | 
			
		|||
          : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
 | 
			
		||||
  return %0 : tensor<2x2xi1>
 | 
			
		||||
}
 | 
			
		||||
// CHECK: linalg.init_tensor [2, 2] : tensor<2x2xi1>
 | 
			
		||||
// CHECK: linalg.generic
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32):
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %{{.*}}: i1):
 | 
			
		||||
// CHECK-NEXT:   %[[RESULT:.*]] = cmpf "oeq", %[[LHS_IN]], %[[RHS_IN]] : f32
 | 
			
		||||
// CHECK-NEXT:   linalg.yield %[[RESULT]] : i1
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -263,8 +264,9 @@ func @float_cmp_ne(%lhs: tensor<2x2xf32>,
 | 
			
		|||
          : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
 | 
			
		||||
  return %0 : tensor<2x2xi1>
 | 
			
		||||
}
 | 
			
		||||
// CHECK: linalg.init_tensor [2, 2] : tensor<2x2xi1>
 | 
			
		||||
// CHECK: linalg.generic
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32):
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %{{.*}}: i1):
 | 
			
		||||
// CHECK-NEXT:   %[[RESULT:.*]] = cmpf "une", %[[LHS_IN]], %[[RHS_IN]] : f32
 | 
			
		||||
// CHECK-NEXT:   linalg.yield %[[RESULT]] : i1
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -277,8 +279,9 @@ func @int_cmp(%lhs: tensor<2x2xi32>,
 | 
			
		|||
          : (tensor<2x2xi32>, tensor<2x2xi32>) -> (tensor<2x2xi1>)
 | 
			
		||||
  return %0 : tensor<2x2xi1>
 | 
			
		||||
}
 | 
			
		||||
// CHECK: linalg.init_tensor [2, 2] : tensor<2x2xi1>
 | 
			
		||||
// CHECK: linalg.generic
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32, %{{.*}}: i1):
 | 
			
		||||
// CHECK-NEXT:   %[[RESULT:.*]] = cmpi "slt", %[[LHS_IN]], %[[RHS_IN]] : i32
 | 
			
		||||
// CHECK-NEXT:   linalg.yield %[[RESULT]] : i1
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -335,8 +338,9 @@ func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>,
 | 
			
		|||
         : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
 | 
			
		||||
  return %0 : tensor<2x2xf32>
 | 
			
		||||
}
 | 
			
		||||
// CHECK: linalg.init_tensor [2, 2] : tensor<2x2xf32>
 | 
			
		||||
// CHECK: linalg.generic
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[PRED_IN:.*]]: i1, %[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32):
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[PRED_IN:.*]]: i1, %[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %{{.*}}: f32):
 | 
			
		||||
// CHECK-NEXT:   %[[RESULT:.*]] = select %[[PRED_IN]], %[[LHS_IN]], %[[RHS_IN]] : f32
 | 
			
		||||
// CHECK-NEXT:   linalg.yield %[[RESULT]] : f32
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -349,8 +353,9 @@ func @broadcast_scalar(%arg: tensor<f32>) -> tensor<4x2x1xf32> {
 | 
			
		|||
  %0 = "mhlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<f32>) -> tensor<4x2x1xf32>
 | 
			
		||||
  return %0: tensor<4x2x1xf32>
 | 
			
		||||
}
 | 
			
		||||
// CHECK: linalg.init_tensor [4, 2, 1] : tensor<4x2x1xf32>
 | 
			
		||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32):
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32):
 | 
			
		||||
// CHECK-NEXT:   linalg.yield %[[OPERAND]] : f32
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
| 
						 | 
				
			
			@ -362,8 +367,11 @@ func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> {
 | 
			
		|||
  %0 = "mhlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32>
 | 
			
		||||
  return %0: tensor<4x2x1x4x?x16xf32>
 | 
			
		||||
}
 | 
			
		||||
// CHECK: %[[C1:.*]] = constant 1 : index
 | 
			
		||||
// CHECK: %[[D1:.*]] = dim %{{.*}}, %[[C1]] : tensor<4x?x16xf32>
 | 
			
		||||
// CHECK: linalg.init_tensor [4, 2, 1, 4, %[[D1]], 16] : tensor<4x2x1x4x?x16xf32>
 | 
			
		||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32):
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32):
 | 
			
		||||
// CHECK-NEXT:   linalg.yield %[[OPERAND]] : f32
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
| 
						 | 
				
			
			@ -377,8 +385,9 @@ func @broadcast_in_dim(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> {
 | 
			
		|||
         : (tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32>
 | 
			
		||||
  return %0 : tensor<7x10x6x4x5xf32>
 | 
			
		||||
}
 | 
			
		||||
// CHECK: linalg.init_tensor [7, 10, 6, 4, 5] : tensor<7x10x6x4x5xf32>
 | 
			
		||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32):
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32):
 | 
			
		||||
// CHECK-NEXT:   linalg.yield %[[OPERAND]] : f32
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
| 
						 | 
				
			
			@ -393,8 +402,9 @@ func @broadcast_in_dim_with_one_to_one(
 | 
			
		|||
         : (tensor<1xf32>) -> tensor<1x5xf32>
 | 
			
		||||
  return %0 : tensor<1x5xf32>
 | 
			
		||||
}
 | 
			
		||||
// CHECK: linalg.init_tensor [1, 5] : tensor<1x5xf32>
 | 
			
		||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32):
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32):
 | 
			
		||||
// CHECK-NEXT:   linalg.yield %[[OPERAND]] : f32
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
| 
						 | 
				
			
			@ -408,8 +418,9 @@ func @broadcast_scalar(%operand: tensor<f32>) -> tensor<7x10x6xf32> {
 | 
			
		|||
        : (tensor<f32>) -> tensor<7x10x6xf32>
 | 
			
		||||
  return %0 : tensor<7x10x6xf32>
 | 
			
		||||
}
 | 
			
		||||
// CHECK: linalg.init_tensor [7, 10, 6] : tensor<7x10x6xf32>
 | 
			
		||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32):
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32):
 | 
			
		||||
// CHECK-NEXT:   linalg.yield %[[OPERAND]] : f32
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
| 
						 | 
				
			
			@ -499,8 +510,9 @@ func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
 | 
			
		|||
          : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
 | 
			
		||||
  return %0 : tensor<2x2xf32>
 | 
			
		||||
}
 | 
			
		||||
// CHECK: linalg.init_tensor [2, 2] : tensor<2x2xf32>
 | 
			
		||||
// CHECK: linalg.generic
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32):
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %{{.*}}: f32):
 | 
			
		||||
// CHECK-NEXT:   %[[CMP:.*]] = cmpf "olt", %[[LHS_IN]], %[[RHS_IN]] : f32
 | 
			
		||||
// CHECK-NEXT:   %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : f32
 | 
			
		||||
// CHECK-NEXT:   linalg.yield %[[RESULT]] : f32
 | 
			
		||||
| 
						 | 
				
			
			@ -513,8 +525,9 @@ func @maxi(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
 | 
			
		|||
          : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
 | 
			
		||||
  return %0 : tensor<2x2xi32>
 | 
			
		||||
}
 | 
			
		||||
// CHECK: linalg.init_tensor [2, 2] : tensor<2x2xi32>
 | 
			
		||||
// CHECK: linalg.generic
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32, %{{.*}}: i32):
 | 
			
		||||
// CHECK-NEXT:   %[[CMP:.*]] = cmpi "sgt", %[[LHS_IN]], %[[RHS_IN]] : i32
 | 
			
		||||
// CHECK-NEXT:   %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32
 | 
			
		||||
// CHECK-NEXT:   linalg.yield %[[RESULT]] : i32
 | 
			
		||||
| 
						 | 
				
			
			@ -527,9 +540,10 @@ func @add_scalar(%lhs: tensor<f32>, %rhs: tensor<f32>) -> tensor<f32> {
 | 
			
		|||
  %0 = "mhlo.add"(%lhs, %rhs) : (tensor<f32>, tensor<f32>) -> tensor<f32>
 | 
			
		||||
  return %0 : tensor<f32>
 | 
			
		||||
}
 | 
			
		||||
// CHECK: linalg.init_tensor
 | 
			
		||||
// CHECK: linalg.generic
 | 
			
		||||
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32, %{{.*}}: f32):
 | 
			
		||||
// CHECK: %[[RESULT:.*]] = addf %[[LHS]], %[[RHS]]
 | 
			
		||||
// CHECK-NEXT:   linalg.yield %[[RESULT]] : f32
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -599,8 +613,9 @@ func @convert_i32_to_f32(%input: tensor<2x2xi32>) -> tensor<2x2xf32> {
 | 
			
		|||
  %result = "mhlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xf32>
 | 
			
		||||
  return %result : tensor<2x2xf32>
 | 
			
		||||
}
 | 
			
		||||
// CHECK: linalg.init_tensor
 | 
			
		||||
// CHECK: linalg.generic
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32):
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %{{.*}}: f32):
 | 
			
		||||
// CHECK-NEXT:   %[[RESULT:.*]] = sitofp %[[OPERAND_IN]] : i32 to f32
 | 
			
		||||
// CHECK-NEXT:   linalg.yield %[[RESULT]] : f32
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -611,8 +626,9 @@ func @convert_i16_to_i32(%input: tensor<2x2xi16>) -> tensor<2x2xi32> {
 | 
			
		|||
  %result = "mhlo.convert"(%input) : (tensor<2x2xi16>) -> tensor<2x2xi32>
 | 
			
		||||
  return %result : tensor<2x2xi32>
 | 
			
		||||
}
 | 
			
		||||
// CHECK: linalg.init_tensor
 | 
			
		||||
// CHECK: linalg.generic
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i16):
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i16, %{{.*}}: i32):
 | 
			
		||||
// CHECK-NEXT:   %[[RESULT:.*]] = zexti %[[OPERAND_IN]] : i16 to i32
 | 
			
		||||
// CHECK-NEXT:   linalg.yield %[[RESULT]] : i32
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -623,8 +639,9 @@ func @convert_i32_to_i16(%input: tensor<2x2xi32>) -> tensor<2x2xi16> {
 | 
			
		|||
  %result = "mhlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xi16>
 | 
			
		||||
  return %result : tensor<2x2xi16>
 | 
			
		||||
}
 | 
			
		||||
// CHECK: linalg.init_tensor
 | 
			
		||||
// CHECK: linalg.generic
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32):
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %{{.*}}: i16):
 | 
			
		||||
// CHECK-NEXT:   %[[RESULT:.*]] = trunci %[[OPERAND_IN]] : i32 to i16
 | 
			
		||||
// CHECK-NEXT:   linalg.yield %[[RESULT]] : i16
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -635,8 +652,9 @@ func @convert_f32_to_f64(%input: tensor<2x2xf32>) -> tensor<2x2xf64> {
 | 
			
		|||
  %result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xf64>
 | 
			
		||||
  return %result : tensor<2x2xf64>
 | 
			
		||||
}
 | 
			
		||||
// CHECK: linalg.init_tensor
 | 
			
		||||
// CHECK: linalg.generic
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32):
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %{{.*}}: f64):
 | 
			
		||||
// CHECK-NEXT:   %[[RESULT:.*]] = fpext %[[OPERAND_IN]] : f32 to f64
 | 
			
		||||
// CHECK-NEXT:   linalg.yield %[[RESULT]] : f64
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -647,8 +665,9 @@ func @convert_f64_to_f32(%input: tensor<2x2xf64>) -> tensor<2x2xf32> {
 | 
			
		|||
  %result = "mhlo.convert"(%input) : (tensor<2x2xf64>) -> tensor<2x2xf32>
 | 
			
		||||
  return %result : tensor<2x2xf32>
 | 
			
		||||
}
 | 
			
		||||
// CHECK: linalg.init_tensor
 | 
			
		||||
// CHECK: linalg.generic
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f64):
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f64, %{{.*}}: f32):
 | 
			
		||||
// CHECK-NEXT:   %[[RESULT:.*]] = fptrunc %[[OPERAND_IN]] : f64 to f32
 | 
			
		||||
// CHECK-NEXT:   linalg.yield %[[RESULT]] : f32
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -659,8 +678,9 @@ func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> {
 | 
			
		|||
  %result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi32>
 | 
			
		||||
  return %result : tensor<2x2xi32>
 | 
			
		||||
}
 | 
			
		||||
// CHECK: linalg.init_tensor
 | 
			
		||||
// CHECK: linalg.generic
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32):
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %{{.*}}: i32):
 | 
			
		||||
// CHECK-NEXT:   %[[RESULT:.*]] = fptosi %[[OPERAND_IN]] : f32 to i32
 | 
			
		||||
// CHECK-NEXT:   linalg.yield %[[RESULT]] : i32
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -686,9 +706,10 @@ func @iota() -> tensor<7x10xf32> {
 | 
			
		|||
  %result = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<7x10xf32>)
 | 
			
		||||
  return %result : tensor<7x10xf32>
 | 
			
		||||
}
 | 
			
		||||
// CHECK: linalg.init_tensor
 | 
			
		||||
// CHECK: linalg.indexed_generic
 | 
			
		||||
// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]]
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index):
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index, %{{.*}}: f32):
 | 
			
		||||
// CHECK-NEXT:   %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32
 | 
			
		||||
// CHECK-NEXT:   %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32
 | 
			
		||||
// CHECK-NEXT:   linalg.yield %[[FLOAT_CAST]] : f32
 | 
			
		||||
| 
						 | 
				
			
			@ -702,8 +723,9 @@ func @shift_left(%lhs: tensor<2x2xi32>,
 | 
			
		|||
  return %result : tensor<2x2xi32>
 | 
			
		||||
}
 | 
			
		||||
// CHECK-LABEL: func @shift_left
 | 
			
		||||
// CHECK: linalg.init_tensor
 | 
			
		||||
// CHECK: linalg.generic
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32):
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32, %{{.*}}: i32):
 | 
			
		||||
// CHECK-NEXT:   %[[RESULT:.*]] = shift_left %[[LHS]], %[[RHS]] : i32
 | 
			
		||||
// CHECK-NEXT:   linalg.yield %[[RESULT]] : i32
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -716,8 +738,9 @@ func @shift_right_arithmetic(%lhs: tensor<2x2xi32>,
 | 
			
		|||
  return %result : tensor<2x2xi32>
 | 
			
		||||
}
 | 
			
		||||
// CHECK-LABEL: func @shift_right_arithmetic
 | 
			
		||||
// CHECK: linalg.init_tensor
 | 
			
		||||
// CHECK: linalg.generic
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32):
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32, %{{.*}}: i32):
 | 
			
		||||
// CHECK-NEXT:   %[[RESULT:.*]] = shift_right_signed %[[LHS]], %[[RHS]] : i32
 | 
			
		||||
// CHECK-NEXT:   linalg.yield %[[RESULT]] : i32
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -730,8 +753,9 @@ func @shift_right_logical(%lhs: tensor<2x2xi32>,
 | 
			
		|||
  return %result : tensor<2x2xi32>
 | 
			
		||||
}
 | 
			
		||||
// CHECK-LABEL: func @shift_right_logical
 | 
			
		||||
// CHECK: linalg.init_tensor
 | 
			
		||||
// CHECK: linalg.generic
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32):
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32, %{{.*}}: i32):
 | 
			
		||||
// CHECK-NEXT:   %[[RESULT:.*]] = shift_right_unsigned %[[LHS]], %[[RHS]] : i32
 | 
			
		||||
// CHECK-NEXT:   linalg.yield %[[RESULT]] : i32
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -163,7 +163,7 @@ func @addUnrankedUnranked(
 | 
			
		|||
// CHECK-NEXT:           %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[LHS_RANK]], %[[C0]] : index
 | 
			
		||||
//                       Handle scalar LHS case
 | 
			
		||||
// CHECK-NEXT:           %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) {
 | 
			
		||||
// CHECK-NEXT:             %[[SCALAR_LHS:.*]] = tensor_cast %[[LHS]] : tensor<*xf32> to tensor<f32>
 | 
			
		||||
// CHECK-NEXT:             %[[SCALAR_LHS:.*]] = tensor.cast %[[LHS]] : tensor<*xf32> to tensor<f32>
 | 
			
		||||
// CHECK-NEXT:             %[[RHS_SHAPE_1:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
 | 
			
		||||
// CHECK-NEXT:             %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE_1]] : tensor<?xindex> -> index
 | 
			
		||||
// CHECK-NEXT:             %[[NUM_TENS_RHS:.*]] = tensor_from_elements %[[NUM_RHS]] : tensor<1xindex>
 | 
			
		||||
| 
						 | 
				
			
			@ -177,7 +177,7 @@ func @addUnrankedUnranked(
 | 
			
		|||
// CHECK-NEXT:             %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RHS_RANK]], %[[C0]] : index
 | 
			
		||||
//                         Handle scalar RHS case
 | 
			
		||||
// CHECK-NEXT:             %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) {
 | 
			
		||||
// CHECK-NEXT:               %[[SCALAR_RHS:.*]] = tensor_cast %[[RHS]] : tensor<*xf32> to tensor<f32>
 | 
			
		||||
// CHECK-NEXT:               %[[SCALAR_RHS:.*]] = tensor.cast %[[RHS]] : tensor<*xf32> to tensor<f32>
 | 
			
		||||
// CHECK-NEXT:               %[[NUM_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor<?xindex> -> index
 | 
			
		||||
// CHECK-NEXT:               %[[NUM_TENS_LHS:.*]] = tensor_from_elements %[[NUM_LHS]] : tensor<1xindex>
 | 
			
		||||
// CHECK-NEXT:               %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[NUM_TENS_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
 | 
			
		||||
| 
						 | 
				
			
			@ -205,13 +205,13 @@ func @addUnrankedUnranked(
 | 
			
		|||
// CHECK-NEXT:                 %[[RESULT_RANK_1:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
 | 
			
		||||
// CHECK-NEXT:                   %[[CONST_SHAPE_1:.*]] = shape.const_shape [1]
 | 
			
		||||
// CHECK-NEXT:                   %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
 | 
			
		||||
// CHECK-NEXT:                   %[[CASTED_LHS_1:.*]] = tensor_cast %[[BROADCASTED_LHS_1]] : tensor<?xindex> to tensor<1xindex>
 | 
			
		||||
// CHECK-NEXT:                   %[[CASTED_LHS_1:.*]] = tensor.cast %[[BROADCASTED_LHS_1]] : tensor<?xindex> to tensor<1xindex>
 | 
			
		||||
// CHECK-NEXT:                   %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
 | 
			
		||||
// CHECK-NEXT:                   %[[CASTED_RHS_1:.*]] = tensor_cast %[[BROADCASTED_RHS_1]] : tensor<?xindex> to tensor<1xindex>
 | 
			
		||||
// CHECK-NEXT:                   %[[CASTED_RHS_1:.*]] = tensor.cast %[[BROADCASTED_RHS_1]] : tensor<?xindex> to tensor<1xindex>
 | 
			
		||||
// CHECK-NEXT:                   %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
 | 
			
		||||
// CHECK-NEXT:                   %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
 | 
			
		||||
// CHECK-NEXT:                   %[[RESULT_RANK_1:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_1]], %[[RESHAPED_RHS_1]] : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
 | 
			
		||||
// CHECK-NEXT:                   %[[RESULT_1:.*]] = tensor_cast %[[RESULT_RANK_1]] : tensor<?xf32> to tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                   %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1]] : tensor<?xf32> to tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                   scf.yield %[[RESULT_1]] : tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                 } else {
 | 
			
		||||
// CHECK-NEXT:                   %[[C2:.*]] = constant 2 : index
 | 
			
		||||
| 
						 | 
				
			
			@ -220,13 +220,13 @@ func @addUnrankedUnranked(
 | 
			
		|||
// CHECK-NEXT:                   %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) {
 | 
			
		||||
// CHECK-NEXT:                     %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1]
 | 
			
		||||
// CHECK-NEXT:                     %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
 | 
			
		||||
// CHECK-NEXT:                     %[[CASTED_LHS_2:.*]] = tensor_cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex>
 | 
			
		||||
// CHECK-NEXT:                     %[[CASTED_LHS_2:.*]] = tensor.cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex>
 | 
			
		||||
// CHECK-NEXT:                     %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
 | 
			
		||||
// CHECK-NEXT:                     %[[CASTED_RHS_2:.*]] = tensor_cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex>
 | 
			
		||||
// CHECK-NEXT:                     %[[CASTED_RHS_2:.*]] = tensor.cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex>
 | 
			
		||||
// CHECK-NEXT:                     %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
 | 
			
		||||
// CHECK-NEXT:                     %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
 | 
			
		||||
// CHECK-NEXT:                     %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
 | 
			
		||||
// CHECK-NEXT:                     %[[RESULT_2:.*]] = tensor_cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                     %[[RESULT_2:.*]] = tensor.cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                     scf.yield %[[RESULT_2]] : tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                   } else {
 | 
			
		||||
// CHECK-NEXT:                     %[[C3:.*]] = constant 3 : index
 | 
			
		||||
| 
						 | 
				
			
			@ -235,13 +235,13 @@ func @addUnrankedUnranked(
 | 
			
		|||
// CHECK-NEXT:                     %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) {
 | 
			
		||||
// CHECK-NEXT:                       %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
 | 
			
		||||
// CHECK-NEXT:                       %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
 | 
			
		||||
// CHECK-NEXT:                       %[[CASTED_LHS_3:.*]] = tensor_cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex>
 | 
			
		||||
// CHECK-NEXT:                       %[[CASTED_LHS_3:.*]] = tensor.cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex>
 | 
			
		||||
// CHECK-NEXT:                       %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
 | 
			
		||||
// CHECK-NEXT:                       %[[CASTED_RHS_3:.*]] = tensor_cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex>
 | 
			
		||||
// CHECK-NEXT:                       %[[CASTED_RHS_3:.*]] = tensor.cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex>
 | 
			
		||||
// CHECK-NEXT:                       %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
 | 
			
		||||
// CHECK-NEXT:                       %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
 | 
			
		||||
// CHECK-NEXT:                       %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
 | 
			
		||||
// CHECK-NEXT:                       %[[RESULT_3:.*]] = tensor_cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                       %[[RESULT_3:.*]] = tensor.cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                       scf.yield %[[RESULT_3]] : tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                     } else {
 | 
			
		||||
// CHECK-NEXT:                       %[[C4:.*]] = constant 4 : index
 | 
			
		||||
| 
						 | 
				
			
			@ -250,13 +250,13 @@ func @addUnrankedUnranked(
 | 
			
		|||
// CHECK-NEXT:                       %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) {
 | 
			
		||||
// CHECK-NEXT:                         %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
 | 
			
		||||
// CHECK-NEXT:                         %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
 | 
			
		||||
// CHECK-NEXT:                         %[[CASTED_LHS_4:.*]] = tensor_cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex>
 | 
			
		||||
// CHECK-NEXT:                         %[[CASTED_LHS_4:.*]] = tensor.cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex>
 | 
			
		||||
// CHECK-NEXT:                         %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
 | 
			
		||||
// CHECK-NEXT:                         %[[CASTED_RHS_4:.*]] = tensor_cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex>
 | 
			
		||||
// CHECK-NEXT:                         %[[CASTED_RHS_4:.*]] = tensor.cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex>
 | 
			
		||||
// CHECK-NEXT:                         %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
 | 
			
		||||
// CHECK-NEXT:                         %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
 | 
			
		||||
// CHECK-NEXT:                         %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
 | 
			
		||||
// CHECK-NEXT:                         %[[RESULT_4:.*]] = tensor_cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                         %[[RESULT_4:.*]] = tensor.cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                         scf.yield %[[RESULT_4]] : tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                       } else {
 | 
			
		||||
// CHECK-NEXT:                         %[[C5:.*]] = constant 5 : index
 | 
			
		||||
| 
						 | 
				
			
			@ -265,13 +265,13 @@ func @addUnrankedUnranked(
 | 
			
		|||
// CHECK-NEXT:                         %[[VAL_50:.*]] = scf.if %[[GREATEST_RANK_IS_5]] -> (tensor<*xf32>) {
 | 
			
		||||
// CHECK-NEXT:                           %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
 | 
			
		||||
// CHECK-NEXT:                           %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
 | 
			
		||||
// CHECK-NEXT:                           %[[CASTED_LHS_5:.*]] = tensor_cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
 | 
			
		||||
// CHECK-NEXT:                           %[[CASTED_LHS_5:.*]] = tensor.cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
 | 
			
		||||
// CHECK-NEXT:                           %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
 | 
			
		||||
// CHECK-NEXT:                           %[[CASTED_RHS_5:.*]] = tensor_cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
 | 
			
		||||
// CHECK-NEXT:                           %[[CASTED_RHS_5:.*]] = tensor.cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
 | 
			
		||||
// CHECK-NEXT:                           %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
 | 
			
		||||
// CHECK-NEXT:                           %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
 | 
			
		||||
// CHECK-NEXT:                           %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
 | 
			
		||||
// CHECK-NEXT:                           %[[RESULT_5:.*]] = tensor_cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                           %[[RESULT_5:.*]] = tensor.cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                           scf.yield %[[RESULT_5]] : tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                         } else {
 | 
			
		||||
// CHECK-NEXT:                           %[[C6:.*]] = constant 6 : index
 | 
			
		||||
| 
						 | 
				
			
			@ -280,13 +280,13 @@ func @addUnrankedUnranked(
 | 
			
		|||
//                                       Handle rank 6 specialization
 | 
			
		||||
// CHECK-NEXT:                           %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
 | 
			
		||||
// CHECK-NEXT:                           %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
 | 
			
		||||
// CHECK-NEXT:                           %[[CASTED_LHS_6:.*]] = tensor_cast %[[BROADCASTED_LHS_6]] : tensor<?xindex> to tensor<6xindex>
 | 
			
		||||
// CHECK-NEXT:                           %[[CASTED_LHS_6:.*]] = tensor.cast %[[BROADCASTED_LHS_6]] : tensor<?xindex> to tensor<6xindex>
 | 
			
		||||
// CHECK-NEXT:                           %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
 | 
			
		||||
// CHECK-NEXT:                           %[[CASTED_RHS_6:.*]] = tensor_cast %[[BROADCASTED_RHS_6]] : tensor<?xindex> to tensor<6xindex>
 | 
			
		||||
// CHECK-NEXT:                           %[[CASTED_RHS_6:.*]] = tensor.cast %[[BROADCASTED_RHS_6]] : tensor<?xindex> to tensor<6xindex>
 | 
			
		||||
// CHECK-NEXT:                           %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
 | 
			
		||||
// CHECK-NEXT:                           %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
 | 
			
		||||
// CHECK-NEXT:                           %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
 | 
			
		||||
// CHECK-NEXT:                           %[[RESULT_6:.*]] = tensor_cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                           %[[RESULT_6:.*]] = tensor.cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                           scf.yield %[[RESULT_6]] : tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                         }
 | 
			
		||||
// CHECK-NEXT:                         scf.yield %[[VAL_65:.*]] : tensor<*xf32>
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -375,7 +375,7 @@ func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: inde
 | 
			
		|||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// Confirm that tiling information is passed through tensor_load, tensor_cast
 | 
			
		||||
// Confirm that tiling information is passed through tensor_load, tensor.cast
 | 
			
		||||
// and memref_to_tensor  operations.
 | 
			
		||||
func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
 | 
			
		||||
    -> memref<?xf32> {
 | 
			
		||||
| 
						 | 
				
			
			@ -390,7 +390,7 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
 | 
			
		|||
    linalg.yield %13 : f32
 | 
			
		||||
  }
 | 
			
		||||
  %2 = tensor_load %1 : memref<32xf32>
 | 
			
		||||
  %3 = tensor_cast %2 : tensor<32xf32> to tensor<?xf32>
 | 
			
		||||
  %3 = tensor.cast %2 : tensor<32xf32> to tensor<?xf32>
 | 
			
		||||
  %4 = tensor_to_memref %3 : memref<?xf32>
 | 
			
		||||
  return %4 : memref<?xf32>
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -403,7 +403,7 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
 | 
			
		|||
//       CHECK:      linalg.generic
 | 
			
		||||
//       CHECK:        absf
 | 
			
		||||
//       CHECK:  tensor_load
 | 
			
		||||
//       CHECK:  tensor_cast
 | 
			
		||||
//       CHECK:  tensor.cast
 | 
			
		||||
//       CHECK:  tensor_to_memref
 | 
			
		||||
 | 
			
		||||
// TILED-LABEL: func @tensor_ops
 | 
			
		||||
| 
						 | 
				
			
			@ -414,7 +414,7 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
 | 
			
		|||
//       TILED:      linalg.generic
 | 
			
		||||
//       TILED:        absf
 | 
			
		||||
//       TILED:  tensor_load
 | 
			
		||||
//       TILED:  tensor_cast
 | 
			
		||||
//       TILED:  tensor.cast
 | 
			
		||||
//       TILED:  tensor_to_memref
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -425,5 +425,5 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
 | 
			
		|||
//       PLOOP:      linalg.generic
 | 
			
		||||
//       PLOOP:        absf
 | 
			
		||||
//       PLOOP:  tensor_load
 | 
			
		||||
//       PLOOP:  tensor_cast
 | 
			
		||||
//       PLOOP:  tensor.cast
 | 
			
		||||
//       PLOOP:  tensor_to_memref
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue