Fix scalar entry point parameter lowering issue. (#78)
* Fix scalar entry point parameter lowering issue. * Enable scalar bias test. * Nit. Improve comments and remove debug code. * Make helper function static, move to upfront position. * Move helper function to top of the file. Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
e5677bba1f
commit
937bbec265
|
@ -41,6 +41,23 @@ static FlatSymbolRefAttr getOrInsertExternFunc(StringRef funcName,
|
||||||
return SymbolRefAttr::get(funcName, context);
|
return SymbolRefAttr::get(funcName, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static size_t getRankFromMemRefType(LLVM::LLVMType memRefTy) {
|
||||||
|
// Usually a MemRef is a 5-element struct, where the 4th and 5th elements in
|
||||||
|
// this struct are arrays whose size is the rank of the tensor. In the event
|
||||||
|
// that the corresponding tensor of this MemRef is a scalar, the 4th and 5th
|
||||||
|
// elements will have 0-length, which in turn causes the MemRef struct to
|
||||||
|
// degenerate into a 3-element struct. For more information, refer to
|
||||||
|
// https://github.com/llvm/llvm-project/blob/master/mlir/docs/ConversionToLLVMDialect.md#memref-types.
|
||||||
|
auto numElems = memRefTy.getStructNumElements();
|
||||||
|
assert((numElems == 3 || numElems == 5) &&
|
||||||
|
"Expect MemRef type to contain either 3 or 5 elements.");
|
||||||
|
|
||||||
|
if (numElems == 3)
|
||||||
|
return 0; // MemRef refers to a scalar.
|
||||||
|
else
|
||||||
|
return memRefTy.getStructElementType(3).getArrayNumElements();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// KRNL to LLVM: KrnlMemcpyOpLowering
|
// KRNL to LLVM: KrnlMemcpyOpLowering
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -91,9 +108,8 @@ public:
|
||||||
// Memcpy call
|
// Memcpy call
|
||||||
rewriter.create<CallOp>(
|
rewriter.create<CallOp>(
|
||||||
loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect),
|
loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect),
|
||||||
ArrayRef<Value>(
|
ArrayRef<Value>({alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory,
|
||||||
{alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size,
|
int64Size, isVolatile}));
|
||||||
isVolatile}));
|
|
||||||
|
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
|
@ -116,7 +132,8 @@ private:
|
||||||
auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect);
|
auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect);
|
||||||
auto llvmFnType = LLVM::LLVMType::getFunctionTy(
|
auto llvmFnType = LLVM::LLVMType::getFunctionTy(
|
||||||
llvmVoidTy,
|
llvmVoidTy,
|
||||||
ArrayRef<mlir::LLVM::LLVMType>({llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}),
|
ArrayRef<mlir::LLVM::LLVMType>(
|
||||||
|
{llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}),
|
||||||
false);
|
false);
|
||||||
|
|
||||||
// Insert the memcpy function into the body of the parent module.
|
// Insert the memcpy function into the body of the parent module.
|
||||||
|
@ -261,8 +278,7 @@ public:
|
||||||
// it in the wrapped Output.
|
// it in the wrapped Output.
|
||||||
auto outMemRef = outputMemRefs.getResult(0);
|
auto outMemRef = outputMemRefs.getResult(0);
|
||||||
auto outMemRefTy = outMemRef.getType().dyn_cast<LLVMType>();
|
auto outMemRefTy = outMemRef.getType().dyn_cast<LLVMType>();
|
||||||
auto outMemRefRank =
|
auto outMemRefRank = getRankFromMemRefType(outMemRefTy);
|
||||||
outMemRefTy.getStructElementType(3).getArrayNumElements();
|
|
||||||
auto outMemRefRankVal = rewriter.create<LLVM::ConstantOp>(
|
auto outMemRefRankVal = rewriter.create<LLVM::ConstantOp>(
|
||||||
loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank));
|
loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank));
|
||||||
auto outDynMemRef = callApi(rewriter, loc, apiRegistry,
|
auto outDynMemRef = callApi(rewriter, loc, apiRegistry,
|
||||||
|
@ -376,7 +392,7 @@ private:
|
||||||
rewriter.getArrayAttr({rewriter.getI32IntegerAttr(2)}));
|
rewriter.getArrayAttr({rewriter.getI32IntegerAttr(2)}));
|
||||||
|
|
||||||
// Get rank, sizes array ptr and strides array ptr.
|
// Get rank, sizes array ptr and strides array ptr.
|
||||||
auto rank = memRefTy.getStructElementType(3).getArrayNumElements();
|
auto rank = getRankFromMemRefType(memRefTy);
|
||||||
auto sizesArrayPtr =
|
auto sizesArrayPtr =
|
||||||
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {dynMemRef});
|
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {dynMemRef});
|
||||||
auto stridesArrayPtr =
|
auto stridesArrayPtr =
|
||||||
|
@ -428,7 +444,7 @@ private:
|
||||||
callApi(rewriter, loc, apiRegistry, API::SET_DATA,
|
callApi(rewriter, loc, apiRegistry, API::SET_DATA,
|
||||||
{outDynMemRef, outMemRefDataPtr});
|
{outDynMemRef, outMemRefDataPtr});
|
||||||
|
|
||||||
auto rank = outMemRefTy.getStructElementType(3).getArrayNumElements();
|
auto rank = getRankFromMemRefType(outMemRefTy);
|
||||||
auto sizesArrayPtr =
|
auto sizesArrayPtr =
|
||||||
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {outDynMemRef});
|
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {outDynMemRef});
|
||||||
auto stridesArrayPtr =
|
auto stridesArrayPtr =
|
||||||
|
|
|
@ -99,7 +99,7 @@ test_to_enable = [
|
||||||
"test_gemm_beta_cpu",
|
"test_gemm_beta_cpu",
|
||||||
"test_gemm_default_matrix_bias_cpu",
|
"test_gemm_default_matrix_bias_cpu",
|
||||||
# "test_gemm_default_no_bias_cpu", <- error, need support for optional operands
|
# "test_gemm_default_no_bias_cpu", <- error, need support for optional operands
|
||||||
# "test_gemm_default_scalar_bias_cpu", <- error, shapes mismatch, why?
|
"test_gemm_default_scalar_bias_cpu",
|
||||||
"test_gemm_default_single_elem_vector_bias_cpu",
|
"test_gemm_default_single_elem_vector_bias_cpu",
|
||||||
"test_gemm_default_vector_bias_cpu",
|
"test_gemm_default_vector_bias_cpu",
|
||||||
"test_gemm_default_zero_bias_cpu",
|
"test_gemm_default_zero_bias_cpu",
|
||||||
|
|
Loading…
Reference in New Issue