diff --git a/BUILD b/BUILD index 77d153f..7d2df24 100644 --- a/BUILD +++ b/BUILD @@ -23,6 +23,7 @@ td_library( ], includes = ["include"], deps = [ + "@llvm-project//mlir:MemRefOpsTdFiles", "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:SideEffectTdFiles", ], @@ -462,6 +463,7 @@ cc_library( "@llvm-project//mlir:Analysis", "@llvm-project//mlir:CopyOpInterface", "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SideEffects", "@llvm-project//mlir:StandardOps", @@ -615,6 +617,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgOps", + "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:StandardOps", @@ -724,6 +727,7 @@ cc_library( "@llvm-project//mlir:Affine", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgTransforms", + "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:StandardOps", @@ -951,6 +955,7 @@ cc_library( ":hlo", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:Transforms", diff --git a/WORKSPACE b/WORKSPACE index 8a9166e..039e5ca 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -15,9 +15,9 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -LLVM_COMMIT = "6878be5dc3ec7031d0deec3e321310115bd71103" +LLVM_COMMIT = "678241795c957b18bc473045e48abe3f2a61ff5c" -LLVM_SHA256 = "f55187a3329fd97fd62fd0714783524d50a3be934a35484bd4442195fb25f0e5" +LLVM_SHA256 = "58fd00a9ed7841f36aa7042bb8c98323b030dee98abe36757eea9ddf4fd5ea75" LLVM_BAZEL_TAG = "llvm-project-{commit}".format(commit = LLVM_COMMIT) diff --git a/build_tools/llvm_version.txt b/build_tools/llvm_version.txt index dd73762..139d9b6 100644 --- a/build_tools/llvm_version.txt +++ b/build_tools/llvm_version.txt @@ -1,2 +1,2 @@ -6878be5dc3ec7031d0deec3e321310115bd71103 +678241795c957b18bc473045e48abe3f2a61ff5c diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h index deee865..7d32cff 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index 4a64395..db3aa43 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -33,6 +33,7 @@ limitations under the License. #ifndef LHLO_OPS #define LHLO_OPS +include "mlir/Dialect/MemRef/IR/MemRefBase.td" include "mlir/IR/OpBase.td" include "mlir/Interfaces/CopyOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -685,7 +686,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">] let extraClassDeclaration = [{ SmallVector getInputBuffers() { SmallVector buffers; - this->region().walk([&](TensorLoadOp load) { + this->region().walk([&](memref::TensorLoadOp load) { if (load.memref().getParentRegion()->isProperAncestor(®ion())) buffers.push_back(load.memref()); }); @@ -694,7 +695,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">] SmallVector getOutputBuffers() { SmallVector buffers; - this->region().walk([&](TensorStoreOp store) { + this->region().walk([&](memref::TensorStoreOp store) { if (store.memref().getParentRegion()->isProperAncestor(®ion())) buffers.push_back(store.memref()); }); @@ -703,7 +704,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">] SmallVector getFusionParameters() { SmallVector buffers; - this->region().walk([&](TensorLoadOp load) { + this->region().walk([&](memref::TensorLoadOp load) { if (load.memref().getParentRegion()->isProperAncestor(®ion())) buffers.push_back(load); }); @@ -712,7 +713,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">] SmallVector getFusionResults() { SmallVector buffers; - this->region().walk([&](TensorStoreOp store) { + this->region().walk([&](memref::TensorStoreOp store) { if (store.memref().getParentRegion()->isProperAncestor(®ion())) buffers.push_back(store.tensor()); }); diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td index 2f5b542..ba158d9 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td @@ -16,6 +16,7 @@ limitations under the License. #ifndef LHLO_OPS_BASE #define LHLO_OPS_BASE +include "mlir/Dialect/MemRef/IR/MemRefBase.td" include "mlir/IR/OpBase.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" diff --git a/lib/Dialect/mhlo/IR/lhlo_ops.cc b/lib/Dialect/mhlo/IR/lhlo_ops.cc index 960e643..40ebb51 100644 --- a/lib/Dialect/mhlo/IR/lhlo_ops.cc +++ b/lib/Dialect/mhlo/IR/lhlo_ops.cc @@ -33,6 +33,7 @@ limitations under the License. #include "llvm/Support/FormatVariadic.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -156,13 +157,13 @@ struct EraseConstOp : public OpRewritePattern { LogicalResult matchAndRewrite(ConstOp op, PatternRewriter& rewriter) const override { Value memref = op.output(); - if (!memref.getDefiningOp()) { + if (!memref.getDefiningOp()) { return failure(); } // Check that all uses of the memref are either DeallocOps or this op. for (Operation* user : memref.getUsers()) - if (user != op && !isa(user)) return failure(); + if (user != op && !isa(user)) return failure(); rewriter.eraseOp(op); return success(); diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index cc9ce58..38f817b 100644 --- a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -71,7 +71,7 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result, dynamic_operands.push_back(alloc_operand); } - return rewriter->create(loc, memref_type, dynamic_operands); + return rewriter->create(loc, memref_type, dynamic_operands); } Value InsertAlloc(Location loc, OpResult result, @@ -85,7 +85,7 @@ Value InsertAlloc(Location loc, OpResult result, MemRefType::get(result_type.getShape(), result_type.getElementType()); OpBuilder::InsertionGuard guard(*rewriter); rewriter->setInsertionPoint(result.getDefiningOp()); - auto alloc = rewriter->create(loc, memref_type); + auto alloc = rewriter->create(loc, memref_type); return alloc; } @@ -207,7 +207,7 @@ class HloToLhloReshapeUnrankedConverter if (unranked_operand_type == nullptr) return failure(); auto result_type = op.getType().cast(); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, adaptor.operand(), MemRefType::get(result_type.getShape(), result_type.getElementType())); return success(); @@ -235,7 +235,7 @@ class HloToLhloDynamicReshapeConverter return failure(); } mhlo::DynamicReshapeOp::Adaptor adaptor(operands); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, result_type, adaptor.operand(), adaptor.output_shape()); return success(); } @@ -273,7 +273,7 @@ class HloToLhloDynamicBroadcastInDimOpConverter // Inserts dynamic memref to change the layout of the memref to put 0-stride // and size of the target dimension if size-1 dimension expansion is // necessary. - MemRefReinterpretCastOp InsertDynamicMemrefCastOp( + memref::ReinterpretCastOp InsertDynamicMemrefCastOp( mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const { auto loc = op.getLoc(); auto operand_type = operand.getType().cast(); @@ -295,7 +295,7 @@ class HloToLhloDynamicBroadcastInDimOpConverter for (int i = operand_rank - 1; i >= 0; --i) { Value operand_dim_size = ShapedType::isDynamic(operand_shape[i]) - ? b->create(loc, operand, i).getResult() + ? b->create(loc, operand, i).getResult() : b->create(loc, operand_shape[i]).getResult(); operand_sizes[i] = operand_dim_size; @@ -355,7 +355,7 @@ class HloToLhloDynamicBroadcastInDimOpConverter makeStridedLinearLayoutMap(dynamic_layout, /*offset=*/0, b->getContext())); - auto transformed_operand = b->create( + auto transformed_operand = b->create( loc, type_erased_memref_type, operand, /*offset=*/b->getI64IntegerAttr(0), sizes, strides); return transformed_operand; @@ -484,12 +484,12 @@ struct HloToLhloReturnOpConverter : public BaseOpConversion { // TODO(b/175789537) Remove this pattern. class HloToLhloTensorStoreOpLegacyConverter - : public BaseOpConversion { + : public BaseOpConversion { public: - using BaseOpConversion::BaseOpConversion; + using BaseOpConversion::BaseOpConversion; LogicalResult matchAndRewrite( - mlir::TensorStoreOp op, ArrayRef operands, + mlir::memref::TensorStoreOp op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { rewriter.replaceOpWithNewOp(op, llvm::None, operands.front(), operands.back()); @@ -577,14 +577,16 @@ struct HloLegalizeToLhlo ConversionTarget target(context); target.addLegalDialect(); target.addLegalDialect(); + target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); target.addIllegalDialect(); // Declare tensor_load and tensor_store illegal. - target.addIllegalOp(); - // tensor_to_memref is illegal if it has uses. - // TODO(b/175670649) Make tensor_to_memref illegal. - target.addDynamicallyLegalOp( + target.addIllegalOp(); + // buffer_cast is illegal if it has uses. + // TODO(b/175670649) Make buffer_cast illegal. + target.addDynamicallyLegalOp( [](auto op) { return op->use_empty(); }); BufferizeTypeConverter converter; diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 3d770e7..06651a7 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -108,7 +108,7 @@ SmallVector ExtractDynamicSizes(OpBuilder& b, Location loc, dyn_sizes.push_back( b.create(loc, b.getIndexType(), extract)); } else { - dyn_sizes.push_back(b.create(loc, tensor, en.index())); + dyn_sizes.push_back(b.create(loc, tensor, en.index())); } } return dyn_sizes; @@ -324,13 +324,13 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern { } // Create two loads from the input. - auto lhs = rewriter.create(loc, lhlo_op.lhs()); - auto rhs = rewriter.create(loc, lhlo_op.rhs()); + auto lhs = rewriter.create(loc, lhlo_op.lhs()); + auto rhs = rewriter.create(loc, lhlo_op.rhs()); // TODO(ravishankarm) : Move this method out of lmhlo namespace. Value op_result = lmhlo::HloOpToStdScalarOp::map( lhlo_op, arg_type.getElementType(), llvm::ArrayRef{lhs, rhs}, &rewriter); - rewriter.create(loc, op_result, lhlo_op.out()); + rewriter.create(loc, op_result, lhlo_op.out()); rewriter.eraseOp(lhlo_op); return success(); } @@ -590,8 +590,8 @@ class LhloBroadcastInDimConverter operand_type.getDimSize(0) < result_type.getDimSize(broadcast_dims.front())) { Value zero = rewriter.create(loc, 0); - Value val = - rewriter.create(loc, operand, llvm::makeArrayRef({zero})); + Value val = rewriter.create(loc, operand, + llvm::makeArrayRef({zero})); rewriter.create( loc, /*inputs=*/ValueRange{}, /*outputBuffers=*/ValueRange{operand_adaptor.output()}, @@ -971,7 +971,8 @@ class ReduceConverter : public OpConversionPattern { } // First fill the output buffer with the init value. - Value init_value = rewriter.create(loc, adaptor.init_values()[0]); + Value init_value = + rewriter.create(loc, adaptor.init_values()[0]); rewriter.create(loc, adaptor.out()[0], init_value); DenseIntElementsAttr dimensions_attr = reduce_op.dimensions(); @@ -1011,9 +1012,9 @@ class ReduceConverter : public OpConversionPattern { // expects scalar SSA values. Add some allocs around the original op to // make it compatible. auto arg_type = block->getArgument(0).getType().cast(); - Value alloc_a = rewriter.create(loc, arg_type); - Value alloc_b = rewriter.create(loc, arg_type); - Value alloc_res = rewriter.create(loc, arg_type); + Value alloc_a = rewriter.create(loc, arg_type); + Value alloc_b = rewriter.create(loc, arg_type); + Value alloc_res = rewriter.create(loc, arg_type); // Now turn the existing signature // (memref, memref, memref) -> () @@ -1030,13 +1031,15 @@ class ReduceConverter : public OpConversionPattern { // Store the arguments into the newly allocated buffers. rewriter.setInsertionPointAfter(alloc_res.getDefiningOp()); - rewriter.create(loc, entry_block->getArgument(0), alloc_a); - rewriter.create(loc, entry_block->getArgument(1), alloc_b); + rewriter.create(loc, entry_block->getArgument(0), + alloc_a); + rewriter.create(loc, entry_block->getArgument(1), + alloc_b); rewriter.replaceOp(entry_block->getTerminator(), {}); // Load & yield the result. rewriter.setInsertionPointToEnd(entry_block); - auto load_res = rewriter.create(loc, alloc_res); + auto load_res = rewriter.create(loc, alloc_res); rewriter.create(loc, ValueRange{load_res}); } @@ -1099,8 +1102,8 @@ class SliceConverter : public OpConversionPattern { slice_op.strides().template getValue(i))); } if (isLHLO) { - auto linalg_op = - rewriter.create(loc, args[0], offsets, sizes, strides); + auto linalg_op = rewriter.create(loc, args[0], offsets, + sizes, strides); rewriter.create(loc, linalg_op, args[1]); rewriter.eraseOp(slice_op); } else { @@ -1149,14 +1152,14 @@ SmallVector GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc, switch (type) { case DotOperationType::kMatrixMatrix: { if (lhs.getType().cast().isDynamicDim(0)) - dyn_shape.push_back(b.create(loc, lhs, 0)); + dyn_shape.push_back(b.create(loc, lhs, 0)); if (rhs.getType().cast().isDynamicDim(1)) - dyn_shape.push_back(b.create(loc, rhs, 1)); + dyn_shape.push_back(b.create(loc, rhs, 1)); break; } case DotOperationType::kMatrixVector: { if (lhs.getType().cast().isDynamicDim(0)) - dyn_shape.push_back(b.create(loc, lhs, 0)); + dyn_shape.push_back(b.create(loc, lhs, 0)); break; } case DotOperationType::kVectorDot: @@ -1203,11 +1206,11 @@ SmallVector GetDotGeneralOpInitTensorDynSizes( OpBuilder& b, Location loc, Value lhs, Value rhs, ShapedType result_type) { SmallVector dyn_shape; if (result_type.isDynamicDim(0)) - dyn_shape.push_back(b.create(loc, lhs, 0)); + dyn_shape.push_back(b.create(loc, lhs, 0)); if (result_type.isDynamicDim(1)) - dyn_shape.push_back(b.create(loc, lhs, 1)); + dyn_shape.push_back(b.create(loc, lhs, 1)); if (result_type.isDynamicDim(2)) - dyn_shape.push_back(b.create(loc, rhs, 2)); + dyn_shape.push_back(b.create(loc, rhs, 2)); return dyn_shape; } @@ -1307,7 +1310,7 @@ SmallVector GetReduceOpInitTensorDynSizes( for (int i = 0, j = 0; i < rank; ++i) { if (s.count(i)) continue; if (!result_type.isDynamicDim(j++)) continue; - dyn_shape.push_back(b.create(loc, arg, i)); + dyn_shape.push_back(b.create(loc, arg, i)); } return dyn_shape; @@ -1467,7 +1470,7 @@ struct NormalConvOpOnTensorsConversion // The output shape is N spatial_dims F. SmallVector dyn_sizes; if (result_type.isDynamicDim(0)) { - dyn_sizes.push_back(rewriter.create(loc, input, 0)); + dyn_sizes.push_back(rewriter.create(loc, input, 0)); } for (int64_t i = 1, e = rank - 1; i < e; ++i) { if (result_type.isDynamicDim(i)) { @@ -1476,7 +1479,8 @@ struct NormalConvOpOnTensorsConversion } } if (result_type.isDynamicDim(rank - 1)) { - dyn_sizes.push_back(rewriter.create(loc, filter, rank - 1)); + dyn_sizes.push_back( + rewriter.create(loc, filter, rank - 1)); } Value init_tensor = rewriter.create( loc, dyn_sizes, result_type.getShape(), result_type.getElementType()); @@ -1856,8 +1860,8 @@ struct LhloLegalizeToLinalgPass OwningRewritePatternList patterns; ConversionTarget target(getContext()); target.addLegalDialect(); + math::MathDialect, memref::MemRefDialect, + StandardOpsDialect, AffineDialect>(); auto func = getFunction(); populateLHLOToLinalgConversionPattern(func.getContext(), &patterns); @@ -1881,6 +1885,9 @@ struct HloLegalizeToLinalgPass math::MathDialect, StandardOpsDialect, tensor::TensorDialect, scf::SCFDialect>(); + // TODO: DimOp shouldn't be in MemRefDialect + target.addLegalOp(); + auto func = getFunction(); mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns); if (failed(applyPartialConversion(func, target, std::move(patterns)))) { diff --git a/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc b/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc index 7acaf4a..34a201a 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -95,7 +96,7 @@ class LhloFuseLinalgPass continue; } - if (auto tensor_load = dyn_cast(definingOp)) { + if (auto tensor_load = dyn_cast(definingOp)) { auto alias = tensor_load.memref(); if (result_buffers.insert(alias).second) { worklist.push_back(alias); @@ -103,7 +104,7 @@ class LhloFuseLinalgPass continue; } - if (auto tensor_to_memref = dyn_cast(definingOp)) { + if (auto tensor_to_memref = dyn_cast(definingOp)) { auto alias = tensor_to_memref.tensor(); if (result_buffers.insert(alias).second) { worklist.push_back(alias); diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc index a8710ef..91c50b0 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc @@ -96,9 +96,10 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern { // Load the initial value and store it to the output. for (auto pair : llvm::zip(reduce_op.init_values(), reduce_op.out())) { - auto init_value = rewriter.create(loc, std::get<0>(pair)); - rewriter.create(loc, init_value, std::get<1>(pair), - ArrayRef{index}); + auto init_value = + rewriter.create(loc, std::get<0>(pair)); + rewriter.create( + loc, init_value, std::get<1>(pair), ArrayRef{index}); } // Insert a loop into the body to compute the reduction. The loop ranges @@ -128,8 +129,8 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern { auto oneAttr = rewriter.getI64IntegerAttr(1); OpFoldResult size = oneAttr; OpFoldResult stride = oneAttr; - auto accumulator = rewriter.create(loc, resType, output, - offset, size, stride); + auto accumulator = rewriter.create( + loc, resType, output, offset, size, stride); llvm::SmallVector indexings; auto input_buffer = *reduce_op.operands().begin(); auto input_type_rank = @@ -143,8 +144,8 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern { })); SmallVector sizes(input_type_rank, oneAttr); SmallVector strides(input_type_rank, oneAttr); - auto rhs = rewriter.create(loc, accumulator.getType(), input, - offsets, sizes, strides); + auto rhs = rewriter.create( + loc, accumulator.getType(), input, offsets, sizes, strides); // Now copy over the actual body of the reduction, leaving out the // terminator. @@ -179,8 +180,9 @@ struct LhloLegalizeToGpuPass void runOnFunction() override { OwningRewritePatternList patterns; ConversionTarget target(getContext()); - target.addLegalDialect(); + target.addLegalDialect(); target.addIllegalOp(); auto func = getFunction(); patterns.insert(func.getContext()); diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc index 4bf73ae..b45440d 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc @@ -18,6 +18,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BuiltinTypes.h" @@ -43,10 +44,11 @@ Value ApplySingleResultLhloCode(Location loc, ValueRange operands, Block* lhlo_block, OpBuilder* b) { SmallVector arg_bufs; for (auto arg_type : lhlo_block->getArgumentTypes()) { - arg_bufs.push_back(b->create(loc, arg_type.cast())); + arg_bufs.push_back( + b->create(loc, arg_type.cast())); } for (auto operand : llvm::enumerate(operands)) { - b->create(loc, operand.value(), arg_bufs[operand.index()]); + b->create(loc, operand.value(), arg_bufs[operand.index()]); } // Clone the ops from `lhlo_block`. BlockAndValueMapping mapping; @@ -55,7 +57,7 @@ Value ApplySingleResultLhloCode(Location loc, ValueRange operands, auto clone = b->clone(nested, mapping); mapping.map(nested.getResults(), clone->getResults()); } - return b->create(loc, arg_bufs.back()); + return b->create(loc, arg_bufs.back()); } // Converts a block with LHLO ops and with signature: @@ -78,7 +80,8 @@ void ConvertToReductionOperator(Location loc, scf::ReduceOp reduce_op, Value GetStaticOrDynamicDim(mlir::Location loc, Value shaped_value, size_t dim_index, int64_t dim, OpBuilder* b) { return dim == ShapedType::kDynamicSize - ? b->create(loc, shaped_value, dim_index).getResult() + ? b->create(loc, shaped_value, dim_index) + .getResult() : b->create(loc, dim); } @@ -249,8 +252,8 @@ class ReduceOpConverter : public OpConversionPattern { (is_reducing_dim ? reduce_step : parallel_step).push_back(step); } // Load initial value from memref. - SmallVector init_value = { - rewriter->create(loc, *reduce_op.init_values().begin())}; + SmallVector init_value = {rewriter->create( + loc, *reduce_op.init_values().begin())}; // Outer ParallelOp is not needed if it is a reduction across all dims. scf::ParallelOp outer; if (!parallel_lower.empty()) { @@ -272,7 +275,7 @@ class ReduceOpConverter : public OpConversionPattern { out_indices.push_back(rewriter->create(loc, 0)); } - rewriter->create(loc, reduction_result, out, out_indices); + rewriter->create(loc, reduction_result, out, out_indices); // Load the element to reduce. SmallVector indices; @@ -290,7 +293,7 @@ class ReduceOpConverter : public OpConversionPattern { } rewriter->setInsertionPointToStart(inner.getBody()); - Value elem = rewriter->create( + Value elem = rewriter->create( loc, *reduce_op.operands().begin(), indices); return rewriter->create(loc, elem); } @@ -385,7 +388,7 @@ class ReduceWindowOpConverter ConversionPatternRewriter* rewriter) const { auto loc = reduce_window_op.getLoc(); Value init_value = - rewriter->create(loc, reduce_window_op.init_value()); + rewriter->create(loc, reduce_window_op.init_value()); Value zero = rewriter->create(loc, 0); Value one = rewriter->create(loc, 1); @@ -408,7 +411,8 @@ class ReduceWindowOpConverter Value reduction_result = *window_loop.getResults().begin(); auto output_ivs = output_loop.getInductionVars(); - rewriter->create(loc, reduction_result, output, output_ivs); + rewriter->create(loc, reduction_result, output, + output_ivs); return std::make_pair(output_loop, window_loop); } @@ -439,7 +443,7 @@ class ReduceWindowOpConverter OpBuilder then_builder = elem_or_init.getThenBodyBuilder(rewriter->getListener()); - Value elem = then_builder.create( + Value elem = then_builder.create( loc, reduce_window_op.operand(), mapped_ivs.ivs); then_builder.create(loc, elem); @@ -497,8 +501,8 @@ class SelectAndScatterOpConverter auto selected_ivs = SelectIvs(s_and_s_op, loop_over_src, &rewriter); // Load `source[selected_ivs]`. - auto src_elem = rewriter.create(loc, s_and_s_op.source(), - loop_over_src.getInductionVars()); + auto src_elem = rewriter.create( + loc, s_and_s_op.source(), loop_over_src.getInductionVars()); // Compute `out[selected_ivs]` = scatter(out[selected_ivs], src_element)`. auto rmw = rewriter.create(loc, s_and_s_op.out(), @@ -517,14 +521,14 @@ class SelectAndScatterOpConverter void InitializeOutput(lmhlo::SelectAndScatterOp s_and_s_op, OpBuilder* b) const { auto loc = s_and_s_op.getLoc(); - Value init_value = b->create(loc, s_and_s_op.init_value()); + Value init_value = b->create(loc, s_and_s_op.init_value()); scf::ParallelOp loop_over_output = MakeLoopOverShape(loc, s_and_s_op.out(), b); OpBuilder::InsertionGuard guard(*b); b->setInsertionPointToStart(loop_over_output.getBody()); - b->create(loc, init_value, s_and_s_op.out(), - loop_over_output.getInductionVars()); + b->create(loc, init_value, s_and_s_op.out(), + loop_over_output.getInductionVars()); } struct WindowLoops { @@ -647,7 +651,7 @@ class SelectAndScatterOpConverter TypeRange iter_arg_types{ivs_val_flag->to_vector()}; Value operand_elem = - b->create(loc, s_and_s_op.operand(), operand_ivs); + b->create(loc, s_and_s_op.operand(), operand_ivs); auto if_init = b->create(loc, iter_arg_types, ivs_val_flag->is_init(), /*withElseRegion=*/true); @@ -712,8 +716,8 @@ struct LhloLegalizeToParallelLoopsPass // clang-format on ConversionTarget target(getContext()); - target.addLegalDialect(); + target.addLegalDialect(); target.addIllegalOp(); diff --git a/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc b/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc index 962d4ca..a6a240f 100644 --- a/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc +++ b/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc @@ -15,6 +15,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Attributes.h" @@ -58,7 +59,8 @@ Value CalculateShapeValue(Location loc, Value operand, int64_t rank = result_type.getRank(); shape_values.reserve(rank); for (int64_t i = 0; i < rank; ++i) { - shape_values.push_back(rewriter.create(loc, operand, i)); + shape_values.push_back( + rewriter.create(loc, operand, i)); } return rewriter.create(loc, shape_values); } diff --git a/tests/canonicalize.mlir b/tests/canonicalize.mlir index ec35f58..4206ac3 100644 --- a/tests/canonicalize.mlir +++ b/tests/canonicalize.mlir @@ -967,10 +967,10 @@ func @unpack_repack_same_tuple_single_element(%arg0: tuple>) -> tupl // CHECK-LABEL: func @erase_dead_lhlo_constant func @erase_dead_lhlo_constant() { - %M = alloc() : memref<256x1024xf32> + %M = memref.alloc() : memref<256x1024xf32> // CHECK-NEXT: return "lmhlo.constant"(%M) {value = dense<0.0> : tensor} : (memref<256x1024xf32>) -> () - dealloc %M : memref<256x1024xf32> + memref.dealloc %M : memref<256x1024xf32> return } @@ -979,9 +979,9 @@ func @erase_dead_lhlo_constant() { func @erase_dead_lhlo_constant_negative(%M : memref<4xf32>) -> memref<256x1024xf32> { // CHECK-NEXT: lmhlo.constant "lmhlo.constant"(%M) {value = dense<0.0> : tensor} : (memref<4xf32>) -> () - // CHECK-NEXT: alloc + // CHECK-NEXT: memref.alloc // CHECK-NEXT: lmhlo.constant - %N = alloc() : memref<256x1024xf32> + %N = memref.alloc() : memref<256x1024xf32> "lmhlo.constant"(%N) {value = dense<0.0> : tensor} : (memref<256x1024xf32>) -> () return %N : memref<256x1024xf32> } diff --git a/tests/end2end/broadcast.mlir b/tests/end2end/broadcast.mlir index a405b32..71bd028 100644 --- a/tests/end2end/broadcast.mlir +++ b/tests/end2end/broadcast.mlir @@ -27,7 +27,7 @@ func private @print_memref_i8(memref<*xi8>) attributes { llvm.emit_c_interface } func private @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface } func @trivial_broadcast_wrapper() { - %input_buf = alloc() : memref<3xf32> + %input_buf = memref.alloc() : memref<3xf32> %c1f32 = constant 1.0 : f32 %c2f32 = constant 2.0 : f32 @@ -36,19 +36,19 @@ func @trivial_broadcast_wrapper() { %c0 = constant 0 : index %c1 = constant 1 : index %c2 = constant 2 : index - store %c1f32, %input_buf[%c0] : memref<3xf32> - store %c2f32, %input_buf[%c1] : memref<3xf32> - store %c3f32, %input_buf[%c2] : memref<3xf32> - %input = tensor_load %input_buf : memref<3xf32> + memref.store %c1f32, %input_buf[%c0] : memref<3xf32> + memref.store %c2f32, %input_buf[%c1] : memref<3xf32> + memref.store %c3f32, %input_buf[%c2] : memref<3xf32> + %input = memref.tensor_load %input_buf : memref<3xf32> // Test BroadcastInDimOp. %output = "mhlo.broadcast_in_dim"(%input) { broadcast_dimensions = dense<0> : tensor<1xi64> } : (tensor<3xf32>) -> tensor<3x4xf32> - %output_buf = tensor_to_memref %output : memref<3x4xf32> + %output_buf = memref.buffer_cast %output : memref<3x4xf32> - %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32> + %unranked_output = memref.cast %output_buf : memref<3x4xf32> to memref<*xf32> call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] // CHECK-NEXT: [1, 1, 1, 1] @@ -63,9 +63,9 @@ func @trivial_broadcast_wrapper() { broadcast_dimensions = dense<0> : tensor<1xi64> } : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x4xf32> - %dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32> + %dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x4xf32> - %unranked_dyn_output = memref_cast %dyn_output_buf + %unranked_dyn_output = memref.cast %dyn_output_buf : memref<3x4xf32> to memref<*xf32> call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] @@ -76,29 +76,29 @@ func @trivial_broadcast_wrapper() { } func @broadcast_in_X_dim_wrapper() { - %input_buf = alloc() : memref<1x4xf32> + %input_buf = memref.alloc() : memref<1x4xf32> %c1f32 = constant 1.0 : f32 %c0 = constant 0 : index - store %c1f32, %input_buf[%c0, %c0] : memref<1x4xf32> + memref.store %c1f32, %input_buf[%c0, %c0] : memref<1x4xf32> %c2f32 = constant 2.0 : f32 %c1 = constant 1 : index - store %c2f32, %input_buf[%c0, %c1] : memref<1x4xf32> + memref.store %c2f32, %input_buf[%c0, %c1] : memref<1x4xf32> %c3f32 = constant 3.0 : f32 %c2 = constant 2 : index - store %c3f32, %input_buf[%c0, %c2] : memref<1x4xf32> + memref.store %c3f32, %input_buf[%c0, %c2] : memref<1x4xf32> %c4f32 = constant 4.0 : f32 %c3 = constant 3 : index - store %c4f32, %input_buf[%c0, %c3] : memref<1x4xf32> - %input = tensor_load %input_buf : memref<1x4xf32> + memref.store %c4f32, %input_buf[%c0, %c3] : memref<1x4xf32> + %input = memref.tensor_load %input_buf : memref<1x4xf32> // Test BroadcastInDimOp. %output = "mhlo.broadcast_in_dim"(%input) { broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> } : (tensor<1x4xf32>) -> tensor<3x4xf32> - %output_buf = tensor_to_memref %output : memref<3x4xf32> + %output_buf = memref.buffer_cast %output : memref<3x4xf32> - %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32> + %unranked_output = memref.cast %output_buf : memref<3x4xf32> to memref<*xf32> call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] // CHECK-NEXT: [1, 2, 3, 4] @@ -112,9 +112,9 @@ func @broadcast_in_X_dim_wrapper() { broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> } : (tensor<1x4xf32>, tensor<2xindex>) -> tensor<3x4xf32> - %dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32> + %dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x4xf32> - %unranked_dyn_output = memref_cast %dyn_output_buf + %unranked_dyn_output = memref.cast %dyn_output_buf : memref<3x4xf32> to memref<*xf32> call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] @@ -125,26 +125,26 @@ func @broadcast_in_X_dim_wrapper() { } func @broadcast_in_Y_dim_wrapper() { - %input_buf = alloc() : memref<3x1xf32> + %input_buf = memref.alloc() : memref<3x1xf32> %c1f32 = constant 1.0 : f32 %c0 = constant 0 : index - store %c1f32, %input_buf[%c0, %c0] : memref<3x1xf32> + memref.store %c1f32, %input_buf[%c0, %c0] : memref<3x1xf32> %c2f32 = constant 2.0 : f32 %c1 = constant 1 : index - store %c2f32, %input_buf[%c1, %c0] : memref<3x1xf32> + memref.store %c2f32, %input_buf[%c1, %c0] : memref<3x1xf32> %c3f32 = constant 3.0 : f32 %c2 = constant 2 : index - store %c3f32, %input_buf[%c2, %c0] : memref<3x1xf32> - %input = tensor_load %input_buf : memref<3x1xf32> + memref.store %c3f32, %input_buf[%c2, %c0] : memref<3x1xf32> + %input = memref.tensor_load %input_buf : memref<3x1xf32> // Test BroadcastInDimOp. %output = "mhlo.broadcast_in_dim"(%input) { broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> } : (tensor<3x1xf32>) -> tensor<3x4xf32> - %output_buf = tensor_to_memref %output : memref<3x4xf32> + %output_buf = memref.buffer_cast %output : memref<3x4xf32> - %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32> + %unranked_output = memref.cast %output_buf : memref<3x4xf32> to memref<*xf32> call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] // CHECK-NEXT: [1, 1, 1, 1] @@ -159,9 +159,9 @@ func @broadcast_in_Y_dim_wrapper() { broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> } : (tensor<3x1xf32>, tensor<2xindex>) -> tensor<3x4xf32> - %dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32> + %dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x4xf32> - %unranked_dyn_output = memref_cast %dyn_output_buf + %unranked_dyn_output = memref.cast %dyn_output_buf : memref<3x4xf32> to memref<*xf32> call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] @@ -172,29 +172,29 @@ func @broadcast_in_Y_dim_wrapper() { } func @broadcast_in_X_dim_transpose_wrapper() { - %input_buf = alloc() : memref<4x1xf32> + %input_buf = memref.alloc() : memref<4x1xf32> %c1f32 = constant 1.0 : f32 %c0 = constant 0 : index - store %c1f32, %input_buf[%c0, %c0] : memref<4x1xf32> + memref.store %c1f32, %input_buf[%c0, %c0] : memref<4x1xf32> %c2f32 = constant 2.0 : f32 %c1 = constant 1 : index - store %c2f32, %input_buf[%c1, %c0] : memref<4x1xf32> + memref.store %c2f32, %input_buf[%c1, %c0] : memref<4x1xf32> %c3f32 = constant 3.0 : f32 %c2 = constant 2 : index - store %c3f32, %input_buf[%c2, %c0] : memref<4x1xf32> + memref.store %c3f32, %input_buf[%c2, %c0] : memref<4x1xf32> %c4f32 = constant 4.0 : f32 %c3 = constant 3 : index - store %c4f32, %input_buf[%c3, %c0] : memref<4x1xf32> - %input = tensor_load %input_buf : memref<4x1xf32> + memref.store %c4f32, %input_buf[%c3, %c0] : memref<4x1xf32> + %input = memref.tensor_load %input_buf : memref<4x1xf32> // Test BroadcastInDimOp. %output = "mhlo.broadcast_in_dim"(%input) { broadcast_dimensions = dense<[1, 0]> : tensor<2xi64> } : (tensor<4x1xf32>) -> tensor<3x4xf32> - %output_buf = tensor_to_memref %output : memref<3x4xf32> + %output_buf = memref.buffer_cast %output : memref<3x4xf32> - %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32> + %unranked_output = memref.cast %output_buf : memref<3x4xf32> to memref<*xf32> call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] // CHECK-NEXT: [1, 2, 3, 4] @@ -208,9 +208,9 @@ func @broadcast_in_X_dim_transpose_wrapper() { broadcast_dimensions = dense<[1, 0]> : tensor<2xi64> } : (tensor<4x1xf32>, tensor<2xindex>) -> tensor<3x4xf32> - %dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32> + %dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x4xf32> - %unranked_dyn_output = memref_cast %dyn_output_buf + %unranked_dyn_output = memref.cast %dyn_output_buf : memref<3x4xf32> to memref<*xf32> call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] @@ -221,26 +221,26 @@ func @broadcast_in_X_dim_transpose_wrapper() { } func @broadcast_in_Y_dim_transpose_wrapper() { - %input_buf = alloc() : memref<1x3xf32> + %input_buf = memref.alloc() : memref<1x3xf32> %c1f32 = constant 1.0 : f32 %c0 = constant 0 : index - store %c1f32, %input_buf[%c0, %c0] : memref<1x3xf32> + memref.store %c1f32, %input_buf[%c0, %c0] : memref<1x3xf32> %c2f32 = constant 2.0 : f32 %c1 = constant 1 : index - store %c2f32, %input_buf[%c0, %c1] : memref<1x3xf32> + memref.store %c2f32, %input_buf[%c0, %c1] : memref<1x3xf32> %c3f32 = constant 3.0 : f32 %c2 = constant 2 : index - store %c3f32, %input_buf[%c0, %c2] : memref<1x3xf32> - %input = tensor_load %input_buf : memref<1x3xf32> + memref.store %c3f32, %input_buf[%c0, %c2] : memref<1x3xf32> + %input = memref.tensor_load %input_buf : memref<1x3xf32> // Test BroadcastInDimOp. %output = "mhlo.broadcast_in_dim"(%input) { broadcast_dimensions = dense<[1, 0]> : tensor<2xi64> } : (tensor<1x3xf32>) -> tensor<3x4xf32> - %output_buf = tensor_to_memref %output : memref<3x4xf32> + %output_buf = memref.buffer_cast %output : memref<3x4xf32> - %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32> + %unranked_output = memref.cast %output_buf : memref<3x4xf32> to memref<*xf32> call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] // CHECK-NEXT-NEXT: [1, 1, 1, 1] @@ -255,9 +255,9 @@ func @broadcast_in_Y_dim_transpose_wrapper() { broadcast_dimensions = dense<[1, 0]> : tensor<2xi64> } : (tensor<1x3xf32>, tensor<2xindex>) -> tensor<3x4xf32> - %dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32> + %dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x4xf32> - %unranked_dyn_output = memref_cast %dyn_output_buf + %unranked_dyn_output = memref.cast %dyn_output_buf : memref<3x4xf32> to memref<*xf32> call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] @@ -268,20 +268,20 @@ func @broadcast_in_Y_dim_transpose_wrapper() { } func @broadcast_scalar_1d_wrapper() { - %input_buf = alloc() : memref<1xf32> + %input_buf = memref.alloc() : memref<1xf32> %c1f32 = constant 1.0 : f32 %c0 = constant 0 : index - store %c1f32, %input_buf[%c0] : memref<1xf32> - %input = tensor_load %input_buf : memref<1xf32> + memref.store %c1f32, %input_buf[%c0] : memref<1xf32> + %input = memref.tensor_load %input_buf : memref<1xf32> // Test BroadcastInDimOp. %output = "mhlo.broadcast_in_dim"(%input) { broadcast_dimensions = dense<0> : tensor<1xi64> } : (tensor<1xf32>) -> tensor<3x4xf32> - %output_buf = tensor_to_memref %output : memref<3x4xf32> + %output_buf = memref.buffer_cast %output : memref<3x4xf32> - %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32> + %unranked_output = memref.cast %output_buf : memref<3x4xf32> to memref<*xf32> call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] // CHECK-NEXT: [1, 1, 1, 1] @@ -296,9 +296,9 @@ func @broadcast_scalar_1d_wrapper() { broadcast_dimensions = dense<0> : tensor<1xi64> } : (tensor<1xf32>, tensor<2xindex>) -> tensor<3x4xf32> - %dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32> + %dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x4xf32> - %unranked_dyn_output = memref_cast %dyn_output_buf + %unranked_dyn_output = memref.cast %dyn_output_buf : memref<3x4xf32> to memref<*xf32> call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] @@ -309,20 +309,20 @@ func @broadcast_scalar_1d_wrapper() { } func @broadcast_scalar_2d_wrapper() { - %input_buf = alloc() : memref<1x1xf32> + %input_buf = memref.alloc() : memref<1x1xf32> %c1f32 = constant 1.0 : f32 %c0 = constant 0 : index - store %c1f32, %input_buf[%c0, %c0] : memref<1x1xf32> - %input = tensor_load %input_buf : memref<1x1xf32> + memref.store %c1f32, %input_buf[%c0, %c0] : memref<1x1xf32> + %input = memref.tensor_load %input_buf : memref<1x1xf32> // Test BroadcastInDimOp. %output = "mhlo.broadcast_in_dim"(%input) { broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> } : (tensor<1x1xf32>) -> tensor<3x4xf32> - %output_buf = tensor_to_memref %output : memref<3x4xf32> + %output_buf = memref.buffer_cast %output : memref<3x4xf32> - %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32> + %unranked_output = memref.cast %output_buf : memref<3x4xf32> to memref<*xf32> call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] // CHECK-NEXT: [1, 1, 1, 1] @@ -337,9 +337,9 @@ func @broadcast_scalar_2d_wrapper() { broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> } : (tensor<1x1xf32>, tensor<2xindex>) -> tensor<3x4xf32> - %dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32> + %dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x4xf32> - %unranked_dyn_output = memref_cast %dyn_output_buf + %unranked_dyn_output = memref.cast %dyn_output_buf : memref<3x4xf32> to memref<*xf32> call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] @@ -350,7 +350,7 @@ func @broadcast_scalar_2d_wrapper() { } func @broadcast_to_the_same_shape() { - %input_buf = alloc() : memref<2x3xf32> + %input_buf = memref.alloc() : memref<2x3xf32> %c1f32 = constant 1.0 : f32 %c2f32 = constant 2.0 : f32 @@ -360,22 +360,22 @@ func @broadcast_to_the_same_shape() { %c1 = constant 1 : index %c2 = constant 2 : index %c3 = constant 3 : index - store %c1f32, %input_buf[%c0, %c0] : memref<2x3xf32> - store %c1f32, %input_buf[%c1, %c0] : memref<2x3xf32> - store %c2f32, %input_buf[%c0, %c1] : memref<2x3xf32> - store %c2f32, %input_buf[%c1, %c1] : memref<2x3xf32> - store %c3f32, %input_buf[%c0, %c2] : memref<2x3xf32> - store %c3f32, %input_buf[%c1, %c2] : memref<2x3xf32> - %input = tensor_load %input_buf : memref<2x3xf32> + memref.store %c1f32, %input_buf[%c0, %c0] : memref<2x3xf32> + memref.store %c1f32, %input_buf[%c1, %c0] : memref<2x3xf32> + memref.store %c2f32, %input_buf[%c0, %c1] : memref<2x3xf32> + memref.store %c2f32, %input_buf[%c1, %c1] : memref<2x3xf32> + memref.store %c3f32, %input_buf[%c0, %c2] : memref<2x3xf32> + memref.store %c3f32, %input_buf[%c1, %c2] : memref<2x3xf32> + %input = memref.tensor_load %input_buf : memref<2x3xf32> // Test BroadcastInDimOp. %output = "mhlo.broadcast_in_dim"(%input) { broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> } : (tensor<2x3xf32>) -> tensor<2x3xf32> - %output_buf = tensor_to_memref %output : memref<2x3xf32> + %output_buf = memref.buffer_cast %output : memref<2x3xf32> - %unraked_output = memref_cast %output_buf : memref<2x3xf32> to memref<*xf32> + %unraked_output = memref.cast %output_buf : memref<2x3xf32> to memref<*xf32> call @print_memref_f32(%unraked_output) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1] // CHECK-NEXT: [1, 2, 3] @@ -387,9 +387,9 @@ func @broadcast_to_the_same_shape() { broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> } : (tensor<2x3xf32>, tensor<2xindex>) -> tensor<2x3xf32> - %dyn_output_buf = tensor_to_memref %dyn_output : memref<2x3xf32> + %dyn_output_buf = memref.buffer_cast %dyn_output : memref<2x3xf32> - %unranked_dyn_output = memref_cast %dyn_output_buf + %unranked_dyn_output = memref.cast %dyn_output_buf : memref<2x3xf32> to memref<*xf32> call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1] @@ -399,7 +399,7 @@ func @broadcast_to_the_same_shape() { } func @broadcast_1d_to_2d() { - %input_buf = alloc() : memref<3xf32> + %input_buf = memref.alloc() : memref<3xf32> %c1f32 = constant 1.0 : f32 %c2f32 = constant 2.0 : f32 @@ -408,19 +408,19 @@ func @broadcast_1d_to_2d() { %c0 = constant 0 : index %c1 = constant 1 : index %c2 = constant 2 : index - store %c1f32, %input_buf[%c0] : memref<3xf32> - store %c2f32, %input_buf[%c1] : memref<3xf32> - store %c3f32, %input_buf[%c2] : memref<3xf32> - %input = tensor_load %input_buf : memref<3xf32> + memref.store %c1f32, %input_buf[%c0] : memref<3xf32> + memref.store %c2f32, %input_buf[%c1] : memref<3xf32> + memref.store %c3f32, %input_buf[%c2] : memref<3xf32> + %input = memref.tensor_load %input_buf : memref<3xf32> // Test BroadcastInDimOp. %output = "mhlo.broadcast_in_dim"(%input) { broadcast_dimensions = dense<0> : tensor<1xi64> } : (tensor<3xf32>) -> tensor<3x3xf32> - %output_buf = tensor_to_memref %output : memref<3x3xf32> + %output_buf = memref.buffer_cast %output : memref<3x3xf32> - %unraked_output = memref_cast %output_buf : memref<3x3xf32> to memref<*xf32> + %unraked_output = memref.cast %output_buf : memref<3x3xf32> to memref<*xf32> call @print_memref_f32(%unraked_output) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1] // CHECK-NEXT: [1, 1, 1] @@ -435,9 +435,9 @@ func @broadcast_1d_to_2d() { broadcast_dimensions = dense<0> : tensor<1xi64> } : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x3xf32> - %dyn_output_buf = tensor_to_memref %dyn_output : memref<3x3xf32> + %dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x3xf32> - %unranked_dyn_output = memref_cast %dyn_output_buf + %unranked_dyn_output = memref.cast %dyn_output_buf : memref<3x3xf32> to memref<*xf32> call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1] @@ -448,7 +448,7 @@ func @broadcast_1d_to_2d() { } func @broadcast_1d_to_2d_with_transpose() { - %input_buf = alloc() : memref<3xf32> + %input_buf = memref.alloc() : memref<3xf32> %c1f32 = constant 1.0 : f32 %c2f32 = constant 2.0 : f32 @@ -457,19 +457,19 @@ func @broadcast_1d_to_2d_with_transpose() { %c0 = constant 0 : index %c1 = constant 1 : index %c2 = constant 2 : index - store %c1f32, %input_buf[%c0] : memref<3xf32> - store %c2f32, %input_buf[%c1] : memref<3xf32> - store %c3f32, %input_buf[%c2] : memref<3xf32> - %input = tensor_load %input_buf : memref<3xf32> + memref.store %c1f32, %input_buf[%c0] : memref<3xf32> + memref.store %c2f32, %input_buf[%c1] : memref<3xf32> + memref.store %c3f32, %input_buf[%c2] : memref<3xf32> + %input = memref.tensor_load %input_buf : memref<3xf32> // Test BroadcastInDimOp. %output = "mhlo.broadcast_in_dim"(%input) { broadcast_dimensions = dense<1> : tensor<1xi64> } : (tensor<3xf32>) -> tensor<3x3xf32> - %output_buf = tensor_to_memref %output : memref<3x3xf32> + %output_buf = memref.buffer_cast %output : memref<3x3xf32> - %unraked_output = memref_cast %output_buf : memref<3x3xf32> to memref<*xf32> + %unraked_output = memref.cast %output_buf : memref<3x3xf32> to memref<*xf32> call @print_memref_f32(%unraked_output) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1] // CHECK-NEXT: [1, 2, 3] @@ -483,9 +483,9 @@ func @broadcast_1d_to_2d_with_transpose() { broadcast_dimensions = dense<1> : tensor<1xi64> } : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x3xf32> - %dyn_output_buf = tensor_to_memref %dyn_output : memref<3x3xf32> + %dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x3xf32> - %unranked_dyn_output = memref_cast %dyn_output_buf + %unranked_dyn_output = memref.cast %dyn_output_buf : memref<3x3xf32> to memref<*xf32> call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1] diff --git a/tests/end2end/legalize-trigonometric-to-approximation.mlir b/tests/end2end/legalize-trigonometric-to-approximation.mlir index 601dacb..60e9327 100644 --- a/tests/end2end/legalize-trigonometric-to-approximation.mlir +++ b/tests/end2end/legalize-trigonometric-to-approximation.mlir @@ -4,10 +4,10 @@ func private @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface // Helper function to print scalar values. func @print_f32(%arg : f32) { - %mem = alloca() : memref<1xf32> + %mem = memref.alloca() : memref<1xf32> %c0 = constant 0 : index - store %arg, %mem[%c0] : memref<1xf32> - %mem_unranked = memref_cast %mem : memref<1xf32> to memref<*xf32> + memref.store %arg, %mem[%c0] : memref<1xf32> + %mem_unranked = memref.cast %mem : memref<1xf32> to memref<*xf32> call @print_memref_f32(%mem_unranked) : (memref<*xf32>) -> () return } diff --git a/tests/end2end/reduce.mlir b/tests/end2end/reduce.mlir index 63b9dfc..52c59c4 100644 --- a/tests/end2end/reduce.mlir +++ b/tests/end2end/reduce.mlir @@ -21,21 +21,21 @@ func @reduce_add() { %c1 = constant 1 : index // Initialize input. - %input = alloc() : memref<2x3xf32> - %dim_x = dim %input, %c0 : memref<2x3xf32> - %dim_y = dim %input, %c1 : memref<2x3xf32> + %input = memref.alloc() : memref<2x3xf32> + %dim_x = memref.dim %input, %c0 : memref<2x3xf32> + %dim_y = memref.dim %input, %c1 : memref<2x3xf32> scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) { %i_i64 = index_cast %i : index to i64 %i_f32 = sitofp %i_i64 : i64 to f32 - store %i_f32, %input[%i, %j] : memref<2x3xf32> + memref.store %i_f32, %input[%i, %j] : memref<2x3xf32> } - %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32> + %unranked_input = memref.cast %input : memref<2x3xf32> to memref<*xf32> call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1] // CHECK: [0, 0, 0] // CHECK: [1, 1, 1] - %in = tensor_load %input : memref<2x3xf32> + %in = memref.tensor_load %input : memref<2x3xf32> %init = mhlo.constant dense<0.000000e+00> : tensor %reduce = "mhlo.reduce"(%in, %init) ( { @@ -45,8 +45,8 @@ func @reduce_add() { }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf32>, tensor) -> tensor<2xf32> - %output = tensor_to_memref %reduce : memref<2xf32> - %unranked_output = memref_cast %output : memref<2xf32> to memref<*xf32> + %output = memref.buffer_cast %reduce : memref<2xf32> + %unranked_output = memref.cast %output : memref<2xf32> to memref<*xf32> call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () // CHECK: rank = 1 offset = 0 sizes = [2] strides = [1] // CHECK: [0, 3] @@ -58,21 +58,21 @@ func @reduce_max() { %c1 = constant 1 : index // Initialize input. - %input = alloc() : memref<2x3xf32> - %dim_x = dim %input, %c0 : memref<2x3xf32> - %dim_y = dim %input, %c1 : memref<2x3xf32> + %input = memref.alloc() : memref<2x3xf32> + %dim_x = memref.dim %input, %c0 : memref<2x3xf32> + %dim_y = memref.dim %input, %c1 : memref<2x3xf32> scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) { %i_i64 = index_cast %i : index to i64 %i_f32 = sitofp %i_i64 : i64 to f32 - store %i_f32, %input[%i, %j] : memref<2x3xf32> + memref.store %i_f32, %input[%i, %j] : memref<2x3xf32> } - %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32> + %unranked_input = memref.cast %input : memref<2x3xf32> to memref<*xf32> call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1] // CHECK: [0, 0, 0] // CHECK: [1, 1, 1] - %in = tensor_load %input : memref<2x3xf32> + %in = memref.tensor_load %input : memref<2x3xf32> %init = mhlo.constant dense<0xff800000> : tensor %reduce = "mhlo.reduce"(%in, %init) ( { @@ -82,8 +82,8 @@ func @reduce_max() { }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf32>, tensor) -> tensor<2xf32> - %output = tensor_to_memref %reduce : memref<2xf32> - %unranked_output = memref_cast %output : memref<2xf32> to memref<*xf32> + %output = memref.buffer_cast %reduce : memref<2xf32> + %unranked_output = memref.cast %output : memref<2xf32> to memref<*xf32> call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () // CHECK: rank = 1 offset = 0 sizes = [2] strides = [1] // CHECK: [0, 1] diff --git a/tests/hlo-legalize-to-lhlo-unranked.mlir b/tests/hlo-legalize-to-lhlo-unranked.mlir index 79530d0..1c6aeac 100644 --- a/tests/hlo-legalize-to-lhlo-unranked.mlir +++ b/tests/hlo-legalize-to-lhlo-unranked.mlir @@ -17,7 +17,7 @@ func @dynamic_reshape_from_unranked( return %reshaped : tensor } // CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>, [[SHAPE:%.*]]: memref<1xi32>) -// CHECK-NEXT: memref_reshape [[ARG]]([[SHAPE]]) +// CHECK-NEXT: memref.reshape [[ARG]]([[SHAPE]]) // CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref // ----- @@ -30,7 +30,7 @@ func @dynamic_reshape_to_unranked( return %reshaped : tensor<*xf32> } // CHECK-SAME: ([[ARG:%.*]]: memref, [[SHAPE:%.*]]: memref) -// CHECK-NEXT: memref_reshape [[ARG]]([[SHAPE]]) +// CHECK-NEXT: memref.reshape [[ARG]]([[SHAPE]]) // CHECK-SAME: : (memref, memref) -> memref<*xf32> // ----- @@ -41,4 +41,4 @@ func @reshape_unranked(%operand: tensor<*xf32>) -> tensor { return %reshaped : tensor } // CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>) -// CHECK-NEXT: memref_cast [[ARG]] : memref<*xf32> to memref +// CHECK-NEXT: memref.cast [[ARG]] : memref<*xf32> to memref diff --git a/tests/hlo-legalize-to-lhlo.mlir b/tests/hlo-legalize-to-lhlo.mlir index 7d577c9..554ce1e 100644 --- a/tests/hlo-legalize-to-lhlo.mlir +++ b/tests/hlo-legalize-to-lhlo.mlir @@ -31,20 +31,20 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> return %5 : tensor<4xf32> } // CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>) -> memref<4xf32> -// CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32> +// CHECK-NEXT: %[[MAX_RESULT:.*]] = memref.alloc() : memref<4xf32> // CHECK-NEXT: "lmhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]]) -// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32> +// CHECK-NEXT: %[[ADD_RESULT:.*]] = memref.alloc() : memref<4xf32> // CHECK-NEXT: "lmhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]]) -// CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32> -// CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32> +// CHECK-NEXT: memref.dealloc %[[MAX_RESULT]] : memref<4xf32> +// CHECK-NEXT: %[[MIN_RESULT:.*]] = memref.alloc() : memref<4xf32> // CHECK-NEXT: "lmhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]]) -// CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32> +// CHECK-NEXT: %[[SUB_RESULT:.*]] = memref.alloc() : memref<4xf32> //  CHECK-NEXT: "lmhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]]) -// CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32> -// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32> +// CHECK-NEXT: memref.dealloc %[[MIN_RESULT]] : memref<4xf32> +// CHECK-NEXT: %[[MUL_RESULT:.*]] = memref.alloc() : memref<4xf32> // CHECK-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]]) -// CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32> -// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32> +// CHECK-NEXT: memref.dealloc %[[SUB_RESULT]] : memref<4xf32> +// CHECK-NEXT: memref.dealloc %[[ADD_RESULT]] : memref<4xf32> // CHECK-NEXT: return %[[MUL_RESULT]] : memref<4xf32> // ----- @@ -53,15 +53,15 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> func @fusion(%multiplier: tensor<2x2xf32>, %summand_1: tensor<2x2xf32>, %summand_2: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}}) - // CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32> + // CHECK-NEXT: %[[ADD_RESULT:.*]] = memref.alloc() : memref<2x2xf32> %sum = "mhlo.add"(%summand_1, %summand_2) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // CHECK-NEXT: "lmhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]]) - // CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32> + // CHECK-NEXT: %[[MUL_RESULT:.*]] = memref.alloc() : memref<2x2xf32> %result = "mhlo.multiply"(%sum, %multiplier) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // CHECK-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]]) - // CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32> + // CHECK-NEXT: memref.dealloc %[[ADD_RESULT]] : memref<2x2xf32> // CHECK-NEXT: return %[[MUL_RESULT]] : memref<2x2xf32> return %result : tensor<2x2xf32> } @@ -154,9 +154,9 @@ func @dyn_broadcast(%operand: tensor) -> tensor { // CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[C1:.*]] = constant 1 : index -// CHECK: %[[OPER_DIM_1:.*]] = dim %[[OPERAND]], %[[C1]] : memref +// CHECK: %[[OPER_DIM_1:.*]] = memref.dim %[[OPERAND]], %[[C1]] : memref // CHECK: %[[OP_STRIDE_0:.*]] = muli %[[C1]], %[[OPER_DIM_1]] : index -// CHECK: %[[OPER_DIM_0:.*]] = dim %[[OPERAND]], %[[C0]] : memref +// CHECK: %[[OPER_DIM_0:.*]] = memref.dim %[[OPERAND]], %[[C0]] : memref // CHECK: %[[EL0:.*]] = tensor.extract %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64> // CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index @@ -172,9 +172,9 @@ func @dyn_broadcast(%operand: tensor) -> tensor { // CHECK: %[[EXPAND_2:.*]] = cmpi slt, %[[OPER_DIM_1]], %[[SIZE_2]] : index // CHECK: %[[STRIDE_2:.*]] = select %[[EXPAND_2]], %[[C0]], %[[C1]] : index -// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]] : memref to memref +// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref.reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]] : memref to memref -// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref +// CHECK: %[[RESULT:.*]] = memref.alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref // CHECK: "lmhlo.copy"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) : (memref, memref) -> () // CHECK: return %[[RESULT]] : memref @@ -469,7 +469,7 @@ func @add_dyn(%lhs: tensor, %rhs: tensor) -> tensor { // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref -> tensor<2xindex> // CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex> // CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex> - // CHECK: %[[RESULT:.*]] = alloc(%[[EE0]], %[[EE1]]) + // CHECK: %[[RESULT:.*]] = memref.alloc(%[[EE0]], %[[EE1]]) // CHECK: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref, memref, memref) -> () return %result : tensor // CHECK: return %[[RESULT]] @@ -485,7 +485,7 @@ func @tanh_dyn(%arg0: tensor) -> tensor { // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref -> tensor<2xindex> // CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex> // CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex> - // CHECK: %[[RESULT:.*]] = alloc(%[[EE0]], %[[EE1]]) + // CHECK: %[[RESULT:.*]] = memref.alloc(%[[EE0]], %[[EE1]]) // CHECK: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref, memref) -> () return %result : tensor // CHECK: return %[[RESULT]] @@ -496,7 +496,7 @@ func @tanh_dyn(%arg0: tensor) -> tensor { // CHECK-LABEL: func @dot func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> { // CHECK-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]] -// CHECK-NEXT: %[[ALLOC:.*]] = alloc +// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc // CHECK: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) { // dot_dimension_numbers = { // lhs_batching_dimensions = dense<> : tensor<0xi64>, @@ -517,7 +517,7 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> { func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor<3x5x5x4xf32> { %c0 = constant 0 : index - // CHECK: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32> + // CHECK: %[[OUT:.*]] = memref.alloc() : memref<3x5x5x4xf32> // CHECK: "lmhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]]) // CHECK-SAME: padding = dense<[ // CHECK-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64> @@ -548,11 +548,11 @@ func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) // CHECK-LABEL: func @reduce func @reduce(%arg0: tensor<1x8xf32>, %arg1: tensor) -> tensor<1xf32> { - // CHECK: %[[OUT:.*]] = alloc() : memref<1xf32> + // CHECK: %[[OUT:.*]] = memref.alloc() : memref<1xf32> // CHECK: "lmhlo.reduce"(%{{.+}}, %{{.+}}, %[[OUT]]) ( { // CHECK: ^bb0(%[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref, // CHECK-SAME: %[[ARG3:.*]]: memref): - // CHECK: %[[TMP:.*]] = alloc() : memref + // CHECK: %[[TMP:.*]] = memref.alloc() : memref // CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[TMP]]) // CHECK: "lmhlo.copy"(%[[TMP]], %[[ARG3]]) // CHECK: "lmhlo.terminator"() : () -> () diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 714cd0a..53b4cfe 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -404,7 +404,7 @@ func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> { return %0: tensor<4x2x1x4x?x16xf32> } // CHECK: %[[C1:.*]] = constant 1 : index -// CHECK: %[[D1:.*]] = dim %{{.*}}, %[[C1]] : tensor<4x?x16xf32> +// CHECK: %[[D1:.*]] = memref.dim %{{.*}}, %[[C1]] : tensor<4x?x16xf32> // CHECK: linalg.init_tensor [4, 2, 1, 4, %[[D1]], 16] : tensor<4x2x1x4x?x16xf32> // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] // CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32): @@ -1024,7 +1024,7 @@ func @dot_matmul(%arg0: tensor<2x3xf32>, // CHECK-LABEL: func @dot_matmul( // CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xf32>, %[[ARG1:.*]]: tensor<3x?xf32>) // CHECK: %[[C1:.*]] = constant 1 : index -// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]] +// CHECK: %[[D1:.*]] = memref.dim %[[ARG1]], %[[C1]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]] // CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] // CHECK: linalg.matmul @@ -1040,7 +1040,7 @@ func @dot_matmul_i8_i8_i32(%arg0: tensor<2x3xi8>, // CHECK-LABEL: func @dot_matmul_i8_i8_i32( // CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi8>, %[[ARG1:.*]]: tensor<3x?xi8>) // CHECK: %[[C1:.*]] = constant 1 : index -// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]] +// CHECK: %[[D1:.*]] = memref.dim %[[ARG1]], %[[C1]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]] // CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] // CHECK: linalg.matmul @@ -1058,7 +1058,7 @@ func @dot_matmul_i16_i16_i32(%arg0: tensor<2x3xi16>, // CHECK-LABEL: func @dot_matmul_i16_i16_i32( // CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi16>, %[[ARG1:.*]]: tensor<3x?xi16>) // CHECK: %[[C1:.*]] = constant 1 : index -// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]] +// CHECK: %[[D1:.*]] = memref.dim %[[ARG1]], %[[C1]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]] // CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] // CHECK: linalg.matmul @@ -1076,7 +1076,7 @@ func @dot_matmul_i32_i32_i32(%arg0: tensor<2x3xi32>, // CHECK-LABEL: func @dot_matmul_i32_i32_i32( // CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi32>, %[[ARG1:.*]]: tensor<3x?xi32>) // CHECK: %[[C1:.*]] = constant 1 : index -// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]] +// CHECK: %[[D1:.*]] = memref.dim %[[ARG1]], %[[C1]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]] // CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] // CHECK: linalg.matmul @@ -1094,7 +1094,7 @@ func @dot_matvec(%arg0: tensor, // CHECK-LABEL: func @dot_matvec( // CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor<3xf32>) // CHECK: %[[C0:.*]] = constant 0 : index -// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]] +// CHECK: %[[D0:.*]] = memref.dim %[[ARG0]], %[[C0]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]]] // CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] // CHECK: linalg.matvec @@ -1134,11 +1134,11 @@ func @dot_general_batch_matmul(%arg0: tensor, // CHECK-LABEL: func @dot_general_batch_matmul( // CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) // CHECK: %[[C0:.*]] = constant 0 : index -// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]] +// CHECK: %[[D0:.*]] = memref.dim %[[ARG0]], %[[C0]] // CHECK: %[[C1:.*]] = constant 1 : index -// CHECK: %[[D1:.*]] = dim %[[ARG0]], %[[C1]] +// CHECK: %[[D1:.*]] = memref.dim %[[ARG0]], %[[C1]] // CHECK: %[[C2:.*]] = constant 2 : index -// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]] +// CHECK: %[[D2:.*]] = memref.dim %[[ARG1]], %[[C2]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]] // CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] // CHECK: linalg.batch_matmul @@ -1163,11 +1163,11 @@ func @dot_general_batch_matmul_i8_i8_i32(%arg0: tensor, // CHECK-LABEL: func @dot_general_batch_matmul_i8_i8_i32( // CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) // CHECK: %[[C0:.*]] = constant 0 : index -// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]] +// CHECK: %[[D0:.*]] = memref.dim %[[ARG0]], %[[C0]] // CHECK: %[[C1:.*]] = constant 1 : index -// CHECK: %[[D1:.*]] = dim %[[ARG0]], %[[C1]] +// CHECK: %[[D1:.*]] = memref.dim %[[ARG0]], %[[C1]] // CHECK: %[[C2:.*]] = constant 2 : index -// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]] +// CHECK: %[[D2:.*]] = memref.dim %[[ARG1]], %[[C2]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]] // CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] // CHECK: linalg.batch_matmul @@ -1192,11 +1192,11 @@ func @dot_general_batch_matmul_i16_i16_i32(%arg0: tensor, // CHECK-LABEL: func @dot_general_batch_matmul_i16_i16_i32( // CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) // CHECK: %[[C0:.*]] = constant 0 : index -// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]] +// CHECK: %[[D0:.*]] = memref.dim %[[ARG0]], %[[C0]] // CHECK: %[[C1:.*]] = constant 1 : index -// CHECK: %[[D1:.*]] = dim %[[ARG0]], %[[C1]] +// CHECK: %[[D1:.*]] = memref.dim %[[ARG0]], %[[C1]] // CHECK: %[[C2:.*]] = constant 2 : index -// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]] +// CHECK: %[[D2:.*]] = memref.dim %[[ARG1]], %[[C2]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]] // CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] // CHECK: linalg.batch_matmul @@ -1420,7 +1420,7 @@ func @reduce_dynamic(%arg0: tensor, %arg1: tensor) -> tensor // CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor // CHECK-DAG: %[[C0:.*]] = constant 0 : index -// CHECK-DAG: %[[DIM1:.*]] = dim %[[ARG0]], %[[C0]] : tensor +// CHECK-DAG: %[[DIM1:.*]] = memref.dim %[[ARG0]], %[[C0]] : tensor // CHECK-DAG: %[[INIT_TENSOR:.*]] = linalg.init_tensor [%[[DIM1]]] // CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill(%[[INIT_TENSOR]], %[[INIT]]) // CHECK: linalg.generic @@ -1531,9 +1531,9 @@ func @linalg.conv_1d_input_nwc_filter_wcf(%arg0: tensor, %arg1: tenso // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] // CHECK: %[[C0:.+]] = constant 0 : index -// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor +// CHECK: %[[DIM0:.+]] = memref.dim %[[ARG0]], %[[C0]] : tensor // CHECK: %[[C2:.+]] = constant 2 : index -// CHECK: %[[DIM2:.+]] = dim %[[ARG1]], %[[C2]] : tensor<2x?x?xf32> +// CHECK: %[[DIM2:.+]] = memref.dim %[[ARG1]], %[[C2]] : tensor<2x?x?xf32> // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 7, %[[DIM2]]] // CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32 // CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]]) @@ -1571,9 +1571,9 @@ func @conv_2d_input_nhwc_filter_hwcf(%arg0: tensor, %arg1: tensor<3 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] // CHECK: %[[C0:.+]] = constant 0 : index -// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor +// CHECK: %[[DIM0:.+]] = memref.dim %[[ARG0]], %[[C0]] : tensor // CHECK: %[[C3:.+]] = constant 3 : index -// CHECK: %[[DIM3:.+]] = dim %[[ARG1]], %[[C3]] : tensor<3x2x?x?xf32> +// CHECK: %[[DIM3:.+]] = memref.dim %[[ARG1]], %[[C3]] : tensor<3x2x?x?xf32> // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 2, 3, %[[DIM3]]] // CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32 // CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]]) @@ -1611,9 +1611,9 @@ func @conv_3d_input_ndhwc_filter_dhwcf(%arg0: tensor, %arg1: tens // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] // CHECK: %[[C0:.+]] = constant 0 : index -// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor +// CHECK: %[[DIM0:.+]] = memref.dim %[[ARG0]], %[[C0]] : tensor // CHECK: %[[C4:.+]] = constant 4 : index -// CHECK: %[[DIM4:.+]] = dim %[[ARG1]], %[[C4]] : tensor<2x2x2x?x?xf32> +// CHECK: %[[DIM4:.+]] = memref.dim %[[ARG1]], %[[C4]] : tensor<2x2x2x?x?xf32> // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 7, 7, 7, %[[DIM4]]] // CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32 // CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]]) diff --git a/tests/lhlo-fuse-linalg.mlir b/tests/lhlo-fuse-linalg.mlir index a46668c..2e5d494 100644 --- a/tests/lhlo-fuse-linalg.mlir +++ b/tests/lhlo-fuse-linalg.mlir @@ -7,7 +7,7 @@ iterator_types = ["parallel", "parallel"]} func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, %summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) { - %temp_result = alloc() : memref<6x6xf32> + %temp_result = memref.alloc() : memref<6x6xf32> linalg.generic #pointwise_2d_trait ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>) outs(%temp_result : memref<6x6xf32>) { @@ -22,7 +22,7 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, %out = mulf %temp_result_in, %multiplier_in : f32 linalg.yield %out : f32 } - dealloc %temp_result : memref<6x6xf32> + memref.dealloc %temp_result : memref<6x6xf32> return } // CHECK-LABEL: func @fusion @@ -62,7 +62,7 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, func @fusion_of_three(%arg0: memref<100x10xf32>, %arg1: memref<100xf32>, %arg2: memref<100x10xf32>) { - %0 = alloc() : memref<100x10xf32> + %0 = memref.alloc() : memref<100x10xf32> linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], @@ -72,7 +72,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>, ^bb0(%arg3: f32, %arg4: f32): // no predecessors linalg.yield %arg3 : f32 } - %1 = alloc() : memref<100x10xf32> + %1 = memref.alloc() : memref<100x10xf32> linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, @@ -84,7 +84,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>, %2 = subf %arg3, %arg4 : f32 linalg.yield %2 : f32 } - dealloc %0 : memref<100x10xf32> + memref.dealloc %0 : memref<100x10xf32> linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], @@ -95,7 +95,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>, %2 = math.exp %arg3 : f32 linalg.yield %2 : f32 } - dealloc %1 : memref<100x10xf32> + memref.dealloc %1 : memref<100x10xf32> return } // CHECK-LABEL: func @fusion @@ -141,7 +141,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>, "parallel"]} func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>, %summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) { - %temp_result = alloc() : memref<6x6x6x6xf32> + %temp_result = memref.alloc() : memref<6x6x6x6xf32> linalg.generic #pointwise_4d_trait ins(%summand_1, %summand_2 : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>) outs(%temp_result : memref<6x6x6x6xf32>) { @@ -156,7 +156,7 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32 %out = mulf %temp_result_in, %multiplier_in : f32 linalg.yield %out : f32 } - dealloc %temp_result : memref<6x6x6x6xf32> + memref.dealloc %temp_result : memref<6x6x6x6xf32> return } // CHECK-LABEL: func @fusion_4d @@ -200,7 +200,7 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32 iterator_types = ["parallel", "parallel"]} func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, %summand_2: memref<6x6xf32>) -> memref<6x6xf32> { - %temp_result = alloc() : memref<6x6xf32> + %temp_result = memref.alloc() : memref<6x6xf32> linalg.generic #pointwise_2d_trait ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>) outs(%temp_result : memref<6x6xf32>) { @@ -208,7 +208,7 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, %out = addf %summand_1_in, %summand_2_in : f32 linalg.yield %out : f32 } - %result = alloc() : memref<6x6xf32> + %result = memref.alloc() : memref<6x6xf32> linalg.generic #pointwise_2d_trait ins(%temp_result, %multiplier : memref<6x6xf32>, memref<6x6xf32>) outs(%result : memref<6x6xf32>) { @@ -216,7 +216,7 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, %out = mulf %temp_result_in, %multiplier_in : f32 linalg.yield %out : f32 } - dealloc %temp_result : memref<6x6xf32> + memref.dealloc %temp_result : memref<6x6xf32> return %result : memref<6x6xf32> } @@ -258,7 +258,7 @@ func @view_result(%arg0: memref, %arg1: memref, %arg2: index) -> memref<*xf32> { %c1 = constant 1 : index %c0 = constant 0 : index - %1 = alloc(%arg2) : memref + %1 = memref.alloc(%arg2) : memref linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} @@ -267,7 +267,7 @@ func @view_result(%arg0: memref, %arg1: memref, %arg2: index) %13 = absf %arg3 : f32 linalg.yield %13 : f32 } - %2 = memref_reshape %1(%arg1) + %2 = memref.reshape %1(%arg1) : (memref, memref) -> memref<*xf32> return %2 : memref<*xf32> } @@ -279,7 +279,7 @@ func @view_result(%arg0: memref, %arg1: memref, %arg2: index) // CHECK-NOT: scf.for // CHECK: linalg.generic // CHECK: absf -// CHECK: memref_reshape +// CHECK: memref.reshape // TILED-LABEL: func @view_result // TILED-DAG: %[[C2:.*]] = constant 2 @@ -288,7 +288,7 @@ func @view_result(%arg0: memref, %arg1: memref, %arg2: index) // TILED-NOT: scf.for // TILED: linalg.generic // TILED: absf -// TILED: memref_reshape +// TILED: memref.reshape // PLOOP-LABEL: func @view_result @@ -297,20 +297,20 @@ func @view_result(%arg0: memref, %arg1: memref, %arg2: index) // PLOOP-NOT: scf.parallel // PLOOP: linalg.generic // PLOOP: absf -// PLOOP: memref_reshape +// PLOOP: memref.reshape // ----- // Confirm that tiling information is passed through RegionBranchOpInterfaces. -// This test also uses memref_reshape, just to have a value to return through +// This test also uses memref.reshape, just to have a value to return through // the if statement. func @branching_result(%arg0: memref, %arg1: memref, %arg2: index) -> memref<*xf32> { %c1 = constant 1 : index %c0 = constant 0 : index - %1 = alloc(%arg2) : memref + %1 = memref.alloc(%arg2) : memref linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} @@ -321,11 +321,11 @@ func @branching_result(%arg0: memref, %arg1: memref, %arg2: inde } %true = constant 1 : i1 %3 = scf.if %true -> memref<*xf32> { - %2 = memref_reshape %1(%arg1) + %2 = memref.reshape %1(%arg1) : (memref, memref) -> memref<*xf32> scf.yield %2 : memref<*xf32> } else { - %2 = memref_reshape %1(%arg1) + %2 = memref.reshape %1(%arg1) : (memref, memref) -> memref<*xf32> scf.yield %2 : memref<*xf32> } @@ -340,10 +340,10 @@ func @branching_result(%arg0: memref, %arg1: memref, %arg2: inde // CHECK: linalg.generic // CHECK: absf // CHECK: scf.if -// CHECK: memref_reshape +// CHECK: memref.reshape // CHECK: scf.yield // CHECK: else -// CHECK: memref_reshape +// CHECK: memref.reshape // CHECK: scf.yield // TILED-LABEL: func @branching_result @@ -354,10 +354,10 @@ func @branching_result(%arg0: memref, %arg1: memref, %arg2: inde // TILED: linalg.generic // TILED: absf // TILED: scf.if -// TILED: memref_reshape +// TILED: memref.reshape // TILED: scf.yield // TILED: else -// TILED: memref_reshape +// TILED: memref.reshape // TILED: scf.yield // PLOOP-LABEL: func @branching_result @@ -367,10 +367,10 @@ func @branching_result(%arg0: memref, %arg1: memref, %arg2: inde // PLOOP: linalg.generic // PLOOP: absf // PLOOP: scf.if -// PLOOP: memref_reshape +// PLOOP: memref.reshape // PLOOP: scf.yield // PLOOP: else -// PLOOP: memref_reshape +// PLOOP: memref.reshape // PLOOP: scf.yield // ----- @@ -380,7 +380,7 @@ func @branching_result(%arg0: memref, %arg1: memref, %arg2: inde func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>) -> memref { %c1 = constant 1 : index - %1 = alloc() : memref<32xf32> + %1 = memref.alloc() : memref<32xf32> linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} @@ -389,9 +389,9 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>) %13 = absf %arg3 : f32 linalg.yield %13 : f32 } - %2 = tensor_load %1 : memref<32xf32> + %2 = memref.tensor_load %1 : memref<32xf32> %3 = tensor.cast %2 : tensor<32xf32> to tensor - %4 = tensor_to_memref %3 : memref + %4 = memref.buffer_cast %3 : memref return %4 : memref } @@ -402,9 +402,9 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>) // CHECK-NOT: scf.for // CHECK: linalg.generic // CHECK: absf -// CHECK: tensor_load +// CHECK: memref.tensor_load // CHECK: tensor.cast -// CHECK: tensor_to_memref +// CHECK: memref.buffer_cast // TILED-LABEL: func @tensor_ops // TILED-DAG: %[[C2:.*]] = constant 2 @@ -413,9 +413,9 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>) // TILED-NOT: scf.for // TILED: linalg.generic // TILED: absf -// TILED: tensor_load +// TILED: memref.tensor_load // TILED: tensor.cast -// TILED: tensor_to_memref +// TILED: memref.buffer_cast // PLOOP-LABEL: func @tensor_ops @@ -424,6 +424,6 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>) // PLOOP-NOT: scf.parallel // PLOOP: linalg.generic // PLOOP: absf -// PLOOP: tensor_load +// PLOOP: memref.tensor_load // PLOOP: tensor.cast -// PLOOP: tensor_to_memref +// PLOOP: memref.buffer_cast diff --git a/tests/lhlo-legalize-select-and-scatter.mlir b/tests/lhlo-legalize-select-and-scatter.mlir index a022d20..ba79ba4 100644 --- a/tests/lhlo-legalize-select-and-scatter.mlir +++ b/tests/lhlo-legalize-select-and-scatter.mlir @@ -49,10 +49,10 @@ func @select_and_scatter(%arg: memref<112x112xf32>, // CHECK-DAG: [[CTRUE:%.*]] = constant true // Parallel loop to initialize the output buffer. -// CHECK: [[INIT:%.*]] = load [[INIT_BUF]][] : memref +// CHECK: [[INIT:%.*]] = memref.load [[INIT_BUF]][] : memref // CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) // CHECK-SAME: to ([[C112]], [[C112]]) step ([[C1]], [[C1]]) { -// CHECK: store [[INIT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]] +// CHECK: memref.store [[INIT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]] // CHECK: scf.yield // CHECK: } @@ -101,7 +101,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>, // INBOUNDS-THEN-BODY, i.e. if INBOUNDS == true - // CHECK: [[ARG_ELEM:%.*]] = load [[ARG_BUF]]{{\[}}[[ARG_I]], [[ARG_J]]] + // CHECK: [[ARG_ELEM:%.*]] = memref.load [[ARG_BUF]]{{\[}}[[ARG_I]], [[ARG_J]]] // CHECK: [[IF_INIT_RES:%.*]]:4 // CHECK-SAME: = scf.if [[SEL_INIT]] -> (index, index, f32, i1) { @@ -114,16 +114,16 @@ func @select_and_scatter(%arg: memref<112x112xf32>, // Allocate buffers for ARG element, current selected value to adapt LHLO // code. - // CHECK: [[ARG_ELEM_BUF:%.*]] = alloc() : memref - // CHECK: [[SEL_VAL_BUF:%.*]] = alloc() : memref - // CHECK: [[PRED_BUF:%.*]] = alloc() : memref - // CHECK: store [[ARG_ELEM]], [[ARG_ELEM_BUF]][] : memref - // CHECK: store [[SEL_VAL]], [[SEL_VAL_BUF]][] : memref + // CHECK: [[ARG_ELEM_BUF:%.*]] = memref.alloc() : memref + // CHECK: [[SEL_VAL_BUF:%.*]] = memref.alloc() : memref + // CHECK: [[PRED_BUF:%.*]] = memref.alloc() : memref + // CHECK: memref.store [[ARG_ELEM]], [[ARG_ELEM_BUF]][] : memref + // CHECK: memref.store [[SEL_VAL]], [[SEL_VAL_BUF]][] : memref // Compute PRED. // CHECK: "lmhlo.compare"( // CHECK-SAME: [[ARG_ELEM_BUF]], [[SEL_VAL_BUF]], [[PRED_BUF]]) - // CHECK: [[PRED:%.*]] = load [[PRED_BUF]][] : memref + // CHECK: [[PRED:%.*]] = memref.load [[PRED_BUF]][] : memref // Depending on PRED, return ARG ivs & elem or current select ivs and value. @@ -165,7 +165,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>, // CHECK: } // Use selected ivs to load element from the SRC buffer. -// CHECK: [[SRC_ELEM:%.*]] = load [[SRC_BUF]]{{\[}}[[II]], [[JJ]]] +// CHECK: [[SRC_ELEM:%.*]] = memref.load [[SRC_BUF]]{{\[}}[[II]], [[JJ]]] // Update of RESULT[SELECTED_I, SELECTED_J] should be done atomically, because // it may happen that several other threads select the same IVs if the windows @@ -175,16 +175,16 @@ func @select_and_scatter(%arg: memref<112x112xf32>, // CHECK: ^bb0([[CUR_RES:%.*]]: f32): // Allocate buffers for ARG element, current selected value to adapt LHLO code. -// CHECK: [[SRC_ELEM_BUF:%.*]] = alloc() : memref -// CHECK: [[CUR_RES_BUF:%.*]] = alloc() : memref -// CHECK: [[RES_BUF:%.*]] = alloc() : memref -// CHECK: store [[SRC_ELEM]], [[SRC_ELEM_BUF]][] : memref -// CHECK: store [[CUR_RES]], [[CUR_RES_BUF]][] : memref +// CHECK: [[SRC_ELEM_BUF:%.*]] = memref.alloc() : memref +// CHECK: [[CUR_RES_BUF:%.*]] = memref.alloc() : memref +// CHECK: [[RES_BUF:%.*]] = memref.alloc() : memref +// CHECK: memref.store [[SRC_ELEM]], [[SRC_ELEM_BUF]][] : memref +// CHECK: memref.store [[CUR_RES]], [[CUR_RES_BUF]][] : memref // Compute scatter value. // CHECK: "lmhlo.add"([[SRC_ELEM_BUF]], [[CUR_RES_BUF]], [[RES_BUF]]) : // CHECK-SAME: (memref, memref, memref) -> () -// CHECK: [[RES:%.*]] = load [[RES_BUF]][] : memref +// CHECK: [[RES:%.*]] = memref.load [[RES_BUF]][] : memref // Atomic RMW terminator that returns updated value. // CHECK: atomic_yield [[RES]] : f32 diff --git a/tests/lhlo-legalize-to-gpu.mlir b/tests/lhlo-legalize-to-gpu.mlir index 26b230a..92afbba 100644 --- a/tests/lhlo-legalize-to-gpu.mlir +++ b/tests/lhlo-legalize-to-gpu.mlir @@ -19,14 +19,14 @@ func @reduce(%arg: memref<100x10xf32>, // CHECK-DAG: %[[C100:.*]] = constant 100 : index // CHECK-DAG: %[[C1:.*]] = constant 1 : index // CHECK: gpu.launch blocks({{.*}}, {{.*}}, {{.*}}) in ({{.*}} = %[[C1]], {{.*}} = %[[C1]], {{.*}} = %[[C1]]) threads(%[[IDX:.*]], {{.*}}, {{.*}}) in ({{.*}} = %[[C100]], {{.*}} = %[[C1]], {{.*}} = %[[C1]]) { -// CHECK: %[[ACC:.*]] = load %[[ARG1]][] : memref +// CHECK: %[[ACC:.*]] = memref.load %[[ARG1]][] : memref // CHECK: store %[[ACC]], %[[ARG2]][%[[IDX:.*]]] : memref<100xf32> // CHECK-DAG: %[[LB:.*]] = constant 0 : index // CHECK-DAG: %[[UB:.*]] = constant 10 : index // CHECK-DAG: %[[STEP:.*]] = constant 1 : index // CHECK: scf.for %[[IDX1:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { -// CHECK: %[[LHS:.*]] = subview %[[ARG2]][%[[IDX]]] [1] [1] : memref<100xf32> to memref -// CHECK: %[[RHS:.*]] = subview %[[ARG0]][%[[IDX]], %[[IDX1]]] [1, 1] [1, 1] : memref<100x10xf32> to memref +// CHECK: %[[LHS:.*]] = memref.subview %[[ARG2]][%[[IDX]]] [1] [1] : memref<100xf32> to memref +// CHECK: %[[RHS:.*]] = memref.subview %[[ARG0]][%[[IDX]], %[[IDX1]]] [1, 1] [1, 1] : memref<100x10xf32> to memref // CHECK: "lmhlo.add"(%[[LHS]], %[[RHS]], %[[LHS]]) : (memref, memref, memref) -> () // CHECK: } // CHECK: gpu.terminator diff --git a/tests/lhlo-legalize-to-linalg.mlir b/tests/lhlo-legalize-to-linalg.mlir index f0e0234..249e60c 100644 --- a/tests/lhlo-legalize-to-linalg.mlir +++ b/tests/lhlo-legalize-to-linalg.mlir @@ -52,10 +52,10 @@ func @element_wise_scalar(%lhs: memref, %rhs: memref, : (memref, memref, memref) -> () return } -// CHECK: %[[LHS:.*]] = load -// CHECK: %[[RHS:.*]] = load +// CHECK: %[[LHS:.*]] = memref.load +// CHECK: %[[RHS:.*]] = memref.load // CHECK: %[[RES:.*]] = addf %[[LHS]], %[[RHS]] -// CHECK: store %[[RES]] +// CHECK: memref.store %[[RES]] // CHECK-NEXT: return // ----- @@ -347,7 +347,7 @@ func @static_broadcast_in_dim_with_one_to_many(%operand: memref<1xf32>, } // CHECK-NOT: linalg.reshape // CHECK: %[[C0:.*]] = constant 0 : index -// CHECK: %[[VALUE:.*]] = load %{{.*}}[[C0]] +// CHECK: %[[VALUE:.*]] = memref.load %{{.*}}[[C0]] // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[RESULT_MAP]]] // CHECK-NEXT: ^bb0(%{{.+}}: f32): // CHECK-NEXT: linalg.yield %[[VALUE]] : f32 @@ -785,7 +785,7 @@ func @slice(%operand: memref, %result: memref) { } : (memref, memref) -> () return } -// CHECK: %[[RESULT:.*]] = subview %[[IN]][0, 1] [2, 2] [1, 1] : memref to memref<2x2xf32, #{{.*}}> +// CHECK: %[[RESULT:.*]] = memref.subview %[[IN]][0, 1] [2, 2] [1, 1] : memref to memref<2x2xf32, #{{.*}}> // CHECK: linalg.copy(%[[RESULT]], %[[OUT]]) // ----- @@ -899,7 +899,7 @@ func @reverse(%arg0: memref<2x3xf32>, %arg1: memref<2x3xf32>) { func @conv(%input: memref<3x5x5x3xf32>, %filter: memref<2x2x3x4xf32>, %output: memref<3x5x5x4xf32>) { %c0 = constant 0 : index - %0 = alloc() : memref<3x5x5x4xf32> + %0 = memref.alloc() : memref<3x5x5x4xf32> // CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}}) // CHECK-SAME: dilations = [1, 2] // CHECK-SAME: padding = dense<{{\[\[}}0, 1], [0, 1]]> : tensor<2x2xi64> @@ -948,22 +948,22 @@ func @reduce_add(%arg: memref<100x10xf32>, : (memref<100x10xf32>, memref, memref<100xf32>) -> () return } -// CHECK: %[[INIT_VAL:.*]] = load %arg1[] : memref +// CHECK: %[[INIT_VAL:.*]] = memref.load %arg1[] : memref // CHECK: linalg.fill(%arg2, %[[INIT_VAL]]) // CHECK: linalg.generic { // CHECK-SAME: indexing_maps = [#[[REDUCE_INPUT_MAP]], #[[REDUCE_OUTPUT_MAP]]], // CHECK-SAME: iterator_types = ["parallel", "reduction"]} // CHECK-SAME: ins(%arg0 : memref<100x10xf32>) outs(%arg2 : memref<100xf32>) { -// CHECK: alloca -// CHECK-NEXT: alloca -// CHECK-NEXT: alloca -// CHECK-NEXT: store -// CHECK-NEXT: store -// CHECK-NEXT: load -// CHECK-NEXT: load +// CHECK: memref.alloca +// CHECK-NEXT: memref.alloca +// CHECK-NEXT: memref.alloca +// CHECK-NEXT: memref.store +// CHECK-NEXT: memref.store +// CHECK-NEXT: memref.load +// CHECK-NEXT: memref.load // CHECK-NEXT: addf -// CHECK-NEXT: store -// CHECK-NEXT: load +// CHECK-NEXT: memref.store +// CHECK-NEXT: memref.load // CHECK-NEXT: linalg.yield // CHECK-NEXT: } @@ -984,22 +984,22 @@ func @reduce_maximum(%arg: memref<100x10xf32>, : (memref<100x10xf32>, memref, memref<100xf32>) -> () return } -// CHECK: %[[INIT_VAL:.*]] = load %arg1[] : memref +// CHECK: %[[INIT_VAL:.*]] = memref.load %arg1[] : memref // CHECK: linalg.fill(%arg2, %[[INIT_VAL]]) // CHECK: linalg.generic { // CHECK-SAME: indexing_maps = [#[[REDUCE_INPUT_MAP]], #[[REDUCE_OUTPUT_MAP]]], // CHECK-SAME: iterator_types = ["parallel", "reduction"]} // CHECK-SAME: ins(%arg0 : memref<100x10xf32>) outs(%arg2 : memref<100xf32>) { -// CHECK: alloca -// CHECK-NEXT: alloca -// CHECK-NEXT: alloca -// CHECK-NEXT: store -// CHECK-NEXT: store -// CHECK-NEXT: load -// CHECK-NEXT: load +// CHECK: memref.alloca +// CHECK-NEXT: memref.alloca +// CHECK-NEXT: memref.alloca +// CHECK-NEXT: memref.store +// CHECK-NEXT: memref.store +// CHECK-NEXT: memref.load +// CHECK-NEXT: memref.load // CHECK: cmpf // CHECK: select -// CHECK: store -// CHECK-NEXT: load +// CHECK: memref.store +// CHECK-NEXT: memref.load // CHECK-NEXT: linalg.yield // CHECK-NEXT: } diff --git a/tests/lhlo-legalize-to-parallel-loops.mlir b/tests/lhlo-legalize-to-parallel-loops.mlir index ba31ade..44c9ce0 100644 --- a/tests/lhlo-legalize-to-parallel-loops.mlir +++ b/tests/lhlo-legalize-to-parallel-loops.mlir @@ -21,27 +21,27 @@ func @reduce(%arg: memref<100x10x5xf32>, // CHECK-DAG: [[C5:%.*]] = constant 5 : index // CHECK-DAG: [[C10:%.*]] = constant 10 : index // CHECK-DAG: [[C100:%.*]] = constant 100 : index -// CHECK: [[INIT:%.*]] = load [[INIT_BUF]] +// CHECK: [[INIT:%.*]] = memref.load [[INIT_BUF]] // CHECK: scf.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]]) // CHECK-SAME: to ([[C100]], [[C5]]) step ([[C1]], [[C1]]) { // CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[J:%.*]]) = // CHECK-SAME: ([[C0]]) to ([[C10]]) step ([[C1]]) init ([[INIT]]) -> f32 { -// CHECK: [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]] +// CHECK: [[ELEM_TO_REDUCE:%.*]] = memref.load [[ARG_BUF]] // CHECK-SAME: {{\[}}[[I]], [[J]], [[K]]] : memref<100x10x5xf32> // CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 { // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): -// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref -// CHECK: [[ACC_BUF:%.*]] = alloc() : memref -// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref -// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref -// CHECK: store [[ACC]], [[ACC_BUF]][] : memref +// CHECK: [[ELEM_BUF:%.*]] = memref.alloc() : memref +// CHECK: [[ACC_BUF:%.*]] = memref.alloc() : memref +// CHECK: [[ACC_OUT_BUF:%.*]] = memref.alloc() : memref +// CHECK: memref.store [[ELEM]], [[ELEM_BUF]][] : memref +// CHECK: memref.store [[ACC]], [[ACC_BUF]][] : memref // CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) -// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref +// CHECK: [[ACC_RESULT:%.*]] = memref.load [[ACC_OUT_BUF]][] : memref // CHECK: scf.reduce.return [[ACC_RESULT]] : f32 // CHECK: } // CHECK: scf.yield // CHECK: } -// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]] +// CHECK: memref.store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]] // CHECK: scf.yield // ----- @@ -65,23 +65,23 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>, // CHECK-DAG: [[C0:%.*]] = constant 0 : index // CHECK-DAG: [[C1:%.*]] = constant 1 : index // CHECK-DAG: [[C100:%.*]] = constant 100 : index -// CHECK: [[INIT:%.*]] = load [[INIT_BUF]] +// CHECK: [[INIT:%.*]] = memref.load [[INIT_BUF]] // CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[I:%.*]]) = ([[C0]]) // CHECK-SAME: to ([[C100]]) step ([[C1]]) init ([[INIT]]) -> f32 { -// CHECK: [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]]{{\[}}[[I]]{{\]}} +// CHECK: [[ELEM_TO_REDUCE:%.*]] = memref.load [[ARG_BUF]]{{\[}}[[I]]{{\]}} // CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 { // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): -// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref -// CHECK: [[ACC_BUF:%.*]] = alloc() : memref -// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref -// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref -// CHECK: store [[ACC]], [[ACC_BUF]][] : memref +// CHECK: [[ELEM_BUF:%.*]] = memref.alloc() : memref +// CHECK: [[ACC_BUF:%.*]] = memref.alloc() : memref +// CHECK: [[ACC_OUT_BUF:%.*]] = memref.alloc() : memref +// CHECK: memref.store [[ELEM]], [[ELEM_BUF]][] : memref +// CHECK: memref.store [[ACC]], [[ACC_BUF]][] : memref // CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) -// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref +// CHECK: [[ACC_RESULT:%.*]] = memref.load [[ACC_OUT_BUF]][] : memref // CHECK: scf.reduce.return [[ACC_RESULT]] // CHECK: } // CHECK: scf.yield -// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[C0]]] +// CHECK: memref.store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[C0]]] // ----- @@ -104,30 +104,30 @@ func @dynamic_reduce(%arg: memref, // CHECK-DAG: [[C0:%.*]] = constant 0 : index // CHECK-DAG: [[C1:%.*]] = constant 1 : index // CHECK-DAG: [[C2:%.*]] = constant 2 : index -// CHECK: [[DIM0:%.*]] = dim [[ARG_BUF]], [[C0]] : memref -// CHECK: [[DIM1:%.*]] = dim [[ARG_BUF]], [[C1]] : memref -// CHECK: [[DIM2:%.*]] = dim [[ARG_BUF]], [[C2]] : memref -// CHECK: [[INIT:%.*]] = load [[INIT_BUF]] +// CHECK: [[DIM0:%.*]] = memref.dim [[ARG_BUF]], [[C0]] : memref +// CHECK: [[DIM1:%.*]] = memref.dim [[ARG_BUF]], [[C1]] : memref +// CHECK: [[DIM2:%.*]] = memref.dim [[ARG_BUF]], [[C2]] : memref +// CHECK: [[INIT:%.*]] = memref.load [[INIT_BUF]] // CHECK: scf.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]]) // CHECK-SAME: to ([[DIM0]], [[DIM2]]) step ([[C1]], [[C1]]) { // CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[J:%.*]]) = // CHECK-SAME: ([[C0]]) to ([[DIM1]]) step ([[C1]]) init ([[INIT]]) -> f32 { -// CHECK: [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]] +// CHECK: [[ELEM_TO_REDUCE:%.*]] = memref.load [[ARG_BUF]] // CHECK-SAME: {{\[}}[[I]], [[J]], [[K]]] : memref // CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 { // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): -// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref -// CHECK: [[ACC_BUF:%.*]] = alloc() : memref -// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref -// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref -// CHECK: store [[ACC]], [[ACC_BUF]][] : memref +// CHECK: [[ELEM_BUF:%.*]] = memref.alloc() : memref +// CHECK: [[ACC_BUF:%.*]] = memref.alloc() : memref +// CHECK: [[ACC_OUT_BUF:%.*]] = memref.alloc() : memref +// CHECK: memref.store [[ELEM]], [[ELEM_BUF]][] : memref +// CHECK: memref.store [[ACC]], [[ACC_BUF]][] : memref // CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) -// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref +// CHECK: [[ACC_RESULT:%.*]] = memref.load [[ACC_OUT_BUF]][] : memref // CHECK: scf.reduce.return [[ACC_RESULT]] : f32 // CHECK: } // CHECK: scf.yield // CHECK: } -// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]] +// CHECK: memref.store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]] // CHECK: scf.yield // ----- @@ -157,7 +157,7 @@ func @reduce_window(%arg: memref<112x112xf32>, // CHECK-DAG: [[C3:%.*]] = constant 3 : index // CHECK-DAG: [[C56:%.*]] = constant 56 : index // CHECK-DAG: [[C112:%.*]] = constant 112 : index -// CHECK: [[INIT:%.*]] = load [[INIT_BUF]][] : memref +// CHECK: [[INIT:%.*]] = memref.load [[INIT_BUF]][] : memref // CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) // CHECK-SAME: to ([[C56]], [[C56]]) step ([[C1]], [[C1]]) { // CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel @@ -176,7 +176,7 @@ func @reduce_window(%arg: memref<112x112xf32>, // CHECK: [[ELEM_TO_REDUCE:%.*]] = scf.if [[IN_BOUNDS_1]] -> (f32) { // CHECK: [[OPERAND_ELEM:%.*]] = -// CHECK-SAME: load [[OPERAND_BUF]]{{\[}}[[INDEX_I]], [[INDEX_J]]] +// CHECK-SAME: memref.load [[OPERAND_BUF]]{{\[}}[[INDEX_I]], [[INDEX_J]]] // CHECK: scf.yield [[OPERAND_ELEM]] : f32 // CHECK: } else { // CHECK: scf.yield [[INIT]] : f32 @@ -184,18 +184,18 @@ func @reduce_window(%arg: memref<112x112xf32>, // CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 { // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): -// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref -// CHECK: [[ACC_BUF:%.*]] = alloc() : memref -// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref -// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref -// CHECK: store [[ACC]], [[ACC_BUF]][] : memref +// CHECK: [[ELEM_BUF:%.*]] = memref.alloc() : memref +// CHECK: [[ACC_BUF:%.*]] = memref.alloc() : memref +// CHECK: [[ACC_OUT_BUF:%.*]] = memref.alloc() : memref +// CHECK: memref.store [[ELEM]], [[ELEM_BUF]][] : memref +// CHECK: memref.store [[ACC]], [[ACC_BUF]][] : memref // CHECK: "lmhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) -// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref +// CHECK: [[ACC_RESULT:%.*]] = memref.load [[ACC_OUT_BUF]][] : memref // CHECK: scf.reduce.return [[ACC_RESULT]] : f32 // CHECK: } // CHECK: scf.yield // CHECK: } -// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]] +// CHECK: memref.store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]] // CHECK: scf.yield // CHECK: } // CHECK: return diff --git a/tests/lhlo_gpu_ops.mlir b/tests/lhlo_gpu_ops.mlir index 82c455c..4ffd0f4 100644 --- a/tests/lhlo_gpu_ops.mlir +++ b/tests/lhlo_gpu_ops.mlir @@ -30,7 +30,7 @@ func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf3 // CHECK-LABEL: func @conv_forward func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %output: memref<1x1x7x7xf16>) { - %scratch = alloc() : memref<32xi8> + %scratch = memref.alloc() : memref<32xi8> // This defined a 2D convolution over a 8x8 single channel input using a 2x2 // filter and with an output of 7x7xf16. The 1x1x8x8 is (N, C, H, W) "lmhlo_gpu.conv_forward"(%input, %filter, %output, %scratch) @@ -61,7 +61,7 @@ func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, % // CHECK-LABEL: func @conv_backfilter func @conv_backfilter(%input : memref<3x56x56x16xf64>, %filter: memref<3x3x3x64xf64>, %output: memref<54x54x16x64xf64>) { - %scratch = alloc() : memref<23328xui8> + %scratch = memref.alloc() : memref<23328xui8> "lmhlo_gpu.conv_backwardfilter"(%input, %filter, %output, %scratch) { backend_config = {algorithm = 1 : i64, operand_0_layout = [3,2,1,0], @@ -91,7 +91,7 @@ func @conv_backfilter(%input : memref<3x56x56x16xf64>, %filter: memref<3x3x3x64x // CHECK-LABEL: func @conv_backinput func @conv_backinput(%input : memref<4x5x16x16xf64>, %filter : memref<5x3x7x7xf64>, %output : memref<4x3x16x16xf64>) { - %scratch = alloc() : memref<32xui8> + %scratch = memref.alloc() : memref<32xui8> "lmhlo_gpu.conv_backwardinput"(%input, %filter, %output, %scratch) { backend_config = {algorithm = 1 : i64, operand_0_layout = [3,2,1,0], @@ -122,7 +122,7 @@ func @conv_backinput(%input : memref<4x5x16x16xf64>, %filter : memref<5x3x7x7xf6 // CHECK-LABEL: func @conv_fused func @conv_fused(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>, %bias : memref<32xf16>, %output : memref<1x32x9x9xf16>) { - %scratch = alloc() : memref<32xui8> + %scratch = memref.alloc() : memref<32xui8> "lmhlo_gpu.conv_forward_fused"(%input, %filter, %bias, %output, %scratch) {activation_mode = "Relu", backend_config = {algorithm = 1 : i64, @@ -153,7 +153,7 @@ func @conv_fused(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>, // CHECK-LABEL: func @conv_fused_side_input func @conv_fused_side_input(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>, %bias : memref<32xf16>, %side_input: memref<32xf16>, %output : memref<1x32x9x9xf16>) { - %scratch = alloc() : memref<0xui8> + %scratch = memref.alloc() : memref<0xui8> "lmhlo_gpu.conv_forward_fused_with_side_input"(%input, %filter, %bias, %side_input, %output, %scratch) {activation_mode = "Relu", backend_config = {algorithm = 1 : i64, @@ -218,8 +218,8 @@ func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, // CHECK-LABEL: func @cholesky func @cholesky(%arg : memref<10x10xf32>, %out: memref<10x10xf32>) { - %scratch = alloc() : memref<32xi8> - %info = alloc() : memref<32xi32> + %scratch = memref.alloc() : memref<32xi8> + %info = memref.alloc() : memref<32xi32> "lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_lower = true } : (memref<10x10xf32>, memref<10x10xf32>, memref<32xi8>, memref<32xi32>) -> () return diff --git a/tests/lhlo_ops.mlir b/tests/lhlo_ops.mlir index 69a88d4..6c5bd6a 100644 --- a/tests/lhlo_ops.mlir +++ b/tests/lhlo_ops.mlir @@ -457,12 +457,12 @@ func @reduce_memref(%input: memref<10xf32>, %init: memref, %out: memref<1xf // CHECK-LABEL: func @fusion_memref func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: memref<10xf32>, %out: memref<10xf32>) -> () { "lmhlo.fusion"() ( { - %0 = tensor_load %input1 : memref<10xf32> - %1 = tensor_load %input2 : memref<10xf32> + %0 = memref.tensor_load %input1 : memref<10xf32> + %1 = memref.tensor_load %input2 : memref<10xf32> %2 = "mhlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> - %3 = tensor_load %input3 : memref<10xf32> + %3 = memref.tensor_load %input3 : memref<10xf32> %4 = "mhlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> - tensor_store %4, %out : memref<10xf32> + memref.tensor_store %4, %out : memref<10xf32> "lmhlo.terminator"() : () -> () } ) : () -> () return diff --git a/tests/unfuse_batch_norm.mlir b/tests/unfuse_batch_norm.mlir index f8fea55..b094774 100644 --- a/tests/unfuse_batch_norm.mlir +++ b/tests/unfuse_batch_norm.mlir @@ -108,15 +108,15 @@ func @batchNormInference_dynamic_shape( // CHECK-DAG: %[[C2:.*]] = constant 2 : index // CHECK-DAG: %[[C3:.*]] = constant 3 : index // CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e-03> : tensor - // CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], %[[C0]] : tensor + // CHECK-DAG: %[[DIM:.+]] = memref.dim %[[VARIANCE]], %[[C0]] : tensor // CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor.from_elements %[[DIM]] : tensor<1xindex> // CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor // CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor // CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor) -> tensor - // CHECK-DAG: %[[INPUT_DIM_0:.+]] = dim %[[X]], %[[C0]] : tensor - // CHECK-DAG: %[[INPUT_DIM_1:.+]] = dim %[[X]], %[[C1]] : tensor - // CHECK-DAG: %[[INPUT_DIM_2:.+]] = dim %[[X]], %[[C2]] : tensor - // CHECK-DAG: %[[INPUT_DIM_3:.+]] = dim %[[X]], %[[C3]] : tensor + // CHECK-DAG: %[[INPUT_DIM_0:.+]] = memref.dim %[[X]], %[[C0]] : tensor + // CHECK-DAG: %[[INPUT_DIM_1:.+]] = memref.dim %[[X]], %[[C1]] : tensor + // CHECK-DAG: %[[INPUT_DIM_2:.+]] = memref.dim %[[X]], %[[C2]] : tensor + // CHECK-DAG: %[[INPUT_DIM_3:.+]] = memref.dim %[[X]], %[[C3]] : tensor // CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = tensor.from_elements %[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]] : tensor<4xindex> // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor // CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor