Create some helper functions to emit constant op for a specific type (#7)

* emitConstantOp with a given type

* Helper functions to create infinity constants

* Use new constant helper functions for MaxPoolSingleOut

Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
Gheorghe-Teodor Bercea 2020-03-05 14:21:00 -05:00 committed by GitHub
parent 8e1b30e133
commit 8a992b619f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 255 additions and 135 deletions

View File

@ -103,8 +103,8 @@ Value mapToLowerScalarOp<ONNXSinhOp>(Operation *op, ArrayRef<Type> result_types,
Value operand = operands[0];
auto elementType = result_types[0];
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto two = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 2));
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
auto two = emitConstantOp(rewriter, loc, elementType, 2);
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
auto exp = rewriter.create<ExpOp>(loc, operand);
auto negExp = rewriter.create<ExpOp>(loc, neg);
@ -127,8 +127,8 @@ Value mapToLowerScalarOp<ONNXCoshOp>(Operation *op, ArrayRef<Type> result_types,
Value operand = operands[0];
auto elementType = result_types[0];
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto two = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 2));
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
auto two = emitConstantOp(rewriter, loc, elementType, 2);
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
auto exp = rewriter.create<ExpOp>(loc, operand);
auto negExp = rewriter.create<ExpOp>(loc, neg);
@ -152,8 +152,8 @@ Value mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op,
Value operand = operands[0];
auto elementType = result_types[0];
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
auto one = emitConstantOp(rewriter, loc, elementType, 1);
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
auto negExp = rewriter.create<ExpOp>(loc, neg);
auto result = rewriter.create<DivFOp>(
@ -184,8 +184,8 @@ Value mapToLowerScalarOp<ONNXHardSigmoidOp>(
llvm::dyn_cast<ONNXHardSigmoidOp>(op).beta().convertToFloat());
auto elementType = result_types[0];
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
auto one = emitConstantOp(rewriter, loc, elementType, 1);
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
auto beta = rewriter.create<ConstantOp>(loc, betaAttribute);
@ -217,8 +217,8 @@ Value mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
llvm::dyn_cast<ONNXEluOp>(op).alpha().convertToFloat());
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
auto one = emitConstantOp(rewriter, loc, elementType, 1);
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
auto exp = rewriter.create<ExpOp>(loc, operand);
auto lessThanZero =
@ -246,7 +246,7 @@ Value mapToLowerScalarOp<ONNXReluOp>(Operation *op, ArrayRef<Type> result_types,
Value operand = operands[0];
auto elementType = result_types[0];
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
auto lessThanZero =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
auto result = rewriter.create<SelectOp>(loc, lessThanZero, zero, operand);
@ -271,7 +271,7 @@ Value mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op,
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
llvm::dyn_cast<ONNXLeakyReluOp>(op).alpha().convertToFloat());
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
auto lessThanZero =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
@ -301,7 +301,7 @@ Value mapToLowerScalarOp<ONNXSeluOp>(Operation *op, ArrayRef<Type> result_types,
llvm::dyn_cast<ONNXSeluOp>(op).gamma().convertToFloat());
auto elementType = result_types[0];
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
auto gamma = rewriter.create<ConstantOp>(loc, gammaAttribute);
auto exp = rewriter.create<ExpOp>(loc, operand);
@ -328,7 +328,7 @@ Value mapToLowerScalarOp<ONNXReciprocalOp>(
Value operand = operands[0];
auto elementType = result_types[0];
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
auto one = emitConstantOp(rewriter, loc, elementType, 1);
auto result = rewriter.create<DivFOp>(loc, one, operand);
return result;
@ -347,7 +347,7 @@ Value mapToLowerScalarOp<ONNXSoftplusOp>(
auto elementType = result_types[0];
auto exp = rewriter.create<ExpOp>(loc, operand);
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
auto one = emitConstantOp(rewriter, loc, elementType, 1);
auto add = rewriter.create<AddFOp>(loc, exp, one);
auto result = rewriter.create<LogOp>(loc, add);
@ -367,7 +367,7 @@ Value mapToLowerScalarOp<ONNXSoftsignOp>(
auto elementType = result_types[0];
auto abs = rewriter.create<AbsFOp>(loc, operand);
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
auto one = emitConstantOp(rewriter, loc, elementType, 1);
auto add = rewriter.create<AddFOp>(loc, abs, one);
auto result = rewriter.create<DivFOp>(loc, operand, add);
@ -384,19 +384,18 @@ Value mapToLowerScalarOp<ONNXSignOp>(Operation *op, ArrayRef<Type> result_types,
auto loc = op->getLoc();
Value operand = operands[0];
Type element_type = operands.front().getType();
Type elementType = operands.front().getType();
// TODO: unsigned int should be supported separately?
if (element_type.isa<IntegerType>()) {
if (elementType.isa<IntegerType>()) {
// %Y = SelectOP(CmpIOp(GT, %X, ConstantOp 0),
// ConstantOp 1,
// COnstantOp -1)
// ONNXSignOp(%X) = SelectOP(CmpIOp(EQ, %X, ConstantOp 0),
// ConstantOp 0,
// %Y)
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
auto one = rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(1));
auto minusOne =
rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(-1));
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
auto one = emitConstantOp(rewriter, loc, elementType, 1);
auto minusOne = emitConstantOp(rewriter, loc, elementType, -1);
auto plusPredicate =
rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, operand, zero);
auto plusSelect =
@ -406,18 +405,16 @@ Value mapToLowerScalarOp<ONNXSignOp>(Operation *op, ArrayRef<Type> result_types,
auto result =
rewriter.create<SelectOp>(loc, zeroPredicate, zero, plusSelect);
return result;
} else if (element_type.isa<FloatType>()) {
} else if (elementType.isa<FloatType>()) {
// %Y = SelectOP(CmpFOp(OGT, %X, ConstantOp 0),
// ConstantOp 1,
// ConstantOp -1)
// ONNXSignOp(%X) = SelectOP(CmpFOp(OEQ, %X, ConstantOp 0),
// ConstantOp 0,
// %Y)
auto zero =
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
auto minusOne =
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(-1.0f));
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
auto one = emitConstantOp(rewriter, loc, elementType, 1);
auto minusOne = emitConstantOp(rewriter, loc, elementType, -1);
auto plusPredicate =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, operand, zero);
auto plusSelect =

View File

@ -156,8 +156,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
}
// Initialize the output of A*B
auto zero = rewriter.create<ConstantOp>(
loc, FloatAttr::get(memRefType.getElementType(), 0));
auto zero = emitConstantOp(rewriter, loc, memRefType.getElementType(), 0);
rewriter.create<StoreOp>(loc, zero, alloc, loopMNIVs);
// Compute A*B

View File

@ -37,16 +37,7 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
auto memRefShape = memRefType.getShape();
// A value zero
Value zero;
if (elementType.isa<IntegerType>()) {
zero = rewriter.create<ConstantOp>(
loc, IntegerAttr::get(memRefType.getElementType(), 0));
} else if (elementType.isa<FloatType>()) {
zero = rewriter.create<ConstantOp>(
loc, FloatAttr::get(memRefType.getElementType(), 0));
} else {
emitError(loc, "unsupported element type");
}
auto zero = emitConstantOp(rewriter, loc, memRefType.getElementType(), 0);
// Insert an allocation and deallocation for the result of this operation.
Value alloc;

View File

@ -14,43 +14,27 @@ using namespace mlir;
// Identity values
template <>
float getIdentityValue<float, ONNXReduceMaxOp>(){
return (float)-std::numeric_limits<float>::infinity();
Value getIdentityValue<ONNXReduceMaxOp>(
ConversionPatternRewriter &rewriter, Location loc, Type type) {
return emitNegativeInfinityConstantOp(rewriter, loc, type);
}
template <>
int getIdentityValue<int, ONNXReduceMaxOp>(){
return std::numeric_limits<int>::min();
Value getIdentityValue<ONNXReduceMinOp>(
ConversionPatternRewriter &rewriter, Location loc, Type type) {
return emitPositiveInfinityConstantOp(rewriter, loc, type);
}
template <>
float getIdentityValue<float, ONNXReduceMinOp>(){
return (float)std::numeric_limits<float>::infinity();
Value getIdentityValue<ONNXReduceProdOp>(
ConversionPatternRewriter &rewriter, Location loc, Type type) {
return emitConstantOp(rewriter, loc, type, 1);
}
template <>
int getIdentityValue<int, ONNXReduceMinOp>(){
return std::numeric_limits<int>::max();
}
template <>
float getIdentityValue<float, ONNXReduceProdOp>(){
return (float)1.0;
}
template <>
int getIdentityValue<int, ONNXReduceProdOp>(){
return 1;
}
template <>
float getIdentityValue<float, ONNXReduceSumOp>(){
return (float)0;
}
template <>
int getIdentityValue<int, ONNXReduceSumOp>(){
return 0;
Value getIdentityValue<ONNXReduceSumOp>(
ConversionPatternRewriter &rewriter, Location loc, Type type) {
return emitConstantOp(rewriter, loc, type, 0);
}
// Scalar ops
@ -234,18 +218,8 @@ struct ONNXReductionOpLowering : public ConversionPattern {
loopIVs.push_back(arg);
}
Value identity;
if (elementOutType.isa<FloatType>()) {
identity = rewriter.create<ConstantOp>(
loc, FloatAttr::get(elementOutType,
getIdentityValue<float, ONNXReductionOp>()));
} else if (elementOutType.isa<IntegerType>()) {
identity = rewriter.create<ConstantOp>(
loc, IntegerAttr::get(elementOutType,
getIdentityValue<int, ONNXReductionOp>()));
} else {
emitError(loc, "unsupported element type");
}
Value identity =
getIdentityValue<ONNXReductionOp>(rewriter, loc, elementOutType);
rewriter.create<StoreOp>(loc, identity, alloc, loopIVs);
// Define an Krnl loop to do reduction.

View File

@ -48,8 +48,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
MemRefType scalarMemRefType = MemRefType::get({}, elementType, {}, 0);
Value sumOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true);
Value maxOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true);
Value zero =
rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
Value zero = emitConstantOp(rewriter, loc, elementType, 0);
Value negInfinity = rewriter.create<ConstantOp>(
loc,
FloatAttr::get(elementType, -std::numeric_limits<float>::infinity()));

View File

@ -92,8 +92,7 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
int64_t kernelsPerGroup = floor(kernelShape[0] / group);
auto kernelsPerGroupValue =
rewriter.create<ConstantIndexOp>(loc, kernelsPerGroup);
auto zero = rewriter.create<ConstantOp>(
loc, FloatAttr::get(memRefType.getElementType(), 0));
auto zero = emitConstantOp(rewriter, loc, memRefType.getElementType(), 0);
Value subchannels;
if (kernelShape[1] < 0) {
subchannels = rewriter.create<DimOp>(loc, kernelOperand, 1).getResult();

View File

@ -14,13 +14,9 @@ using namespace mlir;
// Identity values
template <>
float getIdentityValue<float, ONNXMaxPoolSingleOutOp>() {
return (float)-std::numeric_limits<float>::infinity();
}
template <>
int getIdentityValue<int, ONNXMaxPoolSingleOutOp>() {
return std::numeric_limits<int>::min();
Value getIdentityValue<ONNXMaxPoolSingleOutOp>(
ConversionPatternRewriter &rewriter, Location loc, Type type) {
return emitNegativeInfinityConstantOp(rewriter, loc, type);
}
template <>
@ -204,18 +200,8 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern {
resultIndices.emplace_back(outerLoops.getInductionVar(i));
// 2.1 Emit: R[n][c][r1][r2] = negative_infinity;
Value identity;
if (resultElementType.isa<FloatType>()) {
identity = rewriter.create<ConstantOp>(
loc, FloatAttr::get(resultElementType,
getIdentityValue<float, ONNXMaxPoolSingleOutOp>()));
} else if (resultElementType.isa<IntegerType>()) {
identity = rewriter.create<ConstantOp>(
loc, IntegerAttr::get(resultElementType,
getIdentityValue<int, ONNXMaxPoolSingleOutOp>()));
} else {
emitError(loc, "unsupported element type");
}
Value identity = getIdentityValue<ONNXMaxPoolSingleOutOp>(
rewriter, loc, resultElementType);
rewriter.create<StoreOp>(loc, identity, alloc, resultIndices);
// 2.2 Define inner loops.

View File

@ -322,3 +322,166 @@ getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
}
return newLoopIVs;
}
Value emitConstantOp(ConversionPatternRewriter &rewriter, Location loc,
Type type, double value) {
Attribute constantAttr;
auto typeKind = type.getKind();
if (typeKind == StandardTypes::F16) {
constantAttr = rewriter.getF16FloatAttr((float)value);
} else if (typeKind == StandardTypes::F32) {
constantAttr = rewriter.getF32FloatAttr((float)value);
} else if (typeKind == StandardTypes::F64) {
constantAttr = rewriter.getF64FloatAttr(value);
} else if (typeKind == StandardTypes::Integer) {
auto width = type.cast<IntegerType>().getWidth();
if (width == 1) {
constantAttr = rewriter.getBoolAttr(false);
} else {
constantAttr =
rewriter.getIntegerAttr(type, APInt(width, (int64_t)value));
}
} else if (typeKind == StandardTypes::Index) {
constantAttr = rewriter.getIntegerAttr(type, (int64_t)value);
} else {
emitError(loc, "unsupported element type");
}
return rewriter.create<ConstantOp>(loc, constantAttr);
}
Value emitPositiveInfinityConstantOp(
ConversionPatternRewriter &rewriter, Location loc, Type type) {
Attribute constantAttr;
auto typeKind = type.getKind();
if (typeKind == StandardTypes::F16) {
// 0x7C00
float value = std::numeric_limits<float>::infinity();
constantAttr = rewriter.getF16FloatAttr(value);
} else if (typeKind == StandardTypes::F32) {
// 0x7F800000
float value = std::numeric_limits<float>::infinity();
constantAttr = rewriter.getF32FloatAttr(value);
} else if (typeKind == StandardTypes::F64) {
// 0x7FF0000000000000
double value = std::numeric_limits<double>::infinity();
constantAttr = rewriter.getF64FloatAttr(value);
} else if (typeKind == StandardTypes::Integer) {
auto width = type.cast<IntegerType>().getWidth();
// The latest llvm-project includes a patch which allows getting the sign of
// IntegerType:
// https://github.com/llvm/llvm-project/commit/35b685270b410f6a1351c2a527021f22330c25b9
// as follows:
// auto isSigned = type.cast<IntegerType>().isSigned();
// TODO (tungld): update the following statement once our llvm-project is
// upgraded to include the patch.
auto isSigned = true;
if (width == 8) {
if (isSigned) {
int8_t value = std::numeric_limits<int8_t>::max();
constantAttr = rewriter.getIntegerAttr(type, APInt(width, value));
} else {
uint8_t value = std::numeric_limits<uint8_t>::max();
constantAttr = rewriter.getIntegerAttr(type, APInt(width, value));
}
} else if (width == 16) {
if (isSigned) {
int16_t value = std::numeric_limits<int16_t>::max();
constantAttr = rewriter.getIntegerAttr(type, APInt(width, value));
} else {
uint16_t value = std::numeric_limits<uint16_t>::max();
constantAttr = rewriter.getIntegerAttr(type, APInt(width, value));
}
} else if (width == 32) {
if (isSigned) {
int32_t value = std::numeric_limits<int32_t>::max();
constantAttr = rewriter.getIntegerAttr(type, APInt(width, value));
} else {
uint32_t value = std::numeric_limits<uint32_t>::max();
constantAttr = rewriter.getIntegerAttr(type, APInt(width, value));
}
} else if (width == 64) {
if (isSigned) {
int64_t value = std::numeric_limits<int64_t>::max();
constantAttr = rewriter.getIntegerAttr(type, APInt(width, value));
} else {
uint64_t value = std::numeric_limits<uint64_t>::max();
constantAttr = rewriter.getIntegerAttr(type, APInt(width, value));
}
} else {
emitError(loc, "unsupported element type");
}
} else {
emitError(loc, "unsupported element type");
}
return rewriter.create<ConstantOp>(loc, constantAttr);
}
Value emitNegativeInfinityConstantOp(
ConversionPatternRewriter &rewriter, Location loc, Type type) {
Attribute constantAttr;
auto typeKind = type.getKind();
if (typeKind == StandardTypes::F16) {
// 0xFC00
float value = -std::numeric_limits<float>::infinity();
constantAttr = rewriter.getF16FloatAttr(value);
} else if (typeKind == StandardTypes::F32) {
// 0xFF800000
float value = -std::numeric_limits<float>::infinity();
constantAttr = rewriter.getF32FloatAttr(value);
} else if (typeKind == StandardTypes::F64) {
// 0xFFF0000000000000
double value = -std::numeric_limits<double>::infinity();
constantAttr = rewriter.getF64FloatAttr(value);
} else if (typeKind == StandardTypes::Integer) {
auto width = type.cast<IntegerType>().getWidth();
// The latest llvm-project includes a patch which allows getting the sign of
// IntegerType:
// https://github.com/llvm/llvm-project/commit/35b685270b410f6a1351c2a527021f22330c25b9
// as follows:
// auto isSigned = type.cast<IntegerType>().isSigned();
// TODO (tungld): update the following statement once our llvm-project is
// upgraded to include the patch.
auto isSigned = true;
if (width == 8) {
if (isSigned) {
int8_t value = std::numeric_limits<int8_t>::min();
constantAttr = rewriter.getIntegerAttr(type, APInt(width, value));
} else {
uint8_t value = std::numeric_limits<uint8_t>::min();
constantAttr = rewriter.getIntegerAttr(type, APInt(width, value));
}
} else if (width == 16) {
if (isSigned) {
int16_t value = std::numeric_limits<int16_t>::min();
constantAttr = rewriter.getIntegerAttr(type, APInt(width, value));
} else {
uint16_t value = std::numeric_limits<uint16_t>::min();
constantAttr = rewriter.getIntegerAttr(type, APInt(width, value));
}
} else if (width == 32) {
if (isSigned) {
int32_t value = std::numeric_limits<int32_t>::min();
constantAttr = rewriter.getIntegerAttr(type, APInt(width, value));
} else {
uint32_t value = std::numeric_limits<uint32_t>::min();
constantAttr = rewriter.getIntegerAttr(type, APInt(width, value));
}
} else if (width == 64) {
if (isSigned) {
int64_t value = std::numeric_limits<int64_t>::min();
constantAttr = rewriter.getIntegerAttr(type, APInt(width, value));
} else {
uint64_t value = std::numeric_limits<uint64_t>::min();
constantAttr = rewriter.getIntegerAttr(type, APInt(width, value));
}
} else {
emitError(loc, "unsupported element type");
}
} else {
emitError(loc, "unsupported element type");
}
return rewriter.create<ConstantOp>(loc, constantAttr);
}

View File

@ -98,6 +98,25 @@ getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
ArrayRef<Value> loopIVs, Value operand,
std::map<int, Value> broadcastedDims);
// Emit a constant of a specific type.
// Use this function for small values only to avoid unexpected loss in type
// casting.
Value emitConstantOp(
ConversionPatternRewriter &rewriter, Location loc, Type type, double value);
// Emit a positive infinity constant of a specific type.
// Supported types: F16, F32, F64, Int8, Int16, Int32, Int64.
// In case of Integer, emit the maximum value.
Value emitPositiveInfinityConstantOp(
ConversionPatternRewriter &rewriter, Location loc, Type type);
// Emit a negative infinity constant of a specific type.
// Supported types: F16, F32, F64, Int8, Int16, Int32, Int64.
// In case of Float, emit the negative of the positive infinity.
// In case of Integer, emit the minimum value.
Value emitNegativeInfinityConstantOp(
ConversionPatternRewriter &rewriter, Location loc, Type type);
//===----------------------------------------------------------------------===//
// This is to get a scalar operation of a given type for a specific operation.
//===----------------------------------------------------------------------===//
@ -112,11 +131,13 @@ using ScalarFOp = typename ScalarOp<FOp>::FOp;
template <typename IOp>
using ScalarIOp = typename ScalarOp<IOp>::IOp;
// Get the identity element of a operation.
// Get the identity element of an operation.
// Return NULL if the function does not have identity.
template <typename DataType, typename Op>
DataType getIdentityValue() {
return NULL;
// Specialize this for a new Op.
template <typename Op>
Value getIdentityValue(
ConversionPatternRewriter &rewriter, Location loc, Type type) {
return nullptr;
}
//===----------------------------------------------------------------------===//

View File

@ -28,9 +28,8 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
Value alloc;
// Compute size in bytes using the input tensor.
Value tensorSize = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
getMemRefEltSizeInBytes(memRefType)));
Value tensorSize = emitConstantOp(rewriter, loc,
rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType));
for (int i = 0; i < inputShape.size(); ++i) {
Value dimVal;
if (inputShape[i] < 0) {
@ -38,9 +37,8 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
dimVal =
rewriter.create<IndexCastOp>(loc, dim, rewriter.getIntegerType(64));
} else {
dimVal = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
inputShape[i]));
dimVal = emitConstantOp(
rewriter, loc, rewriter.getIntegerType(64), inputShape[i]);
}
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, dimVal);
}
@ -59,13 +57,11 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
// If the reduction is negative, then the shape array contains a negative
// dimension. Otherwise, the reduction is the same as the one computed
// from the input tensor.
Value tensorSizeFromShape = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
getMemRefEltSizeInBytes(memRefType)));
Value tensorSizeFromShape = emitConstantOp(rewriter, loc,
rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType));
SmallVector<Value, 4> DimInfo;
for (int i = 0; i < memRefShape.size(); ++i) {
Value index = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i));
Value index = emitConstantOp(rewriter, loc, rewriter.getIndexType(), i);
// Load index from array of indices.
Value loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
// If a dimension is zero, the actual dimension value is taken from the
@ -81,11 +77,10 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
Value dim = rewriter.create<DimOp>(loc, operands[0], i);
dimVal = rewriter.create<IndexCastOp>(loc, dim, loadedValType);
} else {
dimVal = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(loadedValType, inputShape[i]));
dimVal =
emitConstantOp(rewriter, loc, loadedValType, inputShape[i]);
}
auto zero = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(loadedValType, 0));
auto zero = emitConstantOp(rewriter, loc, loadedValType, 0);
auto isZero =
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, loadedVal, zero);
loadedVal = rewriter.create<SelectOp>(loc, isZero, dimVal, loadedVal);
@ -104,15 +99,14 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
// Reverse tensorSizeFromShape since it is negative if the shape array has
// a negative dimension. This is safe since we only use it to compute the
// actual value for the negative dimension.
auto zero = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
auto zero = emitConstantOp(rewriter, loc, rewriter.getIntegerType(64), 0);
tensorSizeFromShape =
rewriter.create<SubIOp>(loc, zero, tensorSizeFromShape);
// Obtain operands for AllocOp.
SmallVector<Value, 4> allocOperands;
auto negOne = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), -1));
auto negOne =
emitConstantOp(rewriter, loc, rewriter.getIntegerType(64), -1);
for (int i = 0; i < memRefShape.size(); ++i) {
auto dimVal = DimInfo[i];

View File

@ -37,18 +37,16 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
Value alloc;
// Compute size in bytes.
Value tensorSize = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
getMemRefEltSizeInBytes(memRefType)));
Value tensorSize = emitConstantOp(rewriter, loc,
rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType));
bool insertDealloc = checkInsertDealloc(op);
auto memRefShape = memRefType.getShape();
if (hasAllConstantDimensions(memRefType)) {
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
for (int i = 0; i < memRefShape.size(); ++i) {
Value dimVal = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
memRefShape[i]));
Value dimVal = emitConstantOp(
rewriter, loc, rewriter.getIntegerType(64), memRefShape[i]);
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, dimVal);
}
} else {
@ -62,9 +60,8 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
loc, index, rewriter.getIntegerType(64));
allocOperands.emplace_back(index);
} else {
dimVal = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
memRefShape[outIdx]));
dimVal = emitConstantOp(
rewriter, loc, rewriter.getIntegerType(64), memRefShape[outIdx]);
}
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, dimVal);
if (std::find(axes.begin(), axes.end(), outIdx) == axes.end())