From adad9e24bdc5f1cad9599017005fa72ebaddf5fd Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Wed, 12 Feb 2020 00:37:47 +0900 Subject: [PATCH] Add support of negative dimensions (#66) Co-authored-by: Gheorghe-Teodor Bercea --- src/pass/lower_frontend_to_krnl.cpp | 80 ++++++++++++++++++++++---- test/backend/test.py | 6 +- test/mlir/onnx/onnx_lowering.mlir | 89 +++++++++++++++++++---------- 3 files changed, 131 insertions(+), 44 deletions(-) diff --git a/src/pass/lower_frontend_to_krnl.cpp b/src/pass/lower_frontend_to_krnl.cpp index e2354c9..265d92c 100644 --- a/src/pass/lower_frontend_to_krnl.cpp +++ b/src/pass/lower_frontend_to_krnl.cpp @@ -1279,44 +1279,73 @@ struct ONNXReshapeOpLowering : public ConversionPattern { matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto tensorType = (*op->result_type_begin()).cast(); + auto inputShape = operands[0].getType().cast().getShape(); auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertTensorToMemRef(tensorType); + auto memRefShape = memRefType.getShape(); Value alloc; - // Compute size in bytes. + // Compute size in bytes using the input tensor. Value tensorSize = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType))); + for (int i = 0; i < inputShape.size(); ++i) { + Value dimVal; + if (inputShape[i] < 0) { + Value dim = rewriter.create(loc, operands[0], i); + dimVal = + rewriter.create(loc, dim, rewriter.getIntegerType(64)); + } else { + dimVal = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), + inputShape[i])); + } + tensorSize = rewriter.create(loc, tensorSize, dimVal); + } + bool insertDealloc = checkInsertDealloc(op); if (hasAllConstantDimensions(memRefType)) { alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); } else { - auto memRefShape = memRefType.getShape(); - auto inputShape = operands[0].getType().cast().getShape(); - SmallVector allocOperands; + // If a dimension is zero, the actual dimension value is taken from the + // input tensor. + // + // If the shape array has a negative dimension (-1), we compute its actual + // dimension value from the other dimensions. But we don't have enough + // information about the other dimensions at this point. So, we need to + // scan the shape first to calculate reduction of all of the dimensions. + // If the reduction is negative, then the shape array contains a negative + // dimension. Otherwise, the reduction is the same as the one computed + // from the input tensor. + Value tensorSizeFromShape = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), + getMemRefEltSizeInBytes(memRefType))); + SmallVector DimInfo; for (int i = 0; i < memRefShape.size(); ++i) { - // The shape array can always be used to construct shape information of - // the result. Value index = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i)); // Load index from array of indices. Value loadedVal = rewriter.create(loc, operands[1], index); // If a dimension is zero, the actual dimension value is taken from the // input tensor. + // + // If a dimension is negative, it is computed from the other dimensions. + // But we don't have enough information about the other dimensions at + // this point. So, we let it as it is (-1), and compute it later. if (i < inputShape.size()) { Value dimVal; - auto dimTy = loadedVal.getType().cast(); + auto loadedValType = loadedVal.getType().cast(); if (inputShape[i] < 0) { Value dim = rewriter.create(loc, operands[0], i); - dimVal = rewriter.create(loc, dim, dimTy); + dimVal = rewriter.create(loc, dim, loadedValType); } else { dimVal = rewriter.create( - loc, rewriter.getIntegerAttr(dimTy, inputShape[i])); + loc, rewriter.getIntegerAttr(loadedValType, inputShape[i])); } auto zero = rewriter.create( - loc, rewriter.getIntegerAttr(dimTy, 0)); + loc, rewriter.getIntegerAttr(loadedValType, 0)); auto isZero = rewriter.create(loc, CmpIPredicate::eq, loadedVal, zero); loadedVal = rewriter.create(loc, isZero, dimVal, loadedVal); @@ -1327,9 +1356,36 @@ struct ONNXReshapeOpLowering : public ConversionPattern { if (loadedVal.getType().cast().getWidth() < 64) int64LoadedVal = rewriter.create( loc, loadedVal, rewriter.getIntegerType(64)); - tensorSize = rewriter.create(loc, tensorSize, int64LoadedVal); + tensorSizeFromShape = + rewriter.create(loc, tensorSizeFromShape, int64LoadedVal); + // Store intermediate results to use later. + DimInfo.emplace_back(int64LoadedVal); + } + // Reverse tensorSizeFromShape since it is negative if the shape array has + // a negative dimension. This is safe since we only use it to compute the + // actual value for the negative dimension. + auto zero = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + tensorSizeFromShape = + rewriter.create(loc, zero, tensorSizeFromShape); + + // Obtain operands for AllocOp. + SmallVector allocOperands; + auto negOne = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), -1)); + + for (int i = 0; i < memRefShape.size(); ++i) { + auto dimVal = DimInfo[i]; + auto isNegOne = + rewriter.create(loc, CmpIPredicate::eq, dimVal, negOne); + // If dimension is negative, compute its value from the other + // dimensions. + auto actualDimVal = + rewriter.create(loc, tensorSize, tensorSizeFromShape); + auto loadedVal = + rewriter.create(loc, isNegOne, actualDimVal, dimVal); allocOperands.push_back(rewriter.create( - loc, loadedVal, rewriter.getIndexType())); + loc, loadedVal, rewriter.getIndexType())); } AllocOp allocateMemref = rewriter.create(loc, memRefType, allocOperands); diff --git a/test/backend/test.py b/test/backend/test.py index 1c520aa..4a375a6 100644 --- a/test/backend/test.py +++ b/test/backend/test.py @@ -224,13 +224,13 @@ test_to_enable = [ # ReshapeOp: "test_reshape_extended_dims_cpu", - #"test_reshape_negative_dim_cpu", <- handle nagative dim - #"test_reshape_negative_extended_dims_cpu", <- handle nagative dim + "test_reshape_negative_dim_cpu", + "test_reshape_negative_extended_dims_cpu", "test_reshape_one_dim_cpu", "test_reshape_reduced_dims_cpu", "test_reshape_reordered_all_dims_cpu", "test_reshape_reordered_last_dims_cpu", - #"test_reshape_zero_and_negative_dim_cpu", <- handle nagative dim + "test_reshape_zero_and_negative_dim_cpu", "test_reshape_zero_dim_cpu", # Transpose diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index e1724b6..ff16551 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -302,38 +302,69 @@ func @test_reshape(%arg0 : tensor, %arg1 : tensor<4xi32>) -> tensor<*x "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_reshape - // CHECK: [[TYPE_IN_BYTES:%.+]] = constant 4 : i64 - // CHECK: %[[INDEX_0:.+]] = constant 0 : index - // CHECK: [[LOAD_0:%.+]] = load %arg1[%[[INDEX_0]]] : memref<4xi32> + // CHECK: [[TYPE_IN_BYTES_0:%.+]] = constant 4 : i64 // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref - // CHECK: [[DIM_0_CAST:%.+]] = index_cast [[DIM_0]] : index to i32 - // CHECK: [[CONSTANT_0:%.+]] = constant 0 : i32 - // CHECK: [[CMP:%.+]] = cmpi "eq", [[LOAD_0]], [[CONSTANT_0]] : i32 - // CHECK: [[SELECT_0:%.+]] = select [[CMP]], [[DIM_0_CAST]], [[LOAD_0]] : i32 - // CHECK: [[EXT_0:%.+]] = zexti [[SELECT_0]] : i32 to i64 - // CHECK: [[MUL_0:%.+]] = muli [[TYPE_IN_BYTES]], [[EXT_0]] : i64 - // CHECK: [[CAST_0:%.+]] = index_cast [[SELECT_0]] : i32 to index - // CHECK: %[[INDEX_1:.+]] = constant 1 : index - // CHECK: [[LOAD_1:%.+]] = load %arg1[%[[INDEX_1]]] : memref<4xi32> - // CHECK: [[CONSTANT_1:%.+]] = constant 10 : i32 + // CHECK: [[DIM_0_CAST:%.+]] = index_cast [[DIM_0]] : index to i64 + // CHECK: [[MUL_0:%.+]] = muli [[TYPE_IN_BYTES_0]], [[DIM_0_CAST]] : i64 + // CHECK: [[CONSTANT_0:%.+]] = constant 10 : i64 + // CHECK: [[TENSOR_SIZE:%.+]] = muli [[MUL_0]], [[CONSTANT_0]] : i64 + + // CHECK: [[TYPE_IN_BYTES_1:%.+]] = constant 4 : i64 + // CHECK: %[[CONSTANT_1:.+]] = constant 0 : index + // CHECK: [[LOAD_0:%.+]] = load %arg1[%[[CONSTANT_1]]] : memref<4xi32> + // CHECK: [[DIM_1:%.+]] = dim %arg0, 0 : memref + // CHECK: [[DIM_1_CAST:%.+]] = index_cast [[DIM_1]] : index to i32 // CHECK: [[CONSTANT_2:%.+]] = constant 0 : i32 - // CHECK: [[CMP_1:%.+]] = cmpi "eq", [[LOAD_1]], [[CONSTANT_2]] : i32 - // CHECK: [[SELECT_1:%.+]] = select [[CMP_1]], [[CONSTANT_1]], [[LOAD_1]] : i32 - // CHECK: [[EXT_1:%.+]] = zexti [[SELECT_1]] : i32 to i64 - // CHECK: [[MUL_1:%.+]] = muli [[MUL_0]], [[EXT_1]] : i64 - // CHECK: [[CAST_1:%.+]] = index_cast [[SELECT_1]] : i32 to index - // CHECK: %[[INDEX_2:.+]] = constant 2 : index - // CHECK: [[LOAD_2:%.+]] = load %arg1[%[[INDEX_2]]] : memref<4xi32> - // CHECK: [[EXT_2:%.+]] = zexti [[LOAD_2]] : i32 to i64 - // CHECK: [[MUL_2:%.+]] = muli [[MUL_1]], [[EXT_2]] : i64 - // CHECK: [[CAST_2:%.+]] = index_cast [[LOAD_2]] : i32 to index - // CHECK: %[[INDEX_3:.+]] = constant 3 : index - // CHECK: [[LOAD_3:%.+]] = load %arg1[%[[INDEX_3]]] : memref<4xi32> - // CHECK: [[EXT_3:%.+]] = zexti [[LOAD_3]] : i32 to i64 - // CHECK: [[MUL_3:%.+]] = muli [[MUL_2]], [[EXT_3]] : i64 - // CHECK: [[CAST_3:%.+]] = index_cast [[LOAD_3]] : i32 to index + // CHECK: [[CMP_0:%.+]] = cmpi "eq", [[LOAD_0]], [[CONSTANT_2]] : i32 + // CHECK: [[SELECT_0:%.+]] = select [[CMP_0]], [[DIM_1_CAST]], [[LOAD_0]] : i32 + // CHECK: [[ZEXTI_0:%.+]] = zexti [[SELECT_0]] : i32 to i64 + // CHECK: [[MUL_1:%.+]] = muli [[TYPE_IN_BYTES_1]], [[ZEXTI_0]] : i64 + + // CHECK: %[[CONSTANT_3:.+]] = constant 1 : index + // CHECK: [[LOAD_1:%.+]] = load %arg1[%[[CONSTANT_3]]] : memref<4xi32> + // CHECK: [[CONSTANT_3:%.+]] = constant 10 : i32 + // CHECK: [[CONSTANT_4:%.+]] = constant 0 : i32 + // CHECK: [[CMP_1:%.+]] = cmpi "eq", [[LOAD_1]], [[CONSTANT_4]] : i32 + // CHECK: [[SELECT_1:%.+]] = select [[CMP_1]], [[CONSTANT_3]], [[LOAD_1]] : i32 + // CHECK: [[ZEXTI_1:%.+]] = zexti [[SELECT_1]] : i32 to i64 + // CHECK: [[MUL_2:%.+]] = muli [[MUL_1]], [[ZEXTI_1]] : i64 + + // CHECK: %[[CONSTANT_5:.+]] = constant 2 : index + // CHECK: [[LOAD_2:%.+]] = load %arg1[%[[CONSTANT_5]]] : memref<4xi32> + // CHECK: [[ZEXTI_2:%.+]] = zexti [[LOAD_2]] : i32 to i64 + // CHECK: [[MUL_3:%.+]] = muli [[MUL_2]], [[ZEXTI_2]] : i64 + + // CHECK: %[[CONSTANT_6:.+]] = constant 3 : index + // CHECK: [[LOAD_3:%.+]] = load %arg1[%[[CONSTANT_6]]] : memref<4xi32> + // CHECK: [[ZEXTI_3:%.+]] = zexti [[LOAD_3]] : i32 to i64 + // CHECK: [[MUL_4:%.+]] = muli [[MUL_3]], [[ZEXTI_3]] : i64 + + // CHECK: [[CONSTANT_7:%.+]] = constant 0 : i64 + // CHECK: [[SUB_0:%.+]] = subi [[CONSTANT_7]], [[MUL_4]] : i64 + + // CHECK: [[CONSTANT_8:%.+]] = constant -1 : i64 + // CHECK: [[CMP_2:%.+]] = cmpi "eq", [[ZEXTI_0]], [[CONSTANT_8]] : i64 + // CHECK: [[DIVISIGNED_0:%.+]] = divi_signed [[TENSOR_SIZE]], [[SUB_0]] : i64 + // CHECK: [[SELECT_2:%.+]] = select [[CMP_2]], [[DIVISIGNED_0]], [[ZEXTI_0]] : i64 + // CHECK: [[CAST_0:%.+]] = index_cast [[SELECT_2]] : i64 to index + + // CHECK: [[CMP_3:%.+]] = cmpi "eq", [[ZEXTI_1]], [[CONSTANT_8]] : i64 + // CHECK: [[DIVISIGNED_1:%.+]] = divi_signed [[TENSOR_SIZE]], [[SUB_0]] : i64 + // CHECK: [[SELECT_3:%.+]] = select [[CMP_3]], [[DIVISIGNED_1]], [[ZEXTI_1]] : i64 + // CHECK: [[CAST_1:%.+]] = index_cast [[SELECT_3]] : i64 to index + + // CHECK: [[CMP_4:%.+]] = cmpi "eq", [[ZEXTI_2]], [[CONSTANT_8]] : i64 + // CHECK: [[DIVISIGNED_2:%.+]] = divi_signed [[TENSOR_SIZE]], [[SUB_0]] : i64 + // CHECK: [[SELECT_4:%.+]] = select [[CMP_4]], [[DIVISIGNED_2]], [[ZEXTI_2]] : i64 + // CHECK: [[CAST_2:%.+]] = index_cast [[SELECT_4]] : i64 to index + + // CHECK: [[CMP_5:%.+]] = cmpi "eq", [[ZEXTI_3]], [[CONSTANT_8]] : i64 + // CHECK: [[DIVISIGNED_3:%.+]] = divi_signed [[TENSOR_SIZE]], [[SUB_0]] : i64 + // CHECK: [[SELECT_5:%.+]] = select [[CMP_5]], [[DIVISIGNED_3]], [[ZEXTI_3]] : i64 + // CHECK: [[CAST_3:%.+]] = index_cast [[SELECT_5]] : i64 to index + // CHECK: [[ALLOC:%.+]] = alloc([[CAST_0]], [[CAST_1]], [[CAST_2]], [[CAST_3]]) : memref - // CHECK: "krnl.memcpy"([[ALLOC]], %arg0, [[MUL_3]]) : (memref, memref, i64) -> () + // CHECK: "krnl.memcpy"([[ALLOC]], %arg0, [[TENSOR_SIZE]]) : (memref, memref, i64) -> () // CHECK: return [[ALLOC]] : memref }