constant folding for transpose of constant tensors (#171)
* added constant folding for transpose of constant tensors * format * responding to reviews
This commit is contained in:
parent
742e817722
commit
82d2caa542
|
@ -68,17 +68,9 @@ struct ONNXTransposeOpLowering : public ConversionPattern {
|
||||||
// Read perm attribute.
|
// Read perm attribute.
|
||||||
SmallVector<int, 4> perm;
|
SmallVector<int, 4> perm;
|
||||||
auto permAttribute = llvm::dyn_cast<ONNXTransposeOp>(op).permAttr();
|
auto permAttribute = llvm::dyn_cast<ONNXTransposeOp>(op).permAttr();
|
||||||
if (permAttribute) {
|
assert(permAttribute && "permute attribute expected to be defined here");
|
||||||
for (auto permVal : permAttribute.getValue())
|
for (auto permVal : permAttribute.getValue())
|
||||||
perm.emplace_back(permVal.cast<IntegerAttr>().getInt());
|
perm.emplace_back(permVal.cast<IntegerAttr>().getInt());
|
||||||
} else {
|
|
||||||
// TODO: Remove when perm is guaranteed to be present (even for
|
|
||||||
// the default case). This means that perm was added by shape
|
|
||||||
// inference or another pass to contain the values corresponding
|
|
||||||
// to the default behavior of Transpose.
|
|
||||||
for (int i = iterationBlock.getArguments().size() - 1; i >= 0; i--)
|
|
||||||
perm.emplace_back(i);
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Value, 4> inLoopIVs;
|
SmallVector<Value, 4> inLoopIVs;
|
||||||
for (auto arg : iterationBlock.getArguments())
|
for (auto arg : iterationBlock.getArguments())
|
||||||
|
|
|
@ -1217,16 +1217,21 @@ LogicalResult ONNXTransposeOp::inferShapes() {
|
||||||
auto arrayTy = data().getType().cast<RankedTensorType>();
|
auto arrayTy = data().getType().cast<RankedTensorType>();
|
||||||
SmallVector<int64_t, 2> dims;
|
SmallVector<int64_t, 2> dims;
|
||||||
auto permutation = ONNXTransposeOp::permAttr();
|
auto permutation = ONNXTransposeOp::permAttr();
|
||||||
if (permutation) {
|
if (!permutation) {
|
||||||
|
// Generate revese order for default transpose operation.
|
||||||
|
SmallVector<int64_t, 4> defaultVals;
|
||||||
|
auto builder = mlir::Builder(getContext());
|
||||||
|
auto rank = arrayTy.getShape().size();
|
||||||
|
for (int i = rank - 1; i >= 0; --i)
|
||||||
|
defaultVals.emplace_back(i);
|
||||||
|
// Set default attribute.
|
||||||
|
ArrayRef<int64_t> defaultRefs(defaultVals);
|
||||||
|
permAttr(builder.getI64ArrayAttr(defaultRefs));
|
||||||
|
permutation = permAttr();
|
||||||
|
}
|
||||||
// Perform transposition according to perm attribute.
|
// Perform transposition according to perm attribute.
|
||||||
for (auto perm : permutation.getValue())
|
for (auto perm : permutation.getValue())
|
||||||
dims.emplace_back(arrayTy.getShape()[perm.cast<IntegerAttr>().getInt()]);
|
dims.emplace_back(arrayTy.getShape()[perm.cast<IntegerAttr>().getInt()]);
|
||||||
} else {
|
|
||||||
// Default
|
|
||||||
for (auto dim : llvm::reverse(arrayTy.getShape()))
|
|
||||||
dims.emplace_back(dim);
|
|
||||||
}
|
|
||||||
|
|
||||||
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
|
@ -39,7 +39,7 @@ namespace {
|
||||||
|
|
||||||
// The methods are:
|
// The methods are:
|
||||||
//
|
//
|
||||||
// ComputeConstProppElementwiseBinary and ComputeConstProppElementwiseUnary
|
// ComputeConstPropElementwiseBinary and ComputeConstPropElementwiseUnary
|
||||||
// and they need to be tempalted wtih an ONNX Operation (presuably).
|
// and they need to be tempalted wtih an ONNX Operation (presuably).
|
||||||
//
|
//
|
||||||
// Then you need to add rules on how to transform the patterns; look into
|
// Then you need to add rules on how to transform the patterns; look into
|
||||||
|
@ -56,20 +56,19 @@ namespace {
|
||||||
// attribute.
|
// attribute.
|
||||||
|
|
||||||
template <typename OP>
|
template <typename OP>
|
||||||
Attribute ComputeConstProppElementwiseBinary(PatternRewriter &rewriter,
|
Attribute ComputeConstPropElementwiseBinary(PatternRewriter &rewriter,
|
||||||
Type elementType, Attribute &lhsAttr, Attribute &secondAttr) {
|
Type elementType, Attribute &lhsAttr, Attribute &secondAttr) {
|
||||||
llvm_unreachable("unkonwn operation");
|
llvm_unreachable("unkonwn operation");
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Attribute ComputeConstProppElementwiseBinary<ONNXAddOp>(
|
Attribute ComputeConstPropElementwiseBinary<ONNXAddOp>(
|
||||||
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
|
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
|
||||||
Attribute &secondAttr) {
|
Attribute &secondAttr) {
|
||||||
if (elementType.isa<FloatType>()) {
|
if (elementType.isa<FloatType>()) {
|
||||||
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
|
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
|
||||||
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
|
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
|
||||||
double res = lhsVal + rhsVal;
|
double res = lhsVal + rhsVal;
|
||||||
// printf(" %f + %f -> %f\n", lhsVal, rhsVal, res);
|
|
||||||
// Could use the APFloat interface to emulate the results, are ok to simply
|
// Could use the APFloat interface to emulate the results, are ok to simply
|
||||||
// perform them in the highest possible precision.
|
// perform them in the highest possible precision.
|
||||||
return rewriter.getFloatAttr(elementType, res);
|
return rewriter.getFloatAttr(elementType, res);
|
||||||
|
@ -78,14 +77,13 @@ Attribute ComputeConstProppElementwiseBinary<ONNXAddOp>(
|
||||||
uint64_t lhsVal = lhsAttr.cast<IntegerAttr>().getInt();
|
uint64_t lhsVal = lhsAttr.cast<IntegerAttr>().getInt();
|
||||||
uint64_t rhsVal = secondAttr.cast<IntegerAttr>().getInt();
|
uint64_t rhsVal = secondAttr.cast<IntegerAttr>().getInt();
|
||||||
uint64_t res = lhsVal + rhsVal;
|
uint64_t res = lhsVal + rhsVal;
|
||||||
// printf(" %llu + %llu -> %llu\n", lhsVal, rhsVal, res);
|
|
||||||
return rewriter.getIntegerAttr(elementType, res);
|
return rewriter.getIntegerAttr(elementType, res);
|
||||||
}
|
}
|
||||||
llvm_unreachable("constant propagation for AddOp: unkonwn data type");
|
llvm_unreachable("constant propagation for AddOp: unkonwn data type");
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Attribute ComputeConstProppElementwiseBinary<ONNXSubOp>(
|
Attribute ComputeConstPropElementwiseBinary<ONNXSubOp>(
|
||||||
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
|
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
|
||||||
Attribute &secondAttr) {
|
Attribute &secondAttr) {
|
||||||
if (elementType.isa<FloatType>()) {
|
if (elementType.isa<FloatType>()) {
|
||||||
|
@ -104,7 +102,7 @@ Attribute ComputeConstProppElementwiseBinary<ONNXSubOp>(
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Attribute ComputeConstProppElementwiseBinary<ONNXMulOp>(
|
Attribute ComputeConstPropElementwiseBinary<ONNXMulOp>(
|
||||||
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
|
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
|
||||||
Attribute &secondAttr) {
|
Attribute &secondAttr) {
|
||||||
if (elementType.isa<FloatType>()) {
|
if (elementType.isa<FloatType>()) {
|
||||||
|
@ -133,11 +131,10 @@ Attribute ComputeConstProppElementwiseBinary<ONNXMulOp>(
|
||||||
// dimension size is equal to 1.
|
// dimension size is equal to 1.
|
||||||
|
|
||||||
template <typename ElementwiseBinaryOp>
|
template <typename ElementwiseBinaryOp>
|
||||||
void RecurseConstProppElementwiseBinary(PatternRewriter &rewriter,
|
void RecurseConstPropElementwiseBinary(PatternRewriter &rewriter,
|
||||||
std::vector<Attribute> &resVector, DenseElementsAttr &lhsAttr,
|
std::vector<Attribute> &resVector, DenseElementsAttr &lhsAttr,
|
||||||
DenseElementsAttr &rhsAttr, SmallVector<uint64_t, 4> &lhsIndices,
|
DenseElementsAttr &rhsAttr, SmallVector<uint64_t, 4> &lhsIndices,
|
||||||
SmallVector<uint64_t, 4> &rhsIndices, int lhsFreeRank, int rhsFreeRank) {
|
SmallVector<uint64_t, 4> &rhsIndices, int lhsFreeRank, int rhsFreeRank) {
|
||||||
// printf("recurse with free %d/%d\n", lhsFreeRank, rhsFreeRank);
|
|
||||||
if (lhsFreeRank == 0) {
|
if (lhsFreeRank == 0) {
|
||||||
// Fully defined ranks.
|
// Fully defined ranks.
|
||||||
assert(
|
assert(
|
||||||
|
@ -145,7 +142,7 @@ void RecurseConstProppElementwiseBinary(PatternRewriter &rewriter,
|
||||||
auto lhsElementAttr = lhsAttr.getValue(ArrayRef<uint64_t>(lhsIndices));
|
auto lhsElementAttr = lhsAttr.getValue(ArrayRef<uint64_t>(lhsIndices));
|
||||||
auto rhsElementAttr = rhsAttr.getValue(ArrayRef<uint64_t>(rhsIndices));
|
auto rhsElementAttr = rhsAttr.getValue(ArrayRef<uint64_t>(rhsIndices));
|
||||||
auto elementaryType = lhsAttr.getType().getElementType();
|
auto elementaryType = lhsAttr.getType().getElementType();
|
||||||
auto res = ComputeConstProppElementwiseBinary<ElementwiseBinaryOp>(
|
auto res = ComputeConstPropElementwiseBinary<ElementwiseBinaryOp>(
|
||||||
rewriter, elementaryType, lhsElementAttr, rhsElementAttr);
|
rewriter, elementaryType, lhsElementAttr, rhsElementAttr);
|
||||||
resVector.emplace_back(res);
|
resVector.emplace_back(res);
|
||||||
} else if (lhsFreeRank > rhsFreeRank) {
|
} else if (lhsFreeRank > rhsFreeRank) {
|
||||||
|
@ -156,7 +153,7 @@ void RecurseConstProppElementwiseBinary(PatternRewriter &rewriter,
|
||||||
int lhsSize = lhsAttr.getType().getShape()[lhsIndex];
|
int lhsSize = lhsAttr.getType().getShape()[lhsIndex];
|
||||||
for (int i = 0; i < lhsSize; ++i) {
|
for (int i = 0; i < lhsSize; ++i) {
|
||||||
lhsIndices[lhsIndex] = i;
|
lhsIndices[lhsIndex] = i;
|
||||||
RecurseConstProppElementwiseBinary<ElementwiseBinaryOp>(rewriter,
|
RecurseConstPropElementwiseBinary<ElementwiseBinaryOp>(rewriter,
|
||||||
resVector, lhsAttr, rhsAttr, lhsIndices, rhsIndices, lhsFreeRank - 1,
|
resVector, lhsAttr, rhsAttr, lhsIndices, rhsIndices, lhsFreeRank - 1,
|
||||||
rhsFreeRank);
|
rhsFreeRank);
|
||||||
}
|
}
|
||||||
|
@ -168,7 +165,7 @@ void RecurseConstProppElementwiseBinary(PatternRewriter &rewriter,
|
||||||
int rhsSize = rhsAttr.getType().getShape()[rhsIndex];
|
int rhsSize = rhsAttr.getType().getShape()[rhsIndex];
|
||||||
for (int i = 0; i < rhsSize; ++i) {
|
for (int i = 0; i < rhsSize; ++i) {
|
||||||
rhsIndices[rhsIndex] = i;
|
rhsIndices[rhsIndex] = i;
|
||||||
RecurseConstProppElementwiseBinary<ElementwiseBinaryOp>(rewriter,
|
RecurseConstPropElementwiseBinary<ElementwiseBinaryOp>(rewriter,
|
||||||
resVector, lhsAttr, rhsAttr, lhsIndices, rhsIndices, lhsFreeRank,
|
resVector, lhsAttr, rhsAttr, lhsIndices, rhsIndices, lhsFreeRank,
|
||||||
rhsFreeRank - 1);
|
rhsFreeRank - 1);
|
||||||
}
|
}
|
||||||
|
@ -192,7 +189,7 @@ void RecurseConstProppElementwiseBinary(PatternRewriter &rewriter,
|
||||||
lhsIndices[lhsIndex] = i;
|
lhsIndices[lhsIndex] = i;
|
||||||
if (rhsSize > 1)
|
if (rhsSize > 1)
|
||||||
rhsIndices[rhsIndex] = i;
|
rhsIndices[rhsIndex] = i;
|
||||||
RecurseConstProppElementwiseBinary<ElementwiseBinaryOp>(rewriter,
|
RecurseConstPropElementwiseBinary<ElementwiseBinaryOp>(rewriter,
|
||||||
resVector, lhsAttr, rhsAttr, lhsIndices, rhsIndices, lhsFreeRank - 1,
|
resVector, lhsAttr, rhsAttr, lhsIndices, rhsIndices, lhsFreeRank - 1,
|
||||||
rhsFreeRank - 1);
|
rhsFreeRank - 1);
|
||||||
}
|
}
|
||||||
|
@ -217,7 +214,7 @@ DenseElementsAttr ConstPropElementwiseBinary(PatternRewriter &rewriter,
|
||||||
SmallVector<uint64_t, 4> lhsIndices(lhsRank, 0);
|
SmallVector<uint64_t, 4> lhsIndices(lhsRank, 0);
|
||||||
SmallVector<uint64_t, 4> rhsIndices(rhsRank, 0);
|
SmallVector<uint64_t, 4> rhsIndices(rhsRank, 0);
|
||||||
std::vector<Attribute> resVector;
|
std::vector<Attribute> resVector;
|
||||||
RecurseConstProppElementwiseBinary<ElementwiseBinaryOp>(rewriter, resVector,
|
RecurseConstPropElementwiseBinary<ElementwiseBinaryOp>(rewriter, resVector,
|
||||||
lhsDenseAttr, rhsDenseAttr, lhsIndices, rhsIndices, lhsRank, rhsRank);
|
lhsDenseAttr, rhsDenseAttr, lhsIndices, rhsIndices, lhsRank, rhsRank);
|
||||||
ArrayRef<Attribute> resRef(resVector);
|
ArrayRef<Attribute> resRef(resVector);
|
||||||
return DenseElementsAttr::get(resType, resRef);
|
return DenseElementsAttr::get(resType, resRef);
|
||||||
|
@ -228,13 +225,13 @@ DenseElementsAttr ConstPropElementwiseBinary(PatternRewriter &rewriter,
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
template <typename OP>
|
template <typename OP>
|
||||||
Attribute ComputeConstProppElementwiseUnary(
|
Attribute ComputeConstPropElementwiseUnary(
|
||||||
PatternRewriter &rewriter, Type elementType, Attribute &attr) {
|
PatternRewriter &rewriter, Type elementType, Attribute &attr) {
|
||||||
llvm_unreachable("unkonwn operation");
|
llvm_unreachable("unkonwn operation");
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Attribute ComputeConstProppElementwiseUnary<ONNXNegOp>(
|
Attribute ComputeConstPropElementwiseUnary<ONNXNegOp>(
|
||||||
PatternRewriter &rewriter, Type elementType, Attribute &attr) {
|
PatternRewriter &rewriter, Type elementType, Attribute &attr) {
|
||||||
if (elementType.isa<FloatType>()) {
|
if (elementType.isa<FloatType>()) {
|
||||||
double val = attr.cast<FloatAttr>().getValueAsDouble();
|
double val = attr.cast<FloatAttr>().getValueAsDouble();
|
||||||
|
@ -250,15 +247,14 @@ Attribute ComputeConstProppElementwiseUnary<ONNXNegOp>(
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename ElementwiseUnaryOp>
|
template <typename ElementwiseUnaryOp>
|
||||||
void RecurseConstProppElementwiseUnary(PatternRewriter &rewriter,
|
void RecurseConstPropElementwiseUnary(PatternRewriter &rewriter,
|
||||||
std::vector<Attribute> &resVector, DenseElementsAttr &attr,
|
std::vector<Attribute> &resVector, DenseElementsAttr &attr,
|
||||||
SmallVector<uint64_t, 4> &indices, int freeRank) {
|
SmallVector<uint64_t, 4> &indices, int freeRank) {
|
||||||
// printf("recurse with free %d\n", freeRank);
|
|
||||||
if (freeRank == 0) {
|
if (freeRank == 0) {
|
||||||
// Fully defined ranks.
|
// Fully defined ranks.
|
||||||
auto elementAttr = attr.getValue(ArrayRef<uint64_t>(indices));
|
auto elementAttr = attr.getValue(ArrayRef<uint64_t>(indices));
|
||||||
auto elementaryType = attr.getType().getElementType();
|
auto elementaryType = attr.getType().getElementType();
|
||||||
auto res = ComputeConstProppElementwiseUnary<ElementwiseUnaryOp>(
|
auto res = ComputeConstPropElementwiseUnary<ElementwiseUnaryOp>(
|
||||||
rewriter, elementaryType, elementAttr);
|
rewriter, elementaryType, elementAttr);
|
||||||
resVector.emplace_back(res);
|
resVector.emplace_back(res);
|
||||||
} else {
|
} else {
|
||||||
|
@ -269,7 +265,7 @@ void RecurseConstProppElementwiseUnary(PatternRewriter &rewriter,
|
||||||
int size = attr.getType().getShape()[index];
|
int size = attr.getType().getShape()[index];
|
||||||
for (int i = 0; i < size; ++i) {
|
for (int i = 0; i < size; ++i) {
|
||||||
indices[index] = i;
|
indices[index] = i;
|
||||||
RecurseConstProppElementwiseUnary<ElementwiseUnaryOp>(
|
RecurseConstPropElementwiseUnary<ElementwiseUnaryOp>(
|
||||||
rewriter, resVector, attr, indices, freeRank - 1);
|
rewriter, resVector, attr, indices, freeRank - 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -289,12 +285,61 @@ DenseElementsAttr ConstPropElementwiseUnary(
|
||||||
auto rank = denseAttr.getType().getShape().size();
|
auto rank = denseAttr.getType().getShape().size();
|
||||||
SmallVector<uint64_t, 4> indices(rank, 0);
|
SmallVector<uint64_t, 4> indices(rank, 0);
|
||||||
std::vector<Attribute> resVector;
|
std::vector<Attribute> resVector;
|
||||||
RecurseConstProppElementwiseUnary<ElementwiseUnaryOp>(
|
RecurseConstPropElementwiseUnary<ElementwiseUnaryOp>(
|
||||||
rewriter, resVector, denseAttr, indices, rank);
|
rewriter, resVector, denseAttr, indices, rank);
|
||||||
ArrayRef<Attribute> resRef(resVector);
|
ArrayRef<Attribute> resRef(resVector);
|
||||||
return DenseElementsAttr::get(resType, resRef);
|
return DenseElementsAttr::get(resType, resRef);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Code to perform constant propagation for transpose.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void RecurseConstPropTranspose(PatternRewriter &rewriter,
|
||||||
|
std::vector<Attribute> &resVector, DenseElementsAttr &attr,
|
||||||
|
SmallVector<uint64_t, 4> &indices, SmallVector<uint64_t, 4> &perm,
|
||||||
|
int freeRank) {
|
||||||
|
if (freeRank == 0) {
|
||||||
|
// Fully defined ranks.
|
||||||
|
auto res = attr.getValue(ArrayRef<uint64_t>(indices));
|
||||||
|
resVector.emplace_back(res);
|
||||||
|
} else {
|
||||||
|
// Recurse.
|
||||||
|
auto shape = attr.getType().getShape();
|
||||||
|
int rank = shape.size();
|
||||||
|
int index = perm[rank - freeRank];
|
||||||
|
int size = attr.getType().getShape()[index];
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
indices[index] = i;
|
||||||
|
RecurseConstPropTranspose(
|
||||||
|
rewriter, resVector, attr, indices, perm, freeRank - 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
DenseElementsAttr ConstPropTranspose(PatternRewriter &rewriter,
|
||||||
|
Value resOperand, Attribute &attr, ArrayAttr &permAttr) {
|
||||||
|
// Read dense attribute, the constant tensor we are transforming.
|
||||||
|
DenseElementsAttr denseAttr =
|
||||||
|
attr.dyn_cast_or_null<mlir::DenseElementsAttr>();
|
||||||
|
assert(denseAttr && "expected dense attribute");
|
||||||
|
ShapedType resType = resOperand.getType().cast<RankedTensorType>();
|
||||||
|
auto rank = denseAttr.getType().getShape().size();
|
||||||
|
// Read permute vector.
|
||||||
|
SmallVector<uint64_t, 4> perm;
|
||||||
|
assert(permAttr && "permute attribute expected to be defined here");
|
||||||
|
for (auto permVal : permAttr.getValue())
|
||||||
|
perm.emplace_back(permVal.cast<IntegerAttr>().getInt());
|
||||||
|
// Init indice vector.
|
||||||
|
SmallVector<uint64_t, 4> indices(rank, 0);
|
||||||
|
std::vector<Attribute> resVector;
|
||||||
|
// Copy using permute order.
|
||||||
|
RecurseConstPropTranspose(
|
||||||
|
rewriter, resVector, denseAttr, indices, perm, rank);
|
||||||
|
ArrayRef<Attribute> resRef(resVector);
|
||||||
|
return DenseElementsAttr::get(resType, resRef);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Pattern definition.
|
// Pattern definition.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -63,6 +63,9 @@ def CreateNegOfConst :
|
||||||
def CreateMulOfTwoConst :
|
def CreateMulOfTwoConst :
|
||||||
NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXMulOp>($_builder, $0, $1, $2)">;
|
NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXMulOp>($_builder, $0, $1, $2)">;
|
||||||
|
|
||||||
|
def CreateTransposeOfConst :
|
||||||
|
NativeCodeCall<"ConstPropTranspose($_builder, $0, $1, $2)">;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Patterns to enable opportunities with elementwise ADD operations.
|
// Patterns to enable opportunities with elementwise ADD operations.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -229,4 +232,17 @@ def MulConstProp : Pat<
|
||||||
// Mulitional constraints (no sparse)
|
// Mulitional constraints (no sparse)
|
||||||
[(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>;
|
[(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Patterns to enable opportunities with Transpose operations.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// Neg of constant is simly -const
|
||||||
|
def TransposeofConst : Pat<
|
||||||
|
// From TransposeOp(c, p)
|
||||||
|
(ONNXTransposeOp:$resOp (ONNXConstantOp $s, $v), $p),
|
||||||
|
// To c' where c' is transposed attribute
|
||||||
|
(ONNXConstantOp (GetNullAttr), (CreateTransposeOfConst $resOp, $v, $p)),
|
||||||
|
[(AttributeIsNull:$s)]>;
|
||||||
|
|
||||||
|
|
||||||
#endif // ONNX_CONSTPROP
|
#endif // ONNX_CONSTPROP
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: onnx-mlir-opt --constprop-onnx %s -split-input-file | FileCheck %s
|
// RUN: onnx-mlir-opt --shape-inference --constprop-onnx %s -split-input-file | FileCheck %s
|
||||||
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -195,3 +195,35 @@ func @test_neg_3(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
|
||||||
// CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
|
// CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
/// Transpose tests.
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: test_default_transpose_const_1
|
||||||
|
func @test_default_transpose_const_1() -> tensor<*xi32> {
|
||||||
|
%0 = "onnx.Constant"() {value = dense<[[[111, 112, 113, 114], [121, 122, 123, 124], [131, 132, 133, 134]], [[211, 212, 213, 214], [221, 222, 223, 224], [231, 232, 233, 234]]]> : tensor<2x3x4xi32>} : () -> tensor<2x3x4xi32>
|
||||||
|
%1 = "onnx.Transpose"(%0) : (tensor<2x3x4xi32>) -> tensor<*xi32>
|
||||||
|
"std.return"(%1) : (tensor<*xi32>) -> ()
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<[{{.}}[111, 211], [121, 221], [131, 231]{{.}}, [{{.}}112, 212], [122, 222], [132, 232]{{.}}, [{{.}}113, 213], [123, 223], [133, 233]{{.}}, [{{.}}114, 214], [124, 224], [134, 234]{{.}}]> : tensor<4x3x2xi32>} : () -> tensor<4x3x2xi32>
|
||||||
|
// CHECK: return [[RES]] : tensor<4x3x2xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: test_default_transpose_const_2
|
||||||
|
func @test_default_transpose_const_2() -> tensor<*xi32> {
|
||||||
|
%0 = "onnx.Constant"() {value = dense<[[[111, 112, 113, 114], [121, 122, 123, 124], [131, 132, 133, 134]], [[211, 212, 213, 214], [221, 222, 223, 224], [231, 232, 233, 234]]]> : tensor<2x3x4xi32>} : () -> tensor<2x3x4xi32>
|
||||||
|
%1 = "onnx.Transpose"(%0) {perm = [0, 2, 1]} : (tensor<2x3x4xi32>) -> tensor<*xi32>
|
||||||
|
"std.return"(%1) : (tensor<*xi32>) -> ()
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<[{{.}}[111, 121, 131], [112, 122, 132], [113, 123, 133], [114, 124, 134]{{.}}, [{{.}}211, 221, 231], [212, 222, 232], [213, 223, 233], [214, 224, 234]{{.}}]> : tensor<2x4x3xi32>} : () -> tensor<2x4x3xi32>
|
||||||
|
// CHECK: return [[RES]] : tensor<2x4x3xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: test_default_transpose_const_3
|
||||||
|
func @test_default_transpose_const_3() -> tensor<*xi32> {
|
||||||
|
%0 = "onnx.Constant"() {value = dense<[[[111, 112, 113, 114], [121, 122, 123, 124], [131, 132, 133, 134]], [[211, 212, 213, 214], [221, 222, 223, 224], [231, 232, 233, 234]]]> : tensor<2x3x4xi32>} : () -> tensor<2x3x4xi32>
|
||||||
|
%1 = "onnx.Transpose"(%0) {perm = [1, 0, 2]} : (tensor<2x3x4xi32>) -> tensor<*xi32>
|
||||||
|
"std.return"(%1) : (tensor<*xi32>) -> ()
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<[{{.}}[111, 112, 113, 114], [211, 212, 213, 214]{{.}}, [{{.}}121, 122, 123, 124], [221, 222, 223, 224]{{.}}, [{{.}}131, 132, 133, 134], [231, 232, 233, 234]{{.}}]> : tensor<3x2x4xi32>} : () -> tensor<3x2x4xi32>
|
||||||
|
// CHECK: return [[RES]] : tensor<3x2x4xi32>
|
||||||
|
}
|
||||||
|
|
|
@ -12,7 +12,7 @@ func @test_default_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
// CHECK-LABEL: test_default_transpose
|
// CHECK-LABEL: test_default_transpose
|
||||||
// CHECK: [[RES:%.+]] = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32>
|
// CHECK: [[RES:%.+]] = "onnx.Transpose"(%arg0) {perm = [3, 2, 1, 0]} : (tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32>
|
||||||
// CHECK: return [[RES]] : tensor<32x1x5x5xf32>
|
// CHECK: return [[RES]] : tensor<32x1x5x5xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue