diff --git a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp index f8adb20..29460ca 100644 --- a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp @@ -474,6 +474,25 @@ Value emitScalarOpFor(ConversionPatternRewriter &rewriter, } } +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXNegOp +//===----------------------------------------------------------------------===// +template <> +Value emitScalarOpFor(ConversionPatternRewriter &rewriter, + Location loc, Operation *op, Type elementType, + ArrayRef scalarOperands) { + Value operand = scalarOperands[0]; + + if (elementType.isa()) { + return rewriter.create(loc, operand); + } else if (elementType.isa()) { + auto zero = emitConstantOp(rewriter, loc, elementType, 0); + return rewriter.create(loc, zero, operand); // 0 - X = -X + } else { + emitError(loc, "unsupported element type"); + } +} + // Element-wise unary ops lowering to Krnl dialect. //===----------------------------------------------------------------------===// template @@ -636,6 +655,7 @@ void populateLoweringONNXElementwiseOpPattern( ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, + ONNXElementwiseUnaryOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index b407c85..fe06204 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -681,6 +681,15 @@ bool ONNXMinOp::inferShapes() { return true; } +//===----------------------------------------------------------------------===// +// Neg +/// Infer the output shape of the ONNXNegOp. This method is required by the +/// shape inference interface. +bool ONNXNegOp::inferShapes() { + getResult().setType(getOperand().getType()); + return true; +} + //===----------------------------------------------------------------------===// // Identity /// Infer the output shape of the ONNXIdentityOp. This method is required by the diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index 2a7c34a..8533863 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -1798,7 +1798,7 @@ def ONNXMultinomialOp:ONNX_Op<"Multinomial", } def ONNXNegOp:ONNX_Op<"Neg", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Neg operation"; let description = [{ "Neg takes one input data (Tensor) and produces one output data" diff --git a/src/Transform/ONNX/ShapeInferencePass.cpp b/src/Transform/ONNX/ShapeInferencePass.cpp index ec8112b..36655cd 100644 --- a/src/Transform/ONNX/ShapeInferencePass.cpp +++ b/src/Transform/ONNX/ShapeInferencePass.cpp @@ -124,6 +124,7 @@ public: op->getName().getStringRef() != "onnx.Abs" && op->getName().getStringRef() != "onnx.Constant" && op->getName().getStringRef() != "onnx.Concat" && + op->getName().getStringRef() != "onnx.Neg" && op->getName().getStringRef() != "onnx.Unsqueeze") return false; return llvm::any_of(op->getResultTypes(), [](Type result_type) { diff --git a/utils/gen_doc.py b/utils/gen_doc.py index 075dd36..f7f44ff 100644 --- a/utils/gen_doc.py +++ b/utils/gen_doc.py @@ -59,7 +59,7 @@ OpsWithShapeInference = [ 'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', 'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin', 'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze', - 'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat' + 'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg' ] # Operations supporting canonicalization.