Introduce krnl.shape operation and its lowering (#267)
* Reorganize main function. * Follow review comments. * Emit constants are globals in Krnl and LLVM dialects. * Add krnl shape op and lowering pass. * Add lowering function. * Clean-up code. * Remove duplicate entry. * Add test. * Update LowerKrnlShape.cpp * Update KrnlToLLVM.cpp
This commit is contained in:
parent
13e8070708
commit
f278f08120
|
@ -143,6 +143,7 @@ find_mlir_lib(MLIRCopyOpInterface)
|
|||
find_mlir_lib(MLIRDialect)
|
||||
find_mlir_lib(MLIREDSC)
|
||||
find_mlir_lib(MLIRExecutionEngine)
|
||||
find_mlir_lib(MLIRInferTypeOpInterface)
|
||||
find_mlir_lib(MLIRIR)
|
||||
find_mlir_lib(MLIRLLVMIR)
|
||||
find_mlir_lib(MLIRLoopAnalysis)
|
||||
|
@ -167,6 +168,10 @@ find_mlir_lib(MLIRTargetLLVMIR)
|
|||
find_mlir_lib(MLIRTransforms)
|
||||
find_mlir_lib(MLIRTransformUtils)
|
||||
find_mlir_lib(MLIRSupport)
|
||||
find_mlir_lib(MLIRShape)
|
||||
find_mlir_lib(MLIRShapeToStandard)
|
||||
find_mlir_lib(MLIRShapeToSCF)
|
||||
find_mlir_lib(MLIRSideEffectInterfaces)
|
||||
find_mlir_lib(MLIROpenMP)
|
||||
find_mlir_lib(MLIROptLib)
|
||||
find_mlir_lib(MLIRTableGen)
|
||||
|
@ -259,6 +264,10 @@ set(MLIRLibs
|
|||
${MLIRLinalgEDSC}
|
||||
${MLIRViewLikeInterface}
|
||||
${MLIRPresburger}
|
||||
${MLIRShape}
|
||||
${MLIRShapeToStandard}
|
||||
${MLIRShapeToSCF}
|
||||
${MLIRInferTypeOpInterface}
|
||||
# strict order verified
|
||||
${LLVMBitWriter}
|
||||
${LLVMObject}
|
||||
|
|
|
@ -24,7 +24,8 @@ set(OMLibs
|
|||
OMPackKrnlGlobalConstants
|
||||
OMEnableMemoryPool
|
||||
OMBundleMemoryPools
|
||||
OMDisconnectKrnlDimFromAlloc)
|
||||
OMDisconnectKrnlDimFromAlloc
|
||||
OMLowerKrnlShape)
|
||||
set(OMLibs ${OMLibs} PARENT_SCOPE)
|
||||
|
||||
add_subdirectory(Tool)
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
|
||||
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
|
||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
||||
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
|
@ -857,6 +858,7 @@ void mlir::populateAffineAndKrnlToLLVMConversion(
|
|||
LLVMTypeConverter &typeConverter) {
|
||||
populateAffineToStdConversionPatterns(patterns, ctx);
|
||||
populateLoopToStdConversionPatterns(patterns, ctx);
|
||||
populateShapeToStandardConversionPatterns(patterns, ctx);
|
||||
populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
||||
|
||||
patterns.insert<KrnlGlobalOpLowering, KrnlPackedConstOpLowering>(
|
||||
|
|
|
@ -54,7 +54,8 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
|
|||
|
||||
// We define the specific operations, or dialects, that are legal targets for
|
||||
// this lowering.
|
||||
target.addLegalDialect<KrnlOpsDialect, AffineDialect, StandardOpsDialect>();
|
||||
target.addLegalDialect<KrnlOpsDialect, AffineDialect, StandardOpsDialect,
|
||||
shape::ShapeDialect>();
|
||||
|
||||
// TODO: enable this once more ops are supported.
|
||||
// We also define the ONNX dialect as Illegal so that the conversion will fail
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/Shape/IR/ShapeBase.td"
|
||||
|
||||
def Krnl_Dialect : Dialect {
|
||||
let name = "krnl";
|
||||
|
@ -355,3 +356,21 @@ def KrnlDimOp : Op<Krnl_Dialect, "dim"> {
|
|||
let parser = ?;
|
||||
let printer = ?;
|
||||
}
|
||||
|
||||
def KrnlShapeOp : Op<Krnl_Dialect, "shape"> {
|
||||
let summary = "Krnl operation to retreieve the shape of a MemRef.";
|
||||
let description = [{
|
||||
Extracts the shape of a MemRef:
|
||||
```
|
||||
"krnl.shape"(%memref)
|
||||
```
|
||||
The return result is of `shape.type`.
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyTypeOf<[AnyMemRef]>:$alloc);
|
||||
let results = (outs Shape_ShapeType:$shape);
|
||||
|
||||
let parser = ?;
|
||||
let printer = ?;
|
||||
}
|
||||
|
||||
|
|
|
@ -81,5 +81,11 @@ void initOMPasses() {
|
|||
[]() -> std::unique_ptr<mlir::Pass> {
|
||||
return mlir::createDisconnectKrnlDimFromAllocPass();
|
||||
});
|
||||
|
||||
mlir::registerPass("lower-krnl-shape",
|
||||
"Lower krnl.shape operation to use Shape dialect operations.",
|
||||
[]() -> std::unique_ptr<mlir::Pass> {
|
||||
return mlir::createLowerKrnlShapePass();
|
||||
});
|
||||
}
|
||||
} // namespace onnx_mlir
|
|
@ -381,6 +381,7 @@ void registerDialects() {
|
|||
mlir::registerDialect<mlir::LLVM::LLVMDialect>();
|
||||
mlir::registerDialect<mlir::scf::SCFDialect>();
|
||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||
mlir::registerDialect<mlir::shape::ShapeDialect>();
|
||||
mlir::registerDialect<mlir::ONNXOpsDialect>();
|
||||
mlir::registerDialect<mlir::KrnlOpsDialect>();
|
||||
}
|
||||
|
|
|
@ -26,11 +26,13 @@
|
|||
#include "src/Pass/Passes.hpp"
|
||||
|
||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
||||
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/InitAllDialects.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Parser.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
|
|
|
@ -43,6 +43,9 @@ std::unique_ptr<Pass> createConvertKrnlToAffinePass();
|
|||
/// Pass for lowering krnl.dim operations to standard dialect.
|
||||
std::unique_ptr<Pass> createDisconnectKrnlDimFromAllocPass();
|
||||
|
||||
/// Pass for lowering krnl.shape operation.
|
||||
std::unique_ptr<Pass> createLowerKrnlShapePass();
|
||||
|
||||
/// Pass for eliding the values of global Krnl operations.
|
||||
std::unique_ptr<Pass> createElideConstGlobalValuePass();
|
||||
|
||||
|
|
|
@ -66,6 +66,7 @@ int main(int argc, char **argv) {
|
|||
mlir::registerDialect<mlir::scf::SCFDialect>();
|
||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||
mlir::registerDialect<mlir::vector::VectorDialect>();
|
||||
mlir::registerDialect<mlir::shape::ShapeDialect>();
|
||||
|
||||
registerTransformsPasses();
|
||||
registerAffinePasses();
|
||||
|
|
|
@ -60,4 +60,17 @@ add_dependencies(OMDisconnectKrnlDimFromAlloc
|
|||
OMKrnlOps
|
||||
OMONNXOps)
|
||||
|
||||
add_library(OMLowerKrnlShape
|
||||
LowerKrnlShape.cpp)
|
||||
target_include_directories(OMLowerKrnlShape
|
||||
PRIVATE
|
||||
${ONNX_MLIR_SRC_ROOT}
|
||||
${ONNX_MLIR_BIN_ROOT}
|
||||
${ONNX_MLIR_SRC_ROOT})
|
||||
target_link_libraries(OMLowerKrnlShape
|
||||
onnx)
|
||||
add_dependencies(OMLowerKrnlShape
|
||||
OMKrnlOps
|
||||
OMONNXOps)
|
||||
|
||||
add_subdirectory(ONNX)
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
//===-------- LowerKrnlShape.cpp ------------------------------------------===//
|
||||
//
|
||||
// Copyright 2019-2020 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This pass enables the lowering of the krnl.shape operation to use Shape
|
||||
// dialect operations.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
|
||||
#include "src/Dialect/Krnl/KrnlOps.hpp"
|
||||
#include "src/Pass/Passes.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
/*!
|
||||
* RewritePattern that replaces:
|
||||
* %0 = alloc(%d) : memref<?x10x<type>, #map>
|
||||
* %1 = krnl.shape(%0) : memref<?x10x<type>> -> !shape.shape
|
||||
* with:
|
||||
* %0 = alloc(%d) : memref<?x10x<type>, #map>
|
||||
* %c0 = constant 0 : index
|
||||
* %1 = krnl.dim(%0, %c0) : memref<?x10x<type>, #map>, index
|
||||
* %c1 = constant 1 : index
|
||||
* %2 = krnl.dim(%0, %c1) : memref<?x10x<type>, #map>, index
|
||||
* %shape = shape.from_extents %1, %2
|
||||
*/
|
||||
|
||||
class LowerKrnlShape : public OpRewritePattern<KrnlShapeOp> {
|
||||
public:
|
||||
using OpRewritePattern<KrnlShapeOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
KrnlShapeOp krnlShapeOp, PatternRewriter &rewriter) const override {
|
||||
auto loc = krnlShapeOp.getLoc();
|
||||
auto rank =
|
||||
convertToMemRefType(krnlShapeOp.alloc().getType()).getShape().size();
|
||||
|
||||
SmallVector<mlir::Value, 4> fromExtentsOpOperands;
|
||||
for (int idx = 0; idx < rank; idx++) {
|
||||
auto index = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), idx));
|
||||
auto operand = rewriter.create<KrnlDimOp>(
|
||||
loc, rewriter.getIndexType(), krnlShapeOp.alloc(), index);
|
||||
fromExtentsOpOperands.emplace_back(operand);
|
||||
}
|
||||
|
||||
auto fromExtentsOp = rewriter.create<mlir::shape::FromExtentsOp>(
|
||||
loc, rewriter.getType<mlir::shape::ShapeType>(), fromExtentsOpOperands);
|
||||
rewriter.replaceOp(krnlShapeOp, fromExtentsOp.getResult());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/*!
|
||||
* Function pass that emits the shape of a MemRef.
|
||||
*/
|
||||
class LowerKrnlShapePass
|
||||
: public PassWrapper<LowerKrnlShapePass, FunctionPass> {
|
||||
public:
|
||||
void runOnFunction() override {
|
||||
auto function = getFunction();
|
||||
|
||||
ConversionTarget target(getContext());
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<LowerKrnlShape>(&getContext());
|
||||
|
||||
applyPatternsAndFoldGreedily(function, patterns);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// TODO: integrate with other passes if needed.
|
||||
std::unique_ptr<Pass> mlir::createLowerKrnlShapePass() {
|
||||
return std::make_unique<LowerKrnlShapePass>();
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
// RUN: onnx-mlir-opt --lower-krnl-shape %s -split-input-file | FileCheck %s
|
||||
|
||||
func @test_krnl_shape_lowering(%arg0: memref<?x?xf32>) -> index {
|
||||
%c1 = constant 1 : index
|
||||
%c0 = constant 0 : index
|
||||
%0 = dim %arg0, %c0 : memref<?x?xf32>
|
||||
%1 = alloc(%0) : memref<?x10xf32>
|
||||
%shape = "krnl.shape"(%1) : (memref<?x10xf32>) -> !shape.shape
|
||||
%d1 = "shape.get_extent"(%shape, %c1) : (!shape.shape, index) -> !shape.size
|
||||
%ind = "shape.size_to_index"(%d1) : (!shape.size) -> index
|
||||
%e = addi %ind, %ind : index
|
||||
return %e : index
|
||||
|
||||
// CHECK-LABEL: test_krnl_shape_lowering
|
||||
// CHECK: [[CONST0:%.+]] = constant 0 : index
|
||||
// CHECK: [[CONST1:%.+]] = constant 1 : index
|
||||
// CHECK: [[DIM:%.+]] = dim %arg0, [[CONST0]] : memref<?x?xf32>
|
||||
// CHECK: [[ALLOC:%.+]] = alloc([[DIM]]) : memref<?x10xf32>
|
||||
// CHECK: [[DIM0:%.+]] = "krnl.dim"([[ALLOC]], [[CONST0]]) : (memref<?x10xf32>, index) -> index
|
||||
// CHECK: [[DIM1:%.+]] = "krnl.dim"([[ALLOC]], [[CONST1]]) : (memref<?x10xf32>, index) -> index
|
||||
// CHECK: [[SHAPE:%.+]] = shape.from_extents [[DIM0]], [[DIM1]]
|
||||
// CHECK: [[EXTENT:%.+]] = shape.get_extent [[SHAPE]], [[CONST1]] : !shape.shape, index -> !shape.size
|
||||
// CHECK: [[EXTENT_AS_INDEX:%.+]] = shape.size_to_index [[EXTENT]] : !shape.size
|
||||
// CHECK: [[RES:%.+]] = addi [[EXTENT_AS_INDEX]], [[EXTENT_AS_INDEX]] : index
|
||||
// CHECK: return [[RES]] : index
|
||||
}
|
Loading…
Reference in New Issue