diff --git a/src/dialect/onnx/gen_doc.py b/src/dialect/onnx/gen_doc.py index 0f56234..18b8431 100644 --- a/src/dialect/onnx/gen_doc.py +++ b/src/dialect/onnx/gen_doc.py @@ -268,7 +268,7 @@ def gen_schema(schema) : 'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', 'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', - 'Softplus', 'Softsign', 'Sqrt'] + 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze'] CanonicalList=['Add', 'Identity'] manual_code = dict([ ('DummyExample', ' let extraClassDeclaration = [{ \n'+ diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index cedcea4..bcb9b27 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -724,6 +724,46 @@ void ONNXConvNoBiasOp::inferShapes() { getResult().setType(RankedTensorType::get(dims, dataTy.getElementType())); } +//===----------------------------------------------------------------------===// +// Unsqueeze + +void ONNXUnsqueezeOp::inferShapes() { + if (!getOperand().getType().isa()) + return; + + auto operandTy = getOperand().getType().cast(); + int inRank = operandTy.getRank(); + + ArrayAttr axisAttrs = axesAttr(); + SmallVector axes; + int outRank = 0; + if (axisAttrs) { + outRank = inRank + axisAttrs.getValue().size(); + for (auto axisAttr : axisAttrs.getValue()) { + int axis = axisAttr.cast().getInt(); + axis = axis >= 0 ? axis : (outRank + axis); + // Valid range + assert(axis >= -outRank && axis <= outRank - 1); + if (std::find(axes.begin(), axes.end(), axis) == axes.end()) + axes.emplace_back(axis); + else + emitError("Duplicated axes."); + } + } else { + emitError("Axes attribute is required."); + } + + SmallVector dims; + for (int i = 0, j = 0; i < outRank || j < inRank; ++i) { + if (std::find(axes.begin(), axes.end(), i) != axes.end()) { + dims.emplace_back(1); + } else { + dims.emplace_back(operandTy.getShape()[j++]); + } + } + getResult().setType(RankedTensorType::get(dims, operandTy.getElementType())); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/src/dialect/onnx/onnxop.inc b/src/dialect/onnx/onnxop.inc index 2467dc1..8c50330 100644 --- a/src/dialect/onnx/onnxop.inc +++ b/src/dialect/onnx/onnxop.inc @@ -3502,7 +3502,7 @@ def ONNXUniqueOp:ONNX_Op<"Unique", } def ONNXUnsqueezeOp:ONNX_Op<"Unsqueeze", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Unsqueeze operation"; let description = [{ "Insert single-dimensional entries to the shape of an input tensor (`data`)." diff --git a/src/pass/lower_frontend_to_krnl.cpp b/src/pass/lower_frontend_to_krnl.cpp index 01e4fa9..0695c26 100644 --- a/src/pass/lower_frontend_to_krnl.cpp +++ b/src/pass/lower_frontend_to_krnl.cpp @@ -1176,6 +1176,79 @@ struct ONNXReshapeOpLowering : public ConversionPattern { } }; +struct ONNXUnsqueezeOpLowering : public ConversionPattern { + ONNXUnsqueezeOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXUnsqueezeOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + auto tensorType = (*op->result_type_begin()).cast(); + int outRank = tensorType.getRank(); + + // Assume that `axes` has been validated by shape inference. + // So, here we just get it. + ArrayAttr axisAttrs = llvm::dyn_cast(op).axesAttr(); + SmallVector axes; + for (auto axisAttr : axisAttrs.getValue()) { + int axis = axisAttr.cast().getInt(); + axis = axis >= 0 ? axis : (outRank + axis); + axes.emplace_back(axis); + } + + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertTensorToMemRef(tensorType); + Value alloc; + + // Compute size in bytes. + Value tensorSize = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), + getMemRefEltSizeInBytes(memRefType))); + + bool insertDealloc = checkInsertDealloc(op); + auto memRefShape = memRefType.getShape(); + if (hasAllConstantDimensions(memRefType)) { + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + for (int i = 0; i < memRefShape.size(); ++i) { + Value dimVal = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), + memRefShape[i])); + tensorSize = rewriter.create(loc, tensorSize, dimVal); + } + } else { + // Unknown dimensions are always the operand's dimensions. + SmallVector allocOperands; + for (int outIdx = 0, inIdx = 0; outIdx < memRefShape.size(); ++outIdx) { + Value dimVal = nullptr; + if (memRefShape[outIdx] < 0) { + Value index = rewriter.create(loc, operands[0], inIdx); + dimVal = rewriter.create( + loc, index, rewriter.getIntegerType(64)); + allocOperands.emplace_back(index); + } else { + dimVal = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), + memRefShape[outIdx])); + } + tensorSize = rewriter.create(loc, tensorSize, dimVal); + if (std::find(axes.begin(), axes.end(), outIdx) == axes.end()) + inIdx++; + } + alloc = rewriter.create(loc, memRefType, allocOperands); + auto *parentBlock = alloc.getDefiningOp()->getBlock(); + if (insertDealloc) { + auto dealloc = rewriter.create(loc, alloc); + dealloc.getOperation()->moveBefore(&parentBlock->back()); + } + } + + rewriter.create(loc, alloc, operands[0], tensorSize); + rewriter.replaceOp(op, alloc); + return matchSuccess(); + } +}; + //===----------------------------------------------------------------------===// // EntryPoint Op lowering to Krnl Entry Point. //===----------------------------------------------------------------------===// @@ -1304,7 +1377,7 @@ void FrontendToKrnlLoweringPass::runOnModule() { ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXReshapeOpLowering, ONNXEntryPointLowering, - ONNXSoftmaxOpLowering>(&getContext()); + ONNXSoftmaxOpLowering, ONNXUnsqueezeOpLowering>(&getContext()); // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` diff --git a/src/pass/shape_inference_pass.cpp b/src/pass/shape_inference_pass.cpp index a861ed7..0b5993b 100644 --- a/src/pass/shape_inference_pass.cpp +++ b/src/pass/shape_inference_pass.cpp @@ -121,7 +121,8 @@ public: op->getName().getStringRef() != "onnx.Transpose" && op->getName().getStringRef() != "onnx.Softmax" && op->getName().getStringRef() != "onnx.Sqrt" && - op->getName().getStringRef() != "onnx.ConvNoBias") + op->getName().getStringRef() != "onnx.ConvNoBias" && + op->getName().getStringRef() != "onnx.Unsqueeze") return false; return llvm::any_of(op->getResultTypes(), [](Type result_type) { return !result_type.isa(); diff --git a/test/backend/test.py b/test/backend/test.py index ae71d85..6ec754e 100644 --- a/test/backend/test.py +++ b/test/backend/test.py @@ -147,6 +147,16 @@ test_to_enable = [ "test_sum_one_input_cpu", "test_sum_two_inputs_cpu", + # Unsqueeze Op: + "test_unsqueeze_axis_0_cpu", + "test_unsqueeze_axis_1_cpu", + "test_unsqueeze_axis_2_cpu", + "test_unsqueeze_axis_3_cpu", + "test_unsqueeze_negative_axes_cpu", + "test_unsqueeze_three_axes_cpu", + "test_unsqueeze_two_axes_cpu", + "test_unsqueeze_unsorted_axes_cpu", + # Reciprocal Op: "test_reciprocal_cpu", "test_reciprocal_example_cpu", diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index b74b65f..2217403 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -652,3 +652,22 @@ func @test_sqrt(%arg0 : tensor) -> tensor<*xf32> { // CHECK: return [[RES]] : memref } +func @test_unsqueeze(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> { + %0 = "onnx.Unsqueeze"(%arg0) {axes=[0,3]} : (tensor<10x10xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_unsqueeze + // CHECK: [[RES:%.+]] = alloc() : memref<1x10x10x1xf32> + // CHECK: [[INBYTES:%.+]] = constant 4 : i64 + // CHECK: [[DIM1:%.+]] = constant 1 : i64 + // CHECK: [[SIZE1:%.+]] = muli [[INBYTES]], [[DIM1]] : i64 + // CHECK: [[DIM2:%.+]] = constant 10 : i64 + // CHECK: [[SIZE2:%.+]] = muli [[SIZE1]], [[DIM2]] : i64 + // CHECK: [[DIM3:%.+]] = constant 10 : i64 + // CHECK: [[SIZE3:%.+]] = muli [[SIZE2]], [[DIM3]] : i64 + // CHECK: [[DIM4:%.+]] = constant 1 : i64 + // CHECK: [[SIZE4:%.+]] = muli [[SIZE3]], [[DIM4]] : i64 + // CHECK: "krnl.memcpy"([[RES]], %arg0, [[SIZE4]]) : (memref<1x10x10x1xf32>, memref<10x10xf32>, i64) -> () + // CHECK: return [[RES]] : memref<1x10x10x1xf32> +} +