diff --git a/src/pass/lower_frontend_to_krnl.cpp b/src/pass/lower_frontend_to_krnl.cpp index 8b5e296..01e4fa9 100644 --- a/src/pass/lower_frontend_to_krnl.cpp +++ b/src/pass/lower_frontend_to_krnl.cpp @@ -1118,6 +1118,7 @@ struct ONNXReshapeOpLowering : public ConversionPattern { alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); } else { auto memRefShape = memRefType.getShape(); + auto inputShape = operands[0].getType().cast().getShape(); SmallVector allocOperands; for (int i = 0; i < memRefShape.size(); ++i) { // The shape array can always be used to construct shape information of @@ -1126,6 +1127,24 @@ struct ONNXReshapeOpLowering : public ConversionPattern { loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i)); // Load index from array of indices. Value loadedVal = rewriter.create(loc, operands[1], index); + // If a dimension is zero, the actual dimension value is taken from the + // input tensor. + if (i < inputShape.size()) { + Value dimVal; + auto dimTy = loadedVal.getType().cast(); + if (inputShape[i] < 0) { + Value dim = rewriter.create(loc, operands[0], i); + dimVal = rewriter.create(loc, dim, dimTy); + } else { + dimVal = rewriter.create( + loc, rewriter.getIntegerAttr(dimTy, inputShape[i])); + } + auto zero = rewriter.create( + loc, rewriter.getIntegerAttr(dimTy, 0)); + auto isZero = + rewriter.create(loc, CmpIPredicate::eq, loadedVal, zero); + loadedVal = rewriter.create(loc, isZero, dimVal, loadedVal); + } // Check if the loaded index is already the correct width of 64 bits. // Convert the value to a 64 bit integer if needed. Value int64LoadedVal = loadedVal; diff --git a/test/backend/test.py b/test/backend/test.py index 7cebffc..ae71d85 100644 --- a/test/backend/test.py +++ b/test/backend/test.py @@ -160,7 +160,15 @@ test_to_enable = [ "test_softsign_example_cpu", # ReshapeOp: + "test_reshape_extended_dims_cpu", + #"test_reshape_negative_dim_cpu", <- handle nagative dim + #"test_reshape_negative_extended_dims_cpu", <- handle nagative dim + "test_reshape_one_dim_cpu", + "test_reshape_reduced_dims_cpu", "test_reshape_reordered_all_dims_cpu", + "test_reshape_reordered_last_dims_cpu", + #"test_reshape_zero_and_negative_dim_cpu", <- handle nagative dim + "test_reshape_zero_dim_cpu", ] # Extract name of all test cases. diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index fcf0bfe..b74b65f 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -305,14 +305,23 @@ func @test_reshape(%arg0 : tensor, %arg1 : tensor<4xi32>) -> tensor<*x // CHECK: [[TYPE_IN_BYTES:%.+]] = constant 4 : i64 // CHECK: %[[INDEX_0:.+]] = constant 0 : index // CHECK: [[LOAD_0:%.+]] = load %arg1[%[[INDEX_0]]] : memref<4xi32> - // CHECK: [[EXT_0:%.+]] = zexti [[LOAD_0]] : i32 to i64 + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[DIM_0_CAST:%.+]] = index_cast [[DIM_0]] : index to i32 + // CHECK: [[CONSTANT_0:%.+]] = constant 0 : i32 + // CHECK: [[CMP:%.+]] = cmpi "eq", [[LOAD_0]], [[CONSTANT_0]] : i32 + // CHECK: [[SELECT_0:%.+]] = select [[CMP]], [[DIM_0_CAST]], [[LOAD_0]] : i32 + // CHECK: [[EXT_0:%.+]] = zexti [[SELECT_0]] : i32 to i64 // CHECK: [[MUL_0:%.+]] = muli [[TYPE_IN_BYTES]], [[EXT_0]] : i64 - // CHECK: [[CAST_0:%.+]] = index_cast [[LOAD_0]] : i32 to index + // CHECK: [[CAST_0:%.+]] = index_cast [[SELECT_0]] : i32 to index // CHECK: %[[INDEX_1:.+]] = constant 1 : index // CHECK: [[LOAD_1:%.+]] = load %arg1[%[[INDEX_1]]] : memref<4xi32> - // CHECK: [[EXT_1:%.+]] = zexti [[LOAD_1]] : i32 to i64 + // CHECK: [[CONSTANT_1:%.+]] = constant 10 : i32 + // CHECK: [[CONSTANT_2:%.+]] = constant 0 : i32 + // CHECK: [[CMP_1:%.+]] = cmpi "eq", [[LOAD_1]], [[CONSTANT_2]] : i32 + // CHECK: [[SELECT_1:%.+]] = select [[CMP_1]], [[CONSTANT_1]], [[LOAD_1]] : i32 + // CHECK: [[EXT_1:%.+]] = zexti [[SELECT_1]] : i32 to i64 // CHECK: [[MUL_1:%.+]] = muli [[MUL_0]], [[EXT_1]] : i64 - // CHECK: [[CAST_1:%.+]] = index_cast [[LOAD_1]] : i32 to index + // CHECK: [[CAST_1:%.+]] = index_cast [[SELECT_1]] : i32 to index // CHECK: %[[INDEX_2:.+]] = constant 2 : index // CHECK: [[LOAD_2:%.+]] = load %arg1[%[[INDEX_2]]] : memref<4xi32> // CHECK: [[EXT_2:%.+]] = zexti [[LOAD_2]] : i32 to i64