implement shape inference for negate (#108)
Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
922b6b6c91
commit
de9e9edc4d
|
@ -474,6 +474,25 @@ Value emitScalarOpFor<ONNXAbsOp>(ConversionPatternRewriter &rewriter,
|
|||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Scalar unary ops for lowering ONNXNegOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
template <>
|
||||
Value emitScalarOpFor<ONNXNegOp>(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Operation *op, Type elementType,
|
||||
ArrayRef<Value> scalarOperands) {
|
||||
Value operand = scalarOperands[0];
|
||||
|
||||
if (elementType.isa<FloatType>()) {
|
||||
return rewriter.create<mlir::NegFOp>(loc, operand);
|
||||
} else if (elementType.isa<IntegerType>()) {
|
||||
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
|
||||
return rewriter.create<mlir::SubIOp>(loc, zero, operand); // 0 - X = -X
|
||||
} else {
|
||||
emitError(loc, "unsupported element type");
|
||||
}
|
||||
}
|
||||
|
||||
// Element-wise unary ops lowering to Krnl dialect.
|
||||
//===----------------------------------------------------------------------===//
|
||||
template <typename ElementwiseUnaryOp>
|
||||
|
@ -636,6 +655,7 @@ void populateLoweringONNXElementwiseOpPattern(
|
|||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
|
||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
|
||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMulOp>,
|
||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXNegOp>,
|
||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXOrOp>,
|
||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXReciprocalOp>,
|
||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXReluOp>,
|
||||
|
|
|
@ -681,6 +681,15 @@ bool ONNXMinOp::inferShapes() {
|
|||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Neg
|
||||
/// Infer the output shape of the ONNXNegOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
bool ONNXNegOp::inferShapes() {
|
||||
getResult().setType(getOperand().getType());
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Identity
|
||||
/// Infer the output shape of the ONNXIdentityOp. This method is required by the
|
||||
|
|
|
@ -1798,7 +1798,7 @@ def ONNXMultinomialOp:ONNX_Op<"Multinomial",
|
|||
}
|
||||
|
||||
def ONNXNegOp:ONNX_Op<"Neg",
|
||||
[NoSideEffect]> {
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX Neg operation";
|
||||
let description = [{
|
||||
"Neg takes one input data (Tensor<T>) and produces one output data"
|
||||
|
|
|
@ -124,6 +124,7 @@ public:
|
|||
op->getName().getStringRef() != "onnx.Abs" &&
|
||||
op->getName().getStringRef() != "onnx.Constant" &&
|
||||
op->getName().getStringRef() != "onnx.Concat" &&
|
||||
op->getName().getStringRef() != "onnx.Neg" &&
|
||||
op->getName().getStringRef() != "onnx.Unsqueeze")
|
||||
return false;
|
||||
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
||||
|
|
|
@ -59,7 +59,7 @@ OpsWithShapeInference = [
|
|||
'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
||||
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
|
||||
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
|
||||
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat'
|
||||
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg'
|
||||
]
|
||||
|
||||
# Operations supporting canonicalization.
|
||||
|
|
Loading…
Reference in New Issue