[MLIR] Add SizeOp conversion from ONNX dialect to Krnl dialect (#295)
* [MLIR] Add SizeOp conversion from ONNX dialect to Krnl dialect Added ONNXSizeOp conversion from ONNX dialect to Krnl dialect. This op is added as a part of --convert-onnx-to-krnl pass. Signed-off-by: Prashant Kumar <pk5561@gmail.com> * Add unit tests for Size op. * Remove unit tests. Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
3520dbd6e1
commit
4cc16aceb7
|
@ -24,6 +24,7 @@ add_library(OMONNXToKrnl
|
|||
Tensor/Concat.cpp
|
||||
Tensor/Split.cpp
|
||||
Tensor/Gather.cpp
|
||||
Tensor/Size.cpp
|
||||
ConvertONNXToKrnl.cpp)
|
||||
target_link_libraries(OMONNXToKrnl
|
||||
onnx)
|
||||
|
|
|
@ -103,6 +103,7 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
|
|||
populateLoweringONNXConcatOpPattern(patterns, &getContext());
|
||||
populateLoweringONNXSqueezeOpPattern(patterns, &getContext());
|
||||
populateLoweringONNXSplitOpPattern(patterns, &getContext());
|
||||
populateLoweringONNXSizeOpPattern(patterns, &getContext());
|
||||
// Neural network
|
||||
populateLoweringONNXConvOpPattern(patterns, &getContext());
|
||||
populateLoweringONNXNormalizationOpPattern(patterns, &getContext());
|
||||
|
|
|
@ -253,6 +253,9 @@ void populateLoweringONNXSqueezeOpPattern(
|
|||
void populateLoweringONNXSplitOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||
|
||||
void populateLoweringONNXSizeOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||
|
||||
bool checkOpResultIsUsedByGetRef(AllocOp *allocOp);
|
||||
|
||||
int64_t getMemRefSizeInBytes(Value val);
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
//===---------------- Size.cpp - Lowering Size Op
|
||||
//-------------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file lowers the ONNX Size Operator to Krnl dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
struct ONNXSizeOpLowering : public ConversionPattern {
|
||||
ONNXSizeOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(mlir::ONNXSizeOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
// Gather info.
|
||||
Location loc = op->getLoc();
|
||||
Value alloc;
|
||||
bool insertDealloc = checkInsertDealloc(op);
|
||||
ONNXSizeOp sizeOp = llvm::dyn_cast<ONNXSizeOp>(op);
|
||||
|
||||
ONNXSizeOpAdaptor operandAdaptor(operands);
|
||||
Value data = operandAdaptor.data();
|
||||
ArrayRef<int64_t> dataShape = data.getType().cast<MemRefType>().getShape();
|
||||
Value resultOperand = sizeOp.size();
|
||||
ValueRange indices;
|
||||
MemRefType memRefType = convertToMemRefType(*op->result_type_begin());
|
||||
|
||||
if (hasAllConstantDimensions(memRefType))
|
||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||
else
|
||||
alloc = insertAllocAndDealloc(
|
||||
memRefType, loc, rewriter, insertDealloc, {resultOperand});
|
||||
|
||||
// Accumulate static dimensions first.
|
||||
int64_t staticNumElement = 1;
|
||||
bool allStaticDimensions = true;
|
||||
for (unsigned i = 0; i < dataShape.size(); i++) {
|
||||
if (dataShape[i] != -1)
|
||||
staticNumElement *= dataShape[i];
|
||||
else
|
||||
allStaticDimensions = false;
|
||||
}
|
||||
// Accumulate the remaining dimensions that are unknown.
|
||||
Value noElements = emitConstantOp(
|
||||
rewriter, loc, memRefType.getElementType(), staticNumElement);
|
||||
if (!allStaticDimensions) {
|
||||
for (unsigned i = 0; i < dataShape.size(); i++) {
|
||||
if (dataShape[i] == -1) {
|
||||
Value index = rewriter.create<DimOp>(loc, data, i);
|
||||
Value dim = rewriter.create<IndexCastOp>(
|
||||
loc, index, memRefType.getElementType());
|
||||
noElements = rewriter.create<MulIOp>(loc, noElements, dim);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.create<AffineStoreOp>(loc, noElements, alloc, llvm::None);
|
||||
rewriter.replaceOp(op, alloc);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateLoweringONNXSizeOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||
patterns.insert<ONNXSizeOpLowering>(ctx);
|
||||
}
|
|
@ -422,6 +422,12 @@ test_to_enable = [
|
|||
|
||||
# ConstantOfShape
|
||||
"test_constantofshape_float_ones_cpu",
|
||||
|
||||
# Size
|
||||
# TODO(tjingrant): fix unit test for size ops.
|
||||
# "test_size_cpu",
|
||||
# "test_size_example_cpu",
|
||||
|
||||
# Error:
|
||||
# Items are not equal:
|
||||
# ACTUAL: dtype('int32')
|
||||
|
|
|
@ -2135,6 +2135,42 @@ func @cast_lowering_f64f32_10(%arg0: tensor<10xf64>) -> tensor<*xf32> {
|
|||
|
||||
// -----
|
||||
|
||||
func @test_size_known(%arg0: tensor<2x2xf32>) -> tensor<i64> {
|
||||
%1 = "onnx.Size"(%arg0) : (tensor<2x2xf32>) -> tensor<i64>
|
||||
"std.return"(%1) : (tensor<i64>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_size_known
|
||||
// CHECK: [[RES:%.+]] = alloc() : memref<i64>
|
||||
// CHECK-NEXT [[SIZE:%.+]] = constant 4 : i64
|
||||
// CHECK-NEXT affine.store [[SIZE]], [[RES]][] : memref<i64>
|
||||
// CHECK-NEXT return [[RES]] : memref<i64>
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_size_unknown(%arg0 : tensor<?x2x?xf32>) -> tensor<i64> {
|
||||
|
||||
// CHECK-LABEL: test_size_unknown
|
||||
// CHECK: [[RES:%.+]] = alloc() : memref<i64>
|
||||
// CHECK-NEXT: [[INIT:%.+]] = constant 2 : i64
|
||||
// CHECK-NEXT: [[IND1:%.+]] = constant 0 : index
|
||||
// CHECK-NEXT: [[DIM1:%.+]] = dim %arg0, [[IND1]] : memref<?x2x?xf32>
|
||||
// CHECK-NEXT: [[CAST1:%.+]] = index_cast [[DIM1]] : index to i64
|
||||
// CHECK-NEXT: [[TMP1:%.+]] = muli [[INIT]], [[CAST1]] : i64
|
||||
// CHECK-NEXT: [[IND2:%.+]] = constant 2 : index
|
||||
// CHECK-NEXT: [[DIM2:%.+]] = dim %arg0, [[IND2]] : memref<?x2x?xf32>
|
||||
// CHECK-NEXT: [[IND3:%.+]] = index_cast [[DIM2]] : index to i64
|
||||
// CHECK-NEXT: [[SIZE:%.+]] = muli [[TMP1]], [[IND3]] : i64
|
||||
// CHECK-NEXT: affine.store [[SIZE]], [[RES]][] : memref<i64>
|
||||
// CHECK-NEXT: return [[RES]] : memref<i64>
|
||||
|
||||
%1 = "onnx.Size"(%arg0) : (tensor<?x2x?xf32>) -> tensor<i64>
|
||||
"std.return"(%1) : (tensor<i64>) -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test gather along axis 0, first example in ONNX for Gather.
|
||||
func @test_gather_axis0(%arg0 : tensor<3x2xf32>) -> tensor<2x2x2xf32> {
|
||||
%indices = "onnx.Constant"() {value = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>} : () -> tensor<2x2xi64>
|
||||
|
|
Loading…
Reference in New Issue