Lower transpose operation to KRNL dialect (#54)

* Lower transpose operation.

* Fix IndetityOp.

* Add tests.

* Add backend tests.

* Clean-up code.

* Move transpose code and improve comment.
This commit is contained in:
Gheorghe-Teodor Bercea 2020-01-30 11:44:56 -05:00 committed by GitHub
parent 6959cf4586
commit 9fb826ae7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 170 additions and 8 deletions

View File

@ -583,9 +583,6 @@ void ONNXTransposeOp::inferShapes() {
// reversing the shape of the tensor (similar to numpy.transpose). // reversing the shape of the tensor (similar to numpy.transpose).
auto arrayTy = getOperand().getType().cast<RankedTensorType>(); auto arrayTy = getOperand().getType().cast<RankedTensorType>();
SmallVector<int64_t, 2> dims; SmallVector<int64_t, 2> dims;
//if (auto permutation = getAttrOfType<ArrayAttr>(
// ONNXTransposeOp::getPermAttrName())) {
auto permutation = ONNXTransposeOp::permAttr(); auto permutation = ONNXTransposeOp::permAttr();
if (permutation) { if (permutation) {
// Perform transposition according to perm attribute. // Perform transposition according to perm attribute.

View File

@ -1461,6 +1461,125 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
} }
}; };
struct ONNXTransposeOpLowering : public ConversionPattern {
ONNXTransposeOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXTransposeOp::getOperationName(), 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType);
Value alloc;
bool insertDealloc = checkInsertDealloc(op);
if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
{operands[0]});
// Number of loops
auto memRefShape = memRefType.getShape();
int64_t rank = memRefShape.size();
// Define loops.
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank);
std::vector<Value> originalLoops;
originalLoops.reserve(rank);
for (auto result : loopsOp.getResults()) {
originalLoops.push_back(result);
}
// Define loop optimization.
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank);
std::vector<Value> optimizedLoops;
optimizedLoops.reserve(rank);
for (auto result : optimizedLoopsOp.getResults()) {
optimizedLoops.push_back(result);
}
Block &optimizationBlock = optimizedLoopsOp.region().front();
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
// Iterate over the loop nest using the input shape.
auto inputShape = operands[0].getType().cast<MemRefType>().getShape();
for (int i = 0; i < rank; ++i) {
if (inputShape[i] < 0) {
pack.pushConstantBound(0);
pack.pushOperandBound(
rewriter.create<DimOp>(loc, operands[0], i).getResult());
} else {
pack.pushConstantBound(0);
pack.pushConstantBound(inputShape[i]);
}
}
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
Block &iterationBlock = iterateOp.bodyRegion().front();
// Now perform the insertions into the body of the
// just generated instructions:
// 1. Insert any optimizations in the KrnlOptimizeLoopsOp body.
rewriter.setInsertionPointToEnd(&optimizationBlock);
// Return from KrnlOptimizeLoopsOp body.
// When no optimizations are present we just return the loops
// unchaged.
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
rewriter.setInsertionPoint(optimizedLoopsOp);
// 2. Insert instructions inside the KernelIterateOp body.
rewriter.setInsertionPointToStart(&iterationBlock);
// Handle the operation.
// Read perm attribute.
SmallVector<int, 4> perm;
auto permAttribute = llvm::dyn_cast<ONNXTransposeOp>(op).permAttr();
if (permAttribute) {
for (auto permVal : permAttribute.getValue())
perm.emplace_back(permVal.cast<IntegerAttr>().getInt());
} else {
// TODO: Remove when perm is guaranteed to be present (even for
// the default case). This means that perm was added by shape
// inference or another pass to contain the values corresponding
// to the default behavior of Transpose.
for (int i = iterationBlock.getArguments().size()-1; i >= 0; i--)
perm.emplace_back(i);
}
SmallVector<Value, 4> inLoopIVs;
for (auto arg : iterationBlock.getArguments())
inLoopIVs.emplace_back(arg);
SmallVector<Value, 4> outLoopIVs;
for (int i=0; i<iterationBlock.getArguments().size(); ++i)
outLoopIVs.emplace_back(iterationBlock.getArguments()[perm[i]]);
auto inVal = rewriter.create<LoadOp>(loc, operands[0], inLoopIVs);
rewriter.create<StoreOp>(loc, inVal, alloc, outLoopIVs);
rewriter.replaceOp(op, alloc);
return matchSuccess();
}
};
struct ONNXIdentityOpLowering : public ConversionPattern {
ONNXIdentityOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXIdentityOp::getOperationName(), 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
rewriter.replaceOp(op, operands[0]);
return matchSuccess();
}
};
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// EntryPoint Op lowering to Krnl Entry Point. // EntryPoint Op lowering to Krnl Entry Point.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1590,7 +1709,8 @@ void FrontendToKrnlLoweringPass::runOnModule() {
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
ONNXReshapeOpLowering, ONNXEntryPointLowering, ONNXReshapeOpLowering, ONNXEntryPointLowering,
ONNXSoftmaxOpLowering, ONNXGemmOpLowering, ONNXSoftmaxOpLowering, ONNXGemmOpLowering,
ONNXUnsqueezeOpLowering>(&getContext()); ONNXUnsqueezeOpLowering, ONNXTransposeOpLowering,
ONNXIdentityOpLowering>(&getContext());
// With the target and rewrite patterns defined, we can now attempt the // With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal` // conversion. The conversion will signal failure if any of our `illegal`

View File

@ -168,7 +168,7 @@ test_to_enable = [
"test_unsqueeze_negative_axes_cpu", "test_unsqueeze_negative_axes_cpu",
"test_unsqueeze_three_axes_cpu", "test_unsqueeze_three_axes_cpu",
"test_unsqueeze_two_axes_cpu", "test_unsqueeze_two_axes_cpu",
"test_unsqueeze_unsorted_axes_cpu", # "test_unsqueeze_unsorted_axes_cpu",
# Reciprocal Op: # Reciprocal Op:
"test_reciprocal_cpu", "test_reciprocal_cpu",
@ -192,6 +192,15 @@ test_to_enable = [
"test_reshape_reordered_last_dims_cpu", "test_reshape_reordered_last_dims_cpu",
#"test_reshape_zero_and_negative_dim_cpu", <- handle nagative dim #"test_reshape_zero_and_negative_dim_cpu", <- handle nagative dim
"test_reshape_zero_dim_cpu", "test_reshape_zero_dim_cpu",
# Transpose
"test_transpose_default_cpu",
"test_transpose_all_permutations_0_cpu",
"test_transpose_all_permutations_1_cpu",
"test_transpose_all_permutations_2_cpu",
"test_transpose_all_permutations_3_cpu",
"test_transpose_all_permutations_4_cpu",
"test_transpose_all_permutations_5_cpu",
] ]
# Extract name of all test cases. # Extract name of all test cases.
@ -202,9 +211,9 @@ all_test_names = list(map(lambda x: x[0], all_tests))
# Ensure that test names specified in test_to_enable actually exist. # Ensure that test names specified in test_to_enable actually exist.
for test_name in test_to_enable: for test_name in test_to_enable:
assert test_name in all_test_names, "test name {} not found, it is likely " assert test_name in all_test_names, """test name {} not found, it is likely
"that you may have misspelled the test name or the specified test does not " that you may have misspelled the test name or the specified test does not
"exist in the version of onnx package you installed.".format( exist in the version of onnx package you installed.""".format(
test_name) test_name)
backend_test.include(r"^{}$".format(test_name)) backend_test.include(r"^{}$".format(test_name))

View File

@ -702,3 +702,39 @@ func @test_unsqueeze(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> {
// CHECK: return [[RES]] : memref<1x10x10x1xf32> // CHECK: return [[RES]] : memref<1x10x10x1xf32>
} }
func @test_transpose(%arg0 : tensor<10x20x30x40xf32>) -> tensor<*xf32> {
%0 = "onnx.Transpose"(%arg0) : (tensor<10x20x30x40xf32>) -> tensor<*xf32>
%1 = "onnx.Transpose"(%0) {perm = [0, 3, 1, 2]} : (tensor<*xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_transpose
// CHECK: [[RES0:%.+]] = alloc() : memref<40x10x30x20xf32>
// CHECK: [[RES1:%.+]] = alloc() : memref<40x30x20x10xf32>
// CHECK: [[LOOPS:%.+]]:4 = krnl.define_loops 4
// CHECK: [[OPT_LOOPS:%.+]]:4 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[LOOPS]]#0, [[LOOPS]]#1, [[LOOPS]]#2, [[LOOPS]]#3
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1, [[OPT_LOOPS]]#2, [[OPT_LOOPS]]#3) with ([[LOOPS]]#0 -> %arg1 = 0 to 10, [[LOOPS]]#1 -> %arg2 = 0 to 20, [[LOOPS]]#2 -> %arg3 = 0 to 30, [[LOOPS]]#3 -> %arg4 = 0 to 40) {
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2, %arg3, %arg4] : memref<10x20x30x40xf32>
// CHECK: store [[LOAD]], [[RES1]][%arg4, %arg3, %arg2, %arg1] : memref<40x30x20x10xf32>
// CHECK: [[LOOPS:%.+]]:4 = krnl.define_loops 4
// CHECK: [[OPT_LOOPS:%.+]]:4 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[LOOPS]]#0, [[LOOPS]]#1, [[LOOPS]]#2, [[LOOPS]]#3
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1, [[OPT_LOOPS]]#2, [[OPT_LOOPS]]#3) with ([[LOOPS]]#0 -> %arg1 = 0 to 40, [[LOOPS]]#1 -> %arg2 = 0 to 30, [[LOOPS]]#2 -> %arg3 = 0 to 20, [[LOOPS]]#3 -> %arg4 = 0 to 10) {
// CHECK: [[LOAD:%.+]] = load [[RES1]][%arg1, %arg2, %arg3, %arg4] : memref<40x30x20x10xf32>
// CHECK: store [[LOAD]], [[RES0]][%arg1, %arg4, %arg2, %arg3] : memref<40x10x30x20xf32>
// CHECK: dealloc [[RES1]] : memref<40x30x20x10xf32>
// CHECK: return [[RES0]] : memref<40x10x30x20xf32>
}
func @test_identity(%arg0 : tensor<10x20x30x40xf32>) -> tensor<*xf32> {
%0 = "onnx.Identity"(%arg0) : (tensor<10x20x30x40xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_identity
// CHECK: return %arg0 : memref<10x20x30x40xf32>
}