Remove special GemmNoBias since we can handle it using NoneType bias (#100)
* Remove special GemmNoBias since we can handle it using NoneType bias * Remove GemmNoBias from onnx.md Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
		
							parent
							
								
									732317cd5a
								
							
						
					
					
						commit
						a720f9a7b2
					
				| 
						 | 
				
			
			@ -1558,33 +1558,6 @@ ONNX Gather operation
 | 
			
		|||
 | 
			
		||||
1. `output`: memref of any type values or tensor of any type values
 | 
			
		||||
 | 
			
		||||
### onnx.GemmNoBias (ONNXGemmNoBiasOp)
 | 
			
		||||
ONNX general matrix multiply operation without bias.
 | 
			
		||||
 | 
			
		||||
#### Description:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
The "onnx.Gemm" generic matrix multiplication without bias.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#### Operands:
 | 
			
		||||
 | 
			
		||||
1. `A`: memref of any type values or tensor of any type values
 | 
			
		||||
1. `B`: memref of any type values or tensor of any type values
 | 
			
		||||
 | 
			
		||||
#### Attributes:
 | 
			
		||||
 | 
			
		||||
| Attribute | MLIR Type | Description |
 | 
			
		||||
| :-------: | :-------: | ----------- |
 | 
			
		||||
| `alpha` | `FloatAttr` | 32-bit float attribute attribute |
 | 
			
		||||
| `beta` | `FloatAttr` | 32-bit float attribute attribute |
 | 
			
		||||
| `transA` | `IntegerAttr` | 64-bit integer attribute attribute |
 | 
			
		||||
| `transB` | `IntegerAttr` | 64-bit integer attribute attribute |
 | 
			
		||||
 | 
			
		||||
#### Results:
 | 
			
		||||
 | 
			
		||||
1. `o_Y`: memref of any type values or tensor of any type values
 | 
			
		||||
 | 
			
		||||
### onnx.Gemm (ONNXGemmOp)
 | 
			
		||||
ONNX Gemm operation
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -17,9 +17,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
 | 
			
		|||
  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
 | 
			
		||||
                  ConversionPatternRewriter &rewriter) const final {
 | 
			
		||||
    auto loc = op->getLoc();
 | 
			
		||||
    // The first predicate is unnecessary when we remove ONXGemmNoBiasOp.
 | 
			
		||||
    bool hasBias = (operands.size() == 3) &&
 | 
			
		||||
                   (!op->getOperand(2).getType().isa<NoneType>());
 | 
			
		||||
    bool hasBias = !op->getOperand(2).getType().isa<NoneType>();
 | 
			
		||||
 | 
			
		||||
    Value A, B, C;
 | 
			
		||||
    A = operands[0];
 | 
			
		||||
| 
						 | 
				
			
			@ -215,5 +213,4 @@ struct ONNXGemmOpLowering : public ConversionPattern {
 | 
			
		|||
void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns,
 | 
			
		||||
                                       MLIRContext *ctx) {
 | 
			
		||||
  patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx);
 | 
			
		||||
  patterns.insert<ONNXGemmOpLowering<ONNXGemmNoBiasOp>>(ctx);
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -90,25 +90,6 @@ def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> {
 | 
			
		|||
// or outputs. This decision affects only ONNX operations with optional
 | 
			
		||||
// arguments not ONNX operations with variadic operands.
 | 
			
		||||
 | 
			
		||||
def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias",
 | 
			
		||||
    [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
 | 
			
		||||
  let summary = "ONNX general matrix multiply operation without bias.";
 | 
			
		||||
  let description = [{
 | 
			
		||||
 | 
			
		||||
    The "onnx.Gemm" generic matrix multiplication without bias.
 | 
			
		||||
 | 
			
		||||
  }];
 | 
			
		||||
 | 
			
		||||
  let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A,
 | 
			
		||||
           AnyTypeOf<[AnyMemRef, AnyTensor]>:$B,
 | 
			
		||||
           DefaultValuedAttr<F32Attr, "1.0">:$alpha,
 | 
			
		||||
           DefaultValuedAttr<F32Attr, "1.0">:$beta,
 | 
			
		||||
           DefaultValuedAttr<I64Attr, "0">:$transA,
 | 
			
		||||
           DefaultValuedAttr<I64Attr, "0">:$transB);
 | 
			
		||||
 | 
			
		||||
  let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
 | 
			
		||||
    [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
 | 
			
		||||
  let hasCanonicalizer = 1;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -565,32 +565,6 @@ void ONNXGemmOp::inferShapes() {
 | 
			
		|||
  getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GemmNoBias
 | 
			
		||||
 | 
			
		||||
void ONNXGemmNoBiasOp::inferShapes() {
 | 
			
		||||
  // Cannot infer shape if no shape exists.
 | 
			
		||||
  if (!getOperand(0).getType().isa<RankedTensorType>() ||
 | 
			
		||||
      !getOperand(1).getType().isa<RankedTensorType>())
 | 
			
		||||
    return;
 | 
			
		||||
  auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
 | 
			
		||||
  auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
 | 
			
		||||
 | 
			
		||||
  int64_t M, N, K_A, K_B;
 | 
			
		||||
  M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1];
 | 
			
		||||
  K_A = (transA() == 0) ? lhsTy.getShape()[1] : lhsTy.getShape()[0];
 | 
			
		||||
  N = (transB() == 0) ? rhsTy.getShape()[1] : rhsTy.getShape()[0];
 | 
			
		||||
  K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1];
 | 
			
		||||
 | 
			
		||||
  if ((K_A != -1) and (K_B != -1) and (K_A != K_B)) {
 | 
			
		||||
    emitError("Tensor shapes mismatched.");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  SmallVector<int64_t, 2> dims;
 | 
			
		||||
  dims.emplace_back(M);
 | 
			
		||||
  dims.emplace_back(N);
 | 
			
		||||
  getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// BatchNormalizationTestMode
 | 
			
		||||
void ONNXBatchNormalizationTestModeOp::inferShapes() {
 | 
			
		||||
  // Cannot infer shape if no shape exists.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -118,7 +118,6 @@ public:
 | 
			
		|||
        op->getName().getStringRef() != "onnx.Identity" &&
 | 
			
		||||
        op->getName().getStringRef() != "onnx.MatMul" &&
 | 
			
		||||
        op->getName().getStringRef() != "onnx.Gemm" &&
 | 
			
		||||
        op->getName().getStringRef() != "onnx.GemmNoBias" &&
 | 
			
		||||
        op->getName().getStringRef() != "onnx.Reshape" &&
 | 
			
		||||
        op->getName().getStringRef() != "onnx.Transpose" &&
 | 
			
		||||
        op->getName().getStringRef() != "onnx.ReduceMax" &&
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -806,35 +806,6 @@ func @test_gemm(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>, %arg2: tenso
 | 
			
		|||
  // CHECK: }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func @test_gemm_no_bias(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>) -> tensor<*xf32> {
 | 
			
		||||
  %0 ="onnx.GemmNoBias"(%arg0, %arg1) {alpha = 1.0 : f32, beta = 5.0 : f32, transA = 1, transB = 0} : (tensor<5x10xf32>, tensor<5x10xf32>) -> tensor<*xf32>
 | 
			
		||||
  "std.return"(%0) : (tensor<*xf32>) -> ()
 | 
			
		||||
 | 
			
		||||
  // CHECK-LABEL: test_gemm_no_bias
 | 
			
		||||
  // CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32>
 | 
			
		||||
  // CHECK: [[ALPHA:%.+]] = constant 1.000000e+00 : f32
 | 
			
		||||
  // CHECK: [[BETA:%.+]] = constant 5.000000e+00 : f32
 | 
			
		||||
  // CHECK: [[DEF_LOOPS:%.+]]:3 = krnl.define_loops 3
 | 
			
		||||
  // CHECK: [[OPT_LOOPS:%.+]]:3 = krnl.optimize_loops  {
 | 
			
		||||
  // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1, [[DEF_LOOPS]]#2
 | 
			
		||||
  // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
 | 
			
		||||
  // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
 | 
			
		||||
  // CHECK: krnl.iterate([[OPT_LOOPS]]#2) with ([[DEF_LOOPS]]#2 -> %arg4 = 0 to 5) {
 | 
			
		||||
  // CHECK: [[A:%.+]] = load %arg0[%arg4, %arg2] : memref<5x10xf32>
 | 
			
		||||
  // CHECK: [[B:%.+]] = load %arg1[%arg4, %arg3] : memref<5x10xf32>
 | 
			
		||||
  // CHECK: [[Y:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32>
 | 
			
		||||
  // CHECK: [[AB:%.+]] = mulf [[A]], [[B]] : f32
 | 
			
		||||
  // CHECK: [[SUM:%.+]] = addf [[Y]], [[AB]] : f32
 | 
			
		||||
  // CHECK: store [[SUM]], [[RES]][%arg2, %arg3] : memref<10x10xf32>
 | 
			
		||||
  // CHECK: }
 | 
			
		||||
  // CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32>
 | 
			
		||||
  // CHECK: [[ALPHA_AB:%.+]] = mulf [[ALPHA]], [[LOAD_Y]] : f32
 | 
			
		||||
  // CHECK: store [[ALPHA_AB]], [[RES]][%arg2, %arg3] : memref<10x10xf32>
 | 
			
		||||
  // CHECK: }
 | 
			
		||||
  // CHECK: return [[RES]] : memref<10x10xf32>
 | 
			
		||||
  // CHECK: }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func @test_sqrt(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
 | 
			
		||||
  %0 = "onnx.Sqrt"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
 | 
			
		||||
  "std.return"(%0) : (tensor<*xf32>) -> ()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue