constant folding for transpose of constant tensors (#171)
* added constant folding for transpose of constant tensors * format * responding to reviews
This commit is contained in:
parent
742e817722
commit
82d2caa542
|
@ -68,17 +68,9 @@ struct ONNXTransposeOpLowering : public ConversionPattern {
|
|||
// Read perm attribute.
|
||||
SmallVector<int, 4> perm;
|
||||
auto permAttribute = llvm::dyn_cast<ONNXTransposeOp>(op).permAttr();
|
||||
if (permAttribute) {
|
||||
for (auto permVal : permAttribute.getValue())
|
||||
perm.emplace_back(permVal.cast<IntegerAttr>().getInt());
|
||||
} else {
|
||||
// TODO: Remove when perm is guaranteed to be present (even for
|
||||
// the default case). This means that perm was added by shape
|
||||
// inference or another pass to contain the values corresponding
|
||||
// to the default behavior of Transpose.
|
||||
for (int i = iterationBlock.getArguments().size() - 1; i >= 0; i--)
|
||||
perm.emplace_back(i);
|
||||
}
|
||||
assert(permAttribute && "permute attribute expected to be defined here");
|
||||
for (auto permVal : permAttribute.getValue())
|
||||
perm.emplace_back(permVal.cast<IntegerAttr>().getInt());
|
||||
|
||||
SmallVector<Value, 4> inLoopIVs;
|
||||
for (auto arg : iterationBlock.getArguments())
|
||||
|
|
|
@ -1217,16 +1217,21 @@ LogicalResult ONNXTransposeOp::inferShapes() {
|
|||
auto arrayTy = data().getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t, 2> dims;
|
||||
auto permutation = ONNXTransposeOp::permAttr();
|
||||
if (permutation) {
|
||||
// Perform transposition according to perm attribute.
|
||||
for (auto perm : permutation.getValue())
|
||||
dims.emplace_back(arrayTy.getShape()[perm.cast<IntegerAttr>().getInt()]);
|
||||
} else {
|
||||
// Default
|
||||
for (auto dim : llvm::reverse(arrayTy.getShape()))
|
||||
dims.emplace_back(dim);
|
||||
if (!permutation) {
|
||||
// Generate revese order for default transpose operation.
|
||||
SmallVector<int64_t, 4> defaultVals;
|
||||
auto builder = mlir::Builder(getContext());
|
||||
auto rank = arrayTy.getShape().size();
|
||||
for (int i = rank - 1; i >= 0; --i)
|
||||
defaultVals.emplace_back(i);
|
||||
// Set default attribute.
|
||||
ArrayRef<int64_t> defaultRefs(defaultVals);
|
||||
permAttr(builder.getI64ArrayAttr(defaultRefs));
|
||||
permutation = permAttr();
|
||||
}
|
||||
|
||||
// Perform transposition according to perm attribute.
|
||||
for (auto perm : permutation.getValue())
|
||||
dims.emplace_back(arrayTy.getShape()[perm.cast<IntegerAttr>().getInt()]);
|
||||
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -39,7 +39,7 @@ namespace {
|
|||
|
||||
// The methods are:
|
||||
//
|
||||
// ComputeConstProppElementwiseBinary and ComputeConstProppElementwiseUnary
|
||||
// ComputeConstPropElementwiseBinary and ComputeConstPropElementwiseUnary
|
||||
// and they need to be tempalted wtih an ONNX Operation (presuably).
|
||||
//
|
||||
// Then you need to add rules on how to transform the patterns; look into
|
||||
|
@ -56,20 +56,19 @@ namespace {
|
|||
// attribute.
|
||||
|
||||
template <typename OP>
|
||||
Attribute ComputeConstProppElementwiseBinary(PatternRewriter &rewriter,
|
||||
Attribute ComputeConstPropElementwiseBinary(PatternRewriter &rewriter,
|
||||
Type elementType, Attribute &lhsAttr, Attribute &secondAttr) {
|
||||
llvm_unreachable("unkonwn operation");
|
||||
}
|
||||
|
||||
template <>
|
||||
Attribute ComputeConstProppElementwiseBinary<ONNXAddOp>(
|
||||
Attribute ComputeConstPropElementwiseBinary<ONNXAddOp>(
|
||||
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
|
||||
Attribute &secondAttr) {
|
||||
if (elementType.isa<FloatType>()) {
|
||||
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
|
||||
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
|
||||
double res = lhsVal + rhsVal;
|
||||
// printf(" %f + %f -> %f\n", lhsVal, rhsVal, res);
|
||||
// Could use the APFloat interface to emulate the results, are ok to simply
|
||||
// perform them in the highest possible precision.
|
||||
return rewriter.getFloatAttr(elementType, res);
|
||||
|
@ -78,14 +77,13 @@ Attribute ComputeConstProppElementwiseBinary<ONNXAddOp>(
|
|||
uint64_t lhsVal = lhsAttr.cast<IntegerAttr>().getInt();
|
||||
uint64_t rhsVal = secondAttr.cast<IntegerAttr>().getInt();
|
||||
uint64_t res = lhsVal + rhsVal;
|
||||
// printf(" %llu + %llu -> %llu\n", lhsVal, rhsVal, res);
|
||||
return rewriter.getIntegerAttr(elementType, res);
|
||||
}
|
||||
llvm_unreachable("constant propagation for AddOp: unkonwn data type");
|
||||
}
|
||||
|
||||
template <>
|
||||
Attribute ComputeConstProppElementwiseBinary<ONNXSubOp>(
|
||||
Attribute ComputeConstPropElementwiseBinary<ONNXSubOp>(
|
||||
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
|
||||
Attribute &secondAttr) {
|
||||
if (elementType.isa<FloatType>()) {
|
||||
|
@ -104,7 +102,7 @@ Attribute ComputeConstProppElementwiseBinary<ONNXSubOp>(
|
|||
}
|
||||
|
||||
template <>
|
||||
Attribute ComputeConstProppElementwiseBinary<ONNXMulOp>(
|
||||
Attribute ComputeConstPropElementwiseBinary<ONNXMulOp>(
|
||||
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
|
||||
Attribute &secondAttr) {
|
||||
if (elementType.isa<FloatType>()) {
|
||||
|
@ -133,11 +131,10 @@ Attribute ComputeConstProppElementwiseBinary<ONNXMulOp>(
|
|||
// dimension size is equal to 1.
|
||||
|
||||
template <typename ElementwiseBinaryOp>
|
||||
void RecurseConstProppElementwiseBinary(PatternRewriter &rewriter,
|
||||
void RecurseConstPropElementwiseBinary(PatternRewriter &rewriter,
|
||||
std::vector<Attribute> &resVector, DenseElementsAttr &lhsAttr,
|
||||
DenseElementsAttr &rhsAttr, SmallVector<uint64_t, 4> &lhsIndices,
|
||||
SmallVector<uint64_t, 4> &rhsIndices, int lhsFreeRank, int rhsFreeRank) {
|
||||
// printf("recurse with free %d/%d\n", lhsFreeRank, rhsFreeRank);
|
||||
if (lhsFreeRank == 0) {
|
||||
// Fully defined ranks.
|
||||
assert(
|
||||
|
@ -145,7 +142,7 @@ void RecurseConstProppElementwiseBinary(PatternRewriter &rewriter,
|
|||
auto lhsElementAttr = lhsAttr.getValue(ArrayRef<uint64_t>(lhsIndices));
|
||||
auto rhsElementAttr = rhsAttr.getValue(ArrayRef<uint64_t>(rhsIndices));
|
||||
auto elementaryType = lhsAttr.getType().getElementType();
|
||||
auto res = ComputeConstProppElementwiseBinary<ElementwiseBinaryOp>(
|
||||
auto res = ComputeConstPropElementwiseBinary<ElementwiseBinaryOp>(
|
||||
rewriter, elementaryType, lhsElementAttr, rhsElementAttr);
|
||||
resVector.emplace_back(res);
|
||||
} else if (lhsFreeRank > rhsFreeRank) {
|
||||
|
@ -156,7 +153,7 @@ void RecurseConstProppElementwiseBinary(PatternRewriter &rewriter,
|
|||
int lhsSize = lhsAttr.getType().getShape()[lhsIndex];
|
||||
for (int i = 0; i < lhsSize; ++i) {
|
||||
lhsIndices[lhsIndex] = i;
|
||||
RecurseConstProppElementwiseBinary<ElementwiseBinaryOp>(rewriter,
|
||||
RecurseConstPropElementwiseBinary<ElementwiseBinaryOp>(rewriter,
|
||||
resVector, lhsAttr, rhsAttr, lhsIndices, rhsIndices, lhsFreeRank - 1,
|
||||
rhsFreeRank);
|
||||
}
|
||||
|
@ -168,7 +165,7 @@ void RecurseConstProppElementwiseBinary(PatternRewriter &rewriter,
|
|||
int rhsSize = rhsAttr.getType().getShape()[rhsIndex];
|
||||
for (int i = 0; i < rhsSize; ++i) {
|
||||
rhsIndices[rhsIndex] = i;
|
||||
RecurseConstProppElementwiseBinary<ElementwiseBinaryOp>(rewriter,
|
||||
RecurseConstPropElementwiseBinary<ElementwiseBinaryOp>(rewriter,
|
||||
resVector, lhsAttr, rhsAttr, lhsIndices, rhsIndices, lhsFreeRank,
|
||||
rhsFreeRank - 1);
|
||||
}
|
||||
|
@ -192,7 +189,7 @@ void RecurseConstProppElementwiseBinary(PatternRewriter &rewriter,
|
|||
lhsIndices[lhsIndex] = i;
|
||||
if (rhsSize > 1)
|
||||
rhsIndices[rhsIndex] = i;
|
||||
RecurseConstProppElementwiseBinary<ElementwiseBinaryOp>(rewriter,
|
||||
RecurseConstPropElementwiseBinary<ElementwiseBinaryOp>(rewriter,
|
||||
resVector, lhsAttr, rhsAttr, lhsIndices, rhsIndices, lhsFreeRank - 1,
|
||||
rhsFreeRank - 1);
|
||||
}
|
||||
|
@ -217,7 +214,7 @@ DenseElementsAttr ConstPropElementwiseBinary(PatternRewriter &rewriter,
|
|||
SmallVector<uint64_t, 4> lhsIndices(lhsRank, 0);
|
||||
SmallVector<uint64_t, 4> rhsIndices(rhsRank, 0);
|
||||
std::vector<Attribute> resVector;
|
||||
RecurseConstProppElementwiseBinary<ElementwiseBinaryOp>(rewriter, resVector,
|
||||
RecurseConstPropElementwiseBinary<ElementwiseBinaryOp>(rewriter, resVector,
|
||||
lhsDenseAttr, rhsDenseAttr, lhsIndices, rhsIndices, lhsRank, rhsRank);
|
||||
ArrayRef<Attribute> resRef(resVector);
|
||||
return DenseElementsAttr::get(resType, resRef);
|
||||
|
@ -228,13 +225,13 @@ DenseElementsAttr ConstPropElementwiseBinary(PatternRewriter &rewriter,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
template <typename OP>
|
||||
Attribute ComputeConstProppElementwiseUnary(
|
||||
Attribute ComputeConstPropElementwiseUnary(
|
||||
PatternRewriter &rewriter, Type elementType, Attribute &attr) {
|
||||
llvm_unreachable("unkonwn operation");
|
||||
}
|
||||
|
||||
template <>
|
||||
Attribute ComputeConstProppElementwiseUnary<ONNXNegOp>(
|
||||
Attribute ComputeConstPropElementwiseUnary<ONNXNegOp>(
|
||||
PatternRewriter &rewriter, Type elementType, Attribute &attr) {
|
||||
if (elementType.isa<FloatType>()) {
|
||||
double val = attr.cast<FloatAttr>().getValueAsDouble();
|
||||
|
@ -250,15 +247,14 @@ Attribute ComputeConstProppElementwiseUnary<ONNXNegOp>(
|
|||
}
|
||||
|
||||
template <typename ElementwiseUnaryOp>
|
||||
void RecurseConstProppElementwiseUnary(PatternRewriter &rewriter,
|
||||
void RecurseConstPropElementwiseUnary(PatternRewriter &rewriter,
|
||||
std::vector<Attribute> &resVector, DenseElementsAttr &attr,
|
||||
SmallVector<uint64_t, 4> &indices, int freeRank) {
|
||||
// printf("recurse with free %d\n", freeRank);
|
||||
if (freeRank == 0) {
|
||||
// Fully defined ranks.
|
||||
auto elementAttr = attr.getValue(ArrayRef<uint64_t>(indices));
|
||||
auto elementaryType = attr.getType().getElementType();
|
||||
auto res = ComputeConstProppElementwiseUnary<ElementwiseUnaryOp>(
|
||||
auto res = ComputeConstPropElementwiseUnary<ElementwiseUnaryOp>(
|
||||
rewriter, elementaryType, elementAttr);
|
||||
resVector.emplace_back(res);
|
||||
} else {
|
||||
|
@ -269,7 +265,7 @@ void RecurseConstProppElementwiseUnary(PatternRewriter &rewriter,
|
|||
int size = attr.getType().getShape()[index];
|
||||
for (int i = 0; i < size; ++i) {
|
||||
indices[index] = i;
|
||||
RecurseConstProppElementwiseUnary<ElementwiseUnaryOp>(
|
||||
RecurseConstPropElementwiseUnary<ElementwiseUnaryOp>(
|
||||
rewriter, resVector, attr, indices, freeRank - 1);
|
||||
}
|
||||
}
|
||||
|
@ -289,12 +285,61 @@ DenseElementsAttr ConstPropElementwiseUnary(
|
|||
auto rank = denseAttr.getType().getShape().size();
|
||||
SmallVector<uint64_t, 4> indices(rank, 0);
|
||||
std::vector<Attribute> resVector;
|
||||
RecurseConstProppElementwiseUnary<ElementwiseUnaryOp>(
|
||||
RecurseConstPropElementwiseUnary<ElementwiseUnaryOp>(
|
||||
rewriter, resVector, denseAttr, indices, rank);
|
||||
ArrayRef<Attribute> resRef(resVector);
|
||||
return DenseElementsAttr::get(resType, resRef);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Code to perform constant propagation for transpose.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void RecurseConstPropTranspose(PatternRewriter &rewriter,
|
||||
std::vector<Attribute> &resVector, DenseElementsAttr &attr,
|
||||
SmallVector<uint64_t, 4> &indices, SmallVector<uint64_t, 4> &perm,
|
||||
int freeRank) {
|
||||
if (freeRank == 0) {
|
||||
// Fully defined ranks.
|
||||
auto res = attr.getValue(ArrayRef<uint64_t>(indices));
|
||||
resVector.emplace_back(res);
|
||||
} else {
|
||||
// Recurse.
|
||||
auto shape = attr.getType().getShape();
|
||||
int rank = shape.size();
|
||||
int index = perm[rank - freeRank];
|
||||
int size = attr.getType().getShape()[index];
|
||||
for (int i = 0; i < size; ++i) {
|
||||
indices[index] = i;
|
||||
RecurseConstPropTranspose(
|
||||
rewriter, resVector, attr, indices, perm, freeRank - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
DenseElementsAttr ConstPropTranspose(PatternRewriter &rewriter,
|
||||
Value resOperand, Attribute &attr, ArrayAttr &permAttr) {
|
||||
// Read dense attribute, the constant tensor we are transforming.
|
||||
DenseElementsAttr denseAttr =
|
||||
attr.dyn_cast_or_null<mlir::DenseElementsAttr>();
|
||||
assert(denseAttr && "expected dense attribute");
|
||||
ShapedType resType = resOperand.getType().cast<RankedTensorType>();
|
||||
auto rank = denseAttr.getType().getShape().size();
|
||||
// Read permute vector.
|
||||
SmallVector<uint64_t, 4> perm;
|
||||
assert(permAttr && "permute attribute expected to be defined here");
|
||||
for (auto permVal : permAttr.getValue())
|
||||
perm.emplace_back(permVal.cast<IntegerAttr>().getInt());
|
||||
// Init indice vector.
|
||||
SmallVector<uint64_t, 4> indices(rank, 0);
|
||||
std::vector<Attribute> resVector;
|
||||
// Copy using permute order.
|
||||
RecurseConstPropTranspose(
|
||||
rewriter, resVector, denseAttr, indices, perm, rank);
|
||||
ArrayRef<Attribute> resRef(resVector);
|
||||
return DenseElementsAttr::get(resType, resRef);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pattern definition.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -63,6 +63,9 @@ def CreateNegOfConst :
|
|||
def CreateMulOfTwoConst :
|
||||
NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXMulOp>($_builder, $0, $1, $2)">;
|
||||
|
||||
def CreateTransposeOfConst :
|
||||
NativeCodeCall<"ConstPropTranspose($_builder, $0, $1, $2)">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Patterns to enable opportunities with elementwise ADD operations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -229,4 +232,17 @@ def MulConstProp : Pat<
|
|||
// Mulitional constraints (no sparse)
|
||||
[(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Patterns to enable opportunities with Transpose operations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Neg of constant is simly -const
|
||||
def TransposeofConst : Pat<
|
||||
// From TransposeOp(c, p)
|
||||
(ONNXTransposeOp:$resOp (ONNXConstantOp $s, $v), $p),
|
||||
// To c' where c' is transposed attribute
|
||||
(ONNXConstantOp (GetNullAttr), (CreateTransposeOfConst $resOp, $v, $p)),
|
||||
[(AttributeIsNull:$s)]>;
|
||||
|
||||
|
||||
#endif // ONNX_CONSTPROP
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: onnx-mlir-opt --constprop-onnx %s -split-input-file | FileCheck %s
|
||||
// RUN: onnx-mlir-opt --shape-inference --constprop-onnx %s -split-input-file | FileCheck %s
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -195,3 +195,35 @@ func @test_neg_3(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
|
|||
// CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Transpose tests.
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: test_default_transpose_const_1
|
||||
func @test_default_transpose_const_1() -> tensor<*xi32> {
|
||||
%0 = "onnx.Constant"() {value = dense<[[[111, 112, 113, 114], [121, 122, 123, 124], [131, 132, 133, 134]], [[211, 212, 213, 214], [221, 222, 223, 224], [231, 232, 233, 234]]]> : tensor<2x3x4xi32>} : () -> tensor<2x3x4xi32>
|
||||
%1 = "onnx.Transpose"(%0) : (tensor<2x3x4xi32>) -> tensor<*xi32>
|
||||
"std.return"(%1) : (tensor<*xi32>) -> ()
|
||||
// CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<[{{.}}[111, 211], [121, 221], [131, 231]{{.}}, [{{.}}112, 212], [122, 222], [132, 232]{{.}}, [{{.}}113, 213], [123, 223], [133, 233]{{.}}, [{{.}}114, 214], [124, 224], [134, 234]{{.}}]> : tensor<4x3x2xi32>} : () -> tensor<4x3x2xi32>
|
||||
// CHECK: return [[RES]] : tensor<4x3x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: test_default_transpose_const_2
|
||||
func @test_default_transpose_const_2() -> tensor<*xi32> {
|
||||
%0 = "onnx.Constant"() {value = dense<[[[111, 112, 113, 114], [121, 122, 123, 124], [131, 132, 133, 134]], [[211, 212, 213, 214], [221, 222, 223, 224], [231, 232, 233, 234]]]> : tensor<2x3x4xi32>} : () -> tensor<2x3x4xi32>
|
||||
%1 = "onnx.Transpose"(%0) {perm = [0, 2, 1]} : (tensor<2x3x4xi32>) -> tensor<*xi32>
|
||||
"std.return"(%1) : (tensor<*xi32>) -> ()
|
||||
// CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<[{{.}}[111, 121, 131], [112, 122, 132], [113, 123, 133], [114, 124, 134]{{.}}, [{{.}}211, 221, 231], [212, 222, 232], [213, 223, 233], [214, 224, 234]{{.}}]> : tensor<2x4x3xi32>} : () -> tensor<2x4x3xi32>
|
||||
// CHECK: return [[RES]] : tensor<2x4x3xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: test_default_transpose_const_3
|
||||
func @test_default_transpose_const_3() -> tensor<*xi32> {
|
||||
%0 = "onnx.Constant"() {value = dense<[[[111, 112, 113, 114], [121, 122, 123, 124], [131, 132, 133, 134]], [[211, 212, 213, 214], [221, 222, 223, 224], [231, 232, 233, 234]]]> : tensor<2x3x4xi32>} : () -> tensor<2x3x4xi32>
|
||||
%1 = "onnx.Transpose"(%0) {perm = [1, 0, 2]} : (tensor<2x3x4xi32>) -> tensor<*xi32>
|
||||
"std.return"(%1) : (tensor<*xi32>) -> ()
|
||||
// CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<[{{.}}[111, 112, 113, 114], [211, 212, 213, 214]{{.}}, [{{.}}121, 122, 123, 124], [221, 222, 223, 224]{{.}}, [{{.}}131, 132, 133, 134], [231, 232, 233, 234]{{.}}]> : tensor<3x2x4xi32>} : () -> tensor<3x2x4xi32>
|
||||
// CHECK: return [[RES]] : tensor<3x2x4xi32>
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@ func @test_default_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
|
|||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_default_transpose
|
||||
// CHECK: [[RES:%.+]] = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32>
|
||||
// CHECK: [[RES:%.+]] = "onnx.Transpose"(%arg0) {perm = [3, 2, 1, 0]} : (tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32>
|
||||
// CHECK: return [[RES]] : tensor<32x1x5x5xf32>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue