onnx-mlir/test/numerical/TestConv.cpp

163 lines
5.9 KiB
C++
Raw Normal View History

#include <algorithm>
#include <cmath>
#include <iostream>
#include <random>
#include <rapidcheck.h>
#include <string>
#include <vector>
#include "mlir/IR/Module.h"
#include "llvm/Support/FileSystem.h"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/MainUtils.hpp"
#include "src/Runtime/ExecusionSession.hpp"
#define SHARED_LIB_BASE string("./TestConv_main_graph")
using namespace std;
// Returns whether onnx-mlir compiled convolution is producing the same results
// as a naive implementation of convolution for a specific set of convolution
// parameters/configuration.
bool isOMConvTheSameAsNaiveImplFor(const int N, const int C, const int H,
const int W, const int kH, const int kW, const int pHBegin, const int pHEnd,
const int pWBegin, const int pWEnd) {
registerDialects();
MLIRContext ctx;
auto module = ModuleOp::create(UnknownLoc::get(&ctx));
OpBuilder builder(&ctx);
llvm::SmallVector<int64_t, 4> xShape = {N, C, H, W};
llvm::SmallVector<int64_t, 1> bShape = {C};
llvm::SmallVector<int64_t, 4> wShape = {C, C, kH, kW};
auto xType = RankedTensorType::get(xShape, builder.getF32Type());
auto wType = RankedTensorType::get(wShape, builder.getF32Type());
auto yType = UnrankedTensorType::get(builder.getF32Type());
llvm::SmallVector<Type, 2> inputsType{xType, wType};
llvm::SmallVector<Type, 1> outputsType{yType};
auto funcType = builder.getFunctionType(inputsType, outputsType);
Compiling Models with Large Constant Arrays (#146) * PoC works. * MNist works. * Clean up. * Fix test. * Make Linux work. * Use consistent symbol name. * Fix variable name. * Fix array addr access. * Bug fix. * Bug fix. * install before running e2e tests. * Fix build config. * Use sudo when installing. * Make embeddedDataLoader position independent. * Enable ResNet50. * Format code. * Format MainUtil. * Try not using sudo to install. * Supply runtime dir via environment variable. * Dump problematic operation. * Dump entire function. * Debug. * Dump input. * Dump constant op. * Debug. * Debug. * Debug. * Print to stderr. * take care of endianness. * Use endianness-aware execution session. * Fix ZLinux error. * Include warning when desired output endianness can't be deduced. * Remove debug code. * Remove debug code in shape inference. * Support binary-decoder for testing constants packing. * Support filename, move-to-file, elision-threshold configurations in constant packing pass for easy testing. * Add lit test, fix lit test type mismatch. * Add more consts packing tests. * Ensure intermediate files are properly cleaned up. * No need for constant elimination. * Link with threading libraries. * Remove debug code. * Format code. * More tests. * test nit. * Remove debug code. * Reduce hard-coded constants. * Use temporary and unique working directory for hosting model parameters. * Test if it works. * Try to find objcopy. * Rename symbols using objcopy. * Move sanitized name to linux section. * Use verbose mode for debugging. * Disambiguate pass constructor. * Fix symbol name. * Use Command API to build and execute commands. * Move linux to use Command API. * Fix reset args. * Execute redefine sym. * Format code. * Do not use verbose mode for CircleCI. * Remove debug code. * Prettify code, add comments. * getSegmentData -> getEmbeddedConstPool * vector -> std::vector. * Make sure we properly clean up intermediate files. * Fix test cases. * Add runtime directory. * Trigger rebuild. * [Merge with master] fix debug script. * Diable affine fusion pass for now. * Support generic fallback const packing mechanism. * Remove debug code. * Handle the case where objcopy is not available. * Fix Windows missing types. * Support int64. * Copy packed constant to a local directory for non-Linux/Mac platforms. * Nit: remove debug code, refactor const pack preprocessing out as a separate function. * Cannot make preprocessConstPack a standalone function because file removers are stack-allocated, and they are deallocated prematurely when function stack gets popped, deleteing intermediate files too early. * Don't require executable filename. * Import ONNX data types directly. * Fix LIT test. * Bug fix, use moved string value. * Remove redundant filenames. * Fix CMake script. * Embed endianness information as a symbol, and check during runtime. * More comments, update lit tests. * Fix lit test on BE machine. * Copyright notices.
2020-06-12 10:27:05 +08:00
string funcName = "main_graph";
llvm::SmallVector<NamedAttribute, 1> attrs;
auto funcOp =
builder.create<FuncOp>(UnknownLoc::get(&ctx), funcName, funcType, attrs);
auto entryBlock = funcOp.addEntryBlock();
builder.setInsertionPointToStart(entryBlock);
auto xVal = entryBlock->getArgument(0);
auto wVal = entryBlock->getArgument(1);
auto bVal =
builder.create<ConstantOp>(UnknownLoc::get(&ctx), builder.getUnitAttr())
.getResult();
auto dilations = builder.getI64ArrayAttr({1, 1});
auto kernel_shape = builder.getI64ArrayAttr({kH, kW});
auto pads = builder.getI64ArrayAttr({pHBegin, pWBegin, pHEnd, pWEnd});
auto strides = builder.getI64ArrayAttr({1, 1});
auto convOp = builder.create<ONNXConvOp>(UnknownLoc::get(&ctx),
/*Y=*/yType,
/*X=*/xVal, /*W=*/wVal, /*B=*/bVal,
/*auto_pad=*/builder.getStringAttr("NOTSET"),
/*dilations=*/dilations, /*group=*/builder.getI64IntegerAttr(1),
/*kernel_shape=*/kernel_shape, /*pads=*/pads,
/*strides=*/strides);
// Use the convOp shape inference method to compute output shape, and unset
// the shape so that we don't leave IR in a inconsistent state.
convOp.inferShapes();
auto outputShape = convOp.getResult().getType().cast<ShapedType>().getShape();
auto NOut = outputShape[0];
auto COut = outputShape[1];
auto HOut = outputShape[2];
auto WOut = outputShape[3];
convOp.getResult().setType(yType);
llvm::SmallVector<Value, 1> results = {convOp.getResult()};
builder.create<ReturnOp>(UnknownLoc::get(&ctx), results);
module.push_back(funcOp);
// Emit the entry point operation which specifies the number of user
// inputs and outputs.
auto entryPoint = ONNXEntryPointOp::create(UnknownLoc::get(&ctx), funcOp,
/*numInputs=*/2,
/*numOutputs=*/1);
module.push_back(entryPoint);
OwningModuleRef moduleRef(module);
compileModule(moduleRef, ctx, SHARED_LIB_BASE, EmitLib);
onnx_mlir::ExecutionSession sess(
SHARED_LIB_BASE + ".so", "_dyn_entry_point_main_graph");
std::vector<unique_ptr<RtMemRef>> inputs;
auto xRmr = unique_ptr<RtMemRef>(getRndRealRmr<float>({N, C, H, W}));
inputs.emplace_back(move(xRmr));
auto wRmr = unique_ptr<RtMemRef>(getRndRealRmr<float>({C, C, kH, kW}));
inputs.emplace_back(move(wRmr));
auto ref = RtMemRef::create<float>({NOut, COut, HOut, WOut});
auto &img = inputs.at(0);
auto &filter = inputs.at(1);
for (int64_t n = 0; n < NOut; n++)
for (int64_t c = 0; c < COut; c++)
for (int64_t h = 0; h < HOut; h++)
for (int64_t w = 0; w < WOut; w++) {
ref->elem<float>({n, c, h, w}) = 0;
for (int64_t ci = 0; ci < C; ci++)
for (int64_t kh = 0; kh < kH; kh++)
for (int64_t kw = 0; kw < kW; kw++)
if ((h + kh - pHBegin >= 0 && h + kh - pHBegin < H) &&
(w + kw - pWBegin >= 0 && w + kw - pWBegin < W))
ref->elem<float>({n, c, h, w}) +=
img->elem<float>(
{n, ci, h + kh - pHBegin, w + kw - pWBegin}) *
filter->elem<float>({c, ci, kh, kw});
}
auto outputs = sess.run(move(inputs));
auto &conv = outputs.at(0);
return isRmrClose<float>(conv.get(), ref);
}
int main(int argc, char *argv[]) {
setExecPath(argv[0], (void *)main);
llvm::FileRemover remover(SHARED_LIB_BASE + ".so");
// RapidCheck test case generation.
rc::check("convolution implementation correctness", []() {
const auto N = *rc::gen::inRange(1, 10);
const auto C = *rc::gen::inRange(1, 20);
const auto H = *rc::gen::inRange(5, 20);
const auto W = *rc::gen::inRange(5, 20);
const auto kH = *rc::gen::inRange(1, 15);
const auto kW = *rc::gen::inRange(1, 15);
// We don't want an entire window of padding.
const auto pHBegin = *rc::gen::inRange(0, kH - 1);
const auto pHEnd = *rc::gen::inRange(0, kH - 1);
const auto pWBegin = *rc::gen::inRange(0, kW - 1);
const auto pWEnd = *rc::gen::inRange(0, kW - 1);
// Make sure we have at least 1 output per dimension.
RC_PRE((H >= kH) && (W > kW));
RC_ASSERT(isOMConvTheSameAsNaiveImplFor(
N, C, H, W, kH, kW, pHBegin, pHEnd, pWBegin, pWEnd));
});
// Exhaustive test case generation.
for (int pHBegin = 0; pHBegin < 3; pHBegin++)
for (int pHEnd = 0; pHEnd < 3; pHEnd++)
for (int pWBegin = 0; pWBegin < 3; pWBegin++)
for (int pWEnd = 0; pWEnd < 3; pWEnd++)
assert(isOMConvTheSameAsNaiveImplFor(
2, 4, 5, 5, 3, 3, pHBegin, pHEnd, pWBegin, pWEnd));
return 0;
}