Add support for Unsqueeze (#50)
* Infer shape for Unsqueeze * Lower Unsqueeze * Revise * Turn off backend tests * Compute tensorSize for static shape * Compute tensorSize with unknown dims * Edit tests * Update the use of attributes * Add e2e tests * Use SmallVector * Remove return * Check whether the operand is ranked or not Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
5b44169aaa
commit
9e82d388f0
|
@ -268,7 +268,7 @@ def gen_schema(schema) :
|
||||||
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
|
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
|
||||||
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
||||||
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax',
|
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax',
|
||||||
'Softplus', 'Softsign', 'Sqrt']
|
'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze']
|
||||||
CanonicalList=['Add', 'Identity']
|
CanonicalList=['Add', 'Identity']
|
||||||
manual_code = dict([
|
manual_code = dict([
|
||||||
('DummyExample', ' let extraClassDeclaration = [{ \n'+
|
('DummyExample', ' let extraClassDeclaration = [{ \n'+
|
||||||
|
|
|
@ -724,6 +724,46 @@ void ONNXConvNoBiasOp::inferShapes() {
|
||||||
getResult().setType(RankedTensorType::get(dims, dataTy.getElementType()));
|
getResult().setType(RankedTensorType::get(dims, dataTy.getElementType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Unsqueeze
|
||||||
|
|
||||||
|
void ONNXUnsqueezeOp::inferShapes() {
|
||||||
|
if (!getOperand().getType().isa<RankedTensorType>())
|
||||||
|
return;
|
||||||
|
|
||||||
|
auto operandTy = getOperand().getType().cast<RankedTensorType>();
|
||||||
|
int inRank = operandTy.getRank();
|
||||||
|
|
||||||
|
ArrayAttr axisAttrs = axesAttr();
|
||||||
|
SmallVector<int, 4> axes;
|
||||||
|
int outRank = 0;
|
||||||
|
if (axisAttrs) {
|
||||||
|
outRank = inRank + axisAttrs.getValue().size();
|
||||||
|
for (auto axisAttr : axisAttrs.getValue()) {
|
||||||
|
int axis = axisAttr.cast<IntegerAttr>().getInt();
|
||||||
|
axis = axis >= 0 ? axis : (outRank + axis);
|
||||||
|
// Valid range
|
||||||
|
assert(axis >= -outRank && axis <= outRank - 1);
|
||||||
|
if (std::find(axes.begin(), axes.end(), axis) == axes.end())
|
||||||
|
axes.emplace_back(axis);
|
||||||
|
else
|
||||||
|
emitError("Duplicated axes.");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
emitError("Axes attribute is required.");
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t, 4> dims;
|
||||||
|
for (int i = 0, j = 0; i < outRank || j < inRank; ++i) {
|
||||||
|
if (std::find(axes.begin(), axes.end(), i) != axes.end()) {
|
||||||
|
dims.emplace_back(1);
|
||||||
|
} else {
|
||||||
|
dims.emplace_back(operandTy.getShape()[j++]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
getResult().setType(RankedTensorType::get(dims, operandTy.getElementType()));
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TableGen'd op method definitions
|
// TableGen'd op method definitions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -3502,7 +3502,7 @@ def ONNXUniqueOp:ONNX_Op<"Unique",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXUnsqueezeOp:ONNX_Op<"Unsqueeze",
|
def ONNXUnsqueezeOp:ONNX_Op<"Unsqueeze",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX Unsqueeze operation";
|
let summary = "ONNX Unsqueeze operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Insert single-dimensional entries to the shape of an input tensor (`data`)."
|
"Insert single-dimensional entries to the shape of an input tensor (`data`)."
|
||||||
|
|
|
@ -1176,6 +1176,79 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct ONNXUnsqueezeOpLowering : public ConversionPattern {
|
||||||
|
ONNXUnsqueezeOpLowering(MLIRContext *ctx)
|
||||||
|
: ConversionPattern(mlir::ONNXUnsqueezeOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
|
PatternMatchResult
|
||||||
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||||
|
int outRank = tensorType.getRank();
|
||||||
|
|
||||||
|
// Assume that `axes` has been validated by shape inference.
|
||||||
|
// So, here we just get it.
|
||||||
|
ArrayAttr axisAttrs = llvm::dyn_cast<ONNXUnsqueezeOp>(op).axesAttr();
|
||||||
|
SmallVector<int, 4> axes;
|
||||||
|
for (auto axisAttr : axisAttrs.getValue()) {
|
||||||
|
int axis = axisAttr.cast<IntegerAttr>().getInt();
|
||||||
|
axis = axis >= 0 ? axis : (outRank + axis);
|
||||||
|
axes.emplace_back(axis);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
|
auto memRefType = convertTensorToMemRef(tensorType);
|
||||||
|
Value alloc;
|
||||||
|
|
||||||
|
// Compute size in bytes.
|
||||||
|
Value tensorSize = rewriter.create<ConstantOp>(
|
||||||
|
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||||
|
getMemRefEltSizeInBytes(memRefType)));
|
||||||
|
|
||||||
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
|
auto memRefShape = memRefType.getShape();
|
||||||
|
if (hasAllConstantDimensions(memRefType)) {
|
||||||
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
|
for (int i = 0; i < memRefShape.size(); ++i) {
|
||||||
|
Value dimVal = rewriter.create<ConstantOp>(
|
||||||
|
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||||
|
memRefShape[i]));
|
||||||
|
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, dimVal);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Unknown dimensions are always the operand's dimensions.
|
||||||
|
SmallVector<Value, 4> allocOperands;
|
||||||
|
for (int outIdx = 0, inIdx = 0; outIdx < memRefShape.size(); ++outIdx) {
|
||||||
|
Value dimVal = nullptr;
|
||||||
|
if (memRefShape[outIdx] < 0) {
|
||||||
|
Value index = rewriter.create<DimOp>(loc, operands[0], inIdx);
|
||||||
|
dimVal = rewriter.create<IndexCastOp>(
|
||||||
|
loc, index, rewriter.getIntegerType(64));
|
||||||
|
allocOperands.emplace_back(index);
|
||||||
|
} else {
|
||||||
|
dimVal = rewriter.create<ConstantOp>(
|
||||||
|
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||||
|
memRefShape[outIdx]));
|
||||||
|
}
|
||||||
|
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, dimVal);
|
||||||
|
if (std::find(axes.begin(), axes.end(), outIdx) == axes.end())
|
||||||
|
inIdx++;
|
||||||
|
}
|
||||||
|
alloc = rewriter.create<AllocOp>(loc, memRefType, allocOperands);
|
||||||
|
auto *parentBlock = alloc.getDefiningOp()->getBlock();
|
||||||
|
if (insertDealloc) {
|
||||||
|
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
|
||||||
|
dealloc.getOperation()->moveBefore(&parentBlock->back());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.create<KrnlMemcpyOp>(loc, alloc, operands[0], tensorSize);
|
||||||
|
rewriter.replaceOp(op, alloc);
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// EntryPoint Op lowering to Krnl Entry Point.
|
// EntryPoint Op lowering to Krnl Entry Point.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1304,7 +1377,7 @@ void FrontendToKrnlLoweringPass::runOnModule() {
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
|
||||||
ONNXReshapeOpLowering, ONNXEntryPointLowering,
|
ONNXReshapeOpLowering, ONNXEntryPointLowering,
|
||||||
ONNXSoftmaxOpLowering>(&getContext());
|
ONNXSoftmaxOpLowering, ONNXUnsqueezeOpLowering>(&getContext());
|
||||||
|
|
||||||
// With the target and rewrite patterns defined, we can now attempt the
|
// With the target and rewrite patterns defined, we can now attempt the
|
||||||
// conversion. The conversion will signal failure if any of our `illegal`
|
// conversion. The conversion will signal failure if any of our `illegal`
|
||||||
|
|
|
@ -121,7 +121,8 @@ public:
|
||||||
op->getName().getStringRef() != "onnx.Transpose" &&
|
op->getName().getStringRef() != "onnx.Transpose" &&
|
||||||
op->getName().getStringRef() != "onnx.Softmax" &&
|
op->getName().getStringRef() != "onnx.Softmax" &&
|
||||||
op->getName().getStringRef() != "onnx.Sqrt" &&
|
op->getName().getStringRef() != "onnx.Sqrt" &&
|
||||||
op->getName().getStringRef() != "onnx.ConvNoBias")
|
op->getName().getStringRef() != "onnx.ConvNoBias" &&
|
||||||
|
op->getName().getStringRef() != "onnx.Unsqueeze")
|
||||||
return false;
|
return false;
|
||||||
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
||||||
return !result_type.isa<RankedTensorType>();
|
return !result_type.isa<RankedTensorType>();
|
||||||
|
|
|
@ -147,6 +147,16 @@ test_to_enable = [
|
||||||
"test_sum_one_input_cpu",
|
"test_sum_one_input_cpu",
|
||||||
"test_sum_two_inputs_cpu",
|
"test_sum_two_inputs_cpu",
|
||||||
|
|
||||||
|
# Unsqueeze Op:
|
||||||
|
"test_unsqueeze_axis_0_cpu",
|
||||||
|
"test_unsqueeze_axis_1_cpu",
|
||||||
|
"test_unsqueeze_axis_2_cpu",
|
||||||
|
"test_unsqueeze_axis_3_cpu",
|
||||||
|
"test_unsqueeze_negative_axes_cpu",
|
||||||
|
"test_unsqueeze_three_axes_cpu",
|
||||||
|
"test_unsqueeze_two_axes_cpu",
|
||||||
|
"test_unsqueeze_unsorted_axes_cpu",
|
||||||
|
|
||||||
# Reciprocal Op:
|
# Reciprocal Op:
|
||||||
"test_reciprocal_cpu",
|
"test_reciprocal_cpu",
|
||||||
"test_reciprocal_example_cpu",
|
"test_reciprocal_example_cpu",
|
||||||
|
|
|
@ -652,3 +652,22 @@ func @test_sqrt(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: return [[RES]] : memref<?x10xf32>
|
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @test_unsqueeze(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Unsqueeze"(%arg0) {axes=[0,3]} : (tensor<10x10xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_unsqueeze
|
||||||
|
// CHECK: [[RES:%.+]] = alloc() : memref<1x10x10x1xf32>
|
||||||
|
// CHECK: [[INBYTES:%.+]] = constant 4 : i64
|
||||||
|
// CHECK: [[DIM1:%.+]] = constant 1 : i64
|
||||||
|
// CHECK: [[SIZE1:%.+]] = muli [[INBYTES]], [[DIM1]] : i64
|
||||||
|
// CHECK: [[DIM2:%.+]] = constant 10 : i64
|
||||||
|
// CHECK: [[SIZE2:%.+]] = muli [[SIZE1]], [[DIM2]] : i64
|
||||||
|
// CHECK: [[DIM3:%.+]] = constant 10 : i64
|
||||||
|
// CHECK: [[SIZE3:%.+]] = muli [[SIZE2]], [[DIM3]] : i64
|
||||||
|
// CHECK: [[DIM4:%.+]] = constant 1 : i64
|
||||||
|
// CHECK: [[SIZE4:%.+]] = muli [[SIZE3]], [[DIM4]] : i64
|
||||||
|
// CHECK: "krnl.memcpy"([[RES]], %arg0, [[SIZE4]]) : (memref<1x10x10x1xf32>, memref<10x10xf32>, i64) -> ()
|
||||||
|
// CHECK: return [[RES]] : memref<1x10x10x1xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue