Remove special GemmNoBias since we can handle it using NoneType bias (#100)

* Remove special GemmNoBias since we can handle it using NoneType bias

* Remove GemmNoBias from onnx.md

Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
Tung D. Le 2020-02-25 14:20:43 +09:00 committed by GitHub
parent 732317cd5a
commit a720f9a7b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 1 additions and 106 deletions

View File

@ -1558,33 +1558,6 @@ ONNX Gather operation
1. `output`: memref of any type values or tensor of any type values 1. `output`: memref of any type values or tensor of any type values
### onnx.GemmNoBias (ONNXGemmNoBiasOp)
ONNX general matrix multiply operation without bias.
#### Description:
The "onnx.Gemm" generic matrix multiplication without bias.
#### Operands:
1. `A`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values
#### Attributes:
| Attribute | MLIR Type | Description |
| :-------: | :-------: | ----------- |
| `alpha` | `FloatAttr` | 32-bit float attribute attribute |
| `beta` | `FloatAttr` | 32-bit float attribute attribute |
| `transA` | `IntegerAttr` | 64-bit integer attribute attribute |
| `transB` | `IntegerAttr` | 64-bit integer attribute attribute |
#### Results:
1. `o_Y`: memref of any type values or tensor of any type values
### onnx.Gemm (ONNXGemmOp) ### onnx.Gemm (ONNXGemmOp)
ONNX Gemm operation ONNX Gemm operation

View File

@ -17,9 +17,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final { ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc(); auto loc = op->getLoc();
// The first predicate is unnecessary when we remove ONXGemmNoBiasOp. bool hasBias = !op->getOperand(2).getType().isa<NoneType>();
bool hasBias = (operands.size() == 3) &&
(!op->getOperand(2).getType().isa<NoneType>());
Value A, B, C; Value A, B, C;
A = operands[0]; A = operands[0];
@ -215,5 +213,4 @@ struct ONNXGemmOpLowering : public ConversionPattern {
void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns, void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns,
MLIRContext *ctx) { MLIRContext *ctx) {
patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx); patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx);
patterns.insert<ONNXGemmOpLowering<ONNXGemmNoBiasOp>>(ctx);
} }

View File

@ -90,25 +90,6 @@ def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> {
// or outputs. This decision affects only ONNX operations with optional // or outputs. This decision affects only ONNX operations with optional
// arguments not ONNX operations with variadic operands. // arguments not ONNX operations with variadic operands.
def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX general matrix multiply operation without bias.";
let description = [{
The "onnx.Gemm" generic matrix multiplication without bias.
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B,
DefaultValuedAttr<F32Attr, "1.0">:$alpha,
DefaultValuedAttr<F32Attr, "1.0">:$beta,
DefaultValuedAttr<I64Attr, "0">:$transA,
DefaultValuedAttr<I64Attr, "0">:$transB);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
}
def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias", def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let hasCanonicalizer = 1; let hasCanonicalizer = 1;

View File

@ -565,32 +565,6 @@ void ONNXGemmOp::inferShapes() {
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
} }
// GemmNoBias
void ONNXGemmNoBiasOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
return;
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
int64_t M, N, K_A, K_B;
M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1];
K_A = (transA() == 0) ? lhsTy.getShape()[1] : lhsTy.getShape()[0];
N = (transB() == 0) ? rhsTy.getShape()[1] : rhsTy.getShape()[0];
K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1];
if ((K_A != -1) and (K_B != -1) and (K_A != K_B)) {
emitError("Tensor shapes mismatched.");
}
SmallVector<int64_t, 2> dims;
dims.emplace_back(M);
dims.emplace_back(N);
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
}
/// BatchNormalizationTestMode /// BatchNormalizationTestMode
void ONNXBatchNormalizationTestModeOp::inferShapes() { void ONNXBatchNormalizationTestModeOp::inferShapes() {
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.

View File

@ -118,7 +118,6 @@ public:
op->getName().getStringRef() != "onnx.Identity" && op->getName().getStringRef() != "onnx.Identity" &&
op->getName().getStringRef() != "onnx.MatMul" && op->getName().getStringRef() != "onnx.MatMul" &&
op->getName().getStringRef() != "onnx.Gemm" && op->getName().getStringRef() != "onnx.Gemm" &&
op->getName().getStringRef() != "onnx.GemmNoBias" &&
op->getName().getStringRef() != "onnx.Reshape" && op->getName().getStringRef() != "onnx.Reshape" &&
op->getName().getStringRef() != "onnx.Transpose" && op->getName().getStringRef() != "onnx.Transpose" &&
op->getName().getStringRef() != "onnx.ReduceMax" && op->getName().getStringRef() != "onnx.ReduceMax" &&

View File

@ -806,35 +806,6 @@ func @test_gemm(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>, %arg2: tenso
// CHECK: } // CHECK: }
} }
func @test_gemm_no_bias(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>) -> tensor<*xf32> {
%0 ="onnx.GemmNoBias"(%arg0, %arg1) {alpha = 1.0 : f32, beta = 5.0 : f32, transA = 1, transB = 0} : (tensor<5x10xf32>, tensor<5x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_gemm_no_bias
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32>
// CHECK: [[ALPHA:%.+]] = constant 1.000000e+00 : f32
// CHECK: [[BETA:%.+]] = constant 5.000000e+00 : f32
// CHECK: [[DEF_LOOPS:%.+]]:3 = krnl.define_loops 3
// CHECK: [[OPT_LOOPS:%.+]]:3 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1, [[DEF_LOOPS]]#2
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: krnl.iterate([[OPT_LOOPS]]#2) with ([[DEF_LOOPS]]#2 -> %arg4 = 0 to 5) {
// CHECK: [[A:%.+]] = load %arg0[%arg4, %arg2] : memref<5x10xf32>
// CHECK: [[B:%.+]] = load %arg1[%arg4, %arg3] : memref<5x10xf32>
// CHECK: [[Y:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32>
// CHECK: [[AB:%.+]] = mulf [[A]], [[B]] : f32
// CHECK: [[SUM:%.+]] = addf [[Y]], [[AB]] : f32
// CHECK: store [[SUM]], [[RES]][%arg2, %arg3] : memref<10x10xf32>
// CHECK: }
// CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32>
// CHECK: [[ALPHA_AB:%.+]] = mulf [[ALPHA]], [[LOAD_Y]] : f32
// CHECK: store [[ALPHA_AB]], [[RES]][%arg2, %arg3] : memref<10x10xf32>
// CHECK: }
// CHECK: return [[RES]] : memref<10x10xf32>
// CHECK: }
}
func @test_sqrt(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> { func @test_sqrt(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Sqrt"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32> %0 = "onnx.Sqrt"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()