Gather: fix for negative indices (#313)
* Define krnl.permute op. * Support krnl.permute operation. * Properly remove loop references. * Re-push, Github was down. * Need to debug interpretOp error. * Fix lowering bug by erasing ops after full krnl IR interpretation is done, and clean up & comment code. * Introduce permute, unroll operations. * More debug. * Remove std::set. * krnl.terminate fails to be converted. * Pass all tests, need to add legal ops as well as part of the conversion target. * Change test format to new permute spec. * Bug fix for nested iterate op lowering. * Simplify error reporting. * Fix compilation error. * Increase comments coverage. * Remove unnecessary imports. * Re-trigger Jenkins * Add permute/unroll tests. * Retrigger Jenkins * changes to support negative indices Signed-off-by: Alexandre Eichenberger <alexe@us.ibm.com> * use krnl.dim now Signed-off-by: Alexandre Eichenberger <alexe@us.ibm.com> * move comment Signed-off-by: Alexandre Eichenberger <alexe@us.ibm.com> * updated test for krnl-dim pattern Signed-off-by: Alexandre Eichenberger <alexe@us.ibm.com> Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
f16983a7c4
commit
f0c5b99229
|
@ -67,6 +67,14 @@ struct ONNXGatherOpLowering : public ConversionPattern {
|
|||
else
|
||||
return emitError(loc, "unsupported dynamic dimensions");
|
||||
|
||||
// Get the size of the "axis"th dimension of data.
|
||||
auto zeroConst = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
|
||||
auto axisIndexConst = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), axisIndex));
|
||||
auto sizeAxisVal = rewriter.create<KrnlDimOp>(
|
||||
loc, rewriter.getIndexType(), data, axisIndexConst);
|
||||
|
||||
// Create the loops
|
||||
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
|
||||
Block &iterationBlock = iterateOp.bodyRegion().front();
|
||||
|
@ -76,24 +84,32 @@ struct ONNXGatherOpLowering : public ConversionPattern {
|
|||
rewriter.setInsertionPointToStart(&iterationBlock);
|
||||
|
||||
// Handle the operations.
|
||||
// Read first the indices[jj] into indexVal.
|
||||
// Read first the indices[jj] into rawIndexVal.
|
||||
SmallVector<Value, 4> indicesMemRefVal;
|
||||
for (int j = 0; j < indicesRank; ++j)
|
||||
indicesMemRefVal.emplace_back(
|
||||
iterationBlock.getArguments()[jIndexStart + j]);
|
||||
auto indexValInteger =
|
||||
rewriter.create<AffineLoadOp>(loc, indices, indicesMemRefVal);
|
||||
auto indexVal = rewriter.create<IndexCastOp>(
|
||||
auto rawIndexVal = rewriter.create<IndexCastOp>(
|
||||
loc, indexValInteger, rewriter.getIndexType());
|
||||
// When raw index value is negative, must add array dimension size to it.
|
||||
auto negativeIndexVal =
|
||||
rewriter.create<AddIOp>(loc, rawIndexVal, sizeAxisVal);
|
||||
// Select value for non-negative or negative case.
|
||||
auto isNegative = rewriter.create<CmpIOp>(
|
||||
loc, CmpIPredicate::slt, rawIndexVal, zeroConst);
|
||||
auto indexVal = rewriter.create<SelectOp>(
|
||||
loc, isNegative, negativeIndexVal, rawIndexVal);
|
||||
|
||||
// Then read input data into DataVal: first add ii's.
|
||||
SmallVector<Value, 4> dataMemRefVal;
|
||||
for (int i = 0; i < axisIndex; ++i)
|
||||
dataMemRefVal.emplace_back(
|
||||
iterationBlock.getArguments()[iIndexStart + i]);
|
||||
// Then add indices[jj] (indexVal)
|
||||
// Then add indices[jj] (indexVal).
|
||||
dataMemRefVal.emplace_back(indexVal);
|
||||
// Then add kk's
|
||||
// Then add kk's.
|
||||
for (int k = axisIndex + 1; k < dataRank; ++k)
|
||||
dataMemRefVal.emplace_back(
|
||||
iterationBlock.getArguments()[kIndexStart + k]);
|
||||
|
|
|
@ -409,6 +409,7 @@ void addONNXToKrnlPasses(mlir::PassManager &pm) {
|
|||
// from ONNX dialect to Standard dialect exposes additional canonicalization
|
||||
// oppertunities.
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addPass(createDisconnectKrnlDimFromAllocPass());
|
||||
|
||||
// TODO: make this pass optional:
|
||||
pm.addPass(mlir::createKrnlEnableMemoryPoolPass());
|
||||
|
|
|
@ -158,9 +158,9 @@ test_to_enable = [
|
|||
"test_exp_example_cpu",
|
||||
|
||||
# Gather Op:
|
||||
#"test_gather_0",
|
||||
#"test_gather_1",
|
||||
#"test_gather_negative_indices",
|
||||
"test_gather_0_cpu",
|
||||
"test_gather_1_cpu",
|
||||
"test_gather_negative_indices_cpu",
|
||||
|
||||
# Gemm Op:
|
||||
"test_gemm_all_attributes_cpu",
|
||||
|
|
|
@ -2181,10 +2181,16 @@ func @test_gather_axis0(%arg0 : tensor<3x2xf32>) -> tensor<2x2x2xf32> {
|
|||
// CHECK: [[ALLOC:%.+]] = alloc() : memref<2x2x2xf32>
|
||||
// CHECK: [[GLOBAL:%.+]] = "krnl.global"() {name = "{{.*}}", shape = [2, 2], value = dense<{{\[+}}0, 1], [1, 2{{\]+}}> : tensor<2x2xi64>} : () -> memref<2x2xi64>
|
||||
// CHECK: [[LOOP:%.+]]:3 = krnl.define_loops 3
|
||||
// CHECK: [[ZERO:%.+]] = constant 0 : index
|
||||
// CHECK: [[DIM_INDEX:%.+]] = constant 0 : index
|
||||
// CHECK: [[DIM:%.+]] = "krnl.dim"(%arg0, [[DIM_INDEX]]) : (memref<3x2xf32>, index) -> index
|
||||
// CHECK: krnl.iterate([[LOOP]]#0, [[LOOP]]#1, [[LOOP]]#2) with ([[LOOP]]#0 -> [[ARG1:%.+]] = 0 to 2, [[LOOP]]#1 -> [[ARG2:%.+]] = 0 to 2, [[LOOP]]#2 -> [[ARG3:%.+]] = 0 to 2) {
|
||||
// CHECK: [[AFFINE1:%.+]] = affine.load [[GLOBAL]]{{.}}[[ARG1]], [[ARG2]]{{.}} : memref<2x2xi64>
|
||||
// CHECK: [[AFFINE2:%.+]] = index_cast [[AFFINE1]] : i64 to index
|
||||
// CHECK: [[DATA:%.+]] = load %arg0{{.}}[[AFFINE2]], [[ARG3]]{{.}} : memref<3x2xf32>
|
||||
// CHECK: [[AFFINE3:%.+]] = addi [[AFFINE2]], [[DIM]] : index
|
||||
// CHECK: [[CMP:%.+]] = cmpi "slt", [[AFFINE2]], [[ZERO]] : index
|
||||
// CHECK: [[AFFINE4:%.+]] = select [[CMP]], [[AFFINE3]], [[AFFINE2]] : index
|
||||
// CHECK: [[DATA:%.+]] = load %arg0{{.}}[[AFFINE4]], [[ARG3]]{{.}} : memref<3x2xf32>
|
||||
// CHECK: affine.store [[DATA]], [[ALLOC]]{{.}}[[ARG1]], [[ARG2]], [[ARG3]]{{.}} : memref<2x2x2xf32>
|
||||
}
|
||||
|
||||
|
@ -2200,10 +2206,16 @@ func @test_gather_axis1(%arg0 : tensor<3x3xf32>) -> tensor<1x3x2xf32> {
|
|||
// CHECK: [[ALLOC:%.+]] = alloc() : memref<1x3x2xf32>
|
||||
// CHECK: [[GLOBAL:%.+]] = "krnl.global"() {name = "constant_0", shape = [1, 2], value = dense<{{\[+}}0, 2{{\]+}}> : tensor<1x2xi64>} : () -> memref<1x2xi64>
|
||||
// CHECK: [[LOOP:%.+]]:3 = krnl.define_loops 3
|
||||
// CHECK: [[ZERO:%.+]] = constant 0 : index
|
||||
// CHECK: [[DIM_INDEX:%.+]] = constant 1 : index
|
||||
// CHECK: [[DIM:%.+]] = "krnl.dim"(%arg0, [[DIM_INDEX]]) : (memref<3x3xf32>, index) -> index
|
||||
// CHECK: krnl.iterate([[LOOP]]#0, [[LOOP]]#1, [[LOOP]]#2) with ([[LOOP]]#0 -> [[ARG1:%.+]] = 0 to 3, [[LOOP]]#1 -> [[ARG2:%.+]] = 0 to 1, [[LOOP]]#2 -> [[ARG3:%.+]] = 0 to 2) {
|
||||
// CHECK: [[AFFINE1:%.+]] = affine.load [[GLOBAL]]{{.}}[[ARG2]], [[ARG3]]{{.}} : memref<1x2xi64>
|
||||
// CHECK: [[AFFINE2:%.+]] = index_cast [[AFFINE1]] : i64 to index
|
||||
// CHECK: [[DATA:%.+]] = load %arg0{{.}}[[ARG1]], [[AFFINE2]]{{.}} : memref<3x3xf32>
|
||||
// CHECK: [[AFFINE3:%.+]] = addi [[AFFINE2]], [[DIM]] : index
|
||||
// CHECK: [[CMP:%.+]] = cmpi "slt", [[AFFINE2]], [[ZERO]] : index
|
||||
// CHECK: [[AFFINE4:%.+]] = select [[CMP]], [[AFFINE3]], [[AFFINE2]] : index
|
||||
// CHECK: [[DATA:%.+]] = load %arg0{{.}}[[ARG1]], [[AFFINE4]]{{.}} : memref<3x3xf32>
|
||||
// CHECK: affine.store [[DATA]], [[ALLOC]]{{.}}[[ARG1]], [[ARG2]], [[ARG3]]{{.}} : memref<1x3x2xf32>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue