diff --git a/src/transform/lower_to_llvm.cpp b/src/transform/lower_to_llvm.cpp index de6e671..7d01207 100644 --- a/src/transform/lower_to_llvm.cpp +++ b/src/transform/lower_to_llvm.cpp @@ -41,6 +41,23 @@ static FlatSymbolRefAttr getOrInsertExternFunc(StringRef funcName, return SymbolRefAttr::get(funcName, context); } +static size_t getRankFromMemRefType(LLVM::LLVMType memRefTy) { + // Usually a MemRef is a 5-element struct, where the 4th and 5th elements in + // this struct are arrays whose size is the rank of the tensor. In the event + // that the corresponding tensor of this MemRef is a scalar, the 4th and 5th + // elements will have 0-length, which in turn causes the MemRef struct to + // degenerate into a 3-element struct. For more information, refer to + // https://github.com/llvm/llvm-project/blob/master/mlir/docs/ConversionToLLVMDialect.md#memref-types. + auto numElems = memRefTy.getStructNumElements(); + assert((numElems == 3 || numElems == 5) && + "Expect MemRef type to contain either 3 or 5 elements."); + + if (numElems == 3) + return 0; // MemRef refers to a scalar. + else + return memRefTy.getStructElementType(3).getArrayNumElements(); +} + //===----------------------------------------------------------------------===// // KRNL to LLVM: KrnlMemcpyOpLowering //===----------------------------------------------------------------------===// @@ -91,9 +108,8 @@ public: // Memcpy call rewriter.create( loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect), - ArrayRef( - {alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size, - isVolatile})); + ArrayRef({alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, + int64Size, isVolatile})); rewriter.eraseOp(op); return matchSuccess(); @@ -116,7 +132,8 @@ private: auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect); auto llvmFnType = LLVM::LLVMType::getFunctionTy( llvmVoidTy, - ArrayRef({llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}), + ArrayRef( + {llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}), false); // Insert the memcpy function into the body of the parent module. @@ -261,8 +278,7 @@ public: // it in the wrapped Output. auto outMemRef = outputMemRefs.getResult(0); auto outMemRefTy = outMemRef.getType().dyn_cast(); - auto outMemRefRank = - outMemRefTy.getStructElementType(3).getArrayNumElements(); + auto outMemRefRank = getRankFromMemRefType(outMemRefTy); auto outMemRefRankVal = rewriter.create( loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank)); auto outDynMemRef = callApi(rewriter, loc, apiRegistry, @@ -376,7 +392,7 @@ private: rewriter.getArrayAttr({rewriter.getI32IntegerAttr(2)})); // Get rank, sizes array ptr and strides array ptr. - auto rank = memRefTy.getStructElementType(3).getArrayNumElements(); + auto rank = getRankFromMemRefType(memRefTy); auto sizesArrayPtr = callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {dynMemRef}); auto stridesArrayPtr = @@ -428,7 +444,7 @@ private: callApi(rewriter, loc, apiRegistry, API::SET_DATA, {outDynMemRef, outMemRefDataPtr}); - auto rank = outMemRefTy.getStructElementType(3).getArrayNumElements(); + auto rank = getRankFromMemRefType(outMemRefTy); auto sizesArrayPtr = callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {outDynMemRef}); auto stridesArrayPtr = diff --git a/test/backend/test.py b/test/backend/test.py index fff7da2..18abd4a 100644 --- a/test/backend/test.py +++ b/test/backend/test.py @@ -99,7 +99,7 @@ test_to_enable = [ "test_gemm_beta_cpu", "test_gemm_default_matrix_bias_cpu", # "test_gemm_default_no_bias_cpu", <- error, need support for optional operands - # "test_gemm_default_scalar_bias_cpu", <- error, shapes mismatch, why? + "test_gemm_default_scalar_bias_cpu", "test_gemm_default_single_elem_vector_bias_cpu", "test_gemm_default_vector_bias_cpu", "test_gemm_default_zero_bias_cpu",