Lower transpose operation to KRNL dialect (#54)

* Lower transpose operation.

* Fix IndetityOp.

* Add tests.

* Add backend tests.

* Clean-up code.

* Move transpose code and improve comment.
This commit is contained in:
Gheorghe-Teodor Bercea 2020-01-30 11:44:56 -05:00 committed by GitHub
parent 6959cf4586
commit 9fb826ae7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 170 additions and 8 deletions

View File

@ -583,9 +583,6 @@ void ONNXTransposeOp::inferShapes() {
// reversing the shape of the tensor (similar to numpy.transpose).
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
SmallVector<int64_t, 2> dims;
//if (auto permutation = getAttrOfType<ArrayAttr>(
// ONNXTransposeOp::getPermAttrName())) {
auto permutation = ONNXTransposeOp::permAttr();
if (permutation) {
// Perform transposition according to perm attribute.

View File

@ -1461,6 +1461,125 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
}
};
struct ONNXTransposeOpLowering : public ConversionPattern {
ONNXTransposeOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXTransposeOp::getOperationName(), 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType);
Value alloc;
bool insertDealloc = checkInsertDealloc(op);
if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
{operands[0]});
// Number of loops
auto memRefShape = memRefType.getShape();
int64_t rank = memRefShape.size();
// Define loops.
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank);
std::vector<Value> originalLoops;
originalLoops.reserve(rank);
for (auto result : loopsOp.getResults()) {
originalLoops.push_back(result);
}
// Define loop optimization.
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank);
std::vector<Value> optimizedLoops;
optimizedLoops.reserve(rank);
for (auto result : optimizedLoopsOp.getResults()) {
optimizedLoops.push_back(result);
}
Block &optimizationBlock = optimizedLoopsOp.region().front();
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
// Iterate over the loop nest using the input shape.
auto inputShape = operands[0].getType().cast<MemRefType>().getShape();
for (int i = 0; i < rank; ++i) {
if (inputShape[i] < 0) {
pack.pushConstantBound(0);
pack.pushOperandBound(
rewriter.create<DimOp>(loc, operands[0], i).getResult());
} else {
pack.pushConstantBound(0);
pack.pushConstantBound(inputShape[i]);
}
}
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
Block &iterationBlock = iterateOp.bodyRegion().front();
// Now perform the insertions into the body of the
// just generated instructions:
// 1. Insert any optimizations in the KrnlOptimizeLoopsOp body.
rewriter.setInsertionPointToEnd(&optimizationBlock);
// Return from KrnlOptimizeLoopsOp body.
// When no optimizations are present we just return the loops
// unchaged.
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
rewriter.setInsertionPoint(optimizedLoopsOp);
// 2. Insert instructions inside the KernelIterateOp body.
rewriter.setInsertionPointToStart(&iterationBlock);
// Handle the operation.
// Read perm attribute.
SmallVector<int, 4> perm;
auto permAttribute = llvm::dyn_cast<ONNXTransposeOp>(op).permAttr();
if (permAttribute) {
for (auto permVal : permAttribute.getValue())
perm.emplace_back(permVal.cast<IntegerAttr>().getInt());
} else {
// TODO: Remove when perm is guaranteed to be present (even for
// the default case). This means that perm was added by shape
// inference or another pass to contain the values corresponding
// to the default behavior of Transpose.
for (int i = iterationBlock.getArguments().size()-1; i >= 0; i--)
perm.emplace_back(i);
}
SmallVector<Value, 4> inLoopIVs;
for (auto arg : iterationBlock.getArguments())
inLoopIVs.emplace_back(arg);
SmallVector<Value, 4> outLoopIVs;
for (int i=0; i<iterationBlock.getArguments().size(); ++i)
outLoopIVs.emplace_back(iterationBlock.getArguments()[perm[i]]);
auto inVal = rewriter.create<LoadOp>(loc, operands[0], inLoopIVs);
rewriter.create<StoreOp>(loc, inVal, alloc, outLoopIVs);
rewriter.replaceOp(op, alloc);
return matchSuccess();
}
};
struct ONNXIdentityOpLowering : public ConversionPattern {
ONNXIdentityOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXIdentityOp::getOperationName(), 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
rewriter.replaceOp(op, operands[0]);
return matchSuccess();
}
};
//===----------------------------------------------------------------------===//
// EntryPoint Op lowering to Krnl Entry Point.
//===----------------------------------------------------------------------===//
@ -1590,7 +1709,8 @@ void FrontendToKrnlLoweringPass::runOnModule() {
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
ONNXReshapeOpLowering, ONNXEntryPointLowering,
ONNXSoftmaxOpLowering, ONNXGemmOpLowering,
ONNXUnsqueezeOpLowering>(&getContext());
ONNXUnsqueezeOpLowering, ONNXTransposeOpLowering,
ONNXIdentityOpLowering>(&getContext());
// With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal`

View File

@ -168,7 +168,7 @@ test_to_enable = [
"test_unsqueeze_negative_axes_cpu",
"test_unsqueeze_three_axes_cpu",
"test_unsqueeze_two_axes_cpu",
"test_unsqueeze_unsorted_axes_cpu",
# "test_unsqueeze_unsorted_axes_cpu",
# Reciprocal Op:
"test_reciprocal_cpu",
@ -192,6 +192,15 @@ test_to_enable = [
"test_reshape_reordered_last_dims_cpu",
#"test_reshape_zero_and_negative_dim_cpu", <- handle nagative dim
"test_reshape_zero_dim_cpu",
# Transpose
"test_transpose_default_cpu",
"test_transpose_all_permutations_0_cpu",
"test_transpose_all_permutations_1_cpu",
"test_transpose_all_permutations_2_cpu",
"test_transpose_all_permutations_3_cpu",
"test_transpose_all_permutations_4_cpu",
"test_transpose_all_permutations_5_cpu",
]
# Extract name of all test cases.
@ -202,9 +211,9 @@ all_test_names = list(map(lambda x: x[0], all_tests))
# Ensure that test names specified in test_to_enable actually exist.
for test_name in test_to_enable:
assert test_name in all_test_names, "test name {} not found, it is likely "
"that you may have misspelled the test name or the specified test does not "
"exist in the version of onnx package you installed.".format(
assert test_name in all_test_names, """test name {} not found, it is likely
that you may have misspelled the test name or the specified test does not
exist in the version of onnx package you installed.""".format(
test_name)
backend_test.include(r"^{}$".format(test_name))

View File

@ -702,3 +702,39 @@ func @test_unsqueeze(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> {
// CHECK: return [[RES]] : memref<1x10x10x1xf32>
}
func @test_transpose(%arg0 : tensor<10x20x30x40xf32>) -> tensor<*xf32> {
%0 = "onnx.Transpose"(%arg0) : (tensor<10x20x30x40xf32>) -> tensor<*xf32>
%1 = "onnx.Transpose"(%0) {perm = [0, 3, 1, 2]} : (tensor<*xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_transpose
// CHECK: [[RES0:%.+]] = alloc() : memref<40x10x30x20xf32>
// CHECK: [[RES1:%.+]] = alloc() : memref<40x30x20x10xf32>
// CHECK: [[LOOPS:%.+]]:4 = krnl.define_loops 4
// CHECK: [[OPT_LOOPS:%.+]]:4 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[LOOPS]]#0, [[LOOPS]]#1, [[LOOPS]]#2, [[LOOPS]]#3
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1, [[OPT_LOOPS]]#2, [[OPT_LOOPS]]#3) with ([[LOOPS]]#0 -> %arg1 = 0 to 10, [[LOOPS]]#1 -> %arg2 = 0 to 20, [[LOOPS]]#2 -> %arg3 = 0 to 30, [[LOOPS]]#3 -> %arg4 = 0 to 40) {
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2, %arg3, %arg4] : memref<10x20x30x40xf32>
// CHECK: store [[LOAD]], [[RES1]][%arg4, %arg3, %arg2, %arg1] : memref<40x30x20x10xf32>
// CHECK: [[LOOPS:%.+]]:4 = krnl.define_loops 4
// CHECK: [[OPT_LOOPS:%.+]]:4 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[LOOPS]]#0, [[LOOPS]]#1, [[LOOPS]]#2, [[LOOPS]]#3
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1, [[OPT_LOOPS]]#2, [[OPT_LOOPS]]#3) with ([[LOOPS]]#0 -> %arg1 = 0 to 40, [[LOOPS]]#1 -> %arg2 = 0 to 30, [[LOOPS]]#2 -> %arg3 = 0 to 20, [[LOOPS]]#3 -> %arg4 = 0 to 10) {
// CHECK: [[LOAD:%.+]] = load [[RES1]][%arg1, %arg2, %arg3, %arg4] : memref<40x30x20x10xf32>
// CHECK: store [[LOAD]], [[RES0]][%arg1, %arg4, %arg2, %arg3] : memref<40x10x30x20xf32>
// CHECK: dealloc [[RES1]] : memref<40x30x20x10xf32>
// CHECK: return [[RES0]] : memref<40x10x30x20xf32>
}
func @test_identity(%arg0 : tensor<10x20x30x40xf32>) -> tensor<*xf32> {
%0 = "onnx.Identity"(%arg0) : (tensor<10x20x30x40xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_identity
// CHECK: return %arg0 : memref<10x20x30x40xf32>
}