Merge pull request #33 from clang-ykt/fix-gemm
Fix Gemm translation to ONNX dialect.
This commit is contained in:
		
						commit
						2ea0724e4d
					
				|  | @ -90,17 +90,17 @@ def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> { | |||
| // or outputs. This decision affects only ONNX operations with optional | ||||
| // arguments not ONNX operations with variadic operands. | ||||
| 
 | ||||
| def ONNXFullGemmOp: ONNX_Op<"FullGemm", | ||||
| def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias", | ||||
|     [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { | ||||
|   let summary = "ONNX general matrix multiply operation"; | ||||
|   let summary = "ONNX general matrix multiply operation without bias."; | ||||
|   let description = [{ | ||||
| 
 | ||||
|     The "onnx.gemm" generic matrix multiplication with bias. | ||||
|     The "onnx.Gemm" generic matrix multiplication without bias. | ||||
| 
 | ||||
|   }]; | ||||
| 
 | ||||
|   let arguments = (ins AnyTensor:$lhs_in, AnyTensor:$rhs_in, AnyTensor:$bias_in); | ||||
|   let results = (outs AnyTensor); | ||||
|   let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$lhs_in, AnyTypeOf<[AnyMemRef, AnyTensor]>:$rhs_in); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); | ||||
| } | ||||
| 
 | ||||
| def ONNXConv1Op:ONNX_Op<"Conv1",  | ||||
|  |  | |||
|  | @ -347,9 +347,9 @@ void ONNXGemmOp::inferShapes() { | |||
|   getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); | ||||
| } | ||||
| 
 | ||||
| // FullGemm
 | ||||
| // GemmNoBias
 | ||||
| 
 | ||||
| void ONNXFullGemmOp::inferShapes() { | ||||
| void ONNXGemmNoBiasOp::inferShapes() { | ||||
|   // Cannot infer shape if no shape exists.
 | ||||
|   if (!getOperand(0).getType().isa<RankedTensorType>() || | ||||
|       !getOperand(1).getType().isa<RankedTensorType>()) | ||||
|  |  | |||
|  | @ -30,9 +30,9 @@ def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>; | |||
| // Pattern-Match and Rewrite | ||||
| //===----------------------------------------------------------------------===// | ||||
| 
 | ||||
| // onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.FullGemm(%X, %Y, %Z) | ||||
| // onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.Gemm(%X, %Y, %Z) | ||||
| def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3), | ||||
|                                  (ONNXFullGemmOp $m1, $m2, $m3), | ||||
|                                  (ONNXGemmOp $m1, $m2, $m3), | ||||
| 				 [(HasOneUse $res)]>; | ||||
| 
 | ||||
| // ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X) | ||||
|  |  | |||
|  | @ -114,7 +114,7 @@ public: | |||
|         op->getName().getStringRef() != "onnx.Identity" && | ||||
|         op->getName().getStringRef() != "onnx.MatMul" && | ||||
|         op->getName().getStringRef() != "onnx.Gemm" && | ||||
|         op->getName().getStringRef() != "onnx.FullGemm" && | ||||
|         op->getName().getStringRef() != "onnx.GemmNoBias" && | ||||
|         op->getName().getStringRef() != "onnx.Reshape" && | ||||
|         op->getName().getStringRef() != "onnx.Transpose") | ||||
|       return false; | ||||
|  |  | |||
|  | @ -2,7 +2,7 @@ | |||
| 
 | ||||
| func @test_matmul_add_simplification(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>, %a2: tensor<10x10xf32>) -> tensor<10x10xf32> { | ||||
|   // CHECK-LABEL: test_matmul_add_simplification | ||||
|   // CHECK: %{{[0-9]+}} = "onnx.FullGemm"(%{{.*}}, %{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> | ||||
|   // CHECK: %{{[0-9]+}} = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> | ||||
|   %0 = "onnx.MatMul"(%a0, %a1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> | ||||
|   %1 = "onnx.Add"(%0, %a2) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> | ||||
|   "std.return"(%1) : (tensor<10x10xf32>) -> () | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue