Lower ConstantOp (#28)
* Lower ConstantOp * Refactor the code * Edit error messages * Check whether attribute is sparse or dense during shape inference
This commit is contained in:
parent
162ac1bc32
commit
a65820940c
|
@ -47,7 +47,7 @@ OpsWithShapeInference = [
|
||||||
'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
||||||
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
|
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
|
||||||
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
|
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
|
||||||
'Sign'
|
'Sign', 'Constant'
|
||||||
]
|
]
|
||||||
|
|
||||||
# Operations supporting canonicalization.
|
# Operations supporting canonicalization.
|
||||||
|
|
|
@ -91,6 +91,7 @@ add_library(onnf_lower_frontend
|
||||||
conversion/onnx_to_krnl/tensor/padconstantvaluepad.cpp
|
conversion/onnx_to_krnl/tensor/padconstantvaluepad.cpp
|
||||||
conversion/onnx_to_krnl/tensor/transpose.cpp
|
conversion/onnx_to_krnl/tensor/transpose.cpp
|
||||||
conversion/onnx_to_krnl/tensor/unsqueeze.cpp
|
conversion/onnx_to_krnl/tensor/unsqueeze.cpp
|
||||||
|
conversion/onnx_to_krnl/tensor/constant.cpp
|
||||||
conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp)
|
conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp)
|
||||||
target_include_directories(onnf_lower_frontend
|
target_include_directories(onnf_lower_frontend
|
||||||
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
|
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
|
||||||
|
|
|
@ -97,6 +97,7 @@ void FrontendToKrnlLoweringPass::runOnModule() {
|
||||||
populateLoweringONNXUnsqueezeOpPattern(patterns, &getContext());
|
populateLoweringONNXUnsqueezeOpPattern(patterns, &getContext());
|
||||||
populateLoweringONNXTransposeOpPattern(patterns, &getContext());
|
populateLoweringONNXTransposeOpPattern(patterns, &getContext());
|
||||||
populateLoweringONNXIdentityOpPattern(patterns, &getContext());
|
populateLoweringONNXIdentityOpPattern(patterns, &getContext());
|
||||||
|
populateLoweringONNXConstantOpPattern(patterns, &getContext());
|
||||||
// Neural network
|
// Neural network
|
||||||
populateLoweringONNXConvOpPattern(patterns, &getContext());
|
populateLoweringONNXConvOpPattern(patterns, &getContext());
|
||||||
populateLoweringONNXNormalizationOpPattern(patterns, &getContext());
|
populateLoweringONNXNormalizationOpPattern(patterns, &getContext());
|
||||||
|
|
|
@ -242,3 +242,7 @@ void populateLoweringONNXReshapeOpPattern(
|
||||||
|
|
||||||
void populateLoweringONNXIdentityOpPattern(
|
void populateLoweringONNXIdentityOpPattern(
|
||||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
|
void populateLoweringONNXConstantOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,100 @@
|
||||||
|
//===---- constant.cpp - Lowering Constant Op -----------------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file lowers the ONNX Constant Operator to Krnl dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
template <typename ElementAttr>
|
||||||
|
void emitConstantAndStoreOpForDenseElementsAttr(
|
||||||
|
ConversionPatternRewriter &rewriter, Location loc,
|
||||||
|
DenseElementsAttr constantValue, ArrayRef<int64_t> valueShape,
|
||||||
|
ArrayRef<Value> constantIndices, Value alloc) {
|
||||||
|
// The following functor recursively walks the dimensions of the constant
|
||||||
|
// shape, generating a store when the recursion hits the base case.
|
||||||
|
SmallVector<Value, 2> indices;
|
||||||
|
auto valueIt = constantValue.getValues<ElementAttr>().begin();
|
||||||
|
std::function<void(uint64_t)> storeElements = [&](uint64_t dimension) {
|
||||||
|
// The last dimension is the base case of the recursion, at this point
|
||||||
|
// we store the element at the given index.
|
||||||
|
if (dimension == valueShape.size()) {
|
||||||
|
rewriter.create<AffineStoreOp>(loc,
|
||||||
|
rewriter.create<ConstantOp>(loc, *valueIt++), alloc,
|
||||||
|
llvm::makeArrayRef(indices));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, iterate over the current dimension and add the indices to
|
||||||
|
// the list.
|
||||||
|
for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) {
|
||||||
|
indices.push_back(constantIndices[i]);
|
||||||
|
storeElements(dimension + 1);
|
||||||
|
indices.pop_back();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// Start the element storing recursion from the first dimension.
|
||||||
|
storeElements(/*dimension=*/0);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ONNXConstantOpLowering : public ConversionPattern {
|
||||||
|
ONNXConstantOpLowering(MLIRContext *ctx)
|
||||||
|
: ConversionPattern(mlir::ONNXConstantOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
|
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
auto constantOp = llvm::dyn_cast<ONNXConstantOp>(op);
|
||||||
|
|
||||||
|
if (constantOp.sparse_value().hasValue()) {
|
||||||
|
emitError(loc, "Only support dense values at this time");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||||
|
|
||||||
|
Value alloc;
|
||||||
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
|
|
||||||
|
if (hasAllConstantDimensions(memRefType))
|
||||||
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
|
else
|
||||||
|
emitError(loc, "Unexpected output has non-Constant shape");
|
||||||
|
|
||||||
|
DenseElementsAttr constantValue =
|
||||||
|
constantOp.value().getValue().cast<DenseElementsAttr>();
|
||||||
|
|
||||||
|
auto valueShape = memRefType.getShape();
|
||||||
|
SmallVector<Value, 8> constantIndices;
|
||||||
|
for (auto i : llvm::seq<int64_t>(
|
||||||
|
0, *std::max_element(valueShape.begin(), valueShape.end())))
|
||||||
|
constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, i));
|
||||||
|
|
||||||
|
// The constant operation represents a multi-dimensional constant, so we
|
||||||
|
// will need to generate a store for each of the elements.
|
||||||
|
if (memRefType.getElementType().isa<IntegerType>()) {
|
||||||
|
emitConstantAndStoreOpForDenseElementsAttr<IntegerAttr>(
|
||||||
|
rewriter, loc, constantValue, valueShape, constantIndices, alloc);
|
||||||
|
} else if (memRefType.getElementType().isa<FloatType>()) {
|
||||||
|
emitConstantAndStoreOpForDenseElementsAttr<FloatAttr>(
|
||||||
|
rewriter, loc, constantValue, valueShape, constantIndices, alloc);
|
||||||
|
} else {
|
||||||
|
emitError(loc, "Unsupported output type");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace this operation with the generated alloc.
|
||||||
|
rewriter.replaceOp(op, alloc);
|
||||||
|
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void populateLoweringONNXConstantOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||||
|
patterns.insert<ONNXConstantOpLowering>(ctx);
|
||||||
|
}
|
|
@ -1171,6 +1171,23 @@ void ONNXUnsqueezeOp::inferShapes() {
|
||||||
getResult().setType(RankedTensorType::get(dims, operandTy.getElementType()));
|
getResult().setType(RankedTensorType::get(dims, operandTy.getElementType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Constant
|
||||||
|
|
||||||
|
void ONNXConstantOp::inferShapes() {
|
||||||
|
if ((sparse_value().hasValue() && value().hasValue()) ||
|
||||||
|
(!sparse_value().hasValue() && !value().hasValue()))
|
||||||
|
emitError("Require exactly one of the two attributes, either value or "
|
||||||
|
"sparse_value");
|
||||||
|
|
||||||
|
ElementsAttr valAttr;
|
||||||
|
if (sparse_value().hasValue())
|
||||||
|
valAttr = sparse_valueAttr().cast<SparseElementsAttr>();
|
||||||
|
else
|
||||||
|
valAttr = valueAttr().cast<DenseElementsAttr>();
|
||||||
|
getResult().setType(valAttr.getType());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TableGen'd op method definitions
|
// TableGen'd op method definitions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -340,7 +340,7 @@ def ONNXConcatFromSequenceOp:ONNX_Op<"ConcatFromSequence",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXConstantOp:ONNX_Op<"Constant",
|
def ONNXConstantOp:ONNX_Op<"Constant",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX Constant operation";
|
let summary = "ONNX Constant operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"A constant tensor. Exactly one of the two attributes, either value or sparse_value,"
|
"A constant tensor. Exactly one of the two attributes, either value or sparse_value,"
|
||||||
|
|
|
@ -120,6 +120,7 @@ public:
|
||||||
op->getName().getStringRef() != "onnx.PadConstantPad" &&
|
op->getName().getStringRef() != "onnx.PadConstantPad" &&
|
||||||
op->getName().getStringRef() != "onnx.PadConstantValuePad" &&
|
op->getName().getStringRef() != "onnx.PadConstantValuePad" &&
|
||||||
op->getName().getStringRef() != "onnx.BatchNormalizationTestMode" &&
|
op->getName().getStringRef() != "onnx.BatchNormalizationTestMode" &&
|
||||||
|
op->getName().getStringRef() != "onnx.Constant" &&
|
||||||
op->getName().getStringRef() != "onnx.Unsqueeze")
|
op->getName().getStringRef() != "onnx.Unsqueeze")
|
||||||
return false;
|
return false;
|
||||||
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
||||||
|
|
|
@ -1535,3 +1535,27 @@ func @test_constant_pad1(%arg0: tensor<16x16xf32>) -> tensor<18x20xf32> {
|
||||||
// CHECK: store [[LOAD]], [[RES]][%arg1, [[ADD]]] : memref<18x20xf32>
|
// CHECK: store [[LOAD]], [[RES]][%arg1, [[ADD]]] : memref<18x20xf32>
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @test_constant_dense_2d_value(%arg0: tensor<1xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Constant"() {value = dense<[[0.0, 0.0], [1.0, 1.1], [2.0, 2.1]]> : tensor<3x2xf32>} : () -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
// CHECK-LABEL: test_constant_dense_2d_value
|
||||||
|
// CHECK: [[RES:%.+]] = alloc() : memref<3x2xf32>
|
||||||
|
// CHECK: %[[INDEX_0:.+]] = constant 0 : index
|
||||||
|
// CHECK: %[[INDEX_1:.+]] = constant 1 : index
|
||||||
|
// CHECK: %[[INDEX_2:.+]] = constant 2 : index
|
||||||
|
// CHECK: [[CONSTANT_0:%.+]] = constant 0.000000e+00 : f32
|
||||||
|
// CHECK: affine.store [[CONSTANT_0]], %0[%[[INDEX_0]], %[[INDEX_0]]] : memref<3x2xf32>
|
||||||
|
// CHECK: [[CONSTANT_1:%.+]] = constant 0.000000e+00 : f32
|
||||||
|
// CHECK: affine.store [[CONSTANT_1]], %0[%[[INDEX_0]], %[[INDEX_1]]] : memref<3x2xf32>
|
||||||
|
// CHECK: [[CONSTANT_2:%.+]] = constant 1.000000e+00 : f32
|
||||||
|
// CHECK: affine.store [[CONSTANT_2]], %0[%[[INDEX_1]], %[[INDEX_0]]] : memref<3x2xf32>
|
||||||
|
// CHECK: [[CONSTANT_3:%.+]] = constant 1.100000e+00 : f32
|
||||||
|
// CHECK: affine.store [[CONSTANT_3]], %0[%[[INDEX_1]], %[[INDEX_1]]] : memref<3x2xf32>
|
||||||
|
// CHECK: [[CONSTANT_4:%.+]] = constant 2.000000e+00 : f32
|
||||||
|
// CHECK: affine.store [[CONSTANT_4]], %0[%[[INDEX_2]], %[[INDEX_0]]] : memref<3x2xf32>
|
||||||
|
// CHECK: [[CONSTANT_5:%.+]] = constant 2.100000e+00 : f32
|
||||||
|
// CHECK: affine.store [[CONSTANT_5]], %0[%[[INDEX_2]], %[[INDEX_1]]] : memref<3x2xf32>
|
||||||
|
// CHECK: return [[RES]] : memref<3x2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -296,3 +296,38 @@ func @test_PadConstantPad_1(%arg0 : tensor<16x13xf32>, %arg1 : tensor<*xf32>) ->
|
||||||
// CHECK: [[RES:%.+]] = "onnx.PadConstantPad"(%arg0, %arg1) {mode = "constant", pads = [0, 2, 3, 1]} : (tensor<16x13xf32>, tensor<*xf32>) -> tensor<18x17xf32>
|
// CHECK: [[RES:%.+]] = "onnx.PadConstantPad"(%arg0, %arg1) {mode = "constant", pads = [0, 2, 3, 1]} : (tensor<16x13xf32>, tensor<*xf32>) -> tensor<18x17xf32>
|
||||||
// CHECK: return [[RES]] : tensor<18x17xf32>
|
// CHECK: return [[RES]] : tensor<18x17xf32>
|
||||||
|
|
||||||
|
/// Test ConstantOp shape inference for 1-D dense tensor.
|
||||||
|
func @test_constant_dense_1d_value() -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Constant"() {value = dense<[0.0, 1.0, 2.0]> : tensor<3xf32>} : () -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: test_constant_dense_1d_value
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>} : () -> tensor<3xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<3xf32>
|
||||||
|
|
||||||
|
/// Test ConstantOp shape inference for 2-D dense tensor.
|
||||||
|
func @test_constant_dense_2d_value() -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Constant"() {value = dense<[[0.0, 0.0], [1.0, 1.1], [2.0, 2.1]]> : tensor<3x2xf32>} : () -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: test_constant_dense_2d_value
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<{{\[}}[0.000000e+00, 0.000000e+00], [1.000000e+00, 1.100000e+00], [2.000000e+00, 2.100000e+00{{\]}}]> : tensor<3x2xf32>} : () -> tensor<3x2xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<3x2xf32>
|
||||||
|
|
||||||
|
/// Test ConstantOp shape inference for 1-D sparse tensor.
|
||||||
|
func @test_constant_sparse_1d_value() -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Constant"() {sparse_value = sparse<[[0]], [1.0]> : tensor<3xf32>} : () -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: test_constant_sparse_1d_value
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Constant"() {sparse_value = sparse<0, 1.000000e+00> : tensor<3xf32>} : () -> tensor<3xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<3xf32>
|
||||||
|
|
||||||
|
/// Test ConstantOp shape inference for 2-D sparse tensor.
|
||||||
|
func @test_constant_sparse_2d_value() -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Constant"() {sparse_value = sparse<[[0, 1]], [2.0]> : tensor<3x2xf32>} : () -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: test_constant_sparse_2d_value
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Constant"() {sparse_value = sparse<{{\[}}[0, 1{{\]}}], 2.000000e+00> : tensor<3x2xf32>} : () -> tensor<3x2xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<3x2xf32>
|
||||||
|
|
Loading…
Reference in New Issue