[HLO] Delete LHLO memref cast ops and migrate to STD ones.

PiperOrigin-RevId: 340663578
This commit is contained in:
Alexander Belyaev 2020-11-04 09:25:57 -08:00 committed by TensorFlow MLIR Team
parent 82031b356c
commit 3d930d08c2
17 changed files with 22 additions and 1071 deletions

View File

@ -314,169 +314,6 @@ def HLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> {
);
}
//===----------------------------------------------------------------------===//
// StaticMemRefCastOp
//===----------------------------------------------------------------------===//
def HLO_StaticMemRefCastOp: Op<LHLO_Dialect, "static_memref_cast",
[NoSideEffect, DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
let summary = [{
modifies the offset, sizes and strides of a statically shaped memref
}];
let description = [{
Casts the statically shaped memref operand to a memref with optionally
modified offsets, sizes and strides.
Example:
```mlir
%buf_transformed =
lmhlo.static_memref_cast %buf
: memref<1x5xf32> -> memref<5xf32, offset: 2, strides: [1]>
// The result of the op is a rank-1 memref with `[5]` shape, stride 1 and
// offset 2.
```
}];
let arguments = (ins Arg<LHLO_Buffer, "", []>:$operand);
let results = (outs Res<LHLO_Buffer, "", []>:$result);
let builders = [
OpBuilderDAG<(ins "MemRefType":$resultType, "Value":$operand),
[{
$_state.addOperands(operand);
$_state.types.push_back(resultType);
}]>];
let extraClassDeclaration = [{
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
}];
let verifier = [{ return Verify(*this); }];
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `->` type($result)
}];
}
//===----------------------------------------------------------------------===//
// DynamicMemRefCastOp
//===----------------------------------------------------------------------===//
def HLO_DynamicMemRefCastOp: Op<LHLO_Dialect, "dynamic_memref_cast",
[SameVariadicOperandSize, NoSideEffect,
DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
let summary = "dynamic memref cast operation";
let description = [{
Change sizes and strides of a memref using the values computed in runtime.
Example:
```mlir
%buf_transformed =
lmhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%step_X, %step_Y]
: memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
// The result of the op is a type-erased memref with `[%size_X, %size_Y]`
// shape and `[%step_X, %step_Y]` strides. The offset will be inherited
// from the input.
```
}];
let arguments = (ins
Arg<LHLO_Buffer, "", []>:$operand,
Variadic<Index>:$sizes,
Variadic<Index>:$strides
);
let results = (outs Res<LHLO_Buffer, "", []>:$result);
let builders = [
OpBuilderDAG<(ins "MemRefType":$resultType, "Value":$operand,
"ValueRange":$sizes, "ValueRange":$strides),
[{
$_state.addOperands(operand);
$_state.addOperands(sizes);
$_state.addOperands(strides);
$_state.types.push_back(resultType);
}]>];
let extraClassDeclaration = [{
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
}];
let verifier = [{ return Verify(*this); }];
let assemblyFormat = [{
$operand `(` $sizes `)` `[` $strides `]` attr-dict `:` type($operand) `->`
type($result)
}];
}
//===----------------------------------------------------------------------===//
// ReshapeMemRefCastOp
//===----------------------------------------------------------------------===//
def ReshapeMemRefCastOp: Op<LHLO_Dialect, "reshape_memref_cast", [
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
NoSideEffect]> {
let summary = "reshape memref cast operation";
let description = [{
The `reshape_memref_cast` operation converts a memref from one type to an
equivalent type with a provided shape. The data is never copied or moved.
The source and destination types are compatible if both have the same
element type, address space and identity layout map. The following
combinations are possible:
a. Both are ranked memref types.
```mlir
// Reshape statically-shaped memref.
%dst = reshape_memref_cast %src(%shape)
: (memref<4x1xf32>, memref<1xi32>) to memref<4xf32>
%dst0 = reshape_memref_cast %src(%shape0)
: (memref<4x1xf32>, memref<2xi32>) to memref<2x2xf32>
```
b. Source type is ranked, destination type is unranked.
```mlir
// Reshape dynamically-shaped 1D memref.
%dst = reshape_memref_cast %src(%shape)
: (memref<?xf32>, memref<?xi32>) to memref<*xf32>
```
c. Source type is unranked, destination type is ranked.
```mlir
// Flatten unranked memref.
%dst = reshape_memref_cast %src(%shape)
: (memref<*xf32>, memref<1xi32>) to memref<?xf32>
```
d. Both are unranked memref types.
```mlir
// Reshape unranked memref.
%dst = reshape_memref_cast %src(%shape)
: (memref<*xf32>, memref<?xi32>) to memref<*xf32>
```
}];
let arguments = (ins
AnyRankedOrUnrankedMemRef:$operand,
LHLO_ExtentBuffer:$shape
);
let results = (outs AnyRankedOrUnrankedMemRef:$result);
let extraClassDeclaration = [{
BaseMemRefType getType() {
return getResult().getType().cast<BaseMemRefType>(); }
}];
let verifier = [{ return Verify(*this); }];
let assemblyFormat = [{
$operand `(` $shape `)` attr-dict `:` `(` type($operand) `,` type($shape)
`)` `->` type($result)
}];
}
//===----------------------------------------------------------------------===//
// LMHLO Other op definitions.
//===----------------------------------------------------------------------===//

View File

@ -46,12 +46,6 @@ def LhloLegalizeToGpuPass : Pass<"lhlo-legalize-to-gpu", "FuncOp"> {
}
def TestLhloToLLVMPass : Pass<"test-lhlo-legalize-to-llvm", "FuncOp"> {
let summary = "Legalize from LHLO dialect to LLVM.";
let constructor = "createTestLhloToLLVMPass()";
}
def LhloLegalizeToParallelLoopsPass : Pass<"lhlo-legalize-to-parallel-loops", "FuncOp"> {
let summary = "Legalize from LHLO dialect to parallel loops.";
let constructor = "createLegalizeLhloToParallelLoopsPass()";

View File

@ -35,8 +35,6 @@ inline void registerAllMhloPasses() { registerMHLOPasses(); }
namespace lmhlo {
std::unique_ptr<Pass> createTestLhloToLLVMPass();
#define GEN_PASS_REGISTRATION
#include "mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.h.inc"

View File

@ -24,8 +24,6 @@ limitations under the License.
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
class LLVMTypeConverter;
class LowerToLLVMOptions;
class OwningRewritePatternList;
// Populates a collection of rewrite patterns to realize element-wise operations
@ -94,14 +92,6 @@ void PopulateTrigonometricToApproximationPatterns(
} // namespace mhlo
namespace lmhlo {
/// Collect a set of patterns to convert from the LHLO dialect to LLVM.
void PopulateLhloToLLVMConversionPatterns(LLVMTypeConverter *converter,
OwningRewritePatternList *patterns);
} // namespace lmhlo
namespace chlo {
// Populates a collection of conversion patterns for legalizing client-HLO to

View File

@ -88,76 +88,6 @@ void ConstOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
results.insert<EraseConstOp>(context);
}
//===----------------------------------------------------------------------===//
// StaticMemRefCastOp
//===----------------------------------------------------------------------===//
Value StaticMemRefCastOp::getViewSource() { return *getODSOperands(0).begin(); }
static LogicalResult Verify(StaticMemRefCastOp op) {
if (!op.operand().getType().cast<ShapedType>().hasStaticShape())
return op.emitOpError("operand must have static shape");
if (!op.getType().hasStaticShape())
return op.emitOpError("result must have static shape");
return success();
}
//===----------------------------------------------------------------------===//
// DynamicMemRefCastOp
//===----------------------------------------------------------------------===//
Value DynamicMemRefCastOp::getViewSource() {
return *getODSOperands(0).begin();
}
static LogicalResult Verify(DynamicMemRefCastOp op) {
// Check if `sizes` and `strides` args are compatible with the result type.
if (op.sizes().size() != op.getType().getRank())
return op.emitOpError(
"`sizes` args count must be equal to the rank of the output memref");
return success();
}
//===----------------------------------------------------------------------===//
// ReshapeMemrefCastOp
//===----------------------------------------------------------------------===//
Value ReshapeMemRefCastOp::getViewSource() { return operand(); }
static LogicalResult Verify(ReshapeMemRefCastOp op) {
Type operandType = op.operand().getType();
Type resultType = op.result().getType();
Type operandElementType = operandType.cast<ShapedType>().getElementType();
Type resultElementType = resultType.cast<ShapedType>().getElementType();
if (operandElementType != resultElementType)
return op.emitOpError(
"element types of source and destination memref "
"types should be the same");
if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
if (!operandMemRefType.getAffineMaps().empty())
return op.emitOpError(
"operand memref type should have identity affine map");
int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0);
auto resultMemRefType = resultType.dyn_cast<MemRefType>();
if (resultMemRefType) {
if (shapeSize == ShapedType::kDynamicSize)
return op.emitOpError(
"cannot use shape operand with dynamic length to "
"cast statically-ranked memref type");
if (shapeSize != resultMemRefType.getRank())
return op.emitOpError(
"length of shape operand differs from the result's memref rank");
if (!resultMemRefType.getAffineMaps().empty())
return op.emitOpError(
"result memref type should have identity affine map");
}
return success();
}
} // namespace lmhlo
} // namespace mlir

View File

@ -134,8 +134,6 @@ add_mlir_library(LmhloPasses
lhlo_fuse_linalg.cc
lhlo_legalize_to_affine.cc
lhlo_legalize_to_gpu.cc
lhlo_legalize_to_llvm.cc
lhlo_legalize_to_llvm_pass.cc
lhlo_legalize_to_parallel_loops.cc
DEPENDS

View File

@ -206,7 +206,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
// Inserts dynamic memref to change the layout of the memref to put 0-stride
// and size of the target dimension if size-1 dimension expansion is
// necessary.
lmhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp(
MemRefReinterpretCastOp InsertDynamicMemrefCastOp(
mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
auto loc = op.getLoc();
auto operand_type = operand.getType().cast<MemRefType>();
@ -259,8 +259,13 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
makeStridedLinearLayoutMap(dynamic_layout,
/*offset=*/0, b->getContext()));
auto transformed_operand = b->create<lmhlo::DynamicMemRefCastOp>(
loc, type_erased_memref_type, operand, sizes, strides);
SmallVector<int64_t, 2> static_sizes(sizes.size(),
ShapedType::kDynamicSize);
SmallVector<int64_t, 2> static_strides(strides.size(),
ShapedType::kDynamicStrideOrOffset);
auto transformed_operand = b->create<MemRefReinterpretCastOp>(
loc, type_erased_memref_type, operand, /*offset=*/0, static_sizes,
static_strides, llvm::None, sizes, strides);
return transformed_operand;
}
};
@ -284,7 +289,7 @@ struct HloToLhloDynamicReshapeConverter
return failure();
}
mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
rewriter.replaceOpWithNewOp<lmhlo::ReshapeMemRefCastOp>(
rewriter.replaceOpWithNewOp<MemRefReshapeOp>(
op, result_type, adaptor.operand(), adaptor.output_shape());
return success();
}

View File

@ -1,370 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace lmhlo {
namespace {
struct StaticMemRefCastOpConverter
: public ConvertOpToLLVMPattern<StaticMemRefCastOp> {
using ConvertOpToLLVMPattern<StaticMemRefCastOp>::ConvertOpToLLVMPattern;
LogicalResult matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto cast_op = cast<StaticMemRefCastOp>(op);
StaticMemRefCastOp::Adaptor operands_adaptor(operands);
MemRefDescriptor sourceMemRef(operands_adaptor.operand());
MemRefType targetMemRefType =
cast_op.getResult().getType().cast<MemRefType>();
auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
.dyn_cast_or_null<LLVM::LLVMType>();
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
return failure();
// Create descriptor.
auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
Type llvmTargetElementTy = desc.getElementPtrType();
// Set allocated ptr.
Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
allocated =
rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
desc.setAllocatedPtr(rewriter, loc, allocated);
// Set aligned ptr.
Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
desc.setAlignedPtr(rewriter, loc, ptr);
// Fill size and stride descriptors in memref.
auto target_sizes = targetMemRefType.getShape();
int64_t target_offset;
llvm::SmallVector<int64_t, 4> target_strides;
if (failed((getStridesAndOffset(targetMemRefType, target_strides,
target_offset))))
return failure();
// Copy offset of `targetMemRef`.
desc.setConstantOffset(rewriter, loc, target_offset);
for (int i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
desc.setConstantSize(rewriter, loc, i, target_sizes[i]);
desc.setConstantStride(rewriter, loc, i, target_strides[i]);
}
rewriter.replaceOp(op, {desc});
return success();
}
};
struct DynamicMemRefCastOpConverter
: public ConvertOpToLLVMPattern<DynamicMemRefCastOp> {
using ConvertOpToLLVMPattern<DynamicMemRefCastOp>::ConvertOpToLLVMPattern;
LogicalResult matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto cast_op = cast<DynamicMemRefCastOp>(op);
DynamicMemRefCastOp::Adaptor operands_adaptor(operands);
MemRefDescriptor sourceMemRef(operands_adaptor.operand());
MemRefType targetMemRefType =
cast_op.getResult().getType().cast<MemRefType>();
auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
.dyn_cast_or_null<LLVM::LLVMType>();
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
return failure();
// Create descriptor.
auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
Type llvmTargetElementTy = desc.getElementPtrType();
// Set allocated ptr.
Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
allocated =
rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
desc.setAllocatedPtr(rewriter, loc, allocated);
// Set aligned ptr.
Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
desc.setAlignedPtr(rewriter, loc, ptr);
// Copy offset of `sourceMemRef`.
desc.setOffset(rewriter, loc, sourceMemRef.offset(rewriter, loc));
// Fill size and stride descriptors in memref.
if (!cast_op.sizes().empty()) {
auto sizes = operands_adaptor.sizes();
auto strides = operands_adaptor.strides();
for (int i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
desc.setSize(rewriter, loc, i, sizes[i]);
desc.setStride(rewriter, loc, i, strides[i]);
}
}
rewriter.replaceOp(op, {desc});
return success();
}
};
struct ReshapeMemRefCastOpConverter
: public ConvertOpToLLVMPattern<ReshapeMemRefCastOp> {
using ConvertOpToLLVMPattern<ReshapeMemRefCastOp>::ConvertOpToLLVMPattern;
LogicalResult matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto reshape_op = cast<ReshapeMemRefCastOp>(op);
auto dst_type = reshape_op.getResult().getType().cast<BaseMemRefType>();
auto element_type = dst_type.getElementType();
auto shape = reshape_op.shape();
ReshapeMemRefCastOp::Adaptor operands_adaptor(operands);
PtrsAndOffset ptrs_n_offset = ExtractMemRefPtrsAndOffset(
loc, reshape_op.operand(), operands_adaptor.operand(), &rewriter);
MemRefDescriptor shape_desc(operands_adaptor.shape());
auto shape_memref_type = shape.getType().cast<MemRefType>();
if (shape_memref_type.hasStaticShape()) {
auto shape_length = shape_memref_type.getDimSize(0);
MemRefType targetMemRefType = MemRefType::get(
SmallVector<int64_t, 1>(shape_length, 1), element_type);
auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
.dyn_cast_or_null<LLVM::LLVMType>();
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
return failure();
// Create descriptor.
auto desc =
MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
desc.setAllocatedPtr(rewriter, loc, ptrs_n_offset.allocated_ptr);
desc.setAlignedPtr(rewriter, loc, ptrs_n_offset.aligned_ptr);
desc.setOffset(rewriter, loc, ptrs_n_offset.offset);
auto llvm_index_type = typeConverter.getIndexType();
auto llvm_index_ptr_type = llvm_index_type.getPointerTo();
Value stride_carried = rewriter.create<LLVM::ConstantOp>(
loc, llvm_index_type,
rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
for (int i = shape_length - 1; i >= 0; --i) {
Value pos = rewriter.create<LLVM::ConstantOp>(
loc, llvm_index_type,
rewriter.getIntegerAttr(rewriter.getIndexType(), i));
Value ptr = rewriter.create<LLVM::GEPOp>(
loc, llvm_index_ptr_type, shape_desc.alignedPtr(rewriter, loc),
ValueRange{pos});
Value extracted_size = rewriter.create<LLVM::LoadOp>(loc, ptr);
desc.setSize(rewriter, loc, i, extracted_size);
desc.setStride(rewriter, loc, i, stride_carried);
// Update stride
if (i > 0) {
stride_carried =
rewriter.create<LLVM::MulOp>(loc, stride_carried, extracted_size);
}
}
if (dst_type.isa<MemRefType>()) {
rewriter.replaceOp(op, {desc});
} else {
Value rank = rewriter.create<LLVM::ConstantOp>(
loc, llvm_index_type,
rewriter.getIntegerAttr(rewriter.getIndexType(), shape_length));
Value alloca =
typeConverter.promoteOneMemRefDescriptor(loc, desc, rewriter);
Value void_ptr =
rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), alloca);
auto unranked_desc = UnrankedMemRefDescriptor::pack(
rewriter, loc, typeConverter, dst_type.cast<UnrankedMemRefType>(),
{rank, void_ptr});
rewriter.replaceOp(op, {unranked_desc});
}
return success();
}
// The shape is a rank-1 tensor with unknown length.
Value result_rank = shape_desc.size(rewriter, loc, 0);
// TODO(herhut): Propely handle address spaces.
unsigned address_space = 0;
auto target_type =
typeConverter
.convertType(UnrankedMemRefType::get(element_type, address_space))
.cast<LLVM::LLVMType>();
// Create the unranked memref descriptor that holds the ranked one. The
// inner descriptor is allocated on stack.
UnrankedMemRefDescriptor target_desc =
UnrankedMemRefDescriptor::undef(rewriter, loc, target_type);
target_desc.setRank(rewriter, loc, result_rank);
SmallVector<Value, 1> sizes;
UnrankedMemRefDescriptor::computeSizes(rewriter, loc, typeConverter,
{target_desc}, sizes);
auto void_ptr_type = LLVM::LLVMType::getInt8PtrTy(rewriter.getContext());
Value ranked_desc_mem = rewriter.create<LLVM::AllocaOp>(
loc, void_ptr_type, sizes.front(), llvm::None);
target_desc.setMemRefDescPtr(rewriter, loc, ranked_desc_mem);
// Fill the fixed parts. For this, we cast to a 0-D memref.
auto zero_d_memref_type = MemRefType::get({}, element_type);
Value as_zero_d = rewriter.create<LLVM::BitcastOp>(
loc,
typeConverter.convertType(zero_d_memref_type)
.cast<LLVM::LLVMType>()
.getPointerTo(address_space),
ranked_desc_mem);
// Some common constants. Use 32 bit where required by gep struct indexes.
auto int32_type = typeConverter.convertType(rewriter.getI32Type());
Value zero_index = rewriter.create<LLVM::ConstantOp>(
loc, typeConverter.getIndexType(), rewriter.getIndexAttr(0));
Value zero = rewriter.create<LLVM::ConstantOp>(
loc, int32_type, rewriter.getI32IntegerAttr(0));
Value one = rewriter.create<LLVM::ConstantOp>(
loc, int32_type, rewriter.getI32IntegerAttr(1));
Value two = rewriter.create<LLVM::ConstantOp>(
loc, int32_type, rewriter.getI32IntegerAttr(2));
// Set base_pointer and aligned pointer.
auto element_ptr_ptr_type = typeConverter.convertType(element_type)
.cast<LLVM::LLVMType>()
.getPointerTo(address_space)
.getPointerTo(address_space);
auto base_gep = rewriter.create<LLVM::GEPOp>(
loc, element_ptr_ptr_type, as_zero_d, ValueRange({zero_index, zero}));
rewriter.create<LLVM::StoreOp>(loc, ptrs_n_offset.allocated_ptr, base_gep);
auto aligned_gep = rewriter.create<LLVM::GEPOp>(
loc, element_ptr_ptr_type, as_zero_d, ValueRange({zero_index, one}));
rewriter.create<LLVM::StoreOp>(loc, ptrs_n_offset.aligned_ptr, aligned_gep);
// Set offset.
auto index_ptr_type =
typeConverter.getIndexType().getPointerTo(address_space);
auto offset_gep = rewriter.create<LLVM::GEPOp>(
loc, index_ptr_type, as_zero_d, ValueRange({zero_index, two}));
rewriter.create<LLVM::StoreOp>(loc, ptrs_n_offset.offset, offset_gep);
// Use the offset pointer as base for further addressing. Copy over the
// new shape and compute strides. For this, we need to create a loop from
// rank - 1 to 0.
Value one_index = rewriter.create<LLVM::ConstantOp>(
loc, typeConverter.getIndexType(), rewriter.getIndexAttr(1));
auto target_shape_base = rewriter.create<LLVM::GEPOp>(
loc, index_ptr_type, offset_gep, ValueRange({one}));
auto target_strides_base = rewriter.create<LLVM::GEPOp>(
loc, index_ptr_type, target_shape_base, ValueRange({result_rank}));
auto shape_ptr = shape_desc.alignedPtr(rewriter, loc);
auto result_rank_minus_one =
rewriter.create<LLVM::SubOp>(loc, result_rank, one_index);
Block *init_block = rewriter.getInsertionBlock();
Block *cond_block =
rewriter.splitBlock(init_block, rewriter.getInsertionPoint());
rewriter.setInsertionPointToEnd(init_block);
rewriter.create<LLVM::BrOp>(
loc, ValueRange({result_rank_minus_one, one_index}), cond_block);
rewriter.setInsertionPointToStart(cond_block);
auto index_arg = cond_block->addArgument(typeConverter.getIndexType());
auto stride_arg = cond_block->addArgument(typeConverter.getIndexType());
auto pred = rewriter.create<LLVM::ICmpOp>(
loc, LLVM::LLVMType::getInt1Ty(rewriter.getContext()),
LLVM::ICmpPredicate::sge, index_arg, zero_index);
Block *body_block =
rewriter.splitBlock(cond_block, rewriter.getInsertionPoint());
rewriter.setInsertionPointToStart(body_block);
// Copy size from shape to descriptor.
auto size_load_gep = rewriter.create<LLVM::GEPOp>(
loc, index_ptr_type, shape_ptr, ValueRange{index_arg});
auto extracted_size = rewriter.create<LLVM::LoadOp>(loc, size_load_gep);
auto size_store_gep = rewriter.create<LLVM::GEPOp>(
loc, index_ptr_type, target_shape_base, ValueRange({index_arg}));
rewriter.create<LLVM::StoreOp>(loc, extracted_size, size_store_gep);
// Write stride value and compute next one.
auto stride_store_gep = rewriter.create<LLVM::GEPOp>(
loc, index_ptr_type, target_strides_base, ValueRange({index_arg}));
rewriter.create<LLVM::StoreOp>(loc, stride_arg, stride_store_gep);
auto next_stride =
rewriter.create<LLVM::MulOp>(loc, stride_arg, extracted_size);
// Decrement loop counter and branch back.
auto decrement = rewriter.create<LLVM::SubOp>(loc, index_arg, one_index);
rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, next_stride}),
cond_block);
Block *remainder =
rewriter.splitBlock(body_block, rewriter.getInsertionPoint());
// Hook up the cond exit to the remainder.
rewriter.setInsertionPointToEnd(cond_block);
rewriter.create<LLVM::CondBrOp>(loc, pred, body_block, ValueRange(),
remainder, ValueRange());
// Reset position to beginning of new remainder block.
rewriter.setInsertionPointToStart(remainder);
rewriter.replaceOp(op, {target_desc});
return success();
}
private:
struct PtrsAndOffset {
Value allocated_ptr;
Value aligned_ptr;
Value offset;
};
PtrsAndOffset ExtractMemRefPtrsAndOffset(
Location loc, Value originalOperand, Value convertedOperand,
ConversionPatternRewriter *rewriter) const {
Type operandType = originalOperand.getType();
Value descriptor_ptr;
if (operandType.isa<MemRefType>()) {
descriptor_ptr = convertedOperand;
} else {
UnrankedMemRefDescriptor unranked_descriptor(convertedOperand);
Value underlying_desc_ptr =
unranked_descriptor.memRefDescPtr(*rewriter, loc);
Type element_type =
operandType.cast<UnrankedMemRefType>().getElementType();
LLVM::LLVMType memref_type_0d =
typeConverter.convertType(MemRefType::get(/*shape=*/{}, element_type))
.cast<LLVM::LLVMType>();
descriptor_ptr = rewriter->create<LLVM::BitcastOp>(
loc, memref_type_0d.getPointerTo(), underlying_desc_ptr);
descriptor_ptr = rewriter->create<LLVM::LoadOp>(loc, descriptor_ptr);
}
MemRefDescriptor descriptor(descriptor_ptr);
PtrsAndOffset result;
result.allocated_ptr = descriptor.allocatedPtr(*rewriter, loc);
result.aligned_ptr = descriptor.alignedPtr(*rewriter, loc);
result.offset = descriptor.offset(*rewriter, loc);
return result;
}
};
} // namespace
void PopulateLhloToLLVMConversionPatterns(LLVMTypeConverter *converter,
OwningRewritePatternList *patterns) {
patterns->insert<DynamicMemRefCastOpConverter, ReshapeMemRefCastOpConverter,
StaticMemRefCastOpConverter>(*converter);
}
} // namespace lmhlo
} // namespace mlir

View File

@ -1,63 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace lmhlo {
namespace {
class TestLhloToLLVMPass
: public ::mlir::PassWrapper<TestLhloToLLVMPass,
::mlir::OperationPass<::mlir::ModuleOp>> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<LLVM::LLVMDialect>();
}
public:
void runOnOperation() override {
ModuleOp m = getOperation();
OwningRewritePatternList patterns;
LLVMTypeConverter converter(&getContext());
populateStdToLLVMConversionPatterns(converter, patterns);
PopulateLhloToLLVMConversionPatterns(&converter, &patterns);
ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
target.addIllegalDialect<LmhloDialect>();
if (failed(applyFullConversion(m, target, std::move(patterns)))) {
signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<Pass> createTestLhloToLLVMPass() {
return std::make_unique<TestLhloToLLVMPass>();
}
} // namespace lmhlo
} // namespace mlir

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -canonicalize -cse -convert-linalg-to-llvm -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -canonicalize -cse -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
func @main() -> () {
call @trivial_broadcast_wrapper() : () -> ()

View File

@ -3,7 +3,7 @@
// RUN: -buffer-deallocation -copy-removal -canonicalize -cse \
// RUN: -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops \
// RUN: -lower-affine -convert-scf-to-std -canonicalize -cse \
// RUN: -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main \
// RUN: -convert-std-to-llvm | mlir-cpu-runner -e main \
// RUN: -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | \
// RUN: FileCheck %s

View File

@ -1,190 +0,0 @@
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -convert-scf-to-std -canonicalize -cse -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
func @main() -> () {
call @reshape_with_static_shape_size_matrix_to_1D() : () -> ()
call @reshape_with_static_shape_size_matrix_to_3D() : () -> ()
call @reshape_with_dynamic_shape_size_matrix_to_1D() : () -> ()
call @reshape_with_dynamic_shape_size_matrix_to_3D() : () -> ()
return
}
func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface }
func @reshape_with_static_shape_size_matrix_to_1D() {
%c0 = constant 0 : index
%c1 = constant 1 : index
// Initialize input.
%input = alloc() : memref<2x3xf32>
%dim_x = dim %input, %c0 : memref<2x3xf32>
%dim_y = dim %input, %c1 : memref<2x3xf32>
scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) {
%i_i64 = index_cast %i : index to i64
%i_f32 = sitofp %i_i64 : i64 to f32
store %i_f32, %input[%i, %j] : memref<2x3xf32>
}
%unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
// CHECK: [0, 0, 0]
// CHECK: [1, 1, 1]
// Initialize shape.
%shape = alloc() : memref<1xi64>
%num_elements = muli %dim_x, %dim_y : index
%num_elements_i64 = index_cast %num_elements : index to i64
store %num_elements_i64, %shape[%c0] : memref<1xi64>
// 1. Ranked input, ranked output.
%output_1 = lmhlo.reshape_memref_cast %input(%shape)
: (memref<2x3xf32>, memref<1xi64>) -> memref<6xf32>
%unranked_output_1 = memref_cast %output_1 : memref<6xf32> to memref<*xf32>
call @print_memref_f32(%unranked_output_1) : (memref<*xf32>) -> ()
// CHECK: rank = 1 offset = 0 sizes = [6] strides = [1]
// CHECK: [0, 0, 0, 1, 1, 1]
// 2. Ranked input, unranked output.
%output_2 = lmhlo.reshape_memref_cast %input(%shape)
: (memref<2x3xf32>, memref<1xi64>) -> memref<*xf32>
call @print_memref_f32(%output_2) : (memref<*xf32>) -> ()
// CHECK: rank = 1 offset = 0 sizes = [6] strides = [1]
// CHECK: [0, 0, 0, 1, 1, 1]
// 3. Unranked input, ranked output.
%output_3 = lmhlo.reshape_memref_cast %unranked_input(%shape)
: (memref<*xf32>, memref<1xi64>) -> memref<?xf32>
%unranked_output_3 = memref_cast %output_3 : memref<?xf32> to memref<*xf32>
call @print_memref_f32(%unranked_output_3) : (memref<*xf32>) -> ()
// CHECK: rank = 1 offset = 0 sizes = [6] strides = [1]
// CHECK: [0, 0, 0, 1, 1, 1]
// 4. Unranked input, unranked output.
%output_4 = lmhlo.reshape_memref_cast %unranked_input(%shape)
: (memref<*xf32>, memref<1xi64>) -> memref<*xf32>
call @print_memref_f32(%output_4) : (memref<*xf32>) -> ()
// CHECK: rank = 1 offset = 0 sizes = [6] strides = [1]
// CHECK: [0, 0, 0, 1, 1, 1]
return
}
func @reshape_with_static_shape_size_matrix_to_3D() {
%c0 = constant 0 : index
%c1 = constant 1 : index
%c2 = constant 2 : index
// Initialize input.
%input = alloc() : memref<2x3xf32>
%dim_x = dim %input, %c0 : memref<2x3xf32>
%dim_y = dim %input, %c1 : memref<2x3xf32>
scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) {
%i_i64 = index_cast %i : index to i64
%i_f32 = sitofp %i_i64 : i64 to f32
store %i_f32, %input[%i, %j] : memref<2x3xf32>
}
%unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
// CHECK: [0, 0, 0]
// CHECK: [1, 1, 1]
// Initialize shape.
%shape = alloc() : memref<3xi64>
%c1_i64 = constant 1 : i64
%c2_i64 = constant 2 : i64
%c3_i64 = constant 3 : i64
store %c3_i64, %shape[%c0] : memref<3xi64>
store %c1_i64, %shape[%c1] : memref<3xi64>
store %c2_i64, %shape[%c2] : memref<3xi64>
// Static shape input and shape, dynamic output.
%unranked_output = lmhlo.reshape_memref_cast %input(%shape)
: (memref<2x3xf32>, memref<3xi64>) -> memref<*xf32>
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
// CHECK: rank = 3 offset = 0 sizes = [3, 1, 2] strides = [2, 2, 1]
// CHECK: {{\[}}{{\[}}[0, 0]],
// CHECK: {{\[}}[0, 1]],
// CHECK: {{\[}}[1, 1]]]
return
}
func @reshape_with_dynamic_shape_size_matrix_to_1D() {
%c0 = constant 0 : index
%c1 = constant 1 : index
// Initialize input.
%input = alloc() : memref<2x3xf32>
%dim_x = dim %input, %c0 : memref<2x3xf32>
%dim_y = dim %input, %c1 : memref<2x3xf32>
scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) {
%i_i64 = index_cast %i : index to i64
%i_f32 = sitofp %i_i64 : i64 to f32
store %i_f32, %input[%i, %j] : memref<2x3xf32>
}
%unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
// CHECK: [0, 0, 0]
// CHECK: [1, 1, 1]
// Initialize shape.
%shape = alloc(%c1) : memref<?xi64>
%num_elements = muli %dim_x, %dim_y : index
%num_elements_i64 = index_cast %num_elements : index to i64
store %num_elements_i64, %shape[%c0] : memref<?xi64>
// 1. Ranked input, unranked output.
%output_2 = lmhlo.reshape_memref_cast %input(%shape)
: (memref<2x3xf32>, memref<?xi64>) -> memref<*xf32>
call @print_memref_f32(%output_2) : (memref<*xf32>) -> ()
// CHECK: rank = 1 offset = 0 sizes = [6] strides = [1]
// CHECK: [0, 0, 0, 1, 1, 1]
// 2. Unranked input, unranked output.
%output_4 = lmhlo.reshape_memref_cast %unranked_input(%shape)
: (memref<*xf32>, memref<?xi64>) -> memref<*xf32>
call @print_memref_f32(%output_4) : (memref<*xf32>) -> ()
// CHECK: rank = 1 offset = 0 sizes = [6] strides = [1]
// CHECK: [0, 0, 0, 1, 1, 1]
return
}
func @reshape_with_dynamic_shape_size_matrix_to_3D() {
%c0 = constant 0 : index
%c1 = constant 1 : index
%c2 = constant 2 : index
%c3 = constant 3 : index
// Initialize input.
%input = alloc() : memref<2x3xf32>
%dim_x = dim %input, %c0 : memref<2x3xf32>
%dim_y = dim %input, %c1 : memref<2x3xf32>
scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) {
%i_i64 = index_cast %i : index to i64
%i_f32 = sitofp %i_i64 : i64 to f32
store %i_f32, %input[%i, %j] : memref<2x3xf32>
}
%unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
// CHECK: [0, 0, 0]
// CHECK: [1, 1, 1]
// Initialize shape.
%shape = alloc(%c3) : memref<?xi64>
%c1_i64 = constant 1 : i64
%c2_i64 = constant 2 : i64
%c3_i64 = constant 3 : i64
store %c3_i64, %shape[%c0] : memref<?xi64>
store %c1_i64, %shape[%c1] : memref<?xi64>
store %c2_i64, %shape[%c2] : memref<?xi64>
// Static shape input, dynamic output and shape.
%unranked_output = lmhlo.reshape_memref_cast %input(%shape)
: (memref<2x3xf32>, memref<?xi64>) -> memref<*xf32>
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
// CHECK: rank = 3 offset = 0 sizes = [3, 1, 2] strides = [2, 2, 1]
// CHECK: {{\[}}{{\[}}[0, 0]],
// CHECK: {{\[}}[0, 1]],
// CHECK: {{\[}}[1, 1]]]
return
}

View File

@ -17,7 +17,7 @@ func @dynamic_reshape_from_unranked(
return %reshaped : tensor<?xf32>
}
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>, [[SHAPE:%.*]]: memref<1xi32>)
// CHECK-NEXT: reshape_memref_cast [[ARG]]([[SHAPE]])
// CHECK-NEXT: memref_reshape [[ARG]]([[SHAPE]])
// CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
// -----
@ -30,5 +30,5 @@ func @dynamic_reshape_to_unranked(
return %reshaped : tensor<*xf32>
}
// CHECK-SAME: ([[ARG:%.*]]: memref<?xf32>, [[SHAPE:%.*]]: memref<?xi32>)
// CHECK-NEXT: reshape_memref_cast [[ARG]]([[SHAPE]])
// CHECK-NEXT: memref_reshape [[ARG]]([[SHAPE]])
// CHECK-SAME: : (memref<?xf32>, memref<?xi32>) -> memref<*xf32>

View File

@ -197,10 +197,11 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) {
// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]]
// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = lmhlo.dynamic_memref_cast
// CHECK-SAME: %[[OPERAND]](%[[RESULT_DIM_1]], %[[RESULT_DIM_2]])
// CHECK-SAME: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]]
// CHECK-SAME: : memref<?x?xf32> -> memref<?x?xf32, #map>
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to
// CHECK-SAME: offset: [0],
// CHECK-SAME: sizes: {{\[}}%[[RESULT_DIM_1]], %[[RESULT_DIM_2]]]
// CHECK-SAME: strides: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]]
// CHECK-SAME: : memref<?x?xf32> to memref<?x?xf32, #map>
// CHECK: "lmhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) {
// CHECK-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>

View File

@ -267,7 +267,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
%13 = absf %arg3 : f32
linalg.yield %13 : f32
}
%2 = lmhlo.reshape_memref_cast %1(%arg1)
%2 = memref_reshape %1(%arg1)
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
return %2 : memref<*xf32>
}
@ -279,7 +279,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
// CHECK-NOT: scf.for
// CHECK: linalg.generic
// CHECK: absf
// CHECK: reshape_memref_cast
// CHECK: memref_reshape
// TILED-LABEL: func @view_result
// TILED-DAG: %[[C2:.*]] = constant 2
@ -288,7 +288,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
// TILED-NOT: scf.for
// TILED: linalg.generic
// TILED: absf
// TILED: reshape_memref_cast
// TILED: memref_reshape
// PLOOP-LABEL: func @view_result
@ -297,5 +297,5 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
// PLOOP-NOT: scf.parallel
// PLOOP: linalg.generic
// PLOOP: absf
// PLOOP: reshape_memref_cast
// PLOOP: memref_reshape

View File

@ -1,65 +0,0 @@
// RUN: mlir-hlo-opt %s -lower-affine -convert-scf-to-std -test-lhlo-legalize-to-llvm -split-input-file | FileCheck %s
// CHECK-LABEL: func @static_memref_cast
func @static_memref_cast(%buf : memref<10x1x5xf32>) {
%0 = lmhlo.static_memref_cast %buf
: memref<10x1x5xf32> -> memref<10x5xf32, offset: 2, strides: [5, 1]>
return
}
// CHECK: %[[INPUT_MEMREF_BLDR:.*]] = llvm.mlir.undef : [[DESCRIPTOR_TYPE_3D:!.*]]
// CHECK: llvm.insertvalue
// CHECK: %[[MEMREF_BLDR_0:.*]] = llvm.mlir.undef : [[DESCRIPTOR_TYPE_2D:!.*]]
// CHECK: %[[IN_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF:.*]][0] : [[DESCRIPTOR_TYPE_3D]]
// CHECK: %[[PTR:.*]] = llvm.bitcast %[[IN_PTR]] : !llvm.ptr<float> to !llvm.ptr<float>
// CHECK: %[[MEMREF_BLDR_1:.*]] = llvm.insertvalue %[[PTR]], %[[MEMREF_BLDR_0]][0] : [[DESCRIPTOR_TYPE_2D]]
// CHECK: %[[IN_ALIGNED_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][1] : [[DESCRIPTOR_TYPE_3D]]
// CHECK: %[[ALIGNED_PTR:.*]] = llvm.bitcast %[[IN_ALIGNED_PTR]] : !llvm.ptr<float> to !llvm.ptr<float>
// CHECK: %[[MEMREF_BLDR_2:.*]] = llvm.insertvalue %[[ALIGNED_PTR]], %[[MEMREF_BLDR_1]][1] : [[DESCRIPTOR_TYPE_2D]]
// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
// CHECK: %[[MEMREF_BLDR_3:.*]] = llvm.insertvalue %[[C2]], %[[MEMREF_BLDR_2]][2] : [[DESCRIPTOR_TYPE_2D]]
// CHECK: %[[C10:.*]] = llvm.mlir.constant(10 : index) : !llvm.i64
// CHECK: %[[MEMREF_BLDR_4:.*]] = llvm.insertvalue %[[C10]], %[[MEMREF_BLDR_3]][3, 0] : [[DESCRIPTOR_TYPE_2D]]
// CHECK: %[[C5:.*]] = llvm.mlir.constant(5 : index) : !llvm.i64
// CHECK: %[[MEMREF_BLDR_5:.*]] = llvm.insertvalue %[[C5]], %[[MEMREF_BLDR_4]][4, 0] : [[DESCRIPTOR_TYPE_2D]]
// CHECK: %[[C5_:.*]] = llvm.mlir.constant(5 : index) : !llvm.i64
// CHECK: %[[MEMREF_BLDR_6:.*]] = llvm.insertvalue %[[C5_]], %[[MEMREF_BLDR_5]][3, 1] : [[DESCRIPTOR_TYPE_2D]]
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: %[[MEMREF_BLDR_7:.*]] = llvm.insertvalue %[[C1]], %[[MEMREF_BLDR_6]][4, 1] : [[DESCRIPTOR_TYPE_2D]]
// -----
// CHECK-LABEL: func @dynamic_memref_cast
func @dynamic_memref_cast(%buf : memref<?x?xf32>) {
%size_X = constant 10 : index
%size_Y = constant 50 : index
%stride_X = constant 1 : index
%stride_Y = constant 0 : index
%0 = lmhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%stride_X, %stride_Y]
: memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
return
}
// CHECK: %[[C10:.*]] = llvm.mlir.constant(10 : index) : !llvm.i64
// CHECK: %[[C50:.*]] = llvm.mlir.constant(50 : index) : !llvm.i64
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK: %[[MEMREF_BLDR_0:.*]] = llvm.mlir.undef : [[DESCRIPTOR_TYPE:!.*]]
// CHECK: %[[IN_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF:.*]][0] : [[DESCRIPTOR_TYPE]]
// CHECK: %[[PTR:.*]] = llvm.bitcast %[[IN_PTR]] : !llvm.ptr<float> to !llvm.ptr<float>
// CHECK: %[[MEMREF_BLDR_1:.*]] = llvm.insertvalue %[[PTR]], %[[MEMREF_BLDR_0]][0] : [[DESCRIPTOR_TYPE]]
// CHECK: %[[IN_ALIGNED_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][1] : [[DESCRIPTOR_TYPE]]
// CHECK: %[[ALIGNED_PTR:.*]] = llvm.bitcast %[[IN_ALIGNED_PTR]] : !llvm.ptr<float> to !llvm.ptr<float>
// CHECK: %[[MEMREF_BLDR_2:.*]] = llvm.insertvalue %[[ALIGNED_PTR]], %[[MEMREF_BLDR_1]][1] : [[DESCRIPTOR_TYPE]]
// CHECK: %[[SRC_OFFSET:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][2] : [[DESCRIPTOR_TYPE]]
// CHECK: %[[MEMREF_BLDR_3:.*]] = llvm.insertvalue %[[SRC_OFFSET]], %[[MEMREF_BLDR_2]][2] : [[DESCRIPTOR_TYPE]]
// CHECK: %[[MEMREF_BLDR_4:.*]] = llvm.insertvalue %[[C10]], %[[MEMREF_BLDR_3]][3, 0] : [[DESCRIPTOR_TYPE]]
// CHECK: %[[MEMREF_BLDR_5:.*]] = llvm.insertvalue %[[C1]], %[[MEMREF_BLDR_4]][4, 0] : [[DESCRIPTOR_TYPE]]
// CHECK: %[[MEMREF_BLDR_6:.*]] = llvm.insertvalue %[[C50]], %[[MEMREF_BLDR_5]][3, 1] : [[DESCRIPTOR_TYPE]]
// CHECK: %[[MEMREF_BLDR_7:.*]] = llvm.insertvalue %[[C0]], %[[MEMREF_BLDR_6]][4, 1] : [[DESCRIPTOR_TYPE]]

View File

@ -429,120 +429,6 @@ func @case_memref(%index: memref<i32>, %operand_1: memref<f32>, %operand_2: memr
// -----
func @static_memref_cast(%in: memref<10x1xf32>) {
%out = lmhlo.static_memref_cast %in
: memref<10x1xf32> -> memref<10xf32, offset: 0, strides: [1]>
return
}
// CHECK-LABEL: func @static_memref_cast
// -----
func @static_memref_cast_dynamic_operand(%in: memref<10x?xf32>) {
// expected-error @+1 {{operand must have static shape}}
%out = lmhlo.static_memref_cast %in
: memref<10x?xf32> -> memref<10x1xf32, offset: 0, strides: [10, 1]>
return
}
// -----
func @static_memref_cast_dynamic_result(%in: memref<10x1xf32>) {
// expected-error @+1 {{result must have static shape}}
%out = lmhlo.static_memref_cast %in
: memref<10x1xf32> -> memref<10x?xf32, offset: 0, strides: [?, ?]>
return
}
// -----
func @dynamic_memref_cast(%in: memref<?xf32>) {
%size = constant 10 : index
%step = constant 1 : index
%out = lmhlo.dynamic_memref_cast %in(%size)[%step]
: memref<?xf32> -> memref<?xf32, offset: 0, strides: [?]>
return
}
// CHECK-LABEL: func @dynamic_memref_cast
// -----
func @dynamic_memref_cast_incompatible_result_type(%in: memref<?xf32>) {
// expected-error @+3 {{`sizes` args count must be equal to the rank of the output memref}}
%size = constant 10 : index
%step = constant 1 : index
%out = lmhlo.dynamic_memref_cast %in(%size)[%step]
: memref<?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
return
}
// -----
// CHECK-LABEL: func @reshape_memref_cast(
func @reshape_memref_cast(%unranked: memref<*xf32>, %shape1: memref<1xi32>,
%shape2: memref<2xi32>, %shape3: memref<?xi32>) {
// CHECK-SAME: [[UNRANKED:%.*]]: memref<*xf32>, [[SHAPE_1:%.*]]: memref<1xi32>,
// CHECK-SAME: [[SHAPE_2:%.*]]: memref<2xi32>, [[SHAPE_3:%.*]]: memref<?xi32>
// CHECK-NEXT: [[DYN_VEC:%.*]] = lmhlo.reshape_memref_cast [[UNRANKED]]
// CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
%dyn_vec = lmhlo.reshape_memref_cast %unranked(%shape1)
: (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
// CHECK-NEXT: [[DYN_MAT:%.*]] = lmhlo.reshape_memref_cast [[DYN_VEC]]
// CHECK-SAME: : (memref<?xf32>, memref<2xi32>) -> memref<?x?xf32>
%dyn_mat = lmhlo.reshape_memref_cast %dyn_vec(%shape2)
: (memref<?xf32>, memref<2xi32>) -> memref<?x?xf32>
// CHECK-NEXT: {{%.*}} = lmhlo.reshape_memref_cast [[DYN_MAT]]
// CHECK-SAME: : (memref<?x?xf32>, memref<?xi32>) -> memref<*xf32>
%new_unranked = lmhlo.reshape_memref_cast %dyn_mat(%shape3)
: (memref<?x?xf32>, memref<?xi32>) -> memref<*xf32>
return
}
// -----
func @reshape_memref_cast_element_type_mismatch(
%buf: memref<*xf32>, %shape: memref<1xi32>) {
// expected-error @+1 {{element types of source and destination memref types should be the same}}
lmhlo.reshape_memref_cast %buf(%shape)
: (memref<*xf32>, memref<1xi32>) -> memref<?xi32>
}
// -----
func @reshape_memref_cast_dst_ranked_shape_unranked(
%buf: memref<*xf32>, %shape: memref<?xi32>) {
// expected-error @+1 {{cannot use shape operand with dynamic length to cast statically-ranked memref type}}
lmhlo.reshape_memref_cast %buf(%shape)
: (memref<*xf32>, memref<?xi32>) -> memref<?xf32>
return
}
// -----
func @reshape_memref_cast_dst_shape_rank_mismatch(
%buf: memref<*xf32>, %shape: memref<1xi32>) {
// expected-error @+1 {{length of shape operand differs from the result's memref rank}}
lmhlo.reshape_memref_cast %buf(%shape)
: (memref<*xf32>, memref<1xi32>) -> memref<?x?xf32>
return
}
// -----
func @reshape_memref_cast_affine_map_is_not_identity(
%buf: memref<4x4xf32, offset: 0, strides: [3, 2]>,
%shape: memref<1xi32>) {
// expected-error @+1 {{operand memref type should have identity affine map}}
lmhlo.reshape_memref_cast %buf(%shape)
: (memref<4x4xf32, offset: 0, strides: [3, 2]>, memref<1xi32>)
-> memref<8xf32>
return
}
// -----
// CHECK-LABEL: func @atan2_memrefs
func @atan2_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
"lmhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()