implement shape inference for negate (#108)
Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
922b6b6c91
commit
de9e9edc4d
|
@ -474,6 +474,25 @@ Value emitScalarOpFor<ONNXAbsOp>(ConversionPatternRewriter &rewriter,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Scalar unary ops for lowering ONNXNegOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <>
|
||||||
|
Value emitScalarOpFor<ONNXNegOp>(ConversionPatternRewriter &rewriter,
|
||||||
|
Location loc, Operation *op, Type elementType,
|
||||||
|
ArrayRef<Value> scalarOperands) {
|
||||||
|
Value operand = scalarOperands[0];
|
||||||
|
|
||||||
|
if (elementType.isa<FloatType>()) {
|
||||||
|
return rewriter.create<mlir::NegFOp>(loc, operand);
|
||||||
|
} else if (elementType.isa<IntegerType>()) {
|
||||||
|
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
|
||||||
|
return rewriter.create<mlir::SubIOp>(loc, zero, operand); // 0 - X = -X
|
||||||
|
} else {
|
||||||
|
emitError(loc, "unsupported element type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Element-wise unary ops lowering to Krnl dialect.
|
// Element-wise unary ops lowering to Krnl dialect.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <typename ElementwiseUnaryOp>
|
template <typename ElementwiseUnaryOp>
|
||||||
|
@ -636,6 +655,7 @@ void populateLoweringONNXElementwiseOpPattern(
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMulOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXMulOp>,
|
||||||
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXNegOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXOrOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXOrOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXReciprocalOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXReciprocalOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXReluOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXReluOp>,
|
||||||
|
|
|
@ -681,6 +681,15 @@ bool ONNXMinOp::inferShapes() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Neg
|
||||||
|
/// Infer the output shape of the ONNXNegOp. This method is required by the
|
||||||
|
/// shape inference interface.
|
||||||
|
bool ONNXNegOp::inferShapes() {
|
||||||
|
getResult().setType(getOperand().getType());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Identity
|
// Identity
|
||||||
/// Infer the output shape of the ONNXIdentityOp. This method is required by the
|
/// Infer the output shape of the ONNXIdentityOp. This method is required by the
|
||||||
|
|
|
@ -1798,7 +1798,7 @@ def ONNXMultinomialOp:ONNX_Op<"Multinomial",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXNegOp:ONNX_Op<"Neg",
|
def ONNXNegOp:ONNX_Op<"Neg",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX Neg operation";
|
let summary = "ONNX Neg operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Neg takes one input data (Tensor<T>) and produces one output data"
|
"Neg takes one input data (Tensor<T>) and produces one output data"
|
||||||
|
|
|
@ -124,6 +124,7 @@ public:
|
||||||
op->getName().getStringRef() != "onnx.Abs" &&
|
op->getName().getStringRef() != "onnx.Abs" &&
|
||||||
op->getName().getStringRef() != "onnx.Constant" &&
|
op->getName().getStringRef() != "onnx.Constant" &&
|
||||||
op->getName().getStringRef() != "onnx.Concat" &&
|
op->getName().getStringRef() != "onnx.Concat" &&
|
||||||
|
op->getName().getStringRef() != "onnx.Neg" &&
|
||||||
op->getName().getStringRef() != "onnx.Unsqueeze")
|
op->getName().getStringRef() != "onnx.Unsqueeze")
|
||||||
return false;
|
return false;
|
||||||
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
||||||
|
|
|
@ -59,7 +59,7 @@ OpsWithShapeInference = [
|
||||||
'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
||||||
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
|
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
|
||||||
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
|
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
|
||||||
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat'
|
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg'
|
||||||
]
|
]
|
||||||
|
|
||||||
# Operations supporting canonicalization.
|
# Operations supporting canonicalization.
|
||||||
|
|
Loading…
Reference in New Issue