Merge pull request #32 from clang-ykt/fix-conv
Fix convolution translation to ONNX dialect
This commit is contained in:
		
						commit
						c2d31c0b78
					
				|  | @ -611,7 +611,7 @@ private: | ||||||
|    * a specialized function is used |    * a specialized function is used | ||||||
|    */ |    */ | ||||||
|   void ImportNodeConv( |   void ImportNodeConv( | ||||||
|       onnx::NodeProto node, int nIn, int nOut, |       onnx::NodeProto node, int nOut, | ||||||
|       std::initializer_list<std::tuple<std::string, std::string, std::string>> |       std::initializer_list<std::tuple<std::string, std::string, std::string>> | ||||||
|           attrs) { |           attrs) { | ||||||
|     // Conv has attribute dilations, kernel_shape, pads, the default value of
 |     // Conv has attribute dilations, kernel_shape, pads, the default value of
 | ||||||
|  | @ -624,12 +624,12 @@ private: | ||||||
|     // similar situation for pads, strides in AveragePool
 |     // similar situation for pads, strides in AveragePool
 | ||||||
|     // axes of ReduceSum,  pads, strides, dilations and kernel_shape of MaxPool
 |     // axes of ReduceSum,  pads, strides, dilations and kernel_shape of MaxPool
 | ||||||
|     // TODO: fix this after type inference
 |     // TODO: fix this after type inference
 | ||||||
|  |     int nOps = node.input().size(); | ||||||
| 
 | 
 | ||||||
|     if (node.input().size() == 1) { |     if (nOps == 2) | ||||||
|       ImportNodeOneOut<mlir::ONNXConv1Op>(node, nIn, nOut, attrs); |       ImportNodeOneOut<mlir::ONNXConvNoBiasOp>(node, nOps, nOut, attrs); | ||||||
|     } else { |     else | ||||||
|       ImportNodeOneOut<mlir::ONNXConv3Op>(node, nIn, nOut, attrs); |       ImportNodeOneOut<mlir::ONNXConvOp>(node, nOps, nOut, attrs); | ||||||
|     } |  | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /*!
 |   /*!
 | ||||||
|  |  | ||||||
|  | @ -87,7 +87,7 @@ | ||||||
|          {"value","", ""} |          {"value","", ""} | ||||||
|       }); |       }); | ||||||
|     }else if (OpName == "Conv") { |     }else if (OpName == "Conv") { | ||||||
|        ImportNodeConv(node, 3, 1, { |        ImportNodeConv(node, 1, { | ||||||
|          {"auto_pad","str","NOTSET"} |          {"auto_pad","str","NOTSET"} | ||||||
|         ,{"dilations","", ""} |         ,{"dilations","", ""} | ||||||
|         ,{"group","int", "1"} |         ,{"group","int", "1"} | ||||||
|  |  | ||||||
|  | @ -103,26 +103,15 @@ def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias", | ||||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); |   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| def ONNXConv1Op:ONNX_Op<"Conv1",  | def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias", | ||||||
|     [NoSideEffect]> { |     [NoSideEffect]> { | ||||||
|   let summary = "ONNX Conv operation"; |   let summary = "ONNX Conv operation with no Bias operand."; | ||||||
|   let description = [{ |   let description = [{ | ||||||
|     "The convolution operator consumes an input tensor and a filter, and" |     "The convolution operator consumes an input tensor and a filter, and" | ||||||
|     "computes the output." |     "computes the output." | ||||||
|   }]; |   }]; | ||||||
|   let arguments = (ins AnyTensor:$X); |   let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W); | ||||||
|   let results = (outs AnyTensor); |   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); | ||||||
| } |  | ||||||
| 
 |  | ||||||
| def ONNXConv3Op:ONNX_Op<"Conv3",  |  | ||||||
|     [NoSideEffect]> { |  | ||||||
|   let summary = "ONNX Conv operation"; |  | ||||||
|   let description = [{ |  | ||||||
|     "The convolution operator consumes an input tensor and a filter, and" |  | ||||||
|     "computes the output." |  | ||||||
|   }]; |  | ||||||
|   let arguments = (ins AnyTensor:$X, AnyTensor:$W, AnyTensor:$B); |  | ||||||
|   let results = (outs AnyTensor); |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut", | def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut", | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue