Integrate LLVM at llvm/llvm-project@678241795c
Updates LLVM usage to match [678241795c95](https://github.com/llvm/llvm-project/commit/678241795c95) PiperOrigin-RevId: 363257913
This commit is contained in:
parent
2be112a603
commit
c54527fe88
5
BUILD
5
BUILD
|
@ -23,6 +23,7 @@ td_library(
|
||||||
],
|
],
|
||||||
includes = ["include"],
|
includes = ["include"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"@llvm-project//mlir:MemRefOpsTdFiles",
|
||||||
"@llvm-project//mlir:OpBaseTdFiles",
|
"@llvm-project//mlir:OpBaseTdFiles",
|
||||||
"@llvm-project//mlir:SideEffectTdFiles",
|
"@llvm-project//mlir:SideEffectTdFiles",
|
||||||
],
|
],
|
||||||
|
@ -462,6 +463,7 @@ cc_library(
|
||||||
"@llvm-project//mlir:Analysis",
|
"@llvm-project//mlir:Analysis",
|
||||||
"@llvm-project//mlir:CopyOpInterface",
|
"@llvm-project//mlir:CopyOpInterface",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
|
"@llvm-project//mlir:MemRefDialect",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
"@llvm-project//mlir:SideEffects",
|
"@llvm-project//mlir:SideEffects",
|
||||||
"@llvm-project//mlir:StandardOps",
|
"@llvm-project//mlir:StandardOps",
|
||||||
|
@ -615,6 +617,7 @@ cc_library(
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:LinalgOps",
|
"@llvm-project//mlir:LinalgOps",
|
||||||
|
"@llvm-project//mlir:MemRefDialect",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
"@llvm-project//mlir:SCFDialect",
|
"@llvm-project//mlir:SCFDialect",
|
||||||
"@llvm-project//mlir:StandardOps",
|
"@llvm-project//mlir:StandardOps",
|
||||||
|
@ -724,6 +727,7 @@ cc_library(
|
||||||
"@llvm-project//mlir:Affine",
|
"@llvm-project//mlir:Affine",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:LinalgTransforms",
|
"@llvm-project//mlir:LinalgTransforms",
|
||||||
|
"@llvm-project//mlir:MemRefDialect",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
"@llvm-project//mlir:SCFDialect",
|
"@llvm-project//mlir:SCFDialect",
|
||||||
"@llvm-project//mlir:StandardOps",
|
"@llvm-project//mlir:StandardOps",
|
||||||
|
@ -951,6 +955,7 @@ cc_library(
|
||||||
":hlo",
|
":hlo",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
|
"@llvm-project//mlir:MemRefDialect",
|
||||||
"@llvm-project//mlir:StandardOps",
|
"@llvm-project//mlir:StandardOps",
|
||||||
"@llvm-project//mlir:TensorDialect",
|
"@llvm-project//mlir:TensorDialect",
|
||||||
"@llvm-project//mlir:Transforms",
|
"@llvm-project//mlir:Transforms",
|
||||||
|
|
|
@ -15,9 +15,9 @@
|
||||||
|
|
||||||
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
||||||
|
|
||||||
LLVM_COMMIT = "6878be5dc3ec7031d0deec3e321310115bd71103"
|
LLVM_COMMIT = "678241795c957b18bc473045e48abe3f2a61ff5c"
|
||||||
|
|
||||||
LLVM_SHA256 = "f55187a3329fd97fd62fd0714783524d50a3be934a35484bd4442195fb25f0e5"
|
LLVM_SHA256 = "58fd00a9ed7841f36aa7042bb8c98323b030dee98abe36757eea9ddf4fd5ea75"
|
||||||
|
|
||||||
LLVM_BAZEL_TAG = "llvm-project-{commit}".format(commit = LLVM_COMMIT)
|
LLVM_BAZEL_TAG = "llvm-project-{commit}".format(commit = LLVM_COMMIT)
|
||||||
|
|
||||||
|
|
|
@ -1,2 +1,2 @@
|
||||||
6878be5dc3ec7031d0deec3e321310115bd71103
|
678241795c957b18bc473045e48abe3f2a61ff5c
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h"
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
|
|
@ -33,6 +33,7 @@ limitations under the License.
|
||||||
#ifndef LHLO_OPS
|
#ifndef LHLO_OPS
|
||||||
#define LHLO_OPS
|
#define LHLO_OPS
|
||||||
|
|
||||||
|
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
include "mlir/Interfaces/CopyOpInterface.td"
|
include "mlir/Interfaces/CopyOpInterface.td"
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
|
@ -685,7 +686,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
SmallVector<Value, 4> getInputBuffers() {
|
SmallVector<Value, 4> getInputBuffers() {
|
||||||
SmallVector<Value, 4> buffers;
|
SmallVector<Value, 4> buffers;
|
||||||
this->region().walk([&](TensorLoadOp load) {
|
this->region().walk([&](memref::TensorLoadOp load) {
|
||||||
if (load.memref().getParentRegion()->isProperAncestor(®ion()))
|
if (load.memref().getParentRegion()->isProperAncestor(®ion()))
|
||||||
buffers.push_back(load.memref());
|
buffers.push_back(load.memref());
|
||||||
});
|
});
|
||||||
|
@ -694,7 +695,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
|
||||||
|
|
||||||
SmallVector<Value, 4> getOutputBuffers() {
|
SmallVector<Value, 4> getOutputBuffers() {
|
||||||
SmallVector<Value, 4> buffers;
|
SmallVector<Value, 4> buffers;
|
||||||
this->region().walk([&](TensorStoreOp store) {
|
this->region().walk([&](memref::TensorStoreOp store) {
|
||||||
if (store.memref().getParentRegion()->isProperAncestor(®ion()))
|
if (store.memref().getParentRegion()->isProperAncestor(®ion()))
|
||||||
buffers.push_back(store.memref());
|
buffers.push_back(store.memref());
|
||||||
});
|
});
|
||||||
|
@ -703,7 +704,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
|
||||||
|
|
||||||
SmallVector<Value, 4> getFusionParameters() {
|
SmallVector<Value, 4> getFusionParameters() {
|
||||||
SmallVector<Value, 4> buffers;
|
SmallVector<Value, 4> buffers;
|
||||||
this->region().walk([&](TensorLoadOp load) {
|
this->region().walk([&](memref::TensorLoadOp load) {
|
||||||
if (load.memref().getParentRegion()->isProperAncestor(®ion()))
|
if (load.memref().getParentRegion()->isProperAncestor(®ion()))
|
||||||
buffers.push_back(load);
|
buffers.push_back(load);
|
||||||
});
|
});
|
||||||
|
@ -712,7 +713,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
|
||||||
|
|
||||||
SmallVector<Value, 4> getFusionResults() {
|
SmallVector<Value, 4> getFusionResults() {
|
||||||
SmallVector<Value, 4> buffers;
|
SmallVector<Value, 4> buffers;
|
||||||
this->region().walk([&](TensorStoreOp store) {
|
this->region().walk([&](memref::TensorStoreOp store) {
|
||||||
if (store.memref().getParentRegion()->isProperAncestor(®ion()))
|
if (store.memref().getParentRegion()->isProperAncestor(®ion()))
|
||||||
buffers.push_back(store.tensor());
|
buffers.push_back(store.tensor());
|
||||||
});
|
});
|
||||||
|
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||||
#ifndef LHLO_OPS_BASE
|
#ifndef LHLO_OPS_BASE
|
||||||
#define LHLO_OPS_BASE
|
#define LHLO_OPS_BASE
|
||||||
|
|
||||||
|
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
||||||
|
|
||||||
|
|
|
@ -33,6 +33,7 @@ limitations under the License.
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
|
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
|
@ -156,13 +157,13 @@ struct EraseConstOp : public OpRewritePattern<ConstOp> {
|
||||||
LogicalResult matchAndRewrite(ConstOp op,
|
LogicalResult matchAndRewrite(ConstOp op,
|
||||||
PatternRewriter& rewriter) const override {
|
PatternRewriter& rewriter) const override {
|
||||||
Value memref = op.output();
|
Value memref = op.output();
|
||||||
if (!memref.getDefiningOp<AllocOp>()) {
|
if (!memref.getDefiningOp<memref::AllocOp>()) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check that all uses of the memref are either DeallocOps or this op.
|
// Check that all uses of the memref are either DeallocOps or this op.
|
||||||
for (Operation* user : memref.getUsers())
|
for (Operation* user : memref.getUsers())
|
||||||
if (user != op && !isa<DeallocOp>(user)) return failure();
|
if (user != op && !isa<memref::DeallocOp>(user)) return failure();
|
||||||
|
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
return success();
|
return success();
|
||||||
|
|
|
@ -71,7 +71,7 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result,
|
||||||
dynamic_operands.push_back(alloc_operand);
|
dynamic_operands.push_back(alloc_operand);
|
||||||
}
|
}
|
||||||
|
|
||||||
return rewriter->create<AllocOp>(loc, memref_type, dynamic_operands);
|
return rewriter->create<memref::AllocOp>(loc, memref_type, dynamic_operands);
|
||||||
}
|
}
|
||||||
|
|
||||||
Value InsertAlloc(Location loc, OpResult result,
|
Value InsertAlloc(Location loc, OpResult result,
|
||||||
|
@ -85,7 +85,7 @@ Value InsertAlloc(Location loc, OpResult result,
|
||||||
MemRefType::get(result_type.getShape(), result_type.getElementType());
|
MemRefType::get(result_type.getShape(), result_type.getElementType());
|
||||||
OpBuilder::InsertionGuard guard(*rewriter);
|
OpBuilder::InsertionGuard guard(*rewriter);
|
||||||
rewriter->setInsertionPoint(result.getDefiningOp());
|
rewriter->setInsertionPoint(result.getDefiningOp());
|
||||||
auto alloc = rewriter->create<AllocOp>(loc, memref_type);
|
auto alloc = rewriter->create<memref::AllocOp>(loc, memref_type);
|
||||||
return alloc;
|
return alloc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -207,7 +207,7 @@ class HloToLhloReshapeUnrankedConverter
|
||||||
if (unranked_operand_type == nullptr) return failure();
|
if (unranked_operand_type == nullptr) return failure();
|
||||||
|
|
||||||
auto result_type = op.getType().cast<RankedTensorType>();
|
auto result_type = op.getType().cast<RankedTensorType>();
|
||||||
rewriter.replaceOpWithNewOp<MemRefCastOp>(
|
rewriter.replaceOpWithNewOp<memref::CastOp>(
|
||||||
op, adaptor.operand(),
|
op, adaptor.operand(),
|
||||||
MemRefType::get(result_type.getShape(), result_type.getElementType()));
|
MemRefType::get(result_type.getShape(), result_type.getElementType()));
|
||||||
return success();
|
return success();
|
||||||
|
@ -235,7 +235,7 @@ class HloToLhloDynamicReshapeConverter
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
|
mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
|
||||||
rewriter.replaceOpWithNewOp<MemRefReshapeOp>(
|
rewriter.replaceOpWithNewOp<memref::ReshapeOp>(
|
||||||
op, result_type, adaptor.operand(), adaptor.output_shape());
|
op, result_type, adaptor.operand(), adaptor.output_shape());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -273,7 +273,7 @@ class HloToLhloDynamicBroadcastInDimOpConverter
|
||||||
// Inserts dynamic memref to change the layout of the memref to put 0-stride
|
// Inserts dynamic memref to change the layout of the memref to put 0-stride
|
||||||
// and size of the target dimension if size-1 dimension expansion is
|
// and size of the target dimension if size-1 dimension expansion is
|
||||||
// necessary.
|
// necessary.
|
||||||
MemRefReinterpretCastOp InsertDynamicMemrefCastOp(
|
memref::ReinterpretCastOp InsertDynamicMemrefCastOp(
|
||||||
mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
|
mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
auto operand_type = operand.getType().cast<MemRefType>();
|
auto operand_type = operand.getType().cast<MemRefType>();
|
||||||
|
@ -295,7 +295,7 @@ class HloToLhloDynamicBroadcastInDimOpConverter
|
||||||
for (int i = operand_rank - 1; i >= 0; --i) {
|
for (int i = operand_rank - 1; i >= 0; --i) {
|
||||||
Value operand_dim_size =
|
Value operand_dim_size =
|
||||||
ShapedType::isDynamic(operand_shape[i])
|
ShapedType::isDynamic(operand_shape[i])
|
||||||
? b->create<DimOp>(loc, operand, i).getResult()
|
? b->create<memref::DimOp>(loc, operand, i).getResult()
|
||||||
: b->create<ConstantIndexOp>(loc, operand_shape[i]).getResult();
|
: b->create<ConstantIndexOp>(loc, operand_shape[i]).getResult();
|
||||||
operand_sizes[i] = operand_dim_size;
|
operand_sizes[i] = operand_dim_size;
|
||||||
|
|
||||||
|
@ -355,7 +355,7 @@ class HloToLhloDynamicBroadcastInDimOpConverter
|
||||||
makeStridedLinearLayoutMap(dynamic_layout,
|
makeStridedLinearLayoutMap(dynamic_layout,
|
||||||
/*offset=*/0, b->getContext()));
|
/*offset=*/0, b->getContext()));
|
||||||
|
|
||||||
auto transformed_operand = b->create<MemRefReinterpretCastOp>(
|
auto transformed_operand = b->create<memref::ReinterpretCastOp>(
|
||||||
loc, type_erased_memref_type, operand,
|
loc, type_erased_memref_type, operand,
|
||||||
/*offset=*/b->getI64IntegerAttr(0), sizes, strides);
|
/*offset=*/b->getI64IntegerAttr(0), sizes, strides);
|
||||||
return transformed_operand;
|
return transformed_operand;
|
||||||
|
@ -484,12 +484,12 @@ struct HloToLhloReturnOpConverter : public BaseOpConversion<mhlo::ReturnOp> {
|
||||||
|
|
||||||
// TODO(b/175789537) Remove this pattern.
|
// TODO(b/175789537) Remove this pattern.
|
||||||
class HloToLhloTensorStoreOpLegacyConverter
|
class HloToLhloTensorStoreOpLegacyConverter
|
||||||
: public BaseOpConversion<mlir::TensorStoreOp> {
|
: public BaseOpConversion<mlir::memref::TensorStoreOp> {
|
||||||
public:
|
public:
|
||||||
using BaseOpConversion<mlir::TensorStoreOp>::BaseOpConversion;
|
using BaseOpConversion<mlir::memref::TensorStoreOp>::BaseOpConversion;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(
|
LogicalResult matchAndRewrite(
|
||||||
mlir::TensorStoreOp op, ArrayRef<Value> operands,
|
mlir::memref::TensorStoreOp op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter& rewriter) const final {
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
rewriter.replaceOpWithNewOp<lmhlo::CopyOp>(op, llvm::None, operands.front(),
|
rewriter.replaceOpWithNewOp<lmhlo::CopyOp>(op, llvm::None, operands.front(),
|
||||||
operands.back());
|
operands.back());
|
||||||
|
@ -577,14 +577,16 @@ struct HloLegalizeToLhlo
|
||||||
ConversionTarget target(context);
|
ConversionTarget target(context);
|
||||||
target.addLegalDialect<lmhlo::LmhloDialect>();
|
target.addLegalDialect<lmhlo::LmhloDialect>();
|
||||||
target.addLegalDialect<StandardOpsDialect>();
|
target.addLegalDialect<StandardOpsDialect>();
|
||||||
|
target.addLegalDialect<memref::MemRefDialect>();
|
||||||
target.addLegalDialect<shape::ShapeDialect>();
|
target.addLegalDialect<shape::ShapeDialect>();
|
||||||
target.addLegalDialect<tensor::TensorDialect>();
|
target.addLegalDialect<tensor::TensorDialect>();
|
||||||
target.addIllegalDialect<mhlo::MhloDialect>();
|
target.addIllegalDialect<mhlo::MhloDialect>();
|
||||||
// Declare tensor_load and tensor_store illegal.
|
// Declare tensor_load and tensor_store illegal.
|
||||||
target.addIllegalOp<mlir::TensorLoadOp, mlir::TensorStoreOp>();
|
target.addIllegalOp<mlir::memref::TensorLoadOp,
|
||||||
// tensor_to_memref is illegal if it has uses.
|
mlir::memref::TensorStoreOp>();
|
||||||
// TODO(b/175670649) Make tensor_to_memref illegal.
|
// buffer_cast is illegal if it has uses.
|
||||||
target.addDynamicallyLegalOp<mlir::TensorToMemrefOp>(
|
// TODO(b/175670649) Make buffer_cast illegal.
|
||||||
|
target.addDynamicallyLegalOp<mlir::memref::BufferCastOp>(
|
||||||
[](auto op) { return op->use_empty(); });
|
[](auto op) { return op->use_empty(); });
|
||||||
|
|
||||||
BufferizeTypeConverter converter;
|
BufferizeTypeConverter converter;
|
||||||
|
|
|
@ -108,7 +108,7 @@ SmallVector<Value, 2> ExtractDynamicSizes(OpBuilder& b, Location loc,
|
||||||
dyn_sizes.push_back(
|
dyn_sizes.push_back(
|
||||||
b.create<IndexCastOp>(loc, b.getIndexType(), extract));
|
b.create<IndexCastOp>(loc, b.getIndexType(), extract));
|
||||||
} else {
|
} else {
|
||||||
dyn_sizes.push_back(b.create<DimOp>(loc, tensor, en.index()));
|
dyn_sizes.push_back(b.create<memref::DimOp>(loc, tensor, en.index()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return dyn_sizes;
|
return dyn_sizes;
|
||||||
|
@ -324,13 +324,13 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create two loads from the input.
|
// Create two loads from the input.
|
||||||
auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs());
|
auto lhs = rewriter.create<memref::LoadOp>(loc, lhlo_op.lhs());
|
||||||
auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs());
|
auto rhs = rewriter.create<memref::LoadOp>(loc, lhlo_op.rhs());
|
||||||
// TODO(ravishankarm) : Move this method out of lmhlo namespace.
|
// TODO(ravishankarm) : Move this method out of lmhlo namespace.
|
||||||
Value op_result = lmhlo::HloOpToStdScalarOp::map<LhloOp>(
|
Value op_result = lmhlo::HloOpToStdScalarOp::map<LhloOp>(
|
||||||
lhlo_op, arg_type.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
|
lhlo_op, arg_type.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
|
||||||
&rewriter);
|
&rewriter);
|
||||||
rewriter.create<StoreOp>(loc, op_result, lhlo_op.out());
|
rewriter.create<memref::StoreOp>(loc, op_result, lhlo_op.out());
|
||||||
rewriter.eraseOp(lhlo_op);
|
rewriter.eraseOp(lhlo_op);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -590,8 +590,8 @@ class LhloBroadcastInDimConverter
|
||||||
operand_type.getDimSize(0) <
|
operand_type.getDimSize(0) <
|
||||||
result_type.getDimSize(broadcast_dims.front())) {
|
result_type.getDimSize(broadcast_dims.front())) {
|
||||||
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
||||||
Value val =
|
Value val = rewriter.create<memref::LoadOp>(loc, operand,
|
||||||
rewriter.create<LoadOp>(loc, operand, llvm::makeArrayRef({zero}));
|
llvm::makeArrayRef({zero}));
|
||||||
rewriter.create<linalg::GenericOp>(
|
rewriter.create<linalg::GenericOp>(
|
||||||
loc, /*inputs=*/ValueRange{},
|
loc, /*inputs=*/ValueRange{},
|
||||||
/*outputBuffers=*/ValueRange{operand_adaptor.output()},
|
/*outputBuffers=*/ValueRange{operand_adaptor.output()},
|
||||||
|
@ -971,7 +971,8 @@ class ReduceConverter : public OpConversionPattern<lmhlo::ReduceOp> {
|
||||||
}
|
}
|
||||||
|
|
||||||
// First fill the output buffer with the init value.
|
// First fill the output buffer with the init value.
|
||||||
Value init_value = rewriter.create<LoadOp>(loc, adaptor.init_values()[0]);
|
Value init_value =
|
||||||
|
rewriter.create<memref::LoadOp>(loc, adaptor.init_values()[0]);
|
||||||
rewriter.create<linalg::FillOp>(loc, adaptor.out()[0], init_value);
|
rewriter.create<linalg::FillOp>(loc, adaptor.out()[0], init_value);
|
||||||
|
|
||||||
DenseIntElementsAttr dimensions_attr = reduce_op.dimensions();
|
DenseIntElementsAttr dimensions_attr = reduce_op.dimensions();
|
||||||
|
@ -1011,9 +1012,9 @@ class ReduceConverter : public OpConversionPattern<lmhlo::ReduceOp> {
|
||||||
// expects scalar SSA values. Add some allocs around the original op to
|
// expects scalar SSA values. Add some allocs around the original op to
|
||||||
// make it compatible.
|
// make it compatible.
|
||||||
auto arg_type = block->getArgument(0).getType().cast<MemRefType>();
|
auto arg_type = block->getArgument(0).getType().cast<MemRefType>();
|
||||||
Value alloc_a = rewriter.create<AllocaOp>(loc, arg_type);
|
Value alloc_a = rewriter.create<memref::AllocaOp>(loc, arg_type);
|
||||||
Value alloc_b = rewriter.create<AllocaOp>(loc, arg_type);
|
Value alloc_b = rewriter.create<memref::AllocaOp>(loc, arg_type);
|
||||||
Value alloc_res = rewriter.create<AllocaOp>(loc, arg_type);
|
Value alloc_res = rewriter.create<memref::AllocaOp>(loc, arg_type);
|
||||||
|
|
||||||
// Now turn the existing signature
|
// Now turn the existing signature
|
||||||
// (memref<X>, memref<X>, memref<X>) -> ()
|
// (memref<X>, memref<X>, memref<X>) -> ()
|
||||||
|
@ -1030,13 +1031,15 @@ class ReduceConverter : public OpConversionPattern<lmhlo::ReduceOp> {
|
||||||
|
|
||||||
// Store the arguments into the newly allocated buffers.
|
// Store the arguments into the newly allocated buffers.
|
||||||
rewriter.setInsertionPointAfter(alloc_res.getDefiningOp());
|
rewriter.setInsertionPointAfter(alloc_res.getDefiningOp());
|
||||||
rewriter.create<StoreOp>(loc, entry_block->getArgument(0), alloc_a);
|
rewriter.create<memref::StoreOp>(loc, entry_block->getArgument(0),
|
||||||
rewriter.create<StoreOp>(loc, entry_block->getArgument(1), alloc_b);
|
alloc_a);
|
||||||
|
rewriter.create<memref::StoreOp>(loc, entry_block->getArgument(1),
|
||||||
|
alloc_b);
|
||||||
rewriter.replaceOp(entry_block->getTerminator(), {});
|
rewriter.replaceOp(entry_block->getTerminator(), {});
|
||||||
|
|
||||||
// Load & yield the result.
|
// Load & yield the result.
|
||||||
rewriter.setInsertionPointToEnd(entry_block);
|
rewriter.setInsertionPointToEnd(entry_block);
|
||||||
auto load_res = rewriter.create<LoadOp>(loc, alloc_res);
|
auto load_res = rewriter.create<memref::LoadOp>(loc, alloc_res);
|
||||||
rewriter.create<linalg::YieldOp>(loc, ValueRange{load_res});
|
rewriter.create<linalg::YieldOp>(loc, ValueRange{load_res});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1099,8 +1102,8 @@ class SliceConverter : public OpConversionPattern<OpTy> {
|
||||||
slice_op.strides().template getValue<int64_t>(i)));
|
slice_op.strides().template getValue<int64_t>(i)));
|
||||||
}
|
}
|
||||||
if (isLHLO) {
|
if (isLHLO) {
|
||||||
auto linalg_op =
|
auto linalg_op = rewriter.create<memref::SubViewOp>(loc, args[0], offsets,
|
||||||
rewriter.create<SubViewOp>(loc, args[0], offsets, sizes, strides);
|
sizes, strides);
|
||||||
rewriter.create<linalg::CopyOp>(loc, linalg_op, args[1]);
|
rewriter.create<linalg::CopyOp>(loc, linalg_op, args[1]);
|
||||||
rewriter.eraseOp(slice_op);
|
rewriter.eraseOp(slice_op);
|
||||||
} else {
|
} else {
|
||||||
|
@ -1149,14 +1152,14 @@ SmallVector<Value, 2> GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc,
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case DotOperationType::kMatrixMatrix: {
|
case DotOperationType::kMatrixMatrix: {
|
||||||
if (lhs.getType().cast<ShapedType>().isDynamicDim(0))
|
if (lhs.getType().cast<ShapedType>().isDynamicDim(0))
|
||||||
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
|
dyn_shape.push_back(b.create<memref::DimOp>(loc, lhs, 0));
|
||||||
if (rhs.getType().cast<ShapedType>().isDynamicDim(1))
|
if (rhs.getType().cast<ShapedType>().isDynamicDim(1))
|
||||||
dyn_shape.push_back(b.create<DimOp>(loc, rhs, 1));
|
dyn_shape.push_back(b.create<memref::DimOp>(loc, rhs, 1));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case DotOperationType::kMatrixVector: {
|
case DotOperationType::kMatrixVector: {
|
||||||
if (lhs.getType().cast<ShapedType>().isDynamicDim(0))
|
if (lhs.getType().cast<ShapedType>().isDynamicDim(0))
|
||||||
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
|
dyn_shape.push_back(b.create<memref::DimOp>(loc, lhs, 0));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case DotOperationType::kVectorDot:
|
case DotOperationType::kVectorDot:
|
||||||
|
@ -1203,11 +1206,11 @@ SmallVector<Value, 8> GetDotGeneralOpInitTensorDynSizes(
|
||||||
OpBuilder& b, Location loc, Value lhs, Value rhs, ShapedType result_type) {
|
OpBuilder& b, Location loc, Value lhs, Value rhs, ShapedType result_type) {
|
||||||
SmallVector<Value, 8> dyn_shape;
|
SmallVector<Value, 8> dyn_shape;
|
||||||
if (result_type.isDynamicDim(0))
|
if (result_type.isDynamicDim(0))
|
||||||
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
|
dyn_shape.push_back(b.create<memref::DimOp>(loc, lhs, 0));
|
||||||
if (result_type.isDynamicDim(1))
|
if (result_type.isDynamicDim(1))
|
||||||
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 1));
|
dyn_shape.push_back(b.create<memref::DimOp>(loc, lhs, 1));
|
||||||
if (result_type.isDynamicDim(2))
|
if (result_type.isDynamicDim(2))
|
||||||
dyn_shape.push_back(b.create<DimOp>(loc, rhs, 2));
|
dyn_shape.push_back(b.create<memref::DimOp>(loc, rhs, 2));
|
||||||
return dyn_shape;
|
return dyn_shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1307,7 +1310,7 @@ SmallVector<Value, 8> GetReduceOpInitTensorDynSizes(
|
||||||
for (int i = 0, j = 0; i < rank; ++i) {
|
for (int i = 0, j = 0; i < rank; ++i) {
|
||||||
if (s.count(i)) continue;
|
if (s.count(i)) continue;
|
||||||
if (!result_type.isDynamicDim(j++)) continue;
|
if (!result_type.isDynamicDim(j++)) continue;
|
||||||
dyn_shape.push_back(b.create<DimOp>(loc, arg, i));
|
dyn_shape.push_back(b.create<memref::DimOp>(loc, arg, i));
|
||||||
}
|
}
|
||||||
|
|
||||||
return dyn_shape;
|
return dyn_shape;
|
||||||
|
@ -1467,7 +1470,7 @@ struct NormalConvOpOnTensorsConversion
|
||||||
// The output shape is N spatial_dims F.
|
// The output shape is N spatial_dims F.
|
||||||
SmallVector<Value, 8> dyn_sizes;
|
SmallVector<Value, 8> dyn_sizes;
|
||||||
if (result_type.isDynamicDim(0)) {
|
if (result_type.isDynamicDim(0)) {
|
||||||
dyn_sizes.push_back(rewriter.create<DimOp>(loc, input, 0));
|
dyn_sizes.push_back(rewriter.create<memref::DimOp>(loc, input, 0));
|
||||||
}
|
}
|
||||||
for (int64_t i = 1, e = rank - 1; i < e; ++i) {
|
for (int64_t i = 1, e = rank - 1; i < e; ++i) {
|
||||||
if (result_type.isDynamicDim(i)) {
|
if (result_type.isDynamicDim(i)) {
|
||||||
|
@ -1476,7 +1479,8 @@ struct NormalConvOpOnTensorsConversion
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (result_type.isDynamicDim(rank - 1)) {
|
if (result_type.isDynamicDim(rank - 1)) {
|
||||||
dyn_sizes.push_back(rewriter.create<DimOp>(loc, filter, rank - 1));
|
dyn_sizes.push_back(
|
||||||
|
rewriter.create<memref::DimOp>(loc, filter, rank - 1));
|
||||||
}
|
}
|
||||||
Value init_tensor = rewriter.create<linalg::InitTensorOp>(
|
Value init_tensor = rewriter.create<linalg::InitTensorOp>(
|
||||||
loc, dyn_sizes, result_type.getShape(), result_type.getElementType());
|
loc, dyn_sizes, result_type.getShape(), result_type.getElementType());
|
||||||
|
@ -1856,8 +1860,8 @@ struct LhloLegalizeToLinalgPass
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
ConversionTarget target(getContext());
|
ConversionTarget target(getContext());
|
||||||
target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
|
target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
|
||||||
math::MathDialect, StandardOpsDialect,
|
math::MathDialect, memref::MemRefDialect,
|
||||||
AffineDialect>();
|
StandardOpsDialect, AffineDialect>();
|
||||||
|
|
||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
populateLHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
populateLHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
||||||
|
@ -1881,6 +1885,9 @@ struct HloLegalizeToLinalgPass
|
||||||
math::MathDialect, StandardOpsDialect,
|
math::MathDialect, StandardOpsDialect,
|
||||||
tensor::TensorDialect, scf::SCFDialect>();
|
tensor::TensorDialect, scf::SCFDialect>();
|
||||||
|
|
||||||
|
// TODO: DimOp shouldn't be in MemRefDialect
|
||||||
|
target.addLegalOp<memref::DimOp>();
|
||||||
|
|
||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
||||||
if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
|
if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
|
||||||
|
|
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
|
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
|
||||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/Dialect/SCF/SCF.h"
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
@ -95,7 +96,7 @@ class LhloFuseLinalgPass
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto tensor_load = dyn_cast<TensorLoadOp>(definingOp)) {
|
if (auto tensor_load = dyn_cast<memref::TensorLoadOp>(definingOp)) {
|
||||||
auto alias = tensor_load.memref();
|
auto alias = tensor_load.memref();
|
||||||
if (result_buffers.insert(alias).second) {
|
if (result_buffers.insert(alias).second) {
|
||||||
worklist.push_back(alias);
|
worklist.push_back(alias);
|
||||||
|
@ -103,7 +104,7 @@ class LhloFuseLinalgPass
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto tensor_to_memref = dyn_cast<TensorToMemrefOp>(definingOp)) {
|
if (auto tensor_to_memref = dyn_cast<memref::BufferCastOp>(definingOp)) {
|
||||||
auto alias = tensor_to_memref.tensor();
|
auto alias = tensor_to_memref.tensor();
|
||||||
if (result_buffers.insert(alias).second) {
|
if (result_buffers.insert(alias).second) {
|
||||||
worklist.push_back(alias);
|
worklist.push_back(alias);
|
||||||
|
|
|
@ -96,9 +96,10 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
|
||||||
|
|
||||||
// Load the initial value and store it to the output.
|
// Load the initial value and store it to the output.
|
||||||
for (auto pair : llvm::zip(reduce_op.init_values(), reduce_op.out())) {
|
for (auto pair : llvm::zip(reduce_op.init_values(), reduce_op.out())) {
|
||||||
auto init_value = rewriter.create<mlir::LoadOp>(loc, std::get<0>(pair));
|
auto init_value =
|
||||||
rewriter.create<mlir::StoreOp>(loc, init_value, std::get<1>(pair),
|
rewriter.create<mlir::memref::LoadOp>(loc, std::get<0>(pair));
|
||||||
ArrayRef<Value>{index});
|
rewriter.create<mlir::memref::StoreOp>(
|
||||||
|
loc, init_value, std::get<1>(pair), ArrayRef<Value>{index});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert a loop into the body to compute the reduction. The loop ranges
|
// Insert a loop into the body to compute the reduction. The loop ranges
|
||||||
|
@ -128,8 +129,8 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
|
||||||
auto oneAttr = rewriter.getI64IntegerAttr(1);
|
auto oneAttr = rewriter.getI64IntegerAttr(1);
|
||||||
OpFoldResult size = oneAttr;
|
OpFoldResult size = oneAttr;
|
||||||
OpFoldResult stride = oneAttr;
|
OpFoldResult stride = oneAttr;
|
||||||
auto accumulator = rewriter.create<SubViewOp>(loc, resType, output,
|
auto accumulator = rewriter.create<memref::SubViewOp>(
|
||||||
offset, size, stride);
|
loc, resType, output, offset, size, stride);
|
||||||
llvm::SmallVector<Value, 4> indexings;
|
llvm::SmallVector<Value, 4> indexings;
|
||||||
auto input_buffer = *reduce_op.operands().begin();
|
auto input_buffer = *reduce_op.operands().begin();
|
||||||
auto input_type_rank =
|
auto input_type_rank =
|
||||||
|
@ -143,8 +144,8 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
|
||||||
}));
|
}));
|
||||||
SmallVector<OpFoldResult> sizes(input_type_rank, oneAttr);
|
SmallVector<OpFoldResult> sizes(input_type_rank, oneAttr);
|
||||||
SmallVector<OpFoldResult> strides(input_type_rank, oneAttr);
|
SmallVector<OpFoldResult> strides(input_type_rank, oneAttr);
|
||||||
auto rhs = rewriter.create<SubViewOp>(loc, accumulator.getType(), input,
|
auto rhs = rewriter.create<memref::SubViewOp>(
|
||||||
offsets, sizes, strides);
|
loc, accumulator.getType(), input, offsets, sizes, strides);
|
||||||
|
|
||||||
// Now copy over the actual body of the reduction, leaving out the
|
// Now copy over the actual body of the reduction, leaving out the
|
||||||
// terminator.
|
// terminator.
|
||||||
|
@ -179,8 +180,9 @@ struct LhloLegalizeToGpuPass
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
ConversionTarget target(getContext());
|
ConversionTarget target(getContext());
|
||||||
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
|
target.addLegalDialect<linalg::LinalgDialect, memref::MemRefDialect,
|
||||||
gpu::GPUDialect, scf::SCFDialect, LmhloDialect>();
|
StandardOpsDialect, gpu::GPUDialect, scf::SCFDialect,
|
||||||
|
LmhloDialect>();
|
||||||
target.addIllegalOp<ReduceOp>();
|
target.addIllegalOp<ReduceOp>();
|
||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
patterns.insert<LhloReduceToGPULaunchConverter>(func.getContext());
|
patterns.insert<LhloReduceToGPULaunchConverter>(func.getContext());
|
||||||
|
|
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/Dialect/SCF/SCF.h"
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
@ -43,10 +44,11 @@ Value ApplySingleResultLhloCode(Location loc, ValueRange operands,
|
||||||
Block* lhlo_block, OpBuilder* b) {
|
Block* lhlo_block, OpBuilder* b) {
|
||||||
SmallVector<Value, 2> arg_bufs;
|
SmallVector<Value, 2> arg_bufs;
|
||||||
for (auto arg_type : lhlo_block->getArgumentTypes()) {
|
for (auto arg_type : lhlo_block->getArgumentTypes()) {
|
||||||
arg_bufs.push_back(b->create<AllocOp>(loc, arg_type.cast<MemRefType>()));
|
arg_bufs.push_back(
|
||||||
|
b->create<memref::AllocOp>(loc, arg_type.cast<MemRefType>()));
|
||||||
}
|
}
|
||||||
for (auto operand : llvm::enumerate(operands)) {
|
for (auto operand : llvm::enumerate(operands)) {
|
||||||
b->create<StoreOp>(loc, operand.value(), arg_bufs[operand.index()]);
|
b->create<memref::StoreOp>(loc, operand.value(), arg_bufs[operand.index()]);
|
||||||
}
|
}
|
||||||
// Clone the ops from `lhlo_block`.
|
// Clone the ops from `lhlo_block`.
|
||||||
BlockAndValueMapping mapping;
|
BlockAndValueMapping mapping;
|
||||||
|
@ -55,7 +57,7 @@ Value ApplySingleResultLhloCode(Location loc, ValueRange operands,
|
||||||
auto clone = b->clone(nested, mapping);
|
auto clone = b->clone(nested, mapping);
|
||||||
mapping.map(nested.getResults(), clone->getResults());
|
mapping.map(nested.getResults(), clone->getResults());
|
||||||
}
|
}
|
||||||
return b->create<LoadOp>(loc, arg_bufs.back());
|
return b->create<memref::LoadOp>(loc, arg_bufs.back());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Converts a block with LHLO ops and with signature:
|
// Converts a block with LHLO ops and with signature:
|
||||||
|
@ -78,7 +80,8 @@ void ConvertToReductionOperator(Location loc, scf::ReduceOp reduce_op,
|
||||||
Value GetStaticOrDynamicDim(mlir::Location loc, Value shaped_value,
|
Value GetStaticOrDynamicDim(mlir::Location loc, Value shaped_value,
|
||||||
size_t dim_index, int64_t dim, OpBuilder* b) {
|
size_t dim_index, int64_t dim, OpBuilder* b) {
|
||||||
return dim == ShapedType::kDynamicSize
|
return dim == ShapedType::kDynamicSize
|
||||||
? b->create<DimOp>(loc, shaped_value, dim_index).getResult()
|
? b->create<memref::DimOp>(loc, shaped_value, dim_index)
|
||||||
|
.getResult()
|
||||||
: b->create<ConstantIndexOp>(loc, dim);
|
: b->create<ConstantIndexOp>(loc, dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -249,8 +252,8 @@ class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
|
||||||
(is_reducing_dim ? reduce_step : parallel_step).push_back(step);
|
(is_reducing_dim ? reduce_step : parallel_step).push_back(step);
|
||||||
}
|
}
|
||||||
// Load initial value from memref<element_type>.
|
// Load initial value from memref<element_type>.
|
||||||
SmallVector<Value, 1> init_value = {
|
SmallVector<Value, 1> init_value = {rewriter->create<memref::LoadOp>(
|
||||||
rewriter->create<LoadOp>(loc, *reduce_op.init_values().begin())};
|
loc, *reduce_op.init_values().begin())};
|
||||||
// Outer ParallelOp is not needed if it is a reduction across all dims.
|
// Outer ParallelOp is not needed if it is a reduction across all dims.
|
||||||
scf::ParallelOp outer;
|
scf::ParallelOp outer;
|
||||||
if (!parallel_lower.empty()) {
|
if (!parallel_lower.empty()) {
|
||||||
|
@ -272,7 +275,7 @@ class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
|
||||||
out_indices.push_back(rewriter->create<ConstantIndexOp>(loc, 0));
|
out_indices.push_back(rewriter->create<ConstantIndexOp>(loc, 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter->create<StoreOp>(loc, reduction_result, out, out_indices);
|
rewriter->create<memref::StoreOp>(loc, reduction_result, out, out_indices);
|
||||||
|
|
||||||
// Load the element to reduce.
|
// Load the element to reduce.
|
||||||
SmallVector<Value, 2> indices;
|
SmallVector<Value, 2> indices;
|
||||||
|
@ -290,7 +293,7 @@ class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter->setInsertionPointToStart(inner.getBody());
|
rewriter->setInsertionPointToStart(inner.getBody());
|
||||||
Value elem = rewriter->create<mlir::LoadOp>(
|
Value elem = rewriter->create<mlir::memref::LoadOp>(
|
||||||
loc, *reduce_op.operands().begin(), indices);
|
loc, *reduce_op.operands().begin(), indices);
|
||||||
return rewriter->create<scf::ReduceOp>(loc, elem);
|
return rewriter->create<scf::ReduceOp>(loc, elem);
|
||||||
}
|
}
|
||||||
|
@ -385,7 +388,7 @@ class ReduceWindowOpConverter
|
||||||
ConversionPatternRewriter* rewriter) const {
|
ConversionPatternRewriter* rewriter) const {
|
||||||
auto loc = reduce_window_op.getLoc();
|
auto loc = reduce_window_op.getLoc();
|
||||||
Value init_value =
|
Value init_value =
|
||||||
rewriter->create<LoadOp>(loc, reduce_window_op.init_value());
|
rewriter->create<memref::LoadOp>(loc, reduce_window_op.init_value());
|
||||||
|
|
||||||
Value zero = rewriter->create<ConstantIndexOp>(loc, 0);
|
Value zero = rewriter->create<ConstantIndexOp>(loc, 0);
|
||||||
Value one = rewriter->create<ConstantIndexOp>(loc, 1);
|
Value one = rewriter->create<ConstantIndexOp>(loc, 1);
|
||||||
|
@ -408,7 +411,8 @@ class ReduceWindowOpConverter
|
||||||
|
|
||||||
Value reduction_result = *window_loop.getResults().begin();
|
Value reduction_result = *window_loop.getResults().begin();
|
||||||
auto output_ivs = output_loop.getInductionVars();
|
auto output_ivs = output_loop.getInductionVars();
|
||||||
rewriter->create<StoreOp>(loc, reduction_result, output, output_ivs);
|
rewriter->create<memref::StoreOp>(loc, reduction_result, output,
|
||||||
|
output_ivs);
|
||||||
return std::make_pair(output_loop, window_loop);
|
return std::make_pair(output_loop, window_loop);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -439,7 +443,7 @@ class ReduceWindowOpConverter
|
||||||
|
|
||||||
OpBuilder then_builder =
|
OpBuilder then_builder =
|
||||||
elem_or_init.getThenBodyBuilder(rewriter->getListener());
|
elem_or_init.getThenBodyBuilder(rewriter->getListener());
|
||||||
Value elem = then_builder.create<mlir::LoadOp>(
|
Value elem = then_builder.create<mlir::memref::LoadOp>(
|
||||||
loc, reduce_window_op.operand(), mapped_ivs.ivs);
|
loc, reduce_window_op.operand(), mapped_ivs.ivs);
|
||||||
then_builder.create<scf::YieldOp>(loc, elem);
|
then_builder.create<scf::YieldOp>(loc, elem);
|
||||||
|
|
||||||
|
@ -497,8 +501,8 @@ class SelectAndScatterOpConverter
|
||||||
auto selected_ivs = SelectIvs(s_and_s_op, loop_over_src, &rewriter);
|
auto selected_ivs = SelectIvs(s_and_s_op, loop_over_src, &rewriter);
|
||||||
|
|
||||||
// Load `source[selected_ivs]`.
|
// Load `source[selected_ivs]`.
|
||||||
auto src_elem = rewriter.create<LoadOp>(loc, s_and_s_op.source(),
|
auto src_elem = rewriter.create<memref::LoadOp>(
|
||||||
loop_over_src.getInductionVars());
|
loc, s_and_s_op.source(), loop_over_src.getInductionVars());
|
||||||
|
|
||||||
// Compute `out[selected_ivs]` = scatter(out[selected_ivs], src_element)`.
|
// Compute `out[selected_ivs]` = scatter(out[selected_ivs], src_element)`.
|
||||||
auto rmw = rewriter.create<GenericAtomicRMWOp>(loc, s_and_s_op.out(),
|
auto rmw = rewriter.create<GenericAtomicRMWOp>(loc, s_and_s_op.out(),
|
||||||
|
@ -517,13 +521,13 @@ class SelectAndScatterOpConverter
|
||||||
void InitializeOutput(lmhlo::SelectAndScatterOp s_and_s_op,
|
void InitializeOutput(lmhlo::SelectAndScatterOp s_and_s_op,
|
||||||
OpBuilder* b) const {
|
OpBuilder* b) const {
|
||||||
auto loc = s_and_s_op.getLoc();
|
auto loc = s_and_s_op.getLoc();
|
||||||
Value init_value = b->create<LoadOp>(loc, s_and_s_op.init_value());
|
Value init_value = b->create<memref::LoadOp>(loc, s_and_s_op.init_value());
|
||||||
|
|
||||||
scf::ParallelOp loop_over_output =
|
scf::ParallelOp loop_over_output =
|
||||||
MakeLoopOverShape(loc, s_and_s_op.out(), b);
|
MakeLoopOverShape(loc, s_and_s_op.out(), b);
|
||||||
OpBuilder::InsertionGuard guard(*b);
|
OpBuilder::InsertionGuard guard(*b);
|
||||||
b->setInsertionPointToStart(loop_over_output.getBody());
|
b->setInsertionPointToStart(loop_over_output.getBody());
|
||||||
b->create<StoreOp>(loc, init_value, s_and_s_op.out(),
|
b->create<memref::StoreOp>(loc, init_value, s_and_s_op.out(),
|
||||||
loop_over_output.getInductionVars());
|
loop_over_output.getInductionVars());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -647,7 +651,7 @@ class SelectAndScatterOpConverter
|
||||||
|
|
||||||
TypeRange iter_arg_types{ivs_val_flag->to_vector()};
|
TypeRange iter_arg_types{ivs_val_flag->to_vector()};
|
||||||
Value operand_elem =
|
Value operand_elem =
|
||||||
b->create<LoadOp>(loc, s_and_s_op.operand(), operand_ivs);
|
b->create<memref::LoadOp>(loc, s_and_s_op.operand(), operand_ivs);
|
||||||
auto if_init =
|
auto if_init =
|
||||||
b->create<scf::IfOp>(loc, iter_arg_types, ivs_val_flag->is_init(),
|
b->create<scf::IfOp>(loc, iter_arg_types, ivs_val_flag->is_init(),
|
||||||
/*withElseRegion=*/true);
|
/*withElseRegion=*/true);
|
||||||
|
@ -712,8 +716,8 @@ struct LhloLegalizeToParallelLoopsPass
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
ConversionTarget target(getContext());
|
ConversionTarget target(getContext());
|
||||||
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
|
target.addLegalDialect<linalg::LinalgDialect, memref::MemRefDialect,
|
||||||
scf::SCFDialect, LmhloDialect>();
|
StandardOpsDialect, scf::SCFDialect, LmhloDialect>();
|
||||||
target.addIllegalOp<lmhlo::ReduceOp, lmhlo::ReduceWindowOp,
|
target.addIllegalOp<lmhlo::ReduceOp, lmhlo::ReduceWindowOp,
|
||||||
lmhlo::SelectAndScatterOp>();
|
lmhlo::SelectAndScatterOp>();
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
|
@ -58,7 +59,8 @@ Value CalculateShapeValue(Location loc, Value operand,
|
||||||
int64_t rank = result_type.getRank();
|
int64_t rank = result_type.getRank();
|
||||||
shape_values.reserve(rank);
|
shape_values.reserve(rank);
|
||||||
for (int64_t i = 0; i < rank; ++i) {
|
for (int64_t i = 0; i < rank; ++i) {
|
||||||
shape_values.push_back(rewriter.create<mlir::DimOp>(loc, operand, i));
|
shape_values.push_back(
|
||||||
|
rewriter.create<mlir::memref::DimOp>(loc, operand, i));
|
||||||
}
|
}
|
||||||
return rewriter.create<tensor::FromElementsOp>(loc, shape_values);
|
return rewriter.create<tensor::FromElementsOp>(loc, shape_values);
|
||||||
}
|
}
|
||||||
|
|
|
@ -967,10 +967,10 @@ func @unpack_repack_same_tuple_single_element(%arg0: tuple<tensor<i32>>) -> tupl
|
||||||
|
|
||||||
// CHECK-LABEL: func @erase_dead_lhlo_constant
|
// CHECK-LABEL: func @erase_dead_lhlo_constant
|
||||||
func @erase_dead_lhlo_constant() {
|
func @erase_dead_lhlo_constant() {
|
||||||
%M = alloc() : memref<256x1024xf32>
|
%M = memref.alloc() : memref<256x1024xf32>
|
||||||
// CHECK-NEXT: return
|
// CHECK-NEXT: return
|
||||||
"lmhlo.constant"(%M) {value = dense<0.0> : tensor<f32>} : (memref<256x1024xf32>) -> ()
|
"lmhlo.constant"(%M) {value = dense<0.0> : tensor<f32>} : (memref<256x1024xf32>) -> ()
|
||||||
dealloc %M : memref<256x1024xf32>
|
memref.dealloc %M : memref<256x1024xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -979,9 +979,9 @@ func @erase_dead_lhlo_constant() {
|
||||||
func @erase_dead_lhlo_constant_negative(%M : memref<4xf32>) -> memref<256x1024xf32> {
|
func @erase_dead_lhlo_constant_negative(%M : memref<4xf32>) -> memref<256x1024xf32> {
|
||||||
// CHECK-NEXT: lmhlo.constant
|
// CHECK-NEXT: lmhlo.constant
|
||||||
"lmhlo.constant"(%M) {value = dense<0.0> : tensor<f32>} : (memref<4xf32>) -> ()
|
"lmhlo.constant"(%M) {value = dense<0.0> : tensor<f32>} : (memref<4xf32>) -> ()
|
||||||
// CHECK-NEXT: alloc
|
// CHECK-NEXT: memref.alloc
|
||||||
// CHECK-NEXT: lmhlo.constant
|
// CHECK-NEXT: lmhlo.constant
|
||||||
%N = alloc() : memref<256x1024xf32>
|
%N = memref.alloc() : memref<256x1024xf32>
|
||||||
"lmhlo.constant"(%N) {value = dense<0.0> : tensor<f32>} : (memref<256x1024xf32>) -> ()
|
"lmhlo.constant"(%N) {value = dense<0.0> : tensor<f32>} : (memref<256x1024xf32>) -> ()
|
||||||
return %N : memref<256x1024xf32>
|
return %N : memref<256x1024xf32>
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,7 +27,7 @@ func private @print_memref_i8(memref<*xi8>) attributes { llvm.emit_c_interface }
|
||||||
func private @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface }
|
func private @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface }
|
||||||
|
|
||||||
func @trivial_broadcast_wrapper() {
|
func @trivial_broadcast_wrapper() {
|
||||||
%input_buf = alloc() : memref<3xf32>
|
%input_buf = memref.alloc() : memref<3xf32>
|
||||||
|
|
||||||
%c1f32 = constant 1.0 : f32
|
%c1f32 = constant 1.0 : f32
|
||||||
%c2f32 = constant 2.0 : f32
|
%c2f32 = constant 2.0 : f32
|
||||||
|
@ -36,19 +36,19 @@ func @trivial_broadcast_wrapper() {
|
||||||
%c0 = constant 0 : index
|
%c0 = constant 0 : index
|
||||||
%c1 = constant 1 : index
|
%c1 = constant 1 : index
|
||||||
%c2 = constant 2 : index
|
%c2 = constant 2 : index
|
||||||
store %c1f32, %input_buf[%c0] : memref<3xf32>
|
memref.store %c1f32, %input_buf[%c0] : memref<3xf32>
|
||||||
store %c2f32, %input_buf[%c1] : memref<3xf32>
|
memref.store %c2f32, %input_buf[%c1] : memref<3xf32>
|
||||||
store %c3f32, %input_buf[%c2] : memref<3xf32>
|
memref.store %c3f32, %input_buf[%c2] : memref<3xf32>
|
||||||
%input = tensor_load %input_buf : memref<3xf32>
|
%input = memref.tensor_load %input_buf : memref<3xf32>
|
||||||
|
|
||||||
// Test BroadcastInDimOp.
|
// Test BroadcastInDimOp.
|
||||||
%output = "mhlo.broadcast_in_dim"(%input) {
|
%output = "mhlo.broadcast_in_dim"(%input) {
|
||||||
broadcast_dimensions = dense<0> : tensor<1xi64>
|
broadcast_dimensions = dense<0> : tensor<1xi64>
|
||||||
} : (tensor<3xf32>) -> tensor<3x4xf32>
|
} : (tensor<3xf32>) -> tensor<3x4xf32>
|
||||||
|
|
||||||
%output_buf = tensor_to_memref %output : memref<3x4xf32>
|
%output_buf = memref.buffer_cast %output : memref<3x4xf32>
|
||||||
|
|
||||||
%unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
|
%unranked_output = memref.cast %output_buf : memref<3x4xf32> to memref<*xf32>
|
||||||
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
||||||
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
||||||
// CHECK-NEXT: [1, 1, 1, 1]
|
// CHECK-NEXT: [1, 1, 1, 1]
|
||||||
|
@ -63,9 +63,9 @@ func @trivial_broadcast_wrapper() {
|
||||||
broadcast_dimensions = dense<0> : tensor<1xi64>
|
broadcast_dimensions = dense<0> : tensor<1xi64>
|
||||||
} : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x4xf32>
|
} : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x4xf32>
|
||||||
|
|
||||||
%dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32>
|
%dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x4xf32>
|
||||||
|
|
||||||
%unranked_dyn_output = memref_cast %dyn_output_buf
|
%unranked_dyn_output = memref.cast %dyn_output_buf
|
||||||
: memref<3x4xf32> to memref<*xf32>
|
: memref<3x4xf32> to memref<*xf32>
|
||||||
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
|
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
|
||||||
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
||||||
|
@ -76,29 +76,29 @@ func @trivial_broadcast_wrapper() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func @broadcast_in_X_dim_wrapper() {
|
func @broadcast_in_X_dim_wrapper() {
|
||||||
%input_buf = alloc() : memref<1x4xf32>
|
%input_buf = memref.alloc() : memref<1x4xf32>
|
||||||
%c1f32 = constant 1.0 : f32
|
%c1f32 = constant 1.0 : f32
|
||||||
%c0 = constant 0 : index
|
%c0 = constant 0 : index
|
||||||
store %c1f32, %input_buf[%c0, %c0] : memref<1x4xf32>
|
memref.store %c1f32, %input_buf[%c0, %c0] : memref<1x4xf32>
|
||||||
%c2f32 = constant 2.0 : f32
|
%c2f32 = constant 2.0 : f32
|
||||||
%c1 = constant 1 : index
|
%c1 = constant 1 : index
|
||||||
store %c2f32, %input_buf[%c0, %c1] : memref<1x4xf32>
|
memref.store %c2f32, %input_buf[%c0, %c1] : memref<1x4xf32>
|
||||||
%c3f32 = constant 3.0 : f32
|
%c3f32 = constant 3.0 : f32
|
||||||
%c2 = constant 2 : index
|
%c2 = constant 2 : index
|
||||||
store %c3f32, %input_buf[%c0, %c2] : memref<1x4xf32>
|
memref.store %c3f32, %input_buf[%c0, %c2] : memref<1x4xf32>
|
||||||
%c4f32 = constant 4.0 : f32
|
%c4f32 = constant 4.0 : f32
|
||||||
%c3 = constant 3 : index
|
%c3 = constant 3 : index
|
||||||
store %c4f32, %input_buf[%c0, %c3] : memref<1x4xf32>
|
memref.store %c4f32, %input_buf[%c0, %c3] : memref<1x4xf32>
|
||||||
%input = tensor_load %input_buf : memref<1x4xf32>
|
%input = memref.tensor_load %input_buf : memref<1x4xf32>
|
||||||
|
|
||||||
// Test BroadcastInDimOp.
|
// Test BroadcastInDimOp.
|
||||||
%output = "mhlo.broadcast_in_dim"(%input) {
|
%output = "mhlo.broadcast_in_dim"(%input) {
|
||||||
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
|
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
|
||||||
} : (tensor<1x4xf32>) -> tensor<3x4xf32>
|
} : (tensor<1x4xf32>) -> tensor<3x4xf32>
|
||||||
|
|
||||||
%output_buf = tensor_to_memref %output : memref<3x4xf32>
|
%output_buf = memref.buffer_cast %output : memref<3x4xf32>
|
||||||
|
|
||||||
%unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
|
%unranked_output = memref.cast %output_buf : memref<3x4xf32> to memref<*xf32>
|
||||||
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
||||||
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
||||||
// CHECK-NEXT: [1, 2, 3, 4]
|
// CHECK-NEXT: [1, 2, 3, 4]
|
||||||
|
@ -112,9 +112,9 @@ func @broadcast_in_X_dim_wrapper() {
|
||||||
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
|
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
|
||||||
} : (tensor<1x4xf32>, tensor<2xindex>) -> tensor<3x4xf32>
|
} : (tensor<1x4xf32>, tensor<2xindex>) -> tensor<3x4xf32>
|
||||||
|
|
||||||
%dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32>
|
%dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x4xf32>
|
||||||
|
|
||||||
%unranked_dyn_output = memref_cast %dyn_output_buf
|
%unranked_dyn_output = memref.cast %dyn_output_buf
|
||||||
: memref<3x4xf32> to memref<*xf32>
|
: memref<3x4xf32> to memref<*xf32>
|
||||||
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
|
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
|
||||||
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
||||||
|
@ -125,26 +125,26 @@ func @broadcast_in_X_dim_wrapper() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func @broadcast_in_Y_dim_wrapper() {
|
func @broadcast_in_Y_dim_wrapper() {
|
||||||
%input_buf = alloc() : memref<3x1xf32>
|
%input_buf = memref.alloc() : memref<3x1xf32>
|
||||||
%c1f32 = constant 1.0 : f32
|
%c1f32 = constant 1.0 : f32
|
||||||
%c0 = constant 0 : index
|
%c0 = constant 0 : index
|
||||||
store %c1f32, %input_buf[%c0, %c0] : memref<3x1xf32>
|
memref.store %c1f32, %input_buf[%c0, %c0] : memref<3x1xf32>
|
||||||
%c2f32 = constant 2.0 : f32
|
%c2f32 = constant 2.0 : f32
|
||||||
%c1 = constant 1 : index
|
%c1 = constant 1 : index
|
||||||
store %c2f32, %input_buf[%c1, %c0] : memref<3x1xf32>
|
memref.store %c2f32, %input_buf[%c1, %c0] : memref<3x1xf32>
|
||||||
%c3f32 = constant 3.0 : f32
|
%c3f32 = constant 3.0 : f32
|
||||||
%c2 = constant 2 : index
|
%c2 = constant 2 : index
|
||||||
store %c3f32, %input_buf[%c2, %c0] : memref<3x1xf32>
|
memref.store %c3f32, %input_buf[%c2, %c0] : memref<3x1xf32>
|
||||||
%input = tensor_load %input_buf : memref<3x1xf32>
|
%input = memref.tensor_load %input_buf : memref<3x1xf32>
|
||||||
|
|
||||||
// Test BroadcastInDimOp.
|
// Test BroadcastInDimOp.
|
||||||
%output = "mhlo.broadcast_in_dim"(%input) {
|
%output = "mhlo.broadcast_in_dim"(%input) {
|
||||||
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
|
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
|
||||||
} : (tensor<3x1xf32>) -> tensor<3x4xf32>
|
} : (tensor<3x1xf32>) -> tensor<3x4xf32>
|
||||||
|
|
||||||
%output_buf = tensor_to_memref %output : memref<3x4xf32>
|
%output_buf = memref.buffer_cast %output : memref<3x4xf32>
|
||||||
|
|
||||||
%unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
|
%unranked_output = memref.cast %output_buf : memref<3x4xf32> to memref<*xf32>
|
||||||
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
||||||
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
||||||
// CHECK-NEXT: [1, 1, 1, 1]
|
// CHECK-NEXT: [1, 1, 1, 1]
|
||||||
|
@ -159,9 +159,9 @@ func @broadcast_in_Y_dim_wrapper() {
|
||||||
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
|
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
|
||||||
} : (tensor<3x1xf32>, tensor<2xindex>) -> tensor<3x4xf32>
|
} : (tensor<3x1xf32>, tensor<2xindex>) -> tensor<3x4xf32>
|
||||||
|
|
||||||
%dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32>
|
%dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x4xf32>
|
||||||
|
|
||||||
%unranked_dyn_output = memref_cast %dyn_output_buf
|
%unranked_dyn_output = memref.cast %dyn_output_buf
|
||||||
: memref<3x4xf32> to memref<*xf32>
|
: memref<3x4xf32> to memref<*xf32>
|
||||||
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
|
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
|
||||||
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
||||||
|
@ -172,29 +172,29 @@ func @broadcast_in_Y_dim_wrapper() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func @broadcast_in_X_dim_transpose_wrapper() {
|
func @broadcast_in_X_dim_transpose_wrapper() {
|
||||||
%input_buf = alloc() : memref<4x1xf32>
|
%input_buf = memref.alloc() : memref<4x1xf32>
|
||||||
%c1f32 = constant 1.0 : f32
|
%c1f32 = constant 1.0 : f32
|
||||||
%c0 = constant 0 : index
|
%c0 = constant 0 : index
|
||||||
store %c1f32, %input_buf[%c0, %c0] : memref<4x1xf32>
|
memref.store %c1f32, %input_buf[%c0, %c0] : memref<4x1xf32>
|
||||||
%c2f32 = constant 2.0 : f32
|
%c2f32 = constant 2.0 : f32
|
||||||
%c1 = constant 1 : index
|
%c1 = constant 1 : index
|
||||||
store %c2f32, %input_buf[%c1, %c0] : memref<4x1xf32>
|
memref.store %c2f32, %input_buf[%c1, %c0] : memref<4x1xf32>
|
||||||
%c3f32 = constant 3.0 : f32
|
%c3f32 = constant 3.0 : f32
|
||||||
%c2 = constant 2 : index
|
%c2 = constant 2 : index
|
||||||
store %c3f32, %input_buf[%c2, %c0] : memref<4x1xf32>
|
memref.store %c3f32, %input_buf[%c2, %c0] : memref<4x1xf32>
|
||||||
%c4f32 = constant 4.0 : f32
|
%c4f32 = constant 4.0 : f32
|
||||||
%c3 = constant 3 : index
|
%c3 = constant 3 : index
|
||||||
store %c4f32, %input_buf[%c3, %c0] : memref<4x1xf32>
|
memref.store %c4f32, %input_buf[%c3, %c0] : memref<4x1xf32>
|
||||||
%input = tensor_load %input_buf : memref<4x1xf32>
|
%input = memref.tensor_load %input_buf : memref<4x1xf32>
|
||||||
|
|
||||||
// Test BroadcastInDimOp.
|
// Test BroadcastInDimOp.
|
||||||
%output = "mhlo.broadcast_in_dim"(%input) {
|
%output = "mhlo.broadcast_in_dim"(%input) {
|
||||||
broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
|
broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
|
||||||
} : (tensor<4x1xf32>) -> tensor<3x4xf32>
|
} : (tensor<4x1xf32>) -> tensor<3x4xf32>
|
||||||
|
|
||||||
%output_buf = tensor_to_memref %output : memref<3x4xf32>
|
%output_buf = memref.buffer_cast %output : memref<3x4xf32>
|
||||||
|
|
||||||
%unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
|
%unranked_output = memref.cast %output_buf : memref<3x4xf32> to memref<*xf32>
|
||||||
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
||||||
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
||||||
// CHECK-NEXT: [1, 2, 3, 4]
|
// CHECK-NEXT: [1, 2, 3, 4]
|
||||||
|
@ -208,9 +208,9 @@ func @broadcast_in_X_dim_transpose_wrapper() {
|
||||||
broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
|
broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
|
||||||
} : (tensor<4x1xf32>, tensor<2xindex>) -> tensor<3x4xf32>
|
} : (tensor<4x1xf32>, tensor<2xindex>) -> tensor<3x4xf32>
|
||||||
|
|
||||||
%dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32>
|
%dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x4xf32>
|
||||||
|
|
||||||
%unranked_dyn_output = memref_cast %dyn_output_buf
|
%unranked_dyn_output = memref.cast %dyn_output_buf
|
||||||
: memref<3x4xf32> to memref<*xf32>
|
: memref<3x4xf32> to memref<*xf32>
|
||||||
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
|
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
|
||||||
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
||||||
|
@ -221,26 +221,26 @@ func @broadcast_in_X_dim_transpose_wrapper() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func @broadcast_in_Y_dim_transpose_wrapper() {
|
func @broadcast_in_Y_dim_transpose_wrapper() {
|
||||||
%input_buf = alloc() : memref<1x3xf32>
|
%input_buf = memref.alloc() : memref<1x3xf32>
|
||||||
%c1f32 = constant 1.0 : f32
|
%c1f32 = constant 1.0 : f32
|
||||||
%c0 = constant 0 : index
|
%c0 = constant 0 : index
|
||||||
store %c1f32, %input_buf[%c0, %c0] : memref<1x3xf32>
|
memref.store %c1f32, %input_buf[%c0, %c0] : memref<1x3xf32>
|
||||||
%c2f32 = constant 2.0 : f32
|
%c2f32 = constant 2.0 : f32
|
||||||
%c1 = constant 1 : index
|
%c1 = constant 1 : index
|
||||||
store %c2f32, %input_buf[%c0, %c1] : memref<1x3xf32>
|
memref.store %c2f32, %input_buf[%c0, %c1] : memref<1x3xf32>
|
||||||
%c3f32 = constant 3.0 : f32
|
%c3f32 = constant 3.0 : f32
|
||||||
%c2 = constant 2 : index
|
%c2 = constant 2 : index
|
||||||
store %c3f32, %input_buf[%c0, %c2] : memref<1x3xf32>
|
memref.store %c3f32, %input_buf[%c0, %c2] : memref<1x3xf32>
|
||||||
%input = tensor_load %input_buf : memref<1x3xf32>
|
%input = memref.tensor_load %input_buf : memref<1x3xf32>
|
||||||
|
|
||||||
// Test BroadcastInDimOp.
|
// Test BroadcastInDimOp.
|
||||||
%output = "mhlo.broadcast_in_dim"(%input) {
|
%output = "mhlo.broadcast_in_dim"(%input) {
|
||||||
broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
|
broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
|
||||||
} : (tensor<1x3xf32>) -> tensor<3x4xf32>
|
} : (tensor<1x3xf32>) -> tensor<3x4xf32>
|
||||||
|
|
||||||
%output_buf = tensor_to_memref %output : memref<3x4xf32>
|
%output_buf = memref.buffer_cast %output : memref<3x4xf32>
|
||||||
|
|
||||||
%unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
|
%unranked_output = memref.cast %output_buf : memref<3x4xf32> to memref<*xf32>
|
||||||
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
||||||
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
||||||
// CHECK-NEXT-NEXT: [1, 1, 1, 1]
|
// CHECK-NEXT-NEXT: [1, 1, 1, 1]
|
||||||
|
@ -255,9 +255,9 @@ func @broadcast_in_Y_dim_transpose_wrapper() {
|
||||||
broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
|
broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
|
||||||
} : (tensor<1x3xf32>, tensor<2xindex>) -> tensor<3x4xf32>
|
} : (tensor<1x3xf32>, tensor<2xindex>) -> tensor<3x4xf32>
|
||||||
|
|
||||||
%dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32>
|
%dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x4xf32>
|
||||||
|
|
||||||
%unranked_dyn_output = memref_cast %dyn_output_buf
|
%unranked_dyn_output = memref.cast %dyn_output_buf
|
||||||
: memref<3x4xf32> to memref<*xf32>
|
: memref<3x4xf32> to memref<*xf32>
|
||||||
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
|
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
|
||||||
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
||||||
|
@ -268,20 +268,20 @@ func @broadcast_in_Y_dim_transpose_wrapper() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func @broadcast_scalar_1d_wrapper() {
|
func @broadcast_scalar_1d_wrapper() {
|
||||||
%input_buf = alloc() : memref<1xf32>
|
%input_buf = memref.alloc() : memref<1xf32>
|
||||||
%c1f32 = constant 1.0 : f32
|
%c1f32 = constant 1.0 : f32
|
||||||
%c0 = constant 0 : index
|
%c0 = constant 0 : index
|
||||||
store %c1f32, %input_buf[%c0] : memref<1xf32>
|
memref.store %c1f32, %input_buf[%c0] : memref<1xf32>
|
||||||
%input = tensor_load %input_buf : memref<1xf32>
|
%input = memref.tensor_load %input_buf : memref<1xf32>
|
||||||
|
|
||||||
// Test BroadcastInDimOp.
|
// Test BroadcastInDimOp.
|
||||||
%output = "mhlo.broadcast_in_dim"(%input) {
|
%output = "mhlo.broadcast_in_dim"(%input) {
|
||||||
broadcast_dimensions = dense<0> : tensor<1xi64>
|
broadcast_dimensions = dense<0> : tensor<1xi64>
|
||||||
} : (tensor<1xf32>) -> tensor<3x4xf32>
|
} : (tensor<1xf32>) -> tensor<3x4xf32>
|
||||||
|
|
||||||
%output_buf = tensor_to_memref %output : memref<3x4xf32>
|
%output_buf = memref.buffer_cast %output : memref<3x4xf32>
|
||||||
|
|
||||||
%unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
|
%unranked_output = memref.cast %output_buf : memref<3x4xf32> to memref<*xf32>
|
||||||
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
||||||
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
||||||
// CHECK-NEXT: [1, 1, 1, 1]
|
// CHECK-NEXT: [1, 1, 1, 1]
|
||||||
|
@ -296,9 +296,9 @@ func @broadcast_scalar_1d_wrapper() {
|
||||||
broadcast_dimensions = dense<0> : tensor<1xi64>
|
broadcast_dimensions = dense<0> : tensor<1xi64>
|
||||||
} : (tensor<1xf32>, tensor<2xindex>) -> tensor<3x4xf32>
|
} : (tensor<1xf32>, tensor<2xindex>) -> tensor<3x4xf32>
|
||||||
|
|
||||||
%dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32>
|
%dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x4xf32>
|
||||||
|
|
||||||
%unranked_dyn_output = memref_cast %dyn_output_buf
|
%unranked_dyn_output = memref.cast %dyn_output_buf
|
||||||
: memref<3x4xf32> to memref<*xf32>
|
: memref<3x4xf32> to memref<*xf32>
|
||||||
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
|
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
|
||||||
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
||||||
|
@ -309,20 +309,20 @@ func @broadcast_scalar_1d_wrapper() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func @broadcast_scalar_2d_wrapper() {
|
func @broadcast_scalar_2d_wrapper() {
|
||||||
%input_buf = alloc() : memref<1x1xf32>
|
%input_buf = memref.alloc() : memref<1x1xf32>
|
||||||
%c1f32 = constant 1.0 : f32
|
%c1f32 = constant 1.0 : f32
|
||||||
%c0 = constant 0 : index
|
%c0 = constant 0 : index
|
||||||
store %c1f32, %input_buf[%c0, %c0] : memref<1x1xf32>
|
memref.store %c1f32, %input_buf[%c0, %c0] : memref<1x1xf32>
|
||||||
%input = tensor_load %input_buf : memref<1x1xf32>
|
%input = memref.tensor_load %input_buf : memref<1x1xf32>
|
||||||
|
|
||||||
// Test BroadcastInDimOp.
|
// Test BroadcastInDimOp.
|
||||||
%output = "mhlo.broadcast_in_dim"(%input) {
|
%output = "mhlo.broadcast_in_dim"(%input) {
|
||||||
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
|
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
|
||||||
} : (tensor<1x1xf32>) -> tensor<3x4xf32>
|
} : (tensor<1x1xf32>) -> tensor<3x4xf32>
|
||||||
|
|
||||||
%output_buf = tensor_to_memref %output : memref<3x4xf32>
|
%output_buf = memref.buffer_cast %output : memref<3x4xf32>
|
||||||
|
|
||||||
%unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
|
%unranked_output = memref.cast %output_buf : memref<3x4xf32> to memref<*xf32>
|
||||||
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
||||||
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
||||||
// CHECK-NEXT: [1, 1, 1, 1]
|
// CHECK-NEXT: [1, 1, 1, 1]
|
||||||
|
@ -337,9 +337,9 @@ func @broadcast_scalar_2d_wrapper() {
|
||||||
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
|
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
|
||||||
} : (tensor<1x1xf32>, tensor<2xindex>) -> tensor<3x4xf32>
|
} : (tensor<1x1xf32>, tensor<2xindex>) -> tensor<3x4xf32>
|
||||||
|
|
||||||
%dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32>
|
%dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x4xf32>
|
||||||
|
|
||||||
%unranked_dyn_output = memref_cast %dyn_output_buf
|
%unranked_dyn_output = memref.cast %dyn_output_buf
|
||||||
: memref<3x4xf32> to memref<*xf32>
|
: memref<3x4xf32> to memref<*xf32>
|
||||||
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
|
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
|
||||||
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
||||||
|
@ -350,7 +350,7 @@ func @broadcast_scalar_2d_wrapper() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func @broadcast_to_the_same_shape() {
|
func @broadcast_to_the_same_shape() {
|
||||||
%input_buf = alloc() : memref<2x3xf32>
|
%input_buf = memref.alloc() : memref<2x3xf32>
|
||||||
|
|
||||||
%c1f32 = constant 1.0 : f32
|
%c1f32 = constant 1.0 : f32
|
||||||
%c2f32 = constant 2.0 : f32
|
%c2f32 = constant 2.0 : f32
|
||||||
|
@ -360,22 +360,22 @@ func @broadcast_to_the_same_shape() {
|
||||||
%c1 = constant 1 : index
|
%c1 = constant 1 : index
|
||||||
%c2 = constant 2 : index
|
%c2 = constant 2 : index
|
||||||
%c3 = constant 3 : index
|
%c3 = constant 3 : index
|
||||||
store %c1f32, %input_buf[%c0, %c0] : memref<2x3xf32>
|
memref.store %c1f32, %input_buf[%c0, %c0] : memref<2x3xf32>
|
||||||
store %c1f32, %input_buf[%c1, %c0] : memref<2x3xf32>
|
memref.store %c1f32, %input_buf[%c1, %c0] : memref<2x3xf32>
|
||||||
store %c2f32, %input_buf[%c0, %c1] : memref<2x3xf32>
|
memref.store %c2f32, %input_buf[%c0, %c1] : memref<2x3xf32>
|
||||||
store %c2f32, %input_buf[%c1, %c1] : memref<2x3xf32>
|
memref.store %c2f32, %input_buf[%c1, %c1] : memref<2x3xf32>
|
||||||
store %c3f32, %input_buf[%c0, %c2] : memref<2x3xf32>
|
memref.store %c3f32, %input_buf[%c0, %c2] : memref<2x3xf32>
|
||||||
store %c3f32, %input_buf[%c1, %c2] : memref<2x3xf32>
|
memref.store %c3f32, %input_buf[%c1, %c2] : memref<2x3xf32>
|
||||||
%input = tensor_load %input_buf : memref<2x3xf32>
|
%input = memref.tensor_load %input_buf : memref<2x3xf32>
|
||||||
|
|
||||||
// Test BroadcastInDimOp.
|
// Test BroadcastInDimOp.
|
||||||
%output = "mhlo.broadcast_in_dim"(%input) {
|
%output = "mhlo.broadcast_in_dim"(%input) {
|
||||||
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
|
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
|
||||||
} : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
} : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||||
|
|
||||||
%output_buf = tensor_to_memref %output : memref<2x3xf32>
|
%output_buf = memref.buffer_cast %output : memref<2x3xf32>
|
||||||
|
|
||||||
%unraked_output = memref_cast %output_buf : memref<2x3xf32> to memref<*xf32>
|
%unraked_output = memref.cast %output_buf : memref<2x3xf32> to memref<*xf32>
|
||||||
call @print_memref_f32(%unraked_output) : (memref<*xf32>) -> ()
|
call @print_memref_f32(%unraked_output) : (memref<*xf32>) -> ()
|
||||||
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
|
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
|
||||||
// CHECK-NEXT: [1, 2, 3]
|
// CHECK-NEXT: [1, 2, 3]
|
||||||
|
@ -387,9 +387,9 @@ func @broadcast_to_the_same_shape() {
|
||||||
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
|
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
|
||||||
} : (tensor<2x3xf32>, tensor<2xindex>) -> tensor<2x3xf32>
|
} : (tensor<2x3xf32>, tensor<2xindex>) -> tensor<2x3xf32>
|
||||||
|
|
||||||
%dyn_output_buf = tensor_to_memref %dyn_output : memref<2x3xf32>
|
%dyn_output_buf = memref.buffer_cast %dyn_output : memref<2x3xf32>
|
||||||
|
|
||||||
%unranked_dyn_output = memref_cast %dyn_output_buf
|
%unranked_dyn_output = memref.cast %dyn_output_buf
|
||||||
: memref<2x3xf32> to memref<*xf32>
|
: memref<2x3xf32> to memref<*xf32>
|
||||||
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
|
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
|
||||||
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
|
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
|
||||||
|
@ -399,7 +399,7 @@ func @broadcast_to_the_same_shape() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func @broadcast_1d_to_2d() {
|
func @broadcast_1d_to_2d() {
|
||||||
%input_buf = alloc() : memref<3xf32>
|
%input_buf = memref.alloc() : memref<3xf32>
|
||||||
|
|
||||||
%c1f32 = constant 1.0 : f32
|
%c1f32 = constant 1.0 : f32
|
||||||
%c2f32 = constant 2.0 : f32
|
%c2f32 = constant 2.0 : f32
|
||||||
|
@ -408,19 +408,19 @@ func @broadcast_1d_to_2d() {
|
||||||
%c0 = constant 0 : index
|
%c0 = constant 0 : index
|
||||||
%c1 = constant 1 : index
|
%c1 = constant 1 : index
|
||||||
%c2 = constant 2 : index
|
%c2 = constant 2 : index
|
||||||
store %c1f32, %input_buf[%c0] : memref<3xf32>
|
memref.store %c1f32, %input_buf[%c0] : memref<3xf32>
|
||||||
store %c2f32, %input_buf[%c1] : memref<3xf32>
|
memref.store %c2f32, %input_buf[%c1] : memref<3xf32>
|
||||||
store %c3f32, %input_buf[%c2] : memref<3xf32>
|
memref.store %c3f32, %input_buf[%c2] : memref<3xf32>
|
||||||
%input = tensor_load %input_buf : memref<3xf32>
|
%input = memref.tensor_load %input_buf : memref<3xf32>
|
||||||
|
|
||||||
// Test BroadcastInDimOp.
|
// Test BroadcastInDimOp.
|
||||||
%output = "mhlo.broadcast_in_dim"(%input) {
|
%output = "mhlo.broadcast_in_dim"(%input) {
|
||||||
broadcast_dimensions = dense<0> : tensor<1xi64>
|
broadcast_dimensions = dense<0> : tensor<1xi64>
|
||||||
} : (tensor<3xf32>) -> tensor<3x3xf32>
|
} : (tensor<3xf32>) -> tensor<3x3xf32>
|
||||||
|
|
||||||
%output_buf = tensor_to_memref %output : memref<3x3xf32>
|
%output_buf = memref.buffer_cast %output : memref<3x3xf32>
|
||||||
|
|
||||||
%unraked_output = memref_cast %output_buf : memref<3x3xf32> to memref<*xf32>
|
%unraked_output = memref.cast %output_buf : memref<3x3xf32> to memref<*xf32>
|
||||||
call @print_memref_f32(%unraked_output) : (memref<*xf32>) -> ()
|
call @print_memref_f32(%unraked_output) : (memref<*xf32>) -> ()
|
||||||
// CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1]
|
// CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1]
|
||||||
// CHECK-NEXT: [1, 1, 1]
|
// CHECK-NEXT: [1, 1, 1]
|
||||||
|
@ -435,9 +435,9 @@ func @broadcast_1d_to_2d() {
|
||||||
broadcast_dimensions = dense<0> : tensor<1xi64>
|
broadcast_dimensions = dense<0> : tensor<1xi64>
|
||||||
} : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x3xf32>
|
} : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x3xf32>
|
||||||
|
|
||||||
%dyn_output_buf = tensor_to_memref %dyn_output : memref<3x3xf32>
|
%dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x3xf32>
|
||||||
|
|
||||||
%unranked_dyn_output = memref_cast %dyn_output_buf
|
%unranked_dyn_output = memref.cast %dyn_output_buf
|
||||||
: memref<3x3xf32> to memref<*xf32>
|
: memref<3x3xf32> to memref<*xf32>
|
||||||
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
|
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
|
||||||
// CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1]
|
// CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1]
|
||||||
|
@ -448,7 +448,7 @@ func @broadcast_1d_to_2d() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func @broadcast_1d_to_2d_with_transpose() {
|
func @broadcast_1d_to_2d_with_transpose() {
|
||||||
%input_buf = alloc() : memref<3xf32>
|
%input_buf = memref.alloc() : memref<3xf32>
|
||||||
|
|
||||||
%c1f32 = constant 1.0 : f32
|
%c1f32 = constant 1.0 : f32
|
||||||
%c2f32 = constant 2.0 : f32
|
%c2f32 = constant 2.0 : f32
|
||||||
|
@ -457,19 +457,19 @@ func @broadcast_1d_to_2d_with_transpose() {
|
||||||
%c0 = constant 0 : index
|
%c0 = constant 0 : index
|
||||||
%c1 = constant 1 : index
|
%c1 = constant 1 : index
|
||||||
%c2 = constant 2 : index
|
%c2 = constant 2 : index
|
||||||
store %c1f32, %input_buf[%c0] : memref<3xf32>
|
memref.store %c1f32, %input_buf[%c0] : memref<3xf32>
|
||||||
store %c2f32, %input_buf[%c1] : memref<3xf32>
|
memref.store %c2f32, %input_buf[%c1] : memref<3xf32>
|
||||||
store %c3f32, %input_buf[%c2] : memref<3xf32>
|
memref.store %c3f32, %input_buf[%c2] : memref<3xf32>
|
||||||
%input = tensor_load %input_buf : memref<3xf32>
|
%input = memref.tensor_load %input_buf : memref<3xf32>
|
||||||
|
|
||||||
// Test BroadcastInDimOp.
|
// Test BroadcastInDimOp.
|
||||||
%output = "mhlo.broadcast_in_dim"(%input) {
|
%output = "mhlo.broadcast_in_dim"(%input) {
|
||||||
broadcast_dimensions = dense<1> : tensor<1xi64>
|
broadcast_dimensions = dense<1> : tensor<1xi64>
|
||||||
} : (tensor<3xf32>) -> tensor<3x3xf32>
|
} : (tensor<3xf32>) -> tensor<3x3xf32>
|
||||||
|
|
||||||
%output_buf = tensor_to_memref %output : memref<3x3xf32>
|
%output_buf = memref.buffer_cast %output : memref<3x3xf32>
|
||||||
|
|
||||||
%unraked_output = memref_cast %output_buf : memref<3x3xf32> to memref<*xf32>
|
%unraked_output = memref.cast %output_buf : memref<3x3xf32> to memref<*xf32>
|
||||||
call @print_memref_f32(%unraked_output) : (memref<*xf32>) -> ()
|
call @print_memref_f32(%unraked_output) : (memref<*xf32>) -> ()
|
||||||
// CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1]
|
// CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1]
|
||||||
// CHECK-NEXT: [1, 2, 3]
|
// CHECK-NEXT: [1, 2, 3]
|
||||||
|
@ -483,9 +483,9 @@ func @broadcast_1d_to_2d_with_transpose() {
|
||||||
broadcast_dimensions = dense<1> : tensor<1xi64>
|
broadcast_dimensions = dense<1> : tensor<1xi64>
|
||||||
} : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x3xf32>
|
} : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x3xf32>
|
||||||
|
|
||||||
%dyn_output_buf = tensor_to_memref %dyn_output : memref<3x3xf32>
|
%dyn_output_buf = memref.buffer_cast %dyn_output : memref<3x3xf32>
|
||||||
|
|
||||||
%unranked_dyn_output = memref_cast %dyn_output_buf
|
%unranked_dyn_output = memref.cast %dyn_output_buf
|
||||||
: memref<3x3xf32> to memref<*xf32>
|
: memref<3x3xf32> to memref<*xf32>
|
||||||
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
|
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
|
||||||
// CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1]
|
// CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1]
|
||||||
|
|
|
@ -4,10 +4,10 @@ func private @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface
|
||||||
|
|
||||||
// Helper function to print scalar values.
|
// Helper function to print scalar values.
|
||||||
func @print_f32(%arg : f32) {
|
func @print_f32(%arg : f32) {
|
||||||
%mem = alloca() : memref<1xf32>
|
%mem = memref.alloca() : memref<1xf32>
|
||||||
%c0 = constant 0 : index
|
%c0 = constant 0 : index
|
||||||
store %arg, %mem[%c0] : memref<1xf32>
|
memref.store %arg, %mem[%c0] : memref<1xf32>
|
||||||
%mem_unranked = memref_cast %mem : memref<1xf32> to memref<*xf32>
|
%mem_unranked = memref.cast %mem : memref<1xf32> to memref<*xf32>
|
||||||
call @print_memref_f32(%mem_unranked) : (memref<*xf32>) -> ()
|
call @print_memref_f32(%mem_unranked) : (memref<*xf32>) -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,21 +21,21 @@ func @reduce_add() {
|
||||||
%c1 = constant 1 : index
|
%c1 = constant 1 : index
|
||||||
|
|
||||||
// Initialize input.
|
// Initialize input.
|
||||||
%input = alloc() : memref<2x3xf32>
|
%input = memref.alloc() : memref<2x3xf32>
|
||||||
%dim_x = dim %input, %c0 : memref<2x3xf32>
|
%dim_x = memref.dim %input, %c0 : memref<2x3xf32>
|
||||||
%dim_y = dim %input, %c1 : memref<2x3xf32>
|
%dim_y = memref.dim %input, %c1 : memref<2x3xf32>
|
||||||
scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) {
|
scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) {
|
||||||
%i_i64 = index_cast %i : index to i64
|
%i_i64 = index_cast %i : index to i64
|
||||||
%i_f32 = sitofp %i_i64 : i64 to f32
|
%i_f32 = sitofp %i_i64 : i64 to f32
|
||||||
store %i_f32, %input[%i, %j] : memref<2x3xf32>
|
memref.store %i_f32, %input[%i, %j] : memref<2x3xf32>
|
||||||
}
|
}
|
||||||
%unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
|
%unranked_input = memref.cast %input : memref<2x3xf32> to memref<*xf32>
|
||||||
call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> ()
|
call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> ()
|
||||||
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
|
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
|
||||||
// CHECK: [0, 0, 0]
|
// CHECK: [0, 0, 0]
|
||||||
// CHECK: [1, 1, 1]
|
// CHECK: [1, 1, 1]
|
||||||
|
|
||||||
%in = tensor_load %input : memref<2x3xf32>
|
%in = memref.tensor_load %input : memref<2x3xf32>
|
||||||
%init = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
%init = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
|
|
||||||
%reduce = "mhlo.reduce"(%in, %init) ( {
|
%reduce = "mhlo.reduce"(%in, %init) ( {
|
||||||
|
@ -45,8 +45,8 @@ func @reduce_add() {
|
||||||
}) {dimensions = dense<1> : tensor<1xi64>}
|
}) {dimensions = dense<1> : tensor<1xi64>}
|
||||||
: (tensor<2x3xf32>, tensor<f32>) -> tensor<2xf32>
|
: (tensor<2x3xf32>, tensor<f32>) -> tensor<2xf32>
|
||||||
|
|
||||||
%output = tensor_to_memref %reduce : memref<2xf32>
|
%output = memref.buffer_cast %reduce : memref<2xf32>
|
||||||
%unranked_output = memref_cast %output : memref<2xf32> to memref<*xf32>
|
%unranked_output = memref.cast %output : memref<2xf32> to memref<*xf32>
|
||||||
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
||||||
// CHECK: rank = 1 offset = 0 sizes = [2] strides = [1]
|
// CHECK: rank = 1 offset = 0 sizes = [2] strides = [1]
|
||||||
// CHECK: [0, 3]
|
// CHECK: [0, 3]
|
||||||
|
@ -58,21 +58,21 @@ func @reduce_max() {
|
||||||
%c1 = constant 1 : index
|
%c1 = constant 1 : index
|
||||||
|
|
||||||
// Initialize input.
|
// Initialize input.
|
||||||
%input = alloc() : memref<2x3xf32>
|
%input = memref.alloc() : memref<2x3xf32>
|
||||||
%dim_x = dim %input, %c0 : memref<2x3xf32>
|
%dim_x = memref.dim %input, %c0 : memref<2x3xf32>
|
||||||
%dim_y = dim %input, %c1 : memref<2x3xf32>
|
%dim_y = memref.dim %input, %c1 : memref<2x3xf32>
|
||||||
scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) {
|
scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) {
|
||||||
%i_i64 = index_cast %i : index to i64
|
%i_i64 = index_cast %i : index to i64
|
||||||
%i_f32 = sitofp %i_i64 : i64 to f32
|
%i_f32 = sitofp %i_i64 : i64 to f32
|
||||||
store %i_f32, %input[%i, %j] : memref<2x3xf32>
|
memref.store %i_f32, %input[%i, %j] : memref<2x3xf32>
|
||||||
}
|
}
|
||||||
%unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
|
%unranked_input = memref.cast %input : memref<2x3xf32> to memref<*xf32>
|
||||||
call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> ()
|
call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> ()
|
||||||
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
|
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
|
||||||
// CHECK: [0, 0, 0]
|
// CHECK: [0, 0, 0]
|
||||||
// CHECK: [1, 1, 1]
|
// CHECK: [1, 1, 1]
|
||||||
|
|
||||||
%in = tensor_load %input : memref<2x3xf32>
|
%in = memref.tensor_load %input : memref<2x3xf32>
|
||||||
%init = mhlo.constant dense<0xff800000> : tensor<f32>
|
%init = mhlo.constant dense<0xff800000> : tensor<f32>
|
||||||
|
|
||||||
%reduce = "mhlo.reduce"(%in, %init) ( {
|
%reduce = "mhlo.reduce"(%in, %init) ( {
|
||||||
|
@ -82,8 +82,8 @@ func @reduce_max() {
|
||||||
}) {dimensions = dense<1> : tensor<1xi64>}
|
}) {dimensions = dense<1> : tensor<1xi64>}
|
||||||
: (tensor<2x3xf32>, tensor<f32>) -> tensor<2xf32>
|
: (tensor<2x3xf32>, tensor<f32>) -> tensor<2xf32>
|
||||||
|
|
||||||
%output = tensor_to_memref %reduce : memref<2xf32>
|
%output = memref.buffer_cast %reduce : memref<2xf32>
|
||||||
%unranked_output = memref_cast %output : memref<2xf32> to memref<*xf32>
|
%unranked_output = memref.cast %output : memref<2xf32> to memref<*xf32>
|
||||||
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
||||||
// CHECK: rank = 1 offset = 0 sizes = [2] strides = [1]
|
// CHECK: rank = 1 offset = 0 sizes = [2] strides = [1]
|
||||||
// CHECK: [0, 1]
|
// CHECK: [0, 1]
|
||||||
|
|
|
@ -17,7 +17,7 @@ func @dynamic_reshape_from_unranked(
|
||||||
return %reshaped : tensor<?xf32>
|
return %reshaped : tensor<?xf32>
|
||||||
}
|
}
|
||||||
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>, [[SHAPE:%.*]]: memref<1xi32>)
|
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>, [[SHAPE:%.*]]: memref<1xi32>)
|
||||||
// CHECK-NEXT: memref_reshape [[ARG]]([[SHAPE]])
|
// CHECK-NEXT: memref.reshape [[ARG]]([[SHAPE]])
|
||||||
// CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
|
// CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
@ -30,7 +30,7 @@ func @dynamic_reshape_to_unranked(
|
||||||
return %reshaped : tensor<*xf32>
|
return %reshaped : tensor<*xf32>
|
||||||
}
|
}
|
||||||
// CHECK-SAME: ([[ARG:%.*]]: memref<?xf32>, [[SHAPE:%.*]]: memref<?xi32>)
|
// CHECK-SAME: ([[ARG:%.*]]: memref<?xf32>, [[SHAPE:%.*]]: memref<?xi32>)
|
||||||
// CHECK-NEXT: memref_reshape [[ARG]]([[SHAPE]])
|
// CHECK-NEXT: memref.reshape [[ARG]]([[SHAPE]])
|
||||||
// CHECK-SAME: : (memref<?xf32>, memref<?xi32>) -> memref<*xf32>
|
// CHECK-SAME: : (memref<?xf32>, memref<?xi32>) -> memref<*xf32>
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
@ -41,4 +41,4 @@ func @reshape_unranked(%operand: tensor<*xf32>) -> tensor<f32> {
|
||||||
return %reshaped : tensor<f32>
|
return %reshaped : tensor<f32>
|
||||||
}
|
}
|
||||||
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>)
|
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>)
|
||||||
// CHECK-NEXT: memref_cast [[ARG]] : memref<*xf32> to memref<f32>
|
// CHECK-NEXT: memref.cast [[ARG]] : memref<*xf32> to memref<f32>
|
||||||
|
|
|
@ -31,20 +31,20 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
|
||||||
return %5 : tensor<4xf32>
|
return %5 : tensor<4xf32>
|
||||||
}
|
}
|
||||||
// CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>) -> memref<4xf32>
|
// CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>) -> memref<4xf32>
|
||||||
// CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32>
|
// CHECK-NEXT: %[[MAX_RESULT:.*]] = memref.alloc() : memref<4xf32>
|
||||||
// CHECK-NEXT: "lmhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
|
// CHECK-NEXT: "lmhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
|
||||||
// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32>
|
// CHECK-NEXT: %[[ADD_RESULT:.*]] = memref.alloc() : memref<4xf32>
|
||||||
// CHECK-NEXT: "lmhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
|
// CHECK-NEXT: "lmhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
|
||||||
// CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
|
// CHECK-NEXT: memref.dealloc %[[MAX_RESULT]] : memref<4xf32>
|
||||||
// CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32>
|
// CHECK-NEXT: %[[MIN_RESULT:.*]] = memref.alloc() : memref<4xf32>
|
||||||
// CHECK-NEXT: "lmhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
|
// CHECK-NEXT: "lmhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
|
||||||
// CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32>
|
// CHECK-NEXT: %[[SUB_RESULT:.*]] = memref.alloc() : memref<4xf32>
|
||||||
// CHECK-NEXT: "lmhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
|
// CHECK-NEXT: "lmhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
|
||||||
// CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
|
// CHECK-NEXT: memref.dealloc %[[MIN_RESULT]] : memref<4xf32>
|
||||||
// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32>
|
// CHECK-NEXT: %[[MUL_RESULT:.*]] = memref.alloc() : memref<4xf32>
|
||||||
// CHECK-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
|
// CHECK-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
|
||||||
// CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
|
// CHECK-NEXT: memref.dealloc %[[SUB_RESULT]] : memref<4xf32>
|
||||||
// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
|
// CHECK-NEXT: memref.dealloc %[[ADD_RESULT]] : memref<4xf32>
|
||||||
// CHECK-NEXT: return %[[MUL_RESULT]] : memref<4xf32>
|
// CHECK-NEXT: return %[[MUL_RESULT]] : memref<4xf32>
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
@ -53,15 +53,15 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
|
||||||
func @fusion(%multiplier: tensor<2x2xf32>, %summand_1: tensor<2x2xf32>,
|
func @fusion(%multiplier: tensor<2x2xf32>, %summand_1: tensor<2x2xf32>,
|
||||||
%summand_2: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
%summand_2: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||||
// CHECK: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}})
|
// CHECK: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}})
|
||||||
// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32>
|
// CHECK-NEXT: %[[ADD_RESULT:.*]] = memref.alloc() : memref<2x2xf32>
|
||||||
%sum = "mhlo.add"(%summand_1, %summand_2)
|
%sum = "mhlo.add"(%summand_1, %summand_2)
|
||||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||||
// CHECK-NEXT: "lmhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
|
// CHECK-NEXT: "lmhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
|
||||||
// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
|
// CHECK-NEXT: %[[MUL_RESULT:.*]] = memref.alloc() : memref<2x2xf32>
|
||||||
%result = "mhlo.multiply"(%sum, %multiplier)
|
%result = "mhlo.multiply"(%sum, %multiplier)
|
||||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||||
// CHECK-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
|
// CHECK-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
|
||||||
// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32>
|
// CHECK-NEXT: memref.dealloc %[[ADD_RESULT]] : memref<2x2xf32>
|
||||||
// CHECK-NEXT: return %[[MUL_RESULT]] : memref<2x2xf32>
|
// CHECK-NEXT: return %[[MUL_RESULT]] : memref<2x2xf32>
|
||||||
return %result : tensor<2x2xf32>
|
return %result : tensor<2x2xf32>
|
||||||
}
|
}
|
||||||
|
@ -154,9 +154,9 @@ func @dyn_broadcast(%operand: tensor<?x?xf32>) -> tensor<?x?x?xf32> {
|
||||||
|
|
||||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||||
// CHECK: %[[OPER_DIM_1:.*]] = dim %[[OPERAND]], %[[C1]] : memref<?x?xf32>
|
// CHECK: %[[OPER_DIM_1:.*]] = memref.dim %[[OPERAND]], %[[C1]] : memref<?x?xf32>
|
||||||
// CHECK: %[[OP_STRIDE_0:.*]] = muli %[[C1]], %[[OPER_DIM_1]] : index
|
// CHECK: %[[OP_STRIDE_0:.*]] = muli %[[C1]], %[[OPER_DIM_1]] : index
|
||||||
// CHECK: %[[OPER_DIM_0:.*]] = dim %[[OPERAND]], %[[C0]] : memref<?x?xf32>
|
// CHECK: %[[OPER_DIM_0:.*]] = memref.dim %[[OPERAND]], %[[C0]] : memref<?x?xf32>
|
||||||
|
|
||||||
// CHECK: %[[EL0:.*]] = tensor.extract %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64>
|
// CHECK: %[[EL0:.*]] = tensor.extract %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64>
|
||||||
// CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index
|
// CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index
|
||||||
|
@ -172,9 +172,9 @@ func @dyn_broadcast(%operand: tensor<?x?xf32>) -> tensor<?x?x?xf32> {
|
||||||
// CHECK: %[[EXPAND_2:.*]] = cmpi slt, %[[OPER_DIM_1]], %[[SIZE_2]] : index
|
// CHECK: %[[EXPAND_2:.*]] = cmpi slt, %[[OPER_DIM_1]], %[[SIZE_2]] : index
|
||||||
// CHECK: %[[STRIDE_2:.*]] = select %[[EXPAND_2]], %[[C0]], %[[C1]] : index
|
// CHECK: %[[STRIDE_2:.*]] = select %[[EXPAND_2]], %[[C0]], %[[C1]] : index
|
||||||
|
|
||||||
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]] : memref<?x?xf32> to memref<?x?x?xf32, #map>
|
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref.reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]] : memref<?x?xf32> to memref<?x?x?xf32, #map>
|
||||||
|
|
||||||
// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>
|
// CHECK: %[[RESULT:.*]] = memref.alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>
|
||||||
|
|
||||||
// CHECK: "lmhlo.copy"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) : (memref<?x?x?xf32, #map>, memref<?x?x?xf32>) -> ()
|
// CHECK: "lmhlo.copy"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) : (memref<?x?x?xf32, #map>, memref<?x?x?xf32>) -> ()
|
||||||
// CHECK: return %[[RESULT]] : memref<?x?x?xf32>
|
// CHECK: return %[[RESULT]] : memref<?x?x?xf32>
|
||||||
|
@ -469,7 +469,7 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||||
// CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex>
|
// CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex>
|
||||||
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex>
|
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex>
|
||||||
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex>
|
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex>
|
||||||
// CHECK: %[[RESULT:.*]] = alloc(%[[EE0]], %[[EE1]])
|
// CHECK: %[[RESULT:.*]] = memref.alloc(%[[EE0]], %[[EE1]])
|
||||||
// CHECK: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
// CHECK: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
return %result : tensor<?x?xf32>
|
return %result : tensor<?x?xf32>
|
||||||
// CHECK: return %[[RESULT]]
|
// CHECK: return %[[RESULT]]
|
||||||
|
@ -485,7 +485,7 @@ func @tanh_dyn(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||||
// CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex>
|
// CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex>
|
||||||
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex>
|
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex>
|
||||||
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex>
|
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex>
|
||||||
// CHECK: %[[RESULT:.*]] = alloc(%[[EE0]], %[[EE1]])
|
// CHECK: %[[RESULT:.*]] = memref.alloc(%[[EE0]], %[[EE1]])
|
||||||
// CHECK: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
// CHECK: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
return %result : tensor<?x?xf32>
|
return %result : tensor<?x?xf32>
|
||||||
// CHECK: return %[[RESULT]]
|
// CHECK: return %[[RESULT]]
|
||||||
|
@ -496,7 +496,7 @@ func @tanh_dyn(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||||
// CHECK-LABEL: func @dot
|
// CHECK-LABEL: func @dot
|
||||||
func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
|
func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
|
||||||
// CHECK-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
|
// CHECK-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
|
||||||
// CHECK-NEXT: %[[ALLOC:.*]] = alloc
|
// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc
|
||||||
// CHECK: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) {
|
// CHECK: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) {
|
||||||
// dot_dimension_numbers = {
|
// dot_dimension_numbers = {
|
||||||
// lhs_batching_dimensions = dense<> : tensor<0xi64>,
|
// lhs_batching_dimensions = dense<> : tensor<0xi64>,
|
||||||
|
@ -517,7 +517,7 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
|
||||||
func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>)
|
func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>)
|
||||||
-> tensor<3x5x5x4xf32> {
|
-> tensor<3x5x5x4xf32> {
|
||||||
%c0 = constant 0 : index
|
%c0 = constant 0 : index
|
||||||
// CHECK: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32>
|
// CHECK: %[[OUT:.*]] = memref.alloc() : memref<3x5x5x4xf32>
|
||||||
// CHECK: "lmhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
|
// CHECK: "lmhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
|
||||||
// CHECK-SAME: padding = dense<[
|
// CHECK-SAME: padding = dense<[
|
||||||
// CHECK-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
|
// CHECK-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
|
||||||
|
@ -548,11 +548,11 @@ func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>)
|
||||||
|
|
||||||
// CHECK-LABEL: func @reduce
|
// CHECK-LABEL: func @reduce
|
||||||
func @reduce(%arg0: tensor<1x8xf32>, %arg1: tensor<f32>) -> tensor<1xf32> {
|
func @reduce(%arg0: tensor<1x8xf32>, %arg1: tensor<f32>) -> tensor<1xf32> {
|
||||||
// CHECK: %[[OUT:.*]] = alloc() : memref<1xf32>
|
// CHECK: %[[OUT:.*]] = memref.alloc() : memref<1xf32>
|
||||||
// CHECK: "lmhlo.reduce"(%{{.+}}, %{{.+}}, %[[OUT]]) ( {
|
// CHECK: "lmhlo.reduce"(%{{.+}}, %{{.+}}, %[[OUT]]) ( {
|
||||||
// CHECK: ^bb0(%[[ARG1:.*]]: memref<f32>, %[[ARG2:.*]]: memref<f32>,
|
// CHECK: ^bb0(%[[ARG1:.*]]: memref<f32>, %[[ARG2:.*]]: memref<f32>,
|
||||||
// CHECK-SAME: %[[ARG3:.*]]: memref<f32>):
|
// CHECK-SAME: %[[ARG3:.*]]: memref<f32>):
|
||||||
// CHECK: %[[TMP:.*]] = alloc() : memref<f32>
|
// CHECK: %[[TMP:.*]] = memref.alloc() : memref<f32>
|
||||||
// CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[TMP]])
|
// CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[TMP]])
|
||||||
// CHECK: "lmhlo.copy"(%[[TMP]], %[[ARG3]])
|
// CHECK: "lmhlo.copy"(%[[TMP]], %[[ARG3]])
|
||||||
// CHECK: "lmhlo.terminator"() : () -> ()
|
// CHECK: "lmhlo.terminator"() : () -> ()
|
||||||
|
|
|
@ -404,7 +404,7 @@ func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> {
|
||||||
return %0: tensor<4x2x1x4x?x16xf32>
|
return %0: tensor<4x2x1x4x?x16xf32>
|
||||||
}
|
}
|
||||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||||
// CHECK: %[[D1:.*]] = dim %{{.*}}, %[[C1]] : tensor<4x?x16xf32>
|
// CHECK: %[[D1:.*]] = memref.dim %{{.*}}, %[[C1]] : tensor<4x?x16xf32>
|
||||||
// CHECK: linalg.init_tensor [4, 2, 1, 4, %[[D1]], 16] : tensor<4x2x1x4x?x16xf32>
|
// CHECK: linalg.init_tensor [4, 2, 1, 4, %[[D1]], 16] : tensor<4x2x1x4x?x16xf32>
|
||||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||||
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32):
|
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32):
|
||||||
|
@ -1024,7 +1024,7 @@ func @dot_matmul(%arg0: tensor<2x3xf32>,
|
||||||
// CHECK-LABEL: func @dot_matmul(
|
// CHECK-LABEL: func @dot_matmul(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xf32>, %[[ARG1:.*]]: tensor<3x?xf32>)
|
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xf32>, %[[ARG1:.*]]: tensor<3x?xf32>)
|
||||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||||
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
|
// CHECK: %[[D1:.*]] = memref.dim %[[ARG1]], %[[C1]]
|
||||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
||||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||||
// CHECK: linalg.matmul
|
// CHECK: linalg.matmul
|
||||||
|
@ -1040,7 +1040,7 @@ func @dot_matmul_i8_i8_i32(%arg0: tensor<2x3xi8>,
|
||||||
// CHECK-LABEL: func @dot_matmul_i8_i8_i32(
|
// CHECK-LABEL: func @dot_matmul_i8_i8_i32(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi8>, %[[ARG1:.*]]: tensor<3x?xi8>)
|
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi8>, %[[ARG1:.*]]: tensor<3x?xi8>)
|
||||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||||
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
|
// CHECK: %[[D1:.*]] = memref.dim %[[ARG1]], %[[C1]]
|
||||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
||||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||||
// CHECK: linalg.matmul
|
// CHECK: linalg.matmul
|
||||||
|
@ -1058,7 +1058,7 @@ func @dot_matmul_i16_i16_i32(%arg0: tensor<2x3xi16>,
|
||||||
// CHECK-LABEL: func @dot_matmul_i16_i16_i32(
|
// CHECK-LABEL: func @dot_matmul_i16_i16_i32(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi16>, %[[ARG1:.*]]: tensor<3x?xi16>)
|
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi16>, %[[ARG1:.*]]: tensor<3x?xi16>)
|
||||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||||
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
|
// CHECK: %[[D1:.*]] = memref.dim %[[ARG1]], %[[C1]]
|
||||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
||||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||||
// CHECK: linalg.matmul
|
// CHECK: linalg.matmul
|
||||||
|
@ -1076,7 +1076,7 @@ func @dot_matmul_i32_i32_i32(%arg0: tensor<2x3xi32>,
|
||||||
// CHECK-LABEL: func @dot_matmul_i32_i32_i32(
|
// CHECK-LABEL: func @dot_matmul_i32_i32_i32(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi32>, %[[ARG1:.*]]: tensor<3x?xi32>)
|
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi32>, %[[ARG1:.*]]: tensor<3x?xi32>)
|
||||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||||
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
|
// CHECK: %[[D1:.*]] = memref.dim %[[ARG1]], %[[C1]]
|
||||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
||||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||||
// CHECK: linalg.matmul
|
// CHECK: linalg.matmul
|
||||||
|
@ -1094,7 +1094,7 @@ func @dot_matvec(%arg0: tensor<?x3xf32>,
|
||||||
// CHECK-LABEL: func @dot_matvec(
|
// CHECK-LABEL: func @dot_matvec(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x3xf32>, %[[ARG1:.*]]: tensor<3xf32>)
|
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x3xf32>, %[[ARG1:.*]]: tensor<3xf32>)
|
||||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
|
// CHECK: %[[D0:.*]] = memref.dim %[[ARG0]], %[[C0]]
|
||||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]]]
|
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]]]
|
||||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||||
// CHECK: linalg.matvec
|
// CHECK: linalg.matvec
|
||||||
|
@ -1134,11 +1134,11 @@ func @dot_general_batch_matmul(%arg0: tensor<?x?x3xf32>,
|
||||||
// CHECK-LABEL: func @dot_general_batch_matmul(
|
// CHECK-LABEL: func @dot_general_batch_matmul(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x3xf32>, %[[ARG1:.*]]: tensor<?x3x?xf32>)
|
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x3xf32>, %[[ARG1:.*]]: tensor<?x3x?xf32>)
|
||||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
|
// CHECK: %[[D0:.*]] = memref.dim %[[ARG0]], %[[C0]]
|
||||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||||
// CHECK: %[[D1:.*]] = dim %[[ARG0]], %[[C1]]
|
// CHECK: %[[D1:.*]] = memref.dim %[[ARG0]], %[[C1]]
|
||||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||||
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
|
// CHECK: %[[D2:.*]] = memref.dim %[[ARG1]], %[[C2]]
|
||||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
|
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
|
||||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||||
// CHECK: linalg.batch_matmul
|
// CHECK: linalg.batch_matmul
|
||||||
|
@ -1163,11 +1163,11 @@ func @dot_general_batch_matmul_i8_i8_i32(%arg0: tensor<?x?x3xi8>,
|
||||||
// CHECK-LABEL: func @dot_general_batch_matmul_i8_i8_i32(
|
// CHECK-LABEL: func @dot_general_batch_matmul_i8_i8_i32(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x3xi8>, %[[ARG1:.*]]: tensor<?x3x?xi8>)
|
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x3xi8>, %[[ARG1:.*]]: tensor<?x3x?xi8>)
|
||||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
|
// CHECK: %[[D0:.*]] = memref.dim %[[ARG0]], %[[C0]]
|
||||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||||
// CHECK: %[[D1:.*]] = dim %[[ARG0]], %[[C1]]
|
// CHECK: %[[D1:.*]] = memref.dim %[[ARG0]], %[[C1]]
|
||||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||||
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
|
// CHECK: %[[D2:.*]] = memref.dim %[[ARG1]], %[[C2]]
|
||||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
|
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
|
||||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||||
// CHECK: linalg.batch_matmul
|
// CHECK: linalg.batch_matmul
|
||||||
|
@ -1192,11 +1192,11 @@ func @dot_general_batch_matmul_i16_i16_i32(%arg0: tensor<?x?x3xi16>,
|
||||||
// CHECK-LABEL: func @dot_general_batch_matmul_i16_i16_i32(
|
// CHECK-LABEL: func @dot_general_batch_matmul_i16_i16_i32(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x3xi16>, %[[ARG1:.*]]: tensor<?x3x?xi16>)
|
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x3xi16>, %[[ARG1:.*]]: tensor<?x3x?xi16>)
|
||||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
|
// CHECK: %[[D0:.*]] = memref.dim %[[ARG0]], %[[C0]]
|
||||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||||
// CHECK: %[[D1:.*]] = dim %[[ARG0]], %[[C1]]
|
// CHECK: %[[D1:.*]] = memref.dim %[[ARG0]], %[[C1]]
|
||||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||||
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
|
// CHECK: %[[D2:.*]] = memref.dim %[[ARG1]], %[[C2]]
|
||||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
|
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
|
||||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||||
// CHECK: linalg.batch_matmul
|
// CHECK: linalg.batch_matmul
|
||||||
|
@ -1420,7 +1420,7 @@ func @reduce_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<i32>) -> tensor<?xi32
|
||||||
// CHECK: func @reduce_dynamic(%[[ARG0:.*]]: tensor<?x?xi32>
|
// CHECK: func @reduce_dynamic(%[[ARG0:.*]]: tensor<?x?xi32>
|
||||||
// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
|
// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
|
||||||
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
|
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
|
||||||
// CHECK-DAG: %[[DIM1:.*]] = dim %[[ARG0]], %[[C0]] : tensor<?x?xi32>
|
// CHECK-DAG: %[[DIM1:.*]] = memref.dim %[[ARG0]], %[[C0]] : tensor<?x?xi32>
|
||||||
// CHECK-DAG: %[[INIT_TENSOR:.*]] = linalg.init_tensor [%[[DIM1]]]
|
// CHECK-DAG: %[[INIT_TENSOR:.*]] = linalg.init_tensor [%[[DIM1]]]
|
||||||
// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill(%[[INIT_TENSOR]], %[[INIT]])
|
// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill(%[[INIT_TENSOR]], %[[INIT]])
|
||||||
// CHECK: linalg.generic
|
// CHECK: linalg.generic
|
||||||
|
@ -1531,9 +1531,9 @@ func @linalg.conv_1d_input_nwc_filter_wcf(%arg0: tensor<?x8x?xf32>, %arg1: tenso
|
||||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
|
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
|
||||||
// CHECK: %[[C0:.+]] = constant 0 : index
|
// CHECK: %[[C0:.+]] = constant 0 : index
|
||||||
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x8x?xf32>
|
// CHECK: %[[DIM0:.+]] = memref.dim %[[ARG0]], %[[C0]] : tensor<?x8x?xf32>
|
||||||
// CHECK: %[[C2:.+]] = constant 2 : index
|
// CHECK: %[[C2:.+]] = constant 2 : index
|
||||||
// CHECK: %[[DIM2:.+]] = dim %[[ARG1]], %[[C2]] : tensor<2x?x?xf32>
|
// CHECK: %[[DIM2:.+]] = memref.dim %[[ARG1]], %[[C2]] : tensor<2x?x?xf32>
|
||||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 7, %[[DIM2]]]
|
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 7, %[[DIM2]]]
|
||||||
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
|
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
|
||||||
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
|
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
|
||||||
|
@ -1571,9 +1571,9 @@ func @conv_2d_input_nhwc_filter_hwcf(%arg0: tensor<?x4x5x?xf32>, %arg1: tensor<3
|
||||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
|
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
|
||||||
// CHECK: %[[C0:.+]] = constant 0 : index
|
// CHECK: %[[C0:.+]] = constant 0 : index
|
||||||
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x4x5x?xf32>
|
// CHECK: %[[DIM0:.+]] = memref.dim %[[ARG0]], %[[C0]] : tensor<?x4x5x?xf32>
|
||||||
// CHECK: %[[C3:.+]] = constant 3 : index
|
// CHECK: %[[C3:.+]] = constant 3 : index
|
||||||
// CHECK: %[[DIM3:.+]] = dim %[[ARG1]], %[[C3]] : tensor<3x2x?x?xf32>
|
// CHECK: %[[DIM3:.+]] = memref.dim %[[ARG1]], %[[C3]] : tensor<3x2x?x?xf32>
|
||||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 2, 3, %[[DIM3]]]
|
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 2, 3, %[[DIM3]]]
|
||||||
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
|
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
|
||||||
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
|
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
|
||||||
|
@ -1611,9 +1611,9 @@ func @conv_3d_input_ndhwc_filter_dhwcf(%arg0: tensor<?x8x8x8x?xf32>, %arg1: tens
|
||||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
|
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
|
||||||
// CHECK: %[[C0:.+]] = constant 0 : index
|
// CHECK: %[[C0:.+]] = constant 0 : index
|
||||||
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x8x8x8x?xf32>
|
// CHECK: %[[DIM0:.+]] = memref.dim %[[ARG0]], %[[C0]] : tensor<?x8x8x8x?xf32>
|
||||||
// CHECK: %[[C4:.+]] = constant 4 : index
|
// CHECK: %[[C4:.+]] = constant 4 : index
|
||||||
// CHECK: %[[DIM4:.+]] = dim %[[ARG1]], %[[C4]] : tensor<2x2x2x?x?xf32>
|
// CHECK: %[[DIM4:.+]] = memref.dim %[[ARG1]], %[[C4]] : tensor<2x2x2x?x?xf32>
|
||||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 7, 7, 7, %[[DIM4]]]
|
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 7, 7, 7, %[[DIM4]]]
|
||||||
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
|
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
|
||||||
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
|
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
iterator_types = ["parallel", "parallel"]}
|
iterator_types = ["parallel", "parallel"]}
|
||||||
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
||||||
%summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) {
|
%summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) {
|
||||||
%temp_result = alloc() : memref<6x6xf32>
|
%temp_result = memref.alloc() : memref<6x6xf32>
|
||||||
linalg.generic #pointwise_2d_trait
|
linalg.generic #pointwise_2d_trait
|
||||||
ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>)
|
ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>)
|
||||||
outs(%temp_result : memref<6x6xf32>) {
|
outs(%temp_result : memref<6x6xf32>) {
|
||||||
|
@ -22,7 +22,7 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
||||||
%out = mulf %temp_result_in, %multiplier_in : f32
|
%out = mulf %temp_result_in, %multiplier_in : f32
|
||||||
linalg.yield %out : f32
|
linalg.yield %out : f32
|
||||||
}
|
}
|
||||||
dealloc %temp_result : memref<6x6xf32>
|
memref.dealloc %temp_result : memref<6x6xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// CHECK-LABEL: func @fusion
|
// CHECK-LABEL: func @fusion
|
||||||
|
@ -62,7 +62,7 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
||||||
func @fusion_of_three(%arg0: memref<100x10xf32>,
|
func @fusion_of_three(%arg0: memref<100x10xf32>,
|
||||||
%arg1: memref<100xf32>,
|
%arg1: memref<100xf32>,
|
||||||
%arg2: memref<100x10xf32>) {
|
%arg2: memref<100x10xf32>) {
|
||||||
%0 = alloc() : memref<100x10xf32>
|
%0 = memref.alloc() : memref<100x10xf32>
|
||||||
linalg.generic {
|
linalg.generic {
|
||||||
indexing_maps = [affine_map<(d0, d1) -> (d0)>,
|
indexing_maps = [affine_map<(d0, d1) -> (d0)>,
|
||||||
affine_map<(d0, d1) -> (d0, d1)>],
|
affine_map<(d0, d1) -> (d0, d1)>],
|
||||||
|
@ -72,7 +72,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
|
||||||
^bb0(%arg3: f32, %arg4: f32): // no predecessors
|
^bb0(%arg3: f32, %arg4: f32): // no predecessors
|
||||||
linalg.yield %arg3 : f32
|
linalg.yield %arg3 : f32
|
||||||
}
|
}
|
||||||
%1 = alloc() : memref<100x10xf32>
|
%1 = memref.alloc() : memref<100x10xf32>
|
||||||
linalg.generic {
|
linalg.generic {
|
||||||
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
|
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
|
||||||
affine_map<(d0, d1) -> (d0, d1)>,
|
affine_map<(d0, d1) -> (d0, d1)>,
|
||||||
|
@ -84,7 +84,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
|
||||||
%2 = subf %arg3, %arg4 : f32
|
%2 = subf %arg3, %arg4 : f32
|
||||||
linalg.yield %2 : f32
|
linalg.yield %2 : f32
|
||||||
}
|
}
|
||||||
dealloc %0 : memref<100x10xf32>
|
memref.dealloc %0 : memref<100x10xf32>
|
||||||
linalg.generic {
|
linalg.generic {
|
||||||
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
|
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
|
||||||
affine_map<(d0, d1) -> (d0, d1)>],
|
affine_map<(d0, d1) -> (d0, d1)>],
|
||||||
|
@ -95,7 +95,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
|
||||||
%2 = math.exp %arg3 : f32
|
%2 = math.exp %arg3 : f32
|
||||||
linalg.yield %2 : f32
|
linalg.yield %2 : f32
|
||||||
}
|
}
|
||||||
dealloc %1 : memref<100x10xf32>
|
memref.dealloc %1 : memref<100x10xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// CHECK-LABEL: func @fusion
|
// CHECK-LABEL: func @fusion
|
||||||
|
@ -141,7 +141,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
|
||||||
"parallel"]}
|
"parallel"]}
|
||||||
func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>,
|
func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>,
|
||||||
%summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) {
|
%summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) {
|
||||||
%temp_result = alloc() : memref<6x6x6x6xf32>
|
%temp_result = memref.alloc() : memref<6x6x6x6xf32>
|
||||||
linalg.generic #pointwise_4d_trait
|
linalg.generic #pointwise_4d_trait
|
||||||
ins(%summand_1, %summand_2 : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>)
|
ins(%summand_1, %summand_2 : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>)
|
||||||
outs(%temp_result : memref<6x6x6x6xf32>) {
|
outs(%temp_result : memref<6x6x6x6xf32>) {
|
||||||
|
@ -156,7 +156,7 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32
|
||||||
%out = mulf %temp_result_in, %multiplier_in : f32
|
%out = mulf %temp_result_in, %multiplier_in : f32
|
||||||
linalg.yield %out : f32
|
linalg.yield %out : f32
|
||||||
}
|
}
|
||||||
dealloc %temp_result : memref<6x6x6x6xf32>
|
memref.dealloc %temp_result : memref<6x6x6x6xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// CHECK-LABEL: func @fusion_4d
|
// CHECK-LABEL: func @fusion_4d
|
||||||
|
@ -200,7 +200,7 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32
|
||||||
iterator_types = ["parallel", "parallel"]}
|
iterator_types = ["parallel", "parallel"]}
|
||||||
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
||||||
%summand_2: memref<6x6xf32>) -> memref<6x6xf32> {
|
%summand_2: memref<6x6xf32>) -> memref<6x6xf32> {
|
||||||
%temp_result = alloc() : memref<6x6xf32>
|
%temp_result = memref.alloc() : memref<6x6xf32>
|
||||||
linalg.generic #pointwise_2d_trait
|
linalg.generic #pointwise_2d_trait
|
||||||
ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>)
|
ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>)
|
||||||
outs(%temp_result : memref<6x6xf32>) {
|
outs(%temp_result : memref<6x6xf32>) {
|
||||||
|
@ -208,7 +208,7 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
||||||
%out = addf %summand_1_in, %summand_2_in : f32
|
%out = addf %summand_1_in, %summand_2_in : f32
|
||||||
linalg.yield %out : f32
|
linalg.yield %out : f32
|
||||||
}
|
}
|
||||||
%result = alloc() : memref<6x6xf32>
|
%result = memref.alloc() : memref<6x6xf32>
|
||||||
linalg.generic #pointwise_2d_trait
|
linalg.generic #pointwise_2d_trait
|
||||||
ins(%temp_result, %multiplier : memref<6x6xf32>, memref<6x6xf32>)
|
ins(%temp_result, %multiplier : memref<6x6xf32>, memref<6x6xf32>)
|
||||||
outs(%result : memref<6x6xf32>) {
|
outs(%result : memref<6x6xf32>) {
|
||||||
|
@ -216,7 +216,7 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
||||||
%out = mulf %temp_result_in, %multiplier_in : f32
|
%out = mulf %temp_result_in, %multiplier_in : f32
|
||||||
linalg.yield %out : f32
|
linalg.yield %out : f32
|
||||||
}
|
}
|
||||||
dealloc %temp_result : memref<6x6xf32>
|
memref.dealloc %temp_result : memref<6x6xf32>
|
||||||
return %result : memref<6x6xf32>
|
return %result : memref<6x6xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -258,7 +258,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
|
||||||
-> memref<*xf32> {
|
-> memref<*xf32> {
|
||||||
%c1 = constant 1 : index
|
%c1 = constant 1 : index
|
||||||
%c0 = constant 0 : index
|
%c0 = constant 0 : index
|
||||||
%1 = alloc(%arg2) : memref<?xf32>
|
%1 = memref.alloc(%arg2) : memref<?xf32>
|
||||||
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
|
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
|
||||||
affine_map<(d0) -> (d0)>],
|
affine_map<(d0) -> (d0)>],
|
||||||
iterator_types = ["parallel"]}
|
iterator_types = ["parallel"]}
|
||||||
|
@ -267,7 +267,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
|
||||||
%13 = absf %arg3 : f32
|
%13 = absf %arg3 : f32
|
||||||
linalg.yield %13 : f32
|
linalg.yield %13 : f32
|
||||||
}
|
}
|
||||||
%2 = memref_reshape %1(%arg1)
|
%2 = memref.reshape %1(%arg1)
|
||||||
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
|
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
|
||||||
return %2 : memref<*xf32>
|
return %2 : memref<*xf32>
|
||||||
}
|
}
|
||||||
|
@ -279,7 +279,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
|
||||||
// CHECK-NOT: scf.for
|
// CHECK-NOT: scf.for
|
||||||
// CHECK: linalg.generic
|
// CHECK: linalg.generic
|
||||||
// CHECK: absf
|
// CHECK: absf
|
||||||
// CHECK: memref_reshape
|
// CHECK: memref.reshape
|
||||||
|
|
||||||
// TILED-LABEL: func @view_result
|
// TILED-LABEL: func @view_result
|
||||||
// TILED-DAG: %[[C2:.*]] = constant 2
|
// TILED-DAG: %[[C2:.*]] = constant 2
|
||||||
|
@ -288,7 +288,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
|
||||||
// TILED-NOT: scf.for
|
// TILED-NOT: scf.for
|
||||||
// TILED: linalg.generic
|
// TILED: linalg.generic
|
||||||
// TILED: absf
|
// TILED: absf
|
||||||
// TILED: memref_reshape
|
// TILED: memref.reshape
|
||||||
|
|
||||||
|
|
||||||
// PLOOP-LABEL: func @view_result
|
// PLOOP-LABEL: func @view_result
|
||||||
|
@ -297,20 +297,20 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
|
||||||
// PLOOP-NOT: scf.parallel
|
// PLOOP-NOT: scf.parallel
|
||||||
// PLOOP: linalg.generic
|
// PLOOP: linalg.generic
|
||||||
// PLOOP: absf
|
// PLOOP: absf
|
||||||
// PLOOP: memref_reshape
|
// PLOOP: memref.reshape
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// Confirm that tiling information is passed through RegionBranchOpInterfaces.
|
// Confirm that tiling information is passed through RegionBranchOpInterfaces.
|
||||||
// This test also uses memref_reshape, just to have a value to return through
|
// This test also uses memref.reshape, just to have a value to return through
|
||||||
// the if statement.
|
// the if statement.
|
||||||
func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
|
func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
|
||||||
-> memref<*xf32> {
|
-> memref<*xf32> {
|
||||||
%c1 = constant 1 : index
|
%c1 = constant 1 : index
|
||||||
%c0 = constant 0 : index
|
%c0 = constant 0 : index
|
||||||
%1 = alloc(%arg2) : memref<?xf32>
|
%1 = memref.alloc(%arg2) : memref<?xf32>
|
||||||
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
|
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
|
||||||
affine_map<(d0) -> (d0)>],
|
affine_map<(d0) -> (d0)>],
|
||||||
iterator_types = ["parallel"]}
|
iterator_types = ["parallel"]}
|
||||||
|
@ -321,11 +321,11 @@ func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: inde
|
||||||
}
|
}
|
||||||
%true = constant 1 : i1
|
%true = constant 1 : i1
|
||||||
%3 = scf.if %true -> memref<*xf32> {
|
%3 = scf.if %true -> memref<*xf32> {
|
||||||
%2 = memref_reshape %1(%arg1)
|
%2 = memref.reshape %1(%arg1)
|
||||||
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
|
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
|
||||||
scf.yield %2 : memref<*xf32>
|
scf.yield %2 : memref<*xf32>
|
||||||
} else {
|
} else {
|
||||||
%2 = memref_reshape %1(%arg1)
|
%2 = memref.reshape %1(%arg1)
|
||||||
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
|
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
|
||||||
scf.yield %2 : memref<*xf32>
|
scf.yield %2 : memref<*xf32>
|
||||||
}
|
}
|
||||||
|
@ -340,10 +340,10 @@ func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: inde
|
||||||
// CHECK: linalg.generic
|
// CHECK: linalg.generic
|
||||||
// CHECK: absf
|
// CHECK: absf
|
||||||
// CHECK: scf.if
|
// CHECK: scf.if
|
||||||
// CHECK: memref_reshape
|
// CHECK: memref.reshape
|
||||||
// CHECK: scf.yield
|
// CHECK: scf.yield
|
||||||
// CHECK: else
|
// CHECK: else
|
||||||
// CHECK: memref_reshape
|
// CHECK: memref.reshape
|
||||||
// CHECK: scf.yield
|
// CHECK: scf.yield
|
||||||
|
|
||||||
// TILED-LABEL: func @branching_result
|
// TILED-LABEL: func @branching_result
|
||||||
|
@ -354,10 +354,10 @@ func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: inde
|
||||||
// TILED: linalg.generic
|
// TILED: linalg.generic
|
||||||
// TILED: absf
|
// TILED: absf
|
||||||
// TILED: scf.if
|
// TILED: scf.if
|
||||||
// TILED: memref_reshape
|
// TILED: memref.reshape
|
||||||
// TILED: scf.yield
|
// TILED: scf.yield
|
||||||
// TILED: else
|
// TILED: else
|
||||||
// TILED: memref_reshape
|
// TILED: memref.reshape
|
||||||
// TILED: scf.yield
|
// TILED: scf.yield
|
||||||
|
|
||||||
// PLOOP-LABEL: func @branching_result
|
// PLOOP-LABEL: func @branching_result
|
||||||
|
@ -367,10 +367,10 @@ func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: inde
|
||||||
// PLOOP: linalg.generic
|
// PLOOP: linalg.generic
|
||||||
// PLOOP: absf
|
// PLOOP: absf
|
||||||
// PLOOP: scf.if
|
// PLOOP: scf.if
|
||||||
// PLOOP: memref_reshape
|
// PLOOP: memref.reshape
|
||||||
// PLOOP: scf.yield
|
// PLOOP: scf.yield
|
||||||
// PLOOP: else
|
// PLOOP: else
|
||||||
// PLOOP: memref_reshape
|
// PLOOP: memref.reshape
|
||||||
// PLOOP: scf.yield
|
// PLOOP: scf.yield
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
@ -380,7 +380,7 @@ func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: inde
|
||||||
func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
|
func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
|
||||||
-> memref<?xf32> {
|
-> memref<?xf32> {
|
||||||
%c1 = constant 1 : index
|
%c1 = constant 1 : index
|
||||||
%1 = alloc() : memref<32xf32>
|
%1 = memref.alloc() : memref<32xf32>
|
||||||
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
|
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
|
||||||
affine_map<(d0) -> (d0)>],
|
affine_map<(d0) -> (d0)>],
|
||||||
iterator_types = ["parallel"]}
|
iterator_types = ["parallel"]}
|
||||||
|
@ -389,9 +389,9 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
|
||||||
%13 = absf %arg3 : f32
|
%13 = absf %arg3 : f32
|
||||||
linalg.yield %13 : f32
|
linalg.yield %13 : f32
|
||||||
}
|
}
|
||||||
%2 = tensor_load %1 : memref<32xf32>
|
%2 = memref.tensor_load %1 : memref<32xf32>
|
||||||
%3 = tensor.cast %2 : tensor<32xf32> to tensor<?xf32>
|
%3 = tensor.cast %2 : tensor<32xf32> to tensor<?xf32>
|
||||||
%4 = tensor_to_memref %3 : memref<?xf32>
|
%4 = memref.buffer_cast %3 : memref<?xf32>
|
||||||
return %4 : memref<?xf32>
|
return %4 : memref<?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -402,9 +402,9 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
|
||||||
// CHECK-NOT: scf.for
|
// CHECK-NOT: scf.for
|
||||||
// CHECK: linalg.generic
|
// CHECK: linalg.generic
|
||||||
// CHECK: absf
|
// CHECK: absf
|
||||||
// CHECK: tensor_load
|
// CHECK: memref.tensor_load
|
||||||
// CHECK: tensor.cast
|
// CHECK: tensor.cast
|
||||||
// CHECK: tensor_to_memref
|
// CHECK: memref.buffer_cast
|
||||||
|
|
||||||
// TILED-LABEL: func @tensor_ops
|
// TILED-LABEL: func @tensor_ops
|
||||||
// TILED-DAG: %[[C2:.*]] = constant 2
|
// TILED-DAG: %[[C2:.*]] = constant 2
|
||||||
|
@ -413,9 +413,9 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
|
||||||
// TILED-NOT: scf.for
|
// TILED-NOT: scf.for
|
||||||
// TILED: linalg.generic
|
// TILED: linalg.generic
|
||||||
// TILED: absf
|
// TILED: absf
|
||||||
// TILED: tensor_load
|
// TILED: memref.tensor_load
|
||||||
// TILED: tensor.cast
|
// TILED: tensor.cast
|
||||||
// TILED: tensor_to_memref
|
// TILED: memref.buffer_cast
|
||||||
|
|
||||||
|
|
||||||
// PLOOP-LABEL: func @tensor_ops
|
// PLOOP-LABEL: func @tensor_ops
|
||||||
|
@ -424,6 +424,6 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
|
||||||
// PLOOP-NOT: scf.parallel
|
// PLOOP-NOT: scf.parallel
|
||||||
// PLOOP: linalg.generic
|
// PLOOP: linalg.generic
|
||||||
// PLOOP: absf
|
// PLOOP: absf
|
||||||
// PLOOP: tensor_load
|
// PLOOP: memref.tensor_load
|
||||||
// PLOOP: tensor.cast
|
// PLOOP: tensor.cast
|
||||||
// PLOOP: tensor_to_memref
|
// PLOOP: memref.buffer_cast
|
||||||
|
|
|
@ -49,10 +49,10 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
|
||||||
// CHECK-DAG: [[CTRUE:%.*]] = constant true
|
// CHECK-DAG: [[CTRUE:%.*]] = constant true
|
||||||
|
|
||||||
// Parallel loop to initialize the output buffer.
|
// Parallel loop to initialize the output buffer.
|
||||||
// CHECK: [[INIT:%.*]] = load [[INIT_BUF]][] : memref<f32>
|
// CHECK: [[INIT:%.*]] = memref.load [[INIT_BUF]][] : memref<f32>
|
||||||
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
|
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
|
||||||
// CHECK-SAME: to ([[C112]], [[C112]]) step ([[C1]], [[C1]]) {
|
// CHECK-SAME: to ([[C112]], [[C112]]) step ([[C1]], [[C1]]) {
|
||||||
// CHECK: store [[INIT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]]
|
// CHECK: memref.store [[INIT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]]
|
||||||
// CHECK: scf.yield
|
// CHECK: scf.yield
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
|
|
||||||
|
@ -101,7 +101,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
|
||||||
|
|
||||||
// INBOUNDS-THEN-BODY, i.e. if INBOUNDS == true
|
// INBOUNDS-THEN-BODY, i.e. if INBOUNDS == true
|
||||||
|
|
||||||
// CHECK: [[ARG_ELEM:%.*]] = load [[ARG_BUF]]{{\[}}[[ARG_I]], [[ARG_J]]]
|
// CHECK: [[ARG_ELEM:%.*]] = memref.load [[ARG_BUF]]{{\[}}[[ARG_I]], [[ARG_J]]]
|
||||||
// CHECK: [[IF_INIT_RES:%.*]]:4
|
// CHECK: [[IF_INIT_RES:%.*]]:4
|
||||||
// CHECK-SAME: = scf.if [[SEL_INIT]] -> (index, index, f32, i1) {
|
// CHECK-SAME: = scf.if [[SEL_INIT]] -> (index, index, f32, i1) {
|
||||||
|
|
||||||
|
@ -114,16 +114,16 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
|
||||||
|
|
||||||
// Allocate buffers for ARG element, current selected value to adapt LHLO
|
// Allocate buffers for ARG element, current selected value to adapt LHLO
|
||||||
// code.
|
// code.
|
||||||
// CHECK: [[ARG_ELEM_BUF:%.*]] = alloc() : memref<f32>
|
// CHECK: [[ARG_ELEM_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||||
// CHECK: [[SEL_VAL_BUF:%.*]] = alloc() : memref<f32>
|
// CHECK: [[SEL_VAL_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||||
// CHECK: [[PRED_BUF:%.*]] = alloc() : memref<i1>
|
// CHECK: [[PRED_BUF:%.*]] = memref.alloc() : memref<i1>
|
||||||
// CHECK: store [[ARG_ELEM]], [[ARG_ELEM_BUF]][] : memref<f32>
|
// CHECK: memref.store [[ARG_ELEM]], [[ARG_ELEM_BUF]][] : memref<f32>
|
||||||
// CHECK: store [[SEL_VAL]], [[SEL_VAL_BUF]][] : memref<f32>
|
// CHECK: memref.store [[SEL_VAL]], [[SEL_VAL_BUF]][] : memref<f32>
|
||||||
|
|
||||||
// Compute PRED.
|
// Compute PRED.
|
||||||
// CHECK: "lmhlo.compare"(
|
// CHECK: "lmhlo.compare"(
|
||||||
// CHECK-SAME: [[ARG_ELEM_BUF]], [[SEL_VAL_BUF]], [[PRED_BUF]])
|
// CHECK-SAME: [[ARG_ELEM_BUF]], [[SEL_VAL_BUF]], [[PRED_BUF]])
|
||||||
// CHECK: [[PRED:%.*]] = load [[PRED_BUF]][] : memref<i1>
|
// CHECK: [[PRED:%.*]] = memref.load [[PRED_BUF]][] : memref<i1>
|
||||||
|
|
||||||
|
|
||||||
// Depending on PRED, return ARG ivs & elem or current select ivs and value.
|
// Depending on PRED, return ARG ivs & elem or current select ivs and value.
|
||||||
|
@ -165,7 +165,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
|
|
||||||
// Use selected ivs to load element from the SRC buffer.
|
// Use selected ivs to load element from the SRC buffer.
|
||||||
// CHECK: [[SRC_ELEM:%.*]] = load [[SRC_BUF]]{{\[}}[[II]], [[JJ]]]
|
// CHECK: [[SRC_ELEM:%.*]] = memref.load [[SRC_BUF]]{{\[}}[[II]], [[JJ]]]
|
||||||
|
|
||||||
// Update of RESULT[SELECTED_I, SELECTED_J] should be done atomically, because
|
// Update of RESULT[SELECTED_I, SELECTED_J] should be done atomically, because
|
||||||
// it may happen that several other threads select the same IVs if the windows
|
// it may happen that several other threads select the same IVs if the windows
|
||||||
|
@ -175,16 +175,16 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
|
||||||
// CHECK: ^bb0([[CUR_RES:%.*]]: f32):
|
// CHECK: ^bb0([[CUR_RES:%.*]]: f32):
|
||||||
|
|
||||||
// Allocate buffers for ARG element, current selected value to adapt LHLO code.
|
// Allocate buffers for ARG element, current selected value to adapt LHLO code.
|
||||||
// CHECK: [[SRC_ELEM_BUF:%.*]] = alloc() : memref<f32>
|
// CHECK: [[SRC_ELEM_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||||
// CHECK: [[CUR_RES_BUF:%.*]] = alloc() : memref<f32>
|
// CHECK: [[CUR_RES_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||||
// CHECK: [[RES_BUF:%.*]] = alloc() : memref<f32>
|
// CHECK: [[RES_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||||
// CHECK: store [[SRC_ELEM]], [[SRC_ELEM_BUF]][] : memref<f32>
|
// CHECK: memref.store [[SRC_ELEM]], [[SRC_ELEM_BUF]][] : memref<f32>
|
||||||
// CHECK: store [[CUR_RES]], [[CUR_RES_BUF]][] : memref<f32>
|
// CHECK: memref.store [[CUR_RES]], [[CUR_RES_BUF]][] : memref<f32>
|
||||||
|
|
||||||
// Compute scatter value.
|
// Compute scatter value.
|
||||||
// CHECK: "lmhlo.add"([[SRC_ELEM_BUF]], [[CUR_RES_BUF]], [[RES_BUF]]) :
|
// CHECK: "lmhlo.add"([[SRC_ELEM_BUF]], [[CUR_RES_BUF]], [[RES_BUF]]) :
|
||||||
// CHECK-SAME: (memref<f32>, memref<f32>, memref<f32>) -> ()
|
// CHECK-SAME: (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||||
// CHECK: [[RES:%.*]] = load [[RES_BUF]][] : memref<f32>
|
// CHECK: [[RES:%.*]] = memref.load [[RES_BUF]][] : memref<f32>
|
||||||
|
|
||||||
// Atomic RMW terminator that returns updated value.
|
// Atomic RMW terminator that returns updated value.
|
||||||
// CHECK: atomic_yield [[RES]] : f32
|
// CHECK: atomic_yield [[RES]] : f32
|
||||||
|
|
|
@ -19,14 +19,14 @@ func @reduce(%arg: memref<100x10xf32>,
|
||||||
// CHECK-DAG: %[[C100:.*]] = constant 100 : index
|
// CHECK-DAG: %[[C100:.*]] = constant 100 : index
|
||||||
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
|
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
|
||||||
// CHECK: gpu.launch blocks({{.*}}, {{.*}}, {{.*}}) in ({{.*}} = %[[C1]], {{.*}} = %[[C1]], {{.*}} = %[[C1]]) threads(%[[IDX:.*]], {{.*}}, {{.*}}) in ({{.*}} = %[[C100]], {{.*}} = %[[C1]], {{.*}} = %[[C1]]) {
|
// CHECK: gpu.launch blocks({{.*}}, {{.*}}, {{.*}}) in ({{.*}} = %[[C1]], {{.*}} = %[[C1]], {{.*}} = %[[C1]]) threads(%[[IDX:.*]], {{.*}}, {{.*}}) in ({{.*}} = %[[C100]], {{.*}} = %[[C1]], {{.*}} = %[[C1]]) {
|
||||||
// CHECK: %[[ACC:.*]] = load %[[ARG1]][] : memref<f32>
|
// CHECK: %[[ACC:.*]] = memref.load %[[ARG1]][] : memref<f32>
|
||||||
// CHECK: store %[[ACC]], %[[ARG2]][%[[IDX:.*]]] : memref<100xf32>
|
// CHECK: store %[[ACC]], %[[ARG2]][%[[IDX:.*]]] : memref<100xf32>
|
||||||
// CHECK-DAG: %[[LB:.*]] = constant 0 : index
|
// CHECK-DAG: %[[LB:.*]] = constant 0 : index
|
||||||
// CHECK-DAG: %[[UB:.*]] = constant 10 : index
|
// CHECK-DAG: %[[UB:.*]] = constant 10 : index
|
||||||
// CHECK-DAG: %[[STEP:.*]] = constant 1 : index
|
// CHECK-DAG: %[[STEP:.*]] = constant 1 : index
|
||||||
// CHECK: scf.for %[[IDX1:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
|
// CHECK: scf.for %[[IDX1:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
|
||||||
// CHECK: %[[LHS:.*]] = subview %[[ARG2]][%[[IDX]]] [1] [1] : memref<100xf32> to memref<f32, #[[$MAP]]>
|
// CHECK: %[[LHS:.*]] = memref.subview %[[ARG2]][%[[IDX]]] [1] [1] : memref<100xf32> to memref<f32, #[[$MAP]]>
|
||||||
// CHECK: %[[RHS:.*]] = subview %[[ARG0]][%[[IDX]], %[[IDX1]]] [1, 1] [1, 1] : memref<100x10xf32> to memref<f32, #[[$MAP]]>
|
// CHECK: %[[RHS:.*]] = memref.subview %[[ARG0]][%[[IDX]], %[[IDX1]]] [1, 1] [1, 1] : memref<100x10xf32> to memref<f32, #[[$MAP]]>
|
||||||
// CHECK: "lmhlo.add"(%[[LHS]], %[[RHS]], %[[LHS]]) : (memref<f32, {{.*}}>, memref<f32, {{.*}}>, memref<f32, {{.*}}>) -> ()
|
// CHECK: "lmhlo.add"(%[[LHS]], %[[RHS]], %[[LHS]]) : (memref<f32, {{.*}}>, memref<f32, {{.*}}>, memref<f32, {{.*}}>) -> ()
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
// CHECK: gpu.terminator
|
// CHECK: gpu.terminator
|
||||||
|
|
|
@ -52,10 +52,10 @@ func @element_wise_scalar(%lhs: memref<f32>, %rhs: memref<f32>,
|
||||||
: (memref<f32>, memref<f32>, memref<f32>) -> ()
|
: (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// CHECK: %[[LHS:.*]] = load
|
// CHECK: %[[LHS:.*]] = memref.load
|
||||||
// CHECK: %[[RHS:.*]] = load
|
// CHECK: %[[RHS:.*]] = memref.load
|
||||||
// CHECK: %[[RES:.*]] = addf %[[LHS]], %[[RHS]]
|
// CHECK: %[[RES:.*]] = addf %[[LHS]], %[[RHS]]
|
||||||
// CHECK: store %[[RES]]
|
// CHECK: memref.store %[[RES]]
|
||||||
// CHECK-NEXT: return
|
// CHECK-NEXT: return
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
@ -347,7 +347,7 @@ func @static_broadcast_in_dim_with_one_to_many(%operand: memref<1xf32>,
|
||||||
}
|
}
|
||||||
// CHECK-NOT: linalg.reshape
|
// CHECK-NOT: linalg.reshape
|
||||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
// CHECK: %[[VALUE:.*]] = load %{{.*}}[[C0]]
|
// CHECK: %[[VALUE:.*]] = memref.load %{{.*}}[[C0]]
|
||||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[RESULT_MAP]]]
|
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[RESULT_MAP]]]
|
||||||
// CHECK-NEXT: ^bb0(%{{.+}}: f32):
|
// CHECK-NEXT: ^bb0(%{{.+}}: f32):
|
||||||
// CHECK-NEXT: linalg.yield %[[VALUE]] : f32
|
// CHECK-NEXT: linalg.yield %[[VALUE]] : f32
|
||||||
|
@ -785,7 +785,7 @@ func @slice(%operand: memref<?x?xf32>, %result: memref<?x?xf32>) {
|
||||||
} : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
} : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// CHECK: %[[RESULT:.*]] = subview %[[IN]][0, 1] [2, 2] [1, 1] : memref<?x?xf32> to memref<2x2xf32, #{{.*}}>
|
// CHECK: %[[RESULT:.*]] = memref.subview %[[IN]][0, 1] [2, 2] [1, 1] : memref<?x?xf32> to memref<2x2xf32, #{{.*}}>
|
||||||
// CHECK: linalg.copy(%[[RESULT]], %[[OUT]])
|
// CHECK: linalg.copy(%[[RESULT]], %[[OUT]])
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
@ -899,7 +899,7 @@ func @reverse(%arg0: memref<2x3xf32>, %arg1: memref<2x3xf32>) {
|
||||||
|
|
||||||
func @conv(%input: memref<3x5x5x3xf32>, %filter: memref<2x2x3x4xf32>, %output: memref<3x5x5x4xf32>) {
|
func @conv(%input: memref<3x5x5x3xf32>, %filter: memref<2x2x3x4xf32>, %output: memref<3x5x5x4xf32>) {
|
||||||
%c0 = constant 0 : index
|
%c0 = constant 0 : index
|
||||||
%0 = alloc() : memref<3x5x5x4xf32>
|
%0 = memref.alloc() : memref<3x5x5x4xf32>
|
||||||
// CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}})
|
// CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}})
|
||||||
// CHECK-SAME: dilations = [1, 2]
|
// CHECK-SAME: dilations = [1, 2]
|
||||||
// CHECK-SAME: padding = dense<{{\[\[}}0, 1], [0, 1]]> : tensor<2x2xi64>
|
// CHECK-SAME: padding = dense<{{\[\[}}0, 1], [0, 1]]> : tensor<2x2xi64>
|
||||||
|
@ -948,22 +948,22 @@ func @reduce_add(%arg: memref<100x10xf32>,
|
||||||
: (memref<100x10xf32>, memref<f32>, memref<100xf32>) -> ()
|
: (memref<100x10xf32>, memref<f32>, memref<100xf32>) -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// CHECK: %[[INIT_VAL:.*]] = load %arg1[] : memref<f32>
|
// CHECK: %[[INIT_VAL:.*]] = memref.load %arg1[] : memref<f32>
|
||||||
// CHECK: linalg.fill(%arg2, %[[INIT_VAL]])
|
// CHECK: linalg.fill(%arg2, %[[INIT_VAL]])
|
||||||
// CHECK: linalg.generic {
|
// CHECK: linalg.generic {
|
||||||
// CHECK-SAME: indexing_maps = [#[[REDUCE_INPUT_MAP]], #[[REDUCE_OUTPUT_MAP]]],
|
// CHECK-SAME: indexing_maps = [#[[REDUCE_INPUT_MAP]], #[[REDUCE_OUTPUT_MAP]]],
|
||||||
// CHECK-SAME: iterator_types = ["parallel", "reduction"]}
|
// CHECK-SAME: iterator_types = ["parallel", "reduction"]}
|
||||||
// CHECK-SAME: ins(%arg0 : memref<100x10xf32>) outs(%arg2 : memref<100xf32>) {
|
// CHECK-SAME: ins(%arg0 : memref<100x10xf32>) outs(%arg2 : memref<100xf32>) {
|
||||||
// CHECK: alloca
|
// CHECK: memref.alloca
|
||||||
// CHECK-NEXT: alloca
|
// CHECK-NEXT: memref.alloca
|
||||||
// CHECK-NEXT: alloca
|
// CHECK-NEXT: memref.alloca
|
||||||
// CHECK-NEXT: store
|
// CHECK-NEXT: memref.store
|
||||||
// CHECK-NEXT: store
|
// CHECK-NEXT: memref.store
|
||||||
// CHECK-NEXT: load
|
// CHECK-NEXT: memref.load
|
||||||
// CHECK-NEXT: load
|
// CHECK-NEXT: memref.load
|
||||||
// CHECK-NEXT: addf
|
// CHECK-NEXT: addf
|
||||||
// CHECK-NEXT: store
|
// CHECK-NEXT: memref.store
|
||||||
// CHECK-NEXT: load
|
// CHECK-NEXT: memref.load
|
||||||
// CHECK-NEXT: linalg.yield
|
// CHECK-NEXT: linalg.yield
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
|
|
||||||
|
@ -984,22 +984,22 @@ func @reduce_maximum(%arg: memref<100x10xf32>,
|
||||||
: (memref<100x10xf32>, memref<f32>, memref<100xf32>) -> ()
|
: (memref<100x10xf32>, memref<f32>, memref<100xf32>) -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// CHECK: %[[INIT_VAL:.*]] = load %arg1[] : memref<f32>
|
// CHECK: %[[INIT_VAL:.*]] = memref.load %arg1[] : memref<f32>
|
||||||
// CHECK: linalg.fill(%arg2, %[[INIT_VAL]])
|
// CHECK: linalg.fill(%arg2, %[[INIT_VAL]])
|
||||||
// CHECK: linalg.generic {
|
// CHECK: linalg.generic {
|
||||||
// CHECK-SAME: indexing_maps = [#[[REDUCE_INPUT_MAP]], #[[REDUCE_OUTPUT_MAP]]],
|
// CHECK-SAME: indexing_maps = [#[[REDUCE_INPUT_MAP]], #[[REDUCE_OUTPUT_MAP]]],
|
||||||
// CHECK-SAME: iterator_types = ["parallel", "reduction"]}
|
// CHECK-SAME: iterator_types = ["parallel", "reduction"]}
|
||||||
// CHECK-SAME: ins(%arg0 : memref<100x10xf32>) outs(%arg2 : memref<100xf32>) {
|
// CHECK-SAME: ins(%arg0 : memref<100x10xf32>) outs(%arg2 : memref<100xf32>) {
|
||||||
// CHECK: alloca
|
// CHECK: memref.alloca
|
||||||
// CHECK-NEXT: alloca
|
// CHECK-NEXT: memref.alloca
|
||||||
// CHECK-NEXT: alloca
|
// CHECK-NEXT: memref.alloca
|
||||||
// CHECK-NEXT: store
|
// CHECK-NEXT: memref.store
|
||||||
// CHECK-NEXT: store
|
// CHECK-NEXT: memref.store
|
||||||
// CHECK-NEXT: load
|
// CHECK-NEXT: memref.load
|
||||||
// CHECK-NEXT: load
|
// CHECK-NEXT: memref.load
|
||||||
// CHECK: cmpf
|
// CHECK: cmpf
|
||||||
// CHECK: select
|
// CHECK: select
|
||||||
// CHECK: store
|
// CHECK: memref.store
|
||||||
// CHECK-NEXT: load
|
// CHECK-NEXT: memref.load
|
||||||
// CHECK-NEXT: linalg.yield
|
// CHECK-NEXT: linalg.yield
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
|
|
|
@ -21,27 +21,27 @@ func @reduce(%arg: memref<100x10x5xf32>,
|
||||||
// CHECK-DAG: [[C5:%.*]] = constant 5 : index
|
// CHECK-DAG: [[C5:%.*]] = constant 5 : index
|
||||||
// CHECK-DAG: [[C10:%.*]] = constant 10 : index
|
// CHECK-DAG: [[C10:%.*]] = constant 10 : index
|
||||||
// CHECK-DAG: [[C100:%.*]] = constant 100 : index
|
// CHECK-DAG: [[C100:%.*]] = constant 100 : index
|
||||||
// CHECK: [[INIT:%.*]] = load [[INIT_BUF]]
|
// CHECK: [[INIT:%.*]] = memref.load [[INIT_BUF]]
|
||||||
// CHECK: scf.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]])
|
// CHECK: scf.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]])
|
||||||
// CHECK-SAME: to ([[C100]], [[C5]]) step ([[C1]], [[C1]]) {
|
// CHECK-SAME: to ([[C100]], [[C5]]) step ([[C1]], [[C1]]) {
|
||||||
// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[J:%.*]]) =
|
// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[J:%.*]]) =
|
||||||
// CHECK-SAME: ([[C0]]) to ([[C10]]) step ([[C1]]) init ([[INIT]]) -> f32 {
|
// CHECK-SAME: ([[C0]]) to ([[C10]]) step ([[C1]]) init ([[INIT]]) -> f32 {
|
||||||
// CHECK: [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]]
|
// CHECK: [[ELEM_TO_REDUCE:%.*]] = memref.load [[ARG_BUF]]
|
||||||
// CHECK-SAME: {{\[}}[[I]], [[J]], [[K]]] : memref<100x10x5xf32>
|
// CHECK-SAME: {{\[}}[[I]], [[J]], [[K]]] : memref<100x10x5xf32>
|
||||||
// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 {
|
// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 {
|
||||||
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
|
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
|
||||||
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
|
// CHECK: [[ELEM_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||||
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
|
// CHECK: [[ACC_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||||
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
|
// CHECK: [[ACC_OUT_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||||
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
// CHECK: memref.store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||||
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
|
// CHECK: memref.store [[ACC]], [[ACC_BUF]][] : memref<f32>
|
||||||
// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
||||||
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
|
// CHECK: [[ACC_RESULT:%.*]] = memref.load [[ACC_OUT_BUF]][] : memref<f32>
|
||||||
// CHECK: scf.reduce.return [[ACC_RESULT]] : f32
|
// CHECK: scf.reduce.return [[ACC_RESULT]] : f32
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
// CHECK: scf.yield
|
// CHECK: scf.yield
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]]
|
// CHECK: memref.store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]]
|
||||||
// CHECK: scf.yield
|
// CHECK: scf.yield
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
@ -65,23 +65,23 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>,
|
||||||
// CHECK-DAG: [[C0:%.*]] = constant 0 : index
|
// CHECK-DAG: [[C0:%.*]] = constant 0 : index
|
||||||
// CHECK-DAG: [[C1:%.*]] = constant 1 : index
|
// CHECK-DAG: [[C1:%.*]] = constant 1 : index
|
||||||
// CHECK-DAG: [[C100:%.*]] = constant 100 : index
|
// CHECK-DAG: [[C100:%.*]] = constant 100 : index
|
||||||
// CHECK: [[INIT:%.*]] = load [[INIT_BUF]]
|
// CHECK: [[INIT:%.*]] = memref.load [[INIT_BUF]]
|
||||||
// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[I:%.*]]) = ([[C0]])
|
// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[I:%.*]]) = ([[C0]])
|
||||||
// CHECK-SAME: to ([[C100]]) step ([[C1]]) init ([[INIT]]) -> f32 {
|
// CHECK-SAME: to ([[C100]]) step ([[C1]]) init ([[INIT]]) -> f32 {
|
||||||
// CHECK: [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]]{{\[}}[[I]]{{\]}}
|
// CHECK: [[ELEM_TO_REDUCE:%.*]] = memref.load [[ARG_BUF]]{{\[}}[[I]]{{\]}}
|
||||||
// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 {
|
// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 {
|
||||||
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
|
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
|
||||||
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
|
// CHECK: [[ELEM_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||||
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
|
// CHECK: [[ACC_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||||
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
|
// CHECK: [[ACC_OUT_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||||
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
// CHECK: memref.store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||||
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
|
// CHECK: memref.store [[ACC]], [[ACC_BUF]][] : memref<f32>
|
||||||
// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
||||||
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
|
// CHECK: [[ACC_RESULT:%.*]] = memref.load [[ACC_OUT_BUF]][] : memref<f32>
|
||||||
// CHECK: scf.reduce.return [[ACC_RESULT]]
|
// CHECK: scf.reduce.return [[ACC_RESULT]]
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
// CHECK: scf.yield
|
// CHECK: scf.yield
|
||||||
// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[C0]]]
|
// CHECK: memref.store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[C0]]]
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
@ -104,30 +104,30 @@ func @dynamic_reduce(%arg: memref<?x?x?xf32>,
|
||||||
// CHECK-DAG: [[C0:%.*]] = constant 0 : index
|
// CHECK-DAG: [[C0:%.*]] = constant 0 : index
|
||||||
// CHECK-DAG: [[C1:%.*]] = constant 1 : index
|
// CHECK-DAG: [[C1:%.*]] = constant 1 : index
|
||||||
// CHECK-DAG: [[C2:%.*]] = constant 2 : index
|
// CHECK-DAG: [[C2:%.*]] = constant 2 : index
|
||||||
// CHECK: [[DIM0:%.*]] = dim [[ARG_BUF]], [[C0]] : memref<?x?x?xf32>
|
// CHECK: [[DIM0:%.*]] = memref.dim [[ARG_BUF]], [[C0]] : memref<?x?x?xf32>
|
||||||
// CHECK: [[DIM1:%.*]] = dim [[ARG_BUF]], [[C1]] : memref<?x?x?xf32>
|
// CHECK: [[DIM1:%.*]] = memref.dim [[ARG_BUF]], [[C1]] : memref<?x?x?xf32>
|
||||||
// CHECK: [[DIM2:%.*]] = dim [[ARG_BUF]], [[C2]] : memref<?x?x?xf32>
|
// CHECK: [[DIM2:%.*]] = memref.dim [[ARG_BUF]], [[C2]] : memref<?x?x?xf32>
|
||||||
// CHECK: [[INIT:%.*]] = load [[INIT_BUF]]
|
// CHECK: [[INIT:%.*]] = memref.load [[INIT_BUF]]
|
||||||
// CHECK: scf.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]])
|
// CHECK: scf.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]])
|
||||||
// CHECK-SAME: to ([[DIM0]], [[DIM2]]) step ([[C1]], [[C1]]) {
|
// CHECK-SAME: to ([[DIM0]], [[DIM2]]) step ([[C1]], [[C1]]) {
|
||||||
// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[J:%.*]]) =
|
// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[J:%.*]]) =
|
||||||
// CHECK-SAME: ([[C0]]) to ([[DIM1]]) step ([[C1]]) init ([[INIT]]) -> f32 {
|
// CHECK-SAME: ([[C0]]) to ([[DIM1]]) step ([[C1]]) init ([[INIT]]) -> f32 {
|
||||||
// CHECK: [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]]
|
// CHECK: [[ELEM_TO_REDUCE:%.*]] = memref.load [[ARG_BUF]]
|
||||||
// CHECK-SAME: {{\[}}[[I]], [[J]], [[K]]] : memref<?x?x?xf32>
|
// CHECK-SAME: {{\[}}[[I]], [[J]], [[K]]] : memref<?x?x?xf32>
|
||||||
// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 {
|
// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 {
|
||||||
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
|
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
|
||||||
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
|
// CHECK: [[ELEM_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||||
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
|
// CHECK: [[ACC_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||||
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
|
// CHECK: [[ACC_OUT_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||||
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
// CHECK: memref.store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||||
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
|
// CHECK: memref.store [[ACC]], [[ACC_BUF]][] : memref<f32>
|
||||||
// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
||||||
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
|
// CHECK: [[ACC_RESULT:%.*]] = memref.load [[ACC_OUT_BUF]][] : memref<f32>
|
||||||
// CHECK: scf.reduce.return [[ACC_RESULT]] : f32
|
// CHECK: scf.reduce.return [[ACC_RESULT]] : f32
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
// CHECK: scf.yield
|
// CHECK: scf.yield
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]]
|
// CHECK: memref.store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]]
|
||||||
// CHECK: scf.yield
|
// CHECK: scf.yield
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
@ -157,7 +157,7 @@ func @reduce_window(%arg: memref<112x112xf32>,
|
||||||
// CHECK-DAG: [[C3:%.*]] = constant 3 : index
|
// CHECK-DAG: [[C3:%.*]] = constant 3 : index
|
||||||
// CHECK-DAG: [[C56:%.*]] = constant 56 : index
|
// CHECK-DAG: [[C56:%.*]] = constant 56 : index
|
||||||
// CHECK-DAG: [[C112:%.*]] = constant 112 : index
|
// CHECK-DAG: [[C112:%.*]] = constant 112 : index
|
||||||
// CHECK: [[INIT:%.*]] = load [[INIT_BUF]][] : memref<f32>
|
// CHECK: [[INIT:%.*]] = memref.load [[INIT_BUF]][] : memref<f32>
|
||||||
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
|
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
|
||||||
// CHECK-SAME: to ([[C56]], [[C56]]) step ([[C1]], [[C1]]) {
|
// CHECK-SAME: to ([[C56]], [[C56]]) step ([[C1]], [[C1]]) {
|
||||||
// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel
|
// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel
|
||||||
|
@ -176,7 +176,7 @@ func @reduce_window(%arg: memref<112x112xf32>,
|
||||||
|
|
||||||
// CHECK: [[ELEM_TO_REDUCE:%.*]] = scf.if [[IN_BOUNDS_1]] -> (f32) {
|
// CHECK: [[ELEM_TO_REDUCE:%.*]] = scf.if [[IN_BOUNDS_1]] -> (f32) {
|
||||||
// CHECK: [[OPERAND_ELEM:%.*]] =
|
// CHECK: [[OPERAND_ELEM:%.*]] =
|
||||||
// CHECK-SAME: load [[OPERAND_BUF]]{{\[}}[[INDEX_I]], [[INDEX_J]]]
|
// CHECK-SAME: memref.load [[OPERAND_BUF]]{{\[}}[[INDEX_I]], [[INDEX_J]]]
|
||||||
// CHECK: scf.yield [[OPERAND_ELEM]] : f32
|
// CHECK: scf.yield [[OPERAND_ELEM]] : f32
|
||||||
// CHECK: } else {
|
// CHECK: } else {
|
||||||
// CHECK: scf.yield [[INIT]] : f32
|
// CHECK: scf.yield [[INIT]] : f32
|
||||||
|
@ -184,18 +184,18 @@ func @reduce_window(%arg: memref<112x112xf32>,
|
||||||
|
|
||||||
// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 {
|
// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 {
|
||||||
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
|
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
|
||||||
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
|
// CHECK: [[ELEM_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||||
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
|
// CHECK: [[ACC_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||||
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
|
// CHECK: [[ACC_OUT_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||||
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
// CHECK: memref.store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||||
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
|
// CHECK: memref.store [[ACC]], [[ACC_BUF]][] : memref<f32>
|
||||||
// CHECK: "lmhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
// CHECK: "lmhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
||||||
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
|
// CHECK: [[ACC_RESULT:%.*]] = memref.load [[ACC_OUT_BUF]][] : memref<f32>
|
||||||
// CHECK: scf.reduce.return [[ACC_RESULT]] : f32
|
// CHECK: scf.reduce.return [[ACC_RESULT]] : f32
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
// CHECK: scf.yield
|
// CHECK: scf.yield
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]]
|
// CHECK: memref.store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]]
|
||||||
// CHECK: scf.yield
|
// CHECK: scf.yield
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
|
|
|
@ -30,7 +30,7 @@ func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf3
|
||||||
|
|
||||||
// CHECK-LABEL: func @conv_forward
|
// CHECK-LABEL: func @conv_forward
|
||||||
func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %output: memref<1x1x7x7xf16>) {
|
func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %output: memref<1x1x7x7xf16>) {
|
||||||
%scratch = alloc() : memref<32xi8>
|
%scratch = memref.alloc() : memref<32xi8>
|
||||||
// This defined a 2D convolution over a 8x8 single channel input using a 2x2
|
// This defined a 2D convolution over a 8x8 single channel input using a 2x2
|
||||||
// filter and with an output of 7x7xf16. The 1x1x8x8 is (N, C, H, W)
|
// filter and with an output of 7x7xf16. The 1x1x8x8 is (N, C, H, W)
|
||||||
"lmhlo_gpu.conv_forward"(%input, %filter, %output, %scratch)
|
"lmhlo_gpu.conv_forward"(%input, %filter, %output, %scratch)
|
||||||
|
@ -61,7 +61,7 @@ func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %
|
||||||
|
|
||||||
// CHECK-LABEL: func @conv_backfilter
|
// CHECK-LABEL: func @conv_backfilter
|
||||||
func @conv_backfilter(%input : memref<3x56x56x16xf64>, %filter: memref<3x3x3x64xf64>, %output: memref<54x54x16x64xf64>) {
|
func @conv_backfilter(%input : memref<3x56x56x16xf64>, %filter: memref<3x3x3x64xf64>, %output: memref<54x54x16x64xf64>) {
|
||||||
%scratch = alloc() : memref<23328xui8>
|
%scratch = memref.alloc() : memref<23328xui8>
|
||||||
"lmhlo_gpu.conv_backwardfilter"(%input, %filter, %output, %scratch)
|
"lmhlo_gpu.conv_backwardfilter"(%input, %filter, %output, %scratch)
|
||||||
{ backend_config = {algorithm = 1 : i64,
|
{ backend_config = {algorithm = 1 : i64,
|
||||||
operand_0_layout = [3,2,1,0],
|
operand_0_layout = [3,2,1,0],
|
||||||
|
@ -91,7 +91,7 @@ func @conv_backfilter(%input : memref<3x56x56x16xf64>, %filter: memref<3x3x3x64x
|
||||||
|
|
||||||
// CHECK-LABEL: func @conv_backinput
|
// CHECK-LABEL: func @conv_backinput
|
||||||
func @conv_backinput(%input : memref<4x5x16x16xf64>, %filter : memref<5x3x7x7xf64>, %output : memref<4x3x16x16xf64>) {
|
func @conv_backinput(%input : memref<4x5x16x16xf64>, %filter : memref<5x3x7x7xf64>, %output : memref<4x3x16x16xf64>) {
|
||||||
%scratch = alloc() : memref<32xui8>
|
%scratch = memref.alloc() : memref<32xui8>
|
||||||
"lmhlo_gpu.conv_backwardinput"(%input, %filter, %output, %scratch)
|
"lmhlo_gpu.conv_backwardinput"(%input, %filter, %output, %scratch)
|
||||||
{ backend_config = {algorithm = 1 : i64,
|
{ backend_config = {algorithm = 1 : i64,
|
||||||
operand_0_layout = [3,2,1,0],
|
operand_0_layout = [3,2,1,0],
|
||||||
|
@ -122,7 +122,7 @@ func @conv_backinput(%input : memref<4x5x16x16xf64>, %filter : memref<5x3x7x7xf6
|
||||||
|
|
||||||
// CHECK-LABEL: func @conv_fused
|
// CHECK-LABEL: func @conv_fused
|
||||||
func @conv_fused(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>, %bias : memref<32xf16>, %output : memref<1x32x9x9xf16>) {
|
func @conv_fused(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>, %bias : memref<32xf16>, %output : memref<1x32x9x9xf16>) {
|
||||||
%scratch = alloc() : memref<32xui8>
|
%scratch = memref.alloc() : memref<32xui8>
|
||||||
"lmhlo_gpu.conv_forward_fused"(%input, %filter, %bias, %output, %scratch)
|
"lmhlo_gpu.conv_forward_fused"(%input, %filter, %bias, %output, %scratch)
|
||||||
{activation_mode = "Relu",
|
{activation_mode = "Relu",
|
||||||
backend_config = {algorithm = 1 : i64,
|
backend_config = {algorithm = 1 : i64,
|
||||||
|
@ -153,7 +153,7 @@ func @conv_fused(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>,
|
||||||
|
|
||||||
// CHECK-LABEL: func @conv_fused_side_input
|
// CHECK-LABEL: func @conv_fused_side_input
|
||||||
func @conv_fused_side_input(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>, %bias : memref<32xf16>, %side_input: memref<32xf16>, %output : memref<1x32x9x9xf16>) {
|
func @conv_fused_side_input(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>, %bias : memref<32xf16>, %side_input: memref<32xf16>, %output : memref<1x32x9x9xf16>) {
|
||||||
%scratch = alloc() : memref<0xui8>
|
%scratch = memref.alloc() : memref<0xui8>
|
||||||
"lmhlo_gpu.conv_forward_fused_with_side_input"(%input, %filter, %bias, %side_input, %output, %scratch)
|
"lmhlo_gpu.conv_forward_fused_with_side_input"(%input, %filter, %bias, %side_input, %output, %scratch)
|
||||||
{activation_mode = "Relu",
|
{activation_mode = "Relu",
|
||||||
backend_config = {algorithm = 1 : i64,
|
backend_config = {algorithm = 1 : i64,
|
||||||
|
@ -218,8 +218,8 @@ func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>,
|
||||||
|
|
||||||
// CHECK-LABEL: func @cholesky
|
// CHECK-LABEL: func @cholesky
|
||||||
func @cholesky(%arg : memref<10x10xf32>, %out: memref<10x10xf32>) {
|
func @cholesky(%arg : memref<10x10xf32>, %out: memref<10x10xf32>) {
|
||||||
%scratch = alloc() : memref<32xi8>
|
%scratch = memref.alloc() : memref<32xi8>
|
||||||
%info = alloc() : memref<32xi32>
|
%info = memref.alloc() : memref<32xi32>
|
||||||
"lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_lower = true }
|
"lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_lower = true }
|
||||||
: (memref<10x10xf32>, memref<10x10xf32>, memref<32xi8>, memref<32xi32>) -> ()
|
: (memref<10x10xf32>, memref<10x10xf32>, memref<32xi8>, memref<32xi32>) -> ()
|
||||||
return
|
return
|
||||||
|
|
|
@ -457,12 +457,12 @@ func @reduce_memref(%input: memref<10xf32>, %init: memref<f32>, %out: memref<1xf
|
||||||
// CHECK-LABEL: func @fusion_memref
|
// CHECK-LABEL: func @fusion_memref
|
||||||
func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: memref<10xf32>, %out: memref<10xf32>) -> () {
|
func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||||
"lmhlo.fusion"() ( {
|
"lmhlo.fusion"() ( {
|
||||||
%0 = tensor_load %input1 : memref<10xf32>
|
%0 = memref.tensor_load %input1 : memref<10xf32>
|
||||||
%1 = tensor_load %input2 : memref<10xf32>
|
%1 = memref.tensor_load %input2 : memref<10xf32>
|
||||||
%2 = "mhlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
|
%2 = "mhlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
|
||||||
%3 = tensor_load %input3 : memref<10xf32>
|
%3 = memref.tensor_load %input3 : memref<10xf32>
|
||||||
%4 = "mhlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
|
%4 = "mhlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
|
||||||
tensor_store %4, %out : memref<10xf32>
|
memref.tensor_store %4, %out : memref<10xf32>
|
||||||
"lmhlo.terminator"() : () -> ()
|
"lmhlo.terminator"() : () -> ()
|
||||||
} ) : () -> ()
|
} ) : () -> ()
|
||||||
return
|
return
|
||||||
|
|
|
@ -108,15 +108,15 @@ func @batchNormInference_dynamic_shape(
|
||||||
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
|
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
|
||||||
// CHECK-DAG: %[[C3:.*]] = constant 3 : index
|
// CHECK-DAG: %[[C3:.*]] = constant 3 : index
|
||||||
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e-03> : tensor<f32>
|
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e-03> : tensor<f32>
|
||||||
// CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], %[[C0]] : tensor<?xf32>
|
// CHECK-DAG: %[[DIM:.+]] = memref.dim %[[VARIANCE]], %[[C0]] : tensor<?xf32>
|
||||||
// CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor.from_elements %[[DIM]] : tensor<1xindex>
|
// CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor.from_elements %[[DIM]] : tensor<1xindex>
|
||||||
// CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
|
// CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<?xf32>
|
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<?xf32>
|
||||||
// CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<?xf32>) -> tensor<?xf32>
|
// CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<?xf32>) -> tensor<?xf32>
|
||||||
// CHECK-DAG: %[[INPUT_DIM_0:.+]] = dim %[[X]], %[[C0]] : tensor<?x?x?x?xf32>
|
// CHECK-DAG: %[[INPUT_DIM_0:.+]] = memref.dim %[[X]], %[[C0]] : tensor<?x?x?x?xf32>
|
||||||
// CHECK-DAG: %[[INPUT_DIM_1:.+]] = dim %[[X]], %[[C1]] : tensor<?x?x?x?xf32>
|
// CHECK-DAG: %[[INPUT_DIM_1:.+]] = memref.dim %[[X]], %[[C1]] : tensor<?x?x?x?xf32>
|
||||||
// CHECK-DAG: %[[INPUT_DIM_2:.+]] = dim %[[X]], %[[C2]] : tensor<?x?x?x?xf32>
|
// CHECK-DAG: %[[INPUT_DIM_2:.+]] = memref.dim %[[X]], %[[C2]] : tensor<?x?x?x?xf32>
|
||||||
// CHECK-DAG: %[[INPUT_DIM_3:.+]] = dim %[[X]], %[[C3]] : tensor<?x?x?x?xf32>
|
// CHECK-DAG: %[[INPUT_DIM_3:.+]] = memref.dim %[[X]], %[[C3]] : tensor<?x?x?x?xf32>
|
||||||
// CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = tensor.from_elements %[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]] : tensor<4xindex>
|
// CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = tensor.from_elements %[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]] : tensor<4xindex>
|
||||||
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||||
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||||
|
|
Loading…
Reference in New Issue