implement shape inference for negate (#108)

Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
David Byrd 2020-05-06 17:42:43 -10:00 committed by GitHub
parent 922b6b6c91
commit de9e9edc4d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 32 additions and 2 deletions

View File

@ -474,6 +474,25 @@ Value emitScalarOpFor<ONNXAbsOp>(ConversionPatternRewriter &rewriter,
} }
} }
//===----------------------------------------------------------------------===//
// Scalar unary ops for lowering ONNXNegOp
//===----------------------------------------------------------------------===//
template <>
Value emitScalarOpFor<ONNXNegOp>(ConversionPatternRewriter &rewriter,
Location loc, Operation *op, Type elementType,
ArrayRef<Value> scalarOperands) {
Value operand = scalarOperands[0];
if (elementType.isa<FloatType>()) {
return rewriter.create<mlir::NegFOp>(loc, operand);
} else if (elementType.isa<IntegerType>()) {
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
return rewriter.create<mlir::SubIOp>(loc, zero, operand); // 0 - X = -X
} else {
emitError(loc, "unsupported element type");
}
}
// Element-wise unary ops lowering to Krnl dialect. // Element-wise unary ops lowering to Krnl dialect.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <typename ElementwiseUnaryOp> template <typename ElementwiseUnaryOp>
@ -636,6 +655,7 @@ void populateLoweringONNXElementwiseOpPattern(
ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXMulOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXMulOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXNegOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXOrOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXOrOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXReciprocalOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXReciprocalOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXReluOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXReluOp>,

View File

@ -681,6 +681,15 @@ bool ONNXMinOp::inferShapes() {
return true; return true;
} }
//===----------------------------------------------------------------------===//
// Neg
/// Infer the output shape of the ONNXNegOp. This method is required by the
/// shape inference interface.
bool ONNXNegOp::inferShapes() {
getResult().setType(getOperand().getType());
return true;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Identity // Identity
/// Infer the output shape of the ONNXIdentityOp. This method is required by the /// Infer the output shape of the ONNXIdentityOp. This method is required by the

View File

@ -1798,7 +1798,7 @@ def ONNXMultinomialOp:ONNX_Op<"Multinomial",
} }
def ONNXNegOp:ONNX_Op<"Neg", def ONNXNegOp:ONNX_Op<"Neg",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Neg operation"; let summary = "ONNX Neg operation";
let description = [{ let description = [{
"Neg takes one input data (Tensor<T>) and produces one output data" "Neg takes one input data (Tensor<T>) and produces one output data"

View File

@ -124,6 +124,7 @@ public:
op->getName().getStringRef() != "onnx.Abs" && op->getName().getStringRef() != "onnx.Abs" &&
op->getName().getStringRef() != "onnx.Constant" && op->getName().getStringRef() != "onnx.Constant" &&
op->getName().getStringRef() != "onnx.Concat" && op->getName().getStringRef() != "onnx.Concat" &&
op->getName().getStringRef() != "onnx.Neg" &&
op->getName().getStringRef() != "onnx.Unsqueeze") op->getName().getStringRef() != "onnx.Unsqueeze")
return false; return false;
return llvm::any_of(op->getResultTypes(), [](Type result_type) { return llvm::any_of(op->getResultTypes(), [](Type result_type) {

View File

@ -59,7 +59,7 @@ OpsWithShapeInference = [
'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', 'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin', 'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze', 'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat' 'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg'
] ]
# Operations supporting canonicalization. # Operations supporting canonicalization.