Integrate LLVM at llvm/llvm-project@678241795c
Updates LLVM usage to match [678241795c95](https://github.com/llvm/llvm-project/commit/678241795c95) PiperOrigin-RevId: 363257913
This commit is contained in:
		
							parent
							
								
									2be112a603
								
							
						
					
					
						commit
						c54527fe88
					
				
							
								
								
									
										5
									
								
								BUILD
								
								
								
								
							
							
						
						
									
										5
									
								
								BUILD
								
								
								
								
							| 
						 | 
				
			
			@ -23,6 +23,7 @@ td_library(
 | 
			
		|||
    ],
 | 
			
		||||
    includes = ["include"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "@llvm-project//mlir:MemRefOpsTdFiles",
 | 
			
		||||
        "@llvm-project//mlir:OpBaseTdFiles",
 | 
			
		||||
        "@llvm-project//mlir:SideEffectTdFiles",
 | 
			
		||||
    ],
 | 
			
		||||
| 
						 | 
				
			
			@ -462,6 +463,7 @@ cc_library(
 | 
			
		|||
        "@llvm-project//mlir:Analysis",
 | 
			
		||||
        "@llvm-project//mlir:CopyOpInterface",
 | 
			
		||||
        "@llvm-project//mlir:IR",
 | 
			
		||||
        "@llvm-project//mlir:MemRefDialect",
 | 
			
		||||
        "@llvm-project//mlir:Pass",
 | 
			
		||||
        "@llvm-project//mlir:SideEffects",
 | 
			
		||||
        "@llvm-project//mlir:StandardOps",
 | 
			
		||||
| 
						 | 
				
			
			@ -615,6 +617,7 @@ cc_library(
 | 
			
		|||
        "@llvm-project//llvm:Support",
 | 
			
		||||
        "@llvm-project//mlir:IR",
 | 
			
		||||
        "@llvm-project//mlir:LinalgOps",
 | 
			
		||||
        "@llvm-project//mlir:MemRefDialect",
 | 
			
		||||
        "@llvm-project//mlir:Pass",
 | 
			
		||||
        "@llvm-project//mlir:SCFDialect",
 | 
			
		||||
        "@llvm-project//mlir:StandardOps",
 | 
			
		||||
| 
						 | 
				
			
			@ -724,6 +727,7 @@ cc_library(
 | 
			
		|||
        "@llvm-project//mlir:Affine",
 | 
			
		||||
        "@llvm-project//mlir:IR",
 | 
			
		||||
        "@llvm-project//mlir:LinalgTransforms",
 | 
			
		||||
        "@llvm-project//mlir:MemRefDialect",
 | 
			
		||||
        "@llvm-project//mlir:Pass",
 | 
			
		||||
        "@llvm-project//mlir:SCFDialect",
 | 
			
		||||
        "@llvm-project//mlir:StandardOps",
 | 
			
		||||
| 
						 | 
				
			
			@ -951,6 +955,7 @@ cc_library(
 | 
			
		|||
        ":hlo",
 | 
			
		||||
        "@llvm-project//llvm:Support",
 | 
			
		||||
        "@llvm-project//mlir:IR",
 | 
			
		||||
        "@llvm-project//mlir:MemRefDialect",
 | 
			
		||||
        "@llvm-project//mlir:StandardOps",
 | 
			
		||||
        "@llvm-project//mlir:TensorDialect",
 | 
			
		||||
        "@llvm-project//mlir:Transforms",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -15,9 +15,9 @@
 | 
			
		|||
 | 
			
		||||
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
 | 
			
		||||
 | 
			
		||||
LLVM_COMMIT = "6878be5dc3ec7031d0deec3e321310115bd71103"
 | 
			
		||||
LLVM_COMMIT = "678241795c957b18bc473045e48abe3f2a61ff5c"
 | 
			
		||||
 | 
			
		||||
LLVM_SHA256 = "f55187a3329fd97fd62fd0714783524d50a3be934a35484bd4442195fb25f0e5"
 | 
			
		||||
LLVM_SHA256 = "58fd00a9ed7841f36aa7042bb8c98323b030dee98abe36757eea9ddf4fd5ea75"
 | 
			
		||||
 | 
			
		||||
LLVM_BAZEL_TAG = "llvm-project-{commit}".format(commit = LLVM_COMMIT)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,2 +1,2 @@
 | 
			
		|||
6878be5dc3ec7031d0deec3e321310115bd71103
 | 
			
		||||
678241795c957b18bc473045e48abe3f2a61ff5c
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -21,6 +21,7 @@ limitations under the License.
 | 
			
		|||
#include "llvm/ADT/StringRef.h"
 | 
			
		||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
 | 
			
		||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h"
 | 
			
		||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
 | 
			
		||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
 | 
			
		||||
#include "mlir/IR/Attributes.h"
 | 
			
		||||
#include "mlir/IR/BuiltinTypes.h"
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -33,6 +33,7 @@ limitations under the License.
 | 
			
		|||
#ifndef LHLO_OPS
 | 
			
		||||
#define LHLO_OPS
 | 
			
		||||
 | 
			
		||||
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
 | 
			
		||||
include "mlir/IR/OpBase.td"
 | 
			
		||||
include "mlir/Interfaces/CopyOpInterface.td"
 | 
			
		||||
include "mlir/Interfaces/SideEffectInterfaces.td"
 | 
			
		||||
| 
						 | 
				
			
			@ -685,7 +686,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
 | 
			
		|||
  let extraClassDeclaration = [{
 | 
			
		||||
    SmallVector<Value, 4> getInputBuffers() {
 | 
			
		||||
      SmallVector<Value, 4> buffers;
 | 
			
		||||
      this->region().walk([&](TensorLoadOp load) {
 | 
			
		||||
      this->region().walk([&](memref::TensorLoadOp load) {
 | 
			
		||||
        if (load.memref().getParentRegion()->isProperAncestor(®ion()))
 | 
			
		||||
          buffers.push_back(load.memref());
 | 
			
		||||
      });
 | 
			
		||||
| 
						 | 
				
			
			@ -694,7 +695,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
 | 
			
		|||
 | 
			
		||||
    SmallVector<Value, 4> getOutputBuffers() {
 | 
			
		||||
      SmallVector<Value, 4> buffers;
 | 
			
		||||
      this->region().walk([&](TensorStoreOp store) {
 | 
			
		||||
      this->region().walk([&](memref::TensorStoreOp store) {
 | 
			
		||||
        if (store.memref().getParentRegion()->isProperAncestor(®ion()))
 | 
			
		||||
          buffers.push_back(store.memref());
 | 
			
		||||
      });
 | 
			
		||||
| 
						 | 
				
			
			@ -703,7 +704,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
 | 
			
		|||
 | 
			
		||||
    SmallVector<Value, 4> getFusionParameters() {
 | 
			
		||||
      SmallVector<Value, 4> buffers;
 | 
			
		||||
      this->region().walk([&](TensorLoadOp load) {
 | 
			
		||||
      this->region().walk([&](memref::TensorLoadOp load) {
 | 
			
		||||
        if (load.memref().getParentRegion()->isProperAncestor(®ion()))
 | 
			
		||||
          buffers.push_back(load);
 | 
			
		||||
      });
 | 
			
		||||
| 
						 | 
				
			
			@ -712,7 +713,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
 | 
			
		|||
 | 
			
		||||
    SmallVector<Value, 4> getFusionResults() {
 | 
			
		||||
      SmallVector<Value, 4> buffers;
 | 
			
		||||
      this->region().walk([&](TensorStoreOp store) {
 | 
			
		||||
      this->region().walk([&](memref::TensorStoreOp store) {
 | 
			
		||||
        if (store.memref().getParentRegion()->isProperAncestor(®ion()))
 | 
			
		||||
          buffers.push_back(store.tensor());
 | 
			
		||||
      });
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -16,6 +16,7 @@ limitations under the License.
 | 
			
		|||
#ifndef LHLO_OPS_BASE
 | 
			
		||||
#define LHLO_OPS_BASE
 | 
			
		||||
 | 
			
		||||
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
 | 
			
		||||
include "mlir/IR/OpBase.td"
 | 
			
		||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -33,6 +33,7 @@ limitations under the License.
 | 
			
		|||
#include "llvm/Support/FormatVariadic.h"
 | 
			
		||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"
 | 
			
		||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
 | 
			
		||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
 | 
			
		||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
 | 
			
		||||
#include "mlir/IR/Attributes.h"
 | 
			
		||||
#include "mlir/IR/Builders.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -156,13 +157,13 @@ struct EraseConstOp : public OpRewritePattern<ConstOp> {
 | 
			
		|||
  LogicalResult matchAndRewrite(ConstOp op,
 | 
			
		||||
                                PatternRewriter& rewriter) const override {
 | 
			
		||||
    Value memref = op.output();
 | 
			
		||||
    if (!memref.getDefiningOp<AllocOp>()) {
 | 
			
		||||
    if (!memref.getDefiningOp<memref::AllocOp>()) {
 | 
			
		||||
      return failure();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Check that all uses of the memref are either DeallocOps or this op.
 | 
			
		||||
    for (Operation* user : memref.getUsers())
 | 
			
		||||
      if (user != op && !isa<DeallocOp>(user)) return failure();
 | 
			
		||||
      if (user != op && !isa<memref::DeallocOp>(user)) return failure();
 | 
			
		||||
 | 
			
		||||
    rewriter.eraseOp(op);
 | 
			
		||||
    return success();
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -71,7 +71,7 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result,
 | 
			
		|||
    dynamic_operands.push_back(alloc_operand);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return rewriter->create<AllocOp>(loc, memref_type, dynamic_operands);
 | 
			
		||||
  return rewriter->create<memref::AllocOp>(loc, memref_type, dynamic_operands);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Value InsertAlloc(Location loc, OpResult result,
 | 
			
		||||
| 
						 | 
				
			
			@ -85,7 +85,7 @@ Value InsertAlloc(Location loc, OpResult result,
 | 
			
		|||
      MemRefType::get(result_type.getShape(), result_type.getElementType());
 | 
			
		||||
  OpBuilder::InsertionGuard guard(*rewriter);
 | 
			
		||||
  rewriter->setInsertionPoint(result.getDefiningOp());
 | 
			
		||||
  auto alloc = rewriter->create<AllocOp>(loc, memref_type);
 | 
			
		||||
  auto alloc = rewriter->create<memref::AllocOp>(loc, memref_type);
 | 
			
		||||
  return alloc;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -207,7 +207,7 @@ class HloToLhloReshapeUnrankedConverter
 | 
			
		|||
    if (unranked_operand_type == nullptr) return failure();
 | 
			
		||||
 | 
			
		||||
    auto result_type = op.getType().cast<RankedTensorType>();
 | 
			
		||||
    rewriter.replaceOpWithNewOp<MemRefCastOp>(
 | 
			
		||||
    rewriter.replaceOpWithNewOp<memref::CastOp>(
 | 
			
		||||
        op, adaptor.operand(),
 | 
			
		||||
        MemRefType::get(result_type.getShape(), result_type.getElementType()));
 | 
			
		||||
    return success();
 | 
			
		||||
| 
						 | 
				
			
			@ -235,7 +235,7 @@ class HloToLhloDynamicReshapeConverter
 | 
			
		|||
      return failure();
 | 
			
		||||
    }
 | 
			
		||||
    mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
 | 
			
		||||
    rewriter.replaceOpWithNewOp<MemRefReshapeOp>(
 | 
			
		||||
    rewriter.replaceOpWithNewOp<memref::ReshapeOp>(
 | 
			
		||||
        op, result_type, adaptor.operand(), adaptor.output_shape());
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			@ -273,7 +273,7 @@ class HloToLhloDynamicBroadcastInDimOpConverter
 | 
			
		|||
  // Inserts dynamic memref to change the layout of the memref to put 0-stride
 | 
			
		||||
  // and size of the target dimension if size-1 dimension expansion is
 | 
			
		||||
  // necessary.
 | 
			
		||||
  MemRefReinterpretCastOp InsertDynamicMemrefCastOp(
 | 
			
		||||
  memref::ReinterpretCastOp InsertDynamicMemrefCastOp(
 | 
			
		||||
      mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
 | 
			
		||||
    auto loc = op.getLoc();
 | 
			
		||||
    auto operand_type = operand.getType().cast<MemRefType>();
 | 
			
		||||
| 
						 | 
				
			
			@ -295,7 +295,7 @@ class HloToLhloDynamicBroadcastInDimOpConverter
 | 
			
		|||
    for (int i = operand_rank - 1; i >= 0; --i) {
 | 
			
		||||
      Value operand_dim_size =
 | 
			
		||||
          ShapedType::isDynamic(operand_shape[i])
 | 
			
		||||
              ? b->create<DimOp>(loc, operand, i).getResult()
 | 
			
		||||
              ? b->create<memref::DimOp>(loc, operand, i).getResult()
 | 
			
		||||
              : b->create<ConstantIndexOp>(loc, operand_shape[i]).getResult();
 | 
			
		||||
      operand_sizes[i] = operand_dim_size;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -355,7 +355,7 @@ class HloToLhloDynamicBroadcastInDimOpConverter
 | 
			
		|||
        makeStridedLinearLayoutMap(dynamic_layout,
 | 
			
		||||
                                   /*offset=*/0, b->getContext()));
 | 
			
		||||
 | 
			
		||||
    auto transformed_operand = b->create<MemRefReinterpretCastOp>(
 | 
			
		||||
    auto transformed_operand = b->create<memref::ReinterpretCastOp>(
 | 
			
		||||
        loc, type_erased_memref_type, operand,
 | 
			
		||||
        /*offset=*/b->getI64IntegerAttr(0), sizes, strides);
 | 
			
		||||
    return transformed_operand;
 | 
			
		||||
| 
						 | 
				
			
			@ -484,12 +484,12 @@ struct HloToLhloReturnOpConverter : public BaseOpConversion<mhlo::ReturnOp> {
 | 
			
		|||
 | 
			
		||||
// TODO(b/175789537) Remove this pattern.
 | 
			
		||||
class HloToLhloTensorStoreOpLegacyConverter
 | 
			
		||||
    : public BaseOpConversion<mlir::TensorStoreOp> {
 | 
			
		||||
    : public BaseOpConversion<mlir::memref::TensorStoreOp> {
 | 
			
		||||
 public:
 | 
			
		||||
  using BaseOpConversion<mlir::TensorStoreOp>::BaseOpConversion;
 | 
			
		||||
  using BaseOpConversion<mlir::memref::TensorStoreOp>::BaseOpConversion;
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(
 | 
			
		||||
      mlir::TensorStoreOp op, ArrayRef<Value> operands,
 | 
			
		||||
      mlir::memref::TensorStoreOp op, ArrayRef<Value> operands,
 | 
			
		||||
      ConversionPatternRewriter& rewriter) const final {
 | 
			
		||||
    rewriter.replaceOpWithNewOp<lmhlo::CopyOp>(op, llvm::None, operands.front(),
 | 
			
		||||
                                               operands.back());
 | 
			
		||||
| 
						 | 
				
			
			@ -577,14 +577,16 @@ struct HloLegalizeToLhlo
 | 
			
		|||
    ConversionTarget target(context);
 | 
			
		||||
    target.addLegalDialect<lmhlo::LmhloDialect>();
 | 
			
		||||
    target.addLegalDialect<StandardOpsDialect>();
 | 
			
		||||
    target.addLegalDialect<memref::MemRefDialect>();
 | 
			
		||||
    target.addLegalDialect<shape::ShapeDialect>();
 | 
			
		||||
    target.addLegalDialect<tensor::TensorDialect>();
 | 
			
		||||
    target.addIllegalDialect<mhlo::MhloDialect>();
 | 
			
		||||
    // Declare tensor_load and tensor_store illegal.
 | 
			
		||||
    target.addIllegalOp<mlir::TensorLoadOp, mlir::TensorStoreOp>();
 | 
			
		||||
    // tensor_to_memref is illegal if it has uses.
 | 
			
		||||
    // TODO(b/175670649) Make tensor_to_memref illegal.
 | 
			
		||||
    target.addDynamicallyLegalOp<mlir::TensorToMemrefOp>(
 | 
			
		||||
    target.addIllegalOp<mlir::memref::TensorLoadOp,
 | 
			
		||||
                        mlir::memref::TensorStoreOp>();
 | 
			
		||||
    // buffer_cast is illegal if it has uses.
 | 
			
		||||
    // TODO(b/175670649) Make buffer_cast illegal.
 | 
			
		||||
    target.addDynamicallyLegalOp<mlir::memref::BufferCastOp>(
 | 
			
		||||
        [](auto op) { return op->use_empty(); });
 | 
			
		||||
 | 
			
		||||
    BufferizeTypeConverter converter;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -108,7 +108,7 @@ SmallVector<Value, 2> ExtractDynamicSizes(OpBuilder& b, Location loc,
 | 
			
		|||
      dyn_sizes.push_back(
 | 
			
		||||
          b.create<IndexCastOp>(loc, b.getIndexType(), extract));
 | 
			
		||||
    } else {
 | 
			
		||||
      dyn_sizes.push_back(b.create<DimOp>(loc, tensor, en.index()));
 | 
			
		||||
      dyn_sizes.push_back(b.create<memref::DimOp>(loc, tensor, en.index()));
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return dyn_sizes;
 | 
			
		||||
| 
						 | 
				
			
			@ -324,13 +324,13 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
 | 
			
		|||
    }
 | 
			
		||||
 | 
			
		||||
    // Create two loads from the input.
 | 
			
		||||
    auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs());
 | 
			
		||||
    auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs());
 | 
			
		||||
    auto lhs = rewriter.create<memref::LoadOp>(loc, lhlo_op.lhs());
 | 
			
		||||
    auto rhs = rewriter.create<memref::LoadOp>(loc, lhlo_op.rhs());
 | 
			
		||||
    // TODO(ravishankarm) : Move this method out of lmhlo namespace.
 | 
			
		||||
    Value op_result = lmhlo::HloOpToStdScalarOp::map<LhloOp>(
 | 
			
		||||
        lhlo_op, arg_type.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
 | 
			
		||||
        &rewriter);
 | 
			
		||||
    rewriter.create<StoreOp>(loc, op_result, lhlo_op.out());
 | 
			
		||||
    rewriter.create<memref::StoreOp>(loc, op_result, lhlo_op.out());
 | 
			
		||||
    rewriter.eraseOp(lhlo_op);
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			@ -590,8 +590,8 @@ class LhloBroadcastInDimConverter
 | 
			
		|||
        operand_type.getDimSize(0) <
 | 
			
		||||
            result_type.getDimSize(broadcast_dims.front())) {
 | 
			
		||||
      Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
 | 
			
		||||
      Value val =
 | 
			
		||||
          rewriter.create<LoadOp>(loc, operand, llvm::makeArrayRef({zero}));
 | 
			
		||||
      Value val = rewriter.create<memref::LoadOp>(loc, operand,
 | 
			
		||||
                                                  llvm::makeArrayRef({zero}));
 | 
			
		||||
      rewriter.create<linalg::GenericOp>(
 | 
			
		||||
          loc, /*inputs=*/ValueRange{},
 | 
			
		||||
          /*outputBuffers=*/ValueRange{operand_adaptor.output()},
 | 
			
		||||
| 
						 | 
				
			
			@ -971,7 +971,8 @@ class ReduceConverter : public OpConversionPattern<lmhlo::ReduceOp> {
 | 
			
		|||
    }
 | 
			
		||||
 | 
			
		||||
    // First fill the output buffer with the init value.
 | 
			
		||||
    Value init_value = rewriter.create<LoadOp>(loc, adaptor.init_values()[0]);
 | 
			
		||||
    Value init_value =
 | 
			
		||||
        rewriter.create<memref::LoadOp>(loc, adaptor.init_values()[0]);
 | 
			
		||||
    rewriter.create<linalg::FillOp>(loc, adaptor.out()[0], init_value);
 | 
			
		||||
 | 
			
		||||
    DenseIntElementsAttr dimensions_attr = reduce_op.dimensions();
 | 
			
		||||
| 
						 | 
				
			
			@ -1011,9 +1012,9 @@ class ReduceConverter : public OpConversionPattern<lmhlo::ReduceOp> {
 | 
			
		|||
      // expects scalar SSA values. Add some allocs around the original op to
 | 
			
		||||
      // make it compatible.
 | 
			
		||||
      auto arg_type = block->getArgument(0).getType().cast<MemRefType>();
 | 
			
		||||
      Value alloc_a = rewriter.create<AllocaOp>(loc, arg_type);
 | 
			
		||||
      Value alloc_b = rewriter.create<AllocaOp>(loc, arg_type);
 | 
			
		||||
      Value alloc_res = rewriter.create<AllocaOp>(loc, arg_type);
 | 
			
		||||
      Value alloc_a = rewriter.create<memref::AllocaOp>(loc, arg_type);
 | 
			
		||||
      Value alloc_b = rewriter.create<memref::AllocaOp>(loc, arg_type);
 | 
			
		||||
      Value alloc_res = rewriter.create<memref::AllocaOp>(loc, arg_type);
 | 
			
		||||
 | 
			
		||||
      // Now turn the existing signature
 | 
			
		||||
      //   (memref<X>, memref<X>, memref<X>) -> ()
 | 
			
		||||
| 
						 | 
				
			
			@ -1030,13 +1031,15 @@ class ReduceConverter : public OpConversionPattern<lmhlo::ReduceOp> {
 | 
			
		|||
 | 
			
		||||
      // Store the arguments into the newly allocated buffers.
 | 
			
		||||
      rewriter.setInsertionPointAfter(alloc_res.getDefiningOp());
 | 
			
		||||
      rewriter.create<StoreOp>(loc, entry_block->getArgument(0), alloc_a);
 | 
			
		||||
      rewriter.create<StoreOp>(loc, entry_block->getArgument(1), alloc_b);
 | 
			
		||||
      rewriter.create<memref::StoreOp>(loc, entry_block->getArgument(0),
 | 
			
		||||
                                       alloc_a);
 | 
			
		||||
      rewriter.create<memref::StoreOp>(loc, entry_block->getArgument(1),
 | 
			
		||||
                                       alloc_b);
 | 
			
		||||
      rewriter.replaceOp(entry_block->getTerminator(), {});
 | 
			
		||||
 | 
			
		||||
      // Load & yield the result.
 | 
			
		||||
      rewriter.setInsertionPointToEnd(entry_block);
 | 
			
		||||
      auto load_res = rewriter.create<LoadOp>(loc, alloc_res);
 | 
			
		||||
      auto load_res = rewriter.create<memref::LoadOp>(loc, alloc_res);
 | 
			
		||||
      rewriter.create<linalg::YieldOp>(loc, ValueRange{load_res});
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -1099,8 +1102,8 @@ class SliceConverter : public OpConversionPattern<OpTy> {
 | 
			
		|||
          slice_op.strides().template getValue<int64_t>(i)));
 | 
			
		||||
    }
 | 
			
		||||
    if (isLHLO) {
 | 
			
		||||
      auto linalg_op =
 | 
			
		||||
          rewriter.create<SubViewOp>(loc, args[0], offsets, sizes, strides);
 | 
			
		||||
      auto linalg_op = rewriter.create<memref::SubViewOp>(loc, args[0], offsets,
 | 
			
		||||
                                                          sizes, strides);
 | 
			
		||||
      rewriter.create<linalg::CopyOp>(loc, linalg_op, args[1]);
 | 
			
		||||
      rewriter.eraseOp(slice_op);
 | 
			
		||||
    } else {
 | 
			
		||||
| 
						 | 
				
			
			@ -1149,14 +1152,14 @@ SmallVector<Value, 2> GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc,
 | 
			
		|||
  switch (type) {
 | 
			
		||||
    case DotOperationType::kMatrixMatrix: {
 | 
			
		||||
      if (lhs.getType().cast<ShapedType>().isDynamicDim(0))
 | 
			
		||||
        dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
 | 
			
		||||
        dyn_shape.push_back(b.create<memref::DimOp>(loc, lhs, 0));
 | 
			
		||||
      if (rhs.getType().cast<ShapedType>().isDynamicDim(1))
 | 
			
		||||
        dyn_shape.push_back(b.create<DimOp>(loc, rhs, 1));
 | 
			
		||||
        dyn_shape.push_back(b.create<memref::DimOp>(loc, rhs, 1));
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
    case DotOperationType::kMatrixVector: {
 | 
			
		||||
      if (lhs.getType().cast<ShapedType>().isDynamicDim(0))
 | 
			
		||||
        dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
 | 
			
		||||
        dyn_shape.push_back(b.create<memref::DimOp>(loc, lhs, 0));
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
    case DotOperationType::kVectorDot:
 | 
			
		||||
| 
						 | 
				
			
			@ -1203,11 +1206,11 @@ SmallVector<Value, 8> GetDotGeneralOpInitTensorDynSizes(
 | 
			
		|||
    OpBuilder& b, Location loc, Value lhs, Value rhs, ShapedType result_type) {
 | 
			
		||||
  SmallVector<Value, 8> dyn_shape;
 | 
			
		||||
  if (result_type.isDynamicDim(0))
 | 
			
		||||
    dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
 | 
			
		||||
    dyn_shape.push_back(b.create<memref::DimOp>(loc, lhs, 0));
 | 
			
		||||
  if (result_type.isDynamicDim(1))
 | 
			
		||||
    dyn_shape.push_back(b.create<DimOp>(loc, lhs, 1));
 | 
			
		||||
    dyn_shape.push_back(b.create<memref::DimOp>(loc, lhs, 1));
 | 
			
		||||
  if (result_type.isDynamicDim(2))
 | 
			
		||||
    dyn_shape.push_back(b.create<DimOp>(loc, rhs, 2));
 | 
			
		||||
    dyn_shape.push_back(b.create<memref::DimOp>(loc, rhs, 2));
 | 
			
		||||
  return dyn_shape;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -1307,7 +1310,7 @@ SmallVector<Value, 8> GetReduceOpInitTensorDynSizes(
 | 
			
		|||
  for (int i = 0, j = 0; i < rank; ++i) {
 | 
			
		||||
    if (s.count(i)) continue;
 | 
			
		||||
    if (!result_type.isDynamicDim(j++)) continue;
 | 
			
		||||
    dyn_shape.push_back(b.create<DimOp>(loc, arg, i));
 | 
			
		||||
    dyn_shape.push_back(b.create<memref::DimOp>(loc, arg, i));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return dyn_shape;
 | 
			
		||||
| 
						 | 
				
			
			@ -1467,7 +1470,7 @@ struct NormalConvOpOnTensorsConversion
 | 
			
		|||
    // The output shape is N spatial_dims F.
 | 
			
		||||
    SmallVector<Value, 8> dyn_sizes;
 | 
			
		||||
    if (result_type.isDynamicDim(0)) {
 | 
			
		||||
      dyn_sizes.push_back(rewriter.create<DimOp>(loc, input, 0));
 | 
			
		||||
      dyn_sizes.push_back(rewriter.create<memref::DimOp>(loc, input, 0));
 | 
			
		||||
    }
 | 
			
		||||
    for (int64_t i = 1, e = rank - 1; i < e; ++i) {
 | 
			
		||||
      if (result_type.isDynamicDim(i)) {
 | 
			
		||||
| 
						 | 
				
			
			@ -1476,7 +1479,8 @@ struct NormalConvOpOnTensorsConversion
 | 
			
		|||
      }
 | 
			
		||||
    }
 | 
			
		||||
    if (result_type.isDynamicDim(rank - 1)) {
 | 
			
		||||
      dyn_sizes.push_back(rewriter.create<DimOp>(loc, filter, rank - 1));
 | 
			
		||||
      dyn_sizes.push_back(
 | 
			
		||||
          rewriter.create<memref::DimOp>(loc, filter, rank - 1));
 | 
			
		||||
    }
 | 
			
		||||
    Value init_tensor = rewriter.create<linalg::InitTensorOp>(
 | 
			
		||||
        loc, dyn_sizes, result_type.getShape(), result_type.getElementType());
 | 
			
		||||
| 
						 | 
				
			
			@ -1856,8 +1860,8 @@ struct LhloLegalizeToLinalgPass
 | 
			
		|||
    OwningRewritePatternList patterns;
 | 
			
		||||
    ConversionTarget target(getContext());
 | 
			
		||||
    target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
 | 
			
		||||
                           math::MathDialect, StandardOpsDialect,
 | 
			
		||||
                           AffineDialect>();
 | 
			
		||||
                           math::MathDialect, memref::MemRefDialect,
 | 
			
		||||
                           StandardOpsDialect, AffineDialect>();
 | 
			
		||||
 | 
			
		||||
    auto func = getFunction();
 | 
			
		||||
    populateLHLOToLinalgConversionPattern(func.getContext(), &patterns);
 | 
			
		||||
| 
						 | 
				
			
			@ -1881,6 +1885,9 @@ struct HloLegalizeToLinalgPass
 | 
			
		|||
                           math::MathDialect, StandardOpsDialect,
 | 
			
		||||
                           tensor::TensorDialect, scf::SCFDialect>();
 | 
			
		||||
 | 
			
		||||
    // TODO: DimOp shouldn't be in MemRefDialect
 | 
			
		||||
    target.addLegalOp<memref::DimOp>();
 | 
			
		||||
 | 
			
		||||
    auto func = getFunction();
 | 
			
		||||
    mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
 | 
			
		||||
    if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -22,6 +22,7 @@ limitations under the License.
 | 
			
		|||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
 | 
			
		||||
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
 | 
			
		||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 | 
			
		||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
 | 
			
		||||
#include "mlir/Dialect/SCF/SCF.h"
 | 
			
		||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
 | 
			
		||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -95,7 +96,7 @@ class LhloFuseLinalgPass
 | 
			
		|||
        continue;
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      if (auto tensor_load = dyn_cast<TensorLoadOp>(definingOp)) {
 | 
			
		||||
      if (auto tensor_load = dyn_cast<memref::TensorLoadOp>(definingOp)) {
 | 
			
		||||
        auto alias = tensor_load.memref();
 | 
			
		||||
        if (result_buffers.insert(alias).second) {
 | 
			
		||||
          worklist.push_back(alias);
 | 
			
		||||
| 
						 | 
				
			
			@ -103,7 +104,7 @@ class LhloFuseLinalgPass
 | 
			
		|||
        continue;
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      if (auto tensor_to_memref = dyn_cast<TensorToMemrefOp>(definingOp)) {
 | 
			
		||||
      if (auto tensor_to_memref = dyn_cast<memref::BufferCastOp>(definingOp)) {
 | 
			
		||||
        auto alias = tensor_to_memref.tensor();
 | 
			
		||||
        if (result_buffers.insert(alias).second) {
 | 
			
		||||
          worklist.push_back(alias);
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -96,9 +96,10 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
 | 
			
		|||
 | 
			
		||||
      // Load the initial value and store it to the output.
 | 
			
		||||
      for (auto pair : llvm::zip(reduce_op.init_values(), reduce_op.out())) {
 | 
			
		||||
        auto init_value = rewriter.create<mlir::LoadOp>(loc, std::get<0>(pair));
 | 
			
		||||
        rewriter.create<mlir::StoreOp>(loc, init_value, std::get<1>(pair),
 | 
			
		||||
                                       ArrayRef<Value>{index});
 | 
			
		||||
        auto init_value =
 | 
			
		||||
            rewriter.create<mlir::memref::LoadOp>(loc, std::get<0>(pair));
 | 
			
		||||
        rewriter.create<mlir::memref::StoreOp>(
 | 
			
		||||
            loc, init_value, std::get<1>(pair), ArrayRef<Value>{index});
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      // Insert a loop into the body to compute the reduction. The loop ranges
 | 
			
		||||
| 
						 | 
				
			
			@ -128,8 +129,8 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
 | 
			
		|||
      auto oneAttr = rewriter.getI64IntegerAttr(1);
 | 
			
		||||
      OpFoldResult size = oneAttr;
 | 
			
		||||
      OpFoldResult stride = oneAttr;
 | 
			
		||||
      auto accumulator = rewriter.create<SubViewOp>(loc, resType, output,
 | 
			
		||||
                                                    offset, size, stride);
 | 
			
		||||
      auto accumulator = rewriter.create<memref::SubViewOp>(
 | 
			
		||||
          loc, resType, output, offset, size, stride);
 | 
			
		||||
      llvm::SmallVector<Value, 4> indexings;
 | 
			
		||||
      auto input_buffer = *reduce_op.operands().begin();
 | 
			
		||||
      auto input_type_rank =
 | 
			
		||||
| 
						 | 
				
			
			@ -143,8 +144,8 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
 | 
			
		|||
          }));
 | 
			
		||||
      SmallVector<OpFoldResult> sizes(input_type_rank, oneAttr);
 | 
			
		||||
      SmallVector<OpFoldResult> strides(input_type_rank, oneAttr);
 | 
			
		||||
      auto rhs = rewriter.create<SubViewOp>(loc, accumulator.getType(), input,
 | 
			
		||||
                                            offsets, sizes, strides);
 | 
			
		||||
      auto rhs = rewriter.create<memref::SubViewOp>(
 | 
			
		||||
          loc, accumulator.getType(), input, offsets, sizes, strides);
 | 
			
		||||
 | 
			
		||||
      // Now copy over the actual body of the reduction, leaving out the
 | 
			
		||||
      // terminator.
 | 
			
		||||
| 
						 | 
				
			
			@ -179,8 +180,9 @@ struct LhloLegalizeToGpuPass
 | 
			
		|||
  void runOnFunction() override {
 | 
			
		||||
    OwningRewritePatternList patterns;
 | 
			
		||||
    ConversionTarget target(getContext());
 | 
			
		||||
    target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
 | 
			
		||||
                           gpu::GPUDialect, scf::SCFDialect, LmhloDialect>();
 | 
			
		||||
    target.addLegalDialect<linalg::LinalgDialect, memref::MemRefDialect,
 | 
			
		||||
                           StandardOpsDialect, gpu::GPUDialect, scf::SCFDialect,
 | 
			
		||||
                           LmhloDialect>();
 | 
			
		||||
    target.addIllegalOp<ReduceOp>();
 | 
			
		||||
    auto func = getFunction();
 | 
			
		||||
    patterns.insert<LhloReduceToGPULaunchConverter>(func.getContext());
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -18,6 +18,7 @@ limitations under the License.
 | 
			
		|||
#include "llvm/ADT/SmallVector.h"
 | 
			
		||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 | 
			
		||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 | 
			
		||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
 | 
			
		||||
#include "mlir/Dialect/SCF/SCF.h"
 | 
			
		||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
 | 
			
		||||
#include "mlir/IR/BuiltinTypes.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -43,10 +44,11 @@ Value ApplySingleResultLhloCode(Location loc, ValueRange operands,
 | 
			
		|||
                                Block* lhlo_block, OpBuilder* b) {
 | 
			
		||||
  SmallVector<Value, 2> arg_bufs;
 | 
			
		||||
  for (auto arg_type : lhlo_block->getArgumentTypes()) {
 | 
			
		||||
    arg_bufs.push_back(b->create<AllocOp>(loc, arg_type.cast<MemRefType>()));
 | 
			
		||||
    arg_bufs.push_back(
 | 
			
		||||
        b->create<memref::AllocOp>(loc, arg_type.cast<MemRefType>()));
 | 
			
		||||
  }
 | 
			
		||||
  for (auto operand : llvm::enumerate(operands)) {
 | 
			
		||||
    b->create<StoreOp>(loc, operand.value(), arg_bufs[operand.index()]);
 | 
			
		||||
    b->create<memref::StoreOp>(loc, operand.value(), arg_bufs[operand.index()]);
 | 
			
		||||
  }
 | 
			
		||||
  // Clone the ops from `lhlo_block`.
 | 
			
		||||
  BlockAndValueMapping mapping;
 | 
			
		||||
| 
						 | 
				
			
			@ -55,7 +57,7 @@ Value ApplySingleResultLhloCode(Location loc, ValueRange operands,
 | 
			
		|||
    auto clone = b->clone(nested, mapping);
 | 
			
		||||
    mapping.map(nested.getResults(), clone->getResults());
 | 
			
		||||
  }
 | 
			
		||||
  return b->create<LoadOp>(loc, arg_bufs.back());
 | 
			
		||||
  return b->create<memref::LoadOp>(loc, arg_bufs.back());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Converts a block with LHLO ops and with signature:
 | 
			
		||||
| 
						 | 
				
			
			@ -78,7 +80,8 @@ void ConvertToReductionOperator(Location loc, scf::ReduceOp reduce_op,
 | 
			
		|||
Value GetStaticOrDynamicDim(mlir::Location loc, Value shaped_value,
 | 
			
		||||
                            size_t dim_index, int64_t dim, OpBuilder* b) {
 | 
			
		||||
  return dim == ShapedType::kDynamicSize
 | 
			
		||||
             ? b->create<DimOp>(loc, shaped_value, dim_index).getResult()
 | 
			
		||||
             ? b->create<memref::DimOp>(loc, shaped_value, dim_index)
 | 
			
		||||
                   .getResult()
 | 
			
		||||
             : b->create<ConstantIndexOp>(loc, dim);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -249,8 +252,8 @@ class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
 | 
			
		|||
      (is_reducing_dim ? reduce_step : parallel_step).push_back(step);
 | 
			
		||||
    }
 | 
			
		||||
    // Load initial value from memref<element_type>.
 | 
			
		||||
    SmallVector<Value, 1> init_value = {
 | 
			
		||||
        rewriter->create<LoadOp>(loc, *reduce_op.init_values().begin())};
 | 
			
		||||
    SmallVector<Value, 1> init_value = {rewriter->create<memref::LoadOp>(
 | 
			
		||||
        loc, *reduce_op.init_values().begin())};
 | 
			
		||||
    // Outer ParallelOp is not needed if it is a reduction across all dims.
 | 
			
		||||
    scf::ParallelOp outer;
 | 
			
		||||
    if (!parallel_lower.empty()) {
 | 
			
		||||
| 
						 | 
				
			
			@ -272,7 +275,7 @@ class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
 | 
			
		|||
      out_indices.push_back(rewriter->create<ConstantIndexOp>(loc, 0));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    rewriter->create<StoreOp>(loc, reduction_result, out, out_indices);
 | 
			
		||||
    rewriter->create<memref::StoreOp>(loc, reduction_result, out, out_indices);
 | 
			
		||||
 | 
			
		||||
    // Load the element to reduce.
 | 
			
		||||
    SmallVector<Value, 2> indices;
 | 
			
		||||
| 
						 | 
				
			
			@ -290,7 +293,7 @@ class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
 | 
			
		|||
    }
 | 
			
		||||
 | 
			
		||||
    rewriter->setInsertionPointToStart(inner.getBody());
 | 
			
		||||
    Value elem = rewriter->create<mlir::LoadOp>(
 | 
			
		||||
    Value elem = rewriter->create<mlir::memref::LoadOp>(
 | 
			
		||||
        loc, *reduce_op.operands().begin(), indices);
 | 
			
		||||
    return rewriter->create<scf::ReduceOp>(loc, elem);
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			@ -385,7 +388,7 @@ class ReduceWindowOpConverter
 | 
			
		|||
      ConversionPatternRewriter* rewriter) const {
 | 
			
		||||
    auto loc = reduce_window_op.getLoc();
 | 
			
		||||
    Value init_value =
 | 
			
		||||
        rewriter->create<LoadOp>(loc, reduce_window_op.init_value());
 | 
			
		||||
        rewriter->create<memref::LoadOp>(loc, reduce_window_op.init_value());
 | 
			
		||||
 | 
			
		||||
    Value zero = rewriter->create<ConstantIndexOp>(loc, 0);
 | 
			
		||||
    Value one = rewriter->create<ConstantIndexOp>(loc, 1);
 | 
			
		||||
| 
						 | 
				
			
			@ -408,7 +411,8 @@ class ReduceWindowOpConverter
 | 
			
		|||
 | 
			
		||||
    Value reduction_result = *window_loop.getResults().begin();
 | 
			
		||||
    auto output_ivs = output_loop.getInductionVars();
 | 
			
		||||
    rewriter->create<StoreOp>(loc, reduction_result, output, output_ivs);
 | 
			
		||||
    rewriter->create<memref::StoreOp>(loc, reduction_result, output,
 | 
			
		||||
                                      output_ivs);
 | 
			
		||||
    return std::make_pair(output_loop, window_loop);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -439,7 +443,7 @@ class ReduceWindowOpConverter
 | 
			
		|||
 | 
			
		||||
    OpBuilder then_builder =
 | 
			
		||||
        elem_or_init.getThenBodyBuilder(rewriter->getListener());
 | 
			
		||||
    Value elem = then_builder.create<mlir::LoadOp>(
 | 
			
		||||
    Value elem = then_builder.create<mlir::memref::LoadOp>(
 | 
			
		||||
        loc, reduce_window_op.operand(), mapped_ivs.ivs);
 | 
			
		||||
    then_builder.create<scf::YieldOp>(loc, elem);
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -497,8 +501,8 @@ class SelectAndScatterOpConverter
 | 
			
		|||
    auto selected_ivs = SelectIvs(s_and_s_op, loop_over_src, &rewriter);
 | 
			
		||||
 | 
			
		||||
    // Load `source[selected_ivs]`.
 | 
			
		||||
    auto src_elem = rewriter.create<LoadOp>(loc, s_and_s_op.source(),
 | 
			
		||||
                                            loop_over_src.getInductionVars());
 | 
			
		||||
    auto src_elem = rewriter.create<memref::LoadOp>(
 | 
			
		||||
        loc, s_and_s_op.source(), loop_over_src.getInductionVars());
 | 
			
		||||
 | 
			
		||||
    // Compute `out[selected_ivs]` = scatter(out[selected_ivs], src_element)`.
 | 
			
		||||
    auto rmw = rewriter.create<GenericAtomicRMWOp>(loc, s_and_s_op.out(),
 | 
			
		||||
| 
						 | 
				
			
			@ -517,13 +521,13 @@ class SelectAndScatterOpConverter
 | 
			
		|||
  void InitializeOutput(lmhlo::SelectAndScatterOp s_and_s_op,
 | 
			
		||||
                        OpBuilder* b) const {
 | 
			
		||||
    auto loc = s_and_s_op.getLoc();
 | 
			
		||||
    Value init_value = b->create<LoadOp>(loc, s_and_s_op.init_value());
 | 
			
		||||
    Value init_value = b->create<memref::LoadOp>(loc, s_and_s_op.init_value());
 | 
			
		||||
 | 
			
		||||
    scf::ParallelOp loop_over_output =
 | 
			
		||||
        MakeLoopOverShape(loc, s_and_s_op.out(), b);
 | 
			
		||||
    OpBuilder::InsertionGuard guard(*b);
 | 
			
		||||
    b->setInsertionPointToStart(loop_over_output.getBody());
 | 
			
		||||
    b->create<StoreOp>(loc, init_value, s_and_s_op.out(),
 | 
			
		||||
    b->create<memref::StoreOp>(loc, init_value, s_and_s_op.out(),
 | 
			
		||||
                               loop_over_output.getInductionVars());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -647,7 +651,7 @@ class SelectAndScatterOpConverter
 | 
			
		|||
 | 
			
		||||
    TypeRange iter_arg_types{ivs_val_flag->to_vector()};
 | 
			
		||||
    Value operand_elem =
 | 
			
		||||
        b->create<LoadOp>(loc, s_and_s_op.operand(), operand_ivs);
 | 
			
		||||
        b->create<memref::LoadOp>(loc, s_and_s_op.operand(), operand_ivs);
 | 
			
		||||
    auto if_init =
 | 
			
		||||
        b->create<scf::IfOp>(loc, iter_arg_types, ivs_val_flag->is_init(),
 | 
			
		||||
                             /*withElseRegion=*/true);
 | 
			
		||||
| 
						 | 
				
			
			@ -712,8 +716,8 @@ struct LhloLegalizeToParallelLoopsPass
 | 
			
		|||
    // clang-format on
 | 
			
		||||
 | 
			
		||||
    ConversionTarget target(getContext());
 | 
			
		||||
    target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
 | 
			
		||||
                           scf::SCFDialect, LmhloDialect>();
 | 
			
		||||
    target.addLegalDialect<linalg::LinalgDialect, memref::MemRefDialect,
 | 
			
		||||
                           StandardOpsDialect, scf::SCFDialect, LmhloDialect>();
 | 
			
		||||
    target.addIllegalOp<lmhlo::ReduceOp, lmhlo::ReduceWindowOp,
 | 
			
		||||
                        lmhlo::SelectAndScatterOp>();
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -15,6 +15,7 @@ limitations under the License.
 | 
			
		|||
 | 
			
		||||
#include "llvm/ADT/SmallVector.h"
 | 
			
		||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 | 
			
		||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
 | 
			
		||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
 | 
			
		||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
 | 
			
		||||
#include "mlir/IR/Attributes.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -58,7 +59,8 @@ Value CalculateShapeValue(Location loc, Value operand,
 | 
			
		|||
  int64_t rank = result_type.getRank();
 | 
			
		||||
  shape_values.reserve(rank);
 | 
			
		||||
  for (int64_t i = 0; i < rank; ++i) {
 | 
			
		||||
    shape_values.push_back(rewriter.create<mlir::DimOp>(loc, operand, i));
 | 
			
		||||
    shape_values.push_back(
 | 
			
		||||
        rewriter.create<mlir::memref::DimOp>(loc, operand, i));
 | 
			
		||||
  }
 | 
			
		||||
  return rewriter.create<tensor::FromElementsOp>(loc, shape_values);
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -967,10 +967,10 @@ func @unpack_repack_same_tuple_single_element(%arg0: tuple<tensor<i32>>) -> tupl
 | 
			
		|||
 | 
			
		||||
// CHECK-LABEL: func @erase_dead_lhlo_constant
 | 
			
		||||
func @erase_dead_lhlo_constant() {
 | 
			
		||||
  %M = alloc() : memref<256x1024xf32>
 | 
			
		||||
  %M = memref.alloc() : memref<256x1024xf32>
 | 
			
		||||
  // CHECK-NEXT: return
 | 
			
		||||
  "lmhlo.constant"(%M) {value = dense<0.0> : tensor<f32>} : (memref<256x1024xf32>) -> ()
 | 
			
		||||
  dealloc %M : memref<256x1024xf32>
 | 
			
		||||
  memref.dealloc %M : memref<256x1024xf32>
 | 
			
		||||
  return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -979,9 +979,9 @@ func @erase_dead_lhlo_constant() {
 | 
			
		|||
func @erase_dead_lhlo_constant_negative(%M : memref<4xf32>) -> memref<256x1024xf32> {
 | 
			
		||||
  // CHECK-NEXT: lmhlo.constant
 | 
			
		||||
  "lmhlo.constant"(%M) {value = dense<0.0> : tensor<f32>} : (memref<4xf32>) -> ()
 | 
			
		||||
  // CHECK-NEXT: alloc
 | 
			
		||||
  // CHECK-NEXT: memref.alloc
 | 
			
		||||
  // CHECK-NEXT: lmhlo.constant
 | 
			
		||||
  %N = alloc() : memref<256x1024xf32>
 | 
			
		||||
  %N = memref.alloc() : memref<256x1024xf32>
 | 
			
		||||
  "lmhlo.constant"(%N) {value = dense<0.0> : tensor<f32>} : (memref<256x1024xf32>) -> ()
 | 
			
		||||
  return %N : memref<256x1024xf32>
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -27,7 +27,7 @@ func private @print_memref_i8(memref<*xi8>) attributes { llvm.emit_c_interface }
 | 
			
		|||
func private @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface }
 | 
			
		||||
 | 
			
		||||
func @trivial_broadcast_wrapper() {
 | 
			
		||||
  %input_buf = alloc() : memref<3xf32>
 | 
			
		||||
  %input_buf = memref.alloc() : memref<3xf32>
 | 
			
		||||
 | 
			
		||||
  %c1f32 = constant 1.0 : f32
 | 
			
		||||
  %c2f32 = constant 2.0 : f32
 | 
			
		||||
| 
						 | 
				
			
			@ -36,19 +36,19 @@ func @trivial_broadcast_wrapper() {
 | 
			
		|||
  %c0 = constant 0 : index
 | 
			
		||||
  %c1 = constant 1 : index
 | 
			
		||||
  %c2 = constant 2 : index
 | 
			
		||||
  store %c1f32, %input_buf[%c0] : memref<3xf32>
 | 
			
		||||
  store %c2f32, %input_buf[%c1] : memref<3xf32>
 | 
			
		||||
  store %c3f32, %input_buf[%c2] : memref<3xf32>
 | 
			
		||||
  %input = tensor_load %input_buf : memref<3xf32>
 | 
			
		||||
  memref.store %c1f32, %input_buf[%c0] : memref<3xf32>
 | 
			
		||||
  memref.store %c2f32, %input_buf[%c1] : memref<3xf32>
 | 
			
		||||
  memref.store %c3f32, %input_buf[%c2] : memref<3xf32>
 | 
			
		||||
  %input = memref.tensor_load %input_buf : memref<3xf32>
 | 
			
		||||
 | 
			
		||||
  // Test BroadcastInDimOp.
 | 
			
		||||
  %output = "mhlo.broadcast_in_dim"(%input) {
 | 
			
		||||
    broadcast_dimensions = dense<0> : tensor<1xi64>
 | 
			
		||||
  } : (tensor<3xf32>) -> tensor<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %output_buf = tensor_to_memref %output : memref<3x4xf32>
 | 
			
		||||
  %output_buf = memref.buffer_cast %output : memref<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
  %unranked_output = memref.cast %output_buf : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
 | 
			
		||||
  // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
 | 
			
		||||
  // CHECK-NEXT: [1,   1,   1,   1]
 | 
			
		||||
| 
						 | 
				
			
			@ -63,9 +63,9 @@ func @trivial_broadcast_wrapper() {
 | 
			
		|||
    broadcast_dimensions = dense<0> : tensor<1xi64>
 | 
			
		||||
  } : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32>
 | 
			
		||||
  %dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_dyn_output = memref_cast %dyn_output_buf
 | 
			
		||||
  %unranked_dyn_output = memref.cast %dyn_output_buf
 | 
			
		||||
    : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
 | 
			
		||||
  // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
 | 
			
		||||
| 
						 | 
				
			
			@ -76,29 +76,29 @@ func @trivial_broadcast_wrapper() {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
func @broadcast_in_X_dim_wrapper() {
 | 
			
		||||
  %input_buf = alloc() : memref<1x4xf32>
 | 
			
		||||
  %input_buf = memref.alloc() : memref<1x4xf32>
 | 
			
		||||
  %c1f32 = constant 1.0 : f32
 | 
			
		||||
  %c0 = constant 0 : index
 | 
			
		||||
  store %c1f32, %input_buf[%c0, %c0] : memref<1x4xf32>
 | 
			
		||||
  memref.store %c1f32, %input_buf[%c0, %c0] : memref<1x4xf32>
 | 
			
		||||
  %c2f32 = constant 2.0 : f32
 | 
			
		||||
  %c1 = constant 1 : index
 | 
			
		||||
  store %c2f32, %input_buf[%c0, %c1] : memref<1x4xf32>
 | 
			
		||||
  memref.store %c2f32, %input_buf[%c0, %c1] : memref<1x4xf32>
 | 
			
		||||
  %c3f32 = constant 3.0 : f32
 | 
			
		||||
  %c2 = constant 2 : index
 | 
			
		||||
  store %c3f32, %input_buf[%c0, %c2] : memref<1x4xf32>
 | 
			
		||||
  memref.store %c3f32, %input_buf[%c0, %c2] : memref<1x4xf32>
 | 
			
		||||
  %c4f32 = constant 4.0 : f32
 | 
			
		||||
  %c3 = constant 3 : index
 | 
			
		||||
  store %c4f32, %input_buf[%c0, %c3] : memref<1x4xf32>
 | 
			
		||||
  %input = tensor_load %input_buf : memref<1x4xf32>
 | 
			
		||||
  memref.store %c4f32, %input_buf[%c0, %c3] : memref<1x4xf32>
 | 
			
		||||
  %input = memref.tensor_load %input_buf : memref<1x4xf32>
 | 
			
		||||
 | 
			
		||||
  // Test BroadcastInDimOp.
 | 
			
		||||
  %output = "mhlo.broadcast_in_dim"(%input) {
 | 
			
		||||
    broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
 | 
			
		||||
  } : (tensor<1x4xf32>) -> tensor<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %output_buf = tensor_to_memref %output : memref<3x4xf32>
 | 
			
		||||
  %output_buf = memref.buffer_cast %output : memref<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
  %unranked_output = memref.cast %output_buf : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
 | 
			
		||||
  // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
 | 
			
		||||
  // CHECK-NEXT: [1,   2,   3,   4]
 | 
			
		||||
| 
						 | 
				
			
			@ -112,9 +112,9 @@ func @broadcast_in_X_dim_wrapper() {
 | 
			
		|||
    broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
 | 
			
		||||
  } : (tensor<1x4xf32>, tensor<2xindex>) -> tensor<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32>
 | 
			
		||||
  %dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_dyn_output = memref_cast %dyn_output_buf
 | 
			
		||||
  %unranked_dyn_output = memref.cast %dyn_output_buf
 | 
			
		||||
    : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
 | 
			
		||||
  // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
 | 
			
		||||
| 
						 | 
				
			
			@ -125,26 +125,26 @@ func @broadcast_in_X_dim_wrapper() {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
func @broadcast_in_Y_dim_wrapper() {
 | 
			
		||||
  %input_buf = alloc() : memref<3x1xf32>
 | 
			
		||||
  %input_buf = memref.alloc() : memref<3x1xf32>
 | 
			
		||||
  %c1f32 = constant 1.0 : f32
 | 
			
		||||
  %c0 = constant 0 : index
 | 
			
		||||
  store %c1f32, %input_buf[%c0, %c0] : memref<3x1xf32>
 | 
			
		||||
  memref.store %c1f32, %input_buf[%c0, %c0] : memref<3x1xf32>
 | 
			
		||||
  %c2f32 = constant 2.0 : f32
 | 
			
		||||
  %c1 = constant 1 : index
 | 
			
		||||
  store %c2f32, %input_buf[%c1, %c0] : memref<3x1xf32>
 | 
			
		||||
  memref.store %c2f32, %input_buf[%c1, %c0] : memref<3x1xf32>
 | 
			
		||||
  %c3f32 = constant 3.0 : f32
 | 
			
		||||
  %c2 = constant 2 : index
 | 
			
		||||
  store %c3f32, %input_buf[%c2, %c0] : memref<3x1xf32>
 | 
			
		||||
  %input = tensor_load %input_buf : memref<3x1xf32>
 | 
			
		||||
  memref.store %c3f32, %input_buf[%c2, %c0] : memref<3x1xf32>
 | 
			
		||||
  %input = memref.tensor_load %input_buf : memref<3x1xf32>
 | 
			
		||||
 | 
			
		||||
  // Test BroadcastInDimOp.
 | 
			
		||||
  %output = "mhlo.broadcast_in_dim"(%input) {
 | 
			
		||||
    broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
 | 
			
		||||
  } : (tensor<3x1xf32>) -> tensor<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %output_buf = tensor_to_memref %output : memref<3x4xf32>
 | 
			
		||||
  %output_buf = memref.buffer_cast %output : memref<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
  %unranked_output = memref.cast %output_buf : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
 | 
			
		||||
  // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
 | 
			
		||||
  // CHECK-NEXT: [1,   1,   1,   1]
 | 
			
		||||
| 
						 | 
				
			
			@ -159,9 +159,9 @@ func @broadcast_in_Y_dim_wrapper() {
 | 
			
		|||
    broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
 | 
			
		||||
  } : (tensor<3x1xf32>, tensor<2xindex>) -> tensor<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32>
 | 
			
		||||
  %dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_dyn_output = memref_cast %dyn_output_buf
 | 
			
		||||
  %unranked_dyn_output = memref.cast %dyn_output_buf
 | 
			
		||||
    : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
 | 
			
		||||
  // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
 | 
			
		||||
| 
						 | 
				
			
			@ -172,29 +172,29 @@ func @broadcast_in_Y_dim_wrapper() {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
func @broadcast_in_X_dim_transpose_wrapper() {
 | 
			
		||||
  %input_buf = alloc() : memref<4x1xf32>
 | 
			
		||||
  %input_buf = memref.alloc() : memref<4x1xf32>
 | 
			
		||||
  %c1f32 = constant 1.0 : f32
 | 
			
		||||
  %c0 = constant 0 : index
 | 
			
		||||
  store %c1f32, %input_buf[%c0, %c0] : memref<4x1xf32>
 | 
			
		||||
  memref.store %c1f32, %input_buf[%c0, %c0] : memref<4x1xf32>
 | 
			
		||||
  %c2f32 = constant 2.0 : f32
 | 
			
		||||
  %c1 = constant 1 : index
 | 
			
		||||
  store %c2f32, %input_buf[%c1, %c0] : memref<4x1xf32>
 | 
			
		||||
  memref.store %c2f32, %input_buf[%c1, %c0] : memref<4x1xf32>
 | 
			
		||||
  %c3f32 = constant 3.0 : f32
 | 
			
		||||
  %c2 = constant 2 : index
 | 
			
		||||
  store %c3f32, %input_buf[%c2, %c0] : memref<4x1xf32>
 | 
			
		||||
  memref.store %c3f32, %input_buf[%c2, %c0] : memref<4x1xf32>
 | 
			
		||||
  %c4f32 = constant 4.0 : f32
 | 
			
		||||
  %c3 = constant 3 : index
 | 
			
		||||
  store %c4f32, %input_buf[%c3, %c0] : memref<4x1xf32>
 | 
			
		||||
  %input = tensor_load %input_buf : memref<4x1xf32>
 | 
			
		||||
  memref.store %c4f32, %input_buf[%c3, %c0] : memref<4x1xf32>
 | 
			
		||||
  %input = memref.tensor_load %input_buf : memref<4x1xf32>
 | 
			
		||||
 | 
			
		||||
  // Test BroadcastInDimOp.
 | 
			
		||||
  %output = "mhlo.broadcast_in_dim"(%input) {
 | 
			
		||||
    broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
 | 
			
		||||
  } : (tensor<4x1xf32>) -> tensor<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %output_buf = tensor_to_memref %output : memref<3x4xf32>
 | 
			
		||||
  %output_buf = memref.buffer_cast %output : memref<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
  %unranked_output = memref.cast %output_buf : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
 | 
			
		||||
  // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
 | 
			
		||||
  // CHECK-NEXT: [1,   2,   3,   4]
 | 
			
		||||
| 
						 | 
				
			
			@ -208,9 +208,9 @@ func @broadcast_in_X_dim_transpose_wrapper() {
 | 
			
		|||
    broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
 | 
			
		||||
  } : (tensor<4x1xf32>, tensor<2xindex>) -> tensor<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32>
 | 
			
		||||
  %dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_dyn_output = memref_cast %dyn_output_buf
 | 
			
		||||
  %unranked_dyn_output = memref.cast %dyn_output_buf
 | 
			
		||||
    : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
 | 
			
		||||
  // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
 | 
			
		||||
| 
						 | 
				
			
			@ -221,26 +221,26 @@ func @broadcast_in_X_dim_transpose_wrapper() {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
func @broadcast_in_Y_dim_transpose_wrapper() {
 | 
			
		||||
  %input_buf = alloc() : memref<1x3xf32>
 | 
			
		||||
  %input_buf = memref.alloc() : memref<1x3xf32>
 | 
			
		||||
  %c1f32 = constant 1.0 : f32
 | 
			
		||||
  %c0 = constant 0 : index
 | 
			
		||||
  store %c1f32, %input_buf[%c0, %c0] : memref<1x3xf32>
 | 
			
		||||
  memref.store %c1f32, %input_buf[%c0, %c0] : memref<1x3xf32>
 | 
			
		||||
  %c2f32 = constant 2.0 : f32
 | 
			
		||||
  %c1 = constant 1 : index
 | 
			
		||||
  store %c2f32, %input_buf[%c0, %c1] : memref<1x3xf32>
 | 
			
		||||
  memref.store %c2f32, %input_buf[%c0, %c1] : memref<1x3xf32>
 | 
			
		||||
  %c3f32 = constant 3.0 : f32
 | 
			
		||||
  %c2 = constant 2 : index
 | 
			
		||||
  store %c3f32, %input_buf[%c0, %c2] : memref<1x3xf32>
 | 
			
		||||
  %input = tensor_load %input_buf : memref<1x3xf32>
 | 
			
		||||
  memref.store %c3f32, %input_buf[%c0, %c2] : memref<1x3xf32>
 | 
			
		||||
  %input = memref.tensor_load %input_buf : memref<1x3xf32>
 | 
			
		||||
 | 
			
		||||
  // Test BroadcastInDimOp.
 | 
			
		||||
  %output = "mhlo.broadcast_in_dim"(%input) {
 | 
			
		||||
    broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
 | 
			
		||||
  } : (tensor<1x3xf32>) -> tensor<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %output_buf = tensor_to_memref %output : memref<3x4xf32>
 | 
			
		||||
  %output_buf = memref.buffer_cast %output : memref<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
  %unranked_output = memref.cast %output_buf : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
 | 
			
		||||
  // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
 | 
			
		||||
  // CHECK-NEXT-NEXT: [1,   1,   1,   1]
 | 
			
		||||
| 
						 | 
				
			
			@ -255,9 +255,9 @@ func @broadcast_in_Y_dim_transpose_wrapper() {
 | 
			
		|||
    broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
 | 
			
		||||
  } : (tensor<1x3xf32>, tensor<2xindex>) -> tensor<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32>
 | 
			
		||||
  %dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_dyn_output = memref_cast %dyn_output_buf
 | 
			
		||||
  %unranked_dyn_output = memref.cast %dyn_output_buf
 | 
			
		||||
    : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
 | 
			
		||||
  // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
 | 
			
		||||
| 
						 | 
				
			
			@ -268,20 +268,20 @@ func @broadcast_in_Y_dim_transpose_wrapper() {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
func @broadcast_scalar_1d_wrapper() {
 | 
			
		||||
  %input_buf = alloc() : memref<1xf32>
 | 
			
		||||
  %input_buf = memref.alloc() : memref<1xf32>
 | 
			
		||||
  %c1f32 = constant 1.0 : f32
 | 
			
		||||
  %c0 = constant 0 : index
 | 
			
		||||
  store %c1f32, %input_buf[%c0] : memref<1xf32>
 | 
			
		||||
  %input = tensor_load %input_buf : memref<1xf32>
 | 
			
		||||
  memref.store %c1f32, %input_buf[%c0] : memref<1xf32>
 | 
			
		||||
  %input = memref.tensor_load %input_buf : memref<1xf32>
 | 
			
		||||
 | 
			
		||||
  // Test BroadcastInDimOp.
 | 
			
		||||
  %output = "mhlo.broadcast_in_dim"(%input) {
 | 
			
		||||
    broadcast_dimensions = dense<0> : tensor<1xi64>
 | 
			
		||||
  } : (tensor<1xf32>) -> tensor<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %output_buf = tensor_to_memref %output : memref<3x4xf32>
 | 
			
		||||
  %output_buf = memref.buffer_cast %output : memref<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
  %unranked_output = memref.cast %output_buf : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
 | 
			
		||||
  // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
 | 
			
		||||
  // CHECK-NEXT: [1, 1, 1, 1]
 | 
			
		||||
| 
						 | 
				
			
			@ -296,9 +296,9 @@ func @broadcast_scalar_1d_wrapper() {
 | 
			
		|||
    broadcast_dimensions = dense<0> : tensor<1xi64>
 | 
			
		||||
  } : (tensor<1xf32>, tensor<2xindex>) -> tensor<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32>
 | 
			
		||||
  %dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_dyn_output = memref_cast %dyn_output_buf
 | 
			
		||||
  %unranked_dyn_output = memref.cast %dyn_output_buf
 | 
			
		||||
    : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
 | 
			
		||||
  // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
 | 
			
		||||
| 
						 | 
				
			
			@ -309,20 +309,20 @@ func @broadcast_scalar_1d_wrapper() {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
func @broadcast_scalar_2d_wrapper() {
 | 
			
		||||
  %input_buf = alloc() : memref<1x1xf32>
 | 
			
		||||
  %input_buf = memref.alloc() : memref<1x1xf32>
 | 
			
		||||
  %c1f32 = constant 1.0 : f32
 | 
			
		||||
  %c0 = constant 0 : index
 | 
			
		||||
  store %c1f32, %input_buf[%c0, %c0] : memref<1x1xf32>
 | 
			
		||||
  %input = tensor_load %input_buf : memref<1x1xf32>
 | 
			
		||||
  memref.store %c1f32, %input_buf[%c0, %c0] : memref<1x1xf32>
 | 
			
		||||
  %input = memref.tensor_load %input_buf : memref<1x1xf32>
 | 
			
		||||
 | 
			
		||||
  // Test BroadcastInDimOp.
 | 
			
		||||
  %output = "mhlo.broadcast_in_dim"(%input) {
 | 
			
		||||
    broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
 | 
			
		||||
  } : (tensor<1x1xf32>) -> tensor<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %output_buf = tensor_to_memref %output : memref<3x4xf32>
 | 
			
		||||
  %output_buf = memref.buffer_cast %output : memref<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
  %unranked_output = memref.cast %output_buf : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
 | 
			
		||||
  // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
 | 
			
		||||
  // CHECK-NEXT: [1, 1, 1, 1]
 | 
			
		||||
| 
						 | 
				
			
			@ -337,9 +337,9 @@ func @broadcast_scalar_2d_wrapper() {
 | 
			
		|||
    broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
 | 
			
		||||
  } : (tensor<1x1xf32>, tensor<2xindex>) -> tensor<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32>
 | 
			
		||||
  %dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_dyn_output = memref_cast %dyn_output_buf
 | 
			
		||||
  %unranked_dyn_output = memref.cast %dyn_output_buf
 | 
			
		||||
    : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
 | 
			
		||||
  // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
 | 
			
		||||
| 
						 | 
				
			
			@ -350,7 +350,7 @@ func @broadcast_scalar_2d_wrapper() {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
func @broadcast_to_the_same_shape() {
 | 
			
		||||
  %input_buf = alloc() : memref<2x3xf32>
 | 
			
		||||
  %input_buf = memref.alloc() : memref<2x3xf32>
 | 
			
		||||
 | 
			
		||||
  %c1f32 = constant 1.0 : f32
 | 
			
		||||
  %c2f32 = constant 2.0 : f32
 | 
			
		||||
| 
						 | 
				
			
			@ -360,22 +360,22 @@ func @broadcast_to_the_same_shape() {
 | 
			
		|||
  %c1 = constant 1 : index
 | 
			
		||||
  %c2 = constant 2 : index
 | 
			
		||||
  %c3 = constant 3 : index
 | 
			
		||||
  store %c1f32, %input_buf[%c0, %c0] : memref<2x3xf32>
 | 
			
		||||
  store %c1f32, %input_buf[%c1, %c0] : memref<2x3xf32>
 | 
			
		||||
  store %c2f32, %input_buf[%c0, %c1] : memref<2x3xf32>
 | 
			
		||||
  store %c2f32, %input_buf[%c1, %c1] : memref<2x3xf32>
 | 
			
		||||
  store %c3f32, %input_buf[%c0, %c2] : memref<2x3xf32>
 | 
			
		||||
  store %c3f32, %input_buf[%c1, %c2] : memref<2x3xf32>
 | 
			
		||||
  %input = tensor_load %input_buf : memref<2x3xf32>
 | 
			
		||||
  memref.store %c1f32, %input_buf[%c0, %c0] : memref<2x3xf32>
 | 
			
		||||
  memref.store %c1f32, %input_buf[%c1, %c0] : memref<2x3xf32>
 | 
			
		||||
  memref.store %c2f32, %input_buf[%c0, %c1] : memref<2x3xf32>
 | 
			
		||||
  memref.store %c2f32, %input_buf[%c1, %c1] : memref<2x3xf32>
 | 
			
		||||
  memref.store %c3f32, %input_buf[%c0, %c2] : memref<2x3xf32>
 | 
			
		||||
  memref.store %c3f32, %input_buf[%c1, %c2] : memref<2x3xf32>
 | 
			
		||||
  %input = memref.tensor_load %input_buf : memref<2x3xf32>
 | 
			
		||||
 | 
			
		||||
  // Test BroadcastInDimOp.
 | 
			
		||||
  %output = "mhlo.broadcast_in_dim"(%input) {
 | 
			
		||||
    broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
 | 
			
		||||
  } : (tensor<2x3xf32>) -> tensor<2x3xf32>
 | 
			
		||||
 | 
			
		||||
  %output_buf = tensor_to_memref %output : memref<2x3xf32>
 | 
			
		||||
  %output_buf = memref.buffer_cast %output : memref<2x3xf32>
 | 
			
		||||
 | 
			
		||||
  %unraked_output = memref_cast %output_buf : memref<2x3xf32> to memref<*xf32>
 | 
			
		||||
  %unraked_output = memref.cast %output_buf : memref<2x3xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unraked_output) : (memref<*xf32>) -> ()
 | 
			
		||||
  // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
 | 
			
		||||
  // CHECK-NEXT: [1,   2,   3]
 | 
			
		||||
| 
						 | 
				
			
			@ -387,9 +387,9 @@ func @broadcast_to_the_same_shape() {
 | 
			
		|||
    broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
 | 
			
		||||
  } : (tensor<2x3xf32>, tensor<2xindex>) -> tensor<2x3xf32>
 | 
			
		||||
 | 
			
		||||
  %dyn_output_buf = tensor_to_memref %dyn_output : memref<2x3xf32>
 | 
			
		||||
  %dyn_output_buf = memref.buffer_cast %dyn_output : memref<2x3xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_dyn_output = memref_cast %dyn_output_buf
 | 
			
		||||
  %unranked_dyn_output = memref.cast %dyn_output_buf
 | 
			
		||||
    : memref<2x3xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
 | 
			
		||||
  // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
 | 
			
		||||
| 
						 | 
				
			
			@ -399,7 +399,7 @@ func @broadcast_to_the_same_shape() {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
func @broadcast_1d_to_2d() {
 | 
			
		||||
  %input_buf = alloc() : memref<3xf32>
 | 
			
		||||
  %input_buf = memref.alloc() : memref<3xf32>
 | 
			
		||||
 | 
			
		||||
  %c1f32 = constant 1.0 : f32
 | 
			
		||||
  %c2f32 = constant 2.0 : f32
 | 
			
		||||
| 
						 | 
				
			
			@ -408,19 +408,19 @@ func @broadcast_1d_to_2d() {
 | 
			
		|||
  %c0 = constant 0 : index
 | 
			
		||||
  %c1 = constant 1 : index
 | 
			
		||||
  %c2 = constant 2 : index
 | 
			
		||||
  store %c1f32, %input_buf[%c0] : memref<3xf32>
 | 
			
		||||
  store %c2f32, %input_buf[%c1] : memref<3xf32>
 | 
			
		||||
  store %c3f32, %input_buf[%c2] : memref<3xf32>
 | 
			
		||||
  %input = tensor_load %input_buf : memref<3xf32>
 | 
			
		||||
  memref.store %c1f32, %input_buf[%c0] : memref<3xf32>
 | 
			
		||||
  memref.store %c2f32, %input_buf[%c1] : memref<3xf32>
 | 
			
		||||
  memref.store %c3f32, %input_buf[%c2] : memref<3xf32>
 | 
			
		||||
  %input = memref.tensor_load %input_buf : memref<3xf32>
 | 
			
		||||
 | 
			
		||||
  // Test BroadcastInDimOp.
 | 
			
		||||
  %output = "mhlo.broadcast_in_dim"(%input) {
 | 
			
		||||
    broadcast_dimensions = dense<0> : tensor<1xi64>
 | 
			
		||||
  } : (tensor<3xf32>) -> tensor<3x3xf32>
 | 
			
		||||
 | 
			
		||||
  %output_buf = tensor_to_memref %output : memref<3x3xf32>
 | 
			
		||||
  %output_buf = memref.buffer_cast %output : memref<3x3xf32>
 | 
			
		||||
 | 
			
		||||
  %unraked_output = memref_cast %output_buf : memref<3x3xf32> to memref<*xf32>
 | 
			
		||||
  %unraked_output = memref.cast %output_buf : memref<3x3xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unraked_output) : (memref<*xf32>) -> ()
 | 
			
		||||
  // CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1]
 | 
			
		||||
  // CHECK-NEXT: [1,   1,   1]
 | 
			
		||||
| 
						 | 
				
			
			@ -435,9 +435,9 @@ func @broadcast_1d_to_2d() {
 | 
			
		|||
    broadcast_dimensions = dense<0> : tensor<1xi64>
 | 
			
		||||
  } : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x3xf32>
 | 
			
		||||
 | 
			
		||||
  %dyn_output_buf = tensor_to_memref %dyn_output : memref<3x3xf32>
 | 
			
		||||
  %dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x3xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_dyn_output = memref_cast %dyn_output_buf
 | 
			
		||||
  %unranked_dyn_output = memref.cast %dyn_output_buf
 | 
			
		||||
    : memref<3x3xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
 | 
			
		||||
  // CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1]
 | 
			
		||||
| 
						 | 
				
			
			@ -448,7 +448,7 @@ func @broadcast_1d_to_2d() {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
func @broadcast_1d_to_2d_with_transpose() {
 | 
			
		||||
  %input_buf = alloc() : memref<3xf32>
 | 
			
		||||
  %input_buf = memref.alloc() : memref<3xf32>
 | 
			
		||||
 | 
			
		||||
  %c1f32 = constant 1.0 : f32
 | 
			
		||||
  %c2f32 = constant 2.0 : f32
 | 
			
		||||
| 
						 | 
				
			
			@ -457,19 +457,19 @@ func @broadcast_1d_to_2d_with_transpose() {
 | 
			
		|||
  %c0 = constant 0 : index
 | 
			
		||||
  %c1 = constant 1 : index
 | 
			
		||||
  %c2 = constant 2 : index
 | 
			
		||||
  store %c1f32, %input_buf[%c0] : memref<3xf32>
 | 
			
		||||
  store %c2f32, %input_buf[%c1] : memref<3xf32>
 | 
			
		||||
  store %c3f32, %input_buf[%c2] : memref<3xf32>
 | 
			
		||||
  %input = tensor_load %input_buf : memref<3xf32>
 | 
			
		||||
  memref.store %c1f32, %input_buf[%c0] : memref<3xf32>
 | 
			
		||||
  memref.store %c2f32, %input_buf[%c1] : memref<3xf32>
 | 
			
		||||
  memref.store %c3f32, %input_buf[%c2] : memref<3xf32>
 | 
			
		||||
  %input = memref.tensor_load %input_buf : memref<3xf32>
 | 
			
		||||
 | 
			
		||||
  // Test BroadcastInDimOp.
 | 
			
		||||
  %output = "mhlo.broadcast_in_dim"(%input) {
 | 
			
		||||
    broadcast_dimensions = dense<1> : tensor<1xi64>
 | 
			
		||||
  } : (tensor<3xf32>) -> tensor<3x3xf32>
 | 
			
		||||
 | 
			
		||||
  %output_buf = tensor_to_memref %output : memref<3x3xf32>
 | 
			
		||||
  %output_buf = memref.buffer_cast %output : memref<3x3xf32>
 | 
			
		||||
 | 
			
		||||
  %unraked_output = memref_cast %output_buf : memref<3x3xf32> to memref<*xf32>
 | 
			
		||||
  %unraked_output = memref.cast %output_buf : memref<3x3xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unraked_output) : (memref<*xf32>) -> ()
 | 
			
		||||
  // CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1]
 | 
			
		||||
  // CHECK-NEXT: [1,   2,   3]
 | 
			
		||||
| 
						 | 
				
			
			@ -483,9 +483,9 @@ func @broadcast_1d_to_2d_with_transpose() {
 | 
			
		|||
    broadcast_dimensions = dense<1> : tensor<1xi64>
 | 
			
		||||
  } : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x3xf32>
 | 
			
		||||
 | 
			
		||||
  %dyn_output_buf = tensor_to_memref %dyn_output : memref<3x3xf32>
 | 
			
		||||
  %dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x3xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_dyn_output = memref_cast %dyn_output_buf
 | 
			
		||||
  %unranked_dyn_output = memref.cast %dyn_output_buf
 | 
			
		||||
    : memref<3x3xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
 | 
			
		||||
  // CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -4,10 +4,10 @@ func private @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface
 | 
			
		|||
 | 
			
		||||
// Helper function to print scalar values.
 | 
			
		||||
func @print_f32(%arg : f32) {
 | 
			
		||||
  %mem = alloca() : memref<1xf32>
 | 
			
		||||
  %mem = memref.alloca() : memref<1xf32>
 | 
			
		||||
  %c0 = constant 0 : index
 | 
			
		||||
  store %arg, %mem[%c0] : memref<1xf32>
 | 
			
		||||
  %mem_unranked = memref_cast %mem : memref<1xf32> to memref<*xf32>
 | 
			
		||||
  memref.store %arg, %mem[%c0] : memref<1xf32>
 | 
			
		||||
  %mem_unranked = memref.cast %mem : memref<1xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%mem_unranked) : (memref<*xf32>) -> ()
 | 
			
		||||
  return
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -21,21 +21,21 @@ func @reduce_add() {
 | 
			
		|||
  %c1 = constant 1 : index
 | 
			
		||||
 | 
			
		||||
  // Initialize input.
 | 
			
		||||
  %input = alloc() : memref<2x3xf32>
 | 
			
		||||
  %dim_x = dim %input, %c0 : memref<2x3xf32>
 | 
			
		||||
  %dim_y = dim %input, %c1 : memref<2x3xf32>
 | 
			
		||||
  %input = memref.alloc() : memref<2x3xf32>
 | 
			
		||||
  %dim_x = memref.dim %input, %c0 : memref<2x3xf32>
 | 
			
		||||
  %dim_y = memref.dim %input, %c1 : memref<2x3xf32>
 | 
			
		||||
  scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) {
 | 
			
		||||
    %i_i64 = index_cast %i : index to i64
 | 
			
		||||
    %i_f32 = sitofp %i_i64 : i64 to f32
 | 
			
		||||
    store %i_f32, %input[%i, %j] : memref<2x3xf32>
 | 
			
		||||
    memref.store %i_f32, %input[%i, %j] : memref<2x3xf32>
 | 
			
		||||
  }
 | 
			
		||||
  %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
 | 
			
		||||
  %unranked_input = memref.cast %input : memref<2x3xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> ()
 | 
			
		||||
  // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
 | 
			
		||||
  // CHECK: [0,   0,   0]
 | 
			
		||||
  // CHECK: [1,   1,   1]
 | 
			
		||||
 | 
			
		||||
  %in = tensor_load %input : memref<2x3xf32>
 | 
			
		||||
  %in = memref.tensor_load %input : memref<2x3xf32>
 | 
			
		||||
  %init = mhlo.constant dense<0.000000e+00> : tensor<f32>
 | 
			
		||||
 | 
			
		||||
  %reduce = "mhlo.reduce"(%in, %init) ( {
 | 
			
		||||
| 
						 | 
				
			
			@ -45,8 +45,8 @@ func @reduce_add() {
 | 
			
		|||
  }) {dimensions = dense<1> : tensor<1xi64>}
 | 
			
		||||
      : (tensor<2x3xf32>, tensor<f32>) -> tensor<2xf32>
 | 
			
		||||
 | 
			
		||||
  %output = tensor_to_memref %reduce : memref<2xf32>
 | 
			
		||||
  %unranked_output = memref_cast %output : memref<2xf32> to memref<*xf32>
 | 
			
		||||
  %output = memref.buffer_cast %reduce : memref<2xf32>
 | 
			
		||||
  %unranked_output = memref.cast %output : memref<2xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
 | 
			
		||||
  // CHECK: rank = 1 offset = 0 sizes = [2] strides = [1]
 | 
			
		||||
  // CHECK: [0,  3]
 | 
			
		||||
| 
						 | 
				
			
			@ -58,21 +58,21 @@ func @reduce_max() {
 | 
			
		|||
  %c1 = constant 1 : index
 | 
			
		||||
 | 
			
		||||
  // Initialize input.
 | 
			
		||||
  %input = alloc() : memref<2x3xf32>
 | 
			
		||||
  %dim_x = dim %input, %c0 : memref<2x3xf32>
 | 
			
		||||
  %dim_y = dim %input, %c1 : memref<2x3xf32>
 | 
			
		||||
  %input = memref.alloc() : memref<2x3xf32>
 | 
			
		||||
  %dim_x = memref.dim %input, %c0 : memref<2x3xf32>
 | 
			
		||||
  %dim_y = memref.dim %input, %c1 : memref<2x3xf32>
 | 
			
		||||
  scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) {
 | 
			
		||||
    %i_i64 = index_cast %i : index to i64
 | 
			
		||||
    %i_f32 = sitofp %i_i64 : i64 to f32
 | 
			
		||||
    store %i_f32, %input[%i, %j] : memref<2x3xf32>
 | 
			
		||||
    memref.store %i_f32, %input[%i, %j] : memref<2x3xf32>
 | 
			
		||||
  }
 | 
			
		||||
  %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
 | 
			
		||||
  %unranked_input = memref.cast %input : memref<2x3xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> ()
 | 
			
		||||
  // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
 | 
			
		||||
  // CHECK: [0,   0,   0]
 | 
			
		||||
  // CHECK: [1,   1,   1]
 | 
			
		||||
 | 
			
		||||
  %in = tensor_load %input : memref<2x3xf32>
 | 
			
		||||
  %in = memref.tensor_load %input : memref<2x3xf32>
 | 
			
		||||
  %init = mhlo.constant dense<0xff800000> : tensor<f32>
 | 
			
		||||
 | 
			
		||||
  %reduce = "mhlo.reduce"(%in, %init) ( {
 | 
			
		||||
| 
						 | 
				
			
			@ -82,8 +82,8 @@ func @reduce_max() {
 | 
			
		|||
  }) {dimensions = dense<1> : tensor<1xi64>}
 | 
			
		||||
      : (tensor<2x3xf32>, tensor<f32>) -> tensor<2xf32>
 | 
			
		||||
 | 
			
		||||
  %output = tensor_to_memref %reduce : memref<2xf32>
 | 
			
		||||
  %unranked_output = memref_cast %output : memref<2xf32> to memref<*xf32>
 | 
			
		||||
  %output = memref.buffer_cast %reduce : memref<2xf32>
 | 
			
		||||
  %unranked_output = memref.cast %output : memref<2xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
 | 
			
		||||
  // CHECK: rank = 1 offset = 0 sizes = [2] strides = [1]
 | 
			
		||||
  // CHECK: [0,  1]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -17,7 +17,7 @@ func @dynamic_reshape_from_unranked(
 | 
			
		|||
  return %reshaped : tensor<?xf32>
 | 
			
		||||
}
 | 
			
		||||
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>, [[SHAPE:%.*]]: memref<1xi32>)
 | 
			
		||||
// CHECK-NEXT: memref_reshape [[ARG]]([[SHAPE]])
 | 
			
		||||
// CHECK-NEXT: memref.reshape [[ARG]]([[SHAPE]])
 | 
			
		||||
// CHECK-SAME:   : (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
| 
						 | 
				
			
			@ -30,7 +30,7 @@ func @dynamic_reshape_to_unranked(
 | 
			
		|||
  return %reshaped : tensor<*xf32>
 | 
			
		||||
}
 | 
			
		||||
// CHECK-SAME: ([[ARG:%.*]]: memref<?xf32>, [[SHAPE:%.*]]: memref<?xi32>)
 | 
			
		||||
// CHECK-NEXT: memref_reshape [[ARG]]([[SHAPE]])
 | 
			
		||||
// CHECK-NEXT: memref.reshape [[ARG]]([[SHAPE]])
 | 
			
		||||
// CHECK-SAME:   : (memref<?xf32>, memref<?xi32>) -> memref<*xf32>
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
| 
						 | 
				
			
			@ -41,4 +41,4 @@ func @reshape_unranked(%operand: tensor<*xf32>) -> tensor<f32> {
 | 
			
		|||
  return %reshaped : tensor<f32>
 | 
			
		||||
}
 | 
			
		||||
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>)
 | 
			
		||||
// CHECK-NEXT: memref_cast [[ARG]] : memref<*xf32> to memref<f32>
 | 
			
		||||
// CHECK-NEXT: memref.cast [[ARG]] : memref<*xf32> to memref<f32>
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -31,20 +31,20 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
 | 
			
		|||
  return %5 : tensor<4xf32>
 | 
			
		||||
}
 | 
			
		||||
//       CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>) -> memref<4xf32>
 | 
			
		||||
//  CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32>
 | 
			
		||||
//  CHECK-NEXT: %[[MAX_RESULT:.*]] = memref.alloc() : memref<4xf32>
 | 
			
		||||
//  CHECK-NEXT: "lmhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
 | 
			
		||||
//  CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32>
 | 
			
		||||
//  CHECK-NEXT: %[[ADD_RESULT:.*]] = memref.alloc() : memref<4xf32>
 | 
			
		||||
//  CHECK-NEXT: "lmhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
 | 
			
		||||
//  CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
 | 
			
		||||
//  CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32>
 | 
			
		||||
//  CHECK-NEXT: memref.dealloc %[[MAX_RESULT]] : memref<4xf32>
 | 
			
		||||
//  CHECK-NEXT: %[[MIN_RESULT:.*]] = memref.alloc() : memref<4xf32>
 | 
			
		||||
//  CHECK-NEXT: "lmhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
 | 
			
		||||
//  CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32>
 | 
			
		||||
//  CHECK-NEXT: %[[SUB_RESULT:.*]] = memref.alloc() : memref<4xf32>
 | 
			
		||||
//  CHECK-NEXT: "lmhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
 | 
			
		||||
//  CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
 | 
			
		||||
//  CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32>
 | 
			
		||||
//  CHECK-NEXT: memref.dealloc %[[MIN_RESULT]] : memref<4xf32>
 | 
			
		||||
//  CHECK-NEXT: %[[MUL_RESULT:.*]] = memref.alloc() : memref<4xf32>
 | 
			
		||||
//  CHECK-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
 | 
			
		||||
//  CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
 | 
			
		||||
//  CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
 | 
			
		||||
//  CHECK-NEXT: memref.dealloc %[[SUB_RESULT]] : memref<4xf32>
 | 
			
		||||
//  CHECK-NEXT: memref.dealloc %[[ADD_RESULT]] : memref<4xf32>
 | 
			
		||||
//  CHECK-NEXT: return %[[MUL_RESULT]] : memref<4xf32>
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
| 
						 | 
				
			
			@ -53,15 +53,15 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
 | 
			
		|||
func @fusion(%multiplier: tensor<2x2xf32>, %summand_1: tensor<2x2xf32>,
 | 
			
		||||
             %summand_2: tensor<2x2xf32>) -> tensor<2x2xf32> {
 | 
			
		||||
  // CHECK: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}})
 | 
			
		||||
  // CHECK-NEXT:  %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32>
 | 
			
		||||
  // CHECK-NEXT:  %[[ADD_RESULT:.*]] = memref.alloc() : memref<2x2xf32>
 | 
			
		||||
  %sum = "mhlo.add"(%summand_1, %summand_2)
 | 
			
		||||
      : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
 | 
			
		||||
  // CHECK-NEXT: "lmhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
 | 
			
		||||
  // CHECK-NEXT:  %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
 | 
			
		||||
  // CHECK-NEXT:  %[[MUL_RESULT:.*]] = memref.alloc() : memref<2x2xf32>
 | 
			
		||||
  %result = "mhlo.multiply"(%sum, %multiplier)
 | 
			
		||||
      : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
 | 
			
		||||
  // CHECK-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
 | 
			
		||||
  // CHECK-NEXT:  dealloc %[[ADD_RESULT]] : memref<2x2xf32>
 | 
			
		||||
  // CHECK-NEXT:  memref.dealloc %[[ADD_RESULT]] : memref<2x2xf32>
 | 
			
		||||
  // CHECK-NEXT:  return %[[MUL_RESULT]] : memref<2x2xf32>
 | 
			
		||||
  return %result : tensor<2x2xf32>
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -154,9 +154,9 @@ func @dyn_broadcast(%operand: tensor<?x?xf32>) -> tensor<?x?x?xf32> {
 | 
			
		|||
 | 
			
		||||
// CHECK: %[[C0:.*]] = constant 0 : index
 | 
			
		||||
// CHECK: %[[C1:.*]] = constant 1 : index
 | 
			
		||||
// CHECK: %[[OPER_DIM_1:.*]] = dim %[[OPERAND]], %[[C1]] : memref<?x?xf32>
 | 
			
		||||
// CHECK: %[[OPER_DIM_1:.*]] = memref.dim %[[OPERAND]], %[[C1]] : memref<?x?xf32>
 | 
			
		||||
// CHECK: %[[OP_STRIDE_0:.*]] = muli %[[C1]], %[[OPER_DIM_1]] : index
 | 
			
		||||
// CHECK: %[[OPER_DIM_0:.*]] = dim %[[OPERAND]], %[[C0]] : memref<?x?xf32>
 | 
			
		||||
// CHECK: %[[OPER_DIM_0:.*]] = memref.dim %[[OPERAND]], %[[C0]] : memref<?x?xf32>
 | 
			
		||||
 | 
			
		||||
// CHECK: %[[EL0:.*]] = tensor.extract %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64>
 | 
			
		||||
// CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index
 | 
			
		||||
| 
						 | 
				
			
			@ -172,9 +172,9 @@ func @dyn_broadcast(%operand: tensor<?x?xf32>) -> tensor<?x?x?xf32> {
 | 
			
		|||
// CHECK: %[[EXPAND_2:.*]] = cmpi slt, %[[OPER_DIM_1]], %[[SIZE_2]] : index
 | 
			
		||||
// CHECK: %[[STRIDE_2:.*]] = select %[[EXPAND_2]], %[[C0]], %[[C1]] : index
 | 
			
		||||
 | 
			
		||||
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]] : memref<?x?xf32> to memref<?x?x?xf32, #map>
 | 
			
		||||
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref.reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]] : memref<?x?xf32> to memref<?x?x?xf32, #map>
 | 
			
		||||
 | 
			
		||||
// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>
 | 
			
		||||
// CHECK: %[[RESULT:.*]] = memref.alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>
 | 
			
		||||
 | 
			
		||||
// CHECK: "lmhlo.copy"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) : (memref<?x?x?xf32, #map>, memref<?x?x?xf32>) -> ()
 | 
			
		||||
// CHECK: return %[[RESULT]] : memref<?x?x?xf32>
 | 
			
		||||
| 
						 | 
				
			
			@ -469,7 +469,7 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) -> tensor<?x?xf32> {
 | 
			
		|||
  // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex>
 | 
			
		||||
  // CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex>
 | 
			
		||||
  // CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex>
 | 
			
		||||
  // CHECK: %[[RESULT:.*]] = alloc(%[[EE0]], %[[EE1]])
 | 
			
		||||
  // CHECK: %[[RESULT:.*]] = memref.alloc(%[[EE0]], %[[EE1]])
 | 
			
		||||
  // CHECK: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
 | 
			
		||||
  return %result : tensor<?x?xf32>
 | 
			
		||||
  // CHECK: return %[[RESULT]]
 | 
			
		||||
| 
						 | 
				
			
			@ -485,7 +485,7 @@ func @tanh_dyn(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
 | 
			
		|||
  // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex>
 | 
			
		||||
  // CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex>
 | 
			
		||||
  // CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex>
 | 
			
		||||
  // CHECK: %[[RESULT:.*]] = alloc(%[[EE0]], %[[EE1]])
 | 
			
		||||
  // CHECK: %[[RESULT:.*]] = memref.alloc(%[[EE0]], %[[EE1]])
 | 
			
		||||
  // CHECK: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
 | 
			
		||||
  return %result : tensor<?x?xf32>
 | 
			
		||||
  // CHECK: return %[[RESULT]]
 | 
			
		||||
| 
						 | 
				
			
			@ -496,7 +496,7 @@ func @tanh_dyn(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
 | 
			
		|||
// CHECK-LABEL: func @dot
 | 
			
		||||
func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
 | 
			
		||||
// CHECK-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
 | 
			
		||||
// CHECK-NEXT: %[[ALLOC:.*]] = alloc
 | 
			
		||||
// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc
 | 
			
		||||
//      CHECK: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) {
 | 
			
		||||
//        dot_dimension_numbers = {
 | 
			
		||||
//          lhs_batching_dimensions = dense<> : tensor<0xi64>,
 | 
			
		||||
| 
						 | 
				
			
			@ -517,7 +517,7 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
 | 
			
		|||
func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>)
 | 
			
		||||
    -> tensor<3x5x5x4xf32> {
 | 
			
		||||
  %c0 = constant 0 : index
 | 
			
		||||
  // CHECK: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32>
 | 
			
		||||
  // CHECK: %[[OUT:.*]] = memref.alloc() : memref<3x5x5x4xf32>
 | 
			
		||||
  // CHECK: "lmhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
 | 
			
		||||
  // CHECK-SAME: padding = dense<[
 | 
			
		||||
  // CHECK-SAME:                  [0, 1], [0, 1]]> : tensor<2x2xi64>
 | 
			
		||||
| 
						 | 
				
			
			@ -548,11 +548,11 @@ func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>)
 | 
			
		|||
 | 
			
		||||
// CHECK-LABEL: func @reduce
 | 
			
		||||
func @reduce(%arg0: tensor<1x8xf32>, %arg1: tensor<f32>) -> tensor<1xf32> {
 | 
			
		||||
  // CHECK: %[[OUT:.*]] = alloc() : memref<1xf32>
 | 
			
		||||
  // CHECK: %[[OUT:.*]] = memref.alloc() : memref<1xf32>
 | 
			
		||||
  // CHECK:  "lmhlo.reduce"(%{{.+}}, %{{.+}}, %[[OUT]]) ( {
 | 
			
		||||
  // CHECK:  ^bb0(%[[ARG1:.*]]: memref<f32>, %[[ARG2:.*]]: memref<f32>,
 | 
			
		||||
  // CHECK-SAME:  %[[ARG3:.*]]: memref<f32>):
 | 
			
		||||
  // CHECK:    %[[TMP:.*]] = alloc() : memref<f32>
 | 
			
		||||
  // CHECK:    %[[TMP:.*]] = memref.alloc() : memref<f32>
 | 
			
		||||
  // CHECK:    "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[TMP]])
 | 
			
		||||
  // CHECK:    "lmhlo.copy"(%[[TMP]], %[[ARG3]])
 | 
			
		||||
  // CHECK:    "lmhlo.terminator"() : () -> ()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -404,7 +404,7 @@ func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> {
 | 
			
		|||
  return %0: tensor<4x2x1x4x?x16xf32>
 | 
			
		||||
}
 | 
			
		||||
// CHECK: %[[C1:.*]] = constant 1 : index
 | 
			
		||||
// CHECK: %[[D1:.*]] = dim %{{.*}}, %[[C1]] : tensor<4x?x16xf32>
 | 
			
		||||
// CHECK: %[[D1:.*]] = memref.dim %{{.*}}, %[[C1]] : tensor<4x?x16xf32>
 | 
			
		||||
// CHECK: linalg.init_tensor [4, 2, 1, 4, %[[D1]], 16] : tensor<4x2x1x4x?x16xf32>
 | 
			
		||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
 | 
			
		||||
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32):
 | 
			
		||||
| 
						 | 
				
			
			@ -1024,7 +1024,7 @@ func @dot_matmul(%arg0: tensor<2x3xf32>,
 | 
			
		|||
// CHECK-LABEL: func @dot_matmul(
 | 
			
		||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xf32>, %[[ARG1:.*]]: tensor<3x?xf32>)
 | 
			
		||||
// CHECK: %[[C1:.*]] = constant 1 : index
 | 
			
		||||
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
 | 
			
		||||
// CHECK: %[[D1:.*]] = memref.dim %[[ARG1]], %[[C1]]
 | 
			
		||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
 | 
			
		||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
 | 
			
		||||
// CHECK: linalg.matmul
 | 
			
		||||
| 
						 | 
				
			
			@ -1040,7 +1040,7 @@ func @dot_matmul_i8_i8_i32(%arg0: tensor<2x3xi8>,
 | 
			
		|||
// CHECK-LABEL: func @dot_matmul_i8_i8_i32(
 | 
			
		||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi8>, %[[ARG1:.*]]: tensor<3x?xi8>)
 | 
			
		||||
// CHECK: %[[C1:.*]] = constant 1 : index
 | 
			
		||||
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
 | 
			
		||||
// CHECK: %[[D1:.*]] = memref.dim %[[ARG1]], %[[C1]]
 | 
			
		||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
 | 
			
		||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
 | 
			
		||||
// CHECK: linalg.matmul
 | 
			
		||||
| 
						 | 
				
			
			@ -1058,7 +1058,7 @@ func @dot_matmul_i16_i16_i32(%arg0: tensor<2x3xi16>,
 | 
			
		|||
// CHECK-LABEL: func @dot_matmul_i16_i16_i32(
 | 
			
		||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi16>, %[[ARG1:.*]]: tensor<3x?xi16>)
 | 
			
		||||
// CHECK: %[[C1:.*]] = constant 1 : index
 | 
			
		||||
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
 | 
			
		||||
// CHECK: %[[D1:.*]] = memref.dim %[[ARG1]], %[[C1]]
 | 
			
		||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
 | 
			
		||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
 | 
			
		||||
// CHECK: linalg.matmul
 | 
			
		||||
| 
						 | 
				
			
			@ -1076,7 +1076,7 @@ func @dot_matmul_i32_i32_i32(%arg0: tensor<2x3xi32>,
 | 
			
		|||
// CHECK-LABEL: func @dot_matmul_i32_i32_i32(
 | 
			
		||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi32>, %[[ARG1:.*]]: tensor<3x?xi32>)
 | 
			
		||||
// CHECK: %[[C1:.*]] = constant 1 : index
 | 
			
		||||
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
 | 
			
		||||
// CHECK: %[[D1:.*]] = memref.dim %[[ARG1]], %[[C1]]
 | 
			
		||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
 | 
			
		||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
 | 
			
		||||
// CHECK: linalg.matmul
 | 
			
		||||
| 
						 | 
				
			
			@ -1094,7 +1094,7 @@ func @dot_matvec(%arg0: tensor<?x3xf32>,
 | 
			
		|||
// CHECK-LABEL: func @dot_matvec(
 | 
			
		||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x3xf32>, %[[ARG1:.*]]: tensor<3xf32>)
 | 
			
		||||
// CHECK: %[[C0:.*]] = constant 0 : index
 | 
			
		||||
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
 | 
			
		||||
// CHECK: %[[D0:.*]] = memref.dim %[[ARG0]], %[[C0]]
 | 
			
		||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]]]
 | 
			
		||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
 | 
			
		||||
// CHECK: linalg.matvec
 | 
			
		||||
| 
						 | 
				
			
			@ -1134,11 +1134,11 @@ func @dot_general_batch_matmul(%arg0: tensor<?x?x3xf32>,
 | 
			
		|||
// CHECK-LABEL: func @dot_general_batch_matmul(
 | 
			
		||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x3xf32>, %[[ARG1:.*]]: tensor<?x3x?xf32>)
 | 
			
		||||
// CHECK: %[[C0:.*]] = constant 0 : index
 | 
			
		||||
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
 | 
			
		||||
// CHECK: %[[D0:.*]] = memref.dim %[[ARG0]], %[[C0]]
 | 
			
		||||
// CHECK: %[[C1:.*]] = constant 1 : index
 | 
			
		||||
// CHECK: %[[D1:.*]] = dim %[[ARG0]], %[[C1]]
 | 
			
		||||
// CHECK: %[[D1:.*]] = memref.dim %[[ARG0]], %[[C1]]
 | 
			
		||||
// CHECK: %[[C2:.*]] = constant 2 : index
 | 
			
		||||
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
 | 
			
		||||
// CHECK: %[[D2:.*]] = memref.dim %[[ARG1]], %[[C2]]
 | 
			
		||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
 | 
			
		||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
 | 
			
		||||
// CHECK: linalg.batch_matmul
 | 
			
		||||
| 
						 | 
				
			
			@ -1163,11 +1163,11 @@ func @dot_general_batch_matmul_i8_i8_i32(%arg0: tensor<?x?x3xi8>,
 | 
			
		|||
// CHECK-LABEL: func @dot_general_batch_matmul_i8_i8_i32(
 | 
			
		||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x3xi8>, %[[ARG1:.*]]: tensor<?x3x?xi8>)
 | 
			
		||||
// CHECK: %[[C0:.*]] = constant 0 : index
 | 
			
		||||
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
 | 
			
		||||
// CHECK: %[[D0:.*]] = memref.dim %[[ARG0]], %[[C0]]
 | 
			
		||||
// CHECK: %[[C1:.*]] = constant 1 : index
 | 
			
		||||
// CHECK: %[[D1:.*]] = dim %[[ARG0]], %[[C1]]
 | 
			
		||||
// CHECK: %[[D1:.*]] = memref.dim %[[ARG0]], %[[C1]]
 | 
			
		||||
// CHECK: %[[C2:.*]] = constant 2 : index
 | 
			
		||||
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
 | 
			
		||||
// CHECK: %[[D2:.*]] = memref.dim %[[ARG1]], %[[C2]]
 | 
			
		||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
 | 
			
		||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
 | 
			
		||||
// CHECK: linalg.batch_matmul
 | 
			
		||||
| 
						 | 
				
			
			@ -1192,11 +1192,11 @@ func @dot_general_batch_matmul_i16_i16_i32(%arg0: tensor<?x?x3xi16>,
 | 
			
		|||
// CHECK-LABEL: func @dot_general_batch_matmul_i16_i16_i32(
 | 
			
		||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x3xi16>, %[[ARG1:.*]]: tensor<?x3x?xi16>)
 | 
			
		||||
// CHECK: %[[C0:.*]] = constant 0 : index
 | 
			
		||||
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
 | 
			
		||||
// CHECK: %[[D0:.*]] = memref.dim %[[ARG0]], %[[C0]]
 | 
			
		||||
// CHECK: %[[C1:.*]] = constant 1 : index
 | 
			
		||||
// CHECK: %[[D1:.*]] = dim %[[ARG0]], %[[C1]]
 | 
			
		||||
// CHECK: %[[D1:.*]] = memref.dim %[[ARG0]], %[[C1]]
 | 
			
		||||
// CHECK: %[[C2:.*]] = constant 2 : index
 | 
			
		||||
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
 | 
			
		||||
// CHECK: %[[D2:.*]] = memref.dim %[[ARG1]], %[[C2]]
 | 
			
		||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
 | 
			
		||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
 | 
			
		||||
// CHECK: linalg.batch_matmul
 | 
			
		||||
| 
						 | 
				
			
			@ -1420,7 +1420,7 @@ func @reduce_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<i32>) -> tensor<?xi32
 | 
			
		|||
// CHECK: func @reduce_dynamic(%[[ARG0:.*]]: tensor<?x?xi32>
 | 
			
		||||
// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
 | 
			
		||||
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
 | 
			
		||||
// CHECK-DAG: %[[DIM1:.*]] = dim %[[ARG0]], %[[C0]] : tensor<?x?xi32>
 | 
			
		||||
// CHECK-DAG: %[[DIM1:.*]] = memref.dim %[[ARG0]], %[[C0]] : tensor<?x?xi32>
 | 
			
		||||
// CHECK-DAG: %[[INIT_TENSOR:.*]] = linalg.init_tensor [%[[DIM1]]]
 | 
			
		||||
// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill(%[[INIT_TENSOR]], %[[INIT]])
 | 
			
		||||
// CHECK: linalg.generic
 | 
			
		||||
| 
						 | 
				
			
			@ -1531,9 +1531,9 @@ func @linalg.conv_1d_input_nwc_filter_wcf(%arg0: tensor<?x8x?xf32>, %arg1: tenso
 | 
			
		|||
// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]*]]
 | 
			
		||||
// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]*]]
 | 
			
		||||
// CHECK:         %[[C0:.+]] = constant 0 : index
 | 
			
		||||
// CHECK:         %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x8x?xf32>
 | 
			
		||||
// CHECK:         %[[DIM0:.+]] = memref.dim %[[ARG0]], %[[C0]] : tensor<?x8x?xf32>
 | 
			
		||||
// CHECK:         %[[C2:.+]] = constant 2 : index
 | 
			
		||||
// CHECK:         %[[DIM2:.+]] = dim %[[ARG1]], %[[C2]] : tensor<2x?x?xf32>
 | 
			
		||||
// CHECK:         %[[DIM2:.+]] = memref.dim %[[ARG1]], %[[C2]] : tensor<2x?x?xf32>
 | 
			
		||||
// CHECK:         %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 7, %[[DIM2]]]
 | 
			
		||||
// CHECK:         %[[ZERO:.+]] = constant 0.000000e+00 : f32
 | 
			
		||||
// CHECK:         %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
 | 
			
		||||
| 
						 | 
				
			
			@ -1571,9 +1571,9 @@ func @conv_2d_input_nhwc_filter_hwcf(%arg0: tensor<?x4x5x?xf32>, %arg1: tensor<3
 | 
			
		|||
// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]*]]
 | 
			
		||||
// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]*]]
 | 
			
		||||
// CHECK:         %[[C0:.+]] = constant 0 : index
 | 
			
		||||
// CHECK:         %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x4x5x?xf32>
 | 
			
		||||
// CHECK:         %[[DIM0:.+]] = memref.dim %[[ARG0]], %[[C0]] : tensor<?x4x5x?xf32>
 | 
			
		||||
// CHECK:         %[[C3:.+]] = constant 3 : index
 | 
			
		||||
// CHECK:         %[[DIM3:.+]] = dim %[[ARG1]], %[[C3]] : tensor<3x2x?x?xf32>
 | 
			
		||||
// CHECK:         %[[DIM3:.+]] = memref.dim %[[ARG1]], %[[C3]] : tensor<3x2x?x?xf32>
 | 
			
		||||
// CHECK:         %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 2, 3, %[[DIM3]]]
 | 
			
		||||
// CHECK:         %[[ZERO:.+]] = constant 0.000000e+00 : f32
 | 
			
		||||
// CHECK:         %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
 | 
			
		||||
| 
						 | 
				
			
			@ -1611,9 +1611,9 @@ func @conv_3d_input_ndhwc_filter_dhwcf(%arg0: tensor<?x8x8x8x?xf32>, %arg1: tens
 | 
			
		|||
// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]*]]
 | 
			
		||||
// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]*]]
 | 
			
		||||
// CHECK:         %[[C0:.+]] = constant 0 : index
 | 
			
		||||
// CHECK:         %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x8x8x8x?xf32>
 | 
			
		||||
// CHECK:         %[[DIM0:.+]] = memref.dim %[[ARG0]], %[[C0]] : tensor<?x8x8x8x?xf32>
 | 
			
		||||
// CHECK:         %[[C4:.+]] = constant 4 : index
 | 
			
		||||
// CHECK:         %[[DIM4:.+]] = dim %[[ARG1]], %[[C4]] : tensor<2x2x2x?x?xf32>
 | 
			
		||||
// CHECK:         %[[DIM4:.+]] = memref.dim %[[ARG1]], %[[C4]] : tensor<2x2x2x?x?xf32>
 | 
			
		||||
// CHECK:         %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 7, 7, 7, %[[DIM4]]]
 | 
			
		||||
// CHECK:         %[[ZERO:.+]] = constant 0.000000e+00 : f32
 | 
			
		||||
// CHECK:         %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -7,7 +7,7 @@
 | 
			
		|||
                       iterator_types = ["parallel", "parallel"]}
 | 
			
		||||
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
 | 
			
		||||
             %summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) {
 | 
			
		||||
  %temp_result = alloc() : memref<6x6xf32>
 | 
			
		||||
  %temp_result = memref.alloc() : memref<6x6xf32>
 | 
			
		||||
  linalg.generic #pointwise_2d_trait
 | 
			
		||||
    ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>)
 | 
			
		||||
   outs(%temp_result : memref<6x6xf32>) {
 | 
			
		||||
| 
						 | 
				
			
			@ -22,7 +22,7 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
 | 
			
		|||
    %out = mulf %temp_result_in, %multiplier_in : f32
 | 
			
		||||
    linalg.yield %out : f32
 | 
			
		||||
  }
 | 
			
		||||
  dealloc %temp_result : memref<6x6xf32>
 | 
			
		||||
  memref.dealloc %temp_result : memref<6x6xf32>
 | 
			
		||||
  return
 | 
			
		||||
}
 | 
			
		||||
// CHECK-LABEL: func @fusion
 | 
			
		||||
| 
						 | 
				
			
			@ -62,7 +62,7 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
 | 
			
		|||
func @fusion_of_three(%arg0: memref<100x10xf32>,
 | 
			
		||||
                      %arg1: memref<100xf32>,
 | 
			
		||||
                      %arg2: memref<100x10xf32>) {
 | 
			
		||||
 %0 = alloc() : memref<100x10xf32>
 | 
			
		||||
 %0 = memref.alloc() : memref<100x10xf32>
 | 
			
		||||
 linalg.generic {
 | 
			
		||||
   indexing_maps = [affine_map<(d0, d1) -> (d0)>,
 | 
			
		||||
                    affine_map<(d0, d1) -> (d0, d1)>],
 | 
			
		||||
| 
						 | 
				
			
			@ -72,7 +72,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
 | 
			
		|||
   ^bb0(%arg3: f32, %arg4: f32): // no predecessors
 | 
			
		||||
     linalg.yield %arg3 : f32
 | 
			
		||||
   }
 | 
			
		||||
 %1 = alloc() : memref<100x10xf32>
 | 
			
		||||
 %1 = memref.alloc() : memref<100x10xf32>
 | 
			
		||||
 linalg.generic {
 | 
			
		||||
   indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
 | 
			
		||||
                    affine_map<(d0, d1) -> (d0, d1)>,
 | 
			
		||||
| 
						 | 
				
			
			@ -84,7 +84,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
 | 
			
		|||
       %2 = subf %arg3, %arg4 : f32
 | 
			
		||||
       linalg.yield %2 : f32
 | 
			
		||||
     }
 | 
			
		||||
 dealloc %0 : memref<100x10xf32>
 | 
			
		||||
 memref.dealloc %0 : memref<100x10xf32>
 | 
			
		||||
 linalg.generic {
 | 
			
		||||
   indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
 | 
			
		||||
                    affine_map<(d0, d1) -> (d0, d1)>],
 | 
			
		||||
| 
						 | 
				
			
			@ -95,7 +95,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
 | 
			
		|||
       %2 = math.exp %arg3 : f32
 | 
			
		||||
       linalg.yield %2 : f32
 | 
			
		||||
     }
 | 
			
		||||
 dealloc %1 : memref<100x10xf32>
 | 
			
		||||
 memref.dealloc %1 : memref<100x10xf32>
 | 
			
		||||
 return
 | 
			
		||||
}
 | 
			
		||||
// CHECK-LABEL: func @fusion
 | 
			
		||||
| 
						 | 
				
			
			@ -141,7 +141,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
 | 
			
		|||
                                         "parallel"]}
 | 
			
		||||
func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>,
 | 
			
		||||
             %summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) {
 | 
			
		||||
  %temp_result = alloc() : memref<6x6x6x6xf32>
 | 
			
		||||
  %temp_result = memref.alloc() : memref<6x6x6x6xf32>
 | 
			
		||||
  linalg.generic #pointwise_4d_trait
 | 
			
		||||
    ins(%summand_1, %summand_2 : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>)
 | 
			
		||||
   outs(%temp_result : memref<6x6x6x6xf32>) {
 | 
			
		||||
| 
						 | 
				
			
			@ -156,7 +156,7 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32
 | 
			
		|||
    %out = mulf %temp_result_in, %multiplier_in : f32
 | 
			
		||||
    linalg.yield %out : f32
 | 
			
		||||
  }
 | 
			
		||||
  dealloc %temp_result : memref<6x6x6x6xf32>
 | 
			
		||||
  memref.dealloc %temp_result : memref<6x6x6x6xf32>
 | 
			
		||||
  return
 | 
			
		||||
}
 | 
			
		||||
// CHECK-LABEL: func @fusion_4d
 | 
			
		||||
| 
						 | 
				
			
			@ -200,7 +200,7 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32
 | 
			
		|||
                       iterator_types = ["parallel", "parallel"]}
 | 
			
		||||
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
 | 
			
		||||
             %summand_2: memref<6x6xf32>) -> memref<6x6xf32> {
 | 
			
		||||
  %temp_result = alloc() : memref<6x6xf32>
 | 
			
		||||
  %temp_result = memref.alloc() : memref<6x6xf32>
 | 
			
		||||
  linalg.generic #pointwise_2d_trait
 | 
			
		||||
    ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>)
 | 
			
		||||
   outs(%temp_result : memref<6x6xf32>) {
 | 
			
		||||
| 
						 | 
				
			
			@ -208,7 +208,7 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
 | 
			
		|||
    %out = addf %summand_1_in, %summand_2_in : f32
 | 
			
		||||
    linalg.yield %out : f32
 | 
			
		||||
  }
 | 
			
		||||
  %result = alloc() : memref<6x6xf32>
 | 
			
		||||
  %result = memref.alloc() : memref<6x6xf32>
 | 
			
		||||
  linalg.generic #pointwise_2d_trait
 | 
			
		||||
    ins(%temp_result, %multiplier : memref<6x6xf32>, memref<6x6xf32>)
 | 
			
		||||
   outs(%result : memref<6x6xf32>) {
 | 
			
		||||
| 
						 | 
				
			
			@ -216,7 +216,7 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
 | 
			
		|||
    %out = mulf %temp_result_in, %multiplier_in : f32
 | 
			
		||||
    linalg.yield %out : f32
 | 
			
		||||
  }
 | 
			
		||||
  dealloc %temp_result : memref<6x6xf32>
 | 
			
		||||
  memref.dealloc %temp_result : memref<6x6xf32>
 | 
			
		||||
  return %result : memref<6x6xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -258,7 +258,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
 | 
			
		|||
    -> memref<*xf32> {
 | 
			
		||||
  %c1 = constant 1 : index
 | 
			
		||||
  %c0 = constant 0 : index
 | 
			
		||||
  %1 = alloc(%arg2) : memref<?xf32>
 | 
			
		||||
  %1 = memref.alloc(%arg2) : memref<?xf32>
 | 
			
		||||
  linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
 | 
			
		||||
                                   affine_map<(d0) -> (d0)>],
 | 
			
		||||
                  iterator_types = ["parallel"]}
 | 
			
		||||
| 
						 | 
				
			
			@ -267,7 +267,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
 | 
			
		|||
    %13 = absf %arg3 : f32
 | 
			
		||||
    linalg.yield %13 : f32
 | 
			
		||||
  }
 | 
			
		||||
  %2 = memref_reshape %1(%arg1)
 | 
			
		||||
  %2 = memref.reshape %1(%arg1)
 | 
			
		||||
      : (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
 | 
			
		||||
  return %2 : memref<*xf32>
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -279,7 +279,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
 | 
			
		|||
//   CHECK-NOT:  scf.for
 | 
			
		||||
//       CHECK:      linalg.generic
 | 
			
		||||
//       CHECK:        absf
 | 
			
		||||
//       CHECK:  memref_reshape
 | 
			
		||||
//       CHECK:  memref.reshape
 | 
			
		||||
 | 
			
		||||
// TILED-LABEL: func @view_result
 | 
			
		||||
//   TILED-DAG:  %[[C2:.*]] = constant 2
 | 
			
		||||
| 
						 | 
				
			
			@ -288,7 +288,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
 | 
			
		|||
//   TILED-NOT:  scf.for
 | 
			
		||||
//       TILED:      linalg.generic
 | 
			
		||||
//       TILED:        absf
 | 
			
		||||
//       TILED:  memref_reshape
 | 
			
		||||
//       TILED:  memref.reshape
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// PLOOP-LABEL: func @view_result
 | 
			
		||||
| 
						 | 
				
			
			@ -297,20 +297,20 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
 | 
			
		|||
//   PLOOP-NOT:  scf.parallel
 | 
			
		||||
//       PLOOP:      linalg.generic
 | 
			
		||||
//       PLOOP:        absf
 | 
			
		||||
//       PLOOP:  memref_reshape
 | 
			
		||||
//       PLOOP:  memref.reshape
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// Confirm that tiling information is passed through RegionBranchOpInterfaces.
 | 
			
		||||
// This test also uses memref_reshape, just to have a value to return through
 | 
			
		||||
// This test also uses memref.reshape, just to have a value to return through
 | 
			
		||||
// the if statement.
 | 
			
		||||
func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
 | 
			
		||||
    -> memref<*xf32> {
 | 
			
		||||
  %c1 = constant 1 : index
 | 
			
		||||
  %c0 = constant 0 : index
 | 
			
		||||
  %1 = alloc(%arg2) : memref<?xf32>
 | 
			
		||||
  %1 = memref.alloc(%arg2) : memref<?xf32>
 | 
			
		||||
  linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
 | 
			
		||||
                                   affine_map<(d0) -> (d0)>],
 | 
			
		||||
                  iterator_types = ["parallel"]}
 | 
			
		||||
| 
						 | 
				
			
			@ -321,11 +321,11 @@ func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: inde
 | 
			
		|||
  }
 | 
			
		||||
  %true = constant 1 : i1
 | 
			
		||||
  %3 = scf.if %true -> memref<*xf32> {
 | 
			
		||||
    %2 = memref_reshape %1(%arg1)
 | 
			
		||||
    %2 = memref.reshape %1(%arg1)
 | 
			
		||||
        : (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
 | 
			
		||||
    scf.yield %2 : memref<*xf32>
 | 
			
		||||
  } else {
 | 
			
		||||
    %2 = memref_reshape %1(%arg1)
 | 
			
		||||
    %2 = memref.reshape %1(%arg1)
 | 
			
		||||
        : (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
 | 
			
		||||
    scf.yield %2 : memref<*xf32>
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			@ -340,10 +340,10 @@ func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: inde
 | 
			
		|||
//       CHECK:      linalg.generic
 | 
			
		||||
//       CHECK:        absf
 | 
			
		||||
//       CHECK:  scf.if
 | 
			
		||||
//       CHECK:    memref_reshape
 | 
			
		||||
//       CHECK:    memref.reshape
 | 
			
		||||
//       CHECK:    scf.yield
 | 
			
		||||
//       CHECK:  else
 | 
			
		||||
//       CHECK:    memref_reshape
 | 
			
		||||
//       CHECK:    memref.reshape
 | 
			
		||||
//       CHECK:    scf.yield
 | 
			
		||||
 | 
			
		||||
// TILED-LABEL: func @branching_result
 | 
			
		||||
| 
						 | 
				
			
			@ -354,10 +354,10 @@ func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: inde
 | 
			
		|||
//       TILED:      linalg.generic
 | 
			
		||||
//       TILED:        absf
 | 
			
		||||
//       TILED:  scf.if
 | 
			
		||||
//       TILED:    memref_reshape
 | 
			
		||||
//       TILED:    memref.reshape
 | 
			
		||||
//       TILED:    scf.yield
 | 
			
		||||
//       TILED:  else
 | 
			
		||||
//       TILED:    memref_reshape
 | 
			
		||||
//       TILED:    memref.reshape
 | 
			
		||||
//       TILED:    scf.yield
 | 
			
		||||
 | 
			
		||||
// PLOOP-LABEL: func @branching_result
 | 
			
		||||
| 
						 | 
				
			
			@ -367,10 +367,10 @@ func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: inde
 | 
			
		|||
//       PLOOP:      linalg.generic
 | 
			
		||||
//       PLOOP:        absf
 | 
			
		||||
//       PLOOP:  scf.if
 | 
			
		||||
//       PLOOP:    memref_reshape
 | 
			
		||||
//       PLOOP:    memref.reshape
 | 
			
		||||
//       PLOOP:    scf.yield
 | 
			
		||||
//       PLOOP:  else
 | 
			
		||||
//       PLOOP:    memref_reshape
 | 
			
		||||
//       PLOOP:    memref.reshape
 | 
			
		||||
//       PLOOP:    scf.yield
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
| 
						 | 
				
			
			@ -380,7 +380,7 @@ func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: inde
 | 
			
		|||
func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
 | 
			
		||||
    -> memref<?xf32> {
 | 
			
		||||
  %c1 = constant 1 : index
 | 
			
		||||
  %1 = alloc() : memref<32xf32>
 | 
			
		||||
  %1 = memref.alloc() : memref<32xf32>
 | 
			
		||||
  linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
 | 
			
		||||
                                   affine_map<(d0) -> (d0)>],
 | 
			
		||||
                  iterator_types = ["parallel"]}
 | 
			
		||||
| 
						 | 
				
			
			@ -389,9 +389,9 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
 | 
			
		|||
    %13 = absf %arg3 : f32
 | 
			
		||||
    linalg.yield %13 : f32
 | 
			
		||||
  }
 | 
			
		||||
  %2 = tensor_load %1 : memref<32xf32>
 | 
			
		||||
  %2 = memref.tensor_load %1 : memref<32xf32>
 | 
			
		||||
  %3 = tensor.cast %2 : tensor<32xf32> to tensor<?xf32>
 | 
			
		||||
  %4 = tensor_to_memref %3 : memref<?xf32>
 | 
			
		||||
  %4 = memref.buffer_cast %3 : memref<?xf32>
 | 
			
		||||
  return %4 : memref<?xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -402,9 +402,9 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
 | 
			
		|||
//   CHECK-NOT:  scf.for
 | 
			
		||||
//       CHECK:      linalg.generic
 | 
			
		||||
//       CHECK:        absf
 | 
			
		||||
//       CHECK:  tensor_load
 | 
			
		||||
//       CHECK:  memref.tensor_load
 | 
			
		||||
//       CHECK:  tensor.cast
 | 
			
		||||
//       CHECK:  tensor_to_memref
 | 
			
		||||
//       CHECK:  memref.buffer_cast
 | 
			
		||||
 | 
			
		||||
// TILED-LABEL: func @tensor_ops
 | 
			
		||||
//   TILED-DAG:  %[[C2:.*]] = constant 2
 | 
			
		||||
| 
						 | 
				
			
			@ -413,9 +413,9 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
 | 
			
		|||
//   TILED-NOT:  scf.for
 | 
			
		||||
//       TILED:      linalg.generic
 | 
			
		||||
//       TILED:        absf
 | 
			
		||||
//       TILED:  tensor_load
 | 
			
		||||
//       TILED:  memref.tensor_load
 | 
			
		||||
//       TILED:  tensor.cast
 | 
			
		||||
//       TILED:  tensor_to_memref
 | 
			
		||||
//       TILED:  memref.buffer_cast
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// PLOOP-LABEL: func @tensor_ops
 | 
			
		||||
| 
						 | 
				
			
			@ -424,6 +424,6 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
 | 
			
		|||
//   PLOOP-NOT:  scf.parallel
 | 
			
		||||
//       PLOOP:      linalg.generic
 | 
			
		||||
//       PLOOP:        absf
 | 
			
		||||
//       PLOOP:  tensor_load
 | 
			
		||||
//       PLOOP:  memref.tensor_load
 | 
			
		||||
//       PLOOP:  tensor.cast
 | 
			
		||||
//       PLOOP:  tensor_to_memref
 | 
			
		||||
//       PLOOP:  memref.buffer_cast
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -49,10 +49,10 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
 | 
			
		|||
// CHECK-DAG:  [[CTRUE:%.*]] = constant true
 | 
			
		||||
 | 
			
		||||
// Parallel loop to initialize the output buffer.
 | 
			
		||||
// CHECK:    [[INIT:%.*]] = load [[INIT_BUF]][] : memref<f32>
 | 
			
		||||
// CHECK:    [[INIT:%.*]] = memref.load [[INIT_BUF]][] : memref<f32>
 | 
			
		||||
// CHECK:    scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
 | 
			
		||||
// CHECK-SAME:          to ([[C112]], [[C112]]) step ([[C1]], [[C1]]) {
 | 
			
		||||
// CHECK:      store [[INIT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]]
 | 
			
		||||
// CHECK:      memref.store [[INIT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]]
 | 
			
		||||
// CHECK:      scf.yield
 | 
			
		||||
// CHECK:    }
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -101,7 +101,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
 | 
			
		|||
 | 
			
		||||
  // INBOUNDS-THEN-BODY, i.e. if INBOUNDS == true
 | 
			
		||||
 | 
			
		||||
  // CHECK: [[ARG_ELEM:%.*]] = load [[ARG_BUF]]{{\[}}[[ARG_I]], [[ARG_J]]]
 | 
			
		||||
  // CHECK: [[ARG_ELEM:%.*]] = memref.load [[ARG_BUF]]{{\[}}[[ARG_I]], [[ARG_J]]]
 | 
			
		||||
  // CHECK: [[IF_INIT_RES:%.*]]:4
 | 
			
		||||
  // CHECK-SAME:  = scf.if [[SEL_INIT]] -> (index, index, f32, i1) {
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -114,16 +114,16 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
 | 
			
		|||
 | 
			
		||||
    // Allocate buffers for ARG element, current selected value to adapt LHLO
 | 
			
		||||
    // code.
 | 
			
		||||
    // CHECK:  [[ARG_ELEM_BUF:%.*]] = alloc() : memref<f32>
 | 
			
		||||
    // CHECK:  [[SEL_VAL_BUF:%.*]] = alloc() : memref<f32>
 | 
			
		||||
    // CHECK:  [[PRED_BUF:%.*]] = alloc() : memref<i1>
 | 
			
		||||
    // CHECK:  store [[ARG_ELEM]], [[ARG_ELEM_BUF]][] : memref<f32>
 | 
			
		||||
    // CHECK:  store [[SEL_VAL]], [[SEL_VAL_BUF]][] : memref<f32>
 | 
			
		||||
    // CHECK:  [[ARG_ELEM_BUF:%.*]] = memref.alloc() : memref<f32>
 | 
			
		||||
    // CHECK:  [[SEL_VAL_BUF:%.*]] = memref.alloc() : memref<f32>
 | 
			
		||||
    // CHECK:  [[PRED_BUF:%.*]] = memref.alloc() : memref<i1>
 | 
			
		||||
    // CHECK:  memref.store [[ARG_ELEM]], [[ARG_ELEM_BUF]][] : memref<f32>
 | 
			
		||||
    // CHECK:  memref.store [[SEL_VAL]], [[SEL_VAL_BUF]][] : memref<f32>
 | 
			
		||||
 | 
			
		||||
    // Compute PRED.
 | 
			
		||||
    // CHECK:  "lmhlo.compare"(
 | 
			
		||||
    // CHECK-SAME:     [[ARG_ELEM_BUF]], [[SEL_VAL_BUF]], [[PRED_BUF]])
 | 
			
		||||
    // CHECK:      [[PRED:%.*]] = load [[PRED_BUF]][] : memref<i1>
 | 
			
		||||
    // CHECK:      [[PRED:%.*]] = memref.load [[PRED_BUF]][] : memref<i1>
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    // Depending on PRED, return ARG ivs & elem or current select ivs and value.
 | 
			
		||||
| 
						 | 
				
			
			@ -165,7 +165,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
 | 
			
		|||
// CHECK:  }
 | 
			
		||||
 | 
			
		||||
// Use selected ivs to load element from the SRC buffer.
 | 
			
		||||
// CHECK: [[SRC_ELEM:%.*]] = load [[SRC_BUF]]{{\[}}[[II]], [[JJ]]]
 | 
			
		||||
// CHECK: [[SRC_ELEM:%.*]] = memref.load [[SRC_BUF]]{{\[}}[[II]], [[JJ]]]
 | 
			
		||||
 | 
			
		||||
// Update of RESULT[SELECTED_I, SELECTED_J] should be done atomically, because
 | 
			
		||||
// it may happen that several other threads select the same IVs if the windows
 | 
			
		||||
| 
						 | 
				
			
			@ -175,16 +175,16 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
 | 
			
		|||
// CHECK: ^bb0([[CUR_RES:%.*]]: f32):
 | 
			
		||||
 | 
			
		||||
// Allocate buffers for ARG element, current selected value to adapt LHLO code.
 | 
			
		||||
// CHECK:  [[SRC_ELEM_BUF:%.*]] = alloc() : memref<f32>
 | 
			
		||||
// CHECK:  [[CUR_RES_BUF:%.*]] = alloc() : memref<f32>
 | 
			
		||||
// CHECK:  [[RES_BUF:%.*]] = alloc() : memref<f32>
 | 
			
		||||
// CHECK:  store [[SRC_ELEM]], [[SRC_ELEM_BUF]][] : memref<f32>
 | 
			
		||||
// CHECK:  store [[CUR_RES]], [[CUR_RES_BUF]][] : memref<f32>
 | 
			
		||||
// CHECK:  [[SRC_ELEM_BUF:%.*]] = memref.alloc() : memref<f32>
 | 
			
		||||
// CHECK:  [[CUR_RES_BUF:%.*]] = memref.alloc() : memref<f32>
 | 
			
		||||
// CHECK:  [[RES_BUF:%.*]] = memref.alloc() : memref<f32>
 | 
			
		||||
// CHECK:  memref.store [[SRC_ELEM]], [[SRC_ELEM_BUF]][] : memref<f32>
 | 
			
		||||
// CHECK:  memref.store [[CUR_RES]], [[CUR_RES_BUF]][] : memref<f32>
 | 
			
		||||
 | 
			
		||||
// Compute scatter value.
 | 
			
		||||
// CHECK:  "lmhlo.add"([[SRC_ELEM_BUF]], [[CUR_RES_BUF]], [[RES_BUF]]) :
 | 
			
		||||
// CHECK-SAME: (memref<f32>, memref<f32>, memref<f32>) -> ()
 | 
			
		||||
// CHECK:  [[RES:%.*]] = load [[RES_BUF]][] : memref<f32>
 | 
			
		||||
// CHECK:  [[RES:%.*]] = memref.load [[RES_BUF]][] : memref<f32>
 | 
			
		||||
 | 
			
		||||
// Atomic RMW terminator that returns updated value.
 | 
			
		||||
// CHECK:  atomic_yield [[RES]] : f32
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,14 +19,14 @@ func @reduce(%arg: memref<100x10xf32>,
 | 
			
		|||
// CHECK-DAG: %[[C100:.*]] = constant 100 : index
 | 
			
		||||
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
 | 
			
		||||
//     CHECK: gpu.launch blocks({{.*}}, {{.*}}, {{.*}}) in ({{.*}} = %[[C1]], {{.*}} = %[[C1]], {{.*}} = %[[C1]]) threads(%[[IDX:.*]], {{.*}}, {{.*}}) in ({{.*}} = %[[C100]], {{.*}} = %[[C1]], {{.*}} = %[[C1]]) {
 | 
			
		||||
//     CHECK:   %[[ACC:.*]] = load %[[ARG1]][] : memref<f32>
 | 
			
		||||
//     CHECK:   %[[ACC:.*]] = memref.load %[[ARG1]][] : memref<f32>
 | 
			
		||||
//     CHECK:   store %[[ACC]], %[[ARG2]][%[[IDX:.*]]] : memref<100xf32>
 | 
			
		||||
// CHECK-DAG:   %[[LB:.*]] = constant 0 : index
 | 
			
		||||
// CHECK-DAG:   %[[UB:.*]] = constant 10 : index
 | 
			
		||||
// CHECK-DAG:   %[[STEP:.*]] = constant 1 : index
 | 
			
		||||
//     CHECK:   scf.for %[[IDX1:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
 | 
			
		||||
//     CHECK:     %[[LHS:.*]] = subview %[[ARG2]][%[[IDX]]] [1] [1] : memref<100xf32> to memref<f32, #[[$MAP]]>
 | 
			
		||||
//     CHECK:     %[[RHS:.*]] = subview %[[ARG0]][%[[IDX]], %[[IDX1]]] [1, 1] [1, 1] : memref<100x10xf32> to memref<f32, #[[$MAP]]>
 | 
			
		||||
//     CHECK:     %[[LHS:.*]] = memref.subview %[[ARG2]][%[[IDX]]] [1] [1] : memref<100xf32> to memref<f32, #[[$MAP]]>
 | 
			
		||||
//     CHECK:     %[[RHS:.*]] = memref.subview %[[ARG0]][%[[IDX]], %[[IDX1]]] [1, 1] [1, 1] : memref<100x10xf32> to memref<f32, #[[$MAP]]>
 | 
			
		||||
//     CHECK:     "lmhlo.add"(%[[LHS]], %[[RHS]], %[[LHS]]) : (memref<f32, {{.*}}>, memref<f32, {{.*}}>, memref<f32, {{.*}}>) -> ()
 | 
			
		||||
//     CHECK:   }
 | 
			
		||||
//     CHECK:   gpu.terminator
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -52,10 +52,10 @@ func @element_wise_scalar(%lhs: memref<f32>, %rhs: memref<f32>,
 | 
			
		|||
      : (memref<f32>, memref<f32>, memref<f32>) -> ()
 | 
			
		||||
  return
 | 
			
		||||
}
 | 
			
		||||
// CHECK: %[[LHS:.*]] = load
 | 
			
		||||
// CHECK: %[[RHS:.*]] = load
 | 
			
		||||
// CHECK: %[[LHS:.*]] = memref.load
 | 
			
		||||
// CHECK: %[[RHS:.*]] = memref.load
 | 
			
		||||
// CHECK: %[[RES:.*]] = addf %[[LHS]], %[[RHS]]
 | 
			
		||||
// CHECK: store %[[RES]]
 | 
			
		||||
// CHECK: memref.store %[[RES]]
 | 
			
		||||
// CHECK-NEXT: return
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
| 
						 | 
				
			
			@ -347,7 +347,7 @@ func @static_broadcast_in_dim_with_one_to_many(%operand: memref<1xf32>,
 | 
			
		|||
}
 | 
			
		||||
// CHECK-NOT: linalg.reshape
 | 
			
		||||
// CHECK: %[[C0:.*]] = constant 0 : index
 | 
			
		||||
// CHECK: %[[VALUE:.*]] = load %{{.*}}[[C0]]
 | 
			
		||||
// CHECK: %[[VALUE:.*]] = memref.load %{{.*}}[[C0]]
 | 
			
		||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[RESULT_MAP]]]
 | 
			
		||||
// CHECK-NEXT: ^bb0(%{{.+}}: f32):
 | 
			
		||||
// CHECK-NEXT:   linalg.yield %[[VALUE]] : f32
 | 
			
		||||
| 
						 | 
				
			
			@ -785,7 +785,7 @@ func @slice(%operand: memref<?x?xf32>, %result: memref<?x?xf32>) {
 | 
			
		|||
  } : (memref<?x?xf32>, memref<?x?xf32>) -> ()
 | 
			
		||||
  return
 | 
			
		||||
}
 | 
			
		||||
// CHECK: %[[RESULT:.*]] = subview %[[IN]][0, 1] [2, 2] [1, 1] : memref<?x?xf32> to memref<2x2xf32, #{{.*}}>
 | 
			
		||||
// CHECK: %[[RESULT:.*]] = memref.subview %[[IN]][0, 1] [2, 2] [1, 1] : memref<?x?xf32> to memref<2x2xf32, #{{.*}}>
 | 
			
		||||
// CHECK: linalg.copy(%[[RESULT]], %[[OUT]])
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
| 
						 | 
				
			
			@ -899,7 +899,7 @@ func @reverse(%arg0: memref<2x3xf32>, %arg1: memref<2x3xf32>) {
 | 
			
		|||
 | 
			
		||||
func @conv(%input: memref<3x5x5x3xf32>, %filter: memref<2x2x3x4xf32>, %output: memref<3x5x5x4xf32>) {
 | 
			
		||||
  %c0 = constant 0 : index
 | 
			
		||||
  %0 = alloc() : memref<3x5x5x4xf32>
 | 
			
		||||
  %0 = memref.alloc() : memref<3x5x5x4xf32>
 | 
			
		||||
  // CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}})
 | 
			
		||||
  // CHECK-SAME: dilations = [1, 2]
 | 
			
		||||
  // CHECK-SAME: padding = dense<{{\[\[}}0, 1], [0, 1]]> : tensor<2x2xi64>
 | 
			
		||||
| 
						 | 
				
			
			@ -948,22 +948,22 @@ func @reduce_add(%arg: memref<100x10xf32>,
 | 
			
		|||
      : (memref<100x10xf32>, memref<f32>, memref<100xf32>) -> ()
 | 
			
		||||
  return
 | 
			
		||||
}
 | 
			
		||||
// CHECK: %[[INIT_VAL:.*]] = load %arg1[] : memref<f32>
 | 
			
		||||
// CHECK: %[[INIT_VAL:.*]] = memref.load %arg1[] : memref<f32>
 | 
			
		||||
// CHECK: linalg.fill(%arg2, %[[INIT_VAL]])
 | 
			
		||||
// CHECK: linalg.generic {
 | 
			
		||||
// CHECK-SAME: indexing_maps = [#[[REDUCE_INPUT_MAP]], #[[REDUCE_OUTPUT_MAP]]],
 | 
			
		||||
// CHECK-SAME: iterator_types = ["parallel", "reduction"]}
 | 
			
		||||
// CHECK-SAME: ins(%arg0 : memref<100x10xf32>) outs(%arg2 : memref<100xf32>) {
 | 
			
		||||
// CHECK: alloca
 | 
			
		||||
// CHECK-NEXT: alloca
 | 
			
		||||
// CHECK-NEXT: alloca
 | 
			
		||||
// CHECK-NEXT: store
 | 
			
		||||
// CHECK-NEXT: store
 | 
			
		||||
// CHECK-NEXT: load
 | 
			
		||||
// CHECK-NEXT: load
 | 
			
		||||
// CHECK: memref.alloca
 | 
			
		||||
// CHECK-NEXT: memref.alloca
 | 
			
		||||
// CHECK-NEXT: memref.alloca
 | 
			
		||||
// CHECK-NEXT: memref.store
 | 
			
		||||
// CHECK-NEXT: memref.store
 | 
			
		||||
// CHECK-NEXT: memref.load
 | 
			
		||||
// CHECK-NEXT: memref.load
 | 
			
		||||
// CHECK-NEXT: addf
 | 
			
		||||
// CHECK-NEXT: store
 | 
			
		||||
// CHECK-NEXT: load
 | 
			
		||||
// CHECK-NEXT: memref.store
 | 
			
		||||
// CHECK-NEXT: memref.load
 | 
			
		||||
// CHECK-NEXT: linalg.yield
 | 
			
		||||
// CHECK-NEXT: }
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -984,22 +984,22 @@ func @reduce_maximum(%arg: memref<100x10xf32>,
 | 
			
		|||
      : (memref<100x10xf32>, memref<f32>, memref<100xf32>) -> ()
 | 
			
		||||
  return
 | 
			
		||||
}
 | 
			
		||||
// CHECK: %[[INIT_VAL:.*]] = load %arg1[] : memref<f32>
 | 
			
		||||
// CHECK: %[[INIT_VAL:.*]] = memref.load %arg1[] : memref<f32>
 | 
			
		||||
// CHECK: linalg.fill(%arg2, %[[INIT_VAL]])
 | 
			
		||||
// CHECK: linalg.generic {
 | 
			
		||||
// CHECK-SAME: indexing_maps = [#[[REDUCE_INPUT_MAP]], #[[REDUCE_OUTPUT_MAP]]],
 | 
			
		||||
// CHECK-SAME: iterator_types = ["parallel", "reduction"]}
 | 
			
		||||
// CHECK-SAME: ins(%arg0 : memref<100x10xf32>) outs(%arg2 : memref<100xf32>) {
 | 
			
		||||
// CHECK: alloca
 | 
			
		||||
// CHECK-NEXT: alloca
 | 
			
		||||
// CHECK-NEXT: alloca
 | 
			
		||||
// CHECK-NEXT: store
 | 
			
		||||
// CHECK-NEXT: store
 | 
			
		||||
// CHECK-NEXT: load
 | 
			
		||||
// CHECK-NEXT: load
 | 
			
		||||
// CHECK: memref.alloca
 | 
			
		||||
// CHECK-NEXT: memref.alloca
 | 
			
		||||
// CHECK-NEXT: memref.alloca
 | 
			
		||||
// CHECK-NEXT: memref.store
 | 
			
		||||
// CHECK-NEXT: memref.store
 | 
			
		||||
// CHECK-NEXT: memref.load
 | 
			
		||||
// CHECK-NEXT: memref.load
 | 
			
		||||
// CHECK: cmpf
 | 
			
		||||
// CHECK: select
 | 
			
		||||
// CHECK: store
 | 
			
		||||
// CHECK-NEXT: load
 | 
			
		||||
// CHECK: memref.store
 | 
			
		||||
// CHECK-NEXT: memref.load
 | 
			
		||||
// CHECK-NEXT: linalg.yield
 | 
			
		||||
// CHECK-NEXT: }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -21,27 +21,27 @@ func @reduce(%arg: memref<100x10x5xf32>,
 | 
			
		|||
// CHECK-DAG:  [[C5:%.*]] = constant 5 : index
 | 
			
		||||
// CHECK-DAG:  [[C10:%.*]] = constant 10 : index
 | 
			
		||||
// CHECK-DAG:  [[C100:%.*]] = constant 100 : index
 | 
			
		||||
// CHECK:  [[INIT:%.*]] = load [[INIT_BUF]]
 | 
			
		||||
// CHECK:  [[INIT:%.*]] = memref.load [[INIT_BUF]]
 | 
			
		||||
// CHECK:  scf.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]])
 | 
			
		||||
// CHECK-SAME:                     to ([[C100]], [[C5]]) step ([[C1]], [[C1]]) {
 | 
			
		||||
// CHECK:    [[REDUCTION_RESULT:%.*]] = scf.parallel ([[J:%.*]]) =
 | 
			
		||||
// CHECK-SAME:      ([[C0]]) to ([[C10]]) step ([[C1]]) init ([[INIT]]) -> f32 {
 | 
			
		||||
// CHECK:      [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]]
 | 
			
		||||
// CHECK:      [[ELEM_TO_REDUCE:%.*]] = memref.load [[ARG_BUF]]
 | 
			
		||||
// CHECK-SAME:                 {{\[}}[[I]], [[J]], [[K]]] : memref<100x10x5xf32>
 | 
			
		||||
// CHECK:      scf.reduce([[ELEM_TO_REDUCE]]) : f32 {
 | 
			
		||||
// CHECK:      ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
 | 
			
		||||
// CHECK:        [[ELEM_BUF:%.*]] = alloc() : memref<f32>
 | 
			
		||||
// CHECK:        [[ACC_BUF:%.*]] = alloc() : memref<f32>
 | 
			
		||||
// CHECK:        [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
 | 
			
		||||
// CHECK:        store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
 | 
			
		||||
// CHECK:        store [[ACC]], [[ACC_BUF]][] : memref<f32>
 | 
			
		||||
// CHECK:        [[ELEM_BUF:%.*]] = memref.alloc() : memref<f32>
 | 
			
		||||
// CHECK:        [[ACC_BUF:%.*]] = memref.alloc() : memref<f32>
 | 
			
		||||
// CHECK:        [[ACC_OUT_BUF:%.*]] = memref.alloc() : memref<f32>
 | 
			
		||||
// CHECK:        memref.store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
 | 
			
		||||
// CHECK:        memref.store [[ACC]], [[ACC_BUF]][] : memref<f32>
 | 
			
		||||
// CHECK:        "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
 | 
			
		||||
// CHECK:        [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
 | 
			
		||||
// CHECK:        [[ACC_RESULT:%.*]] = memref.load [[ACC_OUT_BUF]][] : memref<f32>
 | 
			
		||||
// CHECK:        scf.reduce.return [[ACC_RESULT]] : f32
 | 
			
		||||
// CHECK:      }
 | 
			
		||||
// CHECK:      scf.yield
 | 
			
		||||
// CHECK:    }
 | 
			
		||||
// CHECK:    store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]]
 | 
			
		||||
// CHECK:    memref.store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]]
 | 
			
		||||
// CHECK:    scf.yield
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
| 
						 | 
				
			
			@ -65,23 +65,23 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>,
 | 
			
		|||
// CHECK-DAG:  [[C0:%.*]] = constant 0 : index
 | 
			
		||||
// CHECK-DAG:  [[C1:%.*]] = constant 1 : index
 | 
			
		||||
// CHECK-DAG:  [[C100:%.*]] = constant 100 : index
 | 
			
		||||
// CHECK:      [[INIT:%.*]] = load [[INIT_BUF]]
 | 
			
		||||
// CHECK:      [[INIT:%.*]] = memref.load [[INIT_BUF]]
 | 
			
		||||
// CHECK:      [[REDUCTION_RESULT:%.*]] = scf.parallel ([[I:%.*]]) = ([[C0]])
 | 
			
		||||
// CHECK-SAME:     to ([[C100]]) step ([[C1]]) init ([[INIT]]) -> f32 {
 | 
			
		||||
// CHECK:        [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]]{{\[}}[[I]]{{\]}}
 | 
			
		||||
// CHECK:        [[ELEM_TO_REDUCE:%.*]] = memref.load [[ARG_BUF]]{{\[}}[[I]]{{\]}}
 | 
			
		||||
// CHECK:        scf.reduce([[ELEM_TO_REDUCE]]) : f32 {
 | 
			
		||||
// CHECK:        ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
 | 
			
		||||
// CHECK:          [[ELEM_BUF:%.*]] = alloc() : memref<f32>
 | 
			
		||||
// CHECK:          [[ACC_BUF:%.*]] = alloc() : memref<f32>
 | 
			
		||||
// CHECK:          [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
 | 
			
		||||
// CHECK:          store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
 | 
			
		||||
// CHECK:          store [[ACC]], [[ACC_BUF]][] : memref<f32>
 | 
			
		||||
// CHECK:          [[ELEM_BUF:%.*]] = memref.alloc() : memref<f32>
 | 
			
		||||
// CHECK:          [[ACC_BUF:%.*]] = memref.alloc() : memref<f32>
 | 
			
		||||
// CHECK:          [[ACC_OUT_BUF:%.*]] = memref.alloc() : memref<f32>
 | 
			
		||||
// CHECK:          memref.store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
 | 
			
		||||
// CHECK:          memref.store [[ACC]], [[ACC_BUF]][] : memref<f32>
 | 
			
		||||
// CHECK:          "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
 | 
			
		||||
// CHECK:          [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
 | 
			
		||||
// CHECK:          [[ACC_RESULT:%.*]] = memref.load [[ACC_OUT_BUF]][] : memref<f32>
 | 
			
		||||
// CHECK:          scf.reduce.return [[ACC_RESULT]]
 | 
			
		||||
// CHECK:        }
 | 
			
		||||
// CHECK:        scf.yield
 | 
			
		||||
// CHECK:      store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[C0]]]
 | 
			
		||||
// CHECK:      memref.store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[C0]]]
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -104,30 +104,30 @@ func @dynamic_reduce(%arg: memref<?x?x?xf32>,
 | 
			
		|||
// CHECK-DAG:  [[C0:%.*]] = constant 0 : index
 | 
			
		||||
// CHECK-DAG:  [[C1:%.*]] = constant 1 : index
 | 
			
		||||
// CHECK-DAG:  [[C2:%.*]] = constant 2 : index
 | 
			
		||||
// CHECK:  [[DIM0:%.*]] = dim [[ARG_BUF]], [[C0]] : memref<?x?x?xf32>
 | 
			
		||||
// CHECK:  [[DIM1:%.*]] = dim [[ARG_BUF]], [[C1]] : memref<?x?x?xf32>
 | 
			
		||||
// CHECK:  [[DIM2:%.*]] = dim [[ARG_BUF]], [[C2]] : memref<?x?x?xf32>
 | 
			
		||||
// CHECK:  [[INIT:%.*]] = load [[INIT_BUF]]
 | 
			
		||||
// CHECK:  [[DIM0:%.*]] = memref.dim [[ARG_BUF]], [[C0]] : memref<?x?x?xf32>
 | 
			
		||||
// CHECK:  [[DIM1:%.*]] = memref.dim [[ARG_BUF]], [[C1]] : memref<?x?x?xf32>
 | 
			
		||||
// CHECK:  [[DIM2:%.*]] = memref.dim [[ARG_BUF]], [[C2]] : memref<?x?x?xf32>
 | 
			
		||||
// CHECK:  [[INIT:%.*]] = memref.load [[INIT_BUF]]
 | 
			
		||||
// CHECK:  scf.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]])
 | 
			
		||||
// CHECK-SAME:                     to ([[DIM0]], [[DIM2]]) step ([[C1]], [[C1]]) {
 | 
			
		||||
// CHECK:    [[REDUCTION_RESULT:%.*]] = scf.parallel ([[J:%.*]]) =
 | 
			
		||||
// CHECK-SAME:     ([[C0]]) to ([[DIM1]]) step ([[C1]]) init ([[INIT]]) -> f32 {
 | 
			
		||||
// CHECK:      [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]]
 | 
			
		||||
// CHECK:      [[ELEM_TO_REDUCE:%.*]] = memref.load [[ARG_BUF]]
 | 
			
		||||
// CHECK-SAME:                 {{\[}}[[I]], [[J]], [[K]]] : memref<?x?x?xf32>
 | 
			
		||||
// CHECK:      scf.reduce([[ELEM_TO_REDUCE]]) : f32 {
 | 
			
		||||
// CHECK:      ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
 | 
			
		||||
// CHECK:        [[ELEM_BUF:%.*]] = alloc() : memref<f32>
 | 
			
		||||
// CHECK:        [[ACC_BUF:%.*]] = alloc() : memref<f32>
 | 
			
		||||
// CHECK:        [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
 | 
			
		||||
// CHECK:        store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
 | 
			
		||||
// CHECK:        store [[ACC]], [[ACC_BUF]][] : memref<f32>
 | 
			
		||||
// CHECK:        [[ELEM_BUF:%.*]] = memref.alloc() : memref<f32>
 | 
			
		||||
// CHECK:        [[ACC_BUF:%.*]] = memref.alloc() : memref<f32>
 | 
			
		||||
// CHECK:        [[ACC_OUT_BUF:%.*]] = memref.alloc() : memref<f32>
 | 
			
		||||
// CHECK:        memref.store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
 | 
			
		||||
// CHECK:        memref.store [[ACC]], [[ACC_BUF]][] : memref<f32>
 | 
			
		||||
// CHECK:        "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
 | 
			
		||||
// CHECK:        [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
 | 
			
		||||
// CHECK:        [[ACC_RESULT:%.*]] = memref.load [[ACC_OUT_BUF]][] : memref<f32>
 | 
			
		||||
// CHECK:        scf.reduce.return [[ACC_RESULT]] : f32
 | 
			
		||||
// CHECK:      }
 | 
			
		||||
// CHECK:      scf.yield
 | 
			
		||||
// CHECK:    }
 | 
			
		||||
// CHECK:    store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]]
 | 
			
		||||
// CHECK:    memref.store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]]
 | 
			
		||||
// CHECK:    scf.yield
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
| 
						 | 
				
			
			@ -157,7 +157,7 @@ func @reduce_window(%arg: memref<112x112xf32>,
 | 
			
		|||
// CHECK-DAG:  [[C3:%.*]] = constant 3 : index
 | 
			
		||||
// CHECK-DAG:  [[C56:%.*]] = constant 56 : index
 | 
			
		||||
// CHECK-DAG:  [[C112:%.*]] = constant 112 : index
 | 
			
		||||
// CHECK:      [[INIT:%.*]] = load [[INIT_BUF]][] : memref<f32>
 | 
			
		||||
// CHECK:      [[INIT:%.*]] = memref.load [[INIT_BUF]][] : memref<f32>
 | 
			
		||||
// CHECK:      scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
 | 
			
		||||
// CHECK-SAME:         to ([[C56]], [[C56]]) step ([[C1]], [[C1]]) {
 | 
			
		||||
// CHECK:        [[REDUCTION_RESULT:%.*]] = scf.parallel
 | 
			
		||||
| 
						 | 
				
			
			@ -176,7 +176,7 @@ func @reduce_window(%arg: memref<112x112xf32>,
 | 
			
		|||
 | 
			
		||||
// CHECK:          [[ELEM_TO_REDUCE:%.*]] = scf.if [[IN_BOUNDS_1]] -> (f32) {
 | 
			
		||||
// CHECK:            [[OPERAND_ELEM:%.*]] =
 | 
			
		||||
// CHECK-SAME:         load [[OPERAND_BUF]]{{\[}}[[INDEX_I]], [[INDEX_J]]]
 | 
			
		||||
// CHECK-SAME:         memref.load [[OPERAND_BUF]]{{\[}}[[INDEX_I]], [[INDEX_J]]]
 | 
			
		||||
// CHECK:              scf.yield [[OPERAND_ELEM]] : f32
 | 
			
		||||
// CHECK:            } else {
 | 
			
		||||
// CHECK:              scf.yield [[INIT]] : f32
 | 
			
		||||
| 
						 | 
				
			
			@ -184,18 +184,18 @@ func @reduce_window(%arg: memref<112x112xf32>,
 | 
			
		|||
 | 
			
		||||
// CHECK:          scf.reduce([[ELEM_TO_REDUCE]])  : f32 {
 | 
			
		||||
// CHECK:          ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
 | 
			
		||||
// CHECK:            [[ELEM_BUF:%.*]] = alloc() : memref<f32>
 | 
			
		||||
// CHECK:            [[ACC_BUF:%.*]] = alloc() : memref<f32>
 | 
			
		||||
// CHECK:            [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
 | 
			
		||||
// CHECK:            store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
 | 
			
		||||
// CHECK:            store [[ACC]], [[ACC_BUF]][] : memref<f32>
 | 
			
		||||
// CHECK:            [[ELEM_BUF:%.*]] = memref.alloc() : memref<f32>
 | 
			
		||||
// CHECK:            [[ACC_BUF:%.*]] = memref.alloc() : memref<f32>
 | 
			
		||||
// CHECK:            [[ACC_OUT_BUF:%.*]] = memref.alloc() : memref<f32>
 | 
			
		||||
// CHECK:            memref.store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
 | 
			
		||||
// CHECK:            memref.store [[ACC]], [[ACC_BUF]][] : memref<f32>
 | 
			
		||||
// CHECK:            "lmhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
 | 
			
		||||
// CHECK:            [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
 | 
			
		||||
// CHECK:            [[ACC_RESULT:%.*]] = memref.load [[ACC_OUT_BUF]][] : memref<f32>
 | 
			
		||||
// CHECK:            scf.reduce.return [[ACC_RESULT]] : f32
 | 
			
		||||
// CHECK:          }
 | 
			
		||||
// CHECK:          scf.yield
 | 
			
		||||
// CHECK:        }
 | 
			
		||||
// CHECK:        store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]]
 | 
			
		||||
// CHECK:        memref.store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]]
 | 
			
		||||
// CHECK:        scf.yield
 | 
			
		||||
// CHECK:      }
 | 
			
		||||
// CHECK:      return
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -30,7 +30,7 @@ func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf3
 | 
			
		|||
 | 
			
		||||
// CHECK-LABEL: func @conv_forward
 | 
			
		||||
func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %output: memref<1x1x7x7xf16>) {
 | 
			
		||||
  %scratch = alloc() : memref<32xi8>
 | 
			
		||||
  %scratch = memref.alloc() : memref<32xi8>
 | 
			
		||||
  // This defined a 2D convolution over a 8x8 single channel input using a 2x2
 | 
			
		||||
  // filter and with an output of 7x7xf16. The 1x1x8x8 is (N, C, H, W)
 | 
			
		||||
  "lmhlo_gpu.conv_forward"(%input, %filter, %output, %scratch)
 | 
			
		||||
| 
						 | 
				
			
			@ -61,7 +61,7 @@ func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %
 | 
			
		|||
 | 
			
		||||
// CHECK-LABEL: func @conv_backfilter
 | 
			
		||||
func @conv_backfilter(%input : memref<3x56x56x16xf64>, %filter: memref<3x3x3x64xf64>, %output: memref<54x54x16x64xf64>) {
 | 
			
		||||
  %scratch = alloc() : memref<23328xui8>
 | 
			
		||||
  %scratch = memref.alloc() : memref<23328xui8>
 | 
			
		||||
  "lmhlo_gpu.conv_backwardfilter"(%input, %filter, %output, %scratch)
 | 
			
		||||
    { backend_config = {algorithm = 1 : i64,
 | 
			
		||||
                        operand_0_layout = [3,2,1,0],
 | 
			
		||||
| 
						 | 
				
			
			@ -91,7 +91,7 @@ func @conv_backfilter(%input : memref<3x56x56x16xf64>, %filter: memref<3x3x3x64x
 | 
			
		|||
 | 
			
		||||
// CHECK-LABEL: func @conv_backinput
 | 
			
		||||
func @conv_backinput(%input : memref<4x5x16x16xf64>, %filter : memref<5x3x7x7xf64>, %output : memref<4x3x16x16xf64>) {
 | 
			
		||||
  %scratch = alloc() : memref<32xui8>
 | 
			
		||||
  %scratch = memref.alloc() : memref<32xui8>
 | 
			
		||||
  "lmhlo_gpu.conv_backwardinput"(%input, %filter, %output, %scratch)
 | 
			
		||||
    { backend_config = {algorithm = 1 : i64,
 | 
			
		||||
                        operand_0_layout = [3,2,1,0],
 | 
			
		||||
| 
						 | 
				
			
			@ -122,7 +122,7 @@ func @conv_backinput(%input : memref<4x5x16x16xf64>, %filter : memref<5x3x7x7xf6
 | 
			
		|||
 | 
			
		||||
// CHECK-LABEL: func @conv_fused
 | 
			
		||||
func @conv_fused(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>, %bias : memref<32xf16>, %output : memref<1x32x9x9xf16>) {
 | 
			
		||||
  %scratch = alloc() : memref<32xui8>
 | 
			
		||||
  %scratch = memref.alloc() : memref<32xui8>
 | 
			
		||||
  "lmhlo_gpu.conv_forward_fused"(%input, %filter, %bias, %output, %scratch)
 | 
			
		||||
    {activation_mode = "Relu",
 | 
			
		||||
     backend_config = {algorithm = 1 : i64,
 | 
			
		||||
| 
						 | 
				
			
			@ -153,7 +153,7 @@ func @conv_fused(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>,
 | 
			
		|||
 | 
			
		||||
// CHECK-LABEL: func @conv_fused_side_input
 | 
			
		||||
func @conv_fused_side_input(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>, %bias : memref<32xf16>, %side_input:  memref<32xf16>, %output : memref<1x32x9x9xf16>) {
 | 
			
		||||
  %scratch = alloc() : memref<0xui8>
 | 
			
		||||
  %scratch = memref.alloc() : memref<0xui8>
 | 
			
		||||
  "lmhlo_gpu.conv_forward_fused_with_side_input"(%input, %filter, %bias, %side_input, %output, %scratch)
 | 
			
		||||
    {activation_mode = "Relu",
 | 
			
		||||
     backend_config = {algorithm = 1 : i64,
 | 
			
		||||
| 
						 | 
				
			
			@ -218,8 +218,8 @@ func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>,
 | 
			
		|||
 | 
			
		||||
// CHECK-LABEL: func @cholesky
 | 
			
		||||
func @cholesky(%arg : memref<10x10xf32>, %out: memref<10x10xf32>) {
 | 
			
		||||
  %scratch = alloc() : memref<32xi8>
 | 
			
		||||
  %info = alloc() : memref<32xi32>
 | 
			
		||||
  %scratch = memref.alloc() : memref<32xi8>
 | 
			
		||||
  %info = memref.alloc() : memref<32xi32>
 | 
			
		||||
  "lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_lower = true }
 | 
			
		||||
      : (memref<10x10xf32>, memref<10x10xf32>, memref<32xi8>, memref<32xi32>) -> ()
 | 
			
		||||
  return
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -457,12 +457,12 @@ func @reduce_memref(%input: memref<10xf32>, %init: memref<f32>, %out: memref<1xf
 | 
			
		|||
// CHECK-LABEL: func @fusion_memref
 | 
			
		||||
func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: memref<10xf32>, %out: memref<10xf32>) -> () {
 | 
			
		||||
  "lmhlo.fusion"() ( {
 | 
			
		||||
    %0 = tensor_load %input1 : memref<10xf32>
 | 
			
		||||
    %1 = tensor_load %input2 : memref<10xf32>
 | 
			
		||||
    %0 = memref.tensor_load %input1 : memref<10xf32>
 | 
			
		||||
    %1 = memref.tensor_load %input2 : memref<10xf32>
 | 
			
		||||
    %2 = "mhlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
 | 
			
		||||
    %3 = tensor_load %input3 : memref<10xf32>
 | 
			
		||||
    %3 = memref.tensor_load %input3 : memref<10xf32>
 | 
			
		||||
    %4 = "mhlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
 | 
			
		||||
    tensor_store %4, %out : memref<10xf32>
 | 
			
		||||
    memref.tensor_store %4, %out : memref<10xf32>
 | 
			
		||||
    "lmhlo.terminator"() : () -> ()
 | 
			
		||||
  } ) : () -> ()
 | 
			
		||||
  return
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -108,15 +108,15 @@ func @batchNormInference_dynamic_shape(
 | 
			
		|||
  // CHECK-DAG: %[[C2:.*]] = constant 2 : index
 | 
			
		||||
  // CHECK-DAG: %[[C3:.*]] = constant 3 : index
 | 
			
		||||
  // CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e-03> : tensor<f32>
 | 
			
		||||
  // CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], %[[C0]] : tensor<?xf32>
 | 
			
		||||
  // CHECK-DAG: %[[DIM:.+]] = memref.dim %[[VARIANCE]], %[[C0]] : tensor<?xf32>
 | 
			
		||||
  // CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor.from_elements %[[DIM]] : tensor<1xindex>
 | 
			
		||||
  // CHECK-DAG: %[[EPS_BCAST:.+]] =  "mhlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
 | 
			
		||||
  // CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<?xf32>
 | 
			
		||||
  // CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<?xf32>) -> tensor<?xf32>
 | 
			
		||||
  // CHECK-DAG: %[[INPUT_DIM_0:.+]] = dim %[[X]], %[[C0]] : tensor<?x?x?x?xf32>
 | 
			
		||||
  // CHECK-DAG: %[[INPUT_DIM_1:.+]] = dim %[[X]], %[[C1]] : tensor<?x?x?x?xf32>
 | 
			
		||||
  // CHECK-DAG: %[[INPUT_DIM_2:.+]] = dim %[[X]], %[[C2]] : tensor<?x?x?x?xf32>
 | 
			
		||||
  // CHECK-DAG: %[[INPUT_DIM_3:.+]] = dim %[[X]], %[[C3]] : tensor<?x?x?x?xf32>
 | 
			
		||||
  // CHECK-DAG: %[[INPUT_DIM_0:.+]] = memref.dim %[[X]], %[[C0]] : tensor<?x?x?x?xf32>
 | 
			
		||||
  // CHECK-DAG: %[[INPUT_DIM_1:.+]] = memref.dim %[[X]], %[[C1]] : tensor<?x?x?x?xf32>
 | 
			
		||||
  // CHECK-DAG: %[[INPUT_DIM_2:.+]] = memref.dim %[[X]], %[[C2]] : tensor<?x?x?x?xf32>
 | 
			
		||||
  // CHECK-DAG: %[[INPUT_DIM_3:.+]] = memref.dim %[[X]], %[[C3]] : tensor<?x?x?x?xf32>
 | 
			
		||||
  // CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = tensor.from_elements %[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]] : tensor<4xindex>
 | 
			
		||||
  // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
 | 
			
		||||
  // CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue