Gather ONNX to Kernel Lowering (#294)
* Define krnl.permute op. * Support krnl.permute operation. * Properly remove loop references. * Re-push, Github was down. * Need to debug interpretOp error. * Fix lowering bug by erasing ops after full krnl IR interpretation is done, and clean up & comment code. * Introduce permute, unroll operations. * More debug. * Remove std::set. * krnl.terminate fails to be converted. * Pass all tests, need to add legal ops as well as part of the conversion target. * Change test format to new permute spec. * Bug fix for nested iterate op lowering. * Simplify error reporting. * Fix compilation error. * Increase comments coverage. * Remove unnecessary imports. * Re-trigger Jenkins * Add permute/unroll tests. * Retrigger Jenkins * initial implementation of gather * added tests * format * remove affine load for second load, as it uses an indirection * changes suggested by reviewers * remove backend tests until I can verify them locally Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
		
							parent
							
								
									fa04c32a0c
								
							
						
					
					
						commit
						3a5aa7ee31
					
				| 
						 | 
					@ -22,6 +22,7 @@ add_library(OMONNXToKrnl
 | 
				
			||||||
        Tensor/Constant.cpp
 | 
					        Tensor/Constant.cpp
 | 
				
			||||||
        Tensor/Concat.cpp
 | 
					        Tensor/Concat.cpp
 | 
				
			||||||
        Tensor/Split.cpp
 | 
					        Tensor/Split.cpp
 | 
				
			||||||
 | 
					        Tensor/Gather.cpp
 | 
				
			||||||
        ConvertONNXToKrnl.cpp)
 | 
					        ConvertONNXToKrnl.cpp)
 | 
				
			||||||
target_link_libraries(OMONNXToKrnl
 | 
					target_link_libraries(OMONNXToKrnl
 | 
				
			||||||
        onnx)
 | 
					        onnx)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -96,6 +96,7 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
 | 
				
			||||||
  populateLoweringONNXPadOpPattern(patterns, &getContext());
 | 
					  populateLoweringONNXPadOpPattern(patterns, &getContext());
 | 
				
			||||||
  populateLoweringONNXUnsqueezeOpPattern(patterns, &getContext());
 | 
					  populateLoweringONNXUnsqueezeOpPattern(patterns, &getContext());
 | 
				
			||||||
  populateLoweringONNXTransposeOpPattern(patterns, &getContext());
 | 
					  populateLoweringONNXTransposeOpPattern(patterns, &getContext());
 | 
				
			||||||
 | 
					  populateLoweringONNXGatherOpPattern(patterns, &getContext());
 | 
				
			||||||
  populateLoweringONNXIdentityOpPattern(patterns, &getContext());
 | 
					  populateLoweringONNXIdentityOpPattern(patterns, &getContext());
 | 
				
			||||||
  populateLoweringONNXConstantOpPattern(patterns, &getContext());
 | 
					  populateLoweringONNXConstantOpPattern(patterns, &getContext());
 | 
				
			||||||
  populateLoweringONNXConcatOpPattern(patterns, &getContext());
 | 
					  populateLoweringONNXConcatOpPattern(patterns, &getContext());
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -223,6 +223,9 @@ void populateLoweringONNXUnsqueezeOpPattern(
 | 
				
			||||||
void populateLoweringONNXTransposeOpPattern(
 | 
					void populateLoweringONNXTransposeOpPattern(
 | 
				
			||||||
    OwningRewritePatternList &patterns, MLIRContext *ctx);
 | 
					    OwningRewritePatternList &patterns, MLIRContext *ctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void populateLoweringONNXGatherOpPattern(
 | 
				
			||||||
 | 
					    OwningRewritePatternList &patterns, MLIRContext *ctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void populateLoweringONNXPadConstantValuePadOpPattern(
 | 
					void populateLoweringONNXPadConstantValuePadOpPattern(
 | 
				
			||||||
    OwningRewritePatternList &patterns, MLIRContext *ctx);
 | 
					    OwningRewritePatternList &patterns, MLIRContext *ctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,117 @@
 | 
				
			||||||
 | 
					//===----------------Gather.cpp - Lowering Gather Op----------------------=== //
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Copyright 2020 The IBM Research Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// =============================================================================
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// This file lowers the ONNX Gather Operator to Krnl dialect.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using namespace mlir;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					struct ONNXGatherOpLowering : public ConversionPattern {
 | 
				
			||||||
 | 
					  ONNXGatherOpLowering(MLIRContext *ctx)
 | 
				
			||||||
 | 
					      : ConversionPattern(mlir::ONNXGatherOp::getOperationName(), 1, ctx) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
 | 
				
			||||||
 | 
					      ConversionPatternRewriter &rewriter) const final {
 | 
				
			||||||
 | 
					    ONNXGatherOpAdaptor operandAdaptor(operands);
 | 
				
			||||||
 | 
					    ONNXGatherOp gatherOp = llvm::cast<ONNXGatherOp>(op);
 | 
				
			||||||
 | 
					    auto loc = op->getLoc();
 | 
				
			||||||
 | 
					    // get input operands, shapes, and rank
 | 
				
			||||||
 | 
					    Value data = operandAdaptor.data();
 | 
				
			||||||
 | 
					    auto dataShape = data.getType().cast<MemRefType>().getShape();
 | 
				
			||||||
 | 
					    int64_t dataRank = dataShape.size();
 | 
				
			||||||
 | 
					    Value indices = operandAdaptor.indices();
 | 
				
			||||||
 | 
					    auto indicesShape = indices.getType().cast<MemRefType>().getShape();
 | 
				
			||||||
 | 
					    int64_t indicesRank = indicesShape.size();
 | 
				
			||||||
 | 
					    int64_t axisIndex = gatherOp.axis().getSExtValue();
 | 
				
			||||||
 | 
					    // get output info
 | 
				
			||||||
 | 
					    auto outputMemRefType = convertToMemRefType(*op->result_type_begin());
 | 
				
			||||||
 | 
					    auto outputMemRefShape = outputMemRefType.getShape();
 | 
				
			||||||
 | 
					    int64_t outputRank = outputMemRefShape.size();
 | 
				
			||||||
 | 
					    /*
 | 
				
			||||||
 | 
					      The pattern that we are using is that of numpy.take.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      Ni, Nk = data.shape[:axis], data.shape[axis+1:]
 | 
				
			||||||
 | 
					      Nj = indices.shape
 | 
				
			||||||
 | 
					      for ii in ndindex(Ni):
 | 
				
			||||||
 | 
					        for jj in ndindex(Nj):
 | 
				
			||||||
 | 
					          for kk in ndindex(Nk):
 | 
				
			||||||
 | 
					            out[ii + jj + kk] = data[ii + (indices[jj],) + kk]
 | 
				
			||||||
 | 
					    */
 | 
				
			||||||
 | 
					    // Define loops and iteration trip counts (equivalent to size of output)
 | 
				
			||||||
 | 
					    std::vector<Value> originalLoops;
 | 
				
			||||||
 | 
					    defineLoops(rewriter, loc, originalLoops, outputRank);
 | 
				
			||||||
 | 
					    KrnlIterateOperandPack pack(rewriter, originalLoops);
 | 
				
			||||||
 | 
					    int iIndexStart = 0;
 | 
				
			||||||
 | 
					    for (int ii = 0; ii < axisIndex; ++ii)
 | 
				
			||||||
 | 
					      addDimensionToPack(rewriter, loc, pack, data, ii);
 | 
				
			||||||
 | 
					    // Then iterates over the Nj (indices matrix), jj indices in above algo.
 | 
				
			||||||
 | 
					    int jIndexStart = iIndexStart + axisIndex;
 | 
				
			||||||
 | 
					    for (int jj = 0; jj < indicesRank; ++jj)
 | 
				
			||||||
 | 
					      addDimensionToPack(rewriter, loc, pack, indices, jj);
 | 
				
			||||||
 | 
					    // Finally iterates over the Nk (data after axis), kk indices in above algo.
 | 
				
			||||||
 | 
					    int kIndexStart = jIndexStart + indicesRank - (axisIndex + 1);
 | 
				
			||||||
 | 
					    for (int kk = axisIndex + 1; kk < dataRank; ++kk)
 | 
				
			||||||
 | 
					      addDimensionToPack(rewriter, loc, pack, data, kk);
 | 
				
			||||||
 | 
					    // Insert an allocation and deallocation for the result of this operation.
 | 
				
			||||||
 | 
					    Value alloc;
 | 
				
			||||||
 | 
					    bool insertDealloc = checkInsertDealloc(op);
 | 
				
			||||||
 | 
					    if (hasAllConstantDimensions(outputMemRefType))
 | 
				
			||||||
 | 
					      alloc =
 | 
				
			||||||
 | 
					          insertAllocAndDealloc(outputMemRefType, loc, rewriter, insertDealloc);
 | 
				
			||||||
 | 
					    else
 | 
				
			||||||
 | 
					      return emitError(loc, "unsupported dynamic dimensions");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Create the loops
 | 
				
			||||||
 | 
					    auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
 | 
				
			||||||
 | 
					    Block &iterationBlock = iterateOp.bodyRegion().front();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Now perform the insertions into the body of the just generated loops.
 | 
				
			||||||
 | 
					    // Insert instructions inside the KernelIterateOp body.
 | 
				
			||||||
 | 
					    rewriter.setInsertionPointToStart(&iterationBlock);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Handle the operations.
 | 
				
			||||||
 | 
					    // Read first the indices[jj] into indexVal.
 | 
				
			||||||
 | 
					    SmallVector<Value, 4> indicesMemRefVal;
 | 
				
			||||||
 | 
					    for (int j = 0; j < indicesRank; ++j)
 | 
				
			||||||
 | 
					      indicesMemRefVal.emplace_back(
 | 
				
			||||||
 | 
					          iterationBlock.getArguments()[jIndexStart + j]);
 | 
				
			||||||
 | 
					    auto indexValInteger =
 | 
				
			||||||
 | 
					        rewriter.create<AffineLoadOp>(loc, indices, indicesMemRefVal);
 | 
				
			||||||
 | 
					    auto indexVal = rewriter.create<IndexCastOp>(
 | 
				
			||||||
 | 
					        loc, indexValInteger, rewriter.getIndexType());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Then read input data into DataVal: first add ii's.
 | 
				
			||||||
 | 
					    SmallVector<Value, 4> dataMemRefVal;
 | 
				
			||||||
 | 
					    for (int i = 0; i < axisIndex; ++i)
 | 
				
			||||||
 | 
					      dataMemRefVal.emplace_back(
 | 
				
			||||||
 | 
					          iterationBlock.getArguments()[iIndexStart + i]);
 | 
				
			||||||
 | 
					    // Then add indices[jj] (indexVal)
 | 
				
			||||||
 | 
					    dataMemRefVal.emplace_back(indexVal);
 | 
				
			||||||
 | 
					    // Then add kk's
 | 
				
			||||||
 | 
					    for (int k = axisIndex + 1; k < dataRank; ++k)
 | 
				
			||||||
 | 
					      dataMemRefVal.emplace_back(
 | 
				
			||||||
 | 
					          iterationBlock.getArguments()[kIndexStart + k]);
 | 
				
			||||||
 | 
					    auto dataVal = rewriter.create<LoadOp>(loc, data, dataMemRefVal);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Then store the value in the output.
 | 
				
			||||||
 | 
					    SmallVector<Value, 4> outputMemRefVal;
 | 
				
			||||||
 | 
					    for (int n = 0; n < iterationBlock.getArguments().size(); ++n)
 | 
				
			||||||
 | 
					      outputMemRefVal.emplace_back(iterationBlock.getArguments()[n]);
 | 
				
			||||||
 | 
					    rewriter.create<AffineStoreOp>(loc, dataVal, alloc, outputMemRefVal);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    rewriter.replaceOp(op, alloc);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return success();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void populateLoweringONNXGatherOpPattern(
 | 
				
			||||||
 | 
					    OwningRewritePatternList &patterns, MLIRContext *ctx) {
 | 
				
			||||||
 | 
					  patterns.insert<ONNXGatherOpLowering>(ctx);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -29,6 +29,14 @@ struct ONNXTransposeOpLowering : public ConversionPattern {
 | 
				
			||||||
    if (hasAllConstantDimensions(memRefType))
 | 
					    if (hasAllConstantDimensions(memRefType))
 | 
				
			||||||
      alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
 | 
					      alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
 | 
				
			||||||
    else
 | 
					    else
 | 
				
			||||||
 | 
					      // TODO: While the code below appears to nominally handle the alloc of
 | 
				
			||||||
 | 
					      // data in presence of dynamic dimensions, this appears to be false as the
 | 
				
			||||||
 | 
					      // operand passed here "{data}" reflect the input sizes and does not
 | 
				
			||||||
 | 
					      // reflect the transpose. Indeed, if an input is 4x3x2 then the output
 | 
				
			||||||
 | 
					      // would be 2x3x4. If any of the 2,3,or 4 are dynamic dimensions, then we
 | 
				
			||||||
 | 
					      // simply pass below the operand of the not-transposed input data to
 | 
				
			||||||
 | 
					      // determine the dynamic sizes of the to-be-transposed data. At this time,
 | 
				
			||||||
 | 
					      // there is also no dynamic size lowering tests...
 | 
				
			||||||
      alloc = insertAllocAndDealloc(
 | 
					      alloc = insertAllocAndDealloc(
 | 
				
			||||||
          memRefType, loc, rewriter, insertDealloc, {data});
 | 
					          memRefType, loc, rewriter, insertDealloc, {data});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -512,6 +512,8 @@ mlir::Type ONNXOpsDialect::parseType(mlir::DialectAsmParser &parser) const {
 | 
				
			||||||
    if (parser.parseGreater())
 | 
					    if (parser.parseGreater())
 | 
				
			||||||
      return Type();
 | 
					      return Type();
 | 
				
			||||||
    return SeqType::get(elementTypes);
 | 
					    return SeqType::get(elementTypes);
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    llvm_unreachable("Unexpected onnxmlir keyword");
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -157,6 +157,11 @@ test_to_enable = [
 | 
				
			||||||
    "test_exp_cpu",
 | 
					    "test_exp_cpu",
 | 
				
			||||||
    "test_exp_example_cpu",
 | 
					    "test_exp_example_cpu",
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Gather Op:
 | 
				
			||||||
 | 
					    #"test_gather_0",
 | 
				
			||||||
 | 
					    #"test_gather_1",
 | 
				
			||||||
 | 
					    #"test_gather_negative_indices",
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Gemm Op:
 | 
					    # Gemm Op:
 | 
				
			||||||
    "test_gemm_all_attributes_cpu",
 | 
					    "test_gemm_all_attributes_cpu",
 | 
				
			||||||
    "test_gemm_alpha_cpu",
 | 
					    "test_gemm_alpha_cpu",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2132,3 +2132,41 @@ func @cast_lowering_f64f32_10(%arg0: tensor<10xf64>) -> tensor<*xf32> {
 | 
				
			||||||
  // CHECK: affine.store [[FPTRUNC]], [[RES]][%arg1] : memref<10xf32>
 | 
					  // CHECK: affine.store [[FPTRUNC]], [[RES]][%arg1] : memref<10xf32>
 | 
				
			||||||
  // CHECK: return [[RES]] : memref<10xf32>
 | 
					  // CHECK: return [[RES]] : memref<10xf32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Test gather along axis 0, first example in ONNX for Gather.
 | 
				
			||||||
 | 
					func @test_gather_axis0(%arg0 : tensor<3x2xf32>) -> tensor<2x2x2xf32> {
 | 
				
			||||||
 | 
					  %indices = "onnx.Constant"() {value = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>} : () -> tensor<2x2xi64>
 | 
				
			||||||
 | 
					  %0 = "onnx.Gather"(%arg0, %indices) {axis = 0} : (tensor<3x2xf32>, tensor<2x2xi64>) -> tensor<2x2x2xf32>
 | 
				
			||||||
 | 
					  "std.return"(%0) : (tensor<2x2x2xf32>) -> ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK-LABEL: test_gather_axis0
 | 
				
			||||||
 | 
					  // CHECK: [[ALLOC:%.+]] = alloc() : memref<2x2x2xf32>
 | 
				
			||||||
 | 
					  // CHECK: [[GLOBAL:%.+]] = "krnl.global"() {name = "{{.*}}", shape = [2, 2], value = dense<{{\[+}}0, 1], [1, 2{{\]+}}> : tensor<2x2xi64>} : () -> memref<2x2xi64>
 | 
				
			||||||
 | 
					  // CHECK: [[LOOP:%.+]]:3 = krnl.define_loops 3
 | 
				
			||||||
 | 
					  // CHECK: krnl.iterate([[LOOP]]#0, [[LOOP]]#1, [[LOOP]]#2) with ([[LOOP]]#0 -> [[ARG1:%.+]] = 0 to 2, [[LOOP]]#1 -> [[ARG2:%.+]] = 0 to 2, [[LOOP]]#2 -> [[ARG3:%.+]] = 0 to 2) {
 | 
				
			||||||
 | 
					  // CHECK: [[AFFINE1:%.+]] = affine.load [[GLOBAL]]{{.}}[[ARG1]], [[ARG2]]{{.}} : memref<2x2xi64>
 | 
				
			||||||
 | 
					  // CHECK: [[AFFINE2:%.+]] = index_cast [[AFFINE1]] : i64 to index
 | 
				
			||||||
 | 
					  // CHECK: [[DATA:%.+]] = load %arg0{{.}}[[AFFINE2]], [[ARG3]]{{.}} : memref<3x2xf32>
 | 
				
			||||||
 | 
					  // CHECK: affine.store [[DATA]], [[ALLOC]]{{.}}[[ARG1]], [[ARG2]], [[ARG3]]{{.}} : memref<2x2x2xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Test gather along axis 1, second example in ONNX for Gather.
 | 
				
			||||||
 | 
					func @test_gather_axis1(%arg0 : tensor<3x3xf32>) -> tensor<1x3x2xf32> {
 | 
				
			||||||
 | 
					  %indices = "onnx.Constant"() {value = dense<[[0, 2]]> : tensor<1x2xi64>} : () -> tensor<1x2xi64>
 | 
				
			||||||
 | 
					  %0 = "onnx.Gather"(%arg0, %indices) {axis = 1} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<1x3x2xf32>
 | 
				
			||||||
 | 
					  "std.return"(%0) : (tensor<1x3x2xf32>) -> ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK-LABEL: test_gather_axis1
 | 
				
			||||||
 | 
					  // CHECK: [[ALLOC:%.+]] = alloc() : memref<1x3x2xf32>
 | 
				
			||||||
 | 
					  // CHECK: [[GLOBAL:%.+]] = "krnl.global"() {name = "constant_0", shape = [1, 2], value = dense<{{\[+}}0, 2{{\]+}}> : tensor<1x2xi64>} : () -> memref<1x2xi64>
 | 
				
			||||||
 | 
					  // CHECK: [[LOOP:%.+]]:3 = krnl.define_loops 3
 | 
				
			||||||
 | 
					  // CHECK: krnl.iterate([[LOOP]]#0, [[LOOP]]#1, [[LOOP]]#2) with ([[LOOP]]#0 -> [[ARG1:%.+]] = 0 to 3, [[LOOP]]#1 -> [[ARG2:%.+]] = 0 to 1, [[LOOP]]#2 -> [[ARG3:%.+]] = 0 to 2) {
 | 
				
			||||||
 | 
					  // CHECK: [[AFFINE1:%.+]] = affine.load [[GLOBAL]]{{.}}[[ARG2]], [[ARG3]]{{.}} : memref<1x2xi64>
 | 
				
			||||||
 | 
					  // CHECK: [[AFFINE2:%.+]] = index_cast [[AFFINE1]] : i64 to index
 | 
				
			||||||
 | 
					  // CHECK: [[DATA:%.+]] = load %arg0{{.}}[[ARG1]], [[AFFINE2]]{{.}} : memref<3x3xf32>
 | 
				
			||||||
 | 
					  // CHECK: affine.store [[DATA]], [[ALLOC]]{{.}}[[ARG1]], [[ARG2]], [[ARG3]]{{.}} : memref<1x3x2xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue