Remove special GemmNoBias since we can handle it using NoneType bias (#100)
* Remove special GemmNoBias since we can handle it using NoneType bias * Remove GemmNoBias from onnx.md Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
732317cd5a
commit
a720f9a7b2
|
@ -1558,33 +1558,6 @@ ONNX Gather operation
|
|||
|
||||
1. `output`: memref of any type values or tensor of any type values
|
||||
|
||||
### onnx.GemmNoBias (ONNXGemmNoBiasOp)
|
||||
ONNX general matrix multiply operation without bias.
|
||||
|
||||
#### Description:
|
||||
|
||||
|
||||
The "onnx.Gemm" generic matrix multiplication without bias.
|
||||
|
||||
|
||||
#### Operands:
|
||||
|
||||
1. `A`: memref of any type values or tensor of any type values
|
||||
1. `B`: memref of any type values or tensor of any type values
|
||||
|
||||
#### Attributes:
|
||||
|
||||
| Attribute | MLIR Type | Description |
|
||||
| :-------: | :-------: | ----------- |
|
||||
| `alpha` | `FloatAttr` | 32-bit float attribute attribute |
|
||||
| `beta` | `FloatAttr` | 32-bit float attribute attribute |
|
||||
| `transA` | `IntegerAttr` | 64-bit integer attribute attribute |
|
||||
| `transB` | `IntegerAttr` | 64-bit integer attribute attribute |
|
||||
|
||||
#### Results:
|
||||
|
||||
1. `o_Y`: memref of any type values or tensor of any type values
|
||||
|
||||
### onnx.Gemm (ONNXGemmOp)
|
||||
ONNX Gemm operation
|
||||
|
||||
|
|
|
@ -17,9 +17,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
|||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto loc = op->getLoc();
|
||||
// The first predicate is unnecessary when we remove ONXGemmNoBiasOp.
|
||||
bool hasBias = (operands.size() == 3) &&
|
||||
(!op->getOperand(2).getType().isa<NoneType>());
|
||||
bool hasBias = !op->getOperand(2).getType().isa<NoneType>();
|
||||
|
||||
Value A, B, C;
|
||||
A = operands[0];
|
||||
|
@ -215,5 +213,4 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
|||
void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns,
|
||||
MLIRContext *ctx) {
|
||||
patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx);
|
||||
patterns.insert<ONNXGemmOpLowering<ONNXGemmNoBiasOp>>(ctx);
|
||||
}
|
||||
|
|
|
@ -90,25 +90,6 @@ def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> {
|
|||
// or outputs. This decision affects only ONNX operations with optional
|
||||
// arguments not ONNX operations with variadic operands.
|
||||
|
||||
def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX general matrix multiply operation without bias.";
|
||||
let description = [{
|
||||
|
||||
The "onnx.Gemm" generic matrix multiplication without bias.
|
||||
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A,
|
||||
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B,
|
||||
DefaultValuedAttr<F32Attr, "1.0">:$alpha,
|
||||
DefaultValuedAttr<F32Attr, "1.0">:$beta,
|
||||
DefaultValuedAttr<I64Attr, "0">:$transA,
|
||||
DefaultValuedAttr<I64Attr, "0">:$transB);
|
||||
|
||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
|
||||
}
|
||||
|
||||
def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let hasCanonicalizer = 1;
|
||||
|
|
|
@ -565,32 +565,6 @@ void ONNXGemmOp::inferShapes() {
|
|||
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||
}
|
||||
|
||||
// GemmNoBias
|
||||
|
||||
void ONNXGemmNoBiasOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>())
|
||||
return;
|
||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
|
||||
int64_t M, N, K_A, K_B;
|
||||
M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1];
|
||||
K_A = (transA() == 0) ? lhsTy.getShape()[1] : lhsTy.getShape()[0];
|
||||
N = (transB() == 0) ? rhsTy.getShape()[1] : rhsTy.getShape()[0];
|
||||
K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1];
|
||||
|
||||
if ((K_A != -1) and (K_B != -1) and (K_A != K_B)) {
|
||||
emitError("Tensor shapes mismatched.");
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 2> dims;
|
||||
dims.emplace_back(M);
|
||||
dims.emplace_back(N);
|
||||
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||
}
|
||||
|
||||
/// BatchNormalizationTestMode
|
||||
void ONNXBatchNormalizationTestModeOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
|
|
|
@ -118,7 +118,6 @@ public:
|
|||
op->getName().getStringRef() != "onnx.Identity" &&
|
||||
op->getName().getStringRef() != "onnx.MatMul" &&
|
||||
op->getName().getStringRef() != "onnx.Gemm" &&
|
||||
op->getName().getStringRef() != "onnx.GemmNoBias" &&
|
||||
op->getName().getStringRef() != "onnx.Reshape" &&
|
||||
op->getName().getStringRef() != "onnx.Transpose" &&
|
||||
op->getName().getStringRef() != "onnx.ReduceMax" &&
|
||||
|
|
|
@ -806,35 +806,6 @@ func @test_gemm(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>, %arg2: tenso
|
|||
// CHECK: }
|
||||
}
|
||||
|
||||
func @test_gemm_no_bias(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>) -> tensor<*xf32> {
|
||||
%0 ="onnx.GemmNoBias"(%arg0, %arg1) {alpha = 1.0 : f32, beta = 5.0 : f32, transA = 1, transB = 0} : (tensor<5x10xf32>, tensor<5x10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_gemm_no_bias
|
||||
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32>
|
||||
// CHECK: [[ALPHA:%.+]] = constant 1.000000e+00 : f32
|
||||
// CHECK: [[BETA:%.+]] = constant 5.000000e+00 : f32
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:3 = krnl.define_loops 3
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:3 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1, [[DEF_LOOPS]]#2
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#2) with ([[DEF_LOOPS]]#2 -> %arg4 = 0 to 5) {
|
||||
// CHECK: [[A:%.+]] = load %arg0[%arg4, %arg2] : memref<5x10xf32>
|
||||
// CHECK: [[B:%.+]] = load %arg1[%arg4, %arg3] : memref<5x10xf32>
|
||||
// CHECK: [[Y:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32>
|
||||
// CHECK: [[AB:%.+]] = mulf [[A]], [[B]] : f32
|
||||
// CHECK: [[SUM:%.+]] = addf [[Y]], [[AB]] : f32
|
||||
// CHECK: store [[SUM]], [[RES]][%arg2, %arg3] : memref<10x10xf32>
|
||||
// CHECK: }
|
||||
// CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32>
|
||||
// CHECK: [[ALPHA_AB:%.+]] = mulf [[ALPHA]], [[LOAD_Y]] : f32
|
||||
// CHECK: store [[ALPHA_AB]], [[RES]][%arg2, %arg3] : memref<10x10xf32>
|
||||
// CHECK: }
|
||||
// CHECK: return [[RES]] : memref<10x10xf32>
|
||||
// CHECK: }
|
||||
}
|
||||
|
||||
func @test_sqrt(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Sqrt"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
|
Loading…
Reference in New Issue