Add verifier to check for required attributes.

This commit is contained in:
Doru Bercea 2020-01-20 11:16:27 -05:00
parent 51b0f4c9dd
commit ab8e2f9a1b
4 changed files with 102 additions and 2 deletions

View File

@ -104,7 +104,7 @@ def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias",
}
def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
[NoSideEffect]> {
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Conv operation with no Bias operand.";
let description = [{
"The convolution operator consumes an input tensor and a filter, and"
@ -112,6 +112,8 @@ def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
let verifier = [{ return ::verify(*this); }];
}
def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",

View File

@ -448,6 +448,94 @@ LogicalResult verify(ONNXTransposeOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// Conv
// For this operation, we define the attributes once in the original Conv
// operation class. There is no need to redefine the attribute names for the
// other classes based on Conv.
void ONNXConvNoBiasOp::inferShapes() {
// Generic shape for data input X and weight tensor W:
// X: (N x C x D1 x D2 ... x Dn)
// W: (M x C/group x k1 x k2 x ... x kn)
// Cannot infer shape if no shape exists.
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
return;
auto dataTy = getOperand(0)->getType().cast<RankedTensorType>();
auto weightTy = getOperand(1)->getType().cast<RankedTensorType>();
auto dataShape = dataTy.getShape();
auto weightShape = weightTy.getShape();
if (dataShape.size() != weightShape.size())
emitError("ConvNoBias: weight size not compatible with data size.");
// Group is a required attribute and should have default value of 1.
int64_t group = getAttrOfType<IntegerAttr>(
ONNXConvOp::getGroupAttrName()).getInt();
if (!group)
emitError("ConvNoBias: group attribute missing.");
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
if (dataShape[1] != (weightShape[1] * group))
emitError("ConvNoBias: channel dimension mismatch.");
// Required attributes.
auto auto_pad = getAttrOfType<StringAttr>(
ONNXConvOp::getAutoPadAttrName());
auto pads = getAttrOfType<ArrayAttr>(
ONNXConvOp::getPadsAttrName());
SmallVector<int64_t, 2> dims;
// Insert batch size.
dims.emplace_back(dataShape[0]);
// Insert number of filters being applied (number of output channels).
dims.emplace_back(weightShape[0]);
// // Compute the spatial dimensions.
// SmallVector<int64_t, 2> spatialDims;
// // Number of spatial dimensions.
// int32_t nDims = dataTy.size() - 2;
// // Initialize dimenions based on the input and weight spatial dimensions.
// for (int i = 2; i < dataTy.size(); ++i)
// spatialDims.emplace_back(dataTy[i] - weightTy[i]);
// // Add padding information.
// if () {
// for (int i = 0; i < nDims; ++i) {
// // Padding for beginning of axis.
// int32_t p = (pads.getValue()[i]).cast<IntegerAttr>().getInt();
// spatialDims[i] += p;
// // Padding for end of axis.
// p = (pads.getValue()[i + nDims]).cast<IntegerAttr>().getInt();
// spatialDims[i] += p;
// }
// } else if () {
// // Attribute pads has not been provided.
// }
getResult().setType(RankedTensorType::get(dims, dataTy.getElementType()));
}
LogicalResult verify(ONNXConvNoBiasOp op) {
auto module = op.getParentOfType<ModuleOp>();
if (!module)
op.emitError("expected to belong to a module");
auto autoPadAttr = op.getAttrOfType<StringAttr>(
ONNXConvOp::getAutoPadAttrName());
if (!autoPadAttr)
op.emitError("ONNXConvNoBiasOp: auto_pad attribute not specified.");
auto groupAttr =
op.getAttrOfType<IntegerAttr>(ONNXConvOp::getGroupAttrName());
if (!groupAttr)
op.emitError("ONNXConvNoBiasOp: group attribute not specified.");
return success();
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//

View File

@ -324,6 +324,15 @@ def ONNXConvOp:ONNX_Op<"Conv",
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
let extraClassDeclaration = [{
static StringRef getAutoPadAttrName() { return "auto_pad"; }
static StringRef getDilationsAttrName() { return "dilations"; }
static StringRef getGroupAttrName() { return "group"; }
static StringRef getKernelShapeAttrName() { return "kernel_shape"; }
static StringRef getPadsAttrName() { return "pads"; }
static StringRef getStridesAttrName() { return "strides"; }
}];
}
def ONNXConvIntegerOp:ONNX_Op<"ConvInteger",

View File

@ -117,7 +117,8 @@ public:
op->getName().getStringRef() != "onnx.GemmNoBias" &&
op->getName().getStringRef() != "onnx.Reshape" &&
op->getName().getStringRef() != "onnx.Transpose" &&
op->getName().getStringRef() != "onnx.Softmax")
op->getName().getStringRef() != "onnx.Softmax" &&
op->getName().getStringRef() != "onnx.ConvNoBias")
return false;
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
return !result_type.isa<RankedTensorType>();