Add verifier to check for required attributes.
This commit is contained in:
parent
51b0f4c9dd
commit
ab8e2f9a1b
|
@ -104,7 +104,7 @@ def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias",
|
|||
}
|
||||
|
||||
def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
|
||||
[NoSideEffect]> {
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX Conv operation with no Bias operand.";
|
||||
let description = [{
|
||||
"The convolution operator consumes an input tensor and a filter, and"
|
||||
|
@ -112,6 +112,8 @@ def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
|
|||
}];
|
||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W);
|
||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
|
||||
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
|
||||
|
|
|
@ -448,6 +448,94 @@ LogicalResult verify(ONNXTransposeOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Conv
|
||||
|
||||
// For this operation, we define the attributes once in the original Conv
|
||||
// operation class. There is no need to redefine the attribute names for the
|
||||
// other classes based on Conv.
|
||||
void ONNXConvNoBiasOp::inferShapes() {
|
||||
// Generic shape for data input X and weight tensor W:
|
||||
// X: (N x C x D1 x D2 ... x Dn)
|
||||
// W: (M x C/group x k1 x k2 x ... x kn)
|
||||
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>())
|
||||
return;
|
||||
auto dataTy = getOperand(0)->getType().cast<RankedTensorType>();
|
||||
auto weightTy = getOperand(1)->getType().cast<RankedTensorType>();
|
||||
auto dataShape = dataTy.getShape();
|
||||
auto weightShape = weightTy.getShape();
|
||||
|
||||
if (dataShape.size() != weightShape.size())
|
||||
emitError("ConvNoBias: weight size not compatible with data size.");
|
||||
|
||||
// Group is a required attribute and should have default value of 1.
|
||||
int64_t group = getAttrOfType<IntegerAttr>(
|
||||
ONNXConvOp::getGroupAttrName()).getInt();
|
||||
if (!group)
|
||||
emitError("ConvNoBias: group attribute missing.");
|
||||
|
||||
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
|
||||
if (dataShape[1] != (weightShape[1] * group))
|
||||
emitError("ConvNoBias: channel dimension mismatch.");
|
||||
|
||||
// Required attributes.
|
||||
auto auto_pad = getAttrOfType<StringAttr>(
|
||||
ONNXConvOp::getAutoPadAttrName());
|
||||
auto pads = getAttrOfType<ArrayAttr>(
|
||||
ONNXConvOp::getPadsAttrName());
|
||||
|
||||
SmallVector<int64_t, 2> dims;
|
||||
// Insert batch size.
|
||||
dims.emplace_back(dataShape[0]);
|
||||
// Insert number of filters being applied (number of output channels).
|
||||
dims.emplace_back(weightShape[0]);
|
||||
|
||||
// // Compute the spatial dimensions.
|
||||
// SmallVector<int64_t, 2> spatialDims;
|
||||
// // Number of spatial dimensions.
|
||||
// int32_t nDims = dataTy.size() - 2;
|
||||
// // Initialize dimenions based on the input and weight spatial dimensions.
|
||||
// for (int i = 2; i < dataTy.size(); ++i)
|
||||
// spatialDims.emplace_back(dataTy[i] - weightTy[i]);
|
||||
// // Add padding information.
|
||||
// if () {
|
||||
// for (int i = 0; i < nDims; ++i) {
|
||||
// // Padding for beginning of axis.
|
||||
// int32_t p = (pads.getValue()[i]).cast<IntegerAttr>().getInt();
|
||||
// spatialDims[i] += p;
|
||||
// // Padding for end of axis.
|
||||
// p = (pads.getValue()[i + nDims]).cast<IntegerAttr>().getInt();
|
||||
// spatialDims[i] += p;
|
||||
// }
|
||||
// } else if () {
|
||||
// // Attribute pads has not been provided.
|
||||
// }
|
||||
|
||||
getResult().setType(RankedTensorType::get(dims, dataTy.getElementType()));
|
||||
}
|
||||
|
||||
LogicalResult verify(ONNXConvNoBiasOp op) {
|
||||
auto module = op.getParentOfType<ModuleOp>();
|
||||
if (!module)
|
||||
op.emitError("expected to belong to a module");
|
||||
|
||||
auto autoPadAttr = op.getAttrOfType<StringAttr>(
|
||||
ONNXConvOp::getAutoPadAttrName());
|
||||
if (!autoPadAttr)
|
||||
op.emitError("ONNXConvNoBiasOp: auto_pad attribute not specified.");
|
||||
|
||||
auto groupAttr =
|
||||
op.getAttrOfType<IntegerAttr>(ONNXConvOp::getGroupAttrName());
|
||||
if (!groupAttr)
|
||||
op.emitError("ONNXConvNoBiasOp: group attribute not specified.");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -324,6 +324,15 @@ def ONNXConvOp:ONNX_Op<"Conv",
|
|||
}];
|
||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B);
|
||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static StringRef getAutoPadAttrName() { return "auto_pad"; }
|
||||
static StringRef getDilationsAttrName() { return "dilations"; }
|
||||
static StringRef getGroupAttrName() { return "group"; }
|
||||
static StringRef getKernelShapeAttrName() { return "kernel_shape"; }
|
||||
static StringRef getPadsAttrName() { return "pads"; }
|
||||
static StringRef getStridesAttrName() { return "strides"; }
|
||||
}];
|
||||
}
|
||||
|
||||
def ONNXConvIntegerOp:ONNX_Op<"ConvInteger",
|
||||
|
|
|
@ -117,7 +117,8 @@ public:
|
|||
op->getName().getStringRef() != "onnx.GemmNoBias" &&
|
||||
op->getName().getStringRef() != "onnx.Reshape" &&
|
||||
op->getName().getStringRef() != "onnx.Transpose" &&
|
||||
op->getName().getStringRef() != "onnx.Softmax")
|
||||
op->getName().getStringRef() != "onnx.Softmax" &&
|
||||
op->getName().getStringRef() != "onnx.ConvNoBias")
|
||||
return false;
|
||||
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
||||
return !result_type.isa<RankedTensorType>();
|
||||
|
|
Loading…
Reference in New Issue