Updates LLVM usage to match
[678241795c95](https://github.com/llvm/llvm-project/commit/678241795c95)

PiperOrigin-RevId: 363257913
This commit is contained in:
A. Unique TensorFlower 2021-03-16 13:31:59 -07:00 committed by TensorFlow MLIR Team
parent 2be112a603
commit c54527fe88
28 changed files with 403 additions and 376 deletions

5
BUILD
View File

@ -23,6 +23,7 @@ td_library(
], ],
includes = ["include"], includes = ["include"],
deps = [ deps = [
"@llvm-project//mlir:MemRefOpsTdFiles",
"@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:SideEffectTdFiles", "@llvm-project//mlir:SideEffectTdFiles",
], ],
@ -462,6 +463,7 @@ cc_library(
"@llvm-project//mlir:Analysis", "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:CopyOpInterface", "@llvm-project//mlir:CopyOpInterface",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:SideEffects", "@llvm-project//mlir:SideEffects",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
@ -615,6 +617,7 @@ cc_library(
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps", "@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
@ -724,6 +727,7 @@ cc_library(
"@llvm-project//mlir:Affine", "@llvm-project//mlir:Affine",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
@ -951,6 +955,7 @@ cc_library(
":hlo", ":hlo",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms", "@llvm-project//mlir:Transforms",

View File

@ -15,9 +15,9 @@
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
LLVM_COMMIT = "6878be5dc3ec7031d0deec3e321310115bd71103" LLVM_COMMIT = "678241795c957b18bc473045e48abe3f2a61ff5c"
LLVM_SHA256 = "f55187a3329fd97fd62fd0714783524d50a3be934a35484bd4442195fb25f0e5" LLVM_SHA256 = "58fd00a9ed7841f36aa7042bb8c98323b030dee98abe36757eea9ddf4fd5ea75"
LLVM_BAZEL_TAG = "llvm-project-{commit}".format(commit = LLVM_COMMIT) LLVM_BAZEL_TAG = "llvm-project-{commit}".format(commit = LLVM_COMMIT)

View File

@ -1,2 +1,2 @@
6878be5dc3ec7031d0deec3e321310115bd71103 678241795c957b18bc473045e48abe3f2a61ff5c

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"

View File

@ -33,6 +33,7 @@ limitations under the License.
#ifndef LHLO_OPS #ifndef LHLO_OPS
#define LHLO_OPS #define LHLO_OPS
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
include "mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
include "mlir/Interfaces/CopyOpInterface.td" include "mlir/Interfaces/CopyOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td"
@ -685,7 +686,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
let extraClassDeclaration = [{ let extraClassDeclaration = [{
SmallVector<Value, 4> getInputBuffers() { SmallVector<Value, 4> getInputBuffers() {
SmallVector<Value, 4> buffers; SmallVector<Value, 4> buffers;
this->region().walk([&](TensorLoadOp load) { this->region().walk([&](memref::TensorLoadOp load) {
if (load.memref().getParentRegion()->isProperAncestor(&region())) if (load.memref().getParentRegion()->isProperAncestor(&region()))
buffers.push_back(load.memref()); buffers.push_back(load.memref());
}); });
@ -694,7 +695,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
SmallVector<Value, 4> getOutputBuffers() { SmallVector<Value, 4> getOutputBuffers() {
SmallVector<Value, 4> buffers; SmallVector<Value, 4> buffers;
this->region().walk([&](TensorStoreOp store) { this->region().walk([&](memref::TensorStoreOp store) {
if (store.memref().getParentRegion()->isProperAncestor(&region())) if (store.memref().getParentRegion()->isProperAncestor(&region()))
buffers.push_back(store.memref()); buffers.push_back(store.memref());
}); });
@ -703,7 +704,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
SmallVector<Value, 4> getFusionParameters() { SmallVector<Value, 4> getFusionParameters() {
SmallVector<Value, 4> buffers; SmallVector<Value, 4> buffers;
this->region().walk([&](TensorLoadOp load) { this->region().walk([&](memref::TensorLoadOp load) {
if (load.memref().getParentRegion()->isProperAncestor(&region())) if (load.memref().getParentRegion()->isProperAncestor(&region()))
buffers.push_back(load); buffers.push_back(load);
}); });
@ -712,7 +713,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
SmallVector<Value, 4> getFusionResults() { SmallVector<Value, 4> getFusionResults() {
SmallVector<Value, 4> buffers; SmallVector<Value, 4> buffers;
this->region().walk([&](TensorStoreOp store) { this->region().walk([&](memref::TensorStoreOp store) {
if (store.memref().getParentRegion()->isProperAncestor(&region())) if (store.memref().getParentRegion()->isProperAncestor(&region()))
buffers.push_back(store.tensor()); buffers.push_back(store.tensor());
}); });

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef LHLO_OPS_BASE #ifndef LHLO_OPS_BASE
#define LHLO_OPS_BASE #define LHLO_OPS_BASE
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
include "mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"

View File

@ -33,6 +33,7 @@ limitations under the License.
#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/FormatVariadic.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
@ -156,13 +157,13 @@ struct EraseConstOp : public OpRewritePattern<ConstOp> {
LogicalResult matchAndRewrite(ConstOp op, LogicalResult matchAndRewrite(ConstOp op,
PatternRewriter& rewriter) const override { PatternRewriter& rewriter) const override {
Value memref = op.output(); Value memref = op.output();
if (!memref.getDefiningOp<AllocOp>()) { if (!memref.getDefiningOp<memref::AllocOp>()) {
return failure(); return failure();
} }
// Check that all uses of the memref are either DeallocOps or this op. // Check that all uses of the memref are either DeallocOps or this op.
for (Operation* user : memref.getUsers()) for (Operation* user : memref.getUsers())
if (user != op && !isa<DeallocOp>(user)) return failure(); if (user != op && !isa<memref::DeallocOp>(user)) return failure();
rewriter.eraseOp(op); rewriter.eraseOp(op);
return success(); return success();

View File

@ -71,7 +71,7 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result,
dynamic_operands.push_back(alloc_operand); dynamic_operands.push_back(alloc_operand);
} }
return rewriter->create<AllocOp>(loc, memref_type, dynamic_operands); return rewriter->create<memref::AllocOp>(loc, memref_type, dynamic_operands);
} }
Value InsertAlloc(Location loc, OpResult result, Value InsertAlloc(Location loc, OpResult result,
@ -85,7 +85,7 @@ Value InsertAlloc(Location loc, OpResult result,
MemRefType::get(result_type.getShape(), result_type.getElementType()); MemRefType::get(result_type.getShape(), result_type.getElementType());
OpBuilder::InsertionGuard guard(*rewriter); OpBuilder::InsertionGuard guard(*rewriter);
rewriter->setInsertionPoint(result.getDefiningOp()); rewriter->setInsertionPoint(result.getDefiningOp());
auto alloc = rewriter->create<AllocOp>(loc, memref_type); auto alloc = rewriter->create<memref::AllocOp>(loc, memref_type);
return alloc; return alloc;
} }
@ -207,7 +207,7 @@ class HloToLhloReshapeUnrankedConverter
if (unranked_operand_type == nullptr) return failure(); if (unranked_operand_type == nullptr) return failure();
auto result_type = op.getType().cast<RankedTensorType>(); auto result_type = op.getType().cast<RankedTensorType>();
rewriter.replaceOpWithNewOp<MemRefCastOp>( rewriter.replaceOpWithNewOp<memref::CastOp>(
op, adaptor.operand(), op, adaptor.operand(),
MemRefType::get(result_type.getShape(), result_type.getElementType())); MemRefType::get(result_type.getShape(), result_type.getElementType()));
return success(); return success();
@ -235,7 +235,7 @@ class HloToLhloDynamicReshapeConverter
return failure(); return failure();
} }
mhlo::DynamicReshapeOp::Adaptor adaptor(operands); mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
rewriter.replaceOpWithNewOp<MemRefReshapeOp>( rewriter.replaceOpWithNewOp<memref::ReshapeOp>(
op, result_type, adaptor.operand(), adaptor.output_shape()); op, result_type, adaptor.operand(), adaptor.output_shape());
return success(); return success();
} }
@ -273,7 +273,7 @@ class HloToLhloDynamicBroadcastInDimOpConverter
// Inserts dynamic memref to change the layout of the memref to put 0-stride // Inserts dynamic memref to change the layout of the memref to put 0-stride
// and size of the target dimension if size-1 dimension expansion is // and size of the target dimension if size-1 dimension expansion is
// necessary. // necessary.
MemRefReinterpretCastOp InsertDynamicMemrefCastOp( memref::ReinterpretCastOp InsertDynamicMemrefCastOp(
mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const { mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
auto loc = op.getLoc(); auto loc = op.getLoc();
auto operand_type = operand.getType().cast<MemRefType>(); auto operand_type = operand.getType().cast<MemRefType>();
@ -295,7 +295,7 @@ class HloToLhloDynamicBroadcastInDimOpConverter
for (int i = operand_rank - 1; i >= 0; --i) { for (int i = operand_rank - 1; i >= 0; --i) {
Value operand_dim_size = Value operand_dim_size =
ShapedType::isDynamic(operand_shape[i]) ShapedType::isDynamic(operand_shape[i])
? b->create<DimOp>(loc, operand, i).getResult() ? b->create<memref::DimOp>(loc, operand, i).getResult()
: b->create<ConstantIndexOp>(loc, operand_shape[i]).getResult(); : b->create<ConstantIndexOp>(loc, operand_shape[i]).getResult();
operand_sizes[i] = operand_dim_size; operand_sizes[i] = operand_dim_size;
@ -355,7 +355,7 @@ class HloToLhloDynamicBroadcastInDimOpConverter
makeStridedLinearLayoutMap(dynamic_layout, makeStridedLinearLayoutMap(dynamic_layout,
/*offset=*/0, b->getContext())); /*offset=*/0, b->getContext()));
auto transformed_operand = b->create<MemRefReinterpretCastOp>( auto transformed_operand = b->create<memref::ReinterpretCastOp>(
loc, type_erased_memref_type, operand, loc, type_erased_memref_type, operand,
/*offset=*/b->getI64IntegerAttr(0), sizes, strides); /*offset=*/b->getI64IntegerAttr(0), sizes, strides);
return transformed_operand; return transformed_operand;
@ -484,12 +484,12 @@ struct HloToLhloReturnOpConverter : public BaseOpConversion<mhlo::ReturnOp> {
// TODO(b/175789537) Remove this pattern. // TODO(b/175789537) Remove this pattern.
class HloToLhloTensorStoreOpLegacyConverter class HloToLhloTensorStoreOpLegacyConverter
: public BaseOpConversion<mlir::TensorStoreOp> { : public BaseOpConversion<mlir::memref::TensorStoreOp> {
public: public:
using BaseOpConversion<mlir::TensorStoreOp>::BaseOpConversion; using BaseOpConversion<mlir::memref::TensorStoreOp>::BaseOpConversion;
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
mlir::TensorStoreOp op, ArrayRef<Value> operands, mlir::memref::TensorStoreOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
rewriter.replaceOpWithNewOp<lmhlo::CopyOp>(op, llvm::None, operands.front(), rewriter.replaceOpWithNewOp<lmhlo::CopyOp>(op, llvm::None, operands.front(),
operands.back()); operands.back());
@ -577,14 +577,16 @@ struct HloLegalizeToLhlo
ConversionTarget target(context); ConversionTarget target(context);
target.addLegalDialect<lmhlo::LmhloDialect>(); target.addLegalDialect<lmhlo::LmhloDialect>();
target.addLegalDialect<StandardOpsDialect>(); target.addLegalDialect<StandardOpsDialect>();
target.addLegalDialect<memref::MemRefDialect>();
target.addLegalDialect<shape::ShapeDialect>(); target.addLegalDialect<shape::ShapeDialect>();
target.addLegalDialect<tensor::TensorDialect>(); target.addLegalDialect<tensor::TensorDialect>();
target.addIllegalDialect<mhlo::MhloDialect>(); target.addIllegalDialect<mhlo::MhloDialect>();
// Declare tensor_load and tensor_store illegal. // Declare tensor_load and tensor_store illegal.
target.addIllegalOp<mlir::TensorLoadOp, mlir::TensorStoreOp>(); target.addIllegalOp<mlir::memref::TensorLoadOp,
// tensor_to_memref is illegal if it has uses. mlir::memref::TensorStoreOp>();
// TODO(b/175670649) Make tensor_to_memref illegal. // buffer_cast is illegal if it has uses.
target.addDynamicallyLegalOp<mlir::TensorToMemrefOp>( // TODO(b/175670649) Make buffer_cast illegal.
target.addDynamicallyLegalOp<mlir::memref::BufferCastOp>(
[](auto op) { return op->use_empty(); }); [](auto op) { return op->use_empty(); });
BufferizeTypeConverter converter; BufferizeTypeConverter converter;

View File

@ -108,7 +108,7 @@ SmallVector<Value, 2> ExtractDynamicSizes(OpBuilder& b, Location loc,
dyn_sizes.push_back( dyn_sizes.push_back(
b.create<IndexCastOp>(loc, b.getIndexType(), extract)); b.create<IndexCastOp>(loc, b.getIndexType(), extract));
} else { } else {
dyn_sizes.push_back(b.create<DimOp>(loc, tensor, en.index())); dyn_sizes.push_back(b.create<memref::DimOp>(loc, tensor, en.index()));
} }
} }
return dyn_sizes; return dyn_sizes;
@ -324,13 +324,13 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
} }
// Create two loads from the input. // Create two loads from the input.
auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs()); auto lhs = rewriter.create<memref::LoadOp>(loc, lhlo_op.lhs());
auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs()); auto rhs = rewriter.create<memref::LoadOp>(loc, lhlo_op.rhs());
// TODO(ravishankarm) : Move this method out of lmhlo namespace. // TODO(ravishankarm) : Move this method out of lmhlo namespace.
Value op_result = lmhlo::HloOpToStdScalarOp::map<LhloOp>( Value op_result = lmhlo::HloOpToStdScalarOp::map<LhloOp>(
lhlo_op, arg_type.getElementType(), llvm::ArrayRef<Value>{lhs, rhs}, lhlo_op, arg_type.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
&rewriter); &rewriter);
rewriter.create<StoreOp>(loc, op_result, lhlo_op.out()); rewriter.create<memref::StoreOp>(loc, op_result, lhlo_op.out());
rewriter.eraseOp(lhlo_op); rewriter.eraseOp(lhlo_op);
return success(); return success();
} }
@ -590,8 +590,8 @@ class LhloBroadcastInDimConverter
operand_type.getDimSize(0) < operand_type.getDimSize(0) <
result_type.getDimSize(broadcast_dims.front())) { result_type.getDimSize(broadcast_dims.front())) {
Value zero = rewriter.create<ConstantIndexOp>(loc, 0); Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
Value val = Value val = rewriter.create<memref::LoadOp>(loc, operand,
rewriter.create<LoadOp>(loc, operand, llvm::makeArrayRef({zero})); llvm::makeArrayRef({zero}));
rewriter.create<linalg::GenericOp>( rewriter.create<linalg::GenericOp>(
loc, /*inputs=*/ValueRange{}, loc, /*inputs=*/ValueRange{},
/*outputBuffers=*/ValueRange{operand_adaptor.output()}, /*outputBuffers=*/ValueRange{operand_adaptor.output()},
@ -971,7 +971,8 @@ class ReduceConverter : public OpConversionPattern<lmhlo::ReduceOp> {
} }
// First fill the output buffer with the init value. // First fill the output buffer with the init value.
Value init_value = rewriter.create<LoadOp>(loc, adaptor.init_values()[0]); Value init_value =
rewriter.create<memref::LoadOp>(loc, adaptor.init_values()[0]);
rewriter.create<linalg::FillOp>(loc, adaptor.out()[0], init_value); rewriter.create<linalg::FillOp>(loc, adaptor.out()[0], init_value);
DenseIntElementsAttr dimensions_attr = reduce_op.dimensions(); DenseIntElementsAttr dimensions_attr = reduce_op.dimensions();
@ -1011,9 +1012,9 @@ class ReduceConverter : public OpConversionPattern<lmhlo::ReduceOp> {
// expects scalar SSA values. Add some allocs around the original op to // expects scalar SSA values. Add some allocs around the original op to
// make it compatible. // make it compatible.
auto arg_type = block->getArgument(0).getType().cast<MemRefType>(); auto arg_type = block->getArgument(0).getType().cast<MemRefType>();
Value alloc_a = rewriter.create<AllocaOp>(loc, arg_type); Value alloc_a = rewriter.create<memref::AllocaOp>(loc, arg_type);
Value alloc_b = rewriter.create<AllocaOp>(loc, arg_type); Value alloc_b = rewriter.create<memref::AllocaOp>(loc, arg_type);
Value alloc_res = rewriter.create<AllocaOp>(loc, arg_type); Value alloc_res = rewriter.create<memref::AllocaOp>(loc, arg_type);
// Now turn the existing signature // Now turn the existing signature
// (memref<X>, memref<X>, memref<X>) -> () // (memref<X>, memref<X>, memref<X>) -> ()
@ -1030,13 +1031,15 @@ class ReduceConverter : public OpConversionPattern<lmhlo::ReduceOp> {
// Store the arguments into the newly allocated buffers. // Store the arguments into the newly allocated buffers.
rewriter.setInsertionPointAfter(alloc_res.getDefiningOp()); rewriter.setInsertionPointAfter(alloc_res.getDefiningOp());
rewriter.create<StoreOp>(loc, entry_block->getArgument(0), alloc_a); rewriter.create<memref::StoreOp>(loc, entry_block->getArgument(0),
rewriter.create<StoreOp>(loc, entry_block->getArgument(1), alloc_b); alloc_a);
rewriter.create<memref::StoreOp>(loc, entry_block->getArgument(1),
alloc_b);
rewriter.replaceOp(entry_block->getTerminator(), {}); rewriter.replaceOp(entry_block->getTerminator(), {});
// Load & yield the result. // Load & yield the result.
rewriter.setInsertionPointToEnd(entry_block); rewriter.setInsertionPointToEnd(entry_block);
auto load_res = rewriter.create<LoadOp>(loc, alloc_res); auto load_res = rewriter.create<memref::LoadOp>(loc, alloc_res);
rewriter.create<linalg::YieldOp>(loc, ValueRange{load_res}); rewriter.create<linalg::YieldOp>(loc, ValueRange{load_res});
} }
@ -1099,8 +1102,8 @@ class SliceConverter : public OpConversionPattern<OpTy> {
slice_op.strides().template getValue<int64_t>(i))); slice_op.strides().template getValue<int64_t>(i)));
} }
if (isLHLO) { if (isLHLO) {
auto linalg_op = auto linalg_op = rewriter.create<memref::SubViewOp>(loc, args[0], offsets,
rewriter.create<SubViewOp>(loc, args[0], offsets, sizes, strides); sizes, strides);
rewriter.create<linalg::CopyOp>(loc, linalg_op, args[1]); rewriter.create<linalg::CopyOp>(loc, linalg_op, args[1]);
rewriter.eraseOp(slice_op); rewriter.eraseOp(slice_op);
} else { } else {
@ -1149,14 +1152,14 @@ SmallVector<Value, 2> GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc,
switch (type) { switch (type) {
case DotOperationType::kMatrixMatrix: { case DotOperationType::kMatrixMatrix: {
if (lhs.getType().cast<ShapedType>().isDynamicDim(0)) if (lhs.getType().cast<ShapedType>().isDynamicDim(0))
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0)); dyn_shape.push_back(b.create<memref::DimOp>(loc, lhs, 0));
if (rhs.getType().cast<ShapedType>().isDynamicDim(1)) if (rhs.getType().cast<ShapedType>().isDynamicDim(1))
dyn_shape.push_back(b.create<DimOp>(loc, rhs, 1)); dyn_shape.push_back(b.create<memref::DimOp>(loc, rhs, 1));
break; break;
} }
case DotOperationType::kMatrixVector: { case DotOperationType::kMatrixVector: {
if (lhs.getType().cast<ShapedType>().isDynamicDim(0)) if (lhs.getType().cast<ShapedType>().isDynamicDim(0))
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0)); dyn_shape.push_back(b.create<memref::DimOp>(loc, lhs, 0));
break; break;
} }
case DotOperationType::kVectorDot: case DotOperationType::kVectorDot:
@ -1203,11 +1206,11 @@ SmallVector<Value, 8> GetDotGeneralOpInitTensorDynSizes(
OpBuilder& b, Location loc, Value lhs, Value rhs, ShapedType result_type) { OpBuilder& b, Location loc, Value lhs, Value rhs, ShapedType result_type) {
SmallVector<Value, 8> dyn_shape; SmallVector<Value, 8> dyn_shape;
if (result_type.isDynamicDim(0)) if (result_type.isDynamicDim(0))
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0)); dyn_shape.push_back(b.create<memref::DimOp>(loc, lhs, 0));
if (result_type.isDynamicDim(1)) if (result_type.isDynamicDim(1))
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 1)); dyn_shape.push_back(b.create<memref::DimOp>(loc, lhs, 1));
if (result_type.isDynamicDim(2)) if (result_type.isDynamicDim(2))
dyn_shape.push_back(b.create<DimOp>(loc, rhs, 2)); dyn_shape.push_back(b.create<memref::DimOp>(loc, rhs, 2));
return dyn_shape; return dyn_shape;
} }
@ -1307,7 +1310,7 @@ SmallVector<Value, 8> GetReduceOpInitTensorDynSizes(
for (int i = 0, j = 0; i < rank; ++i) { for (int i = 0, j = 0; i < rank; ++i) {
if (s.count(i)) continue; if (s.count(i)) continue;
if (!result_type.isDynamicDim(j++)) continue; if (!result_type.isDynamicDim(j++)) continue;
dyn_shape.push_back(b.create<DimOp>(loc, arg, i)); dyn_shape.push_back(b.create<memref::DimOp>(loc, arg, i));
} }
return dyn_shape; return dyn_shape;
@ -1467,7 +1470,7 @@ struct NormalConvOpOnTensorsConversion
// The output shape is N spatial_dims F. // The output shape is N spatial_dims F.
SmallVector<Value, 8> dyn_sizes; SmallVector<Value, 8> dyn_sizes;
if (result_type.isDynamicDim(0)) { if (result_type.isDynamicDim(0)) {
dyn_sizes.push_back(rewriter.create<DimOp>(loc, input, 0)); dyn_sizes.push_back(rewriter.create<memref::DimOp>(loc, input, 0));
} }
for (int64_t i = 1, e = rank - 1; i < e; ++i) { for (int64_t i = 1, e = rank - 1; i < e; ++i) {
if (result_type.isDynamicDim(i)) { if (result_type.isDynamicDim(i)) {
@ -1476,7 +1479,8 @@ struct NormalConvOpOnTensorsConversion
} }
} }
if (result_type.isDynamicDim(rank - 1)) { if (result_type.isDynamicDim(rank - 1)) {
dyn_sizes.push_back(rewriter.create<DimOp>(loc, filter, rank - 1)); dyn_sizes.push_back(
rewriter.create<memref::DimOp>(loc, filter, rank - 1));
} }
Value init_tensor = rewriter.create<linalg::InitTensorOp>( Value init_tensor = rewriter.create<linalg::InitTensorOp>(
loc, dyn_sizes, result_type.getShape(), result_type.getElementType()); loc, dyn_sizes, result_type.getShape(), result_type.getElementType());
@ -1856,8 +1860,8 @@ struct LhloLegalizeToLinalgPass
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
ConversionTarget target(getContext()); ConversionTarget target(getContext());
target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect, target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
math::MathDialect, StandardOpsDialect, math::MathDialect, memref::MemRefDialect,
AffineDialect>(); StandardOpsDialect, AffineDialect>();
auto func = getFunction(); auto func = getFunction();
populateLHLOToLinalgConversionPattern(func.getContext(), &patterns); populateLHLOToLinalgConversionPattern(func.getContext(), &patterns);
@ -1881,6 +1885,9 @@ struct HloLegalizeToLinalgPass
math::MathDialect, StandardOpsDialect, math::MathDialect, StandardOpsDialect,
tensor::TensorDialect, scf::SCFDialect>(); tensor::TensorDialect, scf::SCFDialect>();
// TODO: DimOp shouldn't be in MemRefDialect
target.addLegalOp<memref::DimOp>();
auto func = getFunction(); auto func = getFunction();
mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns); mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
if (failed(applyPartialConversion(func, target, std::move(patterns)))) { if (failed(applyPartialConversion(func, target, std::move(patterns)))) {

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
@ -95,7 +96,7 @@ class LhloFuseLinalgPass
continue; continue;
} }
if (auto tensor_load = dyn_cast<TensorLoadOp>(definingOp)) { if (auto tensor_load = dyn_cast<memref::TensorLoadOp>(definingOp)) {
auto alias = tensor_load.memref(); auto alias = tensor_load.memref();
if (result_buffers.insert(alias).second) { if (result_buffers.insert(alias).second) {
worklist.push_back(alias); worklist.push_back(alias);
@ -103,7 +104,7 @@ class LhloFuseLinalgPass
continue; continue;
} }
if (auto tensor_to_memref = dyn_cast<TensorToMemrefOp>(definingOp)) { if (auto tensor_to_memref = dyn_cast<memref::BufferCastOp>(definingOp)) {
auto alias = tensor_to_memref.tensor(); auto alias = tensor_to_memref.tensor();
if (result_buffers.insert(alias).second) { if (result_buffers.insert(alias).second) {
worklist.push_back(alias); worklist.push_back(alias);

View File

@ -96,9 +96,10 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
// Load the initial value and store it to the output. // Load the initial value and store it to the output.
for (auto pair : llvm::zip(reduce_op.init_values(), reduce_op.out())) { for (auto pair : llvm::zip(reduce_op.init_values(), reduce_op.out())) {
auto init_value = rewriter.create<mlir::LoadOp>(loc, std::get<0>(pair)); auto init_value =
rewriter.create<mlir::StoreOp>(loc, init_value, std::get<1>(pair), rewriter.create<mlir::memref::LoadOp>(loc, std::get<0>(pair));
ArrayRef<Value>{index}); rewriter.create<mlir::memref::StoreOp>(
loc, init_value, std::get<1>(pair), ArrayRef<Value>{index});
} }
// Insert a loop into the body to compute the reduction. The loop ranges // Insert a loop into the body to compute the reduction. The loop ranges
@ -128,8 +129,8 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
auto oneAttr = rewriter.getI64IntegerAttr(1); auto oneAttr = rewriter.getI64IntegerAttr(1);
OpFoldResult size = oneAttr; OpFoldResult size = oneAttr;
OpFoldResult stride = oneAttr; OpFoldResult stride = oneAttr;
auto accumulator = rewriter.create<SubViewOp>(loc, resType, output, auto accumulator = rewriter.create<memref::SubViewOp>(
offset, size, stride); loc, resType, output, offset, size, stride);
llvm::SmallVector<Value, 4> indexings; llvm::SmallVector<Value, 4> indexings;
auto input_buffer = *reduce_op.operands().begin(); auto input_buffer = *reduce_op.operands().begin();
auto input_type_rank = auto input_type_rank =
@ -143,8 +144,8 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
})); }));
SmallVector<OpFoldResult> sizes(input_type_rank, oneAttr); SmallVector<OpFoldResult> sizes(input_type_rank, oneAttr);
SmallVector<OpFoldResult> strides(input_type_rank, oneAttr); SmallVector<OpFoldResult> strides(input_type_rank, oneAttr);
auto rhs = rewriter.create<SubViewOp>(loc, accumulator.getType(), input, auto rhs = rewriter.create<memref::SubViewOp>(
offsets, sizes, strides); loc, accumulator.getType(), input, offsets, sizes, strides);
// Now copy over the actual body of the reduction, leaving out the // Now copy over the actual body of the reduction, leaving out the
// terminator. // terminator.
@ -179,8 +180,9 @@ struct LhloLegalizeToGpuPass
void runOnFunction() override { void runOnFunction() override {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
ConversionTarget target(getContext()); ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect, target.addLegalDialect<linalg::LinalgDialect, memref::MemRefDialect,
gpu::GPUDialect, scf::SCFDialect, LmhloDialect>(); StandardOpsDialect, gpu::GPUDialect, scf::SCFDialect,
LmhloDialect>();
target.addIllegalOp<ReduceOp>(); target.addIllegalOp<ReduceOp>();
auto func = getFunction(); auto func = getFunction();
patterns.insert<LhloReduceToGPULaunchConverter>(func.getContext()); patterns.insert<LhloReduceToGPULaunchConverter>(func.getContext());

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
@ -43,10 +44,11 @@ Value ApplySingleResultLhloCode(Location loc, ValueRange operands,
Block* lhlo_block, OpBuilder* b) { Block* lhlo_block, OpBuilder* b) {
SmallVector<Value, 2> arg_bufs; SmallVector<Value, 2> arg_bufs;
for (auto arg_type : lhlo_block->getArgumentTypes()) { for (auto arg_type : lhlo_block->getArgumentTypes()) {
arg_bufs.push_back(b->create<AllocOp>(loc, arg_type.cast<MemRefType>())); arg_bufs.push_back(
b->create<memref::AllocOp>(loc, arg_type.cast<MemRefType>()));
} }
for (auto operand : llvm::enumerate(operands)) { for (auto operand : llvm::enumerate(operands)) {
b->create<StoreOp>(loc, operand.value(), arg_bufs[operand.index()]); b->create<memref::StoreOp>(loc, operand.value(), arg_bufs[operand.index()]);
} }
// Clone the ops from `lhlo_block`. // Clone the ops from `lhlo_block`.
BlockAndValueMapping mapping; BlockAndValueMapping mapping;
@ -55,7 +57,7 @@ Value ApplySingleResultLhloCode(Location loc, ValueRange operands,
auto clone = b->clone(nested, mapping); auto clone = b->clone(nested, mapping);
mapping.map(nested.getResults(), clone->getResults()); mapping.map(nested.getResults(), clone->getResults());
} }
return b->create<LoadOp>(loc, arg_bufs.back()); return b->create<memref::LoadOp>(loc, arg_bufs.back());
} }
// Converts a block with LHLO ops and with signature: // Converts a block with LHLO ops and with signature:
@ -78,7 +80,8 @@ void ConvertToReductionOperator(Location loc, scf::ReduceOp reduce_op,
Value GetStaticOrDynamicDim(mlir::Location loc, Value shaped_value, Value GetStaticOrDynamicDim(mlir::Location loc, Value shaped_value,
size_t dim_index, int64_t dim, OpBuilder* b) { size_t dim_index, int64_t dim, OpBuilder* b) {
return dim == ShapedType::kDynamicSize return dim == ShapedType::kDynamicSize
? b->create<DimOp>(loc, shaped_value, dim_index).getResult() ? b->create<memref::DimOp>(loc, shaped_value, dim_index)
.getResult()
: b->create<ConstantIndexOp>(loc, dim); : b->create<ConstantIndexOp>(loc, dim);
} }
@ -249,8 +252,8 @@ class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
(is_reducing_dim ? reduce_step : parallel_step).push_back(step); (is_reducing_dim ? reduce_step : parallel_step).push_back(step);
} }
// Load initial value from memref<element_type>. // Load initial value from memref<element_type>.
SmallVector<Value, 1> init_value = { SmallVector<Value, 1> init_value = {rewriter->create<memref::LoadOp>(
rewriter->create<LoadOp>(loc, *reduce_op.init_values().begin())}; loc, *reduce_op.init_values().begin())};
// Outer ParallelOp is not needed if it is a reduction across all dims. // Outer ParallelOp is not needed if it is a reduction across all dims.
scf::ParallelOp outer; scf::ParallelOp outer;
if (!parallel_lower.empty()) { if (!parallel_lower.empty()) {
@ -272,7 +275,7 @@ class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
out_indices.push_back(rewriter->create<ConstantIndexOp>(loc, 0)); out_indices.push_back(rewriter->create<ConstantIndexOp>(loc, 0));
} }
rewriter->create<StoreOp>(loc, reduction_result, out, out_indices); rewriter->create<memref::StoreOp>(loc, reduction_result, out, out_indices);
// Load the element to reduce. // Load the element to reduce.
SmallVector<Value, 2> indices; SmallVector<Value, 2> indices;
@ -290,7 +293,7 @@ class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
} }
rewriter->setInsertionPointToStart(inner.getBody()); rewriter->setInsertionPointToStart(inner.getBody());
Value elem = rewriter->create<mlir::LoadOp>( Value elem = rewriter->create<mlir::memref::LoadOp>(
loc, *reduce_op.operands().begin(), indices); loc, *reduce_op.operands().begin(), indices);
return rewriter->create<scf::ReduceOp>(loc, elem); return rewriter->create<scf::ReduceOp>(loc, elem);
} }
@ -385,7 +388,7 @@ class ReduceWindowOpConverter
ConversionPatternRewriter* rewriter) const { ConversionPatternRewriter* rewriter) const {
auto loc = reduce_window_op.getLoc(); auto loc = reduce_window_op.getLoc();
Value init_value = Value init_value =
rewriter->create<LoadOp>(loc, reduce_window_op.init_value()); rewriter->create<memref::LoadOp>(loc, reduce_window_op.init_value());
Value zero = rewriter->create<ConstantIndexOp>(loc, 0); Value zero = rewriter->create<ConstantIndexOp>(loc, 0);
Value one = rewriter->create<ConstantIndexOp>(loc, 1); Value one = rewriter->create<ConstantIndexOp>(loc, 1);
@ -408,7 +411,8 @@ class ReduceWindowOpConverter
Value reduction_result = *window_loop.getResults().begin(); Value reduction_result = *window_loop.getResults().begin();
auto output_ivs = output_loop.getInductionVars(); auto output_ivs = output_loop.getInductionVars();
rewriter->create<StoreOp>(loc, reduction_result, output, output_ivs); rewriter->create<memref::StoreOp>(loc, reduction_result, output,
output_ivs);
return std::make_pair(output_loop, window_loop); return std::make_pair(output_loop, window_loop);
} }
@ -439,7 +443,7 @@ class ReduceWindowOpConverter
OpBuilder then_builder = OpBuilder then_builder =
elem_or_init.getThenBodyBuilder(rewriter->getListener()); elem_or_init.getThenBodyBuilder(rewriter->getListener());
Value elem = then_builder.create<mlir::LoadOp>( Value elem = then_builder.create<mlir::memref::LoadOp>(
loc, reduce_window_op.operand(), mapped_ivs.ivs); loc, reduce_window_op.operand(), mapped_ivs.ivs);
then_builder.create<scf::YieldOp>(loc, elem); then_builder.create<scf::YieldOp>(loc, elem);
@ -497,8 +501,8 @@ class SelectAndScatterOpConverter
auto selected_ivs = SelectIvs(s_and_s_op, loop_over_src, &rewriter); auto selected_ivs = SelectIvs(s_and_s_op, loop_over_src, &rewriter);
// Load `source[selected_ivs]`. // Load `source[selected_ivs]`.
auto src_elem = rewriter.create<LoadOp>(loc, s_and_s_op.source(), auto src_elem = rewriter.create<memref::LoadOp>(
loop_over_src.getInductionVars()); loc, s_and_s_op.source(), loop_over_src.getInductionVars());
// Compute `out[selected_ivs]` = scatter(out[selected_ivs], src_element)`. // Compute `out[selected_ivs]` = scatter(out[selected_ivs], src_element)`.
auto rmw = rewriter.create<GenericAtomicRMWOp>(loc, s_and_s_op.out(), auto rmw = rewriter.create<GenericAtomicRMWOp>(loc, s_and_s_op.out(),
@ -517,13 +521,13 @@ class SelectAndScatterOpConverter
void InitializeOutput(lmhlo::SelectAndScatterOp s_and_s_op, void InitializeOutput(lmhlo::SelectAndScatterOp s_and_s_op,
OpBuilder* b) const { OpBuilder* b) const {
auto loc = s_and_s_op.getLoc(); auto loc = s_and_s_op.getLoc();
Value init_value = b->create<LoadOp>(loc, s_and_s_op.init_value()); Value init_value = b->create<memref::LoadOp>(loc, s_and_s_op.init_value());
scf::ParallelOp loop_over_output = scf::ParallelOp loop_over_output =
MakeLoopOverShape(loc, s_and_s_op.out(), b); MakeLoopOverShape(loc, s_and_s_op.out(), b);
OpBuilder::InsertionGuard guard(*b); OpBuilder::InsertionGuard guard(*b);
b->setInsertionPointToStart(loop_over_output.getBody()); b->setInsertionPointToStart(loop_over_output.getBody());
b->create<StoreOp>(loc, init_value, s_and_s_op.out(), b->create<memref::StoreOp>(loc, init_value, s_and_s_op.out(),
loop_over_output.getInductionVars()); loop_over_output.getInductionVars());
} }
@ -647,7 +651,7 @@ class SelectAndScatterOpConverter
TypeRange iter_arg_types{ivs_val_flag->to_vector()}; TypeRange iter_arg_types{ivs_val_flag->to_vector()};
Value operand_elem = Value operand_elem =
b->create<LoadOp>(loc, s_and_s_op.operand(), operand_ivs); b->create<memref::LoadOp>(loc, s_and_s_op.operand(), operand_ivs);
auto if_init = auto if_init =
b->create<scf::IfOp>(loc, iter_arg_types, ivs_val_flag->is_init(), b->create<scf::IfOp>(loc, iter_arg_types, ivs_val_flag->is_init(),
/*withElseRegion=*/true); /*withElseRegion=*/true);
@ -712,8 +716,8 @@ struct LhloLegalizeToParallelLoopsPass
// clang-format on // clang-format on
ConversionTarget target(getContext()); ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect, target.addLegalDialect<linalg::LinalgDialect, memref::MemRefDialect,
scf::SCFDialect, LmhloDialect>(); StandardOpsDialect, scf::SCFDialect, LmhloDialect>();
target.addIllegalOp<lmhlo::ReduceOp, lmhlo::ReduceWindowOp, target.addIllegalOp<lmhlo::ReduceOp, lmhlo::ReduceWindowOp,
lmhlo::SelectAndScatterOp>(); lmhlo::SelectAndScatterOp>();

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
@ -58,7 +59,8 @@ Value CalculateShapeValue(Location loc, Value operand,
int64_t rank = result_type.getRank(); int64_t rank = result_type.getRank();
shape_values.reserve(rank); shape_values.reserve(rank);
for (int64_t i = 0; i < rank; ++i) { for (int64_t i = 0; i < rank; ++i) {
shape_values.push_back(rewriter.create<mlir::DimOp>(loc, operand, i)); shape_values.push_back(
rewriter.create<mlir::memref::DimOp>(loc, operand, i));
} }
return rewriter.create<tensor::FromElementsOp>(loc, shape_values); return rewriter.create<tensor::FromElementsOp>(loc, shape_values);
} }

View File

@ -967,10 +967,10 @@ func @unpack_repack_same_tuple_single_element(%arg0: tuple<tensor<i32>>) -> tupl
// CHECK-LABEL: func @erase_dead_lhlo_constant // CHECK-LABEL: func @erase_dead_lhlo_constant
func @erase_dead_lhlo_constant() { func @erase_dead_lhlo_constant() {
%M = alloc() : memref<256x1024xf32> %M = memref.alloc() : memref<256x1024xf32>
// CHECK-NEXT: return // CHECK-NEXT: return
"lmhlo.constant"(%M) {value = dense<0.0> : tensor<f32>} : (memref<256x1024xf32>) -> () "lmhlo.constant"(%M) {value = dense<0.0> : tensor<f32>} : (memref<256x1024xf32>) -> ()
dealloc %M : memref<256x1024xf32> memref.dealloc %M : memref<256x1024xf32>
return return
} }
@ -979,9 +979,9 @@ func @erase_dead_lhlo_constant() {
func @erase_dead_lhlo_constant_negative(%M : memref<4xf32>) -> memref<256x1024xf32> { func @erase_dead_lhlo_constant_negative(%M : memref<4xf32>) -> memref<256x1024xf32> {
// CHECK-NEXT: lmhlo.constant // CHECK-NEXT: lmhlo.constant
"lmhlo.constant"(%M) {value = dense<0.0> : tensor<f32>} : (memref<4xf32>) -> () "lmhlo.constant"(%M) {value = dense<0.0> : tensor<f32>} : (memref<4xf32>) -> ()
// CHECK-NEXT: alloc // CHECK-NEXT: memref.alloc
// CHECK-NEXT: lmhlo.constant // CHECK-NEXT: lmhlo.constant
%N = alloc() : memref<256x1024xf32> %N = memref.alloc() : memref<256x1024xf32>
"lmhlo.constant"(%N) {value = dense<0.0> : tensor<f32>} : (memref<256x1024xf32>) -> () "lmhlo.constant"(%N) {value = dense<0.0> : tensor<f32>} : (memref<256x1024xf32>) -> ()
return %N : memref<256x1024xf32> return %N : memref<256x1024xf32>
} }

View File

@ -27,7 +27,7 @@ func private @print_memref_i8(memref<*xi8>) attributes { llvm.emit_c_interface }
func private @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface } func private @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface }
func @trivial_broadcast_wrapper() { func @trivial_broadcast_wrapper() {
%input_buf = alloc() : memref<3xf32> %input_buf = memref.alloc() : memref<3xf32>
%c1f32 = constant 1.0 : f32 %c1f32 = constant 1.0 : f32
%c2f32 = constant 2.0 : f32 %c2f32 = constant 2.0 : f32
@ -36,19 +36,19 @@ func @trivial_broadcast_wrapper() {
%c0 = constant 0 : index %c0 = constant 0 : index
%c1 = constant 1 : index %c1 = constant 1 : index
%c2 = constant 2 : index %c2 = constant 2 : index
store %c1f32, %input_buf[%c0] : memref<3xf32> memref.store %c1f32, %input_buf[%c0] : memref<3xf32>
store %c2f32, %input_buf[%c1] : memref<3xf32> memref.store %c2f32, %input_buf[%c1] : memref<3xf32>
store %c3f32, %input_buf[%c2] : memref<3xf32> memref.store %c3f32, %input_buf[%c2] : memref<3xf32>
%input = tensor_load %input_buf : memref<3xf32> %input = memref.tensor_load %input_buf : memref<3xf32>
// Test BroadcastInDimOp. // Test BroadcastInDimOp.
%output = "mhlo.broadcast_in_dim"(%input) { %output = "mhlo.broadcast_in_dim"(%input) {
broadcast_dimensions = dense<0> : tensor<1xi64> broadcast_dimensions = dense<0> : tensor<1xi64>
} : (tensor<3xf32>) -> tensor<3x4xf32> } : (tensor<3xf32>) -> tensor<3x4xf32>
%output_buf = tensor_to_memref %output : memref<3x4xf32> %output_buf = memref.buffer_cast %output : memref<3x4xf32>
%unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32> %unranked_output = memref.cast %output_buf : memref<3x4xf32> to memref<*xf32>
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
// CHECK-NEXT: [1, 1, 1, 1] // CHECK-NEXT: [1, 1, 1, 1]
@ -63,9 +63,9 @@ func @trivial_broadcast_wrapper() {
broadcast_dimensions = dense<0> : tensor<1xi64> broadcast_dimensions = dense<0> : tensor<1xi64>
} : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x4xf32> } : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x4xf32>
%dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32> %dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x4xf32>
%unranked_dyn_output = memref_cast %dyn_output_buf %unranked_dyn_output = memref.cast %dyn_output_buf
: memref<3x4xf32> to memref<*xf32> : memref<3x4xf32> to memref<*xf32>
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> () call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
@ -76,29 +76,29 @@ func @trivial_broadcast_wrapper() {
} }
func @broadcast_in_X_dim_wrapper() { func @broadcast_in_X_dim_wrapper() {
%input_buf = alloc() : memref<1x4xf32> %input_buf = memref.alloc() : memref<1x4xf32>
%c1f32 = constant 1.0 : f32 %c1f32 = constant 1.0 : f32
%c0 = constant 0 : index %c0 = constant 0 : index
store %c1f32, %input_buf[%c0, %c0] : memref<1x4xf32> memref.store %c1f32, %input_buf[%c0, %c0] : memref<1x4xf32>
%c2f32 = constant 2.0 : f32 %c2f32 = constant 2.0 : f32
%c1 = constant 1 : index %c1 = constant 1 : index
store %c2f32, %input_buf[%c0, %c1] : memref<1x4xf32> memref.store %c2f32, %input_buf[%c0, %c1] : memref<1x4xf32>
%c3f32 = constant 3.0 : f32 %c3f32 = constant 3.0 : f32
%c2 = constant 2 : index %c2 = constant 2 : index
store %c3f32, %input_buf[%c0, %c2] : memref<1x4xf32> memref.store %c3f32, %input_buf[%c0, %c2] : memref<1x4xf32>
%c4f32 = constant 4.0 : f32 %c4f32 = constant 4.0 : f32
%c3 = constant 3 : index %c3 = constant 3 : index
store %c4f32, %input_buf[%c0, %c3] : memref<1x4xf32> memref.store %c4f32, %input_buf[%c0, %c3] : memref<1x4xf32>
%input = tensor_load %input_buf : memref<1x4xf32> %input = memref.tensor_load %input_buf : memref<1x4xf32>
// Test BroadcastInDimOp. // Test BroadcastInDimOp.
%output = "mhlo.broadcast_in_dim"(%input) { %output = "mhlo.broadcast_in_dim"(%input) {
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<1x4xf32>) -> tensor<3x4xf32> } : (tensor<1x4xf32>) -> tensor<3x4xf32>
%output_buf = tensor_to_memref %output : memref<3x4xf32> %output_buf = memref.buffer_cast %output : memref<3x4xf32>
%unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32> %unranked_output = memref.cast %output_buf : memref<3x4xf32> to memref<*xf32>
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
// CHECK-NEXT: [1, 2, 3, 4] // CHECK-NEXT: [1, 2, 3, 4]
@ -112,9 +112,9 @@ func @broadcast_in_X_dim_wrapper() {
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<1x4xf32>, tensor<2xindex>) -> tensor<3x4xf32> } : (tensor<1x4xf32>, tensor<2xindex>) -> tensor<3x4xf32>
%dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32> %dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x4xf32>
%unranked_dyn_output = memref_cast %dyn_output_buf %unranked_dyn_output = memref.cast %dyn_output_buf
: memref<3x4xf32> to memref<*xf32> : memref<3x4xf32> to memref<*xf32>
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> () call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
@ -125,26 +125,26 @@ func @broadcast_in_X_dim_wrapper() {
} }
func @broadcast_in_Y_dim_wrapper() { func @broadcast_in_Y_dim_wrapper() {
%input_buf = alloc() : memref<3x1xf32> %input_buf = memref.alloc() : memref<3x1xf32>
%c1f32 = constant 1.0 : f32 %c1f32 = constant 1.0 : f32
%c0 = constant 0 : index %c0 = constant 0 : index
store %c1f32, %input_buf[%c0, %c0] : memref<3x1xf32> memref.store %c1f32, %input_buf[%c0, %c0] : memref<3x1xf32>
%c2f32 = constant 2.0 : f32 %c2f32 = constant 2.0 : f32
%c1 = constant 1 : index %c1 = constant 1 : index
store %c2f32, %input_buf[%c1, %c0] : memref<3x1xf32> memref.store %c2f32, %input_buf[%c1, %c0] : memref<3x1xf32>
%c3f32 = constant 3.0 : f32 %c3f32 = constant 3.0 : f32
%c2 = constant 2 : index %c2 = constant 2 : index
store %c3f32, %input_buf[%c2, %c0] : memref<3x1xf32> memref.store %c3f32, %input_buf[%c2, %c0] : memref<3x1xf32>
%input = tensor_load %input_buf : memref<3x1xf32> %input = memref.tensor_load %input_buf : memref<3x1xf32>
// Test BroadcastInDimOp. // Test BroadcastInDimOp.
%output = "mhlo.broadcast_in_dim"(%input) { %output = "mhlo.broadcast_in_dim"(%input) {
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<3x1xf32>) -> tensor<3x4xf32> } : (tensor<3x1xf32>) -> tensor<3x4xf32>
%output_buf = tensor_to_memref %output : memref<3x4xf32> %output_buf = memref.buffer_cast %output : memref<3x4xf32>
%unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32> %unranked_output = memref.cast %output_buf : memref<3x4xf32> to memref<*xf32>
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
// CHECK-NEXT: [1, 1, 1, 1] // CHECK-NEXT: [1, 1, 1, 1]
@ -159,9 +159,9 @@ func @broadcast_in_Y_dim_wrapper() {
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<3x1xf32>, tensor<2xindex>) -> tensor<3x4xf32> } : (tensor<3x1xf32>, tensor<2xindex>) -> tensor<3x4xf32>
%dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32> %dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x4xf32>
%unranked_dyn_output = memref_cast %dyn_output_buf %unranked_dyn_output = memref.cast %dyn_output_buf
: memref<3x4xf32> to memref<*xf32> : memref<3x4xf32> to memref<*xf32>
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> () call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
@ -172,29 +172,29 @@ func @broadcast_in_Y_dim_wrapper() {
} }
func @broadcast_in_X_dim_transpose_wrapper() { func @broadcast_in_X_dim_transpose_wrapper() {
%input_buf = alloc() : memref<4x1xf32> %input_buf = memref.alloc() : memref<4x1xf32>
%c1f32 = constant 1.0 : f32 %c1f32 = constant 1.0 : f32
%c0 = constant 0 : index %c0 = constant 0 : index
store %c1f32, %input_buf[%c0, %c0] : memref<4x1xf32> memref.store %c1f32, %input_buf[%c0, %c0] : memref<4x1xf32>
%c2f32 = constant 2.0 : f32 %c2f32 = constant 2.0 : f32
%c1 = constant 1 : index %c1 = constant 1 : index
store %c2f32, %input_buf[%c1, %c0] : memref<4x1xf32> memref.store %c2f32, %input_buf[%c1, %c0] : memref<4x1xf32>
%c3f32 = constant 3.0 : f32 %c3f32 = constant 3.0 : f32
%c2 = constant 2 : index %c2 = constant 2 : index
store %c3f32, %input_buf[%c2, %c0] : memref<4x1xf32> memref.store %c3f32, %input_buf[%c2, %c0] : memref<4x1xf32>
%c4f32 = constant 4.0 : f32 %c4f32 = constant 4.0 : f32
%c3 = constant 3 : index %c3 = constant 3 : index
store %c4f32, %input_buf[%c3, %c0] : memref<4x1xf32> memref.store %c4f32, %input_buf[%c3, %c0] : memref<4x1xf32>
%input = tensor_load %input_buf : memref<4x1xf32> %input = memref.tensor_load %input_buf : memref<4x1xf32>
// Test BroadcastInDimOp. // Test BroadcastInDimOp.
%output = "mhlo.broadcast_in_dim"(%input) { %output = "mhlo.broadcast_in_dim"(%input) {
broadcast_dimensions = dense<[1, 0]> : tensor<2xi64> broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
} : (tensor<4x1xf32>) -> tensor<3x4xf32> } : (tensor<4x1xf32>) -> tensor<3x4xf32>
%output_buf = tensor_to_memref %output : memref<3x4xf32> %output_buf = memref.buffer_cast %output : memref<3x4xf32>
%unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32> %unranked_output = memref.cast %output_buf : memref<3x4xf32> to memref<*xf32>
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
// CHECK-NEXT: [1, 2, 3, 4] // CHECK-NEXT: [1, 2, 3, 4]
@ -208,9 +208,9 @@ func @broadcast_in_X_dim_transpose_wrapper() {
broadcast_dimensions = dense<[1, 0]> : tensor<2xi64> broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
} : (tensor<4x1xf32>, tensor<2xindex>) -> tensor<3x4xf32> } : (tensor<4x1xf32>, tensor<2xindex>) -> tensor<3x4xf32>
%dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32> %dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x4xf32>
%unranked_dyn_output = memref_cast %dyn_output_buf %unranked_dyn_output = memref.cast %dyn_output_buf
: memref<3x4xf32> to memref<*xf32> : memref<3x4xf32> to memref<*xf32>
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> () call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
@ -221,26 +221,26 @@ func @broadcast_in_X_dim_transpose_wrapper() {
} }
func @broadcast_in_Y_dim_transpose_wrapper() { func @broadcast_in_Y_dim_transpose_wrapper() {
%input_buf = alloc() : memref<1x3xf32> %input_buf = memref.alloc() : memref<1x3xf32>
%c1f32 = constant 1.0 : f32 %c1f32 = constant 1.0 : f32
%c0 = constant 0 : index %c0 = constant 0 : index
store %c1f32, %input_buf[%c0, %c0] : memref<1x3xf32> memref.store %c1f32, %input_buf[%c0, %c0] : memref<1x3xf32>
%c2f32 = constant 2.0 : f32 %c2f32 = constant 2.0 : f32
%c1 = constant 1 : index %c1 = constant 1 : index
store %c2f32, %input_buf[%c0, %c1] : memref<1x3xf32> memref.store %c2f32, %input_buf[%c0, %c1] : memref<1x3xf32>
%c3f32 = constant 3.0 : f32 %c3f32 = constant 3.0 : f32
%c2 = constant 2 : index %c2 = constant 2 : index
store %c3f32, %input_buf[%c0, %c2] : memref<1x3xf32> memref.store %c3f32, %input_buf[%c0, %c2] : memref<1x3xf32>
%input = tensor_load %input_buf : memref<1x3xf32> %input = memref.tensor_load %input_buf : memref<1x3xf32>
// Test BroadcastInDimOp. // Test BroadcastInDimOp.
%output = "mhlo.broadcast_in_dim"(%input) { %output = "mhlo.broadcast_in_dim"(%input) {
broadcast_dimensions = dense<[1, 0]> : tensor<2xi64> broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
} : (tensor<1x3xf32>) -> tensor<3x4xf32> } : (tensor<1x3xf32>) -> tensor<3x4xf32>
%output_buf = tensor_to_memref %output : memref<3x4xf32> %output_buf = memref.buffer_cast %output : memref<3x4xf32>
%unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32> %unranked_output = memref.cast %output_buf : memref<3x4xf32> to memref<*xf32>
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
// CHECK-NEXT-NEXT: [1, 1, 1, 1] // CHECK-NEXT-NEXT: [1, 1, 1, 1]
@ -255,9 +255,9 @@ func @broadcast_in_Y_dim_transpose_wrapper() {
broadcast_dimensions = dense<[1, 0]> : tensor<2xi64> broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
} : (tensor<1x3xf32>, tensor<2xindex>) -> tensor<3x4xf32> } : (tensor<1x3xf32>, tensor<2xindex>) -> tensor<3x4xf32>
%dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32> %dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x4xf32>
%unranked_dyn_output = memref_cast %dyn_output_buf %unranked_dyn_output = memref.cast %dyn_output_buf
: memref<3x4xf32> to memref<*xf32> : memref<3x4xf32> to memref<*xf32>
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> () call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
@ -268,20 +268,20 @@ func @broadcast_in_Y_dim_transpose_wrapper() {
} }
func @broadcast_scalar_1d_wrapper() { func @broadcast_scalar_1d_wrapper() {
%input_buf = alloc() : memref<1xf32> %input_buf = memref.alloc() : memref<1xf32>
%c1f32 = constant 1.0 : f32 %c1f32 = constant 1.0 : f32
%c0 = constant 0 : index %c0 = constant 0 : index
store %c1f32, %input_buf[%c0] : memref<1xf32> memref.store %c1f32, %input_buf[%c0] : memref<1xf32>
%input = tensor_load %input_buf : memref<1xf32> %input = memref.tensor_load %input_buf : memref<1xf32>
// Test BroadcastInDimOp. // Test BroadcastInDimOp.
%output = "mhlo.broadcast_in_dim"(%input) { %output = "mhlo.broadcast_in_dim"(%input) {
broadcast_dimensions = dense<0> : tensor<1xi64> broadcast_dimensions = dense<0> : tensor<1xi64>
} : (tensor<1xf32>) -> tensor<3x4xf32> } : (tensor<1xf32>) -> tensor<3x4xf32>
%output_buf = tensor_to_memref %output : memref<3x4xf32> %output_buf = memref.buffer_cast %output : memref<3x4xf32>
%unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32> %unranked_output = memref.cast %output_buf : memref<3x4xf32> to memref<*xf32>
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
// CHECK-NEXT: [1, 1, 1, 1] // CHECK-NEXT: [1, 1, 1, 1]
@ -296,9 +296,9 @@ func @broadcast_scalar_1d_wrapper() {
broadcast_dimensions = dense<0> : tensor<1xi64> broadcast_dimensions = dense<0> : tensor<1xi64>
} : (tensor<1xf32>, tensor<2xindex>) -> tensor<3x4xf32> } : (tensor<1xf32>, tensor<2xindex>) -> tensor<3x4xf32>
%dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32> %dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x4xf32>
%unranked_dyn_output = memref_cast %dyn_output_buf %unranked_dyn_output = memref.cast %dyn_output_buf
: memref<3x4xf32> to memref<*xf32> : memref<3x4xf32> to memref<*xf32>
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> () call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
@ -309,20 +309,20 @@ func @broadcast_scalar_1d_wrapper() {
} }
func @broadcast_scalar_2d_wrapper() { func @broadcast_scalar_2d_wrapper() {
%input_buf = alloc() : memref<1x1xf32> %input_buf = memref.alloc() : memref<1x1xf32>
%c1f32 = constant 1.0 : f32 %c1f32 = constant 1.0 : f32
%c0 = constant 0 : index %c0 = constant 0 : index
store %c1f32, %input_buf[%c0, %c0] : memref<1x1xf32> memref.store %c1f32, %input_buf[%c0, %c0] : memref<1x1xf32>
%input = tensor_load %input_buf : memref<1x1xf32> %input = memref.tensor_load %input_buf : memref<1x1xf32>
// Test BroadcastInDimOp. // Test BroadcastInDimOp.
%output = "mhlo.broadcast_in_dim"(%input) { %output = "mhlo.broadcast_in_dim"(%input) {
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<1x1xf32>) -> tensor<3x4xf32> } : (tensor<1x1xf32>) -> tensor<3x4xf32>
%output_buf = tensor_to_memref %output : memref<3x4xf32> %output_buf = memref.buffer_cast %output : memref<3x4xf32>
%unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32> %unranked_output = memref.cast %output_buf : memref<3x4xf32> to memref<*xf32>
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
// CHECK-NEXT: [1, 1, 1, 1] // CHECK-NEXT: [1, 1, 1, 1]
@ -337,9 +337,9 @@ func @broadcast_scalar_2d_wrapper() {
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<1x1xf32>, tensor<2xindex>) -> tensor<3x4xf32> } : (tensor<1x1xf32>, tensor<2xindex>) -> tensor<3x4xf32>
%dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32> %dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x4xf32>
%unranked_dyn_output = memref_cast %dyn_output_buf %unranked_dyn_output = memref.cast %dyn_output_buf
: memref<3x4xf32> to memref<*xf32> : memref<3x4xf32> to memref<*xf32>
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> () call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
@ -350,7 +350,7 @@ func @broadcast_scalar_2d_wrapper() {
} }
func @broadcast_to_the_same_shape() { func @broadcast_to_the_same_shape() {
%input_buf = alloc() : memref<2x3xf32> %input_buf = memref.alloc() : memref<2x3xf32>
%c1f32 = constant 1.0 : f32 %c1f32 = constant 1.0 : f32
%c2f32 = constant 2.0 : f32 %c2f32 = constant 2.0 : f32
@ -360,22 +360,22 @@ func @broadcast_to_the_same_shape() {
%c1 = constant 1 : index %c1 = constant 1 : index
%c2 = constant 2 : index %c2 = constant 2 : index
%c3 = constant 3 : index %c3 = constant 3 : index
store %c1f32, %input_buf[%c0, %c0] : memref<2x3xf32> memref.store %c1f32, %input_buf[%c0, %c0] : memref<2x3xf32>
store %c1f32, %input_buf[%c1, %c0] : memref<2x3xf32> memref.store %c1f32, %input_buf[%c1, %c0] : memref<2x3xf32>
store %c2f32, %input_buf[%c0, %c1] : memref<2x3xf32> memref.store %c2f32, %input_buf[%c0, %c1] : memref<2x3xf32>
store %c2f32, %input_buf[%c1, %c1] : memref<2x3xf32> memref.store %c2f32, %input_buf[%c1, %c1] : memref<2x3xf32>
store %c3f32, %input_buf[%c0, %c2] : memref<2x3xf32> memref.store %c3f32, %input_buf[%c0, %c2] : memref<2x3xf32>
store %c3f32, %input_buf[%c1, %c2] : memref<2x3xf32> memref.store %c3f32, %input_buf[%c1, %c2] : memref<2x3xf32>
%input = tensor_load %input_buf : memref<2x3xf32> %input = memref.tensor_load %input_buf : memref<2x3xf32>
// Test BroadcastInDimOp. // Test BroadcastInDimOp.
%output = "mhlo.broadcast_in_dim"(%input) { %output = "mhlo.broadcast_in_dim"(%input) {
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<2x3xf32>) -> tensor<2x3xf32> } : (tensor<2x3xf32>) -> tensor<2x3xf32>
%output_buf = tensor_to_memref %output : memref<2x3xf32> %output_buf = memref.buffer_cast %output : memref<2x3xf32>
%unraked_output = memref_cast %output_buf : memref<2x3xf32> to memref<*xf32> %unraked_output = memref.cast %output_buf : memref<2x3xf32> to memref<*xf32>
call @print_memref_f32(%unraked_output) : (memref<*xf32>) -> () call @print_memref_f32(%unraked_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1] // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
// CHECK-NEXT: [1, 2, 3] // CHECK-NEXT: [1, 2, 3]
@ -387,9 +387,9 @@ func @broadcast_to_the_same_shape() {
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<2x3xf32>, tensor<2xindex>) -> tensor<2x3xf32> } : (tensor<2x3xf32>, tensor<2xindex>) -> tensor<2x3xf32>
%dyn_output_buf = tensor_to_memref %dyn_output : memref<2x3xf32> %dyn_output_buf = memref.buffer_cast %dyn_output : memref<2x3xf32>
%unranked_dyn_output = memref_cast %dyn_output_buf %unranked_dyn_output = memref.cast %dyn_output_buf
: memref<2x3xf32> to memref<*xf32> : memref<2x3xf32> to memref<*xf32>
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> () call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1] // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
@ -399,7 +399,7 @@ func @broadcast_to_the_same_shape() {
} }
func @broadcast_1d_to_2d() { func @broadcast_1d_to_2d() {
%input_buf = alloc() : memref<3xf32> %input_buf = memref.alloc() : memref<3xf32>
%c1f32 = constant 1.0 : f32 %c1f32 = constant 1.0 : f32
%c2f32 = constant 2.0 : f32 %c2f32 = constant 2.0 : f32
@ -408,19 +408,19 @@ func @broadcast_1d_to_2d() {
%c0 = constant 0 : index %c0 = constant 0 : index
%c1 = constant 1 : index %c1 = constant 1 : index
%c2 = constant 2 : index %c2 = constant 2 : index
store %c1f32, %input_buf[%c0] : memref<3xf32> memref.store %c1f32, %input_buf[%c0] : memref<3xf32>
store %c2f32, %input_buf[%c1] : memref<3xf32> memref.store %c2f32, %input_buf[%c1] : memref<3xf32>
store %c3f32, %input_buf[%c2] : memref<3xf32> memref.store %c3f32, %input_buf[%c2] : memref<3xf32>
%input = tensor_load %input_buf : memref<3xf32> %input = memref.tensor_load %input_buf : memref<3xf32>
// Test BroadcastInDimOp. // Test BroadcastInDimOp.
%output = "mhlo.broadcast_in_dim"(%input) { %output = "mhlo.broadcast_in_dim"(%input) {
broadcast_dimensions = dense<0> : tensor<1xi64> broadcast_dimensions = dense<0> : tensor<1xi64>
} : (tensor<3xf32>) -> tensor<3x3xf32> } : (tensor<3xf32>) -> tensor<3x3xf32>
%output_buf = tensor_to_memref %output : memref<3x3xf32> %output_buf = memref.buffer_cast %output : memref<3x3xf32>
%unraked_output = memref_cast %output_buf : memref<3x3xf32> to memref<*xf32> %unraked_output = memref.cast %output_buf : memref<3x3xf32> to memref<*xf32>
call @print_memref_f32(%unraked_output) : (memref<*xf32>) -> () call @print_memref_f32(%unraked_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1] // CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1]
// CHECK-NEXT: [1, 1, 1] // CHECK-NEXT: [1, 1, 1]
@ -435,9 +435,9 @@ func @broadcast_1d_to_2d() {
broadcast_dimensions = dense<0> : tensor<1xi64> broadcast_dimensions = dense<0> : tensor<1xi64>
} : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x3xf32> } : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x3xf32>
%dyn_output_buf = tensor_to_memref %dyn_output : memref<3x3xf32> %dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x3xf32>
%unranked_dyn_output = memref_cast %dyn_output_buf %unranked_dyn_output = memref.cast %dyn_output_buf
: memref<3x3xf32> to memref<*xf32> : memref<3x3xf32> to memref<*xf32>
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> () call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1] // CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1]
@ -448,7 +448,7 @@ func @broadcast_1d_to_2d() {
} }
func @broadcast_1d_to_2d_with_transpose() { func @broadcast_1d_to_2d_with_transpose() {
%input_buf = alloc() : memref<3xf32> %input_buf = memref.alloc() : memref<3xf32>
%c1f32 = constant 1.0 : f32 %c1f32 = constant 1.0 : f32
%c2f32 = constant 2.0 : f32 %c2f32 = constant 2.0 : f32
@ -457,19 +457,19 @@ func @broadcast_1d_to_2d_with_transpose() {
%c0 = constant 0 : index %c0 = constant 0 : index
%c1 = constant 1 : index %c1 = constant 1 : index
%c2 = constant 2 : index %c2 = constant 2 : index
store %c1f32, %input_buf[%c0] : memref<3xf32> memref.store %c1f32, %input_buf[%c0] : memref<3xf32>
store %c2f32, %input_buf[%c1] : memref<3xf32> memref.store %c2f32, %input_buf[%c1] : memref<3xf32>
store %c3f32, %input_buf[%c2] : memref<3xf32> memref.store %c3f32, %input_buf[%c2] : memref<3xf32>
%input = tensor_load %input_buf : memref<3xf32> %input = memref.tensor_load %input_buf : memref<3xf32>
// Test BroadcastInDimOp. // Test BroadcastInDimOp.
%output = "mhlo.broadcast_in_dim"(%input) { %output = "mhlo.broadcast_in_dim"(%input) {
broadcast_dimensions = dense<1> : tensor<1xi64> broadcast_dimensions = dense<1> : tensor<1xi64>
} : (tensor<3xf32>) -> tensor<3x3xf32> } : (tensor<3xf32>) -> tensor<3x3xf32>
%output_buf = tensor_to_memref %output : memref<3x3xf32> %output_buf = memref.buffer_cast %output : memref<3x3xf32>
%unraked_output = memref_cast %output_buf : memref<3x3xf32> to memref<*xf32> %unraked_output = memref.cast %output_buf : memref<3x3xf32> to memref<*xf32>
call @print_memref_f32(%unraked_output) : (memref<*xf32>) -> () call @print_memref_f32(%unraked_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1] // CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1]
// CHECK-NEXT: [1, 2, 3] // CHECK-NEXT: [1, 2, 3]
@ -483,9 +483,9 @@ func @broadcast_1d_to_2d_with_transpose() {
broadcast_dimensions = dense<1> : tensor<1xi64> broadcast_dimensions = dense<1> : tensor<1xi64>
} : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x3xf32> } : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x3xf32>
%dyn_output_buf = tensor_to_memref %dyn_output : memref<3x3xf32> %dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x3xf32>
%unranked_dyn_output = memref_cast %dyn_output_buf %unranked_dyn_output = memref.cast %dyn_output_buf
: memref<3x3xf32> to memref<*xf32> : memref<3x3xf32> to memref<*xf32>
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> () call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1] // CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1]

View File

@ -4,10 +4,10 @@ func private @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface
// Helper function to print scalar values. // Helper function to print scalar values.
func @print_f32(%arg : f32) { func @print_f32(%arg : f32) {
%mem = alloca() : memref<1xf32> %mem = memref.alloca() : memref<1xf32>
%c0 = constant 0 : index %c0 = constant 0 : index
store %arg, %mem[%c0] : memref<1xf32> memref.store %arg, %mem[%c0] : memref<1xf32>
%mem_unranked = memref_cast %mem : memref<1xf32> to memref<*xf32> %mem_unranked = memref.cast %mem : memref<1xf32> to memref<*xf32>
call @print_memref_f32(%mem_unranked) : (memref<*xf32>) -> () call @print_memref_f32(%mem_unranked) : (memref<*xf32>) -> ()
return return
} }

View File

@ -21,21 +21,21 @@ func @reduce_add() {
%c1 = constant 1 : index %c1 = constant 1 : index
// Initialize input. // Initialize input.
%input = alloc() : memref<2x3xf32> %input = memref.alloc() : memref<2x3xf32>
%dim_x = dim %input, %c0 : memref<2x3xf32> %dim_x = memref.dim %input, %c0 : memref<2x3xf32>
%dim_y = dim %input, %c1 : memref<2x3xf32> %dim_y = memref.dim %input, %c1 : memref<2x3xf32>
scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) { scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) {
%i_i64 = index_cast %i : index to i64 %i_i64 = index_cast %i : index to i64
%i_f32 = sitofp %i_i64 : i64 to f32 %i_f32 = sitofp %i_i64 : i64 to f32
store %i_f32, %input[%i, %j] : memref<2x3xf32> memref.store %i_f32, %input[%i, %j] : memref<2x3xf32>
} }
%unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32> %unranked_input = memref.cast %input : memref<2x3xf32> to memref<*xf32>
call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> () call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1] // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
// CHECK: [0, 0, 0] // CHECK: [0, 0, 0]
// CHECK: [1, 1, 1] // CHECK: [1, 1, 1]
%in = tensor_load %input : memref<2x3xf32> %in = memref.tensor_load %input : memref<2x3xf32>
%init = mhlo.constant dense<0.000000e+00> : tensor<f32> %init = mhlo.constant dense<0.000000e+00> : tensor<f32>
%reduce = "mhlo.reduce"(%in, %init) ( { %reduce = "mhlo.reduce"(%in, %init) ( {
@ -45,8 +45,8 @@ func @reduce_add() {
}) {dimensions = dense<1> : tensor<1xi64>} }) {dimensions = dense<1> : tensor<1xi64>}
: (tensor<2x3xf32>, tensor<f32>) -> tensor<2xf32> : (tensor<2x3xf32>, tensor<f32>) -> tensor<2xf32>
%output = tensor_to_memref %reduce : memref<2xf32> %output = memref.buffer_cast %reduce : memref<2xf32>
%unranked_output = memref_cast %output : memref<2xf32> to memref<*xf32> %unranked_output = memref.cast %output : memref<2xf32> to memref<*xf32>
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
// CHECK: rank = 1 offset = 0 sizes = [2] strides = [1] // CHECK: rank = 1 offset = 0 sizes = [2] strides = [1]
// CHECK: [0, 3] // CHECK: [0, 3]
@ -58,21 +58,21 @@ func @reduce_max() {
%c1 = constant 1 : index %c1 = constant 1 : index
// Initialize input. // Initialize input.
%input = alloc() : memref<2x3xf32> %input = memref.alloc() : memref<2x3xf32>
%dim_x = dim %input, %c0 : memref<2x3xf32> %dim_x = memref.dim %input, %c0 : memref<2x3xf32>
%dim_y = dim %input, %c1 : memref<2x3xf32> %dim_y = memref.dim %input, %c1 : memref<2x3xf32>
scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) { scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) {
%i_i64 = index_cast %i : index to i64 %i_i64 = index_cast %i : index to i64
%i_f32 = sitofp %i_i64 : i64 to f32 %i_f32 = sitofp %i_i64 : i64 to f32
store %i_f32, %input[%i, %j] : memref<2x3xf32> memref.store %i_f32, %input[%i, %j] : memref<2x3xf32>
} }
%unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32> %unranked_input = memref.cast %input : memref<2x3xf32> to memref<*xf32>
call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> () call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1] // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
// CHECK: [0, 0, 0] // CHECK: [0, 0, 0]
// CHECK: [1, 1, 1] // CHECK: [1, 1, 1]
%in = tensor_load %input : memref<2x3xf32> %in = memref.tensor_load %input : memref<2x3xf32>
%init = mhlo.constant dense<0xff800000> : tensor<f32> %init = mhlo.constant dense<0xff800000> : tensor<f32>
%reduce = "mhlo.reduce"(%in, %init) ( { %reduce = "mhlo.reduce"(%in, %init) ( {
@ -82,8 +82,8 @@ func @reduce_max() {
}) {dimensions = dense<1> : tensor<1xi64>} }) {dimensions = dense<1> : tensor<1xi64>}
: (tensor<2x3xf32>, tensor<f32>) -> tensor<2xf32> : (tensor<2x3xf32>, tensor<f32>) -> tensor<2xf32>
%output = tensor_to_memref %reduce : memref<2xf32> %output = memref.buffer_cast %reduce : memref<2xf32>
%unranked_output = memref_cast %output : memref<2xf32> to memref<*xf32> %unranked_output = memref.cast %output : memref<2xf32> to memref<*xf32>
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
// CHECK: rank = 1 offset = 0 sizes = [2] strides = [1] // CHECK: rank = 1 offset = 0 sizes = [2] strides = [1]
// CHECK: [0, 1] // CHECK: [0, 1]

View File

@ -17,7 +17,7 @@ func @dynamic_reshape_from_unranked(
return %reshaped : tensor<?xf32> return %reshaped : tensor<?xf32>
} }
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>, [[SHAPE:%.*]]: memref<1xi32>) // CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>, [[SHAPE:%.*]]: memref<1xi32>)
// CHECK-NEXT: memref_reshape [[ARG]]([[SHAPE]]) // CHECK-NEXT: memref.reshape [[ARG]]([[SHAPE]])
// CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref<?xf32> // CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
// ----- // -----
@ -30,7 +30,7 @@ func @dynamic_reshape_to_unranked(
return %reshaped : tensor<*xf32> return %reshaped : tensor<*xf32>
} }
// CHECK-SAME: ([[ARG:%.*]]: memref<?xf32>, [[SHAPE:%.*]]: memref<?xi32>) // CHECK-SAME: ([[ARG:%.*]]: memref<?xf32>, [[SHAPE:%.*]]: memref<?xi32>)
// CHECK-NEXT: memref_reshape [[ARG]]([[SHAPE]]) // CHECK-NEXT: memref.reshape [[ARG]]([[SHAPE]])
// CHECK-SAME: : (memref<?xf32>, memref<?xi32>) -> memref<*xf32> // CHECK-SAME: : (memref<?xf32>, memref<?xi32>) -> memref<*xf32>
// ----- // -----
@ -41,4 +41,4 @@ func @reshape_unranked(%operand: tensor<*xf32>) -> tensor<f32> {
return %reshaped : tensor<f32> return %reshaped : tensor<f32>
} }
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>) // CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>)
// CHECK-NEXT: memref_cast [[ARG]] : memref<*xf32> to memref<f32> // CHECK-NEXT: memref.cast [[ARG]] : memref<*xf32> to memref<f32>

View File

@ -31,20 +31,20 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
return %5 : tensor<4xf32> return %5 : tensor<4xf32>
} }
// CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>) -> memref<4xf32> // CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>) -> memref<4xf32>
// CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32> // CHECK-NEXT: %[[MAX_RESULT:.*]] = memref.alloc() : memref<4xf32>
// CHECK-NEXT: "lmhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]]) // CHECK-NEXT: "lmhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32> // CHECK-NEXT: %[[ADD_RESULT:.*]] = memref.alloc() : memref<4xf32>
// CHECK-NEXT: "lmhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]]) // CHECK-NEXT: "lmhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
// CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32> // CHECK-NEXT: memref.dealloc %[[MAX_RESULT]] : memref<4xf32>
// CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32> // CHECK-NEXT: %[[MIN_RESULT:.*]] = memref.alloc() : memref<4xf32>
// CHECK-NEXT: "lmhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]]) // CHECK-NEXT: "lmhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
// CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32> // CHECK-NEXT: %[[SUB_RESULT:.*]] = memref.alloc() : memref<4xf32>
//  CHECK-NEXT: "lmhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]]) //  CHECK-NEXT: "lmhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
// CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32> // CHECK-NEXT: memref.dealloc %[[MIN_RESULT]] : memref<4xf32>
// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32> // CHECK-NEXT: %[[MUL_RESULT:.*]] = memref.alloc() : memref<4xf32>
// CHECK-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]]) // CHECK-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
// CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32> // CHECK-NEXT: memref.dealloc %[[SUB_RESULT]] : memref<4xf32>
// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32> // CHECK-NEXT: memref.dealloc %[[ADD_RESULT]] : memref<4xf32>
// CHECK-NEXT: return %[[MUL_RESULT]] : memref<4xf32> // CHECK-NEXT: return %[[MUL_RESULT]] : memref<4xf32>
// ----- // -----
@ -53,15 +53,15 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
func @fusion(%multiplier: tensor<2x2xf32>, %summand_1: tensor<2x2xf32>, func @fusion(%multiplier: tensor<2x2xf32>, %summand_1: tensor<2x2xf32>,
%summand_2: tensor<2x2xf32>) -> tensor<2x2xf32> { %summand_2: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}}) // CHECK: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}})
// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32> // CHECK-NEXT: %[[ADD_RESULT:.*]] = memref.alloc() : memref<2x2xf32>
%sum = "mhlo.add"(%summand_1, %summand_2) %sum = "mhlo.add"(%summand_1, %summand_2)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: "lmhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]]) // CHECK-NEXT: "lmhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32> // CHECK-NEXT: %[[MUL_RESULT:.*]] = memref.alloc() : memref<2x2xf32>
%result = "mhlo.multiply"(%sum, %multiplier) %result = "mhlo.multiply"(%sum, %multiplier)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]]) // CHECK-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32> // CHECK-NEXT: memref.dealloc %[[ADD_RESULT]] : memref<2x2xf32>
// CHECK-NEXT: return %[[MUL_RESULT]] : memref<2x2xf32> // CHECK-NEXT: return %[[MUL_RESULT]] : memref<2x2xf32>
return %result : tensor<2x2xf32> return %result : tensor<2x2xf32>
} }
@ -154,9 +154,9 @@ func @dyn_broadcast(%operand: tensor<?x?xf32>) -> tensor<?x?x?xf32> {
// CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C1:.*]] = constant 1 : index // CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[OPER_DIM_1:.*]] = dim %[[OPERAND]], %[[C1]] : memref<?x?xf32> // CHECK: %[[OPER_DIM_1:.*]] = memref.dim %[[OPERAND]], %[[C1]] : memref<?x?xf32>
// CHECK: %[[OP_STRIDE_0:.*]] = muli %[[C1]], %[[OPER_DIM_1]] : index // CHECK: %[[OP_STRIDE_0:.*]] = muli %[[C1]], %[[OPER_DIM_1]] : index
// CHECK: %[[OPER_DIM_0:.*]] = dim %[[OPERAND]], %[[C0]] : memref<?x?xf32> // CHECK: %[[OPER_DIM_0:.*]] = memref.dim %[[OPERAND]], %[[C0]] : memref<?x?xf32>
// CHECK: %[[EL0:.*]] = tensor.extract %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64> // CHECK: %[[EL0:.*]] = tensor.extract %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64>
// CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index // CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index
@ -172,9 +172,9 @@ func @dyn_broadcast(%operand: tensor<?x?xf32>) -> tensor<?x?x?xf32> {
// CHECK: %[[EXPAND_2:.*]] = cmpi slt, %[[OPER_DIM_1]], %[[SIZE_2]] : index // CHECK: %[[EXPAND_2:.*]] = cmpi slt, %[[OPER_DIM_1]], %[[SIZE_2]] : index
// CHECK: %[[STRIDE_2:.*]] = select %[[EXPAND_2]], %[[C0]], %[[C1]] : index // CHECK: %[[STRIDE_2:.*]] = select %[[EXPAND_2]], %[[C0]], %[[C1]] : index
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]] : memref<?x?xf32> to memref<?x?x?xf32, #map> // CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref.reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]] : memref<?x?xf32> to memref<?x?x?xf32, #map>
// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32> // CHECK: %[[RESULT:.*]] = memref.alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>
// CHECK: "lmhlo.copy"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) : (memref<?x?x?xf32, #map>, memref<?x?x?xf32>) -> () // CHECK: "lmhlo.copy"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) : (memref<?x?x?xf32, #map>, memref<?x?x?xf32>) -> ()
// CHECK: return %[[RESULT]] : memref<?x?x?xf32> // CHECK: return %[[RESULT]] : memref<?x?x?xf32>
@ -469,7 +469,7 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex> // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex>
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex> // CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex>
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex> // CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex>
// CHECK: %[[RESULT:.*]] = alloc(%[[EE0]], %[[EE1]]) // CHECK: %[[RESULT:.*]] = memref.alloc(%[[EE0]], %[[EE1]])
// CHECK: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () // CHECK: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
return %result : tensor<?x?xf32> return %result : tensor<?x?xf32>
// CHECK: return %[[RESULT]] // CHECK: return %[[RESULT]]
@ -485,7 +485,7 @@ func @tanh_dyn(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex> // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex>
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex> // CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex>
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex> // CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex>
// CHECK: %[[RESULT:.*]] = alloc(%[[EE0]], %[[EE1]]) // CHECK: %[[RESULT:.*]] = memref.alloc(%[[EE0]], %[[EE1]])
// CHECK: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> () // CHECK: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
return %result : tensor<?x?xf32> return %result : tensor<?x?xf32>
// CHECK: return %[[RESULT]] // CHECK: return %[[RESULT]]
@ -496,7 +496,7 @@ func @tanh_dyn(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK-LABEL: func @dot // CHECK-LABEL: func @dot
func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> { func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
// CHECK-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]] // CHECK-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
// CHECK-NEXT: %[[ALLOC:.*]] = alloc // CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc
// CHECK: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) { // CHECK: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) {
// dot_dimension_numbers = { // dot_dimension_numbers = {
// lhs_batching_dimensions = dense<> : tensor<0xi64>, // lhs_batching_dimensions = dense<> : tensor<0xi64>,
@ -517,7 +517,7 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>)
-> tensor<3x5x5x4xf32> { -> tensor<3x5x5x4xf32> {
%c0 = constant 0 : index %c0 = constant 0 : index
// CHECK: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32> // CHECK: %[[OUT:.*]] = memref.alloc() : memref<3x5x5x4xf32>
// CHECK: "lmhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]]) // CHECK: "lmhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
// CHECK-SAME: padding = dense<[ // CHECK-SAME: padding = dense<[
// CHECK-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64> // CHECK-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
@ -548,11 +548,11 @@ func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>)
// CHECK-LABEL: func @reduce // CHECK-LABEL: func @reduce
func @reduce(%arg0: tensor<1x8xf32>, %arg1: tensor<f32>) -> tensor<1xf32> { func @reduce(%arg0: tensor<1x8xf32>, %arg1: tensor<f32>) -> tensor<1xf32> {
// CHECK: %[[OUT:.*]] = alloc() : memref<1xf32> // CHECK: %[[OUT:.*]] = memref.alloc() : memref<1xf32>
// CHECK: "lmhlo.reduce"(%{{.+}}, %{{.+}}, %[[OUT]]) ( { // CHECK: "lmhlo.reduce"(%{{.+}}, %{{.+}}, %[[OUT]]) ( {
// CHECK: ^bb0(%[[ARG1:.*]]: memref<f32>, %[[ARG2:.*]]: memref<f32>, // CHECK: ^bb0(%[[ARG1:.*]]: memref<f32>, %[[ARG2:.*]]: memref<f32>,
// CHECK-SAME: %[[ARG3:.*]]: memref<f32>): // CHECK-SAME: %[[ARG3:.*]]: memref<f32>):
// CHECK: %[[TMP:.*]] = alloc() : memref<f32> // CHECK: %[[TMP:.*]] = memref.alloc() : memref<f32>
// CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[TMP]]) // CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[TMP]])
// CHECK: "lmhlo.copy"(%[[TMP]], %[[ARG3]]) // CHECK: "lmhlo.copy"(%[[TMP]], %[[ARG3]])
// CHECK: "lmhlo.terminator"() : () -> () // CHECK: "lmhlo.terminator"() : () -> ()

View File

@ -404,7 +404,7 @@ func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> {
return %0: tensor<4x2x1x4x?x16xf32> return %0: tensor<4x2x1x4x?x16xf32>
} }
// CHECK: %[[C1:.*]] = constant 1 : index // CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[D1:.*]] = dim %{{.*}}, %[[C1]] : tensor<4x?x16xf32> // CHECK: %[[D1:.*]] = memref.dim %{{.*}}, %[[C1]] : tensor<4x?x16xf32>
// CHECK: linalg.init_tensor [4, 2, 1, 4, %[[D1]], 16] : tensor<4x2x1x4x?x16xf32> // CHECK: linalg.init_tensor [4, 2, 1, 4, %[[D1]], 16] : tensor<4x2x1x4x?x16xf32>
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32): // CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32):
@ -1024,7 +1024,7 @@ func @dot_matmul(%arg0: tensor<2x3xf32>,
// CHECK-LABEL: func @dot_matmul( // CHECK-LABEL: func @dot_matmul(
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xf32>, %[[ARG1:.*]]: tensor<3x?xf32>) // CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xf32>, %[[ARG1:.*]]: tensor<3x?xf32>)
// CHECK: %[[C1:.*]] = constant 1 : index // CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]] // CHECK: %[[D1:.*]] = memref.dim %[[ARG1]], %[[C1]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] // CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.matmul // CHECK: linalg.matmul
@ -1040,7 +1040,7 @@ func @dot_matmul_i8_i8_i32(%arg0: tensor<2x3xi8>,
// CHECK-LABEL: func @dot_matmul_i8_i8_i32( // CHECK-LABEL: func @dot_matmul_i8_i8_i32(
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi8>, %[[ARG1:.*]]: tensor<3x?xi8>) // CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi8>, %[[ARG1:.*]]: tensor<3x?xi8>)
// CHECK: %[[C1:.*]] = constant 1 : index // CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]] // CHECK: %[[D1:.*]] = memref.dim %[[ARG1]], %[[C1]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] // CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.matmul // CHECK: linalg.matmul
@ -1058,7 +1058,7 @@ func @dot_matmul_i16_i16_i32(%arg0: tensor<2x3xi16>,
// CHECK-LABEL: func @dot_matmul_i16_i16_i32( // CHECK-LABEL: func @dot_matmul_i16_i16_i32(
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi16>, %[[ARG1:.*]]: tensor<3x?xi16>) // CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi16>, %[[ARG1:.*]]: tensor<3x?xi16>)
// CHECK: %[[C1:.*]] = constant 1 : index // CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]] // CHECK: %[[D1:.*]] = memref.dim %[[ARG1]], %[[C1]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] // CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.matmul // CHECK: linalg.matmul
@ -1076,7 +1076,7 @@ func @dot_matmul_i32_i32_i32(%arg0: tensor<2x3xi32>,
// CHECK-LABEL: func @dot_matmul_i32_i32_i32( // CHECK-LABEL: func @dot_matmul_i32_i32_i32(
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi32>, %[[ARG1:.*]]: tensor<3x?xi32>) // CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi32>, %[[ARG1:.*]]: tensor<3x?xi32>)
// CHECK: %[[C1:.*]] = constant 1 : index // CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]] // CHECK: %[[D1:.*]] = memref.dim %[[ARG1]], %[[C1]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] // CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.matmul // CHECK: linalg.matmul
@ -1094,7 +1094,7 @@ func @dot_matvec(%arg0: tensor<?x3xf32>,
// CHECK-LABEL: func @dot_matvec( // CHECK-LABEL: func @dot_matvec(
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x3xf32>, %[[ARG1:.*]]: tensor<3xf32>) // CHECK-SAME: %[[ARG0:.*]]: tensor<?x3xf32>, %[[ARG1:.*]]: tensor<3xf32>)
// CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]] // CHECK: %[[D0:.*]] = memref.dim %[[ARG0]], %[[C0]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] // CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.matvec // CHECK: linalg.matvec
@ -1134,11 +1134,11 @@ func @dot_general_batch_matmul(%arg0: tensor<?x?x3xf32>,
// CHECK-LABEL: func @dot_general_batch_matmul( // CHECK-LABEL: func @dot_general_batch_matmul(
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x3xf32>, %[[ARG1:.*]]: tensor<?x3x?xf32>) // CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x3xf32>, %[[ARG1:.*]]: tensor<?x3x?xf32>)
// CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]] // CHECK: %[[D0:.*]] = memref.dim %[[ARG0]], %[[C0]]
// CHECK: %[[C1:.*]] = constant 1 : index // CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[D1:.*]] = dim %[[ARG0]], %[[C1]] // CHECK: %[[D1:.*]] = memref.dim %[[ARG0]], %[[C1]]
// CHECK: %[[C2:.*]] = constant 2 : index // CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]] // CHECK: %[[D2:.*]] = memref.dim %[[ARG1]], %[[C2]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] // CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.batch_matmul // CHECK: linalg.batch_matmul
@ -1163,11 +1163,11 @@ func @dot_general_batch_matmul_i8_i8_i32(%arg0: tensor<?x?x3xi8>,
// CHECK-LABEL: func @dot_general_batch_matmul_i8_i8_i32( // CHECK-LABEL: func @dot_general_batch_matmul_i8_i8_i32(
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x3xi8>, %[[ARG1:.*]]: tensor<?x3x?xi8>) // CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x3xi8>, %[[ARG1:.*]]: tensor<?x3x?xi8>)
// CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]] // CHECK: %[[D0:.*]] = memref.dim %[[ARG0]], %[[C0]]
// CHECK: %[[C1:.*]] = constant 1 : index // CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[D1:.*]] = dim %[[ARG0]], %[[C1]] // CHECK: %[[D1:.*]] = memref.dim %[[ARG0]], %[[C1]]
// CHECK: %[[C2:.*]] = constant 2 : index // CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]] // CHECK: %[[D2:.*]] = memref.dim %[[ARG1]], %[[C2]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] // CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.batch_matmul // CHECK: linalg.batch_matmul
@ -1192,11 +1192,11 @@ func @dot_general_batch_matmul_i16_i16_i32(%arg0: tensor<?x?x3xi16>,
// CHECK-LABEL: func @dot_general_batch_matmul_i16_i16_i32( // CHECK-LABEL: func @dot_general_batch_matmul_i16_i16_i32(
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x3xi16>, %[[ARG1:.*]]: tensor<?x3x?xi16>) // CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x3xi16>, %[[ARG1:.*]]: tensor<?x3x?xi16>)
// CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]] // CHECK: %[[D0:.*]] = memref.dim %[[ARG0]], %[[C0]]
// CHECK: %[[C1:.*]] = constant 1 : index // CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[D1:.*]] = dim %[[ARG0]], %[[C1]] // CHECK: %[[D1:.*]] = memref.dim %[[ARG0]], %[[C1]]
// CHECK: %[[C2:.*]] = constant 2 : index // CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]] // CHECK: %[[D2:.*]] = memref.dim %[[ARG1]], %[[C2]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] // CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.batch_matmul // CHECK: linalg.batch_matmul
@ -1420,7 +1420,7 @@ func @reduce_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<i32>) -> tensor<?xi32
// CHECK: func @reduce_dynamic(%[[ARG0:.*]]: tensor<?x?xi32> // CHECK: func @reduce_dynamic(%[[ARG0:.*]]: tensor<?x?xi32>
// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32> // CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
// CHECK-DAG: %[[C0:.*]] = constant 0 : index // CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK-DAG: %[[DIM1:.*]] = dim %[[ARG0]], %[[C0]] : tensor<?x?xi32> // CHECK-DAG: %[[DIM1:.*]] = memref.dim %[[ARG0]], %[[C0]] : tensor<?x?xi32>
// CHECK-DAG: %[[INIT_TENSOR:.*]] = linalg.init_tensor [%[[DIM1]]] // CHECK-DAG: %[[INIT_TENSOR:.*]] = linalg.init_tensor [%[[DIM1]]]
// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill(%[[INIT_TENSOR]], %[[INIT]]) // CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill(%[[INIT_TENSOR]], %[[INIT]])
// CHECK: linalg.generic // CHECK: linalg.generic
@ -1531,9 +1531,9 @@ func @linalg.conv_1d_input_nwc_filter_wcf(%arg0: tensor<?x8x?xf32>, %arg1: tenso
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
// CHECK: %[[C0:.+]] = constant 0 : index // CHECK: %[[C0:.+]] = constant 0 : index
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x8x?xf32> // CHECK: %[[DIM0:.+]] = memref.dim %[[ARG0]], %[[C0]] : tensor<?x8x?xf32>
// CHECK: %[[C2:.+]] = constant 2 : index // CHECK: %[[C2:.+]] = constant 2 : index
// CHECK: %[[DIM2:.+]] = dim %[[ARG1]], %[[C2]] : tensor<2x?x?xf32> // CHECK: %[[DIM2:.+]] = memref.dim %[[ARG1]], %[[C2]] : tensor<2x?x?xf32>
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 7, %[[DIM2]]] // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 7, %[[DIM2]]]
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32 // CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]]) // CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
@ -1571,9 +1571,9 @@ func @conv_2d_input_nhwc_filter_hwcf(%arg0: tensor<?x4x5x?xf32>, %arg1: tensor<3
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
// CHECK: %[[C0:.+]] = constant 0 : index // CHECK: %[[C0:.+]] = constant 0 : index
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x4x5x?xf32> // CHECK: %[[DIM0:.+]] = memref.dim %[[ARG0]], %[[C0]] : tensor<?x4x5x?xf32>
// CHECK: %[[C3:.+]] = constant 3 : index // CHECK: %[[C3:.+]] = constant 3 : index
// CHECK: %[[DIM3:.+]] = dim %[[ARG1]], %[[C3]] : tensor<3x2x?x?xf32> // CHECK: %[[DIM3:.+]] = memref.dim %[[ARG1]], %[[C3]] : tensor<3x2x?x?xf32>
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 2, 3, %[[DIM3]]] // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 2, 3, %[[DIM3]]]
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32 // CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]]) // CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
@ -1611,9 +1611,9 @@ func @conv_3d_input_ndhwc_filter_dhwcf(%arg0: tensor<?x8x8x8x?xf32>, %arg1: tens
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
// CHECK: %[[C0:.+]] = constant 0 : index // CHECK: %[[C0:.+]] = constant 0 : index
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x8x8x8x?xf32> // CHECK: %[[DIM0:.+]] = memref.dim %[[ARG0]], %[[C0]] : tensor<?x8x8x8x?xf32>
// CHECK: %[[C4:.+]] = constant 4 : index // CHECK: %[[C4:.+]] = constant 4 : index
// CHECK: %[[DIM4:.+]] = dim %[[ARG1]], %[[C4]] : tensor<2x2x2x?x?xf32> // CHECK: %[[DIM4:.+]] = memref.dim %[[ARG1]], %[[C4]] : tensor<2x2x2x?x?xf32>
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 7, 7, 7, %[[DIM4]]] // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 7, 7, 7, %[[DIM4]]]
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32 // CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]]) // CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])

View File

@ -7,7 +7,7 @@
iterator_types = ["parallel", "parallel"]} iterator_types = ["parallel", "parallel"]}
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
%summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) { %summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) {
%temp_result = alloc() : memref<6x6xf32> %temp_result = memref.alloc() : memref<6x6xf32>
linalg.generic #pointwise_2d_trait linalg.generic #pointwise_2d_trait
ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>) ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>)
outs(%temp_result : memref<6x6xf32>) { outs(%temp_result : memref<6x6xf32>) {
@ -22,7 +22,7 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
%out = mulf %temp_result_in, %multiplier_in : f32 %out = mulf %temp_result_in, %multiplier_in : f32
linalg.yield %out : f32 linalg.yield %out : f32
} }
dealloc %temp_result : memref<6x6xf32> memref.dealloc %temp_result : memref<6x6xf32>
return return
} }
// CHECK-LABEL: func @fusion // CHECK-LABEL: func @fusion
@ -62,7 +62,7 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
func @fusion_of_three(%arg0: memref<100x10xf32>, func @fusion_of_three(%arg0: memref<100x10xf32>,
%arg1: memref<100xf32>, %arg1: memref<100xf32>,
%arg2: memref<100x10xf32>) { %arg2: memref<100x10xf32>) {
%0 = alloc() : memref<100x10xf32> %0 = memref.alloc() : memref<100x10xf32>
linalg.generic { linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0)>, indexing_maps = [affine_map<(d0, d1) -> (d0)>,
affine_map<(d0, d1) -> (d0, d1)>], affine_map<(d0, d1) -> (d0, d1)>],
@ -72,7 +72,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
^bb0(%arg3: f32, %arg4: f32): // no predecessors ^bb0(%arg3: f32, %arg4: f32): // no predecessors
linalg.yield %arg3 : f32 linalg.yield %arg3 : f32
} }
%1 = alloc() : memref<100x10xf32> %1 = memref.alloc() : memref<100x10xf32>
linalg.generic { linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>,
@ -84,7 +84,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
%2 = subf %arg3, %arg4 : f32 %2 = subf %arg3, %arg4 : f32
linalg.yield %2 : f32 linalg.yield %2 : f32
} }
dealloc %0 : memref<100x10xf32> memref.dealloc %0 : memref<100x10xf32>
linalg.generic { linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>], affine_map<(d0, d1) -> (d0, d1)>],
@ -95,7 +95,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
%2 = math.exp %arg3 : f32 %2 = math.exp %arg3 : f32
linalg.yield %2 : f32 linalg.yield %2 : f32
} }
dealloc %1 : memref<100x10xf32> memref.dealloc %1 : memref<100x10xf32>
return return
} }
// CHECK-LABEL: func @fusion // CHECK-LABEL: func @fusion
@ -141,7 +141,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
"parallel"]} "parallel"]}
func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>, func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>,
%summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) { %summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) {
%temp_result = alloc() : memref<6x6x6x6xf32> %temp_result = memref.alloc() : memref<6x6x6x6xf32>
linalg.generic #pointwise_4d_trait linalg.generic #pointwise_4d_trait
ins(%summand_1, %summand_2 : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>) ins(%summand_1, %summand_2 : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>)
outs(%temp_result : memref<6x6x6x6xf32>) { outs(%temp_result : memref<6x6x6x6xf32>) {
@ -156,7 +156,7 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32
%out = mulf %temp_result_in, %multiplier_in : f32 %out = mulf %temp_result_in, %multiplier_in : f32
linalg.yield %out : f32 linalg.yield %out : f32
} }
dealloc %temp_result : memref<6x6x6x6xf32> memref.dealloc %temp_result : memref<6x6x6x6xf32>
return return
} }
// CHECK-LABEL: func @fusion_4d // CHECK-LABEL: func @fusion_4d
@ -200,7 +200,7 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32
iterator_types = ["parallel", "parallel"]} iterator_types = ["parallel", "parallel"]}
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
%summand_2: memref<6x6xf32>) -> memref<6x6xf32> { %summand_2: memref<6x6xf32>) -> memref<6x6xf32> {
%temp_result = alloc() : memref<6x6xf32> %temp_result = memref.alloc() : memref<6x6xf32>
linalg.generic #pointwise_2d_trait linalg.generic #pointwise_2d_trait
ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>) ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>)
outs(%temp_result : memref<6x6xf32>) { outs(%temp_result : memref<6x6xf32>) {
@ -208,7 +208,7 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
%out = addf %summand_1_in, %summand_2_in : f32 %out = addf %summand_1_in, %summand_2_in : f32
linalg.yield %out : f32 linalg.yield %out : f32
} }
%result = alloc() : memref<6x6xf32> %result = memref.alloc() : memref<6x6xf32>
linalg.generic #pointwise_2d_trait linalg.generic #pointwise_2d_trait
ins(%temp_result, %multiplier : memref<6x6xf32>, memref<6x6xf32>) ins(%temp_result, %multiplier : memref<6x6xf32>, memref<6x6xf32>)
outs(%result : memref<6x6xf32>) { outs(%result : memref<6x6xf32>) {
@ -216,7 +216,7 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
%out = mulf %temp_result_in, %multiplier_in : f32 %out = mulf %temp_result_in, %multiplier_in : f32
linalg.yield %out : f32 linalg.yield %out : f32
} }
dealloc %temp_result : memref<6x6xf32> memref.dealloc %temp_result : memref<6x6xf32>
return %result : memref<6x6xf32> return %result : memref<6x6xf32>
} }
@ -258,7 +258,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
-> memref<*xf32> { -> memref<*xf32> {
%c1 = constant 1 : index %c1 = constant 1 : index
%c0 = constant 0 : index %c0 = constant 0 : index
%1 = alloc(%arg2) : memref<?xf32> %1 = memref.alloc(%arg2) : memref<?xf32>
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>], affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]} iterator_types = ["parallel"]}
@ -267,7 +267,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
%13 = absf %arg3 : f32 %13 = absf %arg3 : f32
linalg.yield %13 : f32 linalg.yield %13 : f32
} }
%2 = memref_reshape %1(%arg1) %2 = memref.reshape %1(%arg1)
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32> : (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
return %2 : memref<*xf32> return %2 : memref<*xf32>
} }
@ -279,7 +279,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
// CHECK-NOT: scf.for // CHECK-NOT: scf.for
// CHECK: linalg.generic // CHECK: linalg.generic
// CHECK: absf // CHECK: absf
// CHECK: memref_reshape // CHECK: memref.reshape
// TILED-LABEL: func @view_result // TILED-LABEL: func @view_result
// TILED-DAG: %[[C2:.*]] = constant 2 // TILED-DAG: %[[C2:.*]] = constant 2
@ -288,7 +288,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
// TILED-NOT: scf.for // TILED-NOT: scf.for
// TILED: linalg.generic // TILED: linalg.generic
// TILED: absf // TILED: absf
// TILED: memref_reshape // TILED: memref.reshape
// PLOOP-LABEL: func @view_result // PLOOP-LABEL: func @view_result
@ -297,20 +297,20 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
// PLOOP-NOT: scf.parallel // PLOOP-NOT: scf.parallel
// PLOOP: linalg.generic // PLOOP: linalg.generic
// PLOOP: absf // PLOOP: absf
// PLOOP: memref_reshape // PLOOP: memref.reshape
// ----- // -----
// Confirm that tiling information is passed through RegionBranchOpInterfaces. // Confirm that tiling information is passed through RegionBranchOpInterfaces.
// This test also uses memref_reshape, just to have a value to return through // This test also uses memref.reshape, just to have a value to return through
// the if statement. // the if statement.
func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index) func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
-> memref<*xf32> { -> memref<*xf32> {
%c1 = constant 1 : index %c1 = constant 1 : index
%c0 = constant 0 : index %c0 = constant 0 : index
%1 = alloc(%arg2) : memref<?xf32> %1 = memref.alloc(%arg2) : memref<?xf32>
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>], affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]} iterator_types = ["parallel"]}
@ -321,11 +321,11 @@ func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: inde
} }
%true = constant 1 : i1 %true = constant 1 : i1
%3 = scf.if %true -> memref<*xf32> { %3 = scf.if %true -> memref<*xf32> {
%2 = memref_reshape %1(%arg1) %2 = memref.reshape %1(%arg1)
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32> : (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
scf.yield %2 : memref<*xf32> scf.yield %2 : memref<*xf32>
} else { } else {
%2 = memref_reshape %1(%arg1) %2 = memref.reshape %1(%arg1)
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32> : (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
scf.yield %2 : memref<*xf32> scf.yield %2 : memref<*xf32>
} }
@ -340,10 +340,10 @@ func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: inde
// CHECK: linalg.generic // CHECK: linalg.generic
// CHECK: absf // CHECK: absf
// CHECK: scf.if // CHECK: scf.if
// CHECK: memref_reshape // CHECK: memref.reshape
// CHECK: scf.yield // CHECK: scf.yield
// CHECK: else // CHECK: else
// CHECK: memref_reshape // CHECK: memref.reshape
// CHECK: scf.yield // CHECK: scf.yield
// TILED-LABEL: func @branching_result // TILED-LABEL: func @branching_result
@ -354,10 +354,10 @@ func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: inde
// TILED: linalg.generic // TILED: linalg.generic
// TILED: absf // TILED: absf
// TILED: scf.if // TILED: scf.if
// TILED: memref_reshape // TILED: memref.reshape
// TILED: scf.yield // TILED: scf.yield
// TILED: else // TILED: else
// TILED: memref_reshape // TILED: memref.reshape
// TILED: scf.yield // TILED: scf.yield
// PLOOP-LABEL: func @branching_result // PLOOP-LABEL: func @branching_result
@ -367,10 +367,10 @@ func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: inde
// PLOOP: linalg.generic // PLOOP: linalg.generic
// PLOOP: absf // PLOOP: absf
// PLOOP: scf.if // PLOOP: scf.if
// PLOOP: memref_reshape // PLOOP: memref.reshape
// PLOOP: scf.yield // PLOOP: scf.yield
// PLOOP: else // PLOOP: else
// PLOOP: memref_reshape // PLOOP: memref.reshape
// PLOOP: scf.yield // PLOOP: scf.yield
// ----- // -----
@ -380,7 +380,7 @@ func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: inde
func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>) func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
-> memref<?xf32> { -> memref<?xf32> {
%c1 = constant 1 : index %c1 = constant 1 : index
%1 = alloc() : memref<32xf32> %1 = memref.alloc() : memref<32xf32>
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>], affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]} iterator_types = ["parallel"]}
@ -389,9 +389,9 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
%13 = absf %arg3 : f32 %13 = absf %arg3 : f32
linalg.yield %13 : f32 linalg.yield %13 : f32
} }
%2 = tensor_load %1 : memref<32xf32> %2 = memref.tensor_load %1 : memref<32xf32>
%3 = tensor.cast %2 : tensor<32xf32> to tensor<?xf32> %3 = tensor.cast %2 : tensor<32xf32> to tensor<?xf32>
%4 = tensor_to_memref %3 : memref<?xf32> %4 = memref.buffer_cast %3 : memref<?xf32>
return %4 : memref<?xf32> return %4 : memref<?xf32>
} }
@ -402,9 +402,9 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
// CHECK-NOT: scf.for // CHECK-NOT: scf.for
// CHECK: linalg.generic // CHECK: linalg.generic
// CHECK: absf // CHECK: absf
// CHECK: tensor_load // CHECK: memref.tensor_load
// CHECK: tensor.cast // CHECK: tensor.cast
// CHECK: tensor_to_memref // CHECK: memref.buffer_cast
// TILED-LABEL: func @tensor_ops // TILED-LABEL: func @tensor_ops
// TILED-DAG: %[[C2:.*]] = constant 2 // TILED-DAG: %[[C2:.*]] = constant 2
@ -413,9 +413,9 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
// TILED-NOT: scf.for // TILED-NOT: scf.for
// TILED: linalg.generic // TILED: linalg.generic
// TILED: absf // TILED: absf
// TILED: tensor_load // TILED: memref.tensor_load
// TILED: tensor.cast // TILED: tensor.cast
// TILED: tensor_to_memref // TILED: memref.buffer_cast
// PLOOP-LABEL: func @tensor_ops // PLOOP-LABEL: func @tensor_ops
@ -424,6 +424,6 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
// PLOOP-NOT: scf.parallel // PLOOP-NOT: scf.parallel
// PLOOP: linalg.generic // PLOOP: linalg.generic
// PLOOP: absf // PLOOP: absf
// PLOOP: tensor_load // PLOOP: memref.tensor_load
// PLOOP: tensor.cast // PLOOP: tensor.cast
// PLOOP: tensor_to_memref // PLOOP: memref.buffer_cast

View File

@ -49,10 +49,10 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
// CHECK-DAG: [[CTRUE:%.*]] = constant true // CHECK-DAG: [[CTRUE:%.*]] = constant true
// Parallel loop to initialize the output buffer. // Parallel loop to initialize the output buffer.
// CHECK: [[INIT:%.*]] = load [[INIT_BUF]][] : memref<f32> // CHECK: [[INIT:%.*]] = memref.load [[INIT_BUF]][] : memref<f32>
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) // CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
// CHECK-SAME: to ([[C112]], [[C112]]) step ([[C1]], [[C1]]) { // CHECK-SAME: to ([[C112]], [[C112]]) step ([[C1]], [[C1]]) {
// CHECK: store [[INIT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]] // CHECK: memref.store [[INIT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]]
// CHECK: scf.yield // CHECK: scf.yield
// CHECK: } // CHECK: }
@ -101,7 +101,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
// INBOUNDS-THEN-BODY, i.e. if INBOUNDS == true // INBOUNDS-THEN-BODY, i.e. if INBOUNDS == true
// CHECK: [[ARG_ELEM:%.*]] = load [[ARG_BUF]]{{\[}}[[ARG_I]], [[ARG_J]]] // CHECK: [[ARG_ELEM:%.*]] = memref.load [[ARG_BUF]]{{\[}}[[ARG_I]], [[ARG_J]]]
// CHECK: [[IF_INIT_RES:%.*]]:4 // CHECK: [[IF_INIT_RES:%.*]]:4
// CHECK-SAME: = scf.if [[SEL_INIT]] -> (index, index, f32, i1) { // CHECK-SAME: = scf.if [[SEL_INIT]] -> (index, index, f32, i1) {
@ -114,16 +114,16 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
// Allocate buffers for ARG element, current selected value to adapt LHLO // Allocate buffers for ARG element, current selected value to adapt LHLO
// code. // code.
// CHECK: [[ARG_ELEM_BUF:%.*]] = alloc() : memref<f32> // CHECK: [[ARG_ELEM_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: [[SEL_VAL_BUF:%.*]] = alloc() : memref<f32> // CHECK: [[SEL_VAL_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: [[PRED_BUF:%.*]] = alloc() : memref<i1> // CHECK: [[PRED_BUF:%.*]] = memref.alloc() : memref<i1>
// CHECK: store [[ARG_ELEM]], [[ARG_ELEM_BUF]][] : memref<f32> // CHECK: memref.store [[ARG_ELEM]], [[ARG_ELEM_BUF]][] : memref<f32>
// CHECK: store [[SEL_VAL]], [[SEL_VAL_BUF]][] : memref<f32> // CHECK: memref.store [[SEL_VAL]], [[SEL_VAL_BUF]][] : memref<f32>
// Compute PRED. // Compute PRED.
// CHECK: "lmhlo.compare"( // CHECK: "lmhlo.compare"(
// CHECK-SAME: [[ARG_ELEM_BUF]], [[SEL_VAL_BUF]], [[PRED_BUF]]) // CHECK-SAME: [[ARG_ELEM_BUF]], [[SEL_VAL_BUF]], [[PRED_BUF]])
// CHECK: [[PRED:%.*]] = load [[PRED_BUF]][] : memref<i1> // CHECK: [[PRED:%.*]] = memref.load [[PRED_BUF]][] : memref<i1>
// Depending on PRED, return ARG ivs & elem or current select ivs and value. // Depending on PRED, return ARG ivs & elem or current select ivs and value.
@ -165,7 +165,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
// CHECK: } // CHECK: }
// Use selected ivs to load element from the SRC buffer. // Use selected ivs to load element from the SRC buffer.
// CHECK: [[SRC_ELEM:%.*]] = load [[SRC_BUF]]{{\[}}[[II]], [[JJ]]] // CHECK: [[SRC_ELEM:%.*]] = memref.load [[SRC_BUF]]{{\[}}[[II]], [[JJ]]]
// Update of RESULT[SELECTED_I, SELECTED_J] should be done atomically, because // Update of RESULT[SELECTED_I, SELECTED_J] should be done atomically, because
// it may happen that several other threads select the same IVs if the windows // it may happen that several other threads select the same IVs if the windows
@ -175,16 +175,16 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
// CHECK: ^bb0([[CUR_RES:%.*]]: f32): // CHECK: ^bb0([[CUR_RES:%.*]]: f32):
// Allocate buffers for ARG element, current selected value to adapt LHLO code. // Allocate buffers for ARG element, current selected value to adapt LHLO code.
// CHECK: [[SRC_ELEM_BUF:%.*]] = alloc() : memref<f32> // CHECK: [[SRC_ELEM_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: [[CUR_RES_BUF:%.*]] = alloc() : memref<f32> // CHECK: [[CUR_RES_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: [[RES_BUF:%.*]] = alloc() : memref<f32> // CHECK: [[RES_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: store [[SRC_ELEM]], [[SRC_ELEM_BUF]][] : memref<f32> // CHECK: memref.store [[SRC_ELEM]], [[SRC_ELEM_BUF]][] : memref<f32>
// CHECK: store [[CUR_RES]], [[CUR_RES_BUF]][] : memref<f32> // CHECK: memref.store [[CUR_RES]], [[CUR_RES_BUF]][] : memref<f32>
// Compute scatter value. // Compute scatter value.
// CHECK: "lmhlo.add"([[SRC_ELEM_BUF]], [[CUR_RES_BUF]], [[RES_BUF]]) : // CHECK: "lmhlo.add"([[SRC_ELEM_BUF]], [[CUR_RES_BUF]], [[RES_BUF]]) :
// CHECK-SAME: (memref<f32>, memref<f32>, memref<f32>) -> () // CHECK-SAME: (memref<f32>, memref<f32>, memref<f32>) -> ()
// CHECK: [[RES:%.*]] = load [[RES_BUF]][] : memref<f32> // CHECK: [[RES:%.*]] = memref.load [[RES_BUF]][] : memref<f32>
// Atomic RMW terminator that returns updated value. // Atomic RMW terminator that returns updated value.
// CHECK: atomic_yield [[RES]] : f32 // CHECK: atomic_yield [[RES]] : f32

View File

@ -19,14 +19,14 @@ func @reduce(%arg: memref<100x10xf32>,
// CHECK-DAG: %[[C100:.*]] = constant 100 : index // CHECK-DAG: %[[C100:.*]] = constant 100 : index
// CHECK-DAG: %[[C1:.*]] = constant 1 : index // CHECK-DAG: %[[C1:.*]] = constant 1 : index
// CHECK: gpu.launch blocks({{.*}}, {{.*}}, {{.*}}) in ({{.*}} = %[[C1]], {{.*}} = %[[C1]], {{.*}} = %[[C1]]) threads(%[[IDX:.*]], {{.*}}, {{.*}}) in ({{.*}} = %[[C100]], {{.*}} = %[[C1]], {{.*}} = %[[C1]]) { // CHECK: gpu.launch blocks({{.*}}, {{.*}}, {{.*}}) in ({{.*}} = %[[C1]], {{.*}} = %[[C1]], {{.*}} = %[[C1]]) threads(%[[IDX:.*]], {{.*}}, {{.*}}) in ({{.*}} = %[[C100]], {{.*}} = %[[C1]], {{.*}} = %[[C1]]) {
// CHECK: %[[ACC:.*]] = load %[[ARG1]][] : memref<f32> // CHECK: %[[ACC:.*]] = memref.load %[[ARG1]][] : memref<f32>
// CHECK: store %[[ACC]], %[[ARG2]][%[[IDX:.*]]] : memref<100xf32> // CHECK: store %[[ACC]], %[[ARG2]][%[[IDX:.*]]] : memref<100xf32>
// CHECK-DAG: %[[LB:.*]] = constant 0 : index // CHECK-DAG: %[[LB:.*]] = constant 0 : index
// CHECK-DAG: %[[UB:.*]] = constant 10 : index // CHECK-DAG: %[[UB:.*]] = constant 10 : index
// CHECK-DAG: %[[STEP:.*]] = constant 1 : index // CHECK-DAG: %[[STEP:.*]] = constant 1 : index
// CHECK: scf.for %[[IDX1:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { // CHECK: scf.for %[[IDX1:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
// CHECK: %[[LHS:.*]] = subview %[[ARG2]][%[[IDX]]] [1] [1] : memref<100xf32> to memref<f32, #[[$MAP]]> // CHECK: %[[LHS:.*]] = memref.subview %[[ARG2]][%[[IDX]]] [1] [1] : memref<100xf32> to memref<f32, #[[$MAP]]>
// CHECK: %[[RHS:.*]] = subview %[[ARG0]][%[[IDX]], %[[IDX1]]] [1, 1] [1, 1] : memref<100x10xf32> to memref<f32, #[[$MAP]]> // CHECK: %[[RHS:.*]] = memref.subview %[[ARG0]][%[[IDX]], %[[IDX1]]] [1, 1] [1, 1] : memref<100x10xf32> to memref<f32, #[[$MAP]]>
// CHECK: "lmhlo.add"(%[[LHS]], %[[RHS]], %[[LHS]]) : (memref<f32, {{.*}}>, memref<f32, {{.*}}>, memref<f32, {{.*}}>) -> () // CHECK: "lmhlo.add"(%[[LHS]], %[[RHS]], %[[LHS]]) : (memref<f32, {{.*}}>, memref<f32, {{.*}}>, memref<f32, {{.*}}>) -> ()
// CHECK: } // CHECK: }
// CHECK: gpu.terminator // CHECK: gpu.terminator

View File

@ -52,10 +52,10 @@ func @element_wise_scalar(%lhs: memref<f32>, %rhs: memref<f32>,
: (memref<f32>, memref<f32>, memref<f32>) -> () : (memref<f32>, memref<f32>, memref<f32>) -> ()
return return
} }
// CHECK: %[[LHS:.*]] = load // CHECK: %[[LHS:.*]] = memref.load
// CHECK: %[[RHS:.*]] = load // CHECK: %[[RHS:.*]] = memref.load
// CHECK: %[[RES:.*]] = addf %[[LHS]], %[[RHS]] // CHECK: %[[RES:.*]] = addf %[[LHS]], %[[RHS]]
// CHECK: store %[[RES]] // CHECK: memref.store %[[RES]]
// CHECK-NEXT: return // CHECK-NEXT: return
// ----- // -----
@ -347,7 +347,7 @@ func @static_broadcast_in_dim_with_one_to_many(%operand: memref<1xf32>,
} }
// CHECK-NOT: linalg.reshape // CHECK-NOT: linalg.reshape
// CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[VALUE:.*]] = load %{{.*}}[[C0]] // CHECK: %[[VALUE:.*]] = memref.load %{{.*}}[[C0]]
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[RESULT_MAP]]] // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%{{.+}}: f32): // CHECK-NEXT: ^bb0(%{{.+}}: f32):
// CHECK-NEXT: linalg.yield %[[VALUE]] : f32 // CHECK-NEXT: linalg.yield %[[VALUE]] : f32
@ -785,7 +785,7 @@ func @slice(%operand: memref<?x?xf32>, %result: memref<?x?xf32>) {
} : (memref<?x?xf32>, memref<?x?xf32>) -> () } : (memref<?x?xf32>, memref<?x?xf32>) -> ()
return return
} }
// CHECK: %[[RESULT:.*]] = subview %[[IN]][0, 1] [2, 2] [1, 1] : memref<?x?xf32> to memref<2x2xf32, #{{.*}}> // CHECK: %[[RESULT:.*]] = memref.subview %[[IN]][0, 1] [2, 2] [1, 1] : memref<?x?xf32> to memref<2x2xf32, #{{.*}}>
// CHECK: linalg.copy(%[[RESULT]], %[[OUT]]) // CHECK: linalg.copy(%[[RESULT]], %[[OUT]])
// ----- // -----
@ -899,7 +899,7 @@ func @reverse(%arg0: memref<2x3xf32>, %arg1: memref<2x3xf32>) {
func @conv(%input: memref<3x5x5x3xf32>, %filter: memref<2x2x3x4xf32>, %output: memref<3x5x5x4xf32>) { func @conv(%input: memref<3x5x5x3xf32>, %filter: memref<2x2x3x4xf32>, %output: memref<3x5x5x4xf32>) {
%c0 = constant 0 : index %c0 = constant 0 : index
%0 = alloc() : memref<3x5x5x4xf32> %0 = memref.alloc() : memref<3x5x5x4xf32>
// CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}}) // CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}})
// CHECK-SAME: dilations = [1, 2] // CHECK-SAME: dilations = [1, 2]
// CHECK-SAME: padding = dense<{{\[\[}}0, 1], [0, 1]]> : tensor<2x2xi64> // CHECK-SAME: padding = dense<{{\[\[}}0, 1], [0, 1]]> : tensor<2x2xi64>
@ -948,22 +948,22 @@ func @reduce_add(%arg: memref<100x10xf32>,
: (memref<100x10xf32>, memref<f32>, memref<100xf32>) -> () : (memref<100x10xf32>, memref<f32>, memref<100xf32>) -> ()
return return
} }
// CHECK: %[[INIT_VAL:.*]] = load %arg1[] : memref<f32> // CHECK: %[[INIT_VAL:.*]] = memref.load %arg1[] : memref<f32>
// CHECK: linalg.fill(%arg2, %[[INIT_VAL]]) // CHECK: linalg.fill(%arg2, %[[INIT_VAL]])
// CHECK: linalg.generic { // CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [#[[REDUCE_INPUT_MAP]], #[[REDUCE_OUTPUT_MAP]]], // CHECK-SAME: indexing_maps = [#[[REDUCE_INPUT_MAP]], #[[REDUCE_OUTPUT_MAP]]],
// CHECK-SAME: iterator_types = ["parallel", "reduction"]} // CHECK-SAME: iterator_types = ["parallel", "reduction"]}
// CHECK-SAME: ins(%arg0 : memref<100x10xf32>) outs(%arg2 : memref<100xf32>) { // CHECK-SAME: ins(%arg0 : memref<100x10xf32>) outs(%arg2 : memref<100xf32>) {
// CHECK: alloca // CHECK: memref.alloca
// CHECK-NEXT: alloca // CHECK-NEXT: memref.alloca
// CHECK-NEXT: alloca // CHECK-NEXT: memref.alloca
// CHECK-NEXT: store // CHECK-NEXT: memref.store
// CHECK-NEXT: store // CHECK-NEXT: memref.store
// CHECK-NEXT: load // CHECK-NEXT: memref.load
// CHECK-NEXT: load // CHECK-NEXT: memref.load
// CHECK-NEXT: addf // CHECK-NEXT: addf
// CHECK-NEXT: store // CHECK-NEXT: memref.store
// CHECK-NEXT: load // CHECK-NEXT: memref.load
// CHECK-NEXT: linalg.yield // CHECK-NEXT: linalg.yield
// CHECK-NEXT: } // CHECK-NEXT: }
@ -984,22 +984,22 @@ func @reduce_maximum(%arg: memref<100x10xf32>,
: (memref<100x10xf32>, memref<f32>, memref<100xf32>) -> () : (memref<100x10xf32>, memref<f32>, memref<100xf32>) -> ()
return return
} }
// CHECK: %[[INIT_VAL:.*]] = load %arg1[] : memref<f32> // CHECK: %[[INIT_VAL:.*]] = memref.load %arg1[] : memref<f32>
// CHECK: linalg.fill(%arg2, %[[INIT_VAL]]) // CHECK: linalg.fill(%arg2, %[[INIT_VAL]])
// CHECK: linalg.generic { // CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [#[[REDUCE_INPUT_MAP]], #[[REDUCE_OUTPUT_MAP]]], // CHECK-SAME: indexing_maps = [#[[REDUCE_INPUT_MAP]], #[[REDUCE_OUTPUT_MAP]]],
// CHECK-SAME: iterator_types = ["parallel", "reduction"]} // CHECK-SAME: iterator_types = ["parallel", "reduction"]}
// CHECK-SAME: ins(%arg0 : memref<100x10xf32>) outs(%arg2 : memref<100xf32>) { // CHECK-SAME: ins(%arg0 : memref<100x10xf32>) outs(%arg2 : memref<100xf32>) {
// CHECK: alloca // CHECK: memref.alloca
// CHECK-NEXT: alloca // CHECK-NEXT: memref.alloca
// CHECK-NEXT: alloca // CHECK-NEXT: memref.alloca
// CHECK-NEXT: store // CHECK-NEXT: memref.store
// CHECK-NEXT: store // CHECK-NEXT: memref.store
// CHECK-NEXT: load // CHECK-NEXT: memref.load
// CHECK-NEXT: load // CHECK-NEXT: memref.load
// CHECK: cmpf // CHECK: cmpf
// CHECK: select // CHECK: select
// CHECK: store // CHECK: memref.store
// CHECK-NEXT: load // CHECK-NEXT: memref.load
// CHECK-NEXT: linalg.yield // CHECK-NEXT: linalg.yield
// CHECK-NEXT: } // CHECK-NEXT: }

View File

@ -21,27 +21,27 @@ func @reduce(%arg: memref<100x10x5xf32>,
// CHECK-DAG: [[C5:%.*]] = constant 5 : index // CHECK-DAG: [[C5:%.*]] = constant 5 : index
// CHECK-DAG: [[C10:%.*]] = constant 10 : index // CHECK-DAG: [[C10:%.*]] = constant 10 : index
// CHECK-DAG: [[C100:%.*]] = constant 100 : index // CHECK-DAG: [[C100:%.*]] = constant 100 : index
// CHECK: [[INIT:%.*]] = load [[INIT_BUF]] // CHECK: [[INIT:%.*]] = memref.load [[INIT_BUF]]
// CHECK: scf.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]]) // CHECK: scf.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]])
// CHECK-SAME: to ([[C100]], [[C5]]) step ([[C1]], [[C1]]) { // CHECK-SAME: to ([[C100]], [[C5]]) step ([[C1]], [[C1]]) {
// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[J:%.*]]) = // CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[J:%.*]]) =
// CHECK-SAME: ([[C0]]) to ([[C10]]) step ([[C1]]) init ([[INIT]]) -> f32 { // CHECK-SAME: ([[C0]]) to ([[C10]]) step ([[C1]]) init ([[INIT]]) -> f32 {
// CHECK: [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]] // CHECK: [[ELEM_TO_REDUCE:%.*]] = memref.load [[ARG_BUF]]
// CHECK-SAME: {{\[}}[[I]], [[J]], [[K]]] : memref<100x10x5xf32> // CHECK-SAME: {{\[}}[[I]], [[J]], [[K]]] : memref<100x10x5xf32>
// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 { // CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 {
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32> // CHECK: [[ELEM_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32> // CHECK: [[ACC_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32> // CHECK: [[ACC_OUT_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32> // CHECK: memref.store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32> // CHECK: memref.store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) // CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32> // CHECK: [[ACC_RESULT:%.*]] = memref.load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: scf.reduce.return [[ACC_RESULT]] : f32 // CHECK: scf.reduce.return [[ACC_RESULT]] : f32
// CHECK: } // CHECK: }
// CHECK: scf.yield // CHECK: scf.yield
// CHECK: } // CHECK: }
// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]] // CHECK: memref.store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]]
// CHECK: scf.yield // CHECK: scf.yield
// ----- // -----
@ -65,23 +65,23 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>,
// CHECK-DAG: [[C0:%.*]] = constant 0 : index // CHECK-DAG: [[C0:%.*]] = constant 0 : index
// CHECK-DAG: [[C1:%.*]] = constant 1 : index // CHECK-DAG: [[C1:%.*]] = constant 1 : index
// CHECK-DAG: [[C100:%.*]] = constant 100 : index // CHECK-DAG: [[C100:%.*]] = constant 100 : index
// CHECK: [[INIT:%.*]] = load [[INIT_BUF]] // CHECK: [[INIT:%.*]] = memref.load [[INIT_BUF]]
// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[I:%.*]]) = ([[C0]]) // CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[I:%.*]]) = ([[C0]])
// CHECK-SAME: to ([[C100]]) step ([[C1]]) init ([[INIT]]) -> f32 { // CHECK-SAME: to ([[C100]]) step ([[C1]]) init ([[INIT]]) -> f32 {
// CHECK: [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]]{{\[}}[[I]]{{\]}} // CHECK: [[ELEM_TO_REDUCE:%.*]] = memref.load [[ARG_BUF]]{{\[}}[[I]]{{\]}}
// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 { // CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 {
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32> // CHECK: [[ELEM_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32> // CHECK: [[ACC_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32> // CHECK: [[ACC_OUT_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32> // CHECK: memref.store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32> // CHECK: memref.store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) // CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32> // CHECK: [[ACC_RESULT:%.*]] = memref.load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: scf.reduce.return [[ACC_RESULT]] // CHECK: scf.reduce.return [[ACC_RESULT]]
// CHECK: } // CHECK: }
// CHECK: scf.yield // CHECK: scf.yield
// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[C0]]] // CHECK: memref.store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[C0]]]
// ----- // -----
@ -104,30 +104,30 @@ func @dynamic_reduce(%arg: memref<?x?x?xf32>,
// CHECK-DAG: [[C0:%.*]] = constant 0 : index // CHECK-DAG: [[C0:%.*]] = constant 0 : index
// CHECK-DAG: [[C1:%.*]] = constant 1 : index // CHECK-DAG: [[C1:%.*]] = constant 1 : index
// CHECK-DAG: [[C2:%.*]] = constant 2 : index // CHECK-DAG: [[C2:%.*]] = constant 2 : index
// CHECK: [[DIM0:%.*]] = dim [[ARG_BUF]], [[C0]] : memref<?x?x?xf32> // CHECK: [[DIM0:%.*]] = memref.dim [[ARG_BUF]], [[C0]] : memref<?x?x?xf32>
// CHECK: [[DIM1:%.*]] = dim [[ARG_BUF]], [[C1]] : memref<?x?x?xf32> // CHECK: [[DIM1:%.*]] = memref.dim [[ARG_BUF]], [[C1]] : memref<?x?x?xf32>
// CHECK: [[DIM2:%.*]] = dim [[ARG_BUF]], [[C2]] : memref<?x?x?xf32> // CHECK: [[DIM2:%.*]] = memref.dim [[ARG_BUF]], [[C2]] : memref<?x?x?xf32>
// CHECK: [[INIT:%.*]] = load [[INIT_BUF]] // CHECK: [[INIT:%.*]] = memref.load [[INIT_BUF]]
// CHECK: scf.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]]) // CHECK: scf.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]])
// CHECK-SAME: to ([[DIM0]], [[DIM2]]) step ([[C1]], [[C1]]) { // CHECK-SAME: to ([[DIM0]], [[DIM2]]) step ([[C1]], [[C1]]) {
// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[J:%.*]]) = // CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[J:%.*]]) =
// CHECK-SAME: ([[C0]]) to ([[DIM1]]) step ([[C1]]) init ([[INIT]]) -> f32 { // CHECK-SAME: ([[C0]]) to ([[DIM1]]) step ([[C1]]) init ([[INIT]]) -> f32 {
// CHECK: [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]] // CHECK: [[ELEM_TO_REDUCE:%.*]] = memref.load [[ARG_BUF]]
// CHECK-SAME: {{\[}}[[I]], [[J]], [[K]]] : memref<?x?x?xf32> // CHECK-SAME: {{\[}}[[I]], [[J]], [[K]]] : memref<?x?x?xf32>
// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 { // CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 {
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32> // CHECK: [[ELEM_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32> // CHECK: [[ACC_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32> // CHECK: [[ACC_OUT_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32> // CHECK: memref.store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32> // CHECK: memref.store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) // CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32> // CHECK: [[ACC_RESULT:%.*]] = memref.load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: scf.reduce.return [[ACC_RESULT]] : f32 // CHECK: scf.reduce.return [[ACC_RESULT]] : f32
// CHECK: } // CHECK: }
// CHECK: scf.yield // CHECK: scf.yield
// CHECK: } // CHECK: }
// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]] // CHECK: memref.store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]]
// CHECK: scf.yield // CHECK: scf.yield
// ----- // -----
@ -157,7 +157,7 @@ func @reduce_window(%arg: memref<112x112xf32>,
// CHECK-DAG: [[C3:%.*]] = constant 3 : index // CHECK-DAG: [[C3:%.*]] = constant 3 : index
// CHECK-DAG: [[C56:%.*]] = constant 56 : index // CHECK-DAG: [[C56:%.*]] = constant 56 : index
// CHECK-DAG: [[C112:%.*]] = constant 112 : index // CHECK-DAG: [[C112:%.*]] = constant 112 : index
// CHECK: [[INIT:%.*]] = load [[INIT_BUF]][] : memref<f32> // CHECK: [[INIT:%.*]] = memref.load [[INIT_BUF]][] : memref<f32>
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) // CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
// CHECK-SAME: to ([[C56]], [[C56]]) step ([[C1]], [[C1]]) { // CHECK-SAME: to ([[C56]], [[C56]]) step ([[C1]], [[C1]]) {
// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel // CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel
@ -176,7 +176,7 @@ func @reduce_window(%arg: memref<112x112xf32>,
// CHECK: [[ELEM_TO_REDUCE:%.*]] = scf.if [[IN_BOUNDS_1]] -> (f32) { // CHECK: [[ELEM_TO_REDUCE:%.*]] = scf.if [[IN_BOUNDS_1]] -> (f32) {
// CHECK: [[OPERAND_ELEM:%.*]] = // CHECK: [[OPERAND_ELEM:%.*]] =
// CHECK-SAME: load [[OPERAND_BUF]]{{\[}}[[INDEX_I]], [[INDEX_J]]] // CHECK-SAME: memref.load [[OPERAND_BUF]]{{\[}}[[INDEX_I]], [[INDEX_J]]]
// CHECK: scf.yield [[OPERAND_ELEM]] : f32 // CHECK: scf.yield [[OPERAND_ELEM]] : f32
// CHECK: } else { // CHECK: } else {
// CHECK: scf.yield [[INIT]] : f32 // CHECK: scf.yield [[INIT]] : f32
@ -184,18 +184,18 @@ func @reduce_window(%arg: memref<112x112xf32>,
// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 { // CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 {
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32> // CHECK: [[ELEM_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32> // CHECK: [[ACC_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32> // CHECK: [[ACC_OUT_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32> // CHECK: memref.store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32> // CHECK: memref.store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: "lmhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) // CHECK: "lmhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32> // CHECK: [[ACC_RESULT:%.*]] = memref.load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: scf.reduce.return [[ACC_RESULT]] : f32 // CHECK: scf.reduce.return [[ACC_RESULT]] : f32
// CHECK: } // CHECK: }
// CHECK: scf.yield // CHECK: scf.yield
// CHECK: } // CHECK: }
// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]] // CHECK: memref.store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]]
// CHECK: scf.yield // CHECK: scf.yield
// CHECK: } // CHECK: }
// CHECK: return // CHECK: return

View File

@ -30,7 +30,7 @@ func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf3
// CHECK-LABEL: func @conv_forward // CHECK-LABEL: func @conv_forward
func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %output: memref<1x1x7x7xf16>) { func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %output: memref<1x1x7x7xf16>) {
%scratch = alloc() : memref<32xi8> %scratch = memref.alloc() : memref<32xi8>
// This defined a 2D convolution over a 8x8 single channel input using a 2x2 // This defined a 2D convolution over a 8x8 single channel input using a 2x2
// filter and with an output of 7x7xf16. The 1x1x8x8 is (N, C, H, W) // filter and with an output of 7x7xf16. The 1x1x8x8 is (N, C, H, W)
"lmhlo_gpu.conv_forward"(%input, %filter, %output, %scratch) "lmhlo_gpu.conv_forward"(%input, %filter, %output, %scratch)
@ -61,7 +61,7 @@ func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %
// CHECK-LABEL: func @conv_backfilter // CHECK-LABEL: func @conv_backfilter
func @conv_backfilter(%input : memref<3x56x56x16xf64>, %filter: memref<3x3x3x64xf64>, %output: memref<54x54x16x64xf64>) { func @conv_backfilter(%input : memref<3x56x56x16xf64>, %filter: memref<3x3x3x64xf64>, %output: memref<54x54x16x64xf64>) {
%scratch = alloc() : memref<23328xui8> %scratch = memref.alloc() : memref<23328xui8>
"lmhlo_gpu.conv_backwardfilter"(%input, %filter, %output, %scratch) "lmhlo_gpu.conv_backwardfilter"(%input, %filter, %output, %scratch)
{ backend_config = {algorithm = 1 : i64, { backend_config = {algorithm = 1 : i64,
operand_0_layout = [3,2,1,0], operand_0_layout = [3,2,1,0],
@ -91,7 +91,7 @@ func @conv_backfilter(%input : memref<3x56x56x16xf64>, %filter: memref<3x3x3x64x
// CHECK-LABEL: func @conv_backinput // CHECK-LABEL: func @conv_backinput
func @conv_backinput(%input : memref<4x5x16x16xf64>, %filter : memref<5x3x7x7xf64>, %output : memref<4x3x16x16xf64>) { func @conv_backinput(%input : memref<4x5x16x16xf64>, %filter : memref<5x3x7x7xf64>, %output : memref<4x3x16x16xf64>) {
%scratch = alloc() : memref<32xui8> %scratch = memref.alloc() : memref<32xui8>
"lmhlo_gpu.conv_backwardinput"(%input, %filter, %output, %scratch) "lmhlo_gpu.conv_backwardinput"(%input, %filter, %output, %scratch)
{ backend_config = {algorithm = 1 : i64, { backend_config = {algorithm = 1 : i64,
operand_0_layout = [3,2,1,0], operand_0_layout = [3,2,1,0],
@ -122,7 +122,7 @@ func @conv_backinput(%input : memref<4x5x16x16xf64>, %filter : memref<5x3x7x7xf6
// CHECK-LABEL: func @conv_fused // CHECK-LABEL: func @conv_fused
func @conv_fused(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>, %bias : memref<32xf16>, %output : memref<1x32x9x9xf16>) { func @conv_fused(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>, %bias : memref<32xf16>, %output : memref<1x32x9x9xf16>) {
%scratch = alloc() : memref<32xui8> %scratch = memref.alloc() : memref<32xui8>
"lmhlo_gpu.conv_forward_fused"(%input, %filter, %bias, %output, %scratch) "lmhlo_gpu.conv_forward_fused"(%input, %filter, %bias, %output, %scratch)
{activation_mode = "Relu", {activation_mode = "Relu",
backend_config = {algorithm = 1 : i64, backend_config = {algorithm = 1 : i64,
@ -153,7 +153,7 @@ func @conv_fused(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>,
// CHECK-LABEL: func @conv_fused_side_input // CHECK-LABEL: func @conv_fused_side_input
func @conv_fused_side_input(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>, %bias : memref<32xf16>, %side_input: memref<32xf16>, %output : memref<1x32x9x9xf16>) { func @conv_fused_side_input(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>, %bias : memref<32xf16>, %side_input: memref<32xf16>, %output : memref<1x32x9x9xf16>) {
%scratch = alloc() : memref<0xui8> %scratch = memref.alloc() : memref<0xui8>
"lmhlo_gpu.conv_forward_fused_with_side_input"(%input, %filter, %bias, %side_input, %output, %scratch) "lmhlo_gpu.conv_forward_fused_with_side_input"(%input, %filter, %bias, %side_input, %output, %scratch)
{activation_mode = "Relu", {activation_mode = "Relu",
backend_config = {algorithm = 1 : i64, backend_config = {algorithm = 1 : i64,
@ -218,8 +218,8 @@ func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>,
// CHECK-LABEL: func @cholesky // CHECK-LABEL: func @cholesky
func @cholesky(%arg : memref<10x10xf32>, %out: memref<10x10xf32>) { func @cholesky(%arg : memref<10x10xf32>, %out: memref<10x10xf32>) {
%scratch = alloc() : memref<32xi8> %scratch = memref.alloc() : memref<32xi8>
%info = alloc() : memref<32xi32> %info = memref.alloc() : memref<32xi32>
"lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_lower = true } "lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_lower = true }
: (memref<10x10xf32>, memref<10x10xf32>, memref<32xi8>, memref<32xi32>) -> () : (memref<10x10xf32>, memref<10x10xf32>, memref<32xi8>, memref<32xi32>) -> ()
return return

View File

@ -457,12 +457,12 @@ func @reduce_memref(%input: memref<10xf32>, %init: memref<f32>, %out: memref<1xf
// CHECK-LABEL: func @fusion_memref // CHECK-LABEL: func @fusion_memref
func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: memref<10xf32>, %out: memref<10xf32>) -> () { func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: memref<10xf32>, %out: memref<10xf32>) -> () {
"lmhlo.fusion"() ( { "lmhlo.fusion"() ( {
%0 = tensor_load %input1 : memref<10xf32> %0 = memref.tensor_load %input1 : memref<10xf32>
%1 = tensor_load %input2 : memref<10xf32> %1 = memref.tensor_load %input2 : memref<10xf32>
%2 = "mhlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> %2 = "mhlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
%3 = tensor_load %input3 : memref<10xf32> %3 = memref.tensor_load %input3 : memref<10xf32>
%4 = "mhlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> %4 = "mhlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
tensor_store %4, %out : memref<10xf32> memref.tensor_store %4, %out : memref<10xf32>
"lmhlo.terminator"() : () -> () "lmhlo.terminator"() : () -> ()
} ) : () -> () } ) : () -> ()
return return

View File

@ -108,15 +108,15 @@ func @batchNormInference_dynamic_shape(
// CHECK-DAG: %[[C2:.*]] = constant 2 : index // CHECK-DAG: %[[C2:.*]] = constant 2 : index
// CHECK-DAG: %[[C3:.*]] = constant 3 : index // CHECK-DAG: %[[C3:.*]] = constant 3 : index
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e-03> : tensor<f32> // CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e-03> : tensor<f32>
// CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], %[[C0]] : tensor<?xf32> // CHECK-DAG: %[[DIM:.+]] = memref.dim %[[VARIANCE]], %[[C0]] : tensor<?xf32>
// CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor.from_elements %[[DIM]] : tensor<1xindex> // CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor.from_elements %[[DIM]] : tensor<1xindex>
// CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32> // CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<?xf32> // CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<?xf32>
// CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<?xf32>) -> tensor<?xf32> // CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK-DAG: %[[INPUT_DIM_0:.+]] = dim %[[X]], %[[C0]] : tensor<?x?x?x?xf32> // CHECK-DAG: %[[INPUT_DIM_0:.+]] = memref.dim %[[X]], %[[C0]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_1:.+]] = dim %[[X]], %[[C1]] : tensor<?x?x?x?xf32> // CHECK-DAG: %[[INPUT_DIM_1:.+]] = memref.dim %[[X]], %[[C1]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_2:.+]] = dim %[[X]], %[[C2]] : tensor<?x?x?x?xf32> // CHECK-DAG: %[[INPUT_DIM_2:.+]] = memref.dim %[[X]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_3:.+]] = dim %[[X]], %[[C3]] : tensor<?x?x?x?xf32> // CHECK-DAG: %[[INPUT_DIM_3:.+]] = memref.dim %[[X]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = tensor.from_elements %[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]] : tensor<4xindex> // CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = tensor.from_elements %[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]] : tensor<4xindex>
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32> // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32> // CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>