constant folding for transpose of constant tensors (#171)

* added constant folding for transpose of constant tensors

* format

* responding to reviews
This commit is contained in:
Alexandre Eichenberger 2020-06-17 10:42:06 -04:00 committed by GitHub
parent 742e817722
commit 82d2caa542
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 133 additions and 43 deletions

View File

@ -68,17 +68,9 @@ struct ONNXTransposeOpLowering : public ConversionPattern {
// Read perm attribute.
SmallVector<int, 4> perm;
auto permAttribute = llvm::dyn_cast<ONNXTransposeOp>(op).permAttr();
if (permAttribute) {
assert(permAttribute && "permute attribute expected to be defined here");
for (auto permVal : permAttribute.getValue())
perm.emplace_back(permVal.cast<IntegerAttr>().getInt());
} else {
// TODO: Remove when perm is guaranteed to be present (even for
// the default case). This means that perm was added by shape
// inference or another pass to contain the values corresponding
// to the default behavior of Transpose.
for (int i = iterationBlock.getArguments().size() - 1; i >= 0; i--)
perm.emplace_back(i);
}
SmallVector<Value, 4> inLoopIVs;
for (auto arg : iterationBlock.getArguments())

View File

@ -1217,16 +1217,21 @@ LogicalResult ONNXTransposeOp::inferShapes() {
auto arrayTy = data().getType().cast<RankedTensorType>();
SmallVector<int64_t, 2> dims;
auto permutation = ONNXTransposeOp::permAttr();
if (permutation) {
if (!permutation) {
// Generate revese order for default transpose operation.
SmallVector<int64_t, 4> defaultVals;
auto builder = mlir::Builder(getContext());
auto rank = arrayTy.getShape().size();
for (int i = rank - 1; i >= 0; --i)
defaultVals.emplace_back(i);
// Set default attribute.
ArrayRef<int64_t> defaultRefs(defaultVals);
permAttr(builder.getI64ArrayAttr(defaultRefs));
permutation = permAttr();
}
// Perform transposition according to perm attribute.
for (auto perm : permutation.getValue())
dims.emplace_back(arrayTy.getShape()[perm.cast<IntegerAttr>().getInt()]);
} else {
// Default
for (auto dim : llvm::reverse(arrayTy.getShape()))
dims.emplace_back(dim);
}
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
return success();
}

View File

@ -39,7 +39,7 @@ namespace {
// The methods are:
//
// ComputeConstProppElementwiseBinary and ComputeConstProppElementwiseUnary
// ComputeConstPropElementwiseBinary and ComputeConstPropElementwiseUnary
// and they need to be tempalted wtih an ONNX Operation (presuably).
//
// Then you need to add rules on how to transform the patterns; look into
@ -56,20 +56,19 @@ namespace {
// attribute.
template <typename OP>
Attribute ComputeConstProppElementwiseBinary(PatternRewriter &rewriter,
Attribute ComputeConstPropElementwiseBinary(PatternRewriter &rewriter,
Type elementType, Attribute &lhsAttr, Attribute &secondAttr) {
llvm_unreachable("unkonwn operation");
}
template <>
Attribute ComputeConstProppElementwiseBinary<ONNXAddOp>(
Attribute ComputeConstPropElementwiseBinary<ONNXAddOp>(
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
Attribute &secondAttr) {
if (elementType.isa<FloatType>()) {
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
double res = lhsVal + rhsVal;
// printf(" %f + %f -> %f\n", lhsVal, rhsVal, res);
// Could use the APFloat interface to emulate the results, are ok to simply
// perform them in the highest possible precision.
return rewriter.getFloatAttr(elementType, res);
@ -78,14 +77,13 @@ Attribute ComputeConstProppElementwiseBinary<ONNXAddOp>(
uint64_t lhsVal = lhsAttr.cast<IntegerAttr>().getInt();
uint64_t rhsVal = secondAttr.cast<IntegerAttr>().getInt();
uint64_t res = lhsVal + rhsVal;
// printf(" %llu + %llu -> %llu\n", lhsVal, rhsVal, res);
return rewriter.getIntegerAttr(elementType, res);
}
llvm_unreachable("constant propagation for AddOp: unkonwn data type");
}
template <>
Attribute ComputeConstProppElementwiseBinary<ONNXSubOp>(
Attribute ComputeConstPropElementwiseBinary<ONNXSubOp>(
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
Attribute &secondAttr) {
if (elementType.isa<FloatType>()) {
@ -104,7 +102,7 @@ Attribute ComputeConstProppElementwiseBinary<ONNXSubOp>(
}
template <>
Attribute ComputeConstProppElementwiseBinary<ONNXMulOp>(
Attribute ComputeConstPropElementwiseBinary<ONNXMulOp>(
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
Attribute &secondAttr) {
if (elementType.isa<FloatType>()) {
@ -133,11 +131,10 @@ Attribute ComputeConstProppElementwiseBinary<ONNXMulOp>(
// dimension size is equal to 1.
template <typename ElementwiseBinaryOp>
void RecurseConstProppElementwiseBinary(PatternRewriter &rewriter,
void RecurseConstPropElementwiseBinary(PatternRewriter &rewriter,
std::vector<Attribute> &resVector, DenseElementsAttr &lhsAttr,
DenseElementsAttr &rhsAttr, SmallVector<uint64_t, 4> &lhsIndices,
SmallVector<uint64_t, 4> &rhsIndices, int lhsFreeRank, int rhsFreeRank) {
// printf("recurse with free %d/%d\n", lhsFreeRank, rhsFreeRank);
if (lhsFreeRank == 0) {
// Fully defined ranks.
assert(
@ -145,7 +142,7 @@ void RecurseConstProppElementwiseBinary(PatternRewriter &rewriter,
auto lhsElementAttr = lhsAttr.getValue(ArrayRef<uint64_t>(lhsIndices));
auto rhsElementAttr = rhsAttr.getValue(ArrayRef<uint64_t>(rhsIndices));
auto elementaryType = lhsAttr.getType().getElementType();
auto res = ComputeConstProppElementwiseBinary<ElementwiseBinaryOp>(
auto res = ComputeConstPropElementwiseBinary<ElementwiseBinaryOp>(
rewriter, elementaryType, lhsElementAttr, rhsElementAttr);
resVector.emplace_back(res);
} else if (lhsFreeRank > rhsFreeRank) {
@ -156,7 +153,7 @@ void RecurseConstProppElementwiseBinary(PatternRewriter &rewriter,
int lhsSize = lhsAttr.getType().getShape()[lhsIndex];
for (int i = 0; i < lhsSize; ++i) {
lhsIndices[lhsIndex] = i;
RecurseConstProppElementwiseBinary<ElementwiseBinaryOp>(rewriter,
RecurseConstPropElementwiseBinary<ElementwiseBinaryOp>(rewriter,
resVector, lhsAttr, rhsAttr, lhsIndices, rhsIndices, lhsFreeRank - 1,
rhsFreeRank);
}
@ -168,7 +165,7 @@ void RecurseConstProppElementwiseBinary(PatternRewriter &rewriter,
int rhsSize = rhsAttr.getType().getShape()[rhsIndex];
for (int i = 0; i < rhsSize; ++i) {
rhsIndices[rhsIndex] = i;
RecurseConstProppElementwiseBinary<ElementwiseBinaryOp>(rewriter,
RecurseConstPropElementwiseBinary<ElementwiseBinaryOp>(rewriter,
resVector, lhsAttr, rhsAttr, lhsIndices, rhsIndices, lhsFreeRank,
rhsFreeRank - 1);
}
@ -192,7 +189,7 @@ void RecurseConstProppElementwiseBinary(PatternRewriter &rewriter,
lhsIndices[lhsIndex] = i;
if (rhsSize > 1)
rhsIndices[rhsIndex] = i;
RecurseConstProppElementwiseBinary<ElementwiseBinaryOp>(rewriter,
RecurseConstPropElementwiseBinary<ElementwiseBinaryOp>(rewriter,
resVector, lhsAttr, rhsAttr, lhsIndices, rhsIndices, lhsFreeRank - 1,
rhsFreeRank - 1);
}
@ -217,7 +214,7 @@ DenseElementsAttr ConstPropElementwiseBinary(PatternRewriter &rewriter,
SmallVector<uint64_t, 4> lhsIndices(lhsRank, 0);
SmallVector<uint64_t, 4> rhsIndices(rhsRank, 0);
std::vector<Attribute> resVector;
RecurseConstProppElementwiseBinary<ElementwiseBinaryOp>(rewriter, resVector,
RecurseConstPropElementwiseBinary<ElementwiseBinaryOp>(rewriter, resVector,
lhsDenseAttr, rhsDenseAttr, lhsIndices, rhsIndices, lhsRank, rhsRank);
ArrayRef<Attribute> resRef(resVector);
return DenseElementsAttr::get(resType, resRef);
@ -228,13 +225,13 @@ DenseElementsAttr ConstPropElementwiseBinary(PatternRewriter &rewriter,
//===----------------------------------------------------------------------===//
template <typename OP>
Attribute ComputeConstProppElementwiseUnary(
Attribute ComputeConstPropElementwiseUnary(
PatternRewriter &rewriter, Type elementType, Attribute &attr) {
llvm_unreachable("unkonwn operation");
}
template <>
Attribute ComputeConstProppElementwiseUnary<ONNXNegOp>(
Attribute ComputeConstPropElementwiseUnary<ONNXNegOp>(
PatternRewriter &rewriter, Type elementType, Attribute &attr) {
if (elementType.isa<FloatType>()) {
double val = attr.cast<FloatAttr>().getValueAsDouble();
@ -250,15 +247,14 @@ Attribute ComputeConstProppElementwiseUnary<ONNXNegOp>(
}
template <typename ElementwiseUnaryOp>
void RecurseConstProppElementwiseUnary(PatternRewriter &rewriter,
void RecurseConstPropElementwiseUnary(PatternRewriter &rewriter,
std::vector<Attribute> &resVector, DenseElementsAttr &attr,
SmallVector<uint64_t, 4> &indices, int freeRank) {
// printf("recurse with free %d\n", freeRank);
if (freeRank == 0) {
// Fully defined ranks.
auto elementAttr = attr.getValue(ArrayRef<uint64_t>(indices));
auto elementaryType = attr.getType().getElementType();
auto res = ComputeConstProppElementwiseUnary<ElementwiseUnaryOp>(
auto res = ComputeConstPropElementwiseUnary<ElementwiseUnaryOp>(
rewriter, elementaryType, elementAttr);
resVector.emplace_back(res);
} else {
@ -269,7 +265,7 @@ void RecurseConstProppElementwiseUnary(PatternRewriter &rewriter,
int size = attr.getType().getShape()[index];
for (int i = 0; i < size; ++i) {
indices[index] = i;
RecurseConstProppElementwiseUnary<ElementwiseUnaryOp>(
RecurseConstPropElementwiseUnary<ElementwiseUnaryOp>(
rewriter, resVector, attr, indices, freeRank - 1);
}
}
@ -289,12 +285,61 @@ DenseElementsAttr ConstPropElementwiseUnary(
auto rank = denseAttr.getType().getShape().size();
SmallVector<uint64_t, 4> indices(rank, 0);
std::vector<Attribute> resVector;
RecurseConstProppElementwiseUnary<ElementwiseUnaryOp>(
RecurseConstPropElementwiseUnary<ElementwiseUnaryOp>(
rewriter, resVector, denseAttr, indices, rank);
ArrayRef<Attribute> resRef(resVector);
return DenseElementsAttr::get(resType, resRef);
}
//===----------------------------------------------------------------------===//
// Code to perform constant propagation for transpose.
//===----------------------------------------------------------------------===//
void RecurseConstPropTranspose(PatternRewriter &rewriter,
std::vector<Attribute> &resVector, DenseElementsAttr &attr,
SmallVector<uint64_t, 4> &indices, SmallVector<uint64_t, 4> &perm,
int freeRank) {
if (freeRank == 0) {
// Fully defined ranks.
auto res = attr.getValue(ArrayRef<uint64_t>(indices));
resVector.emplace_back(res);
} else {
// Recurse.
auto shape = attr.getType().getShape();
int rank = shape.size();
int index = perm[rank - freeRank];
int size = attr.getType().getShape()[index];
for (int i = 0; i < size; ++i) {
indices[index] = i;
RecurseConstPropTranspose(
rewriter, resVector, attr, indices, perm, freeRank - 1);
}
}
}
DenseElementsAttr ConstPropTranspose(PatternRewriter &rewriter,
Value resOperand, Attribute &attr, ArrayAttr &permAttr) {
// Read dense attribute, the constant tensor we are transforming.
DenseElementsAttr denseAttr =
attr.dyn_cast_or_null<mlir::DenseElementsAttr>();
assert(denseAttr && "expected dense attribute");
ShapedType resType = resOperand.getType().cast<RankedTensorType>();
auto rank = denseAttr.getType().getShape().size();
// Read permute vector.
SmallVector<uint64_t, 4> perm;
assert(permAttr && "permute attribute expected to be defined here");
for (auto permVal : permAttr.getValue())
perm.emplace_back(permVal.cast<IntegerAttr>().getInt());
// Init indice vector.
SmallVector<uint64_t, 4> indices(rank, 0);
std::vector<Attribute> resVector;
// Copy using permute order.
RecurseConstPropTranspose(
rewriter, resVector, denseAttr, indices, perm, rank);
ArrayRef<Attribute> resRef(resVector);
return DenseElementsAttr::get(resType, resRef);
}
//===----------------------------------------------------------------------===//
// Pattern definition.
//===----------------------------------------------------------------------===//

View File

@ -63,6 +63,9 @@ def CreateNegOfConst :
def CreateMulOfTwoConst :
NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXMulOp>($_builder, $0, $1, $2)">;
def CreateTransposeOfConst :
NativeCodeCall<"ConstPropTranspose($_builder, $0, $1, $2)">;
//===----------------------------------------------------------------------===//
// Patterns to enable opportunities with elementwise ADD operations.
//===----------------------------------------------------------------------===//
@ -229,4 +232,17 @@ def MulConstProp : Pat<
// Mulitional constraints (no sparse)
[(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>;
//===----------------------------------------------------------------------===//
// Patterns to enable opportunities with Transpose operations.
//===----------------------------------------------------------------------===//
// Neg of constant is simly -const
def TransposeofConst : Pat<
// From TransposeOp(c, p)
(ONNXTransposeOp:$resOp (ONNXConstantOp $s, $v), $p),
// To c' where c' is transposed attribute
(ONNXConstantOp (GetNullAttr), (CreateTransposeOfConst $resOp, $v, $p)),
[(AttributeIsNull:$s)]>;
#endif // ONNX_CONSTPROP

View File

@ -1,4 +1,4 @@
// RUN: onnx-mlir-opt --constprop-onnx %s -split-input-file | FileCheck %s
// RUN: onnx-mlir-opt --shape-inference --constprop-onnx %s -split-input-file | FileCheck %s
//===----------------------------------------------------------------------===//
@ -195,3 +195,35 @@ func @test_neg_3(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
// CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
}
//===----------------------------------------------------------------------===//
/// Transpose tests.
// -----
// CHECK-LABEL: test_default_transpose_const_1
func @test_default_transpose_const_1() -> tensor<*xi32> {
%0 = "onnx.Constant"() {value = dense<[[[111, 112, 113, 114], [121, 122, 123, 124], [131, 132, 133, 134]], [[211, 212, 213, 214], [221, 222, 223, 224], [231, 232, 233, 234]]]> : tensor<2x3x4xi32>} : () -> tensor<2x3x4xi32>
%1 = "onnx.Transpose"(%0) : (tensor<2x3x4xi32>) -> tensor<*xi32>
"std.return"(%1) : (tensor<*xi32>) -> ()
// CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<[{{.}}[111, 211], [121, 221], [131, 231]{{.}}, [{{.}}112, 212], [122, 222], [132, 232]{{.}}, [{{.}}113, 213], [123, 223], [133, 233]{{.}}, [{{.}}114, 214], [124, 224], [134, 234]{{.}}]> : tensor<4x3x2xi32>} : () -> tensor<4x3x2xi32>
// CHECK: return [[RES]] : tensor<4x3x2xi32>
}
// -----
// CHECK-LABEL: test_default_transpose_const_2
func @test_default_transpose_const_2() -> tensor<*xi32> {
%0 = "onnx.Constant"() {value = dense<[[[111, 112, 113, 114], [121, 122, 123, 124], [131, 132, 133, 134]], [[211, 212, 213, 214], [221, 222, 223, 224], [231, 232, 233, 234]]]> : tensor<2x3x4xi32>} : () -> tensor<2x3x4xi32>
%1 = "onnx.Transpose"(%0) {perm = [0, 2, 1]} : (tensor<2x3x4xi32>) -> tensor<*xi32>
"std.return"(%1) : (tensor<*xi32>) -> ()
// CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<[{{.}}[111, 121, 131], [112, 122, 132], [113, 123, 133], [114, 124, 134]{{.}}, [{{.}}211, 221, 231], [212, 222, 232], [213, 223, 233], [214, 224, 234]{{.}}]> : tensor<2x4x3xi32>} : () -> tensor<2x4x3xi32>
// CHECK: return [[RES]] : tensor<2x4x3xi32>
}
// -----
// CHECK-LABEL: test_default_transpose_const_3
func @test_default_transpose_const_3() -> tensor<*xi32> {
%0 = "onnx.Constant"() {value = dense<[[[111, 112, 113, 114], [121, 122, 123, 124], [131, 132, 133, 134]], [[211, 212, 213, 214], [221, 222, 223, 224], [231, 232, 233, 234]]]> : tensor<2x3x4xi32>} : () -> tensor<2x3x4xi32>
%1 = "onnx.Transpose"(%0) {perm = [1, 0, 2]} : (tensor<2x3x4xi32>) -> tensor<*xi32>
"std.return"(%1) : (tensor<*xi32>) -> ()
// CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<[{{.}}[111, 112, 113, 114], [211, 212, 213, 214]{{.}}, [{{.}}121, 122, 123, 124], [221, 222, 223, 224]{{.}}, [{{.}}131, 132, 133, 134], [231, 232, 233, 234]{{.}}]> : tensor<3x2x4xi32>} : () -> tensor<3x2x4xi32>
// CHECK: return [[RES]] : tensor<3x2x4xi32>
}

View File

@ -12,7 +12,7 @@ func @test_default_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_default_transpose
// CHECK: [[RES:%.+]] = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32>
// CHECK: [[RES:%.+]] = "onnx.Transpose"(%arg0) {perm = [3, 2, 1, 0]} : (tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32>
// CHECK: return [[RES]] : tensor<32x1x5x5xf32>
}