Gather ONNX to Kernel Lowering (#294)

* Define krnl.permute op.

* Support krnl.permute operation.

* Properly remove loop references.

* Re-push, Github was down.

* Need to debug interpretOp error.

* Fix lowering bug by erasing ops after full krnl IR interpretation is done, and clean up & comment code.

* Introduce permute, unroll operations.

* More debug.

* Remove std::set.

* krnl.terminate fails to be converted.

* Pass all tests, need to add legal ops as well as part of the conversion target.

* Change test format to new permute spec.

* Bug fix for nested iterate op lowering.

* Simplify error reporting.

* Fix compilation error.

* Increase comments coverage.

* Remove unnecessary imports.

* Re-trigger Jenkins

* Add permute/unroll tests.

* Retrigger Jenkins

* initial implementation of gather

* added tests

* format

* remove affine load for second load, as it uses an indirection

* changes suggested by reviewers

* remove backend tests until I can verify them locally

Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
Alexandre Eichenberger 2020-09-11 15:36:23 -04:00 committed by GitHub
parent fa04c32a0c
commit 3a5aa7ee31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 175 additions and 0 deletions

View File

@ -22,6 +22,7 @@ add_library(OMONNXToKrnl
Tensor/Constant.cpp
Tensor/Concat.cpp
Tensor/Split.cpp
Tensor/Gather.cpp
ConvertONNXToKrnl.cpp)
target_link_libraries(OMONNXToKrnl
onnx)

View File

@ -96,6 +96,7 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
populateLoweringONNXPadOpPattern(patterns, &getContext());
populateLoweringONNXUnsqueezeOpPattern(patterns, &getContext());
populateLoweringONNXTransposeOpPattern(patterns, &getContext());
populateLoweringONNXGatherOpPattern(patterns, &getContext());
populateLoweringONNXIdentityOpPattern(patterns, &getContext());
populateLoweringONNXConstantOpPattern(patterns, &getContext());
populateLoweringONNXConcatOpPattern(patterns, &getContext());

View File

@ -223,6 +223,9 @@ void populateLoweringONNXUnsqueezeOpPattern(
void populateLoweringONNXTransposeOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx);
void populateLoweringONNXGatherOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx);
void populateLoweringONNXPadConstantValuePadOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx);

View File

@ -0,0 +1,117 @@
//===----------------Gather.cpp - Lowering Gather Op----------------------=== //
//
// Copyright 2020 The IBM Research Authors.
//
// =============================================================================
//
// This file lowers the ONNX Gather Operator to Krnl dialect.
//
//===----------------------------------------------------------------------===//
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
using namespace mlir;
struct ONNXGatherOpLowering : public ConversionPattern {
ONNXGatherOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXGatherOp::getOperationName(), 1, ctx) {}
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
ONNXGatherOpAdaptor operandAdaptor(operands);
ONNXGatherOp gatherOp = llvm::cast<ONNXGatherOp>(op);
auto loc = op->getLoc();
// get input operands, shapes, and rank
Value data = operandAdaptor.data();
auto dataShape = data.getType().cast<MemRefType>().getShape();
int64_t dataRank = dataShape.size();
Value indices = operandAdaptor.indices();
auto indicesShape = indices.getType().cast<MemRefType>().getShape();
int64_t indicesRank = indicesShape.size();
int64_t axisIndex = gatherOp.axis().getSExtValue();
// get output info
auto outputMemRefType = convertToMemRefType(*op->result_type_begin());
auto outputMemRefShape = outputMemRefType.getShape();
int64_t outputRank = outputMemRefShape.size();
/*
The pattern that we are using is that of numpy.take.
Ni, Nk = data.shape[:axis], data.shape[axis+1:]
Nj = indices.shape
for ii in ndindex(Ni):
for jj in ndindex(Nj):
for kk in ndindex(Nk):
out[ii + jj + kk] = data[ii + (indices[jj],) + kk]
*/
// Define loops and iteration trip counts (equivalent to size of output)
std::vector<Value> originalLoops;
defineLoops(rewriter, loc, originalLoops, outputRank);
KrnlIterateOperandPack pack(rewriter, originalLoops);
int iIndexStart = 0;
for (int ii = 0; ii < axisIndex; ++ii)
addDimensionToPack(rewriter, loc, pack, data, ii);
// Then iterates over the Nj (indices matrix), jj indices in above algo.
int jIndexStart = iIndexStart + axisIndex;
for (int jj = 0; jj < indicesRank; ++jj)
addDimensionToPack(rewriter, loc, pack, indices, jj);
// Finally iterates over the Nk (data after axis), kk indices in above algo.
int kIndexStart = jIndexStart + indicesRank - (axisIndex + 1);
for (int kk = axisIndex + 1; kk < dataRank; ++kk)
addDimensionToPack(rewriter, loc, pack, data, kk);
// Insert an allocation and deallocation for the result of this operation.
Value alloc;
bool insertDealloc = checkInsertDealloc(op);
if (hasAllConstantDimensions(outputMemRefType))
alloc =
insertAllocAndDealloc(outputMemRefType, loc, rewriter, insertDealloc);
else
return emitError(loc, "unsupported dynamic dimensions");
// Create the loops
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
Block &iterationBlock = iterateOp.bodyRegion().front();
// Now perform the insertions into the body of the just generated loops.
// Insert instructions inside the KernelIterateOp body.
rewriter.setInsertionPointToStart(&iterationBlock);
// Handle the operations.
// Read first the indices[jj] into indexVal.
SmallVector<Value, 4> indicesMemRefVal;
for (int j = 0; j < indicesRank; ++j)
indicesMemRefVal.emplace_back(
iterationBlock.getArguments()[jIndexStart + j]);
auto indexValInteger =
rewriter.create<AffineLoadOp>(loc, indices, indicesMemRefVal);
auto indexVal = rewriter.create<IndexCastOp>(
loc, indexValInteger, rewriter.getIndexType());
// Then read input data into DataVal: first add ii's.
SmallVector<Value, 4> dataMemRefVal;
for (int i = 0; i < axisIndex; ++i)
dataMemRefVal.emplace_back(
iterationBlock.getArguments()[iIndexStart + i]);
// Then add indices[jj] (indexVal)
dataMemRefVal.emplace_back(indexVal);
// Then add kk's
for (int k = axisIndex + 1; k < dataRank; ++k)
dataMemRefVal.emplace_back(
iterationBlock.getArguments()[kIndexStart + k]);
auto dataVal = rewriter.create<LoadOp>(loc, data, dataMemRefVal);
// Then store the value in the output.
SmallVector<Value, 4> outputMemRefVal;
for (int n = 0; n < iterationBlock.getArguments().size(); ++n)
outputMemRefVal.emplace_back(iterationBlock.getArguments()[n]);
rewriter.create<AffineStoreOp>(loc, dataVal, alloc, outputMemRefVal);
rewriter.replaceOp(op, alloc);
return success();
}
};
void populateLoweringONNXGatherOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
patterns.insert<ONNXGatherOpLowering>(ctx);
}

View File

@ -29,6 +29,14 @@ struct ONNXTransposeOpLowering : public ConversionPattern {
if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else
// TODO: While the code below appears to nominally handle the alloc of
// data in presence of dynamic dimensions, this appears to be false as the
// operand passed here "{data}" reflect the input sizes and does not
// reflect the transpose. Indeed, if an input is 4x3x2 then the output
// would be 2x3x4. If any of the 2,3,or 4 are dynamic dimensions, then we
// simply pass below the operand of the not-transposed input data to
// determine the dynamic sizes of the to-be-transposed data. At this time,
// there is also no dynamic size lowering tests...
alloc = insertAllocAndDealloc(
memRefType, loc, rewriter, insertDealloc, {data});

View File

@ -512,6 +512,8 @@ mlir::Type ONNXOpsDialect::parseType(mlir::DialectAsmParser &parser) const {
if (parser.parseGreater())
return Type();
return SeqType::get(elementTypes);
} else {
llvm_unreachable("Unexpected onnxmlir keyword");
}
}

View File

@ -157,6 +157,11 @@ test_to_enable = [
"test_exp_cpu",
"test_exp_example_cpu",
# Gather Op:
#"test_gather_0",
#"test_gather_1",
#"test_gather_negative_indices",
# Gemm Op:
"test_gemm_all_attributes_cpu",
"test_gemm_alpha_cpu",

View File

@ -2132,3 +2132,41 @@ func @cast_lowering_f64f32_10(%arg0: tensor<10xf64>) -> tensor<*xf32> {
// CHECK: affine.store [[FPTRUNC]], [[RES]][%arg1] : memref<10xf32>
// CHECK: return [[RES]] : memref<10xf32>
}
// -----
// Test gather along axis 0, first example in ONNX for Gather.
func @test_gather_axis0(%arg0 : tensor<3x2xf32>) -> tensor<2x2x2xf32> {
%indices = "onnx.Constant"() {value = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>} : () -> tensor<2x2xi64>
%0 = "onnx.Gather"(%arg0, %indices) {axis = 0} : (tensor<3x2xf32>, tensor<2x2xi64>) -> tensor<2x2x2xf32>
"std.return"(%0) : (tensor<2x2x2xf32>) -> ()
// CHECK-LABEL: test_gather_axis0
// CHECK: [[ALLOC:%.+]] = alloc() : memref<2x2x2xf32>
// CHECK: [[GLOBAL:%.+]] = "krnl.global"() {name = "{{.*}}", shape = [2, 2], value = dense<{{\[+}}0, 1], [1, 2{{\]+}}> : tensor<2x2xi64>} : () -> memref<2x2xi64>
// CHECK: [[LOOP:%.+]]:3 = krnl.define_loops 3
// CHECK: krnl.iterate([[LOOP]]#0, [[LOOP]]#1, [[LOOP]]#2) with ([[LOOP]]#0 -> [[ARG1:%.+]] = 0 to 2, [[LOOP]]#1 -> [[ARG2:%.+]] = 0 to 2, [[LOOP]]#2 -> [[ARG3:%.+]] = 0 to 2) {
// CHECK: [[AFFINE1:%.+]] = affine.load [[GLOBAL]]{{.}}[[ARG1]], [[ARG2]]{{.}} : memref<2x2xi64>
// CHECK: [[AFFINE2:%.+]] = index_cast [[AFFINE1]] : i64 to index
// CHECK: [[DATA:%.+]] = load %arg0{{.}}[[AFFINE2]], [[ARG3]]{{.}} : memref<3x2xf32>
// CHECK: affine.store [[DATA]], [[ALLOC]]{{.}}[[ARG1]], [[ARG2]], [[ARG3]]{{.}} : memref<2x2x2xf32>
}
// -----
// Test gather along axis 1, second example in ONNX for Gather.
func @test_gather_axis1(%arg0 : tensor<3x3xf32>) -> tensor<1x3x2xf32> {
%indices = "onnx.Constant"() {value = dense<[[0, 2]]> : tensor<1x2xi64>} : () -> tensor<1x2xi64>
%0 = "onnx.Gather"(%arg0, %indices) {axis = 1} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<1x3x2xf32>
"std.return"(%0) : (tensor<1x3x2xf32>) -> ()
// CHECK-LABEL: test_gather_axis1
// CHECK: [[ALLOC:%.+]] = alloc() : memref<1x3x2xf32>
// CHECK: [[GLOBAL:%.+]] = "krnl.global"() {name = "constant_0", shape = [1, 2], value = dense<{{\[+}}0, 2{{\]+}}> : tensor<1x2xi64>} : () -> memref<1x2xi64>
// CHECK: [[LOOP:%.+]]:3 = krnl.define_loops 3
// CHECK: krnl.iterate([[LOOP]]#0, [[LOOP]]#1, [[LOOP]]#2) with ([[LOOP]]#0 -> [[ARG1:%.+]] = 0 to 3, [[LOOP]]#1 -> [[ARG2:%.+]] = 0 to 1, [[LOOP]]#2 -> [[ARG3:%.+]] = 0 to 2) {
// CHECK: [[AFFINE1:%.+]] = affine.load [[GLOBAL]]{{.}}[[ARG2]], [[ARG3]]{{.}} : memref<1x2xi64>
// CHECK: [[AFFINE2:%.+]] = index_cast [[AFFINE1]] : i64 to index
// CHECK: [[DATA:%.+]] = load %arg0{{.}}[[ARG1]], [[AFFINE2]]{{.}} : memref<3x3xf32>
// CHECK: affine.store [[DATA]], [[ALLOC]]{{.}}[[ARG1]], [[ARG2]], [[ARG3]]{{.}} : memref<1x3x2xf32>
}