diff --git a/src/Conversion/ONNXToKrnl/Tensor/Gather.cpp b/src/Conversion/ONNXToKrnl/Tensor/Gather.cpp index c41ab85..ca46002 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Gather.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Gather.cpp @@ -67,6 +67,14 @@ struct ONNXGatherOpLowering : public ConversionPattern { else return emitError(loc, "unsupported dynamic dimensions"); + // Get the size of the "axis"th dimension of data. + auto zeroConst = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); + auto axisIndexConst = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIndexType(), axisIndex)); + auto sizeAxisVal = rewriter.create( + loc, rewriter.getIndexType(), data, axisIndexConst); + // Create the loops auto iterateOp = rewriter.create(loc, pack); Block &iterationBlock = iterateOp.bodyRegion().front(); @@ -76,24 +84,32 @@ struct ONNXGatherOpLowering : public ConversionPattern { rewriter.setInsertionPointToStart(&iterationBlock); // Handle the operations. - // Read first the indices[jj] into indexVal. + // Read first the indices[jj] into rawIndexVal. SmallVector indicesMemRefVal; for (int j = 0; j < indicesRank; ++j) indicesMemRefVal.emplace_back( iterationBlock.getArguments()[jIndexStart + j]); auto indexValInteger = rewriter.create(loc, indices, indicesMemRefVal); - auto indexVal = rewriter.create( + auto rawIndexVal = rewriter.create( loc, indexValInteger, rewriter.getIndexType()); + // When raw index value is negative, must add array dimension size to it. + auto negativeIndexVal = + rewriter.create(loc, rawIndexVal, sizeAxisVal); + // Select value for non-negative or negative case. + auto isNegative = rewriter.create( + loc, CmpIPredicate::slt, rawIndexVal, zeroConst); + auto indexVal = rewriter.create( + loc, isNegative, negativeIndexVal, rawIndexVal); // Then read input data into DataVal: first add ii's. SmallVector dataMemRefVal; for (int i = 0; i < axisIndex; ++i) dataMemRefVal.emplace_back( iterationBlock.getArguments()[iIndexStart + i]); - // Then add indices[jj] (indexVal) + // Then add indices[jj] (indexVal). dataMemRefVal.emplace_back(indexVal); - // Then add kk's + // Then add kk's. for (int k = axisIndex + 1; k < dataRank; ++k) dataMemRefVal.emplace_back( iterationBlock.getArguments()[kIndexStart + k]); diff --git a/src/MainUtils.cpp b/src/MainUtils.cpp index f5350f1..f1440b5 100644 --- a/src/MainUtils.cpp +++ b/src/MainUtils.cpp @@ -409,6 +409,7 @@ void addONNXToKrnlPasses(mlir::PassManager &pm) { // from ONNX dialect to Standard dialect exposes additional canonicalization // oppertunities. pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(createDisconnectKrnlDimFromAllocPass()); // TODO: make this pass optional: pm.addPass(mlir::createKrnlEnableMemoryPoolPass()); diff --git a/test/backend/test.py b/test/backend/test.py index a25f891..817e788 100644 --- a/test/backend/test.py +++ b/test/backend/test.py @@ -158,9 +158,9 @@ test_to_enable = [ "test_exp_example_cpu", # Gather Op: - #"test_gather_0", - #"test_gather_1", - #"test_gather_negative_indices", + "test_gather_0_cpu", + "test_gather_1_cpu", + "test_gather_negative_indices_cpu", # Gemm Op: "test_gemm_all_attributes_cpu", diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 14d84ce..0c30e44 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -2181,10 +2181,16 @@ func @test_gather_axis0(%arg0 : tensor<3x2xf32>) -> tensor<2x2x2xf32> { // CHECK: [[ALLOC:%.+]] = alloc() : memref<2x2x2xf32> // CHECK: [[GLOBAL:%.+]] = "krnl.global"() {name = "{{.*}}", shape = [2, 2], value = dense<{{\[+}}0, 1], [1, 2{{\]+}}> : tensor<2x2xi64>} : () -> memref<2x2xi64> // CHECK: [[LOOP:%.+]]:3 = krnl.define_loops 3 + // CHECK: [[ZERO:%.+]] = constant 0 : index + // CHECK: [[DIM_INDEX:%.+]] = constant 0 : index + // CHECK: [[DIM:%.+]] = "krnl.dim"(%arg0, [[DIM_INDEX]]) : (memref<3x2xf32>, index) -> index // CHECK: krnl.iterate([[LOOP]]#0, [[LOOP]]#1, [[LOOP]]#2) with ([[LOOP]]#0 -> [[ARG1:%.+]] = 0 to 2, [[LOOP]]#1 -> [[ARG2:%.+]] = 0 to 2, [[LOOP]]#2 -> [[ARG3:%.+]] = 0 to 2) { // CHECK: [[AFFINE1:%.+]] = affine.load [[GLOBAL]]{{.}}[[ARG1]], [[ARG2]]{{.}} : memref<2x2xi64> // CHECK: [[AFFINE2:%.+]] = index_cast [[AFFINE1]] : i64 to index - // CHECK: [[DATA:%.+]] = load %arg0{{.}}[[AFFINE2]], [[ARG3]]{{.}} : memref<3x2xf32> + // CHECK: [[AFFINE3:%.+]] = addi [[AFFINE2]], [[DIM]] : index + // CHECK: [[CMP:%.+]] = cmpi "slt", [[AFFINE2]], [[ZERO]] : index + // CHECK: [[AFFINE4:%.+]] = select [[CMP]], [[AFFINE3]], [[AFFINE2]] : index + // CHECK: [[DATA:%.+]] = load %arg0{{.}}[[AFFINE4]], [[ARG3]]{{.}} : memref<3x2xf32> // CHECK: affine.store [[DATA]], [[ALLOC]]{{.}}[[ARG1]], [[ARG2]], [[ARG3]]{{.}} : memref<2x2x2xf32> } @@ -2200,10 +2206,16 @@ func @test_gather_axis1(%arg0 : tensor<3x3xf32>) -> tensor<1x3x2xf32> { // CHECK: [[ALLOC:%.+]] = alloc() : memref<1x3x2xf32> // CHECK: [[GLOBAL:%.+]] = "krnl.global"() {name = "constant_0", shape = [1, 2], value = dense<{{\[+}}0, 2{{\]+}}> : tensor<1x2xi64>} : () -> memref<1x2xi64> // CHECK: [[LOOP:%.+]]:3 = krnl.define_loops 3 + // CHECK: [[ZERO:%.+]] = constant 0 : index + // CHECK: [[DIM_INDEX:%.+]] = constant 1 : index + // CHECK: [[DIM:%.+]] = "krnl.dim"(%arg0, [[DIM_INDEX]]) : (memref<3x3xf32>, index) -> index // CHECK: krnl.iterate([[LOOP]]#0, [[LOOP]]#1, [[LOOP]]#2) with ([[LOOP]]#0 -> [[ARG1:%.+]] = 0 to 3, [[LOOP]]#1 -> [[ARG2:%.+]] = 0 to 1, [[LOOP]]#2 -> [[ARG3:%.+]] = 0 to 2) { // CHECK: [[AFFINE1:%.+]] = affine.load [[GLOBAL]]{{.}}[[ARG2]], [[ARG3]]{{.}} : memref<1x2xi64> // CHECK: [[AFFINE2:%.+]] = index_cast [[AFFINE1]] : i64 to index - // CHECK: [[DATA:%.+]] = load %arg0{{.}}[[ARG1]], [[AFFINE2]]{{.}} : memref<3x3xf32> + // CHECK: [[AFFINE3:%.+]] = addi [[AFFINE2]], [[DIM]] : index + // CHECK: [[CMP:%.+]] = cmpi "slt", [[AFFINE2]], [[ZERO]] : index + // CHECK: [[AFFINE4:%.+]] = select [[CMP]], [[AFFINE3]], [[AFFINE2]] : index + // CHECK: [[DATA:%.+]] = load %arg0{{.}}[[ARG1]], [[AFFINE4]]{{.}} : memref<3x3xf32> // CHECK: affine.store [[DATA]], [[ALLOC]]{{.}}[[ARG1]], [[ARG2]], [[ARG3]]{{.}} : memref<1x3x2xf32> }