diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index 32901f4..9762e24 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -314,169 +314,6 @@ def HLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> { ); } -//===----------------------------------------------------------------------===// -// StaticMemRefCastOp -//===----------------------------------------------------------------------===// - -def HLO_StaticMemRefCastOp: Op]> { - let summary = [{ - modifies the offset, sizes and strides of a statically shaped memref - }]; - let description = [{ - Casts the statically shaped memref operand to a memref with optionally - modified offsets, sizes and strides. - - Example: - ```mlir - %buf_transformed = - lmhlo.static_memref_cast %buf - : memref<1x5xf32> -> memref<5xf32, offset: 2, strides: [1]> - - // The result of the op is a rank-1 memref with `[5]` shape, stride 1 and - // offset 2. - ``` - }]; - - let arguments = (ins Arg:$operand); - let results = (outs Res:$result); - - let builders = [ - OpBuilderDAG<(ins "MemRefType":$resultType, "Value":$operand), - [{ - $_state.addOperands(operand); - $_state.types.push_back(resultType); - }]>]; - - let extraClassDeclaration = [{ - MemRefType getType() { return getResult().getType().cast(); } - }]; - - let verifier = [{ return Verify(*this); }]; - let assemblyFormat = [{ - $operand attr-dict `:` type($operand) `->` type($result) - }]; -} - -//===----------------------------------------------------------------------===// -// DynamicMemRefCastOp -//===----------------------------------------------------------------------===// - -def HLO_DynamicMemRefCastOp: Op]> { - let summary = "dynamic memref cast operation"; - let description = [{ - Change sizes and strides of a memref using the values computed in runtime. - - Example: - ```mlir - %buf_transformed = - lmhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%step_X, %step_Y] - : memref -> memref - // The result of the op is a type-erased memref with `[%size_X, %size_Y]` - // shape and `[%step_X, %step_Y]` strides. The offset will be inherited - // from the input. - ``` - }]; - - let arguments = (ins - Arg:$operand, - Variadic:$sizes, - Variadic:$strides - ); - let results = (outs Res:$result); - - let builders = [ - OpBuilderDAG<(ins "MemRefType":$resultType, "Value":$operand, - "ValueRange":$sizes, "ValueRange":$strides), - [{ - $_state.addOperands(operand); - $_state.addOperands(sizes); - $_state.addOperands(strides); - $_state.types.push_back(resultType); - }]>]; - - let extraClassDeclaration = [{ - MemRefType getType() { return getResult().getType().cast(); } - }]; - - let verifier = [{ return Verify(*this); }]; - let assemblyFormat = [{ - $operand `(` $sizes `)` `[` $strides `]` attr-dict `:` type($operand) `->` - type($result) - }]; -} - -//===----------------------------------------------------------------------===// -// ReshapeMemRefCastOp -//===----------------------------------------------------------------------===// - -def ReshapeMemRefCastOp: Op, - NoSideEffect]> { - let summary = "reshape memref cast operation"; - let description = [{ - The `reshape_memref_cast` operation converts a memref from one type to an - equivalent type with a provided shape. The data is never copied or moved. - The source and destination types are compatible if both have the same - element type, address space and identity layout map. The following - combinations are possible: - - a. Both are ranked memref types. - - ```mlir - // Reshape statically-shaped memref. - %dst = reshape_memref_cast %src(%shape) - : (memref<4x1xf32>, memref<1xi32>) to memref<4xf32> - %dst0 = reshape_memref_cast %src(%shape0) - : (memref<4x1xf32>, memref<2xi32>) to memref<2x2xf32> - ``` - - b. Source type is ranked, destination type is unranked. - - ```mlir - // Reshape dynamically-shaped 1D memref. - %dst = reshape_memref_cast %src(%shape) - : (memref, memref) to memref<*xf32> - ``` - - c. Source type is unranked, destination type is ranked. - - ```mlir - // Flatten unranked memref. - %dst = reshape_memref_cast %src(%shape) - : (memref<*xf32>, memref<1xi32>) to memref - ``` - - d. Both are unranked memref types. - - ```mlir - // Reshape unranked memref. - %dst = reshape_memref_cast %src(%shape) - : (memref<*xf32>, memref) to memref<*xf32> - ``` - }]; - - let arguments = (ins - AnyRankedOrUnrankedMemRef:$operand, - LHLO_ExtentBuffer:$shape - ); - let results = (outs AnyRankedOrUnrankedMemRef:$result); - - let extraClassDeclaration = [{ - BaseMemRefType getType() { - return getResult().getType().cast(); } - }]; - - let verifier = [{ return Verify(*this); }]; - let assemblyFormat = [{ - $operand `(` $shape `)` attr-dict `:` `(` type($operand) `,` type($shape) - `)` `->` type($result) - }]; -} - - //===----------------------------------------------------------------------===// // LMHLO Other op definitions. //===----------------------------------------------------------------------===// diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td b/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td index 39b4ca6..17c0524 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td +++ b/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td @@ -46,12 +46,6 @@ def LhloLegalizeToGpuPass : Pass<"lhlo-legalize-to-gpu", "FuncOp"> { } -def TestLhloToLLVMPass : Pass<"test-lhlo-legalize-to-llvm", "FuncOp"> { - let summary = "Legalize from LHLO dialect to LLVM."; - let constructor = "createTestLhloToLLVMPass()"; -} - - def LhloLegalizeToParallelLoopsPass : Pass<"lhlo-legalize-to-parallel-loops", "FuncOp"> { let summary = "Legalize from LHLO dialect to parallel loops."; let constructor = "createLegalizeLhloToParallelLoopsPass()"; diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h b/include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h index e9418f0..3b6041c 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h @@ -35,8 +35,6 @@ inline void registerAllMhloPasses() { registerMHLOPasses(); } namespace lmhlo { -std::unique_ptr createTestLhloToLLVMPass(); - #define GEN_PASS_REGISTRATION #include "mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.h.inc" diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h index c568165..a2066df 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h @@ -24,8 +24,6 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" namespace mlir { -class LLVMTypeConverter; -class LowerToLLVMOptions; class OwningRewritePatternList; // Populates a collection of rewrite patterns to realize element-wise operations @@ -94,14 +92,6 @@ void PopulateTrigonometricToApproximationPatterns( } // namespace mhlo -namespace lmhlo { - -/// Collect a set of patterns to convert from the LHLO dialect to LLVM. -void PopulateLhloToLLVMConversionPatterns(LLVMTypeConverter *converter, - OwningRewritePatternList *patterns); - -} // namespace lmhlo - namespace chlo { // Populates a collection of conversion patterns for legalizing client-HLO to diff --git a/lib/Dialect/mhlo/IR/lhlo_ops.cc b/lib/Dialect/mhlo/IR/lhlo_ops.cc index 4524cf3..126eda0 100644 --- a/lib/Dialect/mhlo/IR/lhlo_ops.cc +++ b/lib/Dialect/mhlo/IR/lhlo_ops.cc @@ -88,76 +88,6 @@ void ConstOp::getCanonicalizationPatterns(OwningRewritePatternList& results, results.insert(context); } -//===----------------------------------------------------------------------===// -// StaticMemRefCastOp -//===----------------------------------------------------------------------===// - -Value StaticMemRefCastOp::getViewSource() { return *getODSOperands(0).begin(); } - -static LogicalResult Verify(StaticMemRefCastOp op) { - if (!op.operand().getType().cast().hasStaticShape()) - return op.emitOpError("operand must have static shape"); - if (!op.getType().hasStaticShape()) - return op.emitOpError("result must have static shape"); - return success(); -} - -//===----------------------------------------------------------------------===// -// DynamicMemRefCastOp -//===----------------------------------------------------------------------===// - -Value DynamicMemRefCastOp::getViewSource() { - return *getODSOperands(0).begin(); -} - -static LogicalResult Verify(DynamicMemRefCastOp op) { - // Check if `sizes` and `strides` args are compatible with the result type. - if (op.sizes().size() != op.getType().getRank()) - return op.emitOpError( - "`sizes` args count must be equal to the rank of the output memref"); - return success(); -} - -//===----------------------------------------------------------------------===// -// ReshapeMemrefCastOp -//===----------------------------------------------------------------------===// - -Value ReshapeMemRefCastOp::getViewSource() { return operand(); } - -static LogicalResult Verify(ReshapeMemRefCastOp op) { - Type operandType = op.operand().getType(); - Type resultType = op.result().getType(); - - Type operandElementType = operandType.cast().getElementType(); - Type resultElementType = resultType.cast().getElementType(); - if (operandElementType != resultElementType) - return op.emitOpError( - "element types of source and destination memref " - "types should be the same"); - - if (auto operandMemRefType = operandType.dyn_cast()) - if (!operandMemRefType.getAffineMaps().empty()) - return op.emitOpError( - "operand memref type should have identity affine map"); - - int64_t shapeSize = op.shape().getType().cast().getDimSize(0); - auto resultMemRefType = resultType.dyn_cast(); - if (resultMemRefType) { - if (shapeSize == ShapedType::kDynamicSize) - return op.emitOpError( - "cannot use shape operand with dynamic length to " - "cast statically-ranked memref type"); - if (shapeSize != resultMemRefType.getRank()) - return op.emitOpError( - "length of shape operand differs from the result's memref rank"); - - if (!resultMemRefType.getAffineMaps().empty()) - return op.emitOpError( - "result memref type should have identity affine map"); - } - return success(); -} - } // namespace lmhlo } // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/CMakeLists.txt b/lib/Dialect/mhlo/transforms/CMakeLists.txt index 2435074..eebdcf4 100644 --- a/lib/Dialect/mhlo/transforms/CMakeLists.txt +++ b/lib/Dialect/mhlo/transforms/CMakeLists.txt @@ -134,8 +134,6 @@ add_mlir_library(LmhloPasses lhlo_fuse_linalg.cc lhlo_legalize_to_affine.cc lhlo_legalize_to_gpu.cc - lhlo_legalize_to_llvm.cc - lhlo_legalize_to_llvm_pass.cc lhlo_legalize_to_parallel_loops.cc DEPENDS diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index ae897a6..aca5977 100644 --- a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -206,7 +206,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter // Inserts dynamic memref to change the layout of the memref to put 0-stride // and size of the target dimension if size-1 dimension expansion is // necessary. - lmhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp( + MemRefReinterpretCastOp InsertDynamicMemrefCastOp( mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const { auto loc = op.getLoc(); auto operand_type = operand.getType().cast(); @@ -259,8 +259,13 @@ struct HloToLhloDynamicBroadcastInDimOpConverter makeStridedLinearLayoutMap(dynamic_layout, /*offset=*/0, b->getContext())); - auto transformed_operand = b->create( - loc, type_erased_memref_type, operand, sizes, strides); + SmallVector static_sizes(sizes.size(), + ShapedType::kDynamicSize); + SmallVector static_strides(strides.size(), + ShapedType::kDynamicStrideOrOffset); + auto transformed_operand = b->create( + loc, type_erased_memref_type, operand, /*offset=*/0, static_sizes, + static_strides, llvm::None, sizes, strides); return transformed_operand; } }; @@ -284,7 +289,7 @@ struct HloToLhloDynamicReshapeConverter return failure(); } mhlo::DynamicReshapeOp::Adaptor adaptor(operands); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, result_type, adaptor.operand(), adaptor.output_shape()); return success(); } diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc deleted file mode 100644 index 57ea947..0000000 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc +++ /dev/null @@ -1,370 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/Transforms/DialectConversion.h" - -namespace mlir { -namespace lmhlo { -namespace { - -struct StaticMemRefCastOpConverter - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult matchAndRewrite( - Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - auto cast_op = cast(op); - - StaticMemRefCastOp::Adaptor operands_adaptor(operands); - MemRefDescriptor sourceMemRef(operands_adaptor.operand()); - - MemRefType targetMemRefType = - cast_op.getResult().getType().cast(); - auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) - .dyn_cast_or_null(); - if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) - return failure(); - // Create descriptor. - auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); - Type llvmTargetElementTy = desc.getElementPtrType(); - // Set allocated ptr. - Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); - allocated = - rewriter.create(loc, llvmTargetElementTy, allocated); - desc.setAllocatedPtr(rewriter, loc, allocated); - // Set aligned ptr. - Value ptr = sourceMemRef.alignedPtr(rewriter, loc); - ptr = rewriter.create(loc, llvmTargetElementTy, ptr); - desc.setAlignedPtr(rewriter, loc, ptr); - - // Fill size and stride descriptors in memref. - auto target_sizes = targetMemRefType.getShape(); - int64_t target_offset; - llvm::SmallVector target_strides; - if (failed((getStridesAndOffset(targetMemRefType, target_strides, - target_offset)))) - return failure(); - - // Copy offset of `targetMemRef`. - desc.setConstantOffset(rewriter, loc, target_offset); - for (int i = 0, e = targetMemRefType.getRank(); i < e; ++i) { - desc.setConstantSize(rewriter, loc, i, target_sizes[i]); - desc.setConstantStride(rewriter, loc, i, target_strides[i]); - } - rewriter.replaceOp(op, {desc}); - return success(); - } -}; - -struct DynamicMemRefCastOpConverter - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult matchAndRewrite( - Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - auto cast_op = cast(op); - - DynamicMemRefCastOp::Adaptor operands_adaptor(operands); - MemRefDescriptor sourceMemRef(operands_adaptor.operand()); - - MemRefType targetMemRefType = - cast_op.getResult().getType().cast(); - auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) - .dyn_cast_or_null(); - if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) - return failure(); - // Create descriptor. - auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); - Type llvmTargetElementTy = desc.getElementPtrType(); - // Set allocated ptr. - Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); - allocated = - rewriter.create(loc, llvmTargetElementTy, allocated); - desc.setAllocatedPtr(rewriter, loc, allocated); - // Set aligned ptr. - Value ptr = sourceMemRef.alignedPtr(rewriter, loc); - ptr = rewriter.create(loc, llvmTargetElementTy, ptr); - desc.setAlignedPtr(rewriter, loc, ptr); - // Copy offset of `sourceMemRef`. - desc.setOffset(rewriter, loc, sourceMemRef.offset(rewriter, loc)); - - // Fill size and stride descriptors in memref. - if (!cast_op.sizes().empty()) { - auto sizes = operands_adaptor.sizes(); - auto strides = operands_adaptor.strides(); - for (int i = 0, e = targetMemRefType.getRank(); i < e; ++i) { - desc.setSize(rewriter, loc, i, sizes[i]); - desc.setStride(rewriter, loc, i, strides[i]); - } - } - rewriter.replaceOp(op, {desc}); - return success(); - } -}; - -struct ReshapeMemRefCastOpConverter - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult matchAndRewrite( - Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - - auto reshape_op = cast(op); - auto dst_type = reshape_op.getResult().getType().cast(); - auto element_type = dst_type.getElementType(); - - auto shape = reshape_op.shape(); - - ReshapeMemRefCastOp::Adaptor operands_adaptor(operands); - PtrsAndOffset ptrs_n_offset = ExtractMemRefPtrsAndOffset( - loc, reshape_op.operand(), operands_adaptor.operand(), &rewriter); - - MemRefDescriptor shape_desc(operands_adaptor.shape()); - - auto shape_memref_type = shape.getType().cast(); - - if (shape_memref_type.hasStaticShape()) { - auto shape_length = shape_memref_type.getDimSize(0); - - MemRefType targetMemRefType = MemRefType::get( - SmallVector(shape_length, 1), element_type); - auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) - .dyn_cast_or_null(); - if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) - return failure(); - // Create descriptor. - auto desc = - MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); - desc.setAllocatedPtr(rewriter, loc, ptrs_n_offset.allocated_ptr); - desc.setAlignedPtr(rewriter, loc, ptrs_n_offset.aligned_ptr); - desc.setOffset(rewriter, loc, ptrs_n_offset.offset); - - auto llvm_index_type = typeConverter.getIndexType(); - auto llvm_index_ptr_type = llvm_index_type.getPointerTo(); - Value stride_carried = rewriter.create( - loc, llvm_index_type, - rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); - for (int i = shape_length - 1; i >= 0; --i) { - Value pos = rewriter.create( - loc, llvm_index_type, - rewriter.getIntegerAttr(rewriter.getIndexType(), i)); - Value ptr = rewriter.create( - loc, llvm_index_ptr_type, shape_desc.alignedPtr(rewriter, loc), - ValueRange{pos}); - Value extracted_size = rewriter.create(loc, ptr); - desc.setSize(rewriter, loc, i, extracted_size); - desc.setStride(rewriter, loc, i, stride_carried); - // Update stride - if (i > 0) { - stride_carried = - rewriter.create(loc, stride_carried, extracted_size); - } - } - if (dst_type.isa()) { - rewriter.replaceOp(op, {desc}); - } else { - Value rank = rewriter.create( - loc, llvm_index_type, - rewriter.getIntegerAttr(rewriter.getIndexType(), shape_length)); - Value alloca = - typeConverter.promoteOneMemRefDescriptor(loc, desc, rewriter); - Value void_ptr = - rewriter.create(loc, getVoidPtrType(), alloca); - auto unranked_desc = UnrankedMemRefDescriptor::pack( - rewriter, loc, typeConverter, dst_type.cast(), - {rank, void_ptr}); - rewriter.replaceOp(op, {unranked_desc}); - } - return success(); - } - - // The shape is a rank-1 tensor with unknown length. - Value result_rank = shape_desc.size(rewriter, loc, 0); - // TODO(herhut): Propely handle address spaces. - unsigned address_space = 0; - auto target_type = - typeConverter - .convertType(UnrankedMemRefType::get(element_type, address_space)) - .cast(); - // Create the unranked memref descriptor that holds the ranked one. The - // inner descriptor is allocated on stack. - UnrankedMemRefDescriptor target_desc = - UnrankedMemRefDescriptor::undef(rewriter, loc, target_type); - target_desc.setRank(rewriter, loc, result_rank); - SmallVector sizes; - UnrankedMemRefDescriptor::computeSizes(rewriter, loc, typeConverter, - {target_desc}, sizes); - auto void_ptr_type = LLVM::LLVMType::getInt8PtrTy(rewriter.getContext()); - Value ranked_desc_mem = rewriter.create( - loc, void_ptr_type, sizes.front(), llvm::None); - target_desc.setMemRefDescPtr(rewriter, loc, ranked_desc_mem); - - // Fill the fixed parts. For this, we cast to a 0-D memref. - auto zero_d_memref_type = MemRefType::get({}, element_type); - Value as_zero_d = rewriter.create( - loc, - typeConverter.convertType(zero_d_memref_type) - .cast() - .getPointerTo(address_space), - ranked_desc_mem); - // Some common constants. Use 32 bit where required by gep struct indexes. - auto int32_type = typeConverter.convertType(rewriter.getI32Type()); - Value zero_index = rewriter.create( - loc, typeConverter.getIndexType(), rewriter.getIndexAttr(0)); - Value zero = rewriter.create( - loc, int32_type, rewriter.getI32IntegerAttr(0)); - Value one = rewriter.create( - loc, int32_type, rewriter.getI32IntegerAttr(1)); - Value two = rewriter.create( - loc, int32_type, rewriter.getI32IntegerAttr(2)); - // Set base_pointer and aligned pointer. - auto element_ptr_ptr_type = typeConverter.convertType(element_type) - .cast() - .getPointerTo(address_space) - .getPointerTo(address_space); - auto base_gep = rewriter.create( - loc, element_ptr_ptr_type, as_zero_d, ValueRange({zero_index, zero})); - rewriter.create(loc, ptrs_n_offset.allocated_ptr, base_gep); - auto aligned_gep = rewriter.create( - loc, element_ptr_ptr_type, as_zero_d, ValueRange({zero_index, one})); - rewriter.create(loc, ptrs_n_offset.aligned_ptr, aligned_gep); - // Set offset. - auto index_ptr_type = - typeConverter.getIndexType().getPointerTo(address_space); - auto offset_gep = rewriter.create( - loc, index_ptr_type, as_zero_d, ValueRange({zero_index, two})); - rewriter.create(loc, ptrs_n_offset.offset, offset_gep); - - // Use the offset pointer as base for further addressing. Copy over the - // new shape and compute strides. For this, we need to create a loop from - // rank - 1 to 0. - Value one_index = rewriter.create( - loc, typeConverter.getIndexType(), rewriter.getIndexAttr(1)); - auto target_shape_base = rewriter.create( - loc, index_ptr_type, offset_gep, ValueRange({one})); - auto target_strides_base = rewriter.create( - loc, index_ptr_type, target_shape_base, ValueRange({result_rank})); - auto shape_ptr = shape_desc.alignedPtr(rewriter, loc); - auto result_rank_minus_one = - rewriter.create(loc, result_rank, one_index); - - Block *init_block = rewriter.getInsertionBlock(); - Block *cond_block = - rewriter.splitBlock(init_block, rewriter.getInsertionPoint()); - rewriter.setInsertionPointToEnd(init_block); - rewriter.create( - loc, ValueRange({result_rank_minus_one, one_index}), cond_block); - rewriter.setInsertionPointToStart(cond_block); - auto index_arg = cond_block->addArgument(typeConverter.getIndexType()); - auto stride_arg = cond_block->addArgument(typeConverter.getIndexType()); - auto pred = rewriter.create( - loc, LLVM::LLVMType::getInt1Ty(rewriter.getContext()), - LLVM::ICmpPredicate::sge, index_arg, zero_index); - - Block *body_block = - rewriter.splitBlock(cond_block, rewriter.getInsertionPoint()); - rewriter.setInsertionPointToStart(body_block); - - // Copy size from shape to descriptor. - auto size_load_gep = rewriter.create( - loc, index_ptr_type, shape_ptr, ValueRange{index_arg}); - auto extracted_size = rewriter.create(loc, size_load_gep); - auto size_store_gep = rewriter.create( - loc, index_ptr_type, target_shape_base, ValueRange({index_arg})); - rewriter.create(loc, extracted_size, size_store_gep); - // Write stride value and compute next one. - auto stride_store_gep = rewriter.create( - loc, index_ptr_type, target_strides_base, ValueRange({index_arg})); - rewriter.create(loc, stride_arg, stride_store_gep); - auto next_stride = - rewriter.create(loc, stride_arg, extracted_size); - - // Decrement loop counter and branch back. - auto decrement = rewriter.create(loc, index_arg, one_index); - rewriter.create(loc, ValueRange({decrement, next_stride}), - cond_block); - - Block *remainder = - rewriter.splitBlock(body_block, rewriter.getInsertionPoint()); - - // Hook up the cond exit to the remainder. - rewriter.setInsertionPointToEnd(cond_block); - rewriter.create(loc, pred, body_block, ValueRange(), - remainder, ValueRange()); - - // Reset position to beginning of new remainder block. - rewriter.setInsertionPointToStart(remainder); - rewriter.replaceOp(op, {target_desc}); - return success(); - } - - private: - struct PtrsAndOffset { - Value allocated_ptr; - Value aligned_ptr; - Value offset; - }; - - PtrsAndOffset ExtractMemRefPtrsAndOffset( - Location loc, Value originalOperand, Value convertedOperand, - ConversionPatternRewriter *rewriter) const { - Type operandType = originalOperand.getType(); - Value descriptor_ptr; - if (operandType.isa()) { - descriptor_ptr = convertedOperand; - } else { - UnrankedMemRefDescriptor unranked_descriptor(convertedOperand); - Value underlying_desc_ptr = - unranked_descriptor.memRefDescPtr(*rewriter, loc); - - Type element_type = - operandType.cast().getElementType(); - LLVM::LLVMType memref_type_0d = - typeConverter.convertType(MemRefType::get(/*shape=*/{}, element_type)) - .cast(); - descriptor_ptr = rewriter->create( - loc, memref_type_0d.getPointerTo(), underlying_desc_ptr); - descriptor_ptr = rewriter->create(loc, descriptor_ptr); - } - MemRefDescriptor descriptor(descriptor_ptr); - PtrsAndOffset result; - result.allocated_ptr = descriptor.allocatedPtr(*rewriter, loc); - result.aligned_ptr = descriptor.alignedPtr(*rewriter, loc); - result.offset = descriptor.offset(*rewriter, loc); - return result; - } -}; - -} // namespace - -void PopulateLhloToLLVMConversionPatterns(LLVMTypeConverter *converter, - OwningRewritePatternList *patterns) { - patterns->insert(*converter); -} - -} // namespace lmhlo -} // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc deleted file mode 100644 index 6b9286a..0000000 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace lmhlo { -namespace { - -class TestLhloToLLVMPass - : public ::mlir::PassWrapper> { - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - public: - void runOnOperation() override { - ModuleOp m = getOperation(); - - OwningRewritePatternList patterns; - LLVMTypeConverter converter(&getContext()); - populateStdToLLVMConversionPatterns(converter, patterns); - PopulateLhloToLLVMConversionPatterns(&converter, &patterns); - - ConversionTarget target(getContext()); - target.addLegalDialect(); - target.addLegalOp(); - target.addIllegalDialect(); - - if (failed(applyFullConversion(m, target, std::move(patterns)))) { - signalPassFailure(); - } - } -}; - -} // namespace - -std::unique_ptr createTestLhloToLLVMPass() { - return std::make_unique(); -} - -} // namespace lmhlo -} // namespace mlir diff --git a/tests/end2end/broadcast.mlir b/tests/end2end/broadcast.mlir index dd2c311..f8f6ce3 100644 --- a/tests/end2end/broadcast.mlir +++ b/tests/end2end/broadcast.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -canonicalize -cse -convert-linalg-to-llvm -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s +// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -canonicalize -cse -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s func @main() -> () { call @trivial_broadcast_wrapper() : () -> () diff --git a/tests/end2end/reduce.mlir b/tests/end2end/reduce.mlir index b018b73..9b8b100 100644 --- a/tests/end2end/reduce.mlir +++ b/tests/end2end/reduce.mlir @@ -3,7 +3,7 @@ // RUN: -buffer-deallocation -copy-removal -canonicalize -cse \ // RUN: -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops \ // RUN: -lower-affine -convert-scf-to-std -canonicalize -cse \ -// RUN: -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main \ +// RUN: -convert-std-to-llvm | mlir-cpu-runner -e main \ // RUN: -entry-point-result=void \ // RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | \ // RUN: FileCheck %s diff --git a/tests/end2end/reshape.mlir b/tests/end2end/reshape.mlir deleted file mode 100644 index 8311546..0000000 --- a/tests/end2end/reshape.mlir +++ /dev/null @@ -1,190 +0,0 @@ -// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -convert-scf-to-std -canonicalize -cse -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s - -func @main() -> () { - call @reshape_with_static_shape_size_matrix_to_1D() : () -> () - call @reshape_with_static_shape_size_matrix_to_3D() : () -> () - call @reshape_with_dynamic_shape_size_matrix_to_1D() : () -> () - call @reshape_with_dynamic_shape_size_matrix_to_3D() : () -> () - return -} - -func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface } - -func @reshape_with_static_shape_size_matrix_to_1D() { - %c0 = constant 0 : index - %c1 = constant 1 : index - - // Initialize input. - %input = alloc() : memref<2x3xf32> - %dim_x = dim %input, %c0 : memref<2x3xf32> - %dim_y = dim %input, %c1 : memref<2x3xf32> - scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) { - %i_i64 = index_cast %i : index to i64 - %i_f32 = sitofp %i_i64 : i64 to f32 - store %i_f32, %input[%i, %j] : memref<2x3xf32> - } - %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32> - call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> () - // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1] - // CHECK: [0, 0, 0] - // CHECK: [1, 1, 1] - - // Initialize shape. - %shape = alloc() : memref<1xi64> - %num_elements = muli %dim_x, %dim_y : index - %num_elements_i64 = index_cast %num_elements : index to i64 - store %num_elements_i64, %shape[%c0] : memref<1xi64> - - // 1. Ranked input, ranked output. - %output_1 = lmhlo.reshape_memref_cast %input(%shape) - : (memref<2x3xf32>, memref<1xi64>) -> memref<6xf32> - %unranked_output_1 = memref_cast %output_1 : memref<6xf32> to memref<*xf32> - call @print_memref_f32(%unranked_output_1) : (memref<*xf32>) -> () - // CHECK: rank = 1 offset = 0 sizes = [6] strides = [1] - // CHECK: [0, 0, 0, 1, 1, 1] - - // 2. Ranked input, unranked output. - %output_2 = lmhlo.reshape_memref_cast %input(%shape) - : (memref<2x3xf32>, memref<1xi64>) -> memref<*xf32> - call @print_memref_f32(%output_2) : (memref<*xf32>) -> () - // CHECK: rank = 1 offset = 0 sizes = [6] strides = [1] - // CHECK: [0, 0, 0, 1, 1, 1] - - // 3. Unranked input, ranked output. - %output_3 = lmhlo.reshape_memref_cast %unranked_input(%shape) - : (memref<*xf32>, memref<1xi64>) -> memref - %unranked_output_3 = memref_cast %output_3 : memref to memref<*xf32> - call @print_memref_f32(%unranked_output_3) : (memref<*xf32>) -> () - // CHECK: rank = 1 offset = 0 sizes = [6] strides = [1] - // CHECK: [0, 0, 0, 1, 1, 1] - - // 4. Unranked input, unranked output. - %output_4 = lmhlo.reshape_memref_cast %unranked_input(%shape) - : (memref<*xf32>, memref<1xi64>) -> memref<*xf32> - call @print_memref_f32(%output_4) : (memref<*xf32>) -> () - // CHECK: rank = 1 offset = 0 sizes = [6] strides = [1] - // CHECK: [0, 0, 0, 1, 1, 1] - return -} - -func @reshape_with_static_shape_size_matrix_to_3D() { - %c0 = constant 0 : index - %c1 = constant 1 : index - %c2 = constant 2 : index - - // Initialize input. - %input = alloc() : memref<2x3xf32> - %dim_x = dim %input, %c0 : memref<2x3xf32> - %dim_y = dim %input, %c1 : memref<2x3xf32> - scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) { - %i_i64 = index_cast %i : index to i64 - %i_f32 = sitofp %i_i64 : i64 to f32 - store %i_f32, %input[%i, %j] : memref<2x3xf32> - } - %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32> - call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> () - // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1] - // CHECK: [0, 0, 0] - // CHECK: [1, 1, 1] - - // Initialize shape. - %shape = alloc() : memref<3xi64> - %c1_i64 = constant 1 : i64 - %c2_i64 = constant 2 : i64 - %c3_i64 = constant 3 : i64 - store %c3_i64, %shape[%c0] : memref<3xi64> - store %c1_i64, %shape[%c1] : memref<3xi64> - store %c2_i64, %shape[%c2] : memref<3xi64> - - // Static shape input and shape, dynamic output. - %unranked_output = lmhlo.reshape_memref_cast %input(%shape) - : (memref<2x3xf32>, memref<3xi64>) -> memref<*xf32> - call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () - // CHECK: rank = 3 offset = 0 sizes = [3, 1, 2] strides = [2, 2, 1] - // CHECK: {{\[}}{{\[}}[0, 0]], - // CHECK: {{\[}}[0, 1]], - // CHECK: {{\[}}[1, 1]]] - return -} - -func @reshape_with_dynamic_shape_size_matrix_to_1D() { - %c0 = constant 0 : index - %c1 = constant 1 : index - - // Initialize input. - %input = alloc() : memref<2x3xf32> - %dim_x = dim %input, %c0 : memref<2x3xf32> - %dim_y = dim %input, %c1 : memref<2x3xf32> - scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) { - %i_i64 = index_cast %i : index to i64 - %i_f32 = sitofp %i_i64 : i64 to f32 - store %i_f32, %input[%i, %j] : memref<2x3xf32> - } - %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32> - call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> () - // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1] - // CHECK: [0, 0, 0] - // CHECK: [1, 1, 1] - - // Initialize shape. - %shape = alloc(%c1) : memref - %num_elements = muli %dim_x, %dim_y : index - %num_elements_i64 = index_cast %num_elements : index to i64 - store %num_elements_i64, %shape[%c0] : memref - - // 1. Ranked input, unranked output. - %output_2 = lmhlo.reshape_memref_cast %input(%shape) - : (memref<2x3xf32>, memref) -> memref<*xf32> - call @print_memref_f32(%output_2) : (memref<*xf32>) -> () - // CHECK: rank = 1 offset = 0 sizes = [6] strides = [1] - // CHECK: [0, 0, 0, 1, 1, 1] - - // 2. Unranked input, unranked output. - %output_4 = lmhlo.reshape_memref_cast %unranked_input(%shape) - : (memref<*xf32>, memref) -> memref<*xf32> - call @print_memref_f32(%output_4) : (memref<*xf32>) -> () - // CHECK: rank = 1 offset = 0 sizes = [6] strides = [1] - // CHECK: [0, 0, 0, 1, 1, 1] - return -} - -func @reshape_with_dynamic_shape_size_matrix_to_3D() { - %c0 = constant 0 : index - %c1 = constant 1 : index - %c2 = constant 2 : index - %c3 = constant 3 : index - - // Initialize input. - %input = alloc() : memref<2x3xf32> - %dim_x = dim %input, %c0 : memref<2x3xf32> - %dim_y = dim %input, %c1 : memref<2x3xf32> - scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) { - %i_i64 = index_cast %i : index to i64 - %i_f32 = sitofp %i_i64 : i64 to f32 - store %i_f32, %input[%i, %j] : memref<2x3xf32> - } - %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32> - call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> () - // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1] - // CHECK: [0, 0, 0] - // CHECK: [1, 1, 1] - - // Initialize shape. - %shape = alloc(%c3) : memref - %c1_i64 = constant 1 : i64 - %c2_i64 = constant 2 : i64 - %c3_i64 = constant 3 : i64 - store %c3_i64, %shape[%c0] : memref - store %c1_i64, %shape[%c1] : memref - store %c2_i64, %shape[%c2] : memref - - // Static shape input, dynamic output and shape. - %unranked_output = lmhlo.reshape_memref_cast %input(%shape) - : (memref<2x3xf32>, memref) -> memref<*xf32> - call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () - // CHECK: rank = 3 offset = 0 sizes = [3, 1, 2] strides = [2, 2, 1] - // CHECK: {{\[}}{{\[}}[0, 0]], - // CHECK: {{\[}}[0, 1]], - // CHECK: {{\[}}[1, 1]]] - return -} diff --git a/tests/hlo-legalize-to-lhlo-unranked.mlir b/tests/hlo-legalize-to-lhlo-unranked.mlir index 7e3d13e..ee6e7b3 100644 --- a/tests/hlo-legalize-to-lhlo-unranked.mlir +++ b/tests/hlo-legalize-to-lhlo-unranked.mlir @@ -17,7 +17,7 @@ func @dynamic_reshape_from_unranked( return %reshaped : tensor } // CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>, [[SHAPE:%.*]]: memref<1xi32>) -// CHECK-NEXT: reshape_memref_cast [[ARG]]([[SHAPE]]) +// CHECK-NEXT: memref_reshape [[ARG]]([[SHAPE]]) // CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref // ----- @@ -30,5 +30,5 @@ func @dynamic_reshape_to_unranked( return %reshaped : tensor<*xf32> } // CHECK-SAME: ([[ARG:%.*]]: memref, [[SHAPE:%.*]]: memref) -// CHECK-NEXT: reshape_memref_cast [[ARG]]([[SHAPE]]) +// CHECK-NEXT: memref_reshape [[ARG]]([[SHAPE]]) // CHECK-SAME: : (memref, memref) -> memref<*xf32> diff --git a/tests/hlo-legalize-to-lhlo.mlir b/tests/hlo-legalize-to-lhlo.mlir index 7f880bd..399ec9e 100644 --- a/tests/hlo-legalize-to-lhlo.mlir +++ b/tests/hlo-legalize-to-lhlo.mlir @@ -197,10 +197,11 @@ func @dyn_broadcast(%operand: memref) { // CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]] // CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index - // CHECK: %[[TRANSFORMED_MEMREF:.*]] = lmhlo.dynamic_memref_cast - // CHECK-SAME: %[[OPERAND]](%[[RESULT_DIM_1]], %[[RESULT_DIM_2]]) - // CHECK-SAME: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]] - // CHECK-SAME: : memref -> memref + // CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to + // CHECK-SAME: offset: [0], + // CHECK-SAME: sizes: {{\[}}%[[RESULT_DIM_1]], %[[RESULT_DIM_2]]] + // CHECK-SAME: strides: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]] + // CHECK-SAME: : memref to memref // CHECK: "lmhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) { // CHECK-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> diff --git a/tests/lhlo-fuse-linalg.mlir b/tests/lhlo-fuse-linalg.mlir index e51bdfe..628b6bd 100644 --- a/tests/lhlo-fuse-linalg.mlir +++ b/tests/lhlo-fuse-linalg.mlir @@ -267,7 +267,7 @@ func @view_result(%arg0: memref, %arg1: memref, %arg2: index) %13 = absf %arg3 : f32 linalg.yield %13 : f32 } - %2 = lmhlo.reshape_memref_cast %1(%arg1) + %2 = memref_reshape %1(%arg1) : (memref, memref) -> memref<*xf32> return %2 : memref<*xf32> } @@ -279,7 +279,7 @@ func @view_result(%arg0: memref, %arg1: memref, %arg2: index) // CHECK-NOT: scf.for // CHECK: linalg.generic // CHECK: absf -// CHECK: reshape_memref_cast +// CHECK: memref_reshape // TILED-LABEL: func @view_result // TILED-DAG: %[[C2:.*]] = constant 2 @@ -288,7 +288,7 @@ func @view_result(%arg0: memref, %arg1: memref, %arg2: index) // TILED-NOT: scf.for // TILED: linalg.generic // TILED: absf -// TILED: reshape_memref_cast +// TILED: memref_reshape // PLOOP-LABEL: func @view_result @@ -297,5 +297,5 @@ func @view_result(%arg0: memref, %arg1: memref, %arg2: index) // PLOOP-NOT: scf.parallel // PLOOP: linalg.generic // PLOOP: absf -// PLOOP: reshape_memref_cast +// PLOOP: memref_reshape diff --git a/tests/lhlo-legalize-to-llvm.mlir b/tests/lhlo-legalize-to-llvm.mlir deleted file mode 100644 index 45c383b..0000000 --- a/tests/lhlo-legalize-to-llvm.mlir +++ /dev/null @@ -1,65 +0,0 @@ -// RUN: mlir-hlo-opt %s -lower-affine -convert-scf-to-std -test-lhlo-legalize-to-llvm -split-input-file | FileCheck %s - -// CHECK-LABEL: func @static_memref_cast -func @static_memref_cast(%buf : memref<10x1x5xf32>) { - %0 = lmhlo.static_memref_cast %buf - : memref<10x1x5xf32> -> memref<10x5xf32, offset: 2, strides: [5, 1]> - return -} -// CHECK: %[[INPUT_MEMREF_BLDR:.*]] = llvm.mlir.undef : [[DESCRIPTOR_TYPE_3D:!.*]] -// CHECK: llvm.insertvalue -// CHECK: %[[MEMREF_BLDR_0:.*]] = llvm.mlir.undef : [[DESCRIPTOR_TYPE_2D:!.*]] - -// CHECK: %[[IN_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF:.*]][0] : [[DESCRIPTOR_TYPE_3D]] -// CHECK: %[[PTR:.*]] = llvm.bitcast %[[IN_PTR]] : !llvm.ptr to !llvm.ptr -// CHECK: %[[MEMREF_BLDR_1:.*]] = llvm.insertvalue %[[PTR]], %[[MEMREF_BLDR_0]][0] : [[DESCRIPTOR_TYPE_2D]] - -// CHECK: %[[IN_ALIGNED_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][1] : [[DESCRIPTOR_TYPE_3D]] -// CHECK: %[[ALIGNED_PTR:.*]] = llvm.bitcast %[[IN_ALIGNED_PTR]] : !llvm.ptr to !llvm.ptr -// CHECK: %[[MEMREF_BLDR_2:.*]] = llvm.insertvalue %[[ALIGNED_PTR]], %[[MEMREF_BLDR_1]][1] : [[DESCRIPTOR_TYPE_2D]] - -// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 -// CHECK: %[[MEMREF_BLDR_3:.*]] = llvm.insertvalue %[[C2]], %[[MEMREF_BLDR_2]][2] : [[DESCRIPTOR_TYPE_2D]] - -// CHECK: %[[C10:.*]] = llvm.mlir.constant(10 : index) : !llvm.i64 -// CHECK: %[[MEMREF_BLDR_4:.*]] = llvm.insertvalue %[[C10]], %[[MEMREF_BLDR_3]][3, 0] : [[DESCRIPTOR_TYPE_2D]] -// CHECK: %[[C5:.*]] = llvm.mlir.constant(5 : index) : !llvm.i64 -// CHECK: %[[MEMREF_BLDR_5:.*]] = llvm.insertvalue %[[C5]], %[[MEMREF_BLDR_4]][4, 0] : [[DESCRIPTOR_TYPE_2D]] -// CHECK: %[[C5_:.*]] = llvm.mlir.constant(5 : index) : !llvm.i64 -// CHECK: %[[MEMREF_BLDR_6:.*]] = llvm.insertvalue %[[C5_]], %[[MEMREF_BLDR_5]][3, 1] : [[DESCRIPTOR_TYPE_2D]] -// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK: %[[MEMREF_BLDR_7:.*]] = llvm.insertvalue %[[C1]], %[[MEMREF_BLDR_6]][4, 1] : [[DESCRIPTOR_TYPE_2D]] - -// ----- - -// CHECK-LABEL: func @dynamic_memref_cast -func @dynamic_memref_cast(%buf : memref) { - %size_X = constant 10 : index - %size_Y = constant 50 : index - %stride_X = constant 1 : index - %stride_Y = constant 0 : index - %0 = lmhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%stride_X, %stride_Y] - : memref -> memref - return -} -// CHECK: %[[C10:.*]] = llvm.mlir.constant(10 : index) : !llvm.i64 -// CHECK: %[[C50:.*]] = llvm.mlir.constant(50 : index) : !llvm.i64 -// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 - -// CHECK: %[[MEMREF_BLDR_0:.*]] = llvm.mlir.undef : [[DESCRIPTOR_TYPE:!.*]] - -// CHECK: %[[IN_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF:.*]][0] : [[DESCRIPTOR_TYPE]] -// CHECK: %[[PTR:.*]] = llvm.bitcast %[[IN_PTR]] : !llvm.ptr to !llvm.ptr -// CHECK: %[[MEMREF_BLDR_1:.*]] = llvm.insertvalue %[[PTR]], %[[MEMREF_BLDR_0]][0] : [[DESCRIPTOR_TYPE]] - -// CHECK: %[[IN_ALIGNED_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][1] : [[DESCRIPTOR_TYPE]] -// CHECK: %[[ALIGNED_PTR:.*]] = llvm.bitcast %[[IN_ALIGNED_PTR]] : !llvm.ptr to !llvm.ptr -// CHECK: %[[MEMREF_BLDR_2:.*]] = llvm.insertvalue %[[ALIGNED_PTR]], %[[MEMREF_BLDR_1]][1] : [[DESCRIPTOR_TYPE]] - -// CHECK: %[[SRC_OFFSET:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][2] : [[DESCRIPTOR_TYPE]] -// CHECK: %[[MEMREF_BLDR_3:.*]] = llvm.insertvalue %[[SRC_OFFSET]], %[[MEMREF_BLDR_2]][2] : [[DESCRIPTOR_TYPE]] -// CHECK: %[[MEMREF_BLDR_4:.*]] = llvm.insertvalue %[[C10]], %[[MEMREF_BLDR_3]][3, 0] : [[DESCRIPTOR_TYPE]] -// CHECK: %[[MEMREF_BLDR_5:.*]] = llvm.insertvalue %[[C1]], %[[MEMREF_BLDR_4]][4, 0] : [[DESCRIPTOR_TYPE]] -// CHECK: %[[MEMREF_BLDR_6:.*]] = llvm.insertvalue %[[C50]], %[[MEMREF_BLDR_5]][3, 1] : [[DESCRIPTOR_TYPE]] -// CHECK: %[[MEMREF_BLDR_7:.*]] = llvm.insertvalue %[[C0]], %[[MEMREF_BLDR_6]][4, 1] : [[DESCRIPTOR_TYPE]] diff --git a/tests/lhlo_ops.mlir b/tests/lhlo_ops.mlir index 30ff965..2167cad 100644 --- a/tests/lhlo_ops.mlir +++ b/tests/lhlo_ops.mlir @@ -429,120 +429,6 @@ func @case_memref(%index: memref, %operand_1: memref, %operand_2: memr // ----- -func @static_memref_cast(%in: memref<10x1xf32>) { - %out = lmhlo.static_memref_cast %in - : memref<10x1xf32> -> memref<10xf32, offset: 0, strides: [1]> - return -} -// CHECK-LABEL: func @static_memref_cast - -// ----- - -func @static_memref_cast_dynamic_operand(%in: memref<10x?xf32>) { - // expected-error @+1 {{operand must have static shape}} - %out = lmhlo.static_memref_cast %in - : memref<10x?xf32> -> memref<10x1xf32, offset: 0, strides: [10, 1]> - return -} - -// ----- - -func @static_memref_cast_dynamic_result(%in: memref<10x1xf32>) { - // expected-error @+1 {{result must have static shape}} - %out = lmhlo.static_memref_cast %in - : memref<10x1xf32> -> memref<10x?xf32, offset: 0, strides: [?, ?]> - return -} - -// ----- - -func @dynamic_memref_cast(%in: memref) { - %size = constant 10 : index - %step = constant 1 : index - %out = lmhlo.dynamic_memref_cast %in(%size)[%step] - : memref -> memref - return -} -// CHECK-LABEL: func @dynamic_memref_cast - -// ----- - -func @dynamic_memref_cast_incompatible_result_type(%in: memref) { - // expected-error @+3 {{`sizes` args count must be equal to the rank of the output memref}} - %size = constant 10 : index - %step = constant 1 : index - %out = lmhlo.dynamic_memref_cast %in(%size)[%step] - : memref -> memref - return -} -// ----- - -// CHECK-LABEL: func @reshape_memref_cast( -func @reshape_memref_cast(%unranked: memref<*xf32>, %shape1: memref<1xi32>, - %shape2: memref<2xi32>, %shape3: memref) { - // CHECK-SAME: [[UNRANKED:%.*]]: memref<*xf32>, [[SHAPE_1:%.*]]: memref<1xi32>, - // CHECK-SAME: [[SHAPE_2:%.*]]: memref<2xi32>, [[SHAPE_3:%.*]]: memref - - // CHECK-NEXT: [[DYN_VEC:%.*]] = lmhlo.reshape_memref_cast [[UNRANKED]] - // CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref - %dyn_vec = lmhlo.reshape_memref_cast %unranked(%shape1) - : (memref<*xf32>, memref<1xi32>) -> memref - - // CHECK-NEXT: [[DYN_MAT:%.*]] = lmhlo.reshape_memref_cast [[DYN_VEC]] - // CHECK-SAME: : (memref, memref<2xi32>) -> memref - %dyn_mat = lmhlo.reshape_memref_cast %dyn_vec(%shape2) - : (memref, memref<2xi32>) -> memref - - // CHECK-NEXT: {{%.*}} = lmhlo.reshape_memref_cast [[DYN_MAT]] - // CHECK-SAME: : (memref, memref) -> memref<*xf32> - %new_unranked = lmhlo.reshape_memref_cast %dyn_mat(%shape3) - : (memref, memref) -> memref<*xf32> - return -} - -// ----- - -func @reshape_memref_cast_element_type_mismatch( - %buf: memref<*xf32>, %shape: memref<1xi32>) { - // expected-error @+1 {{element types of source and destination memref types should be the same}} - lmhlo.reshape_memref_cast %buf(%shape) - : (memref<*xf32>, memref<1xi32>) -> memref -} - -// ----- - -func @reshape_memref_cast_dst_ranked_shape_unranked( - %buf: memref<*xf32>, %shape: memref) { - // expected-error @+1 {{cannot use shape operand with dynamic length to cast statically-ranked memref type}} - lmhlo.reshape_memref_cast %buf(%shape) - : (memref<*xf32>, memref) -> memref - return -} - -// ----- - -func @reshape_memref_cast_dst_shape_rank_mismatch( - %buf: memref<*xf32>, %shape: memref<1xi32>) { - // expected-error @+1 {{length of shape operand differs from the result's memref rank}} - lmhlo.reshape_memref_cast %buf(%shape) - : (memref<*xf32>, memref<1xi32>) -> memref - return -} - -// ----- - -func @reshape_memref_cast_affine_map_is_not_identity( - %buf: memref<4x4xf32, offset: 0, strides: [3, 2]>, - %shape: memref<1xi32>) { - // expected-error @+1 {{operand memref type should have identity affine map}} - lmhlo.reshape_memref_cast %buf(%shape) - : (memref<4x4xf32, offset: 0, strides: [3, 2]>, memref<1xi32>) - -> memref<8xf32> - return -} - -// ----- - // CHECK-LABEL: func @atan2_memrefs func @atan2_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () { "lmhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()