Add support for Unsqueeze (#50)
* Infer shape for Unsqueeze * Lower Unsqueeze * Revise * Turn off backend tests * Compute tensorSize for static shape * Compute tensorSize with unknown dims * Edit tests * Update the use of attributes * Add e2e tests * Use SmallVector * Remove return * Check whether the operand is ranked or not Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
5b44169aaa
commit
9e82d388f0
|
@ -268,7 +268,7 @@ def gen_schema(schema) :
|
|||
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
|
||||
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
||||
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax',
|
||||
'Softplus', 'Softsign', 'Sqrt']
|
||||
'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze']
|
||||
CanonicalList=['Add', 'Identity']
|
||||
manual_code = dict([
|
||||
('DummyExample', ' let extraClassDeclaration = [{ \n'+
|
||||
|
|
|
@ -724,6 +724,46 @@ void ONNXConvNoBiasOp::inferShapes() {
|
|||
getResult().setType(RankedTensorType::get(dims, dataTy.getElementType()));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Unsqueeze
|
||||
|
||||
void ONNXUnsqueezeOp::inferShapes() {
|
||||
if (!getOperand().getType().isa<RankedTensorType>())
|
||||
return;
|
||||
|
||||
auto operandTy = getOperand().getType().cast<RankedTensorType>();
|
||||
int inRank = operandTy.getRank();
|
||||
|
||||
ArrayAttr axisAttrs = axesAttr();
|
||||
SmallVector<int, 4> axes;
|
||||
int outRank = 0;
|
||||
if (axisAttrs) {
|
||||
outRank = inRank + axisAttrs.getValue().size();
|
||||
for (auto axisAttr : axisAttrs.getValue()) {
|
||||
int axis = axisAttr.cast<IntegerAttr>().getInt();
|
||||
axis = axis >= 0 ? axis : (outRank + axis);
|
||||
// Valid range
|
||||
assert(axis >= -outRank && axis <= outRank - 1);
|
||||
if (std::find(axes.begin(), axes.end(), axis) == axes.end())
|
||||
axes.emplace_back(axis);
|
||||
else
|
||||
emitError("Duplicated axes.");
|
||||
}
|
||||
} else {
|
||||
emitError("Axes attribute is required.");
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 4> dims;
|
||||
for (int i = 0, j = 0; i < outRank || j < inRank; ++i) {
|
||||
if (std::find(axes.begin(), axes.end(), i) != axes.end()) {
|
||||
dims.emplace_back(1);
|
||||
} else {
|
||||
dims.emplace_back(operandTy.getShape()[j++]);
|
||||
}
|
||||
}
|
||||
getResult().setType(RankedTensorType::get(dims, operandTy.getElementType()));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -3502,7 +3502,7 @@ def ONNXUniqueOp:ONNX_Op<"Unique",
|
|||
}
|
||||
|
||||
def ONNXUnsqueezeOp:ONNX_Op<"Unsqueeze",
|
||||
[NoSideEffect]> {
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX Unsqueeze operation";
|
||||
let description = [{
|
||||
"Insert single-dimensional entries to the shape of an input tensor (`data`)."
|
||||
|
|
|
@ -1176,6 +1176,79 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
|||
}
|
||||
};
|
||||
|
||||
struct ONNXUnsqueezeOpLowering : public ConversionPattern {
|
||||
ONNXUnsqueezeOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(mlir::ONNXUnsqueezeOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto loc = op->getLoc();
|
||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||
int outRank = tensorType.getRank();
|
||||
|
||||
// Assume that `axes` has been validated by shape inference.
|
||||
// So, here we just get it.
|
||||
ArrayAttr axisAttrs = llvm::dyn_cast<ONNXUnsqueezeOp>(op).axesAttr();
|
||||
SmallVector<int, 4> axes;
|
||||
for (auto axisAttr : axisAttrs.getValue()) {
|
||||
int axis = axisAttr.cast<IntegerAttr>().getInt();
|
||||
axis = axis >= 0 ? axis : (outRank + axis);
|
||||
axes.emplace_back(axis);
|
||||
}
|
||||
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
Value alloc;
|
||||
|
||||
// Compute size in bytes.
|
||||
Value tensorSize = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||
getMemRefEltSizeInBytes(memRefType)));
|
||||
|
||||
bool insertDealloc = checkInsertDealloc(op);
|
||||
auto memRefShape = memRefType.getShape();
|
||||
if (hasAllConstantDimensions(memRefType)) {
|
||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||
for (int i = 0; i < memRefShape.size(); ++i) {
|
||||
Value dimVal = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||
memRefShape[i]));
|
||||
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, dimVal);
|
||||
}
|
||||
} else {
|
||||
// Unknown dimensions are always the operand's dimensions.
|
||||
SmallVector<Value, 4> allocOperands;
|
||||
for (int outIdx = 0, inIdx = 0; outIdx < memRefShape.size(); ++outIdx) {
|
||||
Value dimVal = nullptr;
|
||||
if (memRefShape[outIdx] < 0) {
|
||||
Value index = rewriter.create<DimOp>(loc, operands[0], inIdx);
|
||||
dimVal = rewriter.create<IndexCastOp>(
|
||||
loc, index, rewriter.getIntegerType(64));
|
||||
allocOperands.emplace_back(index);
|
||||
} else {
|
||||
dimVal = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||
memRefShape[outIdx]));
|
||||
}
|
||||
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, dimVal);
|
||||
if (std::find(axes.begin(), axes.end(), outIdx) == axes.end())
|
||||
inIdx++;
|
||||
}
|
||||
alloc = rewriter.create<AllocOp>(loc, memRefType, allocOperands);
|
||||
auto *parentBlock = alloc.getDefiningOp()->getBlock();
|
||||
if (insertDealloc) {
|
||||
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
|
||||
dealloc.getOperation()->moveBefore(&parentBlock->back());
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.create<KrnlMemcpyOp>(loc, alloc, operands[0], tensorSize);
|
||||
rewriter.replaceOp(op, alloc);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// EntryPoint Op lowering to Krnl Entry Point.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1304,7 +1377,7 @@ void FrontendToKrnlLoweringPass::runOnModule() {
|
|||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
|
||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
|
||||
ONNXReshapeOpLowering, ONNXEntryPointLowering,
|
||||
ONNXSoftmaxOpLowering>(&getContext());
|
||||
ONNXSoftmaxOpLowering, ONNXUnsqueezeOpLowering>(&getContext());
|
||||
|
||||
// With the target and rewrite patterns defined, we can now attempt the
|
||||
// conversion. The conversion will signal failure if any of our `illegal`
|
||||
|
|
|
@ -121,7 +121,8 @@ public:
|
|||
op->getName().getStringRef() != "onnx.Transpose" &&
|
||||
op->getName().getStringRef() != "onnx.Softmax" &&
|
||||
op->getName().getStringRef() != "onnx.Sqrt" &&
|
||||
op->getName().getStringRef() != "onnx.ConvNoBias")
|
||||
op->getName().getStringRef() != "onnx.ConvNoBias" &&
|
||||
op->getName().getStringRef() != "onnx.Unsqueeze")
|
||||
return false;
|
||||
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
||||
return !result_type.isa<RankedTensorType>();
|
||||
|
|
|
@ -147,6 +147,16 @@ test_to_enable = [
|
|||
"test_sum_one_input_cpu",
|
||||
"test_sum_two_inputs_cpu",
|
||||
|
||||
# Unsqueeze Op:
|
||||
"test_unsqueeze_axis_0_cpu",
|
||||
"test_unsqueeze_axis_1_cpu",
|
||||
"test_unsqueeze_axis_2_cpu",
|
||||
"test_unsqueeze_axis_3_cpu",
|
||||
"test_unsqueeze_negative_axes_cpu",
|
||||
"test_unsqueeze_three_axes_cpu",
|
||||
"test_unsqueeze_two_axes_cpu",
|
||||
"test_unsqueeze_unsorted_axes_cpu",
|
||||
|
||||
# Reciprocal Op:
|
||||
"test_reciprocal_cpu",
|
||||
"test_reciprocal_example_cpu",
|
||||
|
|
|
@ -652,3 +652,22 @@ func @test_sqrt(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
|||
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||
}
|
||||
|
||||
func @test_unsqueeze(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Unsqueeze"(%arg0) {axes=[0,3]} : (tensor<10x10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_unsqueeze
|
||||
// CHECK: [[RES:%.+]] = alloc() : memref<1x10x10x1xf32>
|
||||
// CHECK: [[INBYTES:%.+]] = constant 4 : i64
|
||||
// CHECK: [[DIM1:%.+]] = constant 1 : i64
|
||||
// CHECK: [[SIZE1:%.+]] = muli [[INBYTES]], [[DIM1]] : i64
|
||||
// CHECK: [[DIM2:%.+]] = constant 10 : i64
|
||||
// CHECK: [[SIZE2:%.+]] = muli [[SIZE1]], [[DIM2]] : i64
|
||||
// CHECK: [[DIM3:%.+]] = constant 10 : i64
|
||||
// CHECK: [[SIZE3:%.+]] = muli [[SIZE2]], [[DIM3]] : i64
|
||||
// CHECK: [[DIM4:%.+]] = constant 1 : i64
|
||||
// CHECK: [[SIZE4:%.+]] = muli [[SIZE3]], [[DIM4]] : i64
|
||||
// CHECK: "krnl.memcpy"([[RES]], %arg0, [[SIZE4]]) : (memref<1x10x10x1xf32>, memref<10x10xf32>, i64) -> ()
|
||||
// CHECK: return [[RES]] : memref<1x10x10x1xf32>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue