Fix scalar entry point parameter lowering issue. (#78)
* Fix scalar entry point parameter lowering issue. * Enable scalar bias test. * Nit. Improve comments and remove debug code. * Make helper function static, move to upfront position. * Move helper function to top of the file. Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
e5677bba1f
commit
937bbec265
|
@ -41,6 +41,23 @@ static FlatSymbolRefAttr getOrInsertExternFunc(StringRef funcName,
|
|||
return SymbolRefAttr::get(funcName, context);
|
||||
}
|
||||
|
||||
static size_t getRankFromMemRefType(LLVM::LLVMType memRefTy) {
|
||||
// Usually a MemRef is a 5-element struct, where the 4th and 5th elements in
|
||||
// this struct are arrays whose size is the rank of the tensor. In the event
|
||||
// that the corresponding tensor of this MemRef is a scalar, the 4th and 5th
|
||||
// elements will have 0-length, which in turn causes the MemRef struct to
|
||||
// degenerate into a 3-element struct. For more information, refer to
|
||||
// https://github.com/llvm/llvm-project/blob/master/mlir/docs/ConversionToLLVMDialect.md#memref-types.
|
||||
auto numElems = memRefTy.getStructNumElements();
|
||||
assert((numElems == 3 || numElems == 5) &&
|
||||
"Expect MemRef type to contain either 3 or 5 elements.");
|
||||
|
||||
if (numElems == 3)
|
||||
return 0; // MemRef refers to a scalar.
|
||||
else
|
||||
return memRefTy.getStructElementType(3).getArrayNumElements();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// KRNL to LLVM: KrnlMemcpyOpLowering
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -91,9 +108,8 @@ public:
|
|||
// Memcpy call
|
||||
rewriter.create<CallOp>(
|
||||
loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect),
|
||||
ArrayRef<Value>(
|
||||
{alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size,
|
||||
isVolatile}));
|
||||
ArrayRef<Value>({alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory,
|
||||
int64Size, isVolatile}));
|
||||
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
|
@ -116,7 +132,8 @@ private:
|
|||
auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect);
|
||||
auto llvmFnType = LLVM::LLVMType::getFunctionTy(
|
||||
llvmVoidTy,
|
||||
ArrayRef<mlir::LLVM::LLVMType>({llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}),
|
||||
ArrayRef<mlir::LLVM::LLVMType>(
|
||||
{llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}),
|
||||
false);
|
||||
|
||||
// Insert the memcpy function into the body of the parent module.
|
||||
|
@ -261,8 +278,7 @@ public:
|
|||
// it in the wrapped Output.
|
||||
auto outMemRef = outputMemRefs.getResult(0);
|
||||
auto outMemRefTy = outMemRef.getType().dyn_cast<LLVMType>();
|
||||
auto outMemRefRank =
|
||||
outMemRefTy.getStructElementType(3).getArrayNumElements();
|
||||
auto outMemRefRank = getRankFromMemRefType(outMemRefTy);
|
||||
auto outMemRefRankVal = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank));
|
||||
auto outDynMemRef = callApi(rewriter, loc, apiRegistry,
|
||||
|
@ -376,7 +392,7 @@ private:
|
|||
rewriter.getArrayAttr({rewriter.getI32IntegerAttr(2)}));
|
||||
|
||||
// Get rank, sizes array ptr and strides array ptr.
|
||||
auto rank = memRefTy.getStructElementType(3).getArrayNumElements();
|
||||
auto rank = getRankFromMemRefType(memRefTy);
|
||||
auto sizesArrayPtr =
|
||||
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {dynMemRef});
|
||||
auto stridesArrayPtr =
|
||||
|
@ -428,7 +444,7 @@ private:
|
|||
callApi(rewriter, loc, apiRegistry, API::SET_DATA,
|
||||
{outDynMemRef, outMemRefDataPtr});
|
||||
|
||||
auto rank = outMemRefTy.getStructElementType(3).getArrayNumElements();
|
||||
auto rank = getRankFromMemRefType(outMemRefTy);
|
||||
auto sizesArrayPtr =
|
||||
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {outDynMemRef});
|
||||
auto stridesArrayPtr =
|
||||
|
|
|
@ -99,7 +99,7 @@ test_to_enable = [
|
|||
"test_gemm_beta_cpu",
|
||||
"test_gemm_default_matrix_bias_cpu",
|
||||
# "test_gemm_default_no_bias_cpu", <- error, need support for optional operands
|
||||
# "test_gemm_default_scalar_bias_cpu", <- error, shapes mismatch, why?
|
||||
"test_gemm_default_scalar_bias_cpu",
|
||||
"test_gemm_default_single_elem_vector_bias_cpu",
|
||||
"test_gemm_default_vector_bias_cpu",
|
||||
"test_gemm_default_zero_bias_cpu",
|
||||
|
|
Loading…
Reference in New Issue