Gather ONNX to Kernel Lowering (#294)
* Define krnl.permute op. * Support krnl.permute operation. * Properly remove loop references. * Re-push, Github was down. * Need to debug interpretOp error. * Fix lowering bug by erasing ops after full krnl IR interpretation is done, and clean up & comment code. * Introduce permute, unroll operations. * More debug. * Remove std::set. * krnl.terminate fails to be converted. * Pass all tests, need to add legal ops as well as part of the conversion target. * Change test format to new permute spec. * Bug fix for nested iterate op lowering. * Simplify error reporting. * Fix compilation error. * Increase comments coverage. * Remove unnecessary imports. * Re-trigger Jenkins * Add permute/unroll tests. * Retrigger Jenkins * initial implementation of gather * added tests * format * remove affine load for second load, as it uses an indirection * changes suggested by reviewers * remove backend tests until I can verify them locally Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
fa04c32a0c
commit
3a5aa7ee31
|
@ -22,6 +22,7 @@ add_library(OMONNXToKrnl
|
||||||
Tensor/Constant.cpp
|
Tensor/Constant.cpp
|
||||||
Tensor/Concat.cpp
|
Tensor/Concat.cpp
|
||||||
Tensor/Split.cpp
|
Tensor/Split.cpp
|
||||||
|
Tensor/Gather.cpp
|
||||||
ConvertONNXToKrnl.cpp)
|
ConvertONNXToKrnl.cpp)
|
||||||
target_link_libraries(OMONNXToKrnl
|
target_link_libraries(OMONNXToKrnl
|
||||||
onnx)
|
onnx)
|
||||||
|
|
|
@ -96,6 +96,7 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
|
||||||
populateLoweringONNXPadOpPattern(patterns, &getContext());
|
populateLoweringONNXPadOpPattern(patterns, &getContext());
|
||||||
populateLoweringONNXUnsqueezeOpPattern(patterns, &getContext());
|
populateLoweringONNXUnsqueezeOpPattern(patterns, &getContext());
|
||||||
populateLoweringONNXTransposeOpPattern(patterns, &getContext());
|
populateLoweringONNXTransposeOpPattern(patterns, &getContext());
|
||||||
|
populateLoweringONNXGatherOpPattern(patterns, &getContext());
|
||||||
populateLoweringONNXIdentityOpPattern(patterns, &getContext());
|
populateLoweringONNXIdentityOpPattern(patterns, &getContext());
|
||||||
populateLoweringONNXConstantOpPattern(patterns, &getContext());
|
populateLoweringONNXConstantOpPattern(patterns, &getContext());
|
||||||
populateLoweringONNXConcatOpPattern(patterns, &getContext());
|
populateLoweringONNXConcatOpPattern(patterns, &getContext());
|
||||||
|
|
|
@ -223,6 +223,9 @@ void populateLoweringONNXUnsqueezeOpPattern(
|
||||||
void populateLoweringONNXTransposeOpPattern(
|
void populateLoweringONNXTransposeOpPattern(
|
||||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
|
void populateLoweringONNXGatherOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
void populateLoweringONNXPadConstantValuePadOpPattern(
|
void populateLoweringONNXPadConstantValuePadOpPattern(
|
||||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,117 @@
|
||||||
|
//===----------------Gather.cpp - Lowering Gather Op----------------------=== //
|
||||||
|
//
|
||||||
|
// Copyright 2020 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file lowers the ONNX Gather Operator to Krnl dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
struct ONNXGatherOpLowering : public ConversionPattern {
|
||||||
|
ONNXGatherOpLowering(MLIRContext *ctx)
|
||||||
|
: ConversionPattern(mlir::ONNXGatherOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
|
ONNXGatherOpAdaptor operandAdaptor(operands);
|
||||||
|
ONNXGatherOp gatherOp = llvm::cast<ONNXGatherOp>(op);
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
// get input operands, shapes, and rank
|
||||||
|
Value data = operandAdaptor.data();
|
||||||
|
auto dataShape = data.getType().cast<MemRefType>().getShape();
|
||||||
|
int64_t dataRank = dataShape.size();
|
||||||
|
Value indices = operandAdaptor.indices();
|
||||||
|
auto indicesShape = indices.getType().cast<MemRefType>().getShape();
|
||||||
|
int64_t indicesRank = indicesShape.size();
|
||||||
|
int64_t axisIndex = gatherOp.axis().getSExtValue();
|
||||||
|
// get output info
|
||||||
|
auto outputMemRefType = convertToMemRefType(*op->result_type_begin());
|
||||||
|
auto outputMemRefShape = outputMemRefType.getShape();
|
||||||
|
int64_t outputRank = outputMemRefShape.size();
|
||||||
|
/*
|
||||||
|
The pattern that we are using is that of numpy.take.
|
||||||
|
|
||||||
|
Ni, Nk = data.shape[:axis], data.shape[axis+1:]
|
||||||
|
Nj = indices.shape
|
||||||
|
for ii in ndindex(Ni):
|
||||||
|
for jj in ndindex(Nj):
|
||||||
|
for kk in ndindex(Nk):
|
||||||
|
out[ii + jj + kk] = data[ii + (indices[jj],) + kk]
|
||||||
|
*/
|
||||||
|
// Define loops and iteration trip counts (equivalent to size of output)
|
||||||
|
std::vector<Value> originalLoops;
|
||||||
|
defineLoops(rewriter, loc, originalLoops, outputRank);
|
||||||
|
KrnlIterateOperandPack pack(rewriter, originalLoops);
|
||||||
|
int iIndexStart = 0;
|
||||||
|
for (int ii = 0; ii < axisIndex; ++ii)
|
||||||
|
addDimensionToPack(rewriter, loc, pack, data, ii);
|
||||||
|
// Then iterates over the Nj (indices matrix), jj indices in above algo.
|
||||||
|
int jIndexStart = iIndexStart + axisIndex;
|
||||||
|
for (int jj = 0; jj < indicesRank; ++jj)
|
||||||
|
addDimensionToPack(rewriter, loc, pack, indices, jj);
|
||||||
|
// Finally iterates over the Nk (data after axis), kk indices in above algo.
|
||||||
|
int kIndexStart = jIndexStart + indicesRank - (axisIndex + 1);
|
||||||
|
for (int kk = axisIndex + 1; kk < dataRank; ++kk)
|
||||||
|
addDimensionToPack(rewriter, loc, pack, data, kk);
|
||||||
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
|
Value alloc;
|
||||||
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
|
if (hasAllConstantDimensions(outputMemRefType))
|
||||||
|
alloc =
|
||||||
|
insertAllocAndDealloc(outputMemRefType, loc, rewriter, insertDealloc);
|
||||||
|
else
|
||||||
|
return emitError(loc, "unsupported dynamic dimensions");
|
||||||
|
|
||||||
|
// Create the loops
|
||||||
|
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
|
||||||
|
Block &iterationBlock = iterateOp.bodyRegion().front();
|
||||||
|
|
||||||
|
// Now perform the insertions into the body of the just generated loops.
|
||||||
|
// Insert instructions inside the KernelIterateOp body.
|
||||||
|
rewriter.setInsertionPointToStart(&iterationBlock);
|
||||||
|
|
||||||
|
// Handle the operations.
|
||||||
|
// Read first the indices[jj] into indexVal.
|
||||||
|
SmallVector<Value, 4> indicesMemRefVal;
|
||||||
|
for (int j = 0; j < indicesRank; ++j)
|
||||||
|
indicesMemRefVal.emplace_back(
|
||||||
|
iterationBlock.getArguments()[jIndexStart + j]);
|
||||||
|
auto indexValInteger =
|
||||||
|
rewriter.create<AffineLoadOp>(loc, indices, indicesMemRefVal);
|
||||||
|
auto indexVal = rewriter.create<IndexCastOp>(
|
||||||
|
loc, indexValInteger, rewriter.getIndexType());
|
||||||
|
|
||||||
|
// Then read input data into DataVal: first add ii's.
|
||||||
|
SmallVector<Value, 4> dataMemRefVal;
|
||||||
|
for (int i = 0; i < axisIndex; ++i)
|
||||||
|
dataMemRefVal.emplace_back(
|
||||||
|
iterationBlock.getArguments()[iIndexStart + i]);
|
||||||
|
// Then add indices[jj] (indexVal)
|
||||||
|
dataMemRefVal.emplace_back(indexVal);
|
||||||
|
// Then add kk's
|
||||||
|
for (int k = axisIndex + 1; k < dataRank; ++k)
|
||||||
|
dataMemRefVal.emplace_back(
|
||||||
|
iterationBlock.getArguments()[kIndexStart + k]);
|
||||||
|
auto dataVal = rewriter.create<LoadOp>(loc, data, dataMemRefVal);
|
||||||
|
|
||||||
|
// Then store the value in the output.
|
||||||
|
SmallVector<Value, 4> outputMemRefVal;
|
||||||
|
for (int n = 0; n < iterationBlock.getArguments().size(); ++n)
|
||||||
|
outputMemRefVal.emplace_back(iterationBlock.getArguments()[n]);
|
||||||
|
rewriter.create<AffineStoreOp>(loc, dataVal, alloc, outputMemRefVal);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, alloc);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void populateLoweringONNXGatherOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||||
|
patterns.insert<ONNXGatherOpLowering>(ctx);
|
||||||
|
}
|
|
@ -29,6 +29,14 @@ struct ONNXTransposeOpLowering : public ConversionPattern {
|
||||||
if (hasAllConstantDimensions(memRefType))
|
if (hasAllConstantDimensions(memRefType))
|
||||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
else
|
else
|
||||||
|
// TODO: While the code below appears to nominally handle the alloc of
|
||||||
|
// data in presence of dynamic dimensions, this appears to be false as the
|
||||||
|
// operand passed here "{data}" reflect the input sizes and does not
|
||||||
|
// reflect the transpose. Indeed, if an input is 4x3x2 then the output
|
||||||
|
// would be 2x3x4. If any of the 2,3,or 4 are dynamic dimensions, then we
|
||||||
|
// simply pass below the operand of the not-transposed input data to
|
||||||
|
// determine the dynamic sizes of the to-be-transposed data. At this time,
|
||||||
|
// there is also no dynamic size lowering tests...
|
||||||
alloc = insertAllocAndDealloc(
|
alloc = insertAllocAndDealloc(
|
||||||
memRefType, loc, rewriter, insertDealloc, {data});
|
memRefType, loc, rewriter, insertDealloc, {data});
|
||||||
|
|
||||||
|
|
|
@ -512,6 +512,8 @@ mlir::Type ONNXOpsDialect::parseType(mlir::DialectAsmParser &parser) const {
|
||||||
if (parser.parseGreater())
|
if (parser.parseGreater())
|
||||||
return Type();
|
return Type();
|
||||||
return SeqType::get(elementTypes);
|
return SeqType::get(elementTypes);
|
||||||
|
} else {
|
||||||
|
llvm_unreachable("Unexpected onnxmlir keyword");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -157,6 +157,11 @@ test_to_enable = [
|
||||||
"test_exp_cpu",
|
"test_exp_cpu",
|
||||||
"test_exp_example_cpu",
|
"test_exp_example_cpu",
|
||||||
|
|
||||||
|
# Gather Op:
|
||||||
|
#"test_gather_0",
|
||||||
|
#"test_gather_1",
|
||||||
|
#"test_gather_negative_indices",
|
||||||
|
|
||||||
# Gemm Op:
|
# Gemm Op:
|
||||||
"test_gemm_all_attributes_cpu",
|
"test_gemm_all_attributes_cpu",
|
||||||
"test_gemm_alpha_cpu",
|
"test_gemm_alpha_cpu",
|
||||||
|
|
|
@ -2132,3 +2132,41 @@ func @cast_lowering_f64f32_10(%arg0: tensor<10xf64>) -> tensor<*xf32> {
|
||||||
// CHECK: affine.store [[FPTRUNC]], [[RES]][%arg1] : memref<10xf32>
|
// CHECK: affine.store [[FPTRUNC]], [[RES]][%arg1] : memref<10xf32>
|
||||||
// CHECK: return [[RES]] : memref<10xf32>
|
// CHECK: return [[RES]] : memref<10xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Test gather along axis 0, first example in ONNX for Gather.
|
||||||
|
func @test_gather_axis0(%arg0 : tensor<3x2xf32>) -> tensor<2x2x2xf32> {
|
||||||
|
%indices = "onnx.Constant"() {value = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>} : () -> tensor<2x2xi64>
|
||||||
|
%0 = "onnx.Gather"(%arg0, %indices) {axis = 0} : (tensor<3x2xf32>, tensor<2x2xi64>) -> tensor<2x2x2xf32>
|
||||||
|
"std.return"(%0) : (tensor<2x2x2xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_gather_axis0
|
||||||
|
// CHECK: [[ALLOC:%.+]] = alloc() : memref<2x2x2xf32>
|
||||||
|
// CHECK: [[GLOBAL:%.+]] = "krnl.global"() {name = "{{.*}}", shape = [2, 2], value = dense<{{\[+}}0, 1], [1, 2{{\]+}}> : tensor<2x2xi64>} : () -> memref<2x2xi64>
|
||||||
|
// CHECK: [[LOOP:%.+]]:3 = krnl.define_loops 3
|
||||||
|
// CHECK: krnl.iterate([[LOOP]]#0, [[LOOP]]#1, [[LOOP]]#2) with ([[LOOP]]#0 -> [[ARG1:%.+]] = 0 to 2, [[LOOP]]#1 -> [[ARG2:%.+]] = 0 to 2, [[LOOP]]#2 -> [[ARG3:%.+]] = 0 to 2) {
|
||||||
|
// CHECK: [[AFFINE1:%.+]] = affine.load [[GLOBAL]]{{.}}[[ARG1]], [[ARG2]]{{.}} : memref<2x2xi64>
|
||||||
|
// CHECK: [[AFFINE2:%.+]] = index_cast [[AFFINE1]] : i64 to index
|
||||||
|
// CHECK: [[DATA:%.+]] = load %arg0{{.}}[[AFFINE2]], [[ARG3]]{{.}} : memref<3x2xf32>
|
||||||
|
// CHECK: affine.store [[DATA]], [[ALLOC]]{{.}}[[ARG1]], [[ARG2]], [[ARG3]]{{.}} : memref<2x2x2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Test gather along axis 1, second example in ONNX for Gather.
|
||||||
|
func @test_gather_axis1(%arg0 : tensor<3x3xf32>) -> tensor<1x3x2xf32> {
|
||||||
|
%indices = "onnx.Constant"() {value = dense<[[0, 2]]> : tensor<1x2xi64>} : () -> tensor<1x2xi64>
|
||||||
|
%0 = "onnx.Gather"(%arg0, %indices) {axis = 1} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<1x3x2xf32>
|
||||||
|
"std.return"(%0) : (tensor<1x3x2xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_gather_axis1
|
||||||
|
// CHECK: [[ALLOC:%.+]] = alloc() : memref<1x3x2xf32>
|
||||||
|
// CHECK: [[GLOBAL:%.+]] = "krnl.global"() {name = "constant_0", shape = [1, 2], value = dense<{{\[+}}0, 2{{\]+}}> : tensor<1x2xi64>} : () -> memref<1x2xi64>
|
||||||
|
// CHECK: [[LOOP:%.+]]:3 = krnl.define_loops 3
|
||||||
|
// CHECK: krnl.iterate([[LOOP]]#0, [[LOOP]]#1, [[LOOP]]#2) with ([[LOOP]]#0 -> [[ARG1:%.+]] = 0 to 3, [[LOOP]]#1 -> [[ARG2:%.+]] = 0 to 1, [[LOOP]]#2 -> [[ARG3:%.+]] = 0 to 2) {
|
||||||
|
// CHECK: [[AFFINE1:%.+]] = affine.load [[GLOBAL]]{{.}}[[ARG2]], [[ARG3]]{{.}} : memref<1x2xi64>
|
||||||
|
// CHECK: [[AFFINE2:%.+]] = index_cast [[AFFINE1]] : i64 to index
|
||||||
|
// CHECK: [[DATA:%.+]] = load %arg0{{.}}[[ARG1]], [[AFFINE2]]{{.}} : memref<3x3xf32>
|
||||||
|
// CHECK: affine.store [[DATA]], [[ALLOC]]{{.}}[[ARG1]], [[ARG2]], [[ARG3]]{{.}} : memref<1x3x2xf32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue