[HLO] Delete LHLO memref cast ops and migrate to STD ones.
PiperOrigin-RevId: 340663578
This commit is contained in:
parent
82031b356c
commit
3d930d08c2
|
@ -314,169 +314,6 @@ def HLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> {
|
|||
);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// StaticMemRefCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def HLO_StaticMemRefCastOp: Op<LHLO_Dialect, "static_memref_cast",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
|
||||
let summary = [{
|
||||
modifies the offset, sizes and strides of a statically shaped memref
|
||||
}];
|
||||
let description = [{
|
||||
Casts the statically shaped memref operand to a memref with optionally
|
||||
modified offsets, sizes and strides.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
%buf_transformed =
|
||||
lmhlo.static_memref_cast %buf
|
||||
: memref<1x5xf32> -> memref<5xf32, offset: 2, strides: [1]>
|
||||
|
||||
// The result of the op is a rank-1 memref with `[5]` shape, stride 1 and
|
||||
// offset 2.
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Arg<LHLO_Buffer, "", []>:$operand);
|
||||
let results = (outs Res<LHLO_Buffer, "", []>:$result);
|
||||
|
||||
let builders = [
|
||||
OpBuilderDAG<(ins "MemRefType":$resultType, "Value":$operand),
|
||||
[{
|
||||
$_state.addOperands(operand);
|
||||
$_state.types.push_back(resultType);
|
||||
}]>];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
|
||||
}];
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
let assemblyFormat = [{
|
||||
$operand attr-dict `:` type($operand) `->` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DynamicMemRefCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def HLO_DynamicMemRefCastOp: Op<LHLO_Dialect, "dynamic_memref_cast",
|
||||
[SameVariadicOperandSize, NoSideEffect,
|
||||
DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
|
||||
let summary = "dynamic memref cast operation";
|
||||
let description = [{
|
||||
Change sizes and strides of a memref using the values computed in runtime.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
%buf_transformed =
|
||||
lmhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%step_X, %step_Y]
|
||||
: memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
|
||||
// The result of the op is a type-erased memref with `[%size_X, %size_Y]`
|
||||
// shape and `[%step_X, %step_Y]` strides. The offset will be inherited
|
||||
// from the input.
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", []>:$operand,
|
||||
Variadic<Index>:$sizes,
|
||||
Variadic<Index>:$strides
|
||||
);
|
||||
let results = (outs Res<LHLO_Buffer, "", []>:$result);
|
||||
|
||||
let builders = [
|
||||
OpBuilderDAG<(ins "MemRefType":$resultType, "Value":$operand,
|
||||
"ValueRange":$sizes, "ValueRange":$strides),
|
||||
[{
|
||||
$_state.addOperands(operand);
|
||||
$_state.addOperands(sizes);
|
||||
$_state.addOperands(strides);
|
||||
$_state.types.push_back(resultType);
|
||||
}]>];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
|
||||
}];
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
let assemblyFormat = [{
|
||||
$operand `(` $sizes `)` `[` $strides `]` attr-dict `:` type($operand) `->`
|
||||
type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReshapeMemRefCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ReshapeMemRefCastOp: Op<LHLO_Dialect, "reshape_memref_cast", [
|
||||
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
|
||||
NoSideEffect]> {
|
||||
let summary = "reshape memref cast operation";
|
||||
let description = [{
|
||||
The `reshape_memref_cast` operation converts a memref from one type to an
|
||||
equivalent type with a provided shape. The data is never copied or moved.
|
||||
The source and destination types are compatible if both have the same
|
||||
element type, address space and identity layout map. The following
|
||||
combinations are possible:
|
||||
|
||||
a. Both are ranked memref types.
|
||||
|
||||
```mlir
|
||||
// Reshape statically-shaped memref.
|
||||
%dst = reshape_memref_cast %src(%shape)
|
||||
: (memref<4x1xf32>, memref<1xi32>) to memref<4xf32>
|
||||
%dst0 = reshape_memref_cast %src(%shape0)
|
||||
: (memref<4x1xf32>, memref<2xi32>) to memref<2x2xf32>
|
||||
```
|
||||
|
||||
b. Source type is ranked, destination type is unranked.
|
||||
|
||||
```mlir
|
||||
// Reshape dynamically-shaped 1D memref.
|
||||
%dst = reshape_memref_cast %src(%shape)
|
||||
: (memref<?xf32>, memref<?xi32>) to memref<*xf32>
|
||||
```
|
||||
|
||||
c. Source type is unranked, destination type is ranked.
|
||||
|
||||
```mlir
|
||||
// Flatten unranked memref.
|
||||
%dst = reshape_memref_cast %src(%shape)
|
||||
: (memref<*xf32>, memref<1xi32>) to memref<?xf32>
|
||||
```
|
||||
|
||||
d. Both are unranked memref types.
|
||||
|
||||
```mlir
|
||||
// Reshape unranked memref.
|
||||
%dst = reshape_memref_cast %src(%shape)
|
||||
: (memref<*xf32>, memref<?xi32>) to memref<*xf32>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AnyRankedOrUnrankedMemRef:$operand,
|
||||
LHLO_ExtentBuffer:$shape
|
||||
);
|
||||
let results = (outs AnyRankedOrUnrankedMemRef:$result);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
BaseMemRefType getType() {
|
||||
return getResult().getType().cast<BaseMemRefType>(); }
|
||||
}];
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
let assemblyFormat = [{
|
||||
$operand `(` $shape `)` attr-dict `:` `(` type($operand) `,` type($shape)
|
||||
`)` `->` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LMHLO Other op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -46,12 +46,6 @@ def LhloLegalizeToGpuPass : Pass<"lhlo-legalize-to-gpu", "FuncOp"> {
|
|||
}
|
||||
|
||||
|
||||
def TestLhloToLLVMPass : Pass<"test-lhlo-legalize-to-llvm", "FuncOp"> {
|
||||
let summary = "Legalize from LHLO dialect to LLVM.";
|
||||
let constructor = "createTestLhloToLLVMPass()";
|
||||
}
|
||||
|
||||
|
||||
def LhloLegalizeToParallelLoopsPass : Pass<"lhlo-legalize-to-parallel-loops", "FuncOp"> {
|
||||
let summary = "Legalize from LHLO dialect to parallel loops.";
|
||||
let constructor = "createLegalizeLhloToParallelLoopsPass()";
|
||||
|
|
|
@ -35,8 +35,6 @@ inline void registerAllMhloPasses() { registerMHLOPasses(); }
|
|||
|
||||
namespace lmhlo {
|
||||
|
||||
std::unique_ptr<Pass> createTestLhloToLLVMPass();
|
||||
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.h.inc"
|
||||
|
||||
|
|
|
@ -24,8 +24,6 @@ limitations under the License.
|
|||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
class LLVMTypeConverter;
|
||||
class LowerToLLVMOptions;
|
||||
class OwningRewritePatternList;
|
||||
|
||||
// Populates a collection of rewrite patterns to realize element-wise operations
|
||||
|
@ -94,14 +92,6 @@ void PopulateTrigonometricToApproximationPatterns(
|
|||
|
||||
} // namespace mhlo
|
||||
|
||||
namespace lmhlo {
|
||||
|
||||
/// Collect a set of patterns to convert from the LHLO dialect to LLVM.
|
||||
void PopulateLhloToLLVMConversionPatterns(LLVMTypeConverter *converter,
|
||||
OwningRewritePatternList *patterns);
|
||||
|
||||
} // namespace lmhlo
|
||||
|
||||
namespace chlo {
|
||||
|
||||
// Populates a collection of conversion patterns for legalizing client-HLO to
|
||||
|
|
|
@ -88,76 +88,6 @@ void ConstOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
|
|||
results.insert<EraseConstOp>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// StaticMemRefCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Value StaticMemRefCastOp::getViewSource() { return *getODSOperands(0).begin(); }
|
||||
|
||||
static LogicalResult Verify(StaticMemRefCastOp op) {
|
||||
if (!op.operand().getType().cast<ShapedType>().hasStaticShape())
|
||||
return op.emitOpError("operand must have static shape");
|
||||
if (!op.getType().hasStaticShape())
|
||||
return op.emitOpError("result must have static shape");
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DynamicMemRefCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Value DynamicMemRefCastOp::getViewSource() {
|
||||
return *getODSOperands(0).begin();
|
||||
}
|
||||
|
||||
static LogicalResult Verify(DynamicMemRefCastOp op) {
|
||||
// Check if `sizes` and `strides` args are compatible with the result type.
|
||||
if (op.sizes().size() != op.getType().getRank())
|
||||
return op.emitOpError(
|
||||
"`sizes` args count must be equal to the rank of the output memref");
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReshapeMemrefCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Value ReshapeMemRefCastOp::getViewSource() { return operand(); }
|
||||
|
||||
static LogicalResult Verify(ReshapeMemRefCastOp op) {
|
||||
Type operandType = op.operand().getType();
|
||||
Type resultType = op.result().getType();
|
||||
|
||||
Type operandElementType = operandType.cast<ShapedType>().getElementType();
|
||||
Type resultElementType = resultType.cast<ShapedType>().getElementType();
|
||||
if (operandElementType != resultElementType)
|
||||
return op.emitOpError(
|
||||
"element types of source and destination memref "
|
||||
"types should be the same");
|
||||
|
||||
if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
|
||||
if (!operandMemRefType.getAffineMaps().empty())
|
||||
return op.emitOpError(
|
||||
"operand memref type should have identity affine map");
|
||||
|
||||
int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0);
|
||||
auto resultMemRefType = resultType.dyn_cast<MemRefType>();
|
||||
if (resultMemRefType) {
|
||||
if (shapeSize == ShapedType::kDynamicSize)
|
||||
return op.emitOpError(
|
||||
"cannot use shape operand with dynamic length to "
|
||||
"cast statically-ranked memref type");
|
||||
if (shapeSize != resultMemRefType.getRank())
|
||||
return op.emitOpError(
|
||||
"length of shape operand differs from the result's memref rank");
|
||||
|
||||
if (!resultMemRefType.getAffineMaps().empty())
|
||||
return op.emitOpError(
|
||||
"result memref type should have identity affine map");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -134,8 +134,6 @@ add_mlir_library(LmhloPasses
|
|||
lhlo_fuse_linalg.cc
|
||||
lhlo_legalize_to_affine.cc
|
||||
lhlo_legalize_to_gpu.cc
|
||||
lhlo_legalize_to_llvm.cc
|
||||
lhlo_legalize_to_llvm_pass.cc
|
||||
lhlo_legalize_to_parallel_loops.cc
|
||||
|
||||
DEPENDS
|
||||
|
|
|
@ -206,7 +206,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
|
|||
// Inserts dynamic memref to change the layout of the memref to put 0-stride
|
||||
// and size of the target dimension if size-1 dimension expansion is
|
||||
// necessary.
|
||||
lmhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp(
|
||||
MemRefReinterpretCastOp InsertDynamicMemrefCastOp(
|
||||
mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
|
||||
auto loc = op.getLoc();
|
||||
auto operand_type = operand.getType().cast<MemRefType>();
|
||||
|
@ -259,8 +259,13 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
|
|||
makeStridedLinearLayoutMap(dynamic_layout,
|
||||
/*offset=*/0, b->getContext()));
|
||||
|
||||
auto transformed_operand = b->create<lmhlo::DynamicMemRefCastOp>(
|
||||
loc, type_erased_memref_type, operand, sizes, strides);
|
||||
SmallVector<int64_t, 2> static_sizes(sizes.size(),
|
||||
ShapedType::kDynamicSize);
|
||||
SmallVector<int64_t, 2> static_strides(strides.size(),
|
||||
ShapedType::kDynamicStrideOrOffset);
|
||||
auto transformed_operand = b->create<MemRefReinterpretCastOp>(
|
||||
loc, type_erased_memref_type, operand, /*offset=*/0, static_sizes,
|
||||
static_strides, llvm::None, sizes, strides);
|
||||
return transformed_operand;
|
||||
}
|
||||
};
|
||||
|
@ -284,7 +289,7 @@ struct HloToLhloDynamicReshapeConverter
|
|||
return failure();
|
||||
}
|
||||
mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
|
||||
rewriter.replaceOpWithNewOp<lmhlo::ReshapeMemRefCastOp>(
|
||||
rewriter.replaceOpWithNewOp<MemRefReshapeOp>(
|
||||
op, result_type, adaptor.operand(), adaptor.output_shape());
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -1,370 +0,0 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace lmhlo {
|
||||
namespace {
|
||||
|
||||
struct StaticMemRefCastOpConverter
|
||||
: public ConvertOpToLLVMPattern<StaticMemRefCastOp> {
|
||||
using ConvertOpToLLVMPattern<StaticMemRefCastOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
auto cast_op = cast<StaticMemRefCastOp>(op);
|
||||
|
||||
StaticMemRefCastOp::Adaptor operands_adaptor(operands);
|
||||
MemRefDescriptor sourceMemRef(operands_adaptor.operand());
|
||||
|
||||
MemRefType targetMemRefType =
|
||||
cast_op.getResult().getType().cast<MemRefType>();
|
||||
auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
|
||||
.dyn_cast_or_null<LLVM::LLVMType>();
|
||||
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
|
||||
return failure();
|
||||
// Create descriptor.
|
||||
auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
|
||||
Type llvmTargetElementTy = desc.getElementPtrType();
|
||||
// Set allocated ptr.
|
||||
Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
|
||||
allocated =
|
||||
rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
|
||||
desc.setAllocatedPtr(rewriter, loc, allocated);
|
||||
// Set aligned ptr.
|
||||
Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
|
||||
ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
|
||||
desc.setAlignedPtr(rewriter, loc, ptr);
|
||||
|
||||
// Fill size and stride descriptors in memref.
|
||||
auto target_sizes = targetMemRefType.getShape();
|
||||
int64_t target_offset;
|
||||
llvm::SmallVector<int64_t, 4> target_strides;
|
||||
if (failed((getStridesAndOffset(targetMemRefType, target_strides,
|
||||
target_offset))))
|
||||
return failure();
|
||||
|
||||
// Copy offset of `targetMemRef`.
|
||||
desc.setConstantOffset(rewriter, loc, target_offset);
|
||||
for (int i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
|
||||
desc.setConstantSize(rewriter, loc, i, target_sizes[i]);
|
||||
desc.setConstantStride(rewriter, loc, i, target_strides[i]);
|
||||
}
|
||||
rewriter.replaceOp(op, {desc});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct DynamicMemRefCastOpConverter
|
||||
: public ConvertOpToLLVMPattern<DynamicMemRefCastOp> {
|
||||
using ConvertOpToLLVMPattern<DynamicMemRefCastOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
auto cast_op = cast<DynamicMemRefCastOp>(op);
|
||||
|
||||
DynamicMemRefCastOp::Adaptor operands_adaptor(operands);
|
||||
MemRefDescriptor sourceMemRef(operands_adaptor.operand());
|
||||
|
||||
MemRefType targetMemRefType =
|
||||
cast_op.getResult().getType().cast<MemRefType>();
|
||||
auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
|
||||
.dyn_cast_or_null<LLVM::LLVMType>();
|
||||
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
|
||||
return failure();
|
||||
// Create descriptor.
|
||||
auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
|
||||
Type llvmTargetElementTy = desc.getElementPtrType();
|
||||
// Set allocated ptr.
|
||||
Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
|
||||
allocated =
|
||||
rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
|
||||
desc.setAllocatedPtr(rewriter, loc, allocated);
|
||||
// Set aligned ptr.
|
||||
Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
|
||||
ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
|
||||
desc.setAlignedPtr(rewriter, loc, ptr);
|
||||
// Copy offset of `sourceMemRef`.
|
||||
desc.setOffset(rewriter, loc, sourceMemRef.offset(rewriter, loc));
|
||||
|
||||
// Fill size and stride descriptors in memref.
|
||||
if (!cast_op.sizes().empty()) {
|
||||
auto sizes = operands_adaptor.sizes();
|
||||
auto strides = operands_adaptor.strides();
|
||||
for (int i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
|
||||
desc.setSize(rewriter, loc, i, sizes[i]);
|
||||
desc.setStride(rewriter, loc, i, strides[i]);
|
||||
}
|
||||
}
|
||||
rewriter.replaceOp(op, {desc});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ReshapeMemRefCastOpConverter
|
||||
: public ConvertOpToLLVMPattern<ReshapeMemRefCastOp> {
|
||||
using ConvertOpToLLVMPattern<ReshapeMemRefCastOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
|
||||
auto reshape_op = cast<ReshapeMemRefCastOp>(op);
|
||||
auto dst_type = reshape_op.getResult().getType().cast<BaseMemRefType>();
|
||||
auto element_type = dst_type.getElementType();
|
||||
|
||||
auto shape = reshape_op.shape();
|
||||
|
||||
ReshapeMemRefCastOp::Adaptor operands_adaptor(operands);
|
||||
PtrsAndOffset ptrs_n_offset = ExtractMemRefPtrsAndOffset(
|
||||
loc, reshape_op.operand(), operands_adaptor.operand(), &rewriter);
|
||||
|
||||
MemRefDescriptor shape_desc(operands_adaptor.shape());
|
||||
|
||||
auto shape_memref_type = shape.getType().cast<MemRefType>();
|
||||
|
||||
if (shape_memref_type.hasStaticShape()) {
|
||||
auto shape_length = shape_memref_type.getDimSize(0);
|
||||
|
||||
MemRefType targetMemRefType = MemRefType::get(
|
||||
SmallVector<int64_t, 1>(shape_length, 1), element_type);
|
||||
auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
|
||||
.dyn_cast_or_null<LLVM::LLVMType>();
|
||||
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
|
||||
return failure();
|
||||
// Create descriptor.
|
||||
auto desc =
|
||||
MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
|
||||
desc.setAllocatedPtr(rewriter, loc, ptrs_n_offset.allocated_ptr);
|
||||
desc.setAlignedPtr(rewriter, loc, ptrs_n_offset.aligned_ptr);
|
||||
desc.setOffset(rewriter, loc, ptrs_n_offset.offset);
|
||||
|
||||
auto llvm_index_type = typeConverter.getIndexType();
|
||||
auto llvm_index_ptr_type = llvm_index_type.getPointerTo();
|
||||
Value stride_carried = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, llvm_index_type,
|
||||
rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
|
||||
for (int i = shape_length - 1; i >= 0; --i) {
|
||||
Value pos = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, llvm_index_type,
|
||||
rewriter.getIntegerAttr(rewriter.getIndexType(), i));
|
||||
Value ptr = rewriter.create<LLVM::GEPOp>(
|
||||
loc, llvm_index_ptr_type, shape_desc.alignedPtr(rewriter, loc),
|
||||
ValueRange{pos});
|
||||
Value extracted_size = rewriter.create<LLVM::LoadOp>(loc, ptr);
|
||||
desc.setSize(rewriter, loc, i, extracted_size);
|
||||
desc.setStride(rewriter, loc, i, stride_carried);
|
||||
// Update stride
|
||||
if (i > 0) {
|
||||
stride_carried =
|
||||
rewriter.create<LLVM::MulOp>(loc, stride_carried, extracted_size);
|
||||
}
|
||||
}
|
||||
if (dst_type.isa<MemRefType>()) {
|
||||
rewriter.replaceOp(op, {desc});
|
||||
} else {
|
||||
Value rank = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, llvm_index_type,
|
||||
rewriter.getIntegerAttr(rewriter.getIndexType(), shape_length));
|
||||
Value alloca =
|
||||
typeConverter.promoteOneMemRefDescriptor(loc, desc, rewriter);
|
||||
Value void_ptr =
|
||||
rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), alloca);
|
||||
auto unranked_desc = UnrankedMemRefDescriptor::pack(
|
||||
rewriter, loc, typeConverter, dst_type.cast<UnrankedMemRefType>(),
|
||||
{rank, void_ptr});
|
||||
rewriter.replaceOp(op, {unranked_desc});
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
// The shape is a rank-1 tensor with unknown length.
|
||||
Value result_rank = shape_desc.size(rewriter, loc, 0);
|
||||
// TODO(herhut): Propely handle address spaces.
|
||||
unsigned address_space = 0;
|
||||
auto target_type =
|
||||
typeConverter
|
||||
.convertType(UnrankedMemRefType::get(element_type, address_space))
|
||||
.cast<LLVM::LLVMType>();
|
||||
// Create the unranked memref descriptor that holds the ranked one. The
|
||||
// inner descriptor is allocated on stack.
|
||||
UnrankedMemRefDescriptor target_desc =
|
||||
UnrankedMemRefDescriptor::undef(rewriter, loc, target_type);
|
||||
target_desc.setRank(rewriter, loc, result_rank);
|
||||
SmallVector<Value, 1> sizes;
|
||||
UnrankedMemRefDescriptor::computeSizes(rewriter, loc, typeConverter,
|
||||
{target_desc}, sizes);
|
||||
auto void_ptr_type = LLVM::LLVMType::getInt8PtrTy(rewriter.getContext());
|
||||
Value ranked_desc_mem = rewriter.create<LLVM::AllocaOp>(
|
||||
loc, void_ptr_type, sizes.front(), llvm::None);
|
||||
target_desc.setMemRefDescPtr(rewriter, loc, ranked_desc_mem);
|
||||
|
||||
// Fill the fixed parts. For this, we cast to a 0-D memref.
|
||||
auto zero_d_memref_type = MemRefType::get({}, element_type);
|
||||
Value as_zero_d = rewriter.create<LLVM::BitcastOp>(
|
||||
loc,
|
||||
typeConverter.convertType(zero_d_memref_type)
|
||||
.cast<LLVM::LLVMType>()
|
||||
.getPointerTo(address_space),
|
||||
ranked_desc_mem);
|
||||
// Some common constants. Use 32 bit where required by gep struct indexes.
|
||||
auto int32_type = typeConverter.convertType(rewriter.getI32Type());
|
||||
Value zero_index = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, typeConverter.getIndexType(), rewriter.getIndexAttr(0));
|
||||
Value zero = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32_type, rewriter.getI32IntegerAttr(0));
|
||||
Value one = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32_type, rewriter.getI32IntegerAttr(1));
|
||||
Value two = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32_type, rewriter.getI32IntegerAttr(2));
|
||||
// Set base_pointer and aligned pointer.
|
||||
auto element_ptr_ptr_type = typeConverter.convertType(element_type)
|
||||
.cast<LLVM::LLVMType>()
|
||||
.getPointerTo(address_space)
|
||||
.getPointerTo(address_space);
|
||||
auto base_gep = rewriter.create<LLVM::GEPOp>(
|
||||
loc, element_ptr_ptr_type, as_zero_d, ValueRange({zero_index, zero}));
|
||||
rewriter.create<LLVM::StoreOp>(loc, ptrs_n_offset.allocated_ptr, base_gep);
|
||||
auto aligned_gep = rewriter.create<LLVM::GEPOp>(
|
||||
loc, element_ptr_ptr_type, as_zero_d, ValueRange({zero_index, one}));
|
||||
rewriter.create<LLVM::StoreOp>(loc, ptrs_n_offset.aligned_ptr, aligned_gep);
|
||||
// Set offset.
|
||||
auto index_ptr_type =
|
||||
typeConverter.getIndexType().getPointerTo(address_space);
|
||||
auto offset_gep = rewriter.create<LLVM::GEPOp>(
|
||||
loc, index_ptr_type, as_zero_d, ValueRange({zero_index, two}));
|
||||
rewriter.create<LLVM::StoreOp>(loc, ptrs_n_offset.offset, offset_gep);
|
||||
|
||||
// Use the offset pointer as base for further addressing. Copy over the
|
||||
// new shape and compute strides. For this, we need to create a loop from
|
||||
// rank - 1 to 0.
|
||||
Value one_index = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, typeConverter.getIndexType(), rewriter.getIndexAttr(1));
|
||||
auto target_shape_base = rewriter.create<LLVM::GEPOp>(
|
||||
loc, index_ptr_type, offset_gep, ValueRange({one}));
|
||||
auto target_strides_base = rewriter.create<LLVM::GEPOp>(
|
||||
loc, index_ptr_type, target_shape_base, ValueRange({result_rank}));
|
||||
auto shape_ptr = shape_desc.alignedPtr(rewriter, loc);
|
||||
auto result_rank_minus_one =
|
||||
rewriter.create<LLVM::SubOp>(loc, result_rank, one_index);
|
||||
|
||||
Block *init_block = rewriter.getInsertionBlock();
|
||||
Block *cond_block =
|
||||
rewriter.splitBlock(init_block, rewriter.getInsertionPoint());
|
||||
rewriter.setInsertionPointToEnd(init_block);
|
||||
rewriter.create<LLVM::BrOp>(
|
||||
loc, ValueRange({result_rank_minus_one, one_index}), cond_block);
|
||||
rewriter.setInsertionPointToStart(cond_block);
|
||||
auto index_arg = cond_block->addArgument(typeConverter.getIndexType());
|
||||
auto stride_arg = cond_block->addArgument(typeConverter.getIndexType());
|
||||
auto pred = rewriter.create<LLVM::ICmpOp>(
|
||||
loc, LLVM::LLVMType::getInt1Ty(rewriter.getContext()),
|
||||
LLVM::ICmpPredicate::sge, index_arg, zero_index);
|
||||
|
||||
Block *body_block =
|
||||
rewriter.splitBlock(cond_block, rewriter.getInsertionPoint());
|
||||
rewriter.setInsertionPointToStart(body_block);
|
||||
|
||||
// Copy size from shape to descriptor.
|
||||
auto size_load_gep = rewriter.create<LLVM::GEPOp>(
|
||||
loc, index_ptr_type, shape_ptr, ValueRange{index_arg});
|
||||
auto extracted_size = rewriter.create<LLVM::LoadOp>(loc, size_load_gep);
|
||||
auto size_store_gep = rewriter.create<LLVM::GEPOp>(
|
||||
loc, index_ptr_type, target_shape_base, ValueRange({index_arg}));
|
||||
rewriter.create<LLVM::StoreOp>(loc, extracted_size, size_store_gep);
|
||||
// Write stride value and compute next one.
|
||||
auto stride_store_gep = rewriter.create<LLVM::GEPOp>(
|
||||
loc, index_ptr_type, target_strides_base, ValueRange({index_arg}));
|
||||
rewriter.create<LLVM::StoreOp>(loc, stride_arg, stride_store_gep);
|
||||
auto next_stride =
|
||||
rewriter.create<LLVM::MulOp>(loc, stride_arg, extracted_size);
|
||||
|
||||
// Decrement loop counter and branch back.
|
||||
auto decrement = rewriter.create<LLVM::SubOp>(loc, index_arg, one_index);
|
||||
rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, next_stride}),
|
||||
cond_block);
|
||||
|
||||
Block *remainder =
|
||||
rewriter.splitBlock(body_block, rewriter.getInsertionPoint());
|
||||
|
||||
// Hook up the cond exit to the remainder.
|
||||
rewriter.setInsertionPointToEnd(cond_block);
|
||||
rewriter.create<LLVM::CondBrOp>(loc, pred, body_block, ValueRange(),
|
||||
remainder, ValueRange());
|
||||
|
||||
// Reset position to beginning of new remainder block.
|
||||
rewriter.setInsertionPointToStart(remainder);
|
||||
rewriter.replaceOp(op, {target_desc});
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
struct PtrsAndOffset {
|
||||
Value allocated_ptr;
|
||||
Value aligned_ptr;
|
||||
Value offset;
|
||||
};
|
||||
|
||||
PtrsAndOffset ExtractMemRefPtrsAndOffset(
|
||||
Location loc, Value originalOperand, Value convertedOperand,
|
||||
ConversionPatternRewriter *rewriter) const {
|
||||
Type operandType = originalOperand.getType();
|
||||
Value descriptor_ptr;
|
||||
if (operandType.isa<MemRefType>()) {
|
||||
descriptor_ptr = convertedOperand;
|
||||
} else {
|
||||
UnrankedMemRefDescriptor unranked_descriptor(convertedOperand);
|
||||
Value underlying_desc_ptr =
|
||||
unranked_descriptor.memRefDescPtr(*rewriter, loc);
|
||||
|
||||
Type element_type =
|
||||
operandType.cast<UnrankedMemRefType>().getElementType();
|
||||
LLVM::LLVMType memref_type_0d =
|
||||
typeConverter.convertType(MemRefType::get(/*shape=*/{}, element_type))
|
||||
.cast<LLVM::LLVMType>();
|
||||
descriptor_ptr = rewriter->create<LLVM::BitcastOp>(
|
||||
loc, memref_type_0d.getPointerTo(), underlying_desc_ptr);
|
||||
descriptor_ptr = rewriter->create<LLVM::LoadOp>(loc, descriptor_ptr);
|
||||
}
|
||||
MemRefDescriptor descriptor(descriptor_ptr);
|
||||
PtrsAndOffset result;
|
||||
result.allocated_ptr = descriptor.allocatedPtr(*rewriter, loc);
|
||||
result.aligned_ptr = descriptor.alignedPtr(*rewriter, loc);
|
||||
result.offset = descriptor.offset(*rewriter, loc);
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void PopulateLhloToLLVMConversionPatterns(LLVMTypeConverter *converter,
|
||||
OwningRewritePatternList *patterns) {
|
||||
patterns->insert<DynamicMemRefCastOpConverter, ReshapeMemRefCastOpConverter,
|
||||
StaticMemRefCastOpConverter>(*converter);
|
||||
}
|
||||
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
|
@ -1,63 +0,0 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace lmhlo {
|
||||
namespace {
|
||||
|
||||
class TestLhloToLLVMPass
|
||||
: public ::mlir::PassWrapper<TestLhloToLLVMPass,
|
||||
::mlir::OperationPass<::mlir::ModuleOp>> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<LLVM::LLVMDialect>();
|
||||
}
|
||||
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
ModuleOp m = getOperation();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
LLVMTypeConverter converter(&getContext());
|
||||
populateStdToLLVMConversionPatterns(converter, patterns);
|
||||
PopulateLhloToLLVMConversionPatterns(&converter, &patterns);
|
||||
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<LLVM::LLVMDialect>();
|
||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||
target.addIllegalDialect<LmhloDialect>();
|
||||
|
||||
if (failed(applyFullConversion(m, target, std::move(patterns)))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createTestLhloToLLVMPass() {
|
||||
return std::make_unique<TestLhloToLLVMPass>();
|
||||
}
|
||||
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -canonicalize -cse -convert-linalg-to-llvm -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
|
||||
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -canonicalize -cse -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
|
||||
|
||||
func @main() -> () {
|
||||
call @trivial_broadcast_wrapper() : () -> ()
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
// RUN: -buffer-deallocation -copy-removal -canonicalize -cse \
|
||||
// RUN: -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops \
|
||||
// RUN: -lower-affine -convert-scf-to-std -canonicalize -cse \
|
||||
// RUN: -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main \
|
||||
// RUN: -convert-std-to-llvm | mlir-cpu-runner -e main \
|
||||
// RUN: -entry-point-result=void \
|
||||
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | \
|
||||
// RUN: FileCheck %s
|
||||
|
|
|
@ -1,190 +0,0 @@
|
|||
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -convert-scf-to-std -canonicalize -cse -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
|
||||
|
||||
func @main() -> () {
|
||||
call @reshape_with_static_shape_size_matrix_to_1D() : () -> ()
|
||||
call @reshape_with_static_shape_size_matrix_to_3D() : () -> ()
|
||||
call @reshape_with_dynamic_shape_size_matrix_to_1D() : () -> ()
|
||||
call @reshape_with_dynamic_shape_size_matrix_to_3D() : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface }
|
||||
|
||||
func @reshape_with_static_shape_size_matrix_to_1D() {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
|
||||
// Initialize input.
|
||||
%input = alloc() : memref<2x3xf32>
|
||||
%dim_x = dim %input, %c0 : memref<2x3xf32>
|
||||
%dim_y = dim %input, %c1 : memref<2x3xf32>
|
||||
scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) {
|
||||
%i_i64 = index_cast %i : index to i64
|
||||
%i_f32 = sitofp %i_i64 : i64 to f32
|
||||
store %i_f32, %input[%i, %j] : memref<2x3xf32>
|
||||
}
|
||||
%unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
|
||||
call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> ()
|
||||
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
|
||||
// CHECK: [0, 0, 0]
|
||||
// CHECK: [1, 1, 1]
|
||||
|
||||
// Initialize shape.
|
||||
%shape = alloc() : memref<1xi64>
|
||||
%num_elements = muli %dim_x, %dim_y : index
|
||||
%num_elements_i64 = index_cast %num_elements : index to i64
|
||||
store %num_elements_i64, %shape[%c0] : memref<1xi64>
|
||||
|
||||
// 1. Ranked input, ranked output.
|
||||
%output_1 = lmhlo.reshape_memref_cast %input(%shape)
|
||||
: (memref<2x3xf32>, memref<1xi64>) -> memref<6xf32>
|
||||
%unranked_output_1 = memref_cast %output_1 : memref<6xf32> to memref<*xf32>
|
||||
call @print_memref_f32(%unranked_output_1) : (memref<*xf32>) -> ()
|
||||
// CHECK: rank = 1 offset = 0 sizes = [6] strides = [1]
|
||||
// CHECK: [0, 0, 0, 1, 1, 1]
|
||||
|
||||
// 2. Ranked input, unranked output.
|
||||
%output_2 = lmhlo.reshape_memref_cast %input(%shape)
|
||||
: (memref<2x3xf32>, memref<1xi64>) -> memref<*xf32>
|
||||
call @print_memref_f32(%output_2) : (memref<*xf32>) -> ()
|
||||
// CHECK: rank = 1 offset = 0 sizes = [6] strides = [1]
|
||||
// CHECK: [0, 0, 0, 1, 1, 1]
|
||||
|
||||
// 3. Unranked input, ranked output.
|
||||
%output_3 = lmhlo.reshape_memref_cast %unranked_input(%shape)
|
||||
: (memref<*xf32>, memref<1xi64>) -> memref<?xf32>
|
||||
%unranked_output_3 = memref_cast %output_3 : memref<?xf32> to memref<*xf32>
|
||||
call @print_memref_f32(%unranked_output_3) : (memref<*xf32>) -> ()
|
||||
// CHECK: rank = 1 offset = 0 sizes = [6] strides = [1]
|
||||
// CHECK: [0, 0, 0, 1, 1, 1]
|
||||
|
||||
// 4. Unranked input, unranked output.
|
||||
%output_4 = lmhlo.reshape_memref_cast %unranked_input(%shape)
|
||||
: (memref<*xf32>, memref<1xi64>) -> memref<*xf32>
|
||||
call @print_memref_f32(%output_4) : (memref<*xf32>) -> ()
|
||||
// CHECK: rank = 1 offset = 0 sizes = [6] strides = [1]
|
||||
// CHECK: [0, 0, 0, 1, 1, 1]
|
||||
return
|
||||
}
|
||||
|
||||
func @reshape_with_static_shape_size_matrix_to_3D() {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%c2 = constant 2 : index
|
||||
|
||||
// Initialize input.
|
||||
%input = alloc() : memref<2x3xf32>
|
||||
%dim_x = dim %input, %c0 : memref<2x3xf32>
|
||||
%dim_y = dim %input, %c1 : memref<2x3xf32>
|
||||
scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) {
|
||||
%i_i64 = index_cast %i : index to i64
|
||||
%i_f32 = sitofp %i_i64 : i64 to f32
|
||||
store %i_f32, %input[%i, %j] : memref<2x3xf32>
|
||||
}
|
||||
%unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
|
||||
call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> ()
|
||||
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
|
||||
// CHECK: [0, 0, 0]
|
||||
// CHECK: [1, 1, 1]
|
||||
|
||||
// Initialize shape.
|
||||
%shape = alloc() : memref<3xi64>
|
||||
%c1_i64 = constant 1 : i64
|
||||
%c2_i64 = constant 2 : i64
|
||||
%c3_i64 = constant 3 : i64
|
||||
store %c3_i64, %shape[%c0] : memref<3xi64>
|
||||
store %c1_i64, %shape[%c1] : memref<3xi64>
|
||||
store %c2_i64, %shape[%c2] : memref<3xi64>
|
||||
|
||||
// Static shape input and shape, dynamic output.
|
||||
%unranked_output = lmhlo.reshape_memref_cast %input(%shape)
|
||||
: (memref<2x3xf32>, memref<3xi64>) -> memref<*xf32>
|
||||
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
||||
// CHECK: rank = 3 offset = 0 sizes = [3, 1, 2] strides = [2, 2, 1]
|
||||
// CHECK: {{\[}}{{\[}}[0, 0]],
|
||||
// CHECK: {{\[}}[0, 1]],
|
||||
// CHECK: {{\[}}[1, 1]]]
|
||||
return
|
||||
}
|
||||
|
||||
func @reshape_with_dynamic_shape_size_matrix_to_1D() {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
|
||||
// Initialize input.
|
||||
%input = alloc() : memref<2x3xf32>
|
||||
%dim_x = dim %input, %c0 : memref<2x3xf32>
|
||||
%dim_y = dim %input, %c1 : memref<2x3xf32>
|
||||
scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) {
|
||||
%i_i64 = index_cast %i : index to i64
|
||||
%i_f32 = sitofp %i_i64 : i64 to f32
|
||||
store %i_f32, %input[%i, %j] : memref<2x3xf32>
|
||||
}
|
||||
%unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
|
||||
call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> ()
|
||||
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
|
||||
// CHECK: [0, 0, 0]
|
||||
// CHECK: [1, 1, 1]
|
||||
|
||||
// Initialize shape.
|
||||
%shape = alloc(%c1) : memref<?xi64>
|
||||
%num_elements = muli %dim_x, %dim_y : index
|
||||
%num_elements_i64 = index_cast %num_elements : index to i64
|
||||
store %num_elements_i64, %shape[%c0] : memref<?xi64>
|
||||
|
||||
// 1. Ranked input, unranked output.
|
||||
%output_2 = lmhlo.reshape_memref_cast %input(%shape)
|
||||
: (memref<2x3xf32>, memref<?xi64>) -> memref<*xf32>
|
||||
call @print_memref_f32(%output_2) : (memref<*xf32>) -> ()
|
||||
// CHECK: rank = 1 offset = 0 sizes = [6] strides = [1]
|
||||
// CHECK: [0, 0, 0, 1, 1, 1]
|
||||
|
||||
// 2. Unranked input, unranked output.
|
||||
%output_4 = lmhlo.reshape_memref_cast %unranked_input(%shape)
|
||||
: (memref<*xf32>, memref<?xi64>) -> memref<*xf32>
|
||||
call @print_memref_f32(%output_4) : (memref<*xf32>) -> ()
|
||||
// CHECK: rank = 1 offset = 0 sizes = [6] strides = [1]
|
||||
// CHECK: [0, 0, 0, 1, 1, 1]
|
||||
return
|
||||
}
|
||||
|
||||
func @reshape_with_dynamic_shape_size_matrix_to_3D() {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%c2 = constant 2 : index
|
||||
%c3 = constant 3 : index
|
||||
|
||||
// Initialize input.
|
||||
%input = alloc() : memref<2x3xf32>
|
||||
%dim_x = dim %input, %c0 : memref<2x3xf32>
|
||||
%dim_y = dim %input, %c1 : memref<2x3xf32>
|
||||
scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) {
|
||||
%i_i64 = index_cast %i : index to i64
|
||||
%i_f32 = sitofp %i_i64 : i64 to f32
|
||||
store %i_f32, %input[%i, %j] : memref<2x3xf32>
|
||||
}
|
||||
%unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
|
||||
call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> ()
|
||||
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
|
||||
// CHECK: [0, 0, 0]
|
||||
// CHECK: [1, 1, 1]
|
||||
|
||||
// Initialize shape.
|
||||
%shape = alloc(%c3) : memref<?xi64>
|
||||
%c1_i64 = constant 1 : i64
|
||||
%c2_i64 = constant 2 : i64
|
||||
%c3_i64 = constant 3 : i64
|
||||
store %c3_i64, %shape[%c0] : memref<?xi64>
|
||||
store %c1_i64, %shape[%c1] : memref<?xi64>
|
||||
store %c2_i64, %shape[%c2] : memref<?xi64>
|
||||
|
||||
// Static shape input, dynamic output and shape.
|
||||
%unranked_output = lmhlo.reshape_memref_cast %input(%shape)
|
||||
: (memref<2x3xf32>, memref<?xi64>) -> memref<*xf32>
|
||||
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
||||
// CHECK: rank = 3 offset = 0 sizes = [3, 1, 2] strides = [2, 2, 1]
|
||||
// CHECK: {{\[}}{{\[}}[0, 0]],
|
||||
// CHECK: {{\[}}[0, 1]],
|
||||
// CHECK: {{\[}}[1, 1]]]
|
||||
return
|
||||
}
|
|
@ -17,7 +17,7 @@ func @dynamic_reshape_from_unranked(
|
|||
return %reshaped : tensor<?xf32>
|
||||
}
|
||||
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>, [[SHAPE:%.*]]: memref<1xi32>)
|
||||
// CHECK-NEXT: reshape_memref_cast [[ARG]]([[SHAPE]])
|
||||
// CHECK-NEXT: memref_reshape [[ARG]]([[SHAPE]])
|
||||
// CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
|
||||
|
||||
// -----
|
||||
|
@ -30,5 +30,5 @@ func @dynamic_reshape_to_unranked(
|
|||
return %reshaped : tensor<*xf32>
|
||||
}
|
||||
// CHECK-SAME: ([[ARG:%.*]]: memref<?xf32>, [[SHAPE:%.*]]: memref<?xi32>)
|
||||
// CHECK-NEXT: reshape_memref_cast [[ARG]]([[SHAPE]])
|
||||
// CHECK-NEXT: memref_reshape [[ARG]]([[SHAPE]])
|
||||
// CHECK-SAME: : (memref<?xf32>, memref<?xi32>) -> memref<*xf32>
|
||||
|
|
|
@ -197,10 +197,11 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) {
|
|||
// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]]
|
||||
// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index
|
||||
|
||||
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = lmhlo.dynamic_memref_cast
|
||||
// CHECK-SAME: %[[OPERAND]](%[[RESULT_DIM_1]], %[[RESULT_DIM_2]])
|
||||
// CHECK-SAME: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]]
|
||||
// CHECK-SAME: : memref<?x?xf32> -> memref<?x?xf32, #map>
|
||||
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to
|
||||
// CHECK-SAME: offset: [0],
|
||||
// CHECK-SAME: sizes: {{\[}}%[[RESULT_DIM_1]], %[[RESULT_DIM_2]]]
|
||||
// CHECK-SAME: strides: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]]
|
||||
// CHECK-SAME: : memref<?x?xf32> to memref<?x?xf32, #map>
|
||||
|
||||
// CHECK: "lmhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) {
|
||||
// CHECK-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||
|
|
|
@ -267,7 +267,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
|
|||
%13 = absf %arg3 : f32
|
||||
linalg.yield %13 : f32
|
||||
}
|
||||
%2 = lmhlo.reshape_memref_cast %1(%arg1)
|
||||
%2 = memref_reshape %1(%arg1)
|
||||
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
|
||||
return %2 : memref<*xf32>
|
||||
}
|
||||
|
@ -279,7 +279,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
|
|||
// CHECK-NOT: scf.for
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: absf
|
||||
// CHECK: reshape_memref_cast
|
||||
// CHECK: memref_reshape
|
||||
|
||||
// TILED-LABEL: func @view_result
|
||||
// TILED-DAG: %[[C2:.*]] = constant 2
|
||||
|
@ -288,7 +288,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
|
|||
// TILED-NOT: scf.for
|
||||
// TILED: linalg.generic
|
||||
// TILED: absf
|
||||
// TILED: reshape_memref_cast
|
||||
// TILED: memref_reshape
|
||||
|
||||
|
||||
// PLOOP-LABEL: func @view_result
|
||||
|
@ -297,5 +297,5 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
|
|||
// PLOOP-NOT: scf.parallel
|
||||
// PLOOP: linalg.generic
|
||||
// PLOOP: absf
|
||||
// PLOOP: reshape_memref_cast
|
||||
// PLOOP: memref_reshape
|
||||
|
||||
|
|
|
@ -1,65 +0,0 @@
|
|||
// RUN: mlir-hlo-opt %s -lower-affine -convert-scf-to-std -test-lhlo-legalize-to-llvm -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @static_memref_cast
|
||||
func @static_memref_cast(%buf : memref<10x1x5xf32>) {
|
||||
%0 = lmhlo.static_memref_cast %buf
|
||||
: memref<10x1x5xf32> -> memref<10x5xf32, offset: 2, strides: [5, 1]>
|
||||
return
|
||||
}
|
||||
// CHECK: %[[INPUT_MEMREF_BLDR:.*]] = llvm.mlir.undef : [[DESCRIPTOR_TYPE_3D:!.*]]
|
||||
// CHECK: llvm.insertvalue
|
||||
// CHECK: %[[MEMREF_BLDR_0:.*]] = llvm.mlir.undef : [[DESCRIPTOR_TYPE_2D:!.*]]
|
||||
|
||||
// CHECK: %[[IN_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF:.*]][0] : [[DESCRIPTOR_TYPE_3D]]
|
||||
// CHECK: %[[PTR:.*]] = llvm.bitcast %[[IN_PTR]] : !llvm.ptr<float> to !llvm.ptr<float>
|
||||
// CHECK: %[[MEMREF_BLDR_1:.*]] = llvm.insertvalue %[[PTR]], %[[MEMREF_BLDR_0]][0] : [[DESCRIPTOR_TYPE_2D]]
|
||||
|
||||
// CHECK: %[[IN_ALIGNED_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][1] : [[DESCRIPTOR_TYPE_3D]]
|
||||
// CHECK: %[[ALIGNED_PTR:.*]] = llvm.bitcast %[[IN_ALIGNED_PTR]] : !llvm.ptr<float> to !llvm.ptr<float>
|
||||
// CHECK: %[[MEMREF_BLDR_2:.*]] = llvm.insertvalue %[[ALIGNED_PTR]], %[[MEMREF_BLDR_1]][1] : [[DESCRIPTOR_TYPE_2D]]
|
||||
|
||||
// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
|
||||
// CHECK: %[[MEMREF_BLDR_3:.*]] = llvm.insertvalue %[[C2]], %[[MEMREF_BLDR_2]][2] : [[DESCRIPTOR_TYPE_2D]]
|
||||
|
||||
// CHECK: %[[C10:.*]] = llvm.mlir.constant(10 : index) : !llvm.i64
|
||||
// CHECK: %[[MEMREF_BLDR_4:.*]] = llvm.insertvalue %[[C10]], %[[MEMREF_BLDR_3]][3, 0] : [[DESCRIPTOR_TYPE_2D]]
|
||||
// CHECK: %[[C5:.*]] = llvm.mlir.constant(5 : index) : !llvm.i64
|
||||
// CHECK: %[[MEMREF_BLDR_5:.*]] = llvm.insertvalue %[[C5]], %[[MEMREF_BLDR_4]][4, 0] : [[DESCRIPTOR_TYPE_2D]]
|
||||
// CHECK: %[[C5_:.*]] = llvm.mlir.constant(5 : index) : !llvm.i64
|
||||
// CHECK: %[[MEMREF_BLDR_6:.*]] = llvm.insertvalue %[[C5_]], %[[MEMREF_BLDR_5]][3, 1] : [[DESCRIPTOR_TYPE_2D]]
|
||||
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK: %[[MEMREF_BLDR_7:.*]] = llvm.insertvalue %[[C1]], %[[MEMREF_BLDR_6]][4, 1] : [[DESCRIPTOR_TYPE_2D]]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @dynamic_memref_cast
|
||||
func @dynamic_memref_cast(%buf : memref<?x?xf32>) {
|
||||
%size_X = constant 10 : index
|
||||
%size_Y = constant 50 : index
|
||||
%stride_X = constant 1 : index
|
||||
%stride_Y = constant 0 : index
|
||||
%0 = lmhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%stride_X, %stride_Y]
|
||||
: memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
|
||||
return
|
||||
}
|
||||
// CHECK: %[[C10:.*]] = llvm.mlir.constant(10 : index) : !llvm.i64
|
||||
// CHECK: %[[C50:.*]] = llvm.mlir.constant(50 : index) : !llvm.i64
|
||||
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
|
||||
// CHECK: %[[MEMREF_BLDR_0:.*]] = llvm.mlir.undef : [[DESCRIPTOR_TYPE:!.*]]
|
||||
|
||||
// CHECK: %[[IN_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF:.*]][0] : [[DESCRIPTOR_TYPE]]
|
||||
// CHECK: %[[PTR:.*]] = llvm.bitcast %[[IN_PTR]] : !llvm.ptr<float> to !llvm.ptr<float>
|
||||
// CHECK: %[[MEMREF_BLDR_1:.*]] = llvm.insertvalue %[[PTR]], %[[MEMREF_BLDR_0]][0] : [[DESCRIPTOR_TYPE]]
|
||||
|
||||
// CHECK: %[[IN_ALIGNED_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][1] : [[DESCRIPTOR_TYPE]]
|
||||
// CHECK: %[[ALIGNED_PTR:.*]] = llvm.bitcast %[[IN_ALIGNED_PTR]] : !llvm.ptr<float> to !llvm.ptr<float>
|
||||
// CHECK: %[[MEMREF_BLDR_2:.*]] = llvm.insertvalue %[[ALIGNED_PTR]], %[[MEMREF_BLDR_1]][1] : [[DESCRIPTOR_TYPE]]
|
||||
|
||||
// CHECK: %[[SRC_OFFSET:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][2] : [[DESCRIPTOR_TYPE]]
|
||||
// CHECK: %[[MEMREF_BLDR_3:.*]] = llvm.insertvalue %[[SRC_OFFSET]], %[[MEMREF_BLDR_2]][2] : [[DESCRIPTOR_TYPE]]
|
||||
// CHECK: %[[MEMREF_BLDR_4:.*]] = llvm.insertvalue %[[C10]], %[[MEMREF_BLDR_3]][3, 0] : [[DESCRIPTOR_TYPE]]
|
||||
// CHECK: %[[MEMREF_BLDR_5:.*]] = llvm.insertvalue %[[C1]], %[[MEMREF_BLDR_4]][4, 0] : [[DESCRIPTOR_TYPE]]
|
||||
// CHECK: %[[MEMREF_BLDR_6:.*]] = llvm.insertvalue %[[C50]], %[[MEMREF_BLDR_5]][3, 1] : [[DESCRIPTOR_TYPE]]
|
||||
// CHECK: %[[MEMREF_BLDR_7:.*]] = llvm.insertvalue %[[C0]], %[[MEMREF_BLDR_6]][4, 1] : [[DESCRIPTOR_TYPE]]
|
|
@ -429,120 +429,6 @@ func @case_memref(%index: memref<i32>, %operand_1: memref<f32>, %operand_2: memr
|
|||
|
||||
// -----
|
||||
|
||||
func @static_memref_cast(%in: memref<10x1xf32>) {
|
||||
%out = lmhlo.static_memref_cast %in
|
||||
: memref<10x1xf32> -> memref<10xf32, offset: 0, strides: [1]>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @static_memref_cast
|
||||
|
||||
// -----
|
||||
|
||||
func @static_memref_cast_dynamic_operand(%in: memref<10x?xf32>) {
|
||||
// expected-error @+1 {{operand must have static shape}}
|
||||
%out = lmhlo.static_memref_cast %in
|
||||
: memref<10x?xf32> -> memref<10x1xf32, offset: 0, strides: [10, 1]>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @static_memref_cast_dynamic_result(%in: memref<10x1xf32>) {
|
||||
// expected-error @+1 {{result must have static shape}}
|
||||
%out = lmhlo.static_memref_cast %in
|
||||
: memref<10x1xf32> -> memref<10x?xf32, offset: 0, strides: [?, ?]>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @dynamic_memref_cast(%in: memref<?xf32>) {
|
||||
%size = constant 10 : index
|
||||
%step = constant 1 : index
|
||||
%out = lmhlo.dynamic_memref_cast %in(%size)[%step]
|
||||
: memref<?xf32> -> memref<?xf32, offset: 0, strides: [?]>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @dynamic_memref_cast
|
||||
|
||||
// -----
|
||||
|
||||
func @dynamic_memref_cast_incompatible_result_type(%in: memref<?xf32>) {
|
||||
// expected-error @+3 {{`sizes` args count must be equal to the rank of the output memref}}
|
||||
%size = constant 10 : index
|
||||
%step = constant 1 : index
|
||||
%out = lmhlo.dynamic_memref_cast %in(%size)[%step]
|
||||
: memref<?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
|
||||
return
|
||||
}
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @reshape_memref_cast(
|
||||
func @reshape_memref_cast(%unranked: memref<*xf32>, %shape1: memref<1xi32>,
|
||||
%shape2: memref<2xi32>, %shape3: memref<?xi32>) {
|
||||
// CHECK-SAME: [[UNRANKED:%.*]]: memref<*xf32>, [[SHAPE_1:%.*]]: memref<1xi32>,
|
||||
// CHECK-SAME: [[SHAPE_2:%.*]]: memref<2xi32>, [[SHAPE_3:%.*]]: memref<?xi32>
|
||||
|
||||
// CHECK-NEXT: [[DYN_VEC:%.*]] = lmhlo.reshape_memref_cast [[UNRANKED]]
|
||||
// CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
|
||||
%dyn_vec = lmhlo.reshape_memref_cast %unranked(%shape1)
|
||||
: (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
|
||||
|
||||
// CHECK-NEXT: [[DYN_MAT:%.*]] = lmhlo.reshape_memref_cast [[DYN_VEC]]
|
||||
// CHECK-SAME: : (memref<?xf32>, memref<2xi32>) -> memref<?x?xf32>
|
||||
%dyn_mat = lmhlo.reshape_memref_cast %dyn_vec(%shape2)
|
||||
: (memref<?xf32>, memref<2xi32>) -> memref<?x?xf32>
|
||||
|
||||
// CHECK-NEXT: {{%.*}} = lmhlo.reshape_memref_cast [[DYN_MAT]]
|
||||
// CHECK-SAME: : (memref<?x?xf32>, memref<?xi32>) -> memref<*xf32>
|
||||
%new_unranked = lmhlo.reshape_memref_cast %dyn_mat(%shape3)
|
||||
: (memref<?x?xf32>, memref<?xi32>) -> memref<*xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @reshape_memref_cast_element_type_mismatch(
|
||||
%buf: memref<*xf32>, %shape: memref<1xi32>) {
|
||||
// expected-error @+1 {{element types of source and destination memref types should be the same}}
|
||||
lmhlo.reshape_memref_cast %buf(%shape)
|
||||
: (memref<*xf32>, memref<1xi32>) -> memref<?xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @reshape_memref_cast_dst_ranked_shape_unranked(
|
||||
%buf: memref<*xf32>, %shape: memref<?xi32>) {
|
||||
// expected-error @+1 {{cannot use shape operand with dynamic length to cast statically-ranked memref type}}
|
||||
lmhlo.reshape_memref_cast %buf(%shape)
|
||||
: (memref<*xf32>, memref<?xi32>) -> memref<?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @reshape_memref_cast_dst_shape_rank_mismatch(
|
||||
%buf: memref<*xf32>, %shape: memref<1xi32>) {
|
||||
// expected-error @+1 {{length of shape operand differs from the result's memref rank}}
|
||||
lmhlo.reshape_memref_cast %buf(%shape)
|
||||
: (memref<*xf32>, memref<1xi32>) -> memref<?x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @reshape_memref_cast_affine_map_is_not_identity(
|
||||
%buf: memref<4x4xf32, offset: 0, strides: [3, 2]>,
|
||||
%shape: memref<1xi32>) {
|
||||
// expected-error @+1 {{operand memref type should have identity affine map}}
|
||||
lmhlo.reshape_memref_cast %buf(%shape)
|
||||
: (memref<4x4xf32, offset: 0, strides: [3, 2]>, memref<1xi32>)
|
||||
-> memref<8xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @atan2_memrefs
|
||||
func @atan2_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
"lmhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
|
||||
|
|
Loading…
Reference in New Issue