Add support for Unsqueeze (#50)

* Infer shape for Unsqueeze

* Lower Unsqueeze

* Revise

* Turn off backend tests

* Compute tensorSize for static shape

* Compute tensorSize with unknown dims

* Edit tests

* Update the use of attributes

* Add e2e tests

* Use SmallVector

* Remove return

* Check whether the operand is ranked or not

Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
Tung D. Le 2020-01-30 00:46:02 +09:00 committed by GitHub
parent 5b44169aaa
commit 9e82d388f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 147 additions and 4 deletions

View File

@ -268,7 +268,7 @@ def gen_schema(schema) :
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu', 'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'Identity', 'Cos', 'Log', 'Transpose', 'Softmax',
'Softplus', 'Softsign', 'Sqrt'] 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze']
CanonicalList=['Add', 'Identity'] CanonicalList=['Add', 'Identity']
manual_code = dict([ manual_code = dict([
('DummyExample', ' let extraClassDeclaration = [{ \n'+ ('DummyExample', ' let extraClassDeclaration = [{ \n'+

View File

@ -724,6 +724,46 @@ void ONNXConvNoBiasOp::inferShapes() {
getResult().setType(RankedTensorType::get(dims, dataTy.getElementType())); getResult().setType(RankedTensorType::get(dims, dataTy.getElementType()));
} }
//===----------------------------------------------------------------------===//
// Unsqueeze
void ONNXUnsqueezeOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>())
return;
auto operandTy = getOperand().getType().cast<RankedTensorType>();
int inRank = operandTy.getRank();
ArrayAttr axisAttrs = axesAttr();
SmallVector<int, 4> axes;
int outRank = 0;
if (axisAttrs) {
outRank = inRank + axisAttrs.getValue().size();
for (auto axisAttr : axisAttrs.getValue()) {
int axis = axisAttr.cast<IntegerAttr>().getInt();
axis = axis >= 0 ? axis : (outRank + axis);
// Valid range
assert(axis >= -outRank && axis <= outRank - 1);
if (std::find(axes.begin(), axes.end(), axis) == axes.end())
axes.emplace_back(axis);
else
emitError("Duplicated axes.");
}
} else {
emitError("Axes attribute is required.");
}
SmallVector<int64_t, 4> dims;
for (int i = 0, j = 0; i < outRank || j < inRank; ++i) {
if (std::find(axes.begin(), axes.end(), i) != axes.end()) {
dims.emplace_back(1);
} else {
dims.emplace_back(operandTy.getShape()[j++]);
}
}
getResult().setType(RankedTensorType::get(dims, operandTy.getElementType()));
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TableGen'd op method definitions // TableGen'd op method definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -3502,7 +3502,7 @@ def ONNXUniqueOp:ONNX_Op<"Unique",
} }
def ONNXUnsqueezeOp:ONNX_Op<"Unsqueeze", def ONNXUnsqueezeOp:ONNX_Op<"Unsqueeze",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Unsqueeze operation"; let summary = "ONNX Unsqueeze operation";
let description = [{ let description = [{
"Insert single-dimensional entries to the shape of an input tensor (`data`)." "Insert single-dimensional entries to the shape of an input tensor (`data`)."

View File

@ -1176,6 +1176,79 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
} }
}; };
struct ONNXUnsqueezeOpLowering : public ConversionPattern {
ONNXUnsqueezeOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXUnsqueezeOp::getOperationName(), 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
int outRank = tensorType.getRank();
// Assume that `axes` has been validated by shape inference.
// So, here we just get it.
ArrayAttr axisAttrs = llvm::dyn_cast<ONNXUnsqueezeOp>(op).axesAttr();
SmallVector<int, 4> axes;
for (auto axisAttr : axisAttrs.getValue()) {
int axis = axisAttr.cast<IntegerAttr>().getInt();
axis = axis >= 0 ? axis : (outRank + axis);
axes.emplace_back(axis);
}
// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType);
Value alloc;
// Compute size in bytes.
Value tensorSize = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
getMemRefEltSizeInBytes(memRefType)));
bool insertDealloc = checkInsertDealloc(op);
auto memRefShape = memRefType.getShape();
if (hasAllConstantDimensions(memRefType)) {
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
for (int i = 0; i < memRefShape.size(); ++i) {
Value dimVal = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
memRefShape[i]));
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, dimVal);
}
} else {
// Unknown dimensions are always the operand's dimensions.
SmallVector<Value, 4> allocOperands;
for (int outIdx = 0, inIdx = 0; outIdx < memRefShape.size(); ++outIdx) {
Value dimVal = nullptr;
if (memRefShape[outIdx] < 0) {
Value index = rewriter.create<DimOp>(loc, operands[0], inIdx);
dimVal = rewriter.create<IndexCastOp>(
loc, index, rewriter.getIntegerType(64));
allocOperands.emplace_back(index);
} else {
dimVal = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
memRefShape[outIdx]));
}
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, dimVal);
if (std::find(axes.begin(), axes.end(), outIdx) == axes.end())
inIdx++;
}
alloc = rewriter.create<AllocOp>(loc, memRefType, allocOperands);
auto *parentBlock = alloc.getDefiningOp()->getBlock();
if (insertDealloc) {
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
dealloc.getOperation()->moveBefore(&parentBlock->back());
}
}
rewriter.create<KrnlMemcpyOp>(loc, alloc, operands[0], tensorSize);
rewriter.replaceOp(op, alloc);
return matchSuccess();
}
};
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// EntryPoint Op lowering to Krnl Entry Point. // EntryPoint Op lowering to Krnl Entry Point.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1304,7 +1377,7 @@ void FrontendToKrnlLoweringPass::runOnModule() {
ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
ONNXReshapeOpLowering, ONNXEntryPointLowering, ONNXReshapeOpLowering, ONNXEntryPointLowering,
ONNXSoftmaxOpLowering>(&getContext()); ONNXSoftmaxOpLowering, ONNXUnsqueezeOpLowering>(&getContext());
// With the target and rewrite patterns defined, we can now attempt the // With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal` // conversion. The conversion will signal failure if any of our `illegal`

View File

@ -121,7 +121,8 @@ public:
op->getName().getStringRef() != "onnx.Transpose" && op->getName().getStringRef() != "onnx.Transpose" &&
op->getName().getStringRef() != "onnx.Softmax" && op->getName().getStringRef() != "onnx.Softmax" &&
op->getName().getStringRef() != "onnx.Sqrt" && op->getName().getStringRef() != "onnx.Sqrt" &&
op->getName().getStringRef() != "onnx.ConvNoBias") op->getName().getStringRef() != "onnx.ConvNoBias" &&
op->getName().getStringRef() != "onnx.Unsqueeze")
return false; return false;
return llvm::any_of(op->getResultTypes(), [](Type result_type) { return llvm::any_of(op->getResultTypes(), [](Type result_type) {
return !result_type.isa<RankedTensorType>(); return !result_type.isa<RankedTensorType>();

View File

@ -147,6 +147,16 @@ test_to_enable = [
"test_sum_one_input_cpu", "test_sum_one_input_cpu",
"test_sum_two_inputs_cpu", "test_sum_two_inputs_cpu",
# Unsqueeze Op:
"test_unsqueeze_axis_0_cpu",
"test_unsqueeze_axis_1_cpu",
"test_unsqueeze_axis_2_cpu",
"test_unsqueeze_axis_3_cpu",
"test_unsqueeze_negative_axes_cpu",
"test_unsqueeze_three_axes_cpu",
"test_unsqueeze_two_axes_cpu",
"test_unsqueeze_unsorted_axes_cpu",
# Reciprocal Op: # Reciprocal Op:
"test_reciprocal_cpu", "test_reciprocal_cpu",
"test_reciprocal_example_cpu", "test_reciprocal_example_cpu",

View File

@ -652,3 +652,22 @@ func @test_sqrt(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
// CHECK: return [[RES]] : memref<?x10xf32> // CHECK: return [[RES]] : memref<?x10xf32>
} }
func @test_unsqueeze(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Unsqueeze"(%arg0) {axes=[0,3]} : (tensor<10x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_unsqueeze
// CHECK: [[RES:%.+]] = alloc() : memref<1x10x10x1xf32>
// CHECK: [[INBYTES:%.+]] = constant 4 : i64
// CHECK: [[DIM1:%.+]] = constant 1 : i64
// CHECK: [[SIZE1:%.+]] = muli [[INBYTES]], [[DIM1]] : i64
// CHECK: [[DIM2:%.+]] = constant 10 : i64
// CHECK: [[SIZE2:%.+]] = muli [[SIZE1]], [[DIM2]] : i64
// CHECK: [[DIM3:%.+]] = constant 10 : i64
// CHECK: [[SIZE3:%.+]] = muli [[SIZE2]], [[DIM3]] : i64
// CHECK: [[DIM4:%.+]] = constant 1 : i64
// CHECK: [[SIZE4:%.+]] = muli [[SIZE3]], [[DIM4]] : i64
// CHECK: "krnl.memcpy"([[RES]], %arg0, [[SIZE4]]) : (memref<1x10x10x1xf32>, memref<10x10xf32>, i64) -> ()
// CHECK: return [[RES]] : memref<1x10x10x1xf32>
}