Create some helper functions to emit constant op for a specific type (#7)
* emitConstantOp with a given type * Helper functions to create infinity constants * Use new constant helper functions for MaxPoolSingleOut Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
8e1b30e133
commit
8a992b619f
|
@ -103,8 +103,8 @@ Value mapToLowerScalarOp<ONNXSinhOp>(Operation *op, ArrayRef<Type> result_types,
|
|||
Value operand = operands[0];
|
||||
auto elementType = result_types[0];
|
||||
|
||||
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||
auto two = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 2));
|
||||
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
|
||||
auto two = emitConstantOp(rewriter, loc, elementType, 2);
|
||||
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
||||
auto exp = rewriter.create<ExpOp>(loc, operand);
|
||||
auto negExp = rewriter.create<ExpOp>(loc, neg);
|
||||
|
@ -127,8 +127,8 @@ Value mapToLowerScalarOp<ONNXCoshOp>(Operation *op, ArrayRef<Type> result_types,
|
|||
Value operand = operands[0];
|
||||
auto elementType = result_types[0];
|
||||
|
||||
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||
auto two = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 2));
|
||||
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
|
||||
auto two = emitConstantOp(rewriter, loc, elementType, 2);
|
||||
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
||||
auto exp = rewriter.create<ExpOp>(loc, operand);
|
||||
auto negExp = rewriter.create<ExpOp>(loc, neg);
|
||||
|
@ -152,8 +152,8 @@ Value mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op,
|
|||
Value operand = operands[0];
|
||||
auto elementType = result_types[0];
|
||||
|
||||
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
|
||||
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
|
||||
auto one = emitConstantOp(rewriter, loc, elementType, 1);
|
||||
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
||||
auto negExp = rewriter.create<ExpOp>(loc, neg);
|
||||
auto result = rewriter.create<DivFOp>(
|
||||
|
@ -184,8 +184,8 @@ Value mapToLowerScalarOp<ONNXHardSigmoidOp>(
|
|||
llvm::dyn_cast<ONNXHardSigmoidOp>(op).beta().convertToFloat());
|
||||
auto elementType = result_types[0];
|
||||
|
||||
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
|
||||
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
|
||||
auto one = emitConstantOp(rewriter, loc, elementType, 1);
|
||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
|
||||
auto beta = rewriter.create<ConstantOp>(loc, betaAttribute);
|
||||
|
||||
|
@ -217,8 +217,8 @@ Value mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
|
|||
|
||||
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
||||
llvm::dyn_cast<ONNXEluOp>(op).alpha().convertToFloat());
|
||||
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
|
||||
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
|
||||
auto one = emitConstantOp(rewriter, loc, elementType, 1);
|
||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
|
||||
auto exp = rewriter.create<ExpOp>(loc, operand);
|
||||
auto lessThanZero =
|
||||
|
@ -246,7 +246,7 @@ Value mapToLowerScalarOp<ONNXReluOp>(Operation *op, ArrayRef<Type> result_types,
|
|||
Value operand = operands[0];
|
||||
auto elementType = result_types[0];
|
||||
|
||||
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
|
||||
auto lessThanZero =
|
||||
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
|
||||
auto result = rewriter.create<SelectOp>(loc, lessThanZero, zero, operand);
|
||||
|
@ -271,7 +271,7 @@ Value mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op,
|
|||
|
||||
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
||||
llvm::dyn_cast<ONNXLeakyReluOp>(op).alpha().convertToFloat());
|
||||
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
|
||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
|
||||
auto lessThanZero =
|
||||
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
|
||||
|
@ -301,7 +301,7 @@ Value mapToLowerScalarOp<ONNXSeluOp>(Operation *op, ArrayRef<Type> result_types,
|
|||
llvm::dyn_cast<ONNXSeluOp>(op).gamma().convertToFloat());
|
||||
auto elementType = result_types[0];
|
||||
|
||||
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
|
||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
|
||||
auto gamma = rewriter.create<ConstantOp>(loc, gammaAttribute);
|
||||
auto exp = rewriter.create<ExpOp>(loc, operand);
|
||||
|
@ -328,7 +328,7 @@ Value mapToLowerScalarOp<ONNXReciprocalOp>(
|
|||
Value operand = operands[0];
|
||||
auto elementType = result_types[0];
|
||||
|
||||
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
|
||||
auto one = emitConstantOp(rewriter, loc, elementType, 1);
|
||||
auto result = rewriter.create<DivFOp>(loc, one, operand);
|
||||
|
||||
return result;
|
||||
|
@ -347,7 +347,7 @@ Value mapToLowerScalarOp<ONNXSoftplusOp>(
|
|||
auto elementType = result_types[0];
|
||||
|
||||
auto exp = rewriter.create<ExpOp>(loc, operand);
|
||||
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
|
||||
auto one = emitConstantOp(rewriter, loc, elementType, 1);
|
||||
auto add = rewriter.create<AddFOp>(loc, exp, one);
|
||||
auto result = rewriter.create<LogOp>(loc, add);
|
||||
|
||||
|
@ -367,7 +367,7 @@ Value mapToLowerScalarOp<ONNXSoftsignOp>(
|
|||
auto elementType = result_types[0];
|
||||
|
||||
auto abs = rewriter.create<AbsFOp>(loc, operand);
|
||||
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
|
||||
auto one = emitConstantOp(rewriter, loc, elementType, 1);
|
||||
auto add = rewriter.create<AddFOp>(loc, abs, one);
|
||||
auto result = rewriter.create<DivFOp>(loc, operand, add);
|
||||
|
||||
|
@ -384,19 +384,18 @@ Value mapToLowerScalarOp<ONNXSignOp>(Operation *op, ArrayRef<Type> result_types,
|
|||
|
||||
auto loc = op->getLoc();
|
||||
Value operand = operands[0];
|
||||
Type element_type = operands.front().getType();
|
||||
Type elementType = operands.front().getType();
|
||||
// TODO: unsigned int should be supported separately?
|
||||
if (element_type.isa<IntegerType>()) {
|
||||
if (elementType.isa<IntegerType>()) {
|
||||
// %Y = SelectOP(CmpIOp(GT, %X, ConstantOp 0),
|
||||
// ConstantOp 1,
|
||||
// COnstantOp -1)
|
||||
// ONNXSignOp(%X) = SelectOP(CmpIOp(EQ, %X, ConstantOp 0),
|
||||
// ConstantOp 0,
|
||||
// %Y)
|
||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
|
||||
auto one = rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(1));
|
||||
auto minusOne =
|
||||
rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(-1));
|
||||
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
|
||||
auto one = emitConstantOp(rewriter, loc, elementType, 1);
|
||||
auto minusOne = emitConstantOp(rewriter, loc, elementType, -1);
|
||||
auto plusPredicate =
|
||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, operand, zero);
|
||||
auto plusSelect =
|
||||
|
@ -406,18 +405,16 @@ Value mapToLowerScalarOp<ONNXSignOp>(Operation *op, ArrayRef<Type> result_types,
|
|||
auto result =
|
||||
rewriter.create<SelectOp>(loc, zeroPredicate, zero, plusSelect);
|
||||
return result;
|
||||
} else if (element_type.isa<FloatType>()) {
|
||||
} else if (elementType.isa<FloatType>()) {
|
||||
// %Y = SelectOP(CmpFOp(OGT, %X, ConstantOp 0),
|
||||
// ConstantOp 1,
|
||||
// ConstantOp -1)
|
||||
// ONNXSignOp(%X) = SelectOP(CmpFOp(OEQ, %X, ConstantOp 0),
|
||||
// ConstantOp 0,
|
||||
// %Y)
|
||||
auto zero =
|
||||
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
|
||||
auto minusOne =
|
||||
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(-1.0f));
|
||||
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
|
||||
auto one = emitConstantOp(rewriter, loc, elementType, 1);
|
||||
auto minusOne = emitConstantOp(rewriter, loc, elementType, -1);
|
||||
auto plusPredicate =
|
||||
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, operand, zero);
|
||||
auto plusSelect =
|
||||
|
|
|
@ -156,8 +156,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
|||
}
|
||||
|
||||
// Initialize the output of A*B
|
||||
auto zero = rewriter.create<ConstantOp>(
|
||||
loc, FloatAttr::get(memRefType.getElementType(), 0));
|
||||
auto zero = emitConstantOp(rewriter, loc, memRefType.getElementType(), 0);
|
||||
rewriter.create<StoreOp>(loc, zero, alloc, loopMNIVs);
|
||||
|
||||
// Compute A*B
|
||||
|
|
|
@ -37,16 +37,7 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
|
|||
auto memRefShape = memRefType.getShape();
|
||||
|
||||
// A value zero
|
||||
Value zero;
|
||||
if (elementType.isa<IntegerType>()) {
|
||||
zero = rewriter.create<ConstantOp>(
|
||||
loc, IntegerAttr::get(memRefType.getElementType(), 0));
|
||||
} else if (elementType.isa<FloatType>()) {
|
||||
zero = rewriter.create<ConstantOp>(
|
||||
loc, FloatAttr::get(memRefType.getElementType(), 0));
|
||||
} else {
|
||||
emitError(loc, "unsupported element type");
|
||||
}
|
||||
auto zero = emitConstantOp(rewriter, loc, memRefType.getElementType(), 0);
|
||||
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
Value alloc;
|
||||
|
|
|
@ -14,43 +14,27 @@ using namespace mlir;
|
|||
|
||||
// Identity values
|
||||
template <>
|
||||
float getIdentityValue<float, ONNXReduceMaxOp>(){
|
||||
return (float)-std::numeric_limits<float>::infinity();
|
||||
Value getIdentityValue<ONNXReduceMaxOp>(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Type type) {
|
||||
return emitNegativeInfinityConstantOp(rewriter, loc, type);
|
||||
}
|
||||
|
||||
template <>
|
||||
int getIdentityValue<int, ONNXReduceMaxOp>(){
|
||||
return std::numeric_limits<int>::min();
|
||||
Value getIdentityValue<ONNXReduceMinOp>(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Type type) {
|
||||
return emitPositiveInfinityConstantOp(rewriter, loc, type);
|
||||
}
|
||||
|
||||
template <>
|
||||
float getIdentityValue<float, ONNXReduceMinOp>(){
|
||||
return (float)std::numeric_limits<float>::infinity();
|
||||
Value getIdentityValue<ONNXReduceProdOp>(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Type type) {
|
||||
return emitConstantOp(rewriter, loc, type, 1);
|
||||
}
|
||||
|
||||
template <>
|
||||
int getIdentityValue<int, ONNXReduceMinOp>(){
|
||||
return std::numeric_limits<int>::max();
|
||||
}
|
||||
|
||||
template <>
|
||||
float getIdentityValue<float, ONNXReduceProdOp>(){
|
||||
return (float)1.0;
|
||||
}
|
||||
|
||||
template <>
|
||||
int getIdentityValue<int, ONNXReduceProdOp>(){
|
||||
return 1;
|
||||
}
|
||||
|
||||
template <>
|
||||
float getIdentityValue<float, ONNXReduceSumOp>(){
|
||||
return (float)0;
|
||||
}
|
||||
|
||||
template <>
|
||||
int getIdentityValue<int, ONNXReduceSumOp>(){
|
||||
return 0;
|
||||
Value getIdentityValue<ONNXReduceSumOp>(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Type type) {
|
||||
return emitConstantOp(rewriter, loc, type, 0);
|
||||
}
|
||||
|
||||
// Scalar ops
|
||||
|
@ -234,18 +218,8 @@ struct ONNXReductionOpLowering : public ConversionPattern {
|
|||
loopIVs.push_back(arg);
|
||||
}
|
||||
|
||||
Value identity;
|
||||
if (elementOutType.isa<FloatType>()) {
|
||||
identity = rewriter.create<ConstantOp>(
|
||||
loc, FloatAttr::get(elementOutType,
|
||||
getIdentityValue<float, ONNXReductionOp>()));
|
||||
} else if (elementOutType.isa<IntegerType>()) {
|
||||
identity = rewriter.create<ConstantOp>(
|
||||
loc, IntegerAttr::get(elementOutType,
|
||||
getIdentityValue<int, ONNXReductionOp>()));
|
||||
} else {
|
||||
emitError(loc, "unsupported element type");
|
||||
}
|
||||
Value identity =
|
||||
getIdentityValue<ONNXReductionOp>(rewriter, loc, elementOutType);
|
||||
rewriter.create<StoreOp>(loc, identity, alloc, loopIVs);
|
||||
|
||||
// Define an Krnl loop to do reduction.
|
||||
|
|
|
@ -48,8 +48,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
|||
MemRefType scalarMemRefType = MemRefType::get({}, elementType, {}, 0);
|
||||
Value sumOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true);
|
||||
Value maxOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true);
|
||||
Value zero =
|
||||
rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||
Value zero = emitConstantOp(rewriter, loc, elementType, 0);
|
||||
Value negInfinity = rewriter.create<ConstantOp>(
|
||||
loc,
|
||||
FloatAttr::get(elementType, -std::numeric_limits<float>::infinity()));
|
||||
|
|
|
@ -92,8 +92,7 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
|||
int64_t kernelsPerGroup = floor(kernelShape[0] / group);
|
||||
auto kernelsPerGroupValue =
|
||||
rewriter.create<ConstantIndexOp>(loc, kernelsPerGroup);
|
||||
auto zero = rewriter.create<ConstantOp>(
|
||||
loc, FloatAttr::get(memRefType.getElementType(), 0));
|
||||
auto zero = emitConstantOp(rewriter, loc, memRefType.getElementType(), 0);
|
||||
Value subchannels;
|
||||
if (kernelShape[1] < 0) {
|
||||
subchannels = rewriter.create<DimOp>(loc, kernelOperand, 1).getResult();
|
||||
|
|
|
@ -14,13 +14,9 @@ using namespace mlir;
|
|||
|
||||
// Identity values
|
||||
template <>
|
||||
float getIdentityValue<float, ONNXMaxPoolSingleOutOp>() {
|
||||
return (float)-std::numeric_limits<float>::infinity();
|
||||
}
|
||||
|
||||
template <>
|
||||
int getIdentityValue<int, ONNXMaxPoolSingleOutOp>() {
|
||||
return std::numeric_limits<int>::min();
|
||||
Value getIdentityValue<ONNXMaxPoolSingleOutOp>(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Type type) {
|
||||
return emitNegativeInfinityConstantOp(rewriter, loc, type);
|
||||
}
|
||||
|
||||
template <>
|
||||
|
@ -204,18 +200,8 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern {
|
|||
resultIndices.emplace_back(outerLoops.getInductionVar(i));
|
||||
|
||||
// 2.1 Emit: R[n][c][r1][r2] = negative_infinity;
|
||||
Value identity;
|
||||
if (resultElementType.isa<FloatType>()) {
|
||||
identity = rewriter.create<ConstantOp>(
|
||||
loc, FloatAttr::get(resultElementType,
|
||||
getIdentityValue<float, ONNXMaxPoolSingleOutOp>()));
|
||||
} else if (resultElementType.isa<IntegerType>()) {
|
||||
identity = rewriter.create<ConstantOp>(
|
||||
loc, IntegerAttr::get(resultElementType,
|
||||
getIdentityValue<int, ONNXMaxPoolSingleOutOp>()));
|
||||
} else {
|
||||
emitError(loc, "unsupported element type");
|
||||
}
|
||||
Value identity = getIdentityValue<ONNXMaxPoolSingleOutOp>(
|
||||
rewriter, loc, resultElementType);
|
||||
rewriter.create<StoreOp>(loc, identity, alloc, resultIndices);
|
||||
|
||||
// 2.2 Define inner loops.
|
||||
|
|
|
@ -322,3 +322,166 @@ getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
|
|||
}
|
||||
return newLoopIVs;
|
||||
}
|
||||
|
||||
Value emitConstantOp(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Type type, double value) {
|
||||
Attribute constantAttr;
|
||||
auto typeKind = type.getKind();
|
||||
if (typeKind == StandardTypes::F16) {
|
||||
constantAttr = rewriter.getF16FloatAttr((float)value);
|
||||
} else if (typeKind == StandardTypes::F32) {
|
||||
constantAttr = rewriter.getF32FloatAttr((float)value);
|
||||
} else if (typeKind == StandardTypes::F64) {
|
||||
constantAttr = rewriter.getF64FloatAttr(value);
|
||||
} else if (typeKind == StandardTypes::Integer) {
|
||||
auto width = type.cast<IntegerType>().getWidth();
|
||||
if (width == 1) {
|
||||
constantAttr = rewriter.getBoolAttr(false);
|
||||
} else {
|
||||
constantAttr =
|
||||
rewriter.getIntegerAttr(type, APInt(width, (int64_t)value));
|
||||
}
|
||||
} else if (typeKind == StandardTypes::Index) {
|
||||
constantAttr = rewriter.getIntegerAttr(type, (int64_t)value);
|
||||
} else {
|
||||
emitError(loc, "unsupported element type");
|
||||
}
|
||||
|
||||
return rewriter.create<ConstantOp>(loc, constantAttr);
|
||||
}
|
||||
|
||||
Value emitPositiveInfinityConstantOp(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Type type) {
|
||||
Attribute constantAttr;
|
||||
auto typeKind = type.getKind();
|
||||
if (typeKind == StandardTypes::F16) {
|
||||
// 0x7C00
|
||||
float value = std::numeric_limits<float>::infinity();
|
||||
constantAttr = rewriter.getF16FloatAttr(value);
|
||||
} else if (typeKind == StandardTypes::F32) {
|
||||
// 0x7F800000
|
||||
float value = std::numeric_limits<float>::infinity();
|
||||
constantAttr = rewriter.getF32FloatAttr(value);
|
||||
} else if (typeKind == StandardTypes::F64) {
|
||||
// 0x7FF0000000000000
|
||||
double value = std::numeric_limits<double>::infinity();
|
||||
constantAttr = rewriter.getF64FloatAttr(value);
|
||||
} else if (typeKind == StandardTypes::Integer) {
|
||||
auto width = type.cast<IntegerType>().getWidth();
|
||||
// The latest llvm-project includes a patch which allows getting the sign of
|
||||
// IntegerType:
|
||||
// https://github.com/llvm/llvm-project/commit/35b685270b410f6a1351c2a527021f22330c25b9
|
||||
// as follows:
|
||||
// auto isSigned = type.cast<IntegerType>().isSigned();
|
||||
// TODO (tungld): update the following statement once our llvm-project is
|
||||
// upgraded to include the patch.
|
||||
auto isSigned = true;
|
||||
if (width == 8) {
|
||||
if (isSigned) {
|
||||
int8_t value = std::numeric_limits<int8_t>::max();
|
||||
constantAttr = rewriter.getIntegerAttr(type, APInt(width, value));
|
||||
} else {
|
||||
uint8_t value = std::numeric_limits<uint8_t>::max();
|
||||
constantAttr = rewriter.getIntegerAttr(type, APInt(width, value));
|
||||
}
|
||||
} else if (width == 16) {
|
||||
if (isSigned) {
|
||||
int16_t value = std::numeric_limits<int16_t>::max();
|
||||
constantAttr = rewriter.getIntegerAttr(type, APInt(width, value));
|
||||
} else {
|
||||
uint16_t value = std::numeric_limits<uint16_t>::max();
|
||||
constantAttr = rewriter.getIntegerAttr(type, APInt(width, value));
|
||||
}
|
||||
} else if (width == 32) {
|
||||
if (isSigned) {
|
||||
int32_t value = std::numeric_limits<int32_t>::max();
|
||||
constantAttr = rewriter.getIntegerAttr(type, APInt(width, value));
|
||||
} else {
|
||||
uint32_t value = std::numeric_limits<uint32_t>::max();
|
||||
constantAttr = rewriter.getIntegerAttr(type, APInt(width, value));
|
||||
}
|
||||
} else if (width == 64) {
|
||||
if (isSigned) {
|
||||
int64_t value = std::numeric_limits<int64_t>::max();
|
||||
constantAttr = rewriter.getIntegerAttr(type, APInt(width, value));
|
||||
} else {
|
||||
uint64_t value = std::numeric_limits<uint64_t>::max();
|
||||
constantAttr = rewriter.getIntegerAttr(type, APInt(width, value));
|
||||
}
|
||||
} else {
|
||||
emitError(loc, "unsupported element type");
|
||||
}
|
||||
} else {
|
||||
emitError(loc, "unsupported element type");
|
||||
}
|
||||
|
||||
return rewriter.create<ConstantOp>(loc, constantAttr);
|
||||
}
|
||||
|
||||
Value emitNegativeInfinityConstantOp(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Type type) {
|
||||
Attribute constantAttr;
|
||||
auto typeKind = type.getKind();
|
||||
if (typeKind == StandardTypes::F16) {
|
||||
// 0xFC00
|
||||
float value = -std::numeric_limits<float>::infinity();
|
||||
constantAttr = rewriter.getF16FloatAttr(value);
|
||||
} else if (typeKind == StandardTypes::F32) {
|
||||
// 0xFF800000
|
||||
float value = -std::numeric_limits<float>::infinity();
|
||||
constantAttr = rewriter.getF32FloatAttr(value);
|
||||
} else if (typeKind == StandardTypes::F64) {
|
||||
// 0xFFF0000000000000
|
||||
double value = -std::numeric_limits<double>::infinity();
|
||||
constantAttr = rewriter.getF64FloatAttr(value);
|
||||
} else if (typeKind == StandardTypes::Integer) {
|
||||
auto width = type.cast<IntegerType>().getWidth();
|
||||
// The latest llvm-project includes a patch which allows getting the sign of
|
||||
// IntegerType:
|
||||
// https://github.com/llvm/llvm-project/commit/35b685270b410f6a1351c2a527021f22330c25b9
|
||||
// as follows:
|
||||
// auto isSigned = type.cast<IntegerType>().isSigned();
|
||||
// TODO (tungld): update the following statement once our llvm-project is
|
||||
// upgraded to include the patch.
|
||||
auto isSigned = true;
|
||||
if (width == 8) {
|
||||
if (isSigned) {
|
||||
int8_t value = std::numeric_limits<int8_t>::min();
|
||||
constantAttr = rewriter.getIntegerAttr(type, APInt(width, value));
|
||||
} else {
|
||||
uint8_t value = std::numeric_limits<uint8_t>::min();
|
||||
constantAttr = rewriter.getIntegerAttr(type, APInt(width, value));
|
||||
}
|
||||
} else if (width == 16) {
|
||||
if (isSigned) {
|
||||
int16_t value = std::numeric_limits<int16_t>::min();
|
||||
constantAttr = rewriter.getIntegerAttr(type, APInt(width, value));
|
||||
} else {
|
||||
uint16_t value = std::numeric_limits<uint16_t>::min();
|
||||
constantAttr = rewriter.getIntegerAttr(type, APInt(width, value));
|
||||
}
|
||||
} else if (width == 32) {
|
||||
if (isSigned) {
|
||||
int32_t value = std::numeric_limits<int32_t>::min();
|
||||
constantAttr = rewriter.getIntegerAttr(type, APInt(width, value));
|
||||
} else {
|
||||
uint32_t value = std::numeric_limits<uint32_t>::min();
|
||||
constantAttr = rewriter.getIntegerAttr(type, APInt(width, value));
|
||||
}
|
||||
} else if (width == 64) {
|
||||
if (isSigned) {
|
||||
int64_t value = std::numeric_limits<int64_t>::min();
|
||||
constantAttr = rewriter.getIntegerAttr(type, APInt(width, value));
|
||||
} else {
|
||||
uint64_t value = std::numeric_limits<uint64_t>::min();
|
||||
constantAttr = rewriter.getIntegerAttr(type, APInt(width, value));
|
||||
}
|
||||
} else {
|
||||
emitError(loc, "unsupported element type");
|
||||
}
|
||||
} else {
|
||||
emitError(loc, "unsupported element type");
|
||||
}
|
||||
|
||||
return rewriter.create<ConstantOp>(loc, constantAttr);
|
||||
}
|
||||
|
|
|
@ -98,6 +98,25 @@ getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
|
|||
ArrayRef<Value> loopIVs, Value operand,
|
||||
std::map<int, Value> broadcastedDims);
|
||||
|
||||
// Emit a constant of a specific type.
|
||||
// Use this function for small values only to avoid unexpected loss in type
|
||||
// casting.
|
||||
Value emitConstantOp(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Type type, double value);
|
||||
|
||||
// Emit a positive infinity constant of a specific type.
|
||||
// Supported types: F16, F32, F64, Int8, Int16, Int32, Int64.
|
||||
// In case of Integer, emit the maximum value.
|
||||
Value emitPositiveInfinityConstantOp(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Type type);
|
||||
|
||||
// Emit a negative infinity constant of a specific type.
|
||||
// Supported types: F16, F32, F64, Int8, Int16, Int32, Int64.
|
||||
// In case of Float, emit the negative of the positive infinity.
|
||||
// In case of Integer, emit the minimum value.
|
||||
Value emitNegativeInfinityConstantOp(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Type type);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// This is to get a scalar operation of a given type for a specific operation.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -112,11 +131,13 @@ using ScalarFOp = typename ScalarOp<FOp>::FOp;
|
|||
template <typename IOp>
|
||||
using ScalarIOp = typename ScalarOp<IOp>::IOp;
|
||||
|
||||
// Get the identity element of a operation.
|
||||
// Get the identity element of an operation.
|
||||
// Return NULL if the function does not have identity.
|
||||
template <typename DataType, typename Op>
|
||||
DataType getIdentityValue() {
|
||||
return NULL;
|
||||
// Specialize this for a new Op.
|
||||
template <typename Op>
|
||||
Value getIdentityValue(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Type type) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -28,9 +28,8 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
|||
Value alloc;
|
||||
|
||||
// Compute size in bytes using the input tensor.
|
||||
Value tensorSize = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||
getMemRefEltSizeInBytes(memRefType)));
|
||||
Value tensorSize = emitConstantOp(rewriter, loc,
|
||||
rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType));
|
||||
for (int i = 0; i < inputShape.size(); ++i) {
|
||||
Value dimVal;
|
||||
if (inputShape[i] < 0) {
|
||||
|
@ -38,9 +37,8 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
|||
dimVal =
|
||||
rewriter.create<IndexCastOp>(loc, dim, rewriter.getIntegerType(64));
|
||||
} else {
|
||||
dimVal = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||
inputShape[i]));
|
||||
dimVal = emitConstantOp(
|
||||
rewriter, loc, rewriter.getIntegerType(64), inputShape[i]);
|
||||
}
|
||||
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, dimVal);
|
||||
}
|
||||
|
@ -59,13 +57,11 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
|||
// If the reduction is negative, then the shape array contains a negative
|
||||
// dimension. Otherwise, the reduction is the same as the one computed
|
||||
// from the input tensor.
|
||||
Value tensorSizeFromShape = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||
getMemRefEltSizeInBytes(memRefType)));
|
||||
Value tensorSizeFromShape = emitConstantOp(rewriter, loc,
|
||||
rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType));
|
||||
SmallVector<Value, 4> DimInfo;
|
||||
for (int i = 0; i < memRefShape.size(); ++i) {
|
||||
Value index = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i));
|
||||
Value index = emitConstantOp(rewriter, loc, rewriter.getIndexType(), i);
|
||||
// Load index from array of indices.
|
||||
Value loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
|
||||
// If a dimension is zero, the actual dimension value is taken from the
|
||||
|
@ -81,11 +77,10 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
|||
Value dim = rewriter.create<DimOp>(loc, operands[0], i);
|
||||
dimVal = rewriter.create<IndexCastOp>(loc, dim, loadedValType);
|
||||
} else {
|
||||
dimVal = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(loadedValType, inputShape[i]));
|
||||
dimVal =
|
||||
emitConstantOp(rewriter, loc, loadedValType, inputShape[i]);
|
||||
}
|
||||
auto zero = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(loadedValType, 0));
|
||||
auto zero = emitConstantOp(rewriter, loc, loadedValType, 0);
|
||||
auto isZero =
|
||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, loadedVal, zero);
|
||||
loadedVal = rewriter.create<SelectOp>(loc, isZero, dimVal, loadedVal);
|
||||
|
@ -104,15 +99,14 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
|||
// Reverse tensorSizeFromShape since it is negative if the shape array has
|
||||
// a negative dimension. This is safe since we only use it to compute the
|
||||
// actual value for the negative dimension.
|
||||
auto zero = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||
auto zero = emitConstantOp(rewriter, loc, rewriter.getIntegerType(64), 0);
|
||||
tensorSizeFromShape =
|
||||
rewriter.create<SubIOp>(loc, zero, tensorSizeFromShape);
|
||||
|
||||
// Obtain operands for AllocOp.
|
||||
SmallVector<Value, 4> allocOperands;
|
||||
auto negOne = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), -1));
|
||||
auto negOne =
|
||||
emitConstantOp(rewriter, loc, rewriter.getIntegerType(64), -1);
|
||||
|
||||
for (int i = 0; i < memRefShape.size(); ++i) {
|
||||
auto dimVal = DimInfo[i];
|
||||
|
|
|
@ -37,18 +37,16 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
|
|||
Value alloc;
|
||||
|
||||
// Compute size in bytes.
|
||||
Value tensorSize = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||
getMemRefEltSizeInBytes(memRefType)));
|
||||
Value tensorSize = emitConstantOp(rewriter, loc,
|
||||
rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType));
|
||||
|
||||
bool insertDealloc = checkInsertDealloc(op);
|
||||
auto memRefShape = memRefType.getShape();
|
||||
if (hasAllConstantDimensions(memRefType)) {
|
||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||
for (int i = 0; i < memRefShape.size(); ++i) {
|
||||
Value dimVal = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||
memRefShape[i]));
|
||||
Value dimVal = emitConstantOp(
|
||||
rewriter, loc, rewriter.getIntegerType(64), memRefShape[i]);
|
||||
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, dimVal);
|
||||
}
|
||||
} else {
|
||||
|
@ -62,9 +60,8 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
|
|||
loc, index, rewriter.getIntegerType(64));
|
||||
allocOperands.emplace_back(index);
|
||||
} else {
|
||||
dimVal = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||
memRefShape[outIdx]));
|
||||
dimVal = emitConstantOp(
|
||||
rewriter, loc, rewriter.getIntegerType(64), memRefShape[outIdx]);
|
||||
}
|
||||
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, dimVal);
|
||||
if (std::find(axes.begin(), axes.end(), outIdx) == axes.end())
|
||||
|
|
Loading…
Reference in New Issue