diff --git a/src/Conversion/ONNXToKrnl/CMakeLists.txt b/src/Conversion/ONNXToKrnl/CMakeLists.txt index 9d44980..2a99113 100644 --- a/src/Conversion/ONNXToKrnl/CMakeLists.txt +++ b/src/Conversion/ONNXToKrnl/CMakeLists.txt @@ -22,6 +22,7 @@ add_library(OMONNXToKrnl Tensor/Constant.cpp Tensor/Concat.cpp Tensor/Split.cpp + Tensor/Gather.cpp ConvertONNXToKrnl.cpp) target_link_libraries(OMONNXToKrnl onnx) diff --git a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp index 883e888..047f888 100644 --- a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp +++ b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp @@ -96,6 +96,7 @@ void FrontendToKrnlLoweringPass::runOnOperation() { populateLoweringONNXPadOpPattern(patterns, &getContext()); populateLoweringONNXUnsqueezeOpPattern(patterns, &getContext()); populateLoweringONNXTransposeOpPattern(patterns, &getContext()); + populateLoweringONNXGatherOpPattern(patterns, &getContext()); populateLoweringONNXIdentityOpPattern(patterns, &getContext()); populateLoweringONNXConstantOpPattern(patterns, &getContext()); populateLoweringONNXConcatOpPattern(patterns, &getContext()); diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp index e9e9416..b5ff470 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp @@ -223,6 +223,9 @@ void populateLoweringONNXUnsqueezeOpPattern( void populateLoweringONNXTransposeOpPattern( OwningRewritePatternList &patterns, MLIRContext *ctx); +void populateLoweringONNXGatherOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx); + void populateLoweringONNXPadConstantValuePadOpPattern( OwningRewritePatternList &patterns, MLIRContext *ctx); diff --git a/src/Conversion/ONNXToKrnl/Tensor/Gather.cpp b/src/Conversion/ONNXToKrnl/Tensor/Gather.cpp new file mode 100644 index 0000000..e215e51 --- /dev/null +++ b/src/Conversion/ONNXToKrnl/Tensor/Gather.cpp @@ -0,0 +1,117 @@ +//===----------------Gather.cpp - Lowering Gather Op----------------------=== // +// +// Copyright 2020 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the ONNX Gather Operator to Krnl dialect. +// +//===----------------------------------------------------------------------===// + +#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" + +using namespace mlir; + +struct ONNXGatherOpLowering : public ConversionPattern { + ONNXGatherOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXGatherOp::getOperationName(), 1, ctx) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + ONNXGatherOpAdaptor operandAdaptor(operands); + ONNXGatherOp gatherOp = llvm::cast(op); + auto loc = op->getLoc(); + // get input operands, shapes, and rank + Value data = operandAdaptor.data(); + auto dataShape = data.getType().cast().getShape(); + int64_t dataRank = dataShape.size(); + Value indices = operandAdaptor.indices(); + auto indicesShape = indices.getType().cast().getShape(); + int64_t indicesRank = indicesShape.size(); + int64_t axisIndex = gatherOp.axis().getSExtValue(); + // get output info + auto outputMemRefType = convertToMemRefType(*op->result_type_begin()); + auto outputMemRefShape = outputMemRefType.getShape(); + int64_t outputRank = outputMemRefShape.size(); + /* + The pattern that we are using is that of numpy.take. + + Ni, Nk = data.shape[:axis], data.shape[axis+1:] + Nj = indices.shape + for ii in ndindex(Ni): + for jj in ndindex(Nj): + for kk in ndindex(Nk): + out[ii + jj + kk] = data[ii + (indices[jj],) + kk] + */ + // Define loops and iteration trip counts (equivalent to size of output) + std::vector originalLoops; + defineLoops(rewriter, loc, originalLoops, outputRank); + KrnlIterateOperandPack pack(rewriter, originalLoops); + int iIndexStart = 0; + for (int ii = 0; ii < axisIndex; ++ii) + addDimensionToPack(rewriter, loc, pack, data, ii); + // Then iterates over the Nj (indices matrix), jj indices in above algo. + int jIndexStart = iIndexStart + axisIndex; + for (int jj = 0; jj < indicesRank; ++jj) + addDimensionToPack(rewriter, loc, pack, indices, jj); + // Finally iterates over the Nk (data after axis), kk indices in above algo. + int kIndexStart = jIndexStart + indicesRank - (axisIndex + 1); + for (int kk = axisIndex + 1; kk < dataRank; ++kk) + addDimensionToPack(rewriter, loc, pack, data, kk); + // Insert an allocation and deallocation for the result of this operation. + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + if (hasAllConstantDimensions(outputMemRefType)) + alloc = + insertAllocAndDealloc(outputMemRefType, loc, rewriter, insertDealloc); + else + return emitError(loc, "unsupported dynamic dimensions"); + + // Create the loops + auto iterateOp = rewriter.create(loc, pack); + Block &iterationBlock = iterateOp.bodyRegion().front(); + + // Now perform the insertions into the body of the just generated loops. + // Insert instructions inside the KernelIterateOp body. + rewriter.setInsertionPointToStart(&iterationBlock); + + // Handle the operations. + // Read first the indices[jj] into indexVal. + SmallVector indicesMemRefVal; + for (int j = 0; j < indicesRank; ++j) + indicesMemRefVal.emplace_back( + iterationBlock.getArguments()[jIndexStart + j]); + auto indexValInteger = + rewriter.create(loc, indices, indicesMemRefVal); + auto indexVal = rewriter.create( + loc, indexValInteger, rewriter.getIndexType()); + + // Then read input data into DataVal: first add ii's. + SmallVector dataMemRefVal; + for (int i = 0; i < axisIndex; ++i) + dataMemRefVal.emplace_back( + iterationBlock.getArguments()[iIndexStart + i]); + // Then add indices[jj] (indexVal) + dataMemRefVal.emplace_back(indexVal); + // Then add kk's + for (int k = axisIndex + 1; k < dataRank; ++k) + dataMemRefVal.emplace_back( + iterationBlock.getArguments()[kIndexStart + k]); + auto dataVal = rewriter.create(loc, data, dataMemRefVal); + + // Then store the value in the output. + SmallVector outputMemRefVal; + for (int n = 0; n < iterationBlock.getArguments().size(); ++n) + outputMemRefVal.emplace_back(iterationBlock.getArguments()[n]); + rewriter.create(loc, dataVal, alloc, outputMemRefVal); + + rewriter.replaceOp(op, alloc); + + return success(); + } +}; + +void populateLoweringONNXGatherOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} diff --git a/src/Conversion/ONNXToKrnl/Tensor/Transpose.cpp b/src/Conversion/ONNXToKrnl/Tensor/Transpose.cpp index d912f9e..7c5fcaa 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Transpose.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Transpose.cpp @@ -29,6 +29,14 @@ struct ONNXTransposeOpLowering : public ConversionPattern { if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else + // TODO: While the code below appears to nominally handle the alloc of + // data in presence of dynamic dimensions, this appears to be false as the + // operand passed here "{data}" reflect the input sizes and does not + // reflect the transpose. Indeed, if an input is 4x3x2 then the output + // would be 2x3x4. If any of the 2,3,or 4 are dynamic dimensions, then we + // simply pass below the operand of the not-transposed input data to + // determine the dynamic sizes of the to-be-transposed data. At this time, + // there is also no dynamic size lowering tests... alloc = insertAllocAndDealloc( memRefType, loc, rewriter, insertDealloc, {data}); diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index aaf375a..d3463e8 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -512,6 +512,8 @@ mlir::Type ONNXOpsDialect::parseType(mlir::DialectAsmParser &parser) const { if (parser.parseGreater()) return Type(); return SeqType::get(elementTypes); + } else { + llvm_unreachable("Unexpected onnxmlir keyword"); } } diff --git a/test/backend/test.py b/test/backend/test.py index 07887b2..6e58d79 100644 --- a/test/backend/test.py +++ b/test/backend/test.py @@ -157,6 +157,11 @@ test_to_enable = [ "test_exp_cpu", "test_exp_example_cpu", + # Gather Op: + #"test_gather_0", + #"test_gather_1", + #"test_gather_negative_indices", + # Gemm Op: "test_gemm_all_attributes_cpu", "test_gemm_alpha_cpu", diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 351c2c8..9e0d131 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -2132,3 +2132,41 @@ func @cast_lowering_f64f32_10(%arg0: tensor<10xf64>) -> tensor<*xf32> { // CHECK: affine.store [[FPTRUNC]], [[RES]][%arg1] : memref<10xf32> // CHECK: return [[RES]] : memref<10xf32> } + +// ----- + +// Test gather along axis 0, first example in ONNX for Gather. +func @test_gather_axis0(%arg0 : tensor<3x2xf32>) -> tensor<2x2x2xf32> { + %indices = "onnx.Constant"() {value = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>} : () -> tensor<2x2xi64> + %0 = "onnx.Gather"(%arg0, %indices) {axis = 0} : (tensor<3x2xf32>, tensor<2x2xi64>) -> tensor<2x2x2xf32> + "std.return"(%0) : (tensor<2x2x2xf32>) -> () + + // CHECK-LABEL: test_gather_axis0 + // CHECK: [[ALLOC:%.+]] = alloc() : memref<2x2x2xf32> + // CHECK: [[GLOBAL:%.+]] = "krnl.global"() {name = "{{.*}}", shape = [2, 2], value = dense<{{\[+}}0, 1], [1, 2{{\]+}}> : tensor<2x2xi64>} : () -> memref<2x2xi64> + // CHECK: [[LOOP:%.+]]:3 = krnl.define_loops 3 + // CHECK: krnl.iterate([[LOOP]]#0, [[LOOP]]#1, [[LOOP]]#2) with ([[LOOP]]#0 -> [[ARG1:%.+]] = 0 to 2, [[LOOP]]#1 -> [[ARG2:%.+]] = 0 to 2, [[LOOP]]#2 -> [[ARG3:%.+]] = 0 to 2) { + // CHECK: [[AFFINE1:%.+]] = affine.load [[GLOBAL]]{{.}}[[ARG1]], [[ARG2]]{{.}} : memref<2x2xi64> + // CHECK: [[AFFINE2:%.+]] = index_cast [[AFFINE1]] : i64 to index + // CHECK: [[DATA:%.+]] = load %arg0{{.}}[[AFFINE2]], [[ARG3]]{{.}} : memref<3x2xf32> + // CHECK: affine.store [[DATA]], [[ALLOC]]{{.}}[[ARG1]], [[ARG2]], [[ARG3]]{{.}} : memref<2x2x2xf32> +} + +// ----- + +// Test gather along axis 1, second example in ONNX for Gather. +func @test_gather_axis1(%arg0 : tensor<3x3xf32>) -> tensor<1x3x2xf32> { + %indices = "onnx.Constant"() {value = dense<[[0, 2]]> : tensor<1x2xi64>} : () -> tensor<1x2xi64> + %0 = "onnx.Gather"(%arg0, %indices) {axis = 1} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<1x3x2xf32> + "std.return"(%0) : (tensor<1x3x2xf32>) -> () + + // CHECK-LABEL: test_gather_axis1 + // CHECK: [[ALLOC:%.+]] = alloc() : memref<1x3x2xf32> + // CHECK: [[GLOBAL:%.+]] = "krnl.global"() {name = "constant_0", shape = [1, 2], value = dense<{{\[+}}0, 2{{\]+}}> : tensor<1x2xi64>} : () -> memref<1x2xi64> + // CHECK: [[LOOP:%.+]]:3 = krnl.define_loops 3 + // CHECK: krnl.iterate([[LOOP]]#0, [[LOOP]]#1, [[LOOP]]#2) with ([[LOOP]]#0 -> [[ARG1:%.+]] = 0 to 3, [[LOOP]]#1 -> [[ARG2:%.+]] = 0 to 1, [[LOOP]]#2 -> [[ARG3:%.+]] = 0 to 2) { + // CHECK: [[AFFINE1:%.+]] = affine.load [[GLOBAL]]{{.}}[[ARG2]], [[ARG3]]{{.}} : memref<1x2xi64> + // CHECK: [[AFFINE2:%.+]] = index_cast [[AFFINE1]] : i64 to index + // CHECK: [[DATA:%.+]] = load %arg0{{.}}[[ARG1]], [[AFFINE2]]{{.}} : memref<3x3xf32> + // CHECK: affine.store [[DATA]], [[ALLOC]]{{.}}[[ARG1]], [[ARG2]], [[ARG3]]{{.}} : memref<1x3x2xf32> +}