Lower transpose operation to KRNL dialect (#54)
* Lower transpose operation. * Fix IndetityOp. * Add tests. * Add backend tests. * Clean-up code. * Move transpose code and improve comment.
This commit is contained in:
parent
6959cf4586
commit
9fb826ae7e
|
@ -583,9 +583,6 @@ void ONNXTransposeOp::inferShapes() {
|
||||||
// reversing the shape of the tensor (similar to numpy.transpose).
|
// reversing the shape of the tensor (similar to numpy.transpose).
|
||||||
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
|
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
|
||||||
SmallVector<int64_t, 2> dims;
|
SmallVector<int64_t, 2> dims;
|
||||||
|
|
||||||
//if (auto permutation = getAttrOfType<ArrayAttr>(
|
|
||||||
// ONNXTransposeOp::getPermAttrName())) {
|
|
||||||
auto permutation = ONNXTransposeOp::permAttr();
|
auto permutation = ONNXTransposeOp::permAttr();
|
||||||
if (permutation) {
|
if (permutation) {
|
||||||
// Perform transposition according to perm attribute.
|
// Perform transposition according to perm attribute.
|
||||||
|
|
|
@ -1461,6 +1461,125 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct ONNXTransposeOpLowering : public ConversionPattern {
|
||||||
|
ONNXTransposeOpLowering(MLIRContext *ctx)
|
||||||
|
: ConversionPattern(mlir::ONNXTransposeOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
|
PatternMatchResult
|
||||||
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
|
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
|
auto memRefType = convertTensorToMemRef(tensorType);
|
||||||
|
Value alloc;
|
||||||
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
|
|
||||||
|
if (hasAllConstantDimensions(memRefType))
|
||||||
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
|
else
|
||||||
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
|
||||||
|
{operands[0]});
|
||||||
|
|
||||||
|
// Number of loops
|
||||||
|
auto memRefShape = memRefType.getShape();
|
||||||
|
int64_t rank = memRefShape.size();
|
||||||
|
|
||||||
|
// Define loops.
|
||||||
|
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank);
|
||||||
|
std::vector<Value> originalLoops;
|
||||||
|
originalLoops.reserve(rank);
|
||||||
|
|
||||||
|
for (auto result : loopsOp.getResults()) {
|
||||||
|
originalLoops.push_back(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Define loop optimization.
|
||||||
|
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank);
|
||||||
|
std::vector<Value> optimizedLoops;
|
||||||
|
optimizedLoops.reserve(rank);
|
||||||
|
|
||||||
|
for (auto result : optimizedLoopsOp.getResults()) {
|
||||||
|
optimizedLoops.push_back(result);
|
||||||
|
}
|
||||||
|
Block &optimizationBlock = optimizedLoopsOp.region().front();
|
||||||
|
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
|
||||||
|
// Iterate over the loop nest using the input shape.
|
||||||
|
auto inputShape = operands[0].getType().cast<MemRefType>().getShape();
|
||||||
|
for (int i = 0; i < rank; ++i) {
|
||||||
|
if (inputShape[i] < 0) {
|
||||||
|
pack.pushConstantBound(0);
|
||||||
|
pack.pushOperandBound(
|
||||||
|
rewriter.create<DimOp>(loc, operands[0], i).getResult());
|
||||||
|
} else {
|
||||||
|
pack.pushConstantBound(0);
|
||||||
|
pack.pushConstantBound(inputShape[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
|
||||||
|
Block &iterationBlock = iterateOp.bodyRegion().front();
|
||||||
|
|
||||||
|
// Now perform the insertions into the body of the
|
||||||
|
// just generated instructions:
|
||||||
|
|
||||||
|
// 1. Insert any optimizations in the KrnlOptimizeLoopsOp body.
|
||||||
|
rewriter.setInsertionPointToEnd(&optimizationBlock);
|
||||||
|
// Return from KrnlOptimizeLoopsOp body.
|
||||||
|
// When no optimizations are present we just return the loops
|
||||||
|
// unchaged.
|
||||||
|
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
|
||||||
|
rewriter.setInsertionPoint(optimizedLoopsOp);
|
||||||
|
|
||||||
|
// 2. Insert instructions inside the KernelIterateOp body.
|
||||||
|
rewriter.setInsertionPointToStart(&iterationBlock);
|
||||||
|
|
||||||
|
// Handle the operation.
|
||||||
|
|
||||||
|
// Read perm attribute.
|
||||||
|
SmallVector<int, 4> perm;
|
||||||
|
auto permAttribute = llvm::dyn_cast<ONNXTransposeOp>(op).permAttr();
|
||||||
|
if (permAttribute) {
|
||||||
|
for (auto permVal : permAttribute.getValue())
|
||||||
|
perm.emplace_back(permVal.cast<IntegerAttr>().getInt());
|
||||||
|
} else {
|
||||||
|
// TODO: Remove when perm is guaranteed to be present (even for
|
||||||
|
// the default case). This means that perm was added by shape
|
||||||
|
// inference or another pass to contain the values corresponding
|
||||||
|
// to the default behavior of Transpose.
|
||||||
|
for (int i = iterationBlock.getArguments().size()-1; i >= 0; i--)
|
||||||
|
perm.emplace_back(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value, 4> inLoopIVs;
|
||||||
|
for (auto arg : iterationBlock.getArguments())
|
||||||
|
inLoopIVs.emplace_back(arg);
|
||||||
|
|
||||||
|
SmallVector<Value, 4> outLoopIVs;
|
||||||
|
for (int i=0; i<iterationBlock.getArguments().size(); ++i)
|
||||||
|
outLoopIVs.emplace_back(iterationBlock.getArguments()[perm[i]]);
|
||||||
|
|
||||||
|
auto inVal = rewriter.create<LoadOp>(loc, operands[0], inLoopIVs);
|
||||||
|
rewriter.create<StoreOp>(loc, inVal, alloc, outLoopIVs);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, alloc);
|
||||||
|
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ONNXIdentityOpLowering : public ConversionPattern {
|
||||||
|
ONNXIdentityOpLowering(MLIRContext *ctx)
|
||||||
|
: ConversionPattern(mlir::ONNXIdentityOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
|
PatternMatchResult
|
||||||
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
|
rewriter.replaceOp(op, operands[0]);
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// EntryPoint Op lowering to Krnl Entry Point.
|
// EntryPoint Op lowering to Krnl Entry Point.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1590,7 +1709,8 @@ void FrontendToKrnlLoweringPass::runOnModule() {
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
|
||||||
ONNXReshapeOpLowering, ONNXEntryPointLowering,
|
ONNXReshapeOpLowering, ONNXEntryPointLowering,
|
||||||
ONNXSoftmaxOpLowering, ONNXGemmOpLowering,
|
ONNXSoftmaxOpLowering, ONNXGemmOpLowering,
|
||||||
ONNXUnsqueezeOpLowering>(&getContext());
|
ONNXUnsqueezeOpLowering, ONNXTransposeOpLowering,
|
||||||
|
ONNXIdentityOpLowering>(&getContext());
|
||||||
|
|
||||||
// With the target and rewrite patterns defined, we can now attempt the
|
// With the target and rewrite patterns defined, we can now attempt the
|
||||||
// conversion. The conversion will signal failure if any of our `illegal`
|
// conversion. The conversion will signal failure if any of our `illegal`
|
||||||
|
|
|
@ -168,7 +168,7 @@ test_to_enable = [
|
||||||
"test_unsqueeze_negative_axes_cpu",
|
"test_unsqueeze_negative_axes_cpu",
|
||||||
"test_unsqueeze_three_axes_cpu",
|
"test_unsqueeze_three_axes_cpu",
|
||||||
"test_unsqueeze_two_axes_cpu",
|
"test_unsqueeze_two_axes_cpu",
|
||||||
"test_unsqueeze_unsorted_axes_cpu",
|
# "test_unsqueeze_unsorted_axes_cpu",
|
||||||
|
|
||||||
# Reciprocal Op:
|
# Reciprocal Op:
|
||||||
"test_reciprocal_cpu",
|
"test_reciprocal_cpu",
|
||||||
|
@ -192,6 +192,15 @@ test_to_enable = [
|
||||||
"test_reshape_reordered_last_dims_cpu",
|
"test_reshape_reordered_last_dims_cpu",
|
||||||
#"test_reshape_zero_and_negative_dim_cpu", <- handle nagative dim
|
#"test_reshape_zero_and_negative_dim_cpu", <- handle nagative dim
|
||||||
"test_reshape_zero_dim_cpu",
|
"test_reshape_zero_dim_cpu",
|
||||||
|
|
||||||
|
# Transpose
|
||||||
|
"test_transpose_default_cpu",
|
||||||
|
"test_transpose_all_permutations_0_cpu",
|
||||||
|
"test_transpose_all_permutations_1_cpu",
|
||||||
|
"test_transpose_all_permutations_2_cpu",
|
||||||
|
"test_transpose_all_permutations_3_cpu",
|
||||||
|
"test_transpose_all_permutations_4_cpu",
|
||||||
|
"test_transpose_all_permutations_5_cpu",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Extract name of all test cases.
|
# Extract name of all test cases.
|
||||||
|
@ -202,9 +211,9 @@ all_test_names = list(map(lambda x: x[0], all_tests))
|
||||||
|
|
||||||
# Ensure that test names specified in test_to_enable actually exist.
|
# Ensure that test names specified in test_to_enable actually exist.
|
||||||
for test_name in test_to_enable:
|
for test_name in test_to_enable:
|
||||||
assert test_name in all_test_names, "test name {} not found, it is likely "
|
assert test_name in all_test_names, """test name {} not found, it is likely
|
||||||
"that you may have misspelled the test name or the specified test does not "
|
that you may have misspelled the test name or the specified test does not
|
||||||
"exist in the version of onnx package you installed.".format(
|
exist in the version of onnx package you installed.""".format(
|
||||||
test_name)
|
test_name)
|
||||||
backend_test.include(r"^{}$".format(test_name))
|
backend_test.include(r"^{}$".format(test_name))
|
||||||
|
|
||||||
|
|
|
@ -702,3 +702,39 @@ func @test_unsqueeze(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: return [[RES]] : memref<1x10x10x1xf32>
|
// CHECK: return [[RES]] : memref<1x10x10x1xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @test_transpose(%arg0 : tensor<10x20x30x40xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Transpose"(%arg0) : (tensor<10x20x30x40xf32>) -> tensor<*xf32>
|
||||||
|
%1 = "onnx.Transpose"(%0) {perm = [0, 3, 1, 2]} : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_transpose
|
||||||
|
// CHECK: [[RES0:%.+]] = alloc() : memref<40x10x30x20xf32>
|
||||||
|
// CHECK: [[RES1:%.+]] = alloc() : memref<40x30x20x10xf32>
|
||||||
|
|
||||||
|
// CHECK: [[LOOPS:%.+]]:4 = krnl.define_loops 4
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:4 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[LOOPS]]#0, [[LOOPS]]#1, [[LOOPS]]#2, [[LOOPS]]#3
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1, [[OPT_LOOPS]]#2, [[OPT_LOOPS]]#3) with ([[LOOPS]]#0 -> %arg1 = 0 to 10, [[LOOPS]]#1 -> %arg2 = 0 to 20, [[LOOPS]]#2 -> %arg3 = 0 to 30, [[LOOPS]]#3 -> %arg4 = 0 to 40) {
|
||||||
|
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2, %arg3, %arg4] : memref<10x20x30x40xf32>
|
||||||
|
// CHECK: store [[LOAD]], [[RES1]][%arg4, %arg3, %arg2, %arg1] : memref<40x30x20x10xf32>
|
||||||
|
|
||||||
|
// CHECK: [[LOOPS:%.+]]:4 = krnl.define_loops 4
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:4 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[LOOPS]]#0, [[LOOPS]]#1, [[LOOPS]]#2, [[LOOPS]]#3
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1, [[OPT_LOOPS]]#2, [[OPT_LOOPS]]#3) with ([[LOOPS]]#0 -> %arg1 = 0 to 40, [[LOOPS]]#1 -> %arg2 = 0 to 30, [[LOOPS]]#2 -> %arg3 = 0 to 20, [[LOOPS]]#3 -> %arg4 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD:%.+]] = load [[RES1]][%arg1, %arg2, %arg3, %arg4] : memref<40x30x20x10xf32>
|
||||||
|
// CHECK: store [[LOAD]], [[RES0]][%arg1, %arg4, %arg2, %arg3] : memref<40x10x30x20xf32>
|
||||||
|
|
||||||
|
// CHECK: dealloc [[RES1]] : memref<40x30x20x10xf32>
|
||||||
|
// CHECK: return [[RES0]] : memref<40x10x30x20xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @test_identity(%arg0 : tensor<10x20x30x40xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Identity"(%arg0) : (tensor<10x20x30x40xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_identity
|
||||||
|
// CHECK: return %arg0 : memref<10x20x30x40xf32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue