Handle SAME_LOWER and SAME_UPPER.
This commit is contained in:
parent
ec9e023f04
commit
169236a8fc
|
@ -412,7 +412,7 @@ void ONNXReshapeOp::inferShapes() {
|
||||||
void ONNXTransposeOp::inferShapes() {
|
void ONNXTransposeOp::inferShapes() {
|
||||||
// Cannot infer shape if no shape exists.
|
// Cannot infer shape if no shape exists.
|
||||||
if (!getOperand().getType().isa<RankedTensorType>())
|
if (!getOperand().getType().isa<RankedTensorType>())
|
||||||
emitError("Shape tensor not ranked.");
|
return;
|
||||||
|
|
||||||
// Naive transposition which handles the default case of
|
// Naive transposition which handles the default case of
|
||||||
// reversing the shape of the tensor (similar to numpy.transpose).
|
// reversing the shape of the tensor (similar to numpy.transpose).
|
||||||
|
@ -464,6 +464,7 @@ void ONNXConvNoBiasOp::inferShapes() {
|
||||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(1).getType().isa<RankedTensorType>())
|
!getOperand(1).getType().isa<RankedTensorType>())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
auto dataTy = getOperand(0).getType().cast<RankedTensorType>();
|
auto dataTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||||
auto weightTy = getOperand(1).getType().cast<RankedTensorType>();
|
auto weightTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||||
auto dataShape = dataTy.getShape();
|
auto dataShape = dataTy.getShape();
|
||||||
|
@ -492,34 +493,37 @@ void ONNXConvNoBiasOp::inferShapes() {
|
||||||
// Insert number of filters being applied (number of output channels).
|
// Insert number of filters being applied (number of output channels).
|
||||||
dims.emplace_back(weightShape[0]);
|
dims.emplace_back(weightShape[0]);
|
||||||
|
|
||||||
// Spatial dimensions are computed using the formula:
|
// Spatial dimensions of the output are computed using the formula:
|
||||||
//
|
//
|
||||||
// dim = (inputDim - kernelDim + startPadding + endPadding) / stride + 1
|
// dim = (inputDim - kernelDim + startPadding + endPadding) / stride + 1
|
||||||
//
|
//
|
||||||
SmallVector<int64_t, 2> spatialDims;
|
SmallVector<int64_t, 2> outSpatialDims;
|
||||||
// Number of spatial dimensions.
|
// Number of spatial dimensions.
|
||||||
int32_t nDims = dataShape.size() - 2;
|
int32_t nDims = dataShape.size() - 2;
|
||||||
|
|
||||||
// Initialize dimenions based on the input spatial dimensions.
|
// Initialize dimenions based on the input spatial dimensions.
|
||||||
for (int i = 2; i < dataShape.size(); ++i)
|
for (int i = 2; i < dataShape.size(); ++i)
|
||||||
spatialDims.emplace_back(dataShape[i]);
|
outSpatialDims.emplace_back(dataShape[i]);
|
||||||
|
|
||||||
// Use kernel_shape attribute if present otherwise use size from weight
|
// Use kernel_shape attribute if present otherwise use size from weight
|
||||||
// argument.
|
// argument.
|
||||||
if (auto kernel_shape = getAttrOfType<ArrayAttr>(
|
SmallVector<int64_t, 2> kernelDims;
|
||||||
|
if (auto kernelShape = getAttrOfType<ArrayAttr>(
|
||||||
ONNXConvOp::getKernelShapeAttrName())) {
|
ONNXConvOp::getKernelShapeAttrName())) {
|
||||||
if (kernel_shape.getValue().size() != nDims)
|
if (kernelShape.getValue().size() != nDims)
|
||||||
emitError("kernel_shape length incompatible with spatial dimensions.");
|
emitError("kernel_shape length incompatible with spatial dimensions.");
|
||||||
for (int i = 0; i < nDims; ++i) {
|
for (int i = 0; i < nDims; ++i)
|
||||||
int64_t kernelDim =
|
kernelDims[i] =
|
||||||
(kernel_shape.getValue()[i]).cast<IntegerAttr>().getInt();
|
(kernelShape.getValue()[i]).cast<IntegerAttr>().getInt();
|
||||||
spatialDims[i] -= kernelDim;
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < nDims; ++i)
|
for (int i = 0; i < nDims; ++i)
|
||||||
spatialDims[i] -= weightShape[i + 2];
|
kernelDims[i] = weightShape[i + 2];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Subtract kernel dimensions from input data dimensions.
|
||||||
|
for (int i = 0; i < nDims; ++i)
|
||||||
|
outSpatialDims[i] -= kernelDims[i];
|
||||||
|
|
||||||
// Add padding information.
|
// Add padding information.
|
||||||
if (autoPad == "NOTSET") {
|
if (autoPad == "NOTSET") {
|
||||||
// Use pads to to determine the padding. If attribute is not
|
// Use pads to to determine the padding. If attribute is not
|
||||||
|
@ -533,18 +537,23 @@ void ONNXConvNoBiasOp::inferShapes() {
|
||||||
for (int i = 0; i < nDims; ++i) {
|
for (int i = 0; i < nDims; ++i) {
|
||||||
// Padding for beginning of axis.
|
// Padding for beginning of axis.
|
||||||
int32_t p = (pads.getValue()[i]).cast<IntegerAttr>().getInt();
|
int32_t p = (pads.getValue()[i]).cast<IntegerAttr>().getInt();
|
||||||
spatialDims[i] += p;
|
outSpatialDims[i] += p;
|
||||||
// Padding for end of axis.
|
// Padding for end of axis.
|
||||||
p = (pads.getValue()[i + nDims]).cast<IntegerAttr>().getInt();
|
p = (pads.getValue()[i + nDims]).cast<IntegerAttr>().getInt();
|
||||||
spatialDims[i] += p;
|
outSpatialDims[i] += p;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||||
|
// Pad input so that output size matches input size.
|
||||||
|
// Each spatial dimension needs to be padded by:
|
||||||
|
//
|
||||||
|
// ( K - 1 ) / 2
|
||||||
|
//
|
||||||
|
// where K is a kernel spatial dimension.
|
||||||
|
for (int i = 0; i < nDims; ++i)
|
||||||
|
outSpatialDims[i] += floor((kernelDims[i] - 1) / 2);
|
||||||
} else if (autoPad == "VALID") {
|
} else if (autoPad == "VALID") {
|
||||||
// TODO
|
// No padding
|
||||||
} else if (autoPad == "SAME_UPPER") {
|
|
||||||
// TODO
|
|
||||||
} else if (autoPad == "SAME_LOWER") {
|
|
||||||
// TODO
|
|
||||||
} else {
|
} else {
|
||||||
emitError("Unexpected attribute value for auto_pad.");
|
emitError("Unexpected attribute value for auto_pad.");
|
||||||
}
|
}
|
||||||
|
@ -557,14 +566,14 @@ void ONNXConvNoBiasOp::inferShapes() {
|
||||||
for (int i = 0; i < nDims; ++i) {
|
for (int i = 0; i < nDims; ++i) {
|
||||||
int64_t stride =
|
int64_t stride =
|
||||||
(strides.getValue()[i]).cast<IntegerAttr>().getInt();
|
(strides.getValue()[i]).cast<IntegerAttr>().getInt();
|
||||||
spatialDims[i] = floor(spatialDims[i] / stride);
|
outSpatialDims[i] = floor(outSpatialDims[i] / stride);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < nDims; ++i)
|
for (int i = 0; i < nDims; ++i)
|
||||||
spatialDims[i] += 1;
|
outSpatialDims[i] += 1;
|
||||||
|
|
||||||
dims.append(spatialDims.begin(), spatialDims.end());
|
dims.append(outSpatialDims.begin(), outSpatialDims.end());
|
||||||
getResult().setType(RankedTensorType::get(dims, dataTy.getElementType()));
|
getResult().setType(RankedTensorType::get(dims, dataTy.getElementType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue