use input/output operation names, use helper for attribute function and int values (#106)
This commit is contained in:
parent
3b1c29c078
commit
3a88361b17
|
@ -406,12 +406,12 @@ void ONNXIdentityOp::inferShapes() {
|
||||||
|
|
||||||
void ONNXMatMulOp::inferShapes() {
|
void ONNXMatMulOp::inferShapes() {
|
||||||
// Cannot infer shape if no shape exists.
|
// Cannot infer shape if no shape exists.
|
||||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
if (!A().getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(1).getType().isa<RankedTensorType>())
|
!B().getType().isa<RankedTensorType>())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
auto lhsTy = A().getType().cast<RankedTensorType>();
|
||||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
auto rhsTy = B().getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
SmallVector<int64_t, 2> dims;
|
SmallVector<int64_t, 2> dims;
|
||||||
auto lhsShape = lhsTy.getShape();
|
auto lhsShape = lhsTy.getShape();
|
||||||
|
@ -419,14 +419,14 @@ void ONNXMatMulOp::inferShapes() {
|
||||||
|
|
||||||
if (lhsShape.size() < 1 && rhsShape.size() < 1) {
|
if (lhsShape.size() < 1 && rhsShape.size() < 1) {
|
||||||
// Multiplication by scalars is not allowed.
|
// Multiplication by scalars is not allowed.
|
||||||
emitError("Multiplication by scalar arguments not allowed.");
|
emitError("Multiplication by scalar arguments not allowed");
|
||||||
} else if (lhsShape.size() == 1 && rhsShape.size() == 1) {
|
} else if (lhsShape.size() == 1 && rhsShape.size() == 1) {
|
||||||
// Special case when both arrays are 1-dimensional and according to
|
// Special case when both arrays are 1-dimensional and according to
|
||||||
// numpy rules the types need to be extended to 1xN and Nx1. Helper sizes
|
// numpy rules the types need to be extended to 1xN and Nx1. Helper sizes
|
||||||
// need to be removed after the multiplication but cannot be removed if all
|
// need to be removed after the multiplication but cannot be removed if all
|
||||||
// sizes are 1.
|
// sizes are 1.
|
||||||
if (lhsShape[0] != -1 && rhsShape[0] != -1 && lhsShape[0] != rhsShape[0])
|
if (lhsShape[0] != -1 && rhsShape[0] != -1 && lhsShape[0] != rhsShape[0])
|
||||||
emitError("Attempt to multiply incompatible matrices.");
|
emitError("Attempt to multiply incompatible matrices");
|
||||||
dims.emplace_back(1);
|
dims.emplace_back(1);
|
||||||
} else if (lhsShape.size() == 1 && rhsShape.size() >= 2) {
|
} else if (lhsShape.size() == 1 && rhsShape.size() >= 2) {
|
||||||
// If the first argument is 1-D, it is promoted to a matrix by prepending a
|
// If the first argument is 1-D, it is promoted to a matrix by prepending a
|
||||||
|
@ -441,7 +441,7 @@ void ONNXMatMulOp::inferShapes() {
|
||||||
unsigned rhsRank = rhsShape.size();
|
unsigned rhsRank = rhsShape.size();
|
||||||
if (lhsShape[0] != -1 && rhsShape[rhsRank - 2] != -1 &&
|
if (lhsShape[0] != -1 && rhsShape[rhsRank - 2] != -1 &&
|
||||||
lhsShape[0] != rhsShape[rhsRank - 2])
|
lhsShape[0] != rhsShape[rhsRank - 2])
|
||||||
emitError("Attempt to multiply incompatible matrices.");
|
emitError("Attempt to multiply incompatible matrices");
|
||||||
|
|
||||||
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
|
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
|
||||||
dims.emplace_back(rhsShape[i]);
|
dims.emplace_back(rhsShape[i]);
|
||||||
|
@ -459,7 +459,7 @@ void ONNXMatMulOp::inferShapes() {
|
||||||
unsigned lhsRank = lhsShape.size();
|
unsigned lhsRank = lhsShape.size();
|
||||||
if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 &&
|
if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 &&
|
||||||
lhsShape[lhsRank - 1] != rhsShape[0])
|
lhsShape[lhsRank - 1] != rhsShape[0])
|
||||||
emitError("Attempt to multiply incompatible matrices.");
|
emitError("Attempt to multiply incompatible matrices");
|
||||||
|
|
||||||
for (decltype(lhsRank) i = 0; i < lhsRank - 2; ++i)
|
for (decltype(lhsRank) i = 0; i < lhsRank - 2; ++i)
|
||||||
dims.emplace_back(lhsShape[i]);
|
dims.emplace_back(lhsShape[i]);
|
||||||
|
@ -473,7 +473,7 @@ void ONNXMatMulOp::inferShapes() {
|
||||||
unsigned lhsRank = lhsShape.size();
|
unsigned lhsRank = lhsShape.size();
|
||||||
if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 &&
|
if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 &&
|
||||||
lhsShape[lhsRank - 1] != rhsShape[0])
|
lhsShape[lhsRank - 1] != rhsShape[0])
|
||||||
emitError("Attempt to multiply incompatible matrices.");
|
emitError("Attempt to multiply incompatible matrices");
|
||||||
|
|
||||||
for (decltype(lhsRank) i = 0; i < lhsRank - 1; ++i)
|
for (decltype(lhsRank) i = 0; i < lhsRank - 1; ++i)
|
||||||
dims.emplace_back(lhsShape[i]);
|
dims.emplace_back(lhsShape[i]);
|
||||||
|
@ -487,7 +487,7 @@ void ONNXMatMulOp::inferShapes() {
|
||||||
unsigned rhsRank = rhsShape.size();
|
unsigned rhsRank = rhsShape.size();
|
||||||
if (lhsShape[1] != -1 && rhsShape[rhsRank - 2] != -1 &&
|
if (lhsShape[1] != -1 && rhsShape[rhsRank - 2] != -1 &&
|
||||||
lhsShape[1] != rhsShape[rhsRank - 2])
|
lhsShape[1] != rhsShape[rhsRank - 2])
|
||||||
emitError("Attempt to multiply incompatible matrices.");
|
emitError("Attempt to multiply incompatible matrices");
|
||||||
|
|
||||||
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
|
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
|
||||||
dims.emplace_back(rhsShape[i]);
|
dims.emplace_back(rhsShape[i]);
|
||||||
|
@ -503,7 +503,7 @@ void ONNXMatMulOp::inferShapes() {
|
||||||
unsigned rhsRank = rhsShape.size();
|
unsigned rhsRank = rhsShape.size();
|
||||||
if (lhsShape[lhsRank - 1] != -1 && rhsShape[rhsRank - 2] != -1 &&
|
if (lhsShape[lhsRank - 1] != -1 && rhsShape[rhsRank - 2] != -1 &&
|
||||||
lhsShape[lhsRank - 1] != rhsShape[rhsRank - 2])
|
lhsShape[lhsRank - 1] != rhsShape[rhsRank - 2])
|
||||||
emitError("Attempt to multiply incompatible matrices.");
|
emitError("Attempt to multiply incompatible matrices");
|
||||||
|
|
||||||
// Check and perform broadcasting for the shapes.
|
// Check and perform broadcasting for the shapes.
|
||||||
SmallVector<int64_t, 2> lhsBcastShape;
|
SmallVector<int64_t, 2> lhsBcastShape;
|
||||||
|
@ -513,7 +513,7 @@ void ONNXMatMulOp::inferShapes() {
|
||||||
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
|
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
|
||||||
rhsBcastShape.emplace_back(rhsShape[i]);
|
rhsBcastShape.emplace_back(rhsShape[i]);
|
||||||
if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims))
|
if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims))
|
||||||
emitError("Broadcasted dimensions are incompatible.");
|
emitError("Broadcasted dimensions are incompatible");
|
||||||
|
|
||||||
dims.emplace_back(lhsShape[lhsRank - 2]);
|
dims.emplace_back(lhsShape[lhsRank - 2]);
|
||||||
dims.emplace_back(rhsShape[rhsRank - 1]);
|
dims.emplace_back(rhsShape[rhsRank - 1]);
|
||||||
|
@ -528,7 +528,7 @@ void ONNXMatMulOp::inferShapes() {
|
||||||
|
|
||||||
// Check legality of matrix multiplication.
|
// Check legality of matrix multiplication.
|
||||||
if (lhsDim != -1 && rhsDim != -1 && lhsDim != rhsDim)
|
if (lhsDim != -1 && rhsDim != -1 && lhsDim != rhsDim)
|
||||||
emitError("Attempt to multiply incompatible matrices.");
|
emitError("Attempt to multiply incompatible matrices");
|
||||||
|
|
||||||
if (rhsShape.size() > 1)
|
if (rhsShape.size() > 1)
|
||||||
dims.emplace_back(rhsShape[1]);
|
dims.emplace_back(rhsShape[1]);
|
||||||
|
@ -542,14 +542,14 @@ void ONNXMatMulOp::inferShapes() {
|
||||||
// Gemm
|
// Gemm
|
||||||
|
|
||||||
void ONNXGemmOp::inferShapes() {
|
void ONNXGemmOp::inferShapes() {
|
||||||
bool hasBias = !getOperand(2).getType().isa<NoneType>();
|
bool hasBias = !C().getType().isa<NoneType>();
|
||||||
// Cannot infer shape if no shape exists.
|
// Cannot infer shape if no shape exists.
|
||||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
if (!A().getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(1).getType().isa<RankedTensorType>() ||
|
!B().getType().isa<RankedTensorType>() ||
|
||||||
(hasBias && !getOperand(2).getType().isa<RankedTensorType>()))
|
(hasBias && !C().getType().isa<RankedTensorType>()))
|
||||||
return;
|
return;
|
||||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
auto lhsTy = A().getType().cast<RankedTensorType>();
|
||||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
auto rhsTy = B().getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
int64_t M, N, K_A, K_B;
|
int64_t M, N, K_A, K_B;
|
||||||
M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1];
|
M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1];
|
||||||
|
@ -558,12 +558,12 @@ void ONNXGemmOp::inferShapes() {
|
||||||
K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1];
|
K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1];
|
||||||
|
|
||||||
if ((K_A != -1) and (K_B != -1) and (K_A != K_B)) {
|
if ((K_A != -1) and (K_B != -1) and (K_A != K_B)) {
|
||||||
emitError("Tensor shapes mismatched.");
|
emitError("Tensor shapes mismatched");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (hasBias) {
|
if (hasBias) {
|
||||||
// Check whether bias is unidirectional broadcasting or not.
|
// Check whether bias is unidirectional broadcasting or not.
|
||||||
auto biasTy = getOperand(2).getType().cast<RankedTensorType>();
|
auto biasTy = C().getType().cast<RankedTensorType>();
|
||||||
auto shape = biasTy.getShape();
|
auto shape = biasTy.getShape();
|
||||||
int rank = shape.size();
|
int rank = shape.size();
|
||||||
if ((rank > 2) ||
|
if ((rank > 2) ||
|
||||||
|
@ -571,7 +571,7 @@ void ONNXGemmOp::inferShapes() {
|
||||||
N != shape[rank - 1] && shape[rank - 1] != 1) ||
|
N != shape[rank - 1] && shape[rank - 1] != 1) ||
|
||||||
(rank == 2 && shape[rank - 2] != -1 && M != -1 &&
|
(rank == 2 && shape[rank - 2] != -1 && M != -1 &&
|
||||||
M != shape[rank - 2] && shape[rank - 2] != 1)) {
|
M != shape[rank - 2] && shape[rank - 2] != 1)) {
|
||||||
emitError("Bias shape mismatched.");
|
emitError("Bias shape mismatched");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -584,50 +584,50 @@ void ONNXGemmOp::inferShapes() {
|
||||||
/// BatchNormalizationTestMode
|
/// BatchNormalizationTestMode
|
||||||
void ONNXBatchNormalizationTestModeOp::inferShapes() {
|
void ONNXBatchNormalizationTestModeOp::inferShapes() {
|
||||||
// Cannot infer shape if no shape exists.
|
// Cannot infer shape if no shape exists.
|
||||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
if (!X().getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(1).getType().isa<RankedTensorType>() ||
|
!scale().getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(2).getType().isa<RankedTensorType>() ||
|
!B().getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(3).getType().isa<RankedTensorType>() ||
|
!mean().getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(4).getType().isa<RankedTensorType>())
|
!var().getType().isa<RankedTensorType>())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
auto input = getOperand(0).getType().cast<RankedTensorType>();
|
auto inputTensorTy = X().getType().cast<RankedTensorType>();
|
||||||
auto scale = getOperand(1).getType().cast<RankedTensorType>();
|
auto scaleTensorTy = scale().getType().cast<RankedTensorType>();
|
||||||
auto bias = getOperand(2).getType().cast<RankedTensorType>();
|
auto biasTensorTy = B().getType().cast<RankedTensorType>();
|
||||||
auto mean = getOperand(3).getType().cast<RankedTensorType>();
|
auto meanTensorTy = mean().getType().cast<RankedTensorType>();
|
||||||
auto variance = getOperand(4).getType().cast<RankedTensorType>();
|
auto varianceTensorTy = var().getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
// Check whether the shapes of scale, bias, mean and variance are valid.
|
// Check whether the shapes of scale, bias, mean and variance are valid.
|
||||||
// Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N.
|
// Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N.
|
||||||
// In case of N, C is assumed to be 1.
|
// In case of N, C is assumed to be 1.
|
||||||
// Shapes of scale, bias, mean and variance must be C.
|
// Shapes of scale, bias, mean and variance must be C.
|
||||||
int64_t c = -1;
|
int64_t c = -1;
|
||||||
if (input.getShape().size() == 1) {
|
if (inputTensorTy.getShape().size() == 1) {
|
||||||
c = 1;
|
c = 1;
|
||||||
} else if (input.getShape().size() > 2) {
|
} else if (inputTensorTy.getShape().size() > 2) {
|
||||||
c = (input.getShape()[1] != -1) ? input.getShape()[1] : -1;
|
c = (inputTensorTy.getShape()[1] != -1) ? inputTensorTy.getShape()[1] : -1;
|
||||||
} else {
|
} else {
|
||||||
emitError("Wrong rank for the input.");
|
emitError("Wrong rank for the input");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (c != -1) {
|
if (c != -1) {
|
||||||
auto s = scale.getShape();
|
auto s = scaleTensorTy.getShape();
|
||||||
auto b = bias.getShape();
|
auto b = biasTensorTy.getShape();
|
||||||
auto m = mean.getShape();
|
auto m = meanTensorTy.getShape();
|
||||||
auto v = variance.getShape();
|
auto v = varianceTensorTy.getShape();
|
||||||
|
|
||||||
if ((s.size() != 1) || (s[0] != -1 && s[0] != c))
|
if ((s.size() != 1) || (s[0] != -1 && s[0] != c))
|
||||||
emitError("Wrong rank for the scale.");
|
emitError("Wrong rank for the scale");
|
||||||
if ((b.size() != 1) || (b[0] != -1 && b[0] != c))
|
if ((b.size() != 1) || (b[0] != -1 && b[0] != c))
|
||||||
emitError("Wrong rank for the bias.");
|
emitError("Wrong rank for the bias");
|
||||||
if ((m.size() != 1) || (m[0] != -1 && m[0] != c))
|
if ((m.size() != 1) || (m[0] != -1 && m[0] != c))
|
||||||
emitError("Wrong rank for the mean.");
|
emitError("Wrong rank for the mean");
|
||||||
if ((v.size() != 1) || (v[0] != -1 && v[0] != c))
|
if ((v.size() != 1) || (v[0] != -1 && v[0] != c))
|
||||||
emitError("Wrong rank for the variance.");
|
emitError("Wrong rank for the variance");
|
||||||
}
|
}
|
||||||
|
|
||||||
// The output tensor of the same shape as the input.
|
// The output tensor of the same shape as the input.
|
||||||
getResult().setType(getOperand(0).getType());
|
getResult().setType(X().getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO:
|
// TODO:
|
||||||
|
@ -640,21 +640,21 @@ void ONNXBatchNormalizationTestModeOp::inferShapes() {
|
||||||
|
|
||||||
void ONNXReshapeOp::inferShapes() {
|
void ONNXReshapeOp::inferShapes() {
|
||||||
// Cannot infer shape if no shape tensor is specified.
|
// Cannot infer shape if no shape tensor is specified.
|
||||||
if (!getOperand(1).getType().isa<RankedTensorType>())
|
if (!shape().getType().isa<RankedTensorType>())
|
||||||
emitError("Shape tensor not ranked.");
|
emitError("Shape tensor not ranked");
|
||||||
|
|
||||||
auto inputTensorTy = getOperand(0).getType().cast<RankedTensorType>();
|
auto inputTensorTy = data().getType().cast<RankedTensorType>();
|
||||||
auto shapeTensorTy = getOperand(1).getType().cast<RankedTensorType>();
|
auto shapeTensorTy = shape().getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
// Only rank 1 shape tensors are supported.
|
// Only rank 1 shape tensors are supported.
|
||||||
if (shapeTensorTy.getShape().size() != 1)
|
if (shapeTensorTy.getShape().size() != 1)
|
||||||
emitError("Shape tensor must have rank one.");
|
emitError("Shape tensor must have rank one");
|
||||||
|
|
||||||
int64_t outputRank = shapeTensorTy.getShape()[0];
|
int64_t outputRank = shapeTensorTy.getShape()[0];
|
||||||
|
|
||||||
// Shape tensor must have constant shape.
|
// Shape tensor must have constant shape.
|
||||||
if (outputRank < 0)
|
if (outputRank < 0)
|
||||||
emitError("Shape tensor must have constant shape.");
|
emitError("Shape tensor must have constant shape");
|
||||||
|
|
||||||
SmallVector<int64_t, 2> dims;
|
SmallVector<int64_t, 2> dims;
|
||||||
for (int i = 0; i < outputRank; ++i)
|
for (int i = 0; i < outputRank; ++i)
|
||||||
|
@ -670,12 +670,12 @@ void ONNXReshapeOp::inferShapes() {
|
||||||
|
|
||||||
void ONNXTransposeOp::inferShapes() {
|
void ONNXTransposeOp::inferShapes() {
|
||||||
// Cannot infer shape if no shape exists.
|
// Cannot infer shape if no shape exists.
|
||||||
if (!getOperand().getType().isa<RankedTensorType>())
|
if (!data().getType().isa<RankedTensorType>())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
// Naive transposition which handles the default case of
|
// Naive transposition which handles the default case of
|
||||||
// reversing the shape of the tensor (similar to numpy.transpose).
|
// reversing the shape of the tensor (similar to numpy.transpose).
|
||||||
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
|
auto arrayTy = data().getType().cast<RankedTensorType>();
|
||||||
SmallVector<int64_t, 2> dims;
|
SmallVector<int64_t, 2> dims;
|
||||||
auto permutation = ONNXTransposeOp::permAttr();
|
auto permutation = ONNXTransposeOp::permAttr();
|
||||||
if (permutation) {
|
if (permutation) {
|
||||||
|
@ -697,7 +697,7 @@ void ONNXTransposeOp::inferShapes() {
|
||||||
|
|
||||||
void ONNXReduceMaxOp::inferShapes() {
|
void ONNXReduceMaxOp::inferShapes() {
|
||||||
if (!getOperand().getType().isa<RankedTensorType>()) {
|
if (!getOperand().getType().isa<RankedTensorType>()) {
|
||||||
emitError("Shape tensor not ranked.");
|
emitError("Shape tensor not ranked");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -711,7 +711,7 @@ void ONNXReduceMaxOp::inferShapes() {
|
||||||
|
|
||||||
void ONNXReduceMinOp::inferShapes() {
|
void ONNXReduceMinOp::inferShapes() {
|
||||||
if (!getOperand().getType().isa<RankedTensorType>()) {
|
if (!getOperand().getType().isa<RankedTensorType>()) {
|
||||||
emitError("Shape tensor not ranked.");
|
emitError("Shape tensor not ranked");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -725,7 +725,7 @@ void ONNXReduceMinOp::inferShapes() {
|
||||||
|
|
||||||
void ONNXReduceProdOp::inferShapes() {
|
void ONNXReduceProdOp::inferShapes() {
|
||||||
if (!getOperand().getType().isa<RankedTensorType>()) {
|
if (!getOperand().getType().isa<RankedTensorType>()) {
|
||||||
emitError("Shape tensor not ranked.");
|
emitError("Shape tensor not ranked");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -739,7 +739,7 @@ void ONNXReduceProdOp::inferShapes() {
|
||||||
|
|
||||||
void ONNXReduceSumOp::inferShapes() {
|
void ONNXReduceSumOp::inferShapes() {
|
||||||
if (!getOperand().getType().isa<RankedTensorType>()) {
|
if (!getOperand().getType().isa<RankedTensorType>()) {
|
||||||
emitError("Shape tensor not ranked.");
|
emitError("Shape tensor not ranked");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -758,22 +758,22 @@ void ONNXConvNoBiasOp::inferShapes() {
|
||||||
// W: (M x C/group x k1 x k2 x ... x kn)
|
// W: (M x C/group x k1 x k2 x ... x kn)
|
||||||
|
|
||||||
// Cannot infer shape if no shape exists.
|
// Cannot infer shape if no shape exists.
|
||||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
if (!X().getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(1).getType().isa<RankedTensorType>())
|
!W().getType().isa<RankedTensorType>())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
auto dataTy = getOperand(0).getType().cast<RankedTensorType>();
|
auto dataTy = X().getType().cast<RankedTensorType>();
|
||||||
auto weightTy = getOperand(1).getType().cast<RankedTensorType>();
|
auto weightTy = W().getType().cast<RankedTensorType>();
|
||||||
auto dataShape = dataTy.getShape();
|
auto dataShape = dataTy.getShape();
|
||||||
auto weightShape = weightTy.getShape();
|
auto weightShape = weightTy.getShape();
|
||||||
|
|
||||||
// Lowest supported convolution is a one dimensional convolution.
|
// Lowest supported convolution is a one dimensional convolution.
|
||||||
if (dataShape.size() < 3)
|
if (dataShape.size() < 3)
|
||||||
emitError("Data input shape must be at least (NxCxD1).");
|
emitError("Data input shape must be at least (NxCxD1)");
|
||||||
|
|
||||||
// Check that shape of weight and data have same length.
|
// Check that shape of weight and data have same length.
|
||||||
if (dataShape.size() != weightShape.size())
|
if (dataShape.size() != weightShape.size())
|
||||||
emitError("Weight size not compatible with data size.");
|
emitError("Weight size not compatible with data size");
|
||||||
|
|
||||||
// Required attribute auto_pad defaults to NOTSET.
|
// Required attribute auto_pad defaults to NOTSET.
|
||||||
auto autoPad = auto_pad();
|
auto autoPad = auto_pad();
|
||||||
|
@ -782,7 +782,7 @@ void ONNXConvNoBiasOp::inferShapes() {
|
||||||
ONNXConvNoBiasOp::group().getSExtValue(); //.getLimitedValue();
|
ONNXConvNoBiasOp::group().getSExtValue(); //.getLimitedValue();
|
||||||
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
|
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
|
||||||
if (dataShape[1] != (weightShape[1] * group))
|
if (dataShape[1] != (weightShape[1] * group))
|
||||||
emitError("Channel dimension mismatch.");
|
emitError("Channel dimension mismatch");
|
||||||
|
|
||||||
// Note: the value of the group attribut only impacts the way the
|
// Note: the value of the group attribut only impacts the way the
|
||||||
// computation is carried out and not the actual output size.
|
// computation is carried out and not the actual output size.
|
||||||
|
@ -812,11 +812,10 @@ void ONNXConvNoBiasOp::inferShapes() {
|
||||||
// argument.
|
// argument.
|
||||||
SmallVector<int64_t, 2> kernelDims;
|
SmallVector<int64_t, 2> kernelDims;
|
||||||
if (auto kernelShape = kernel_shapeAttr()) {
|
if (auto kernelShape = kernel_shapeAttr()) {
|
||||||
if (kernelShape.getValue().size() != nDims)
|
if (ArrayAttrSize(kernelShape) != nDims)
|
||||||
emitError("kernel_shape length incompatible with spatial dimensions.");
|
emitError("kernel_shape length incompatible with spatial dimensions");
|
||||||
for (int i = 0; i < nDims; ++i)
|
for (int i = 0; i < nDims; ++i)
|
||||||
kernelDims.emplace_back(
|
kernelDims.emplace_back(ArrayAttrIntVal(kernelShape, i));
|
||||||
(kernelShape.getValue()[i]).cast<IntegerAttr>().getInt());
|
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < nDims; ++i)
|
for (int i = 0; i < nDims; ++i)
|
||||||
kernelDims.emplace_back(weightShape[i + 2]);
|
kernelDims.emplace_back(weightShape[i + 2]);
|
||||||
|
@ -834,13 +833,11 @@ void ONNXConvNoBiasOp::inferShapes() {
|
||||||
// From a dimensionality perspective the kernel size becomes the dilated
|
// From a dimensionality perspective the kernel size becomes the dilated
|
||||||
// kernel size.
|
// kernel size.
|
||||||
if (auto dilations = dilationsAttr()) {
|
if (auto dilations = dilationsAttr()) {
|
||||||
if (dilations.getValue().size() != nDims)
|
if (ArrayAttrSize(dilations) != nDims)
|
||||||
emitError("dilations length incompatible with spatial dimensions.");
|
emitError("dilations length incompatible with spatial dimensions");
|
||||||
for (int i = 0; i < nDims; ++i)
|
for (int i = 0; i < nDims; ++i)
|
||||||
kernelDims[i] =
|
kernelDims[i] =
|
||||||
(kernelDims[i] + 1) *
|
(kernelDims[i] + 1) * ArrayAttrIntVal(dilations, i) - 1;
|
||||||
(dilations.getValue()[i]).cast<IntegerAttr>().getInt() -
|
|
||||||
1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Subtract kernel dimensions from input data dimensions.
|
// Subtract kernel dimensions from input data dimensions.
|
||||||
|
@ -853,16 +850,14 @@ void ONNXConvNoBiasOp::inferShapes() {
|
||||||
// present then pads is considered to be all zeros (no padding).
|
// present then pads is considered to be all zeros (no padding).
|
||||||
if (auto pads = padsAttr()) {
|
if (auto pads = padsAttr()) {
|
||||||
// pads consists of two entries for each spatial axis.
|
// pads consists of two entries for each spatial axis.
|
||||||
if (pads.getValue().size() != 2 * nDims)
|
if (ArrayAttrSize(pads) != 2 * nDims)
|
||||||
emitError("pads size is not twice the spatial size.");
|
emitError("pads size is not twice the spatial size");
|
||||||
|
|
||||||
for (int i = 0; i < nDims; ++i) {
|
for (int i = 0; i < nDims; ++i) {
|
||||||
// Padding for beginning of axis.
|
// Padding for beginning of axis.
|
||||||
int32_t p = (pads.getValue()[i]).cast<IntegerAttr>().getInt();
|
outSpatialDims[i] += ArrayAttrIntVal(pads, i);
|
||||||
outSpatialDims[i] += p;
|
|
||||||
// Padding for end of axis.
|
// Padding for end of axis.
|
||||||
p = (pads.getValue()[i + nDims]).cast<IntegerAttr>().getInt();
|
outSpatialDims[i] += ArrayAttrIntVal(pads, i + nDims);
|
||||||
outSpatialDims[i] += p;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
} else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||||
|
@ -878,15 +873,15 @@ void ONNXConvNoBiasOp::inferShapes() {
|
||||||
} else if (autoPad == "VALID") {
|
} else if (autoPad == "VALID") {
|
||||||
// No padding
|
// No padding
|
||||||
} else {
|
} else {
|
||||||
emitError("Unexpected attribute value for auto_pad.");
|
emitError("Unexpected attribute value for auto_pad");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Strides
|
// Strides
|
||||||
if (auto strides = ONNXConvNoBiasOp::stridesAttr()) {
|
if (auto strides = ONNXConvNoBiasOp::stridesAttr()) {
|
||||||
if (strides.getValue().size() != nDims)
|
if (ArrayAttrSize(strides) != nDims)
|
||||||
emitError("strides length incompatible with spatial dimensions.");
|
emitError("strides length incompatible with spatial dimensions");
|
||||||
for (int i = 0; i < nDims; ++i) {
|
for (int i = 0; i < nDims; ++i) {
|
||||||
int64_t stride = strides.getValue()[i].cast<IntegerAttr>().getInt();
|
int64_t stride = ArrayAttrIntVal(strides, i);
|
||||||
outSpatialDims[i] = floor(outSpatialDims[i] / stride);
|
outSpatialDims[i] = floor(outSpatialDims[i] / stride);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1013,7 +1008,7 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (autoPad != "VALID") {
|
} else if (autoPad != "VALID") {
|
||||||
emitError("auto_pad of unknown / unsupported value.");
|
emitError("auto_pad of unknown / unsupported value");
|
||||||
}
|
}
|
||||||
// Set pads values in attributes.
|
// Set pads values in attributes.
|
||||||
{
|
{
|
||||||
|
@ -1044,7 +1039,7 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
|
||||||
}
|
}
|
||||||
yShape[kernelOffset + i] = res;
|
yShape[kernelOffset + i] = res;
|
||||||
}
|
}
|
||||||
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
|
auto arrayTy = X().getType().cast<RankedTensorType>();
|
||||||
getResult().setType(RankedTensorType::get(yShape, arrayTy.getElementType()));
|
getResult().setType(RankedTensorType::get(yShape, arrayTy.getElementType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1053,10 +1048,10 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
|
||||||
// Unsqueeze
|
// Unsqueeze
|
||||||
|
|
||||||
void ONNXUnsqueezeOp::inferShapes() {
|
void ONNXUnsqueezeOp::inferShapes() {
|
||||||
if (!getOperand().getType().isa<RankedTensorType>())
|
if (!data().getType().isa<RankedTensorType>())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
auto operandTy = getOperand().getType().cast<RankedTensorType>();
|
auto operandTy = data().getType().cast<RankedTensorType>();
|
||||||
int inRank = operandTy.getRank();
|
int inRank = operandTy.getRank();
|
||||||
|
|
||||||
ArrayAttr axisAttrs = axesAttr();
|
ArrayAttr axisAttrs = axesAttr();
|
||||||
|
@ -1072,10 +1067,10 @@ void ONNXUnsqueezeOp::inferShapes() {
|
||||||
if (std::find(axes.begin(), axes.end(), axis) == axes.end())
|
if (std::find(axes.begin(), axes.end(), axis) == axes.end())
|
||||||
axes.emplace_back(axis);
|
axes.emplace_back(axis);
|
||||||
else
|
else
|
||||||
emitError("Duplicated axes.");
|
emitError("Duplicated axes");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
emitError("Axes attribute is required.");
|
emitError("Axes attribute is required");
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<int64_t, 4> dims;
|
SmallVector<int64_t, 4> dims;
|
||||||
|
|
Loading…
Reference in New Issue