Gather: fix for negative indices (#313)

* Define krnl.permute op.

* Support krnl.permute operation.

* Properly remove loop references.

* Re-push, Github was down.

* Need to debug interpretOp error.

* Fix lowering bug by erasing ops after full krnl IR interpretation is done, and clean up & comment code.

* Introduce permute, unroll operations.

* More debug.

* Remove std::set.

* krnl.terminate fails to be converted.

* Pass all tests, need to add legal ops as well as part of the conversion target.

* Change test format to new permute spec.

* Bug fix for nested iterate op lowering.

* Simplify error reporting.

* Fix compilation error.

* Increase comments coverage.

* Remove unnecessary imports.

* Re-trigger Jenkins

* Add permute/unroll tests.

* Retrigger Jenkins

* changes to support negative indices

Signed-off-by: Alexandre Eichenberger <alexe@us.ibm.com>

* use krnl.dim now

Signed-off-by: Alexandre Eichenberger <alexe@us.ibm.com>

* move comment

Signed-off-by: Alexandre Eichenberger <alexe@us.ibm.com>

* updated test for krnl-dim pattern

Signed-off-by: Alexandre Eichenberger <alexe@us.ibm.com>

Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
Alexandre Eichenberger 2020-09-24 14:02:49 -04:00 committed by GitHub
parent f16983a7c4
commit f0c5b99229
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 38 additions and 9 deletions

View File

@ -67,6 +67,14 @@ struct ONNXGatherOpLowering : public ConversionPattern {
else
return emitError(loc, "unsupported dynamic dimensions");
// Get the size of the "axis"th dimension of data.
auto zeroConst = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
auto axisIndexConst = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), axisIndex));
auto sizeAxisVal = rewriter.create<KrnlDimOp>(
loc, rewriter.getIndexType(), data, axisIndexConst);
// Create the loops
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
Block &iterationBlock = iterateOp.bodyRegion().front();
@ -76,24 +84,32 @@ struct ONNXGatherOpLowering : public ConversionPattern {
rewriter.setInsertionPointToStart(&iterationBlock);
// Handle the operations.
// Read first the indices[jj] into indexVal.
// Read first the indices[jj] into rawIndexVal.
SmallVector<Value, 4> indicesMemRefVal;
for (int j = 0; j < indicesRank; ++j)
indicesMemRefVal.emplace_back(
iterationBlock.getArguments()[jIndexStart + j]);
auto indexValInteger =
rewriter.create<AffineLoadOp>(loc, indices, indicesMemRefVal);
auto indexVal = rewriter.create<IndexCastOp>(
auto rawIndexVal = rewriter.create<IndexCastOp>(
loc, indexValInteger, rewriter.getIndexType());
// When raw index value is negative, must add array dimension size to it.
auto negativeIndexVal =
rewriter.create<AddIOp>(loc, rawIndexVal, sizeAxisVal);
// Select value for non-negative or negative case.
auto isNegative = rewriter.create<CmpIOp>(
loc, CmpIPredicate::slt, rawIndexVal, zeroConst);
auto indexVal = rewriter.create<SelectOp>(
loc, isNegative, negativeIndexVal, rawIndexVal);
// Then read input data into DataVal: first add ii's.
SmallVector<Value, 4> dataMemRefVal;
for (int i = 0; i < axisIndex; ++i)
dataMemRefVal.emplace_back(
iterationBlock.getArguments()[iIndexStart + i]);
// Then add indices[jj] (indexVal)
// Then add indices[jj] (indexVal).
dataMemRefVal.emplace_back(indexVal);
// Then add kk's
// Then add kk's.
for (int k = axisIndex + 1; k < dataRank; ++k)
dataMemRefVal.emplace_back(
iterationBlock.getArguments()[kIndexStart + k]);

View File

@ -409,6 +409,7 @@ void addONNXToKrnlPasses(mlir::PassManager &pm) {
// from ONNX dialect to Standard dialect exposes additional canonicalization
// oppertunities.
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(createDisconnectKrnlDimFromAllocPass());
// TODO: make this pass optional:
pm.addPass(mlir::createKrnlEnableMemoryPoolPass());

View File

@ -158,9 +158,9 @@ test_to_enable = [
"test_exp_example_cpu",
# Gather Op:
#"test_gather_0",
#"test_gather_1",
#"test_gather_negative_indices",
"test_gather_0_cpu",
"test_gather_1_cpu",
"test_gather_negative_indices_cpu",
# Gemm Op:
"test_gemm_all_attributes_cpu",

View File

@ -2181,10 +2181,16 @@ func @test_gather_axis0(%arg0 : tensor<3x2xf32>) -> tensor<2x2x2xf32> {
// CHECK: [[ALLOC:%.+]] = alloc() : memref<2x2x2xf32>
// CHECK: [[GLOBAL:%.+]] = "krnl.global"() {name = "{{.*}}", shape = [2, 2], value = dense<{{\[+}}0, 1], [1, 2{{\]+}}> : tensor<2x2xi64>} : () -> memref<2x2xi64>
// CHECK: [[LOOP:%.+]]:3 = krnl.define_loops 3
// CHECK: [[ZERO:%.+]] = constant 0 : index
// CHECK: [[DIM_INDEX:%.+]] = constant 0 : index
// CHECK: [[DIM:%.+]] = "krnl.dim"(%arg0, [[DIM_INDEX]]) : (memref<3x2xf32>, index) -> index
// CHECK: krnl.iterate([[LOOP]]#0, [[LOOP]]#1, [[LOOP]]#2) with ([[LOOP]]#0 -> [[ARG1:%.+]] = 0 to 2, [[LOOP]]#1 -> [[ARG2:%.+]] = 0 to 2, [[LOOP]]#2 -> [[ARG3:%.+]] = 0 to 2) {
// CHECK: [[AFFINE1:%.+]] = affine.load [[GLOBAL]]{{.}}[[ARG1]], [[ARG2]]{{.}} : memref<2x2xi64>
// CHECK: [[AFFINE2:%.+]] = index_cast [[AFFINE1]] : i64 to index
// CHECK: [[DATA:%.+]] = load %arg0{{.}}[[AFFINE2]], [[ARG3]]{{.}} : memref<3x2xf32>
// CHECK: [[AFFINE3:%.+]] = addi [[AFFINE2]], [[DIM]] : index
// CHECK: [[CMP:%.+]] = cmpi "slt", [[AFFINE2]], [[ZERO]] : index
// CHECK: [[AFFINE4:%.+]] = select [[CMP]], [[AFFINE3]], [[AFFINE2]] : index
// CHECK: [[DATA:%.+]] = load %arg0{{.}}[[AFFINE4]], [[ARG3]]{{.}} : memref<3x2xf32>
// CHECK: affine.store [[DATA]], [[ALLOC]]{{.}}[[ARG1]], [[ARG2]], [[ARG3]]{{.}} : memref<2x2x2xf32>
}
@ -2200,10 +2206,16 @@ func @test_gather_axis1(%arg0 : tensor<3x3xf32>) -> tensor<1x3x2xf32> {
// CHECK: [[ALLOC:%.+]] = alloc() : memref<1x3x2xf32>
// CHECK: [[GLOBAL:%.+]] = "krnl.global"() {name = "constant_0", shape = [1, 2], value = dense<{{\[+}}0, 2{{\]+}}> : tensor<1x2xi64>} : () -> memref<1x2xi64>
// CHECK: [[LOOP:%.+]]:3 = krnl.define_loops 3
// CHECK: [[ZERO:%.+]] = constant 0 : index
// CHECK: [[DIM_INDEX:%.+]] = constant 1 : index
// CHECK: [[DIM:%.+]] = "krnl.dim"(%arg0, [[DIM_INDEX]]) : (memref<3x3xf32>, index) -> index
// CHECK: krnl.iterate([[LOOP]]#0, [[LOOP]]#1, [[LOOP]]#2) with ([[LOOP]]#0 -> [[ARG1:%.+]] = 0 to 3, [[LOOP]]#1 -> [[ARG2:%.+]] = 0 to 1, [[LOOP]]#2 -> [[ARG3:%.+]] = 0 to 2) {
// CHECK: [[AFFINE1:%.+]] = affine.load [[GLOBAL]]{{.}}[[ARG2]], [[ARG3]]{{.}} : memref<1x2xi64>
// CHECK: [[AFFINE2:%.+]] = index_cast [[AFFINE1]] : i64 to index
// CHECK: [[DATA:%.+]] = load %arg0{{.}}[[ARG1]], [[AFFINE2]]{{.}} : memref<3x3xf32>
// CHECK: [[AFFINE3:%.+]] = addi [[AFFINE2]], [[DIM]] : index
// CHECK: [[CMP:%.+]] = cmpi "slt", [[AFFINE2]], [[ZERO]] : index
// CHECK: [[AFFINE4:%.+]] = select [[CMP]], [[AFFINE3]], [[AFFINE2]] : index
// CHECK: [[DATA:%.+]] = load %arg0{{.}}[[ARG1]], [[AFFINE4]]{{.}} : memref<3x3xf32>
// CHECK: affine.store [[DATA]], [[ALLOC]]{{.}}[[ARG1]], [[ARG2]], [[ARG3]]{{.}} : memref<1x3x2xf32>
}