Fix convolution translation to MLIR.
This commit is contained in:
		
							parent
							
								
									fc352745e0
								
							
						
					
					
						commit
						67ec9e9009
					
				|  | @ -611,7 +611,7 @@ private: | |||
|    * a specialized function is used | ||||
|    */ | ||||
|   void ImportNodeConv( | ||||
|       onnx::NodeProto node, int nIn, int nOut, | ||||
|       onnx::NodeProto node, int nOut, | ||||
|       std::initializer_list<std::tuple<std::string, std::string, std::string>> | ||||
|           attrs) { | ||||
|     // Conv has attribute dilations, kernel_shape, pads, the default value of
 | ||||
|  | @ -624,12 +624,12 @@ private: | |||
|     // similar situation for pads, strides in AveragePool
 | ||||
|     // axes of ReduceSum,  pads, strides, dilations and kernel_shape of MaxPool
 | ||||
|     // TODO: fix this after type inference
 | ||||
|     int nOps = node.input().size(); | ||||
| 
 | ||||
|     if (node.input().size() == 1) { | ||||
|       ImportNodeOneOut<mlir::ONNXConv1Op>(node, nIn, nOut, attrs); | ||||
|     } else { | ||||
|       ImportNodeOneOut<mlir::ONNXConv3Op>(node, nIn, nOut, attrs); | ||||
|     } | ||||
|     if (nOps == 2) | ||||
|       ImportNodeOneOut<mlir::ONNXConvNoBiasOp>(node, nOps, nOut, attrs); | ||||
|     else | ||||
|       ImportNodeOneOut<mlir::ONNXConvOp>(node, nOps, nOut, attrs); | ||||
|   } | ||||
| 
 | ||||
|   void ImportNode(onnx::NodeProto node) { | ||||
|  |  | |||
|  | @ -87,7 +87,7 @@ | |||
|          {"value","", ""} | ||||
|       }); | ||||
|     }else if (OpName == "Conv") { | ||||
|        ImportNodeConv(node, 3, 1, { | ||||
|        ImportNodeConv(node, 1, { | ||||
|          {"auto_pad","str","NOTSET"} | ||||
|         ,{"dilations","", ""} | ||||
|         ,{"group","int", "1"} | ||||
|  |  | |||
|  | @ -92,26 +92,15 @@ def ONNXFullGemmOp: ONNX_Op<"FullGemm", | |||
|   let results = (outs AnyTensor); | ||||
| } | ||||
| 
 | ||||
| def ONNXConv1Op:ONNX_Op<"Conv1",  | ||||
| def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias", | ||||
|     [NoSideEffect]> { | ||||
|   let summary = "ONNX Conv operation"; | ||||
|   let summary = "ONNX Conv operation with no Bias operand."; | ||||
|   let description = [{ | ||||
|     "The convolution operator consumes an input tensor and a filter, and" | ||||
|     "computes the output." | ||||
|   }]; | ||||
|   let arguments = (ins AnyTensor:$X); | ||||
|   let results = (outs AnyTensor); | ||||
| } | ||||
| 
 | ||||
| def ONNXConv3Op:ONNX_Op<"Conv3",  | ||||
|     [NoSideEffect]> { | ||||
|   let summary = "ONNX Conv operation"; | ||||
|   let description = [{ | ||||
|     "The convolution operator consumes an input tensor and a filter, and" | ||||
|     "computes the output." | ||||
|   }]; | ||||
|   let arguments = (ins AnyTensor:$X, AnyTensor:$W, AnyTensor:$B); | ||||
|   let results = (outs AnyTensor); | ||||
|   let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); | ||||
| } | ||||
| 
 | ||||
| #endif // ONNX_OPS | ||||
|  |  | |||
|  | @ -322,7 +322,7 @@ def ONNXConvOp:ONNX_Op<"Conv", | |||
|     "The convolution operator consumes an input tensor and a filter, and" | ||||
|     "computes the output." | ||||
|   }]; | ||||
|   let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W, Variadic<AnyTypeOf<[AnyMemRef, AnyTensor]>>:$B); | ||||
|   let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); | ||||
| } | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue