Add support of GemmNoBias (#91)
* Add support of GemmNoBias * Fix a wrong indentation
This commit is contained in:
parent
a3f042220e
commit
f1d20e368f
|
@ -8,31 +8,34 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
template <typename GemmOp>
|
||||||
struct ONNXGemmOpLowering : public ConversionPattern {
|
struct ONNXGemmOpLowering : public ConversionPattern {
|
||||||
ONNXGemmOpLowering(MLIRContext *ctx)
|
ONNXGemmOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(mlir::ONNXGemmOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(GemmOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
PatternMatchResult
|
PatternMatchResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
|
auto has_bias = (operands.size() == 3);
|
||||||
|
|
||||||
Value A, B, C;
|
Value A, B, C;
|
||||||
A = operands[0];
|
A = operands[0];
|
||||||
B = operands[1];
|
B = operands[1];
|
||||||
|
if (has_bias)
|
||||||
C = operands[2];
|
C = operands[2];
|
||||||
|
|
||||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||||
|
|
||||||
auto alphaAttr = FloatAttr::get(memRefType.getElementType(),
|
auto alphaAttr = FloatAttr::get(memRefType.getElementType(),
|
||||||
llvm::dyn_cast<ONNXGemmOp>(op).alpha().convertToFloat());
|
llvm::dyn_cast<GemmOp>(op).alpha().convertToFloat());
|
||||||
auto betaAttr = FloatAttr::get(memRefType.getElementType(),
|
auto betaAttr = FloatAttr::get(memRefType.getElementType(),
|
||||||
llvm::dyn_cast<ONNXGemmOp>(op).beta().convertToFloat());
|
llvm::dyn_cast<GemmOp>(op).beta().convertToFloat());
|
||||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
||||||
auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
|
auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
|
||||||
|
|
||||||
bool isTransA = (llvm::dyn_cast<ONNXGemmOp>(op).transA() != 0);
|
bool isTransA = (llvm::dyn_cast<GemmOp>(op).transA() != 0);
|
||||||
bool isTransB = (llvm::dyn_cast<ONNXGemmOp>(op).transB() != 0);
|
bool isTransB = (llvm::dyn_cast<GemmOp>(op).transB() != 0);
|
||||||
|
|
||||||
// Insert an allocation and deallocation for the result of this operation.
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
Value alloc;
|
Value alloc;
|
||||||
|
@ -116,6 +119,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
||||||
// GemmOp supports unidirectional broadcasting from C to A*B.
|
// GemmOp supports unidirectional broadcasting from C to A*B.
|
||||||
// Hence, it must be enough to get broadcasting information for C only.
|
// Hence, it must be enough to get broadcasting information for C only.
|
||||||
std::map<int, Value> broadcastedDimInfo;
|
std::map<int, Value> broadcastedDimInfo;
|
||||||
|
if (has_bias) {
|
||||||
auto shape = C.getType().cast<MemRefType>().getShape();
|
auto shape = C.getType().cast<MemRefType>().getShape();
|
||||||
for (int i = 0; i < shape.size(); ++i) {
|
for (int i = 0; i < shape.size(); ++i) {
|
||||||
if (shape[i] < 0) {
|
if (shape[i] < 0) {
|
||||||
|
@ -126,6 +130,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
||||||
broadcastedDimInfo.insert(std::make_pair(i, isBroadcasted));
|
broadcastedDimInfo.insert(std::make_pair(i, isBroadcasted));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
auto outerIterateOp = rewriter.create<KrnlIterateOp>(loc, outerPack);
|
auto outerIterateOp = rewriter.create<KrnlIterateOp>(loc, outerPack);
|
||||||
|
|
||||||
|
@ -155,14 +160,18 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
||||||
auto matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, reductionPack);
|
auto matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, reductionPack);
|
||||||
|
|
||||||
// Compute beta*C, and add up to alpha*A*B (unidirectional broadcasting)
|
// Compute beta*C, and add up to alpha*A*B (unidirectional broadcasting)
|
||||||
auto loopCIVs = getLoopIVsForBroadcasting(
|
|
||||||
loc, rewriter, loopMNIVs, C, broadcastedDimInfo);
|
|
||||||
auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs);
|
|
||||||
auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs);
|
auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs);
|
||||||
auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB);
|
auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB);
|
||||||
|
if (has_bias) {
|
||||||
|
auto loopCIVs = getLoopIVsForBroadcasting(loc, rewriter, loopMNIVs, C,
|
||||||
|
broadcastedDimInfo);
|
||||||
|
auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs);
|
||||||
auto betaC = rewriter.create<MulFOp>(loc, beta, loadedC);
|
auto betaC = rewriter.create<MulFOp>(loc, beta, loadedC);
|
||||||
auto Y = rewriter.create<AddFOp>(loc, alphaAB, betaC);
|
auto Y = rewriter.create<AddFOp>(loc, alphaAB, betaC);
|
||||||
rewriter.create<StoreOp>(loc, Y, alloc, loopMNIVs);
|
rewriter.create<StoreOp>(loc, Y, alloc, loopMNIVs);
|
||||||
|
} else {
|
||||||
|
rewriter.create<StoreOp>(loc, alphaAB, alloc, loopMNIVs);
|
||||||
|
}
|
||||||
|
|
||||||
// Insert instructions to do matrix multiplication: A*B
|
// Insert instructions to do matrix multiplication: A*B
|
||||||
Block &matmulIterationBlock = matmulIterateOp.bodyRegion().front();
|
Block &matmulIterationBlock = matmulIterateOp.bodyRegion().front();
|
||||||
|
@ -203,5 +212,6 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
||||||
|
|
||||||
void populateLoweringONNXGemmOpPattern(
|
void populateLoweringONNXGemmOpPattern(
|
||||||
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||||
patterns.insert<ONNXGemmOpLowering>(ctx);
|
patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx);
|
||||||
|
patterns.insert<ONNXGemmOpLowering<ONNXGemmNoBiasOp>>(ctx);
|
||||||
}
|
}
|
||||||
|
|
|
@ -586,9 +586,20 @@ void ONNXGemmNoBiasOp::inferShapes() {
|
||||||
return;
|
return;
|
||||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
|
int64_t M, N, K_A, K_B;
|
||||||
|
M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1];
|
||||||
|
K_A = (transA() == 0) ? lhsTy.getShape()[1] : lhsTy.getShape()[0];
|
||||||
|
N = (transB() == 0) ? rhsTy.getShape()[1] : rhsTy.getShape()[0];
|
||||||
|
K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1];
|
||||||
|
|
||||||
|
if ((K_A != -1) and (K_B != -1) and (K_A != K_B)) {
|
||||||
|
emitError("Tensor shapes mismatched.");
|
||||||
|
}
|
||||||
|
|
||||||
SmallVector<int64_t, 2> dims;
|
SmallVector<int64_t, 2> dims;
|
||||||
dims.emplace_back(lhsTy.getShape()[0]);
|
dims.emplace_back(M);
|
||||||
dims.emplace_back(rhsTy.getShape()[1]);
|
dims.emplace_back(N);
|
||||||
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -98,7 +98,7 @@ test_to_enable = [
|
||||||
"test_gemm_alpha_cpu",
|
"test_gemm_alpha_cpu",
|
||||||
"test_gemm_beta_cpu",
|
"test_gemm_beta_cpu",
|
||||||
"test_gemm_default_matrix_bias_cpu",
|
"test_gemm_default_matrix_bias_cpu",
|
||||||
# "test_gemm_default_no_bias_cpu", <- error, need support for optional operands
|
"test_gemm_default_no_bias_cpu",
|
||||||
"test_gemm_default_scalar_bias_cpu",
|
"test_gemm_default_scalar_bias_cpu",
|
||||||
"test_gemm_default_single_elem_vector_bias_cpu",
|
"test_gemm_default_single_elem_vector_bias_cpu",
|
||||||
"test_gemm_default_vector_bias_cpu",
|
"test_gemm_default_vector_bias_cpu",
|
||||||
|
|
|
@ -795,12 +795,42 @@ func @test_gemm(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>, %arg2: tenso
|
||||||
// CHECK: [[SUM:%.+]] = addf [[Y]], [[AB]] : f32
|
// CHECK: [[SUM:%.+]] = addf [[Y]], [[AB]] : f32
|
||||||
// CHECK: store [[SUM]], [[RES]][%arg3, %arg4] : memref<10x10xf32>
|
// CHECK: store [[SUM]], [[RES]][%arg3, %arg4] : memref<10x10xf32>
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
// CHECK: [[C:%.+]] = load %arg2[%arg4] : memref<10xf32>
|
|
||||||
// CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg3, %arg4] : memref<10x10xf32>
|
// CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg3, %arg4] : memref<10x10xf32>
|
||||||
// CHECK: [[ALPHA_AB:%.+]] = mulf [[ALPHA]], [[LOAD_Y]] : f32
|
// CHECK: [[ALPHA_AB:%.+]] = mulf [[ALPHA]], [[LOAD_Y]] : f32
|
||||||
|
// CHECK: [[C:%.+]] = load %arg2[%arg4] : memref<10xf32>
|
||||||
// CHECK: [[BETA_C:%.+]] = mulf [[BETA]], [[C]] : f32
|
// CHECK: [[BETA_C:%.+]] = mulf [[BETA]], [[C]] : f32
|
||||||
// CHECK: [[Y_RES:%.+]] = addf [[ALPHA_AB]], [[BETA_C]] : f32
|
// CHECK: [[Y_RES:%.+]] = addf [[ALPHA_AB]], [[BETA_C]] : f32
|
||||||
// CHECK: store [[Y_RES]], [[RES]][%arg3, %arg4] : memref<10x10xf32>
|
// CHECK: store [[Y_RES]], [[RES]][%arg3, %arg4] : memref<10x10xf32>
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: return [[RES]] : memref<10x10xf32>
|
||||||
|
// CHECK: }
|
||||||
|
}
|
||||||
|
|
||||||
|
func @test_gemm_no_bias(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>) -> tensor<*xf32> {
|
||||||
|
%0 ="onnx.GemmNoBias"(%arg0, %arg1) {alpha = 1.0 : f32, beta = 5.0 : f32, transA = 1, transB = 0} : (tensor<5x10xf32>, tensor<5x10xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_gemm_no_bias
|
||||||
|
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32>
|
||||||
|
// CHECK: [[ALPHA:%.+]] = constant 1.000000e+00 : f32
|
||||||
|
// CHECK: [[BETA:%.+]] = constant 5.000000e+00 : f32
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:3 = krnl.define_loops 3
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:3 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1, [[DEF_LOOPS]]#2
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#2) with ([[DEF_LOOPS]]#2 -> %arg4 = 0 to 5) {
|
||||||
|
// CHECK: [[A:%.+]] = load %arg0[%arg4, %arg2] : memref<5x10xf32>
|
||||||
|
// CHECK: [[B:%.+]] = load %arg1[%arg4, %arg3] : memref<5x10xf32>
|
||||||
|
// CHECK: [[Y:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32>
|
||||||
|
// CHECK: [[AB:%.+]] = mulf [[A]], [[B]] : f32
|
||||||
|
// CHECK: [[SUM:%.+]] = addf [[Y]], [[AB]] : f32
|
||||||
|
// CHECK: store [[SUM]], [[RES]][%arg2, %arg3] : memref<10x10xf32>
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32>
|
||||||
|
// CHECK: [[ALPHA_AB:%.+]] = mulf [[ALPHA]], [[LOAD_Y]] : f32
|
||||||
|
// CHECK: store [[ALPHA_AB]], [[RES]][%arg2, %arg3] : memref<10x10xf32>
|
||||||
|
// CHECK: }
|
||||||
// CHECK: return [[RES]] : memref<10x10xf32>
|
// CHECK: return [[RES]] : memref<10x10xf32>
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue