Merge pull request #34 from clang-ykt/fix-maxpool
Fix MaxPool translation to ONNX dialect.
This commit is contained in:
		
						commit
						b50fc1fdeb
					
				| 
						 | 
				
			
			@ -632,6 +632,21 @@ private:
 | 
			
		|||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /*!
 | 
			
		||||
   * Special handle for MaxPool operations.
 | 
			
		||||
   */
 | 
			
		||||
  void ImportNodeMaxPool(
 | 
			
		||||
      onnx::NodeProto node, int nIn,
 | 
			
		||||
      std::initializer_list<std::tuple<std::string, std::string, std::string>>
 | 
			
		||||
          attrs) {
 | 
			
		||||
    int nOuts = node.output().size();
 | 
			
		||||
    if (nOuts == 1) {
 | 
			
		||||
      ImportNodeOneOut<mlir::ONNXMaxPoolSingleOutOp>(node, nIn, nOuts, attrs);
 | 
			
		||||
    } else {
 | 
			
		||||
      ImportNodeMultipleOuts<mlir::ONNXMaxPoolOp>(node, nIn, nOuts, attrs);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void ImportNode(onnx::NodeProto node) {
 | 
			
		||||
    std::vector<mlir::Value> inputs;
 | 
			
		||||
    for (auto item : node.input()) {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -303,7 +303,7 @@
 | 
			
		|||
       ImportNodeOneOut<mlir::ONNXMaxOp>(node, 1, 1, {
 | 
			
		||||
      });
 | 
			
		||||
    }else if (OpName == "MaxPool") {
 | 
			
		||||
       ImportNodeMultipleOuts<mlir::ONNXMaxPoolOp>(node, 1, 2, {
 | 
			
		||||
       ImportNodeMaxPool(node, 1, {
 | 
			
		||||
         {"auto_pad","str","NOTSET"}
 | 
			
		||||
        ,{"ceil_mode","int","0"}
 | 
			
		||||
        ,{"dilations","", ""}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -365,6 +365,7 @@ special cases:
 | 
			
		|||
def gen_code(schema,fefile) :
 | 
			
		||||
    special_handler = dict([
 | 
			
		||||
        ("Conv", "ImportNodeConv"),
 | 
			
		||||
        ("MaxPool", "ImportNodeMaxPool"),
 | 
			
		||||
        #("Transpose", "ImportNodeTranspose")
 | 
			
		||||
        ])
 | 
			
		||||
    special_type = dict([
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -78,6 +78,17 @@ def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> {
 | 
			
		|||
  }];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// ONNX Operations for handling optional arguments
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
// To allow pattern matching on operations with optional arguments/outputs we
 | 
			
		||||
// implement variants of the original ONNX dialect operations. The ONNX
 | 
			
		||||
// operations automatically generated by the `gen_doc.py` script and included
 | 
			
		||||
// in the `onnxop.inc` file have all optional arguments and outputs present.
 | 
			
		||||
// In the operations below we include the variants with missing operands
 | 
			
		||||
// or outputs. This decision affects only ONNX operations with optional
 | 
			
		||||
// arguments not ONNX operations with variadic operands.
 | 
			
		||||
 | 
			
		||||
def ONNXFullGemmOp: ONNX_Op<"FullGemm",
 | 
			
		||||
    [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
 | 
			
		||||
| 
						 | 
				
			
			@ -114,4 +125,15 @@ def ONNXConv3Op:ONNX_Op<"Conv3",
 | 
			
		|||
  let results = (outs AnyTensor);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
 | 
			
		||||
    [NoSideEffect]> {
 | 
			
		||||
  let summary = "ONNX MaxPool operation with a single output.";
 | 
			
		||||
  let description = [{
 | 
			
		||||
    "ONNX MaxPool operation with a single output."
 | 
			
		||||
    "See ONNXMaxPoolOp for a full description of the MaxPool semantics."
 | 
			
		||||
  }];
 | 
			
		||||
  let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X);
 | 
			
		||||
  let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#endif // ONNX_OPS
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue