diff --git a/src/conversion/onnx_to_krnl/math/elementwise.cpp b/src/conversion/onnx_to_krnl/math/elementwise.cpp index 55d4cda..7ab36af 100644 --- a/src/conversion/onnx_to_krnl/math/elementwise.cpp +++ b/src/conversion/onnx_to_krnl/math/elementwise.cpp @@ -103,8 +103,8 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, Value operand = operands[0]; auto elementType = result_types[0]; - auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); - auto two = rewriter.create(loc, FloatAttr::get(elementType, 2)); + auto zero = emitConstantOp(rewriter, loc, elementType, 0); + auto two = emitConstantOp(rewriter, loc, elementType, 2); auto neg = rewriter.create(loc, zero, operand); auto exp = rewriter.create(loc, operand); auto negExp = rewriter.create(loc, neg); @@ -127,8 +127,8 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, Value operand = operands[0]; auto elementType = result_types[0]; - auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); - auto two = rewriter.create(loc, FloatAttr::get(elementType, 2)); + auto zero = emitConstantOp(rewriter, loc, elementType, 0); + auto two = emitConstantOp(rewriter, loc, elementType, 2); auto neg = rewriter.create(loc, zero, operand); auto exp = rewriter.create(loc, operand); auto negExp = rewriter.create(loc, neg); @@ -152,8 +152,8 @@ Value mapToLowerScalarOp(Operation *op, Value operand = operands[0]; auto elementType = result_types[0]; - auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); - auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); + auto zero = emitConstantOp(rewriter, loc, elementType, 0); + auto one = emitConstantOp(rewriter, loc, elementType, 1); auto neg = rewriter.create(loc, zero, operand); auto negExp = rewriter.create(loc, neg); auto result = rewriter.create( @@ -184,8 +184,8 @@ Value mapToLowerScalarOp( llvm::dyn_cast(op).beta().convertToFloat()); auto elementType = result_types[0]; - auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); - auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); + auto zero = emitConstantOp(rewriter, loc, elementType, 0); + auto one = emitConstantOp(rewriter, loc, elementType, 1); auto alpha = rewriter.create(loc, alphaAttribute); auto beta = rewriter.create(loc, betaAttribute); @@ -217,8 +217,8 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), llvm::dyn_cast(op).alpha().convertToFloat()); - auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); - auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); + auto zero = emitConstantOp(rewriter, loc, elementType, 0); + auto one = emitConstantOp(rewriter, loc, elementType, 1); auto alpha = rewriter.create(loc, alphaAttribute); auto exp = rewriter.create(loc, operand); auto lessThanZero = @@ -246,7 +246,7 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, Value operand = operands[0]; auto elementType = result_types[0]; - auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); + auto zero = emitConstantOp(rewriter, loc, elementType, 0); auto lessThanZero = rewriter.create(loc, CmpFPredicate::OLT, operand, zero); auto result = rewriter.create(loc, lessThanZero, zero, operand); @@ -271,7 +271,7 @@ Value mapToLowerScalarOp(Operation *op, auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), llvm::dyn_cast(op).alpha().convertToFloat()); - auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); + auto zero = emitConstantOp(rewriter, loc, elementType, 0); auto alpha = rewriter.create(loc, alphaAttribute); auto lessThanZero = rewriter.create(loc, CmpFPredicate::OLT, operand, zero); @@ -301,7 +301,7 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, llvm::dyn_cast(op).gamma().convertToFloat()); auto elementType = result_types[0]; - auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); + auto zero = emitConstantOp(rewriter, loc, elementType, 0); auto alpha = rewriter.create(loc, alphaAttribute); auto gamma = rewriter.create(loc, gammaAttribute); auto exp = rewriter.create(loc, operand); @@ -328,7 +328,7 @@ Value mapToLowerScalarOp( Value operand = operands[0]; auto elementType = result_types[0]; - auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); + auto one = emitConstantOp(rewriter, loc, elementType, 1); auto result = rewriter.create(loc, one, operand); return result; @@ -347,7 +347,7 @@ Value mapToLowerScalarOp( auto elementType = result_types[0]; auto exp = rewriter.create(loc, operand); - auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); + auto one = emitConstantOp(rewriter, loc, elementType, 1); auto add = rewriter.create(loc, exp, one); auto result = rewriter.create(loc, add); @@ -367,7 +367,7 @@ Value mapToLowerScalarOp( auto elementType = result_types[0]; auto abs = rewriter.create(loc, operand); - auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); + auto one = emitConstantOp(rewriter, loc, elementType, 1); auto add = rewriter.create(loc, abs, one); auto result = rewriter.create(loc, operand, add); @@ -384,19 +384,18 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, auto loc = op->getLoc(); Value operand = operands[0]; - Type element_type = operands.front().getType(); + Type elementType = operands.front().getType(); // TODO: unsigned int should be supported separately? - if (element_type.isa()) { + if (elementType.isa()) { // %Y = SelectOP(CmpIOp(GT, %X, ConstantOp 0), // ConstantOp 1, // COnstantOp -1) // ONNXSignOp(%X) = SelectOP(CmpIOp(EQ, %X, ConstantOp 0), // ConstantOp 0, // %Y) - auto zero = rewriter.create(loc, rewriter.getI32IntegerAttr(0)); - auto one = rewriter.create(loc, rewriter.getI32IntegerAttr(1)); - auto minusOne = - rewriter.create(loc, rewriter.getI32IntegerAttr(-1)); + auto zero = emitConstantOp(rewriter, loc, elementType, 0); + auto one = emitConstantOp(rewriter, loc, elementType, 1); + auto minusOne = emitConstantOp(rewriter, loc, elementType, -1); auto plusPredicate = rewriter.create(loc, CmpIPredicate::sgt, operand, zero); auto plusSelect = @@ -406,18 +405,16 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, auto result = rewriter.create(loc, zeroPredicate, zero, plusSelect); return result; - } else if (element_type.isa()) { + } else if (elementType.isa()) { // %Y = SelectOP(CmpFOp(OGT, %X, ConstantOp 0), // ConstantOp 1, // ConstantOp -1) // ONNXSignOp(%X) = SelectOP(CmpFOp(OEQ, %X, ConstantOp 0), // ConstantOp 0, // %Y) - auto zero = - rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); - auto one = rewriter.create(loc, rewriter.getF32FloatAttr(1.0f)); - auto minusOne = - rewriter.create(loc, rewriter.getF32FloatAttr(-1.0f)); + auto zero = emitConstantOp(rewriter, loc, elementType, 0); + auto one = emitConstantOp(rewriter, loc, elementType, 1); + auto minusOne = emitConstantOp(rewriter, loc, elementType, -1); auto plusPredicate = rewriter.create(loc, CmpFPredicate::OGT, operand, zero); auto plusSelect = diff --git a/src/conversion/onnx_to_krnl/math/gemm.cpp b/src/conversion/onnx_to_krnl/math/gemm.cpp index 0eed272..33399cb 100644 --- a/src/conversion/onnx_to_krnl/math/gemm.cpp +++ b/src/conversion/onnx_to_krnl/math/gemm.cpp @@ -156,8 +156,7 @@ struct ONNXGemmOpLowering : public ConversionPattern { } // Initialize the output of A*B - auto zero = rewriter.create( - loc, FloatAttr::get(memRefType.getElementType(), 0)); + auto zero = emitConstantOp(rewriter, loc, memRefType.getElementType(), 0); rewriter.create(loc, zero, alloc, loopMNIVs); // Compute A*B diff --git a/src/conversion/onnx_to_krnl/math/matmul.cpp b/src/conversion/onnx_to_krnl/math/matmul.cpp index a3cb26a..2a6b7f2 100644 --- a/src/conversion/onnx_to_krnl/math/matmul.cpp +++ b/src/conversion/onnx_to_krnl/math/matmul.cpp @@ -37,16 +37,7 @@ struct ONNXMatMulOpLowering : public ConversionPattern { auto memRefShape = memRefType.getShape(); // A value zero - Value zero; - if (elementType.isa()) { - zero = rewriter.create( - loc, IntegerAttr::get(memRefType.getElementType(), 0)); - } else if (elementType.isa()) { - zero = rewriter.create( - loc, FloatAttr::get(memRefType.getElementType(), 0)); - } else { - emitError(loc, "unsupported element type"); - } + auto zero = emitConstantOp(rewriter, loc, memRefType.getElementType(), 0); // Insert an allocation and deallocation for the result of this operation. Value alloc; diff --git a/src/conversion/onnx_to_krnl/math/reduction.cpp b/src/conversion/onnx_to_krnl/math/reduction.cpp index 42b074a..e169cc6 100644 --- a/src/conversion/onnx_to_krnl/math/reduction.cpp +++ b/src/conversion/onnx_to_krnl/math/reduction.cpp @@ -14,43 +14,27 @@ using namespace mlir; // Identity values template <> -float getIdentityValue(){ - return (float)-std::numeric_limits::infinity(); +Value getIdentityValue( + ConversionPatternRewriter &rewriter, Location loc, Type type) { + return emitNegativeInfinityConstantOp(rewriter, loc, type); } template <> -int getIdentityValue(){ - return std::numeric_limits::min(); +Value getIdentityValue( + ConversionPatternRewriter &rewriter, Location loc, Type type) { + return emitPositiveInfinityConstantOp(rewriter, loc, type); } template <> -float getIdentityValue(){ - return (float)std::numeric_limits::infinity(); +Value getIdentityValue( + ConversionPatternRewriter &rewriter, Location loc, Type type) { + return emitConstantOp(rewriter, loc, type, 1); } template <> -int getIdentityValue(){ - return std::numeric_limits::max(); -} - -template <> -float getIdentityValue(){ - return (float)1.0; -} - -template <> -int getIdentityValue(){ - return 1; -} - -template <> -float getIdentityValue(){ - return (float)0; -} - -template <> -int getIdentityValue(){ - return 0; +Value getIdentityValue( + ConversionPatternRewriter &rewriter, Location loc, Type type) { + return emitConstantOp(rewriter, loc, type, 0); } // Scalar ops @@ -234,18 +218,8 @@ struct ONNXReductionOpLowering : public ConversionPattern { loopIVs.push_back(arg); } - Value identity; - if (elementOutType.isa()) { - identity = rewriter.create( - loc, FloatAttr::get(elementOutType, - getIdentityValue())); - } else if (elementOutType.isa()) { - identity = rewriter.create( - loc, IntegerAttr::get(elementOutType, - getIdentityValue())); - } else { - emitError(loc, "unsupported element type"); - } + Value identity = + getIdentityValue(rewriter, loc, elementOutType); rewriter.create(loc, identity, alloc, loopIVs); // Define an Krnl loop to do reduction. diff --git a/src/conversion/onnx_to_krnl/math/softmax.cpp b/src/conversion/onnx_to_krnl/math/softmax.cpp index 3277635..b32d15b 100644 --- a/src/conversion/onnx_to_krnl/math/softmax.cpp +++ b/src/conversion/onnx_to_krnl/math/softmax.cpp @@ -48,8 +48,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern { MemRefType scalarMemRefType = MemRefType::get({}, elementType, {}, 0); Value sumOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true); Value maxOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true); - Value zero = - rewriter.create(loc, FloatAttr::get(elementType, 0)); + Value zero = emitConstantOp(rewriter, loc, elementType, 0); Value negInfinity = rewriter.create( loc, FloatAttr::get(elementType, -std::numeric_limits::infinity())); diff --git a/src/conversion/onnx_to_krnl/nn/conv.cpp b/src/conversion/onnx_to_krnl/nn/conv.cpp index 851668a..5b56a74 100644 --- a/src/conversion/onnx_to_krnl/nn/conv.cpp +++ b/src/conversion/onnx_to_krnl/nn/conv.cpp @@ -92,8 +92,7 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern { int64_t kernelsPerGroup = floor(kernelShape[0] / group); auto kernelsPerGroupValue = rewriter.create(loc, kernelsPerGroup); - auto zero = rewriter.create( - loc, FloatAttr::get(memRefType.getElementType(), 0)); + auto zero = emitConstantOp(rewriter, loc, memRefType.getElementType(), 0); Value subchannels; if (kernelShape[1] < 0) { subchannels = rewriter.create(loc, kernelOperand, 1).getResult(); diff --git a/src/conversion/onnx_to_krnl/nn/pooling.cpp b/src/conversion/onnx_to_krnl/nn/pooling.cpp index 0015b8a..17e5b9d 100644 --- a/src/conversion/onnx_to_krnl/nn/pooling.cpp +++ b/src/conversion/onnx_to_krnl/nn/pooling.cpp @@ -14,13 +14,9 @@ using namespace mlir; // Identity values template <> -float getIdentityValue() { - return (float)-std::numeric_limits::infinity(); -} - -template <> -int getIdentityValue() { - return std::numeric_limits::min(); +Value getIdentityValue( + ConversionPatternRewriter &rewriter, Location loc, Type type) { + return emitNegativeInfinityConstantOp(rewriter, loc, type); } template <> @@ -204,18 +200,8 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern { resultIndices.emplace_back(outerLoops.getInductionVar(i)); // 2.1 Emit: R[n][c][r1][r2] = negative_infinity; - Value identity; - if (resultElementType.isa()) { - identity = rewriter.create( - loc, FloatAttr::get(resultElementType, - getIdentityValue())); - } else if (resultElementType.isa()) { - identity = rewriter.create( - loc, IntegerAttr::get(resultElementType, - getIdentityValue())); - } else { - emitError(loc, "unsupported element type"); - } + Value identity = getIdentityValue( + rewriter, loc, resultElementType); rewriter.create(loc, identity, alloc, resultIndices); // 2.2 Define inner loops. diff --git a/src/conversion/onnx_to_krnl/onnx_to_krnl_common.cpp b/src/conversion/onnx_to_krnl/onnx_to_krnl_common.cpp index 16bc499..e8f3f58 100644 --- a/src/conversion/onnx_to_krnl/onnx_to_krnl_common.cpp +++ b/src/conversion/onnx_to_krnl/onnx_to_krnl_common.cpp @@ -322,3 +322,166 @@ getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter, } return newLoopIVs; } + +Value emitConstantOp(ConversionPatternRewriter &rewriter, Location loc, + Type type, double value) { + Attribute constantAttr; + auto typeKind = type.getKind(); + if (typeKind == StandardTypes::F16) { + constantAttr = rewriter.getF16FloatAttr((float)value); + } else if (typeKind == StandardTypes::F32) { + constantAttr = rewriter.getF32FloatAttr((float)value); + } else if (typeKind == StandardTypes::F64) { + constantAttr = rewriter.getF64FloatAttr(value); + } else if (typeKind == StandardTypes::Integer) { + auto width = type.cast().getWidth(); + if (width == 1) { + constantAttr = rewriter.getBoolAttr(false); + } else { + constantAttr = + rewriter.getIntegerAttr(type, APInt(width, (int64_t)value)); + } + } else if (typeKind == StandardTypes::Index) { + constantAttr = rewriter.getIntegerAttr(type, (int64_t)value); + } else { + emitError(loc, "unsupported element type"); + } + + return rewriter.create(loc, constantAttr); +} + +Value emitPositiveInfinityConstantOp( + ConversionPatternRewriter &rewriter, Location loc, Type type) { + Attribute constantAttr; + auto typeKind = type.getKind(); + if (typeKind == StandardTypes::F16) { + // 0x7C00 + float value = std::numeric_limits::infinity(); + constantAttr = rewriter.getF16FloatAttr(value); + } else if (typeKind == StandardTypes::F32) { + // 0x7F800000 + float value = std::numeric_limits::infinity(); + constantAttr = rewriter.getF32FloatAttr(value); + } else if (typeKind == StandardTypes::F64) { + // 0x7FF0000000000000 + double value = std::numeric_limits::infinity(); + constantAttr = rewriter.getF64FloatAttr(value); + } else if (typeKind == StandardTypes::Integer) { + auto width = type.cast().getWidth(); + // The latest llvm-project includes a patch which allows getting the sign of + // IntegerType: + // https://github.com/llvm/llvm-project/commit/35b685270b410f6a1351c2a527021f22330c25b9 + // as follows: + // auto isSigned = type.cast().isSigned(); + // TODO (tungld): update the following statement once our llvm-project is + // upgraded to include the patch. + auto isSigned = true; + if (width == 8) { + if (isSigned) { + int8_t value = std::numeric_limits::max(); + constantAttr = rewriter.getIntegerAttr(type, APInt(width, value)); + } else { + uint8_t value = std::numeric_limits::max(); + constantAttr = rewriter.getIntegerAttr(type, APInt(width, value)); + } + } else if (width == 16) { + if (isSigned) { + int16_t value = std::numeric_limits::max(); + constantAttr = rewriter.getIntegerAttr(type, APInt(width, value)); + } else { + uint16_t value = std::numeric_limits::max(); + constantAttr = rewriter.getIntegerAttr(type, APInt(width, value)); + } + } else if (width == 32) { + if (isSigned) { + int32_t value = std::numeric_limits::max(); + constantAttr = rewriter.getIntegerAttr(type, APInt(width, value)); + } else { + uint32_t value = std::numeric_limits::max(); + constantAttr = rewriter.getIntegerAttr(type, APInt(width, value)); + } + } else if (width == 64) { + if (isSigned) { + int64_t value = std::numeric_limits::max(); + constantAttr = rewriter.getIntegerAttr(type, APInt(width, value)); + } else { + uint64_t value = std::numeric_limits::max(); + constantAttr = rewriter.getIntegerAttr(type, APInt(width, value)); + } + } else { + emitError(loc, "unsupported element type"); + } + } else { + emitError(loc, "unsupported element type"); + } + + return rewriter.create(loc, constantAttr); +} + +Value emitNegativeInfinityConstantOp( + ConversionPatternRewriter &rewriter, Location loc, Type type) { + Attribute constantAttr; + auto typeKind = type.getKind(); + if (typeKind == StandardTypes::F16) { + // 0xFC00 + float value = -std::numeric_limits::infinity(); + constantAttr = rewriter.getF16FloatAttr(value); + } else if (typeKind == StandardTypes::F32) { + // 0xFF800000 + float value = -std::numeric_limits::infinity(); + constantAttr = rewriter.getF32FloatAttr(value); + } else if (typeKind == StandardTypes::F64) { + // 0xFFF0000000000000 + double value = -std::numeric_limits::infinity(); + constantAttr = rewriter.getF64FloatAttr(value); + } else if (typeKind == StandardTypes::Integer) { + auto width = type.cast().getWidth(); + // The latest llvm-project includes a patch which allows getting the sign of + // IntegerType: + // https://github.com/llvm/llvm-project/commit/35b685270b410f6a1351c2a527021f22330c25b9 + // as follows: + // auto isSigned = type.cast().isSigned(); + // TODO (tungld): update the following statement once our llvm-project is + // upgraded to include the patch. + auto isSigned = true; + if (width == 8) { + if (isSigned) { + int8_t value = std::numeric_limits::min(); + constantAttr = rewriter.getIntegerAttr(type, APInt(width, value)); + } else { + uint8_t value = std::numeric_limits::min(); + constantAttr = rewriter.getIntegerAttr(type, APInt(width, value)); + } + } else if (width == 16) { + if (isSigned) { + int16_t value = std::numeric_limits::min(); + constantAttr = rewriter.getIntegerAttr(type, APInt(width, value)); + } else { + uint16_t value = std::numeric_limits::min(); + constantAttr = rewriter.getIntegerAttr(type, APInt(width, value)); + } + } else if (width == 32) { + if (isSigned) { + int32_t value = std::numeric_limits::min(); + constantAttr = rewriter.getIntegerAttr(type, APInt(width, value)); + } else { + uint32_t value = std::numeric_limits::min(); + constantAttr = rewriter.getIntegerAttr(type, APInt(width, value)); + } + } else if (width == 64) { + if (isSigned) { + int64_t value = std::numeric_limits::min(); + constantAttr = rewriter.getIntegerAttr(type, APInt(width, value)); + } else { + uint64_t value = std::numeric_limits::min(); + constantAttr = rewriter.getIntegerAttr(type, APInt(width, value)); + } + } else { + emitError(loc, "unsupported element type"); + } + } else { + emitError(loc, "unsupported element type"); + } + + return rewriter.create(loc, constantAttr); +} diff --git a/src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp b/src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp index df4b65d..7750c16 100644 --- a/src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp +++ b/src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp @@ -98,6 +98,25 @@ getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter, ArrayRef loopIVs, Value operand, std::map broadcastedDims); +// Emit a constant of a specific type. +// Use this function for small values only to avoid unexpected loss in type +// casting. +Value emitConstantOp( + ConversionPatternRewriter &rewriter, Location loc, Type type, double value); + +// Emit a positive infinity constant of a specific type. +// Supported types: F16, F32, F64, Int8, Int16, Int32, Int64. +// In case of Integer, emit the maximum value. +Value emitPositiveInfinityConstantOp( + ConversionPatternRewriter &rewriter, Location loc, Type type); + +// Emit a negative infinity constant of a specific type. +// Supported types: F16, F32, F64, Int8, Int16, Int32, Int64. +// In case of Float, emit the negative of the positive infinity. +// In case of Integer, emit the minimum value. +Value emitNegativeInfinityConstantOp( + ConversionPatternRewriter &rewriter, Location loc, Type type); + //===----------------------------------------------------------------------===// // This is to get a scalar operation of a given type for a specific operation. //===----------------------------------------------------------------------===// @@ -112,11 +131,13 @@ using ScalarFOp = typename ScalarOp::FOp; template using ScalarIOp = typename ScalarOp::IOp; -// Get the identity element of a operation. +// Get the identity element of an operation. // Return NULL if the function does not have identity. -template -DataType getIdentityValue() { - return NULL; +// Specialize this for a new Op. +template +Value getIdentityValue( + ConversionPatternRewriter &rewriter, Location loc, Type type) { + return nullptr; } //===----------------------------------------------------------------------===// diff --git a/src/conversion/onnx_to_krnl/tensor/reshape.cpp b/src/conversion/onnx_to_krnl/tensor/reshape.cpp index 6489a71..9e99f2d 100644 --- a/src/conversion/onnx_to_krnl/tensor/reshape.cpp +++ b/src/conversion/onnx_to_krnl/tensor/reshape.cpp @@ -28,9 +28,8 @@ struct ONNXReshapeOpLowering : public ConversionPattern { Value alloc; // Compute size in bytes using the input tensor. - Value tensorSize = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), - getMemRefEltSizeInBytes(memRefType))); + Value tensorSize = emitConstantOp(rewriter, loc, + rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType)); for (int i = 0; i < inputShape.size(); ++i) { Value dimVal; if (inputShape[i] < 0) { @@ -38,9 +37,8 @@ struct ONNXReshapeOpLowering : public ConversionPattern { dimVal = rewriter.create(loc, dim, rewriter.getIntegerType(64)); } else { - dimVal = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), - inputShape[i])); + dimVal = emitConstantOp( + rewriter, loc, rewriter.getIntegerType(64), inputShape[i]); } tensorSize = rewriter.create(loc, tensorSize, dimVal); } @@ -59,13 +57,11 @@ struct ONNXReshapeOpLowering : public ConversionPattern { // If the reduction is negative, then the shape array contains a negative // dimension. Otherwise, the reduction is the same as the one computed // from the input tensor. - Value tensorSizeFromShape = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), - getMemRefEltSizeInBytes(memRefType))); + Value tensorSizeFromShape = emitConstantOp(rewriter, loc, + rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType)); SmallVector DimInfo; for (int i = 0; i < memRefShape.size(); ++i) { - Value index = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i)); + Value index = emitConstantOp(rewriter, loc, rewriter.getIndexType(), i); // Load index from array of indices. Value loadedVal = rewriter.create(loc, operands[1], index); // If a dimension is zero, the actual dimension value is taken from the @@ -81,11 +77,10 @@ struct ONNXReshapeOpLowering : public ConversionPattern { Value dim = rewriter.create(loc, operands[0], i); dimVal = rewriter.create(loc, dim, loadedValType); } else { - dimVal = rewriter.create( - loc, rewriter.getIntegerAttr(loadedValType, inputShape[i])); + dimVal = + emitConstantOp(rewriter, loc, loadedValType, inputShape[i]); } - auto zero = rewriter.create( - loc, rewriter.getIntegerAttr(loadedValType, 0)); + auto zero = emitConstantOp(rewriter, loc, loadedValType, 0); auto isZero = rewriter.create(loc, CmpIPredicate::eq, loadedVal, zero); loadedVal = rewriter.create(loc, isZero, dimVal, loadedVal); @@ -104,15 +99,14 @@ struct ONNXReshapeOpLowering : public ConversionPattern { // Reverse tensorSizeFromShape since it is negative if the shape array has // a negative dimension. This is safe since we only use it to compute the // actual value for the negative dimension. - auto zero = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + auto zero = emitConstantOp(rewriter, loc, rewriter.getIntegerType(64), 0); tensorSizeFromShape = rewriter.create(loc, zero, tensorSizeFromShape); // Obtain operands for AllocOp. SmallVector allocOperands; - auto negOne = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), -1)); + auto negOne = + emitConstantOp(rewriter, loc, rewriter.getIntegerType(64), -1); for (int i = 0; i < memRefShape.size(); ++i) { auto dimVal = DimInfo[i]; diff --git a/src/conversion/onnx_to_krnl/tensor/unsqueeze.cpp b/src/conversion/onnx_to_krnl/tensor/unsqueeze.cpp index 070a91c..1c5f3ec 100644 --- a/src/conversion/onnx_to_krnl/tensor/unsqueeze.cpp +++ b/src/conversion/onnx_to_krnl/tensor/unsqueeze.cpp @@ -37,18 +37,16 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern { Value alloc; // Compute size in bytes. - Value tensorSize = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), - getMemRefEltSizeInBytes(memRefType))); + Value tensorSize = emitConstantOp(rewriter, loc, + rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType)); bool insertDealloc = checkInsertDealloc(op); auto memRefShape = memRefType.getShape(); if (hasAllConstantDimensions(memRefType)) { alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); for (int i = 0; i < memRefShape.size(); ++i) { - Value dimVal = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), - memRefShape[i])); + Value dimVal = emitConstantOp( + rewriter, loc, rewriter.getIntegerType(64), memRefShape[i]); tensorSize = rewriter.create(loc, tensorSize, dimVal); } } else { @@ -62,9 +60,8 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern { loc, index, rewriter.getIntegerType(64)); allocOperands.emplace_back(index); } else { - dimVal = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), - memRefShape[outIdx])); + dimVal = emitConstantOp( + rewriter, loc, rewriter.getIntegerType(64), memRefShape[outIdx]); } tensorSize = rewriter.create(loc, tensorSize, dimVal); if (std::find(axes.begin(), axes.end(), outIdx) == axes.end())