refine test and tensor

This commit is contained in:
colin.liang 2021-08-17 11:28:17 +08:00
parent 975a47f7b2
commit 84e7697c6a
10 changed files with 147 additions and 85 deletions

32
BUILD
View File

@ -174,39 +174,21 @@ cc_library(
":disc_ral",
":hlo",
":lhlo",
":lhlo_gpu",
":hlo_ops_builder_gen",
"@llvm-project//mlir:MlirTableGenMain",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:AllPassesAndDialects",
# "@llvm-project//mlir:MlirOptLib",
# "@llvm-project//mlir:MlirJitRunner",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:Option",
"@llvm-project//llvm:OrcJIT",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TableGen",
"@llvm-project//llvm:Target",
"@llvm-project//llvm:config",
# "@llvm-project//llvm:AllTargetsAsmParsers",
# "@llvm-project//llvm:AllTargetsCodeGens",
# "@llvm-project//llvm:Core",
# "@llvm-project//llvm:ExecutionEngine",
# "@llvm-project//llvm:Option",
# "@llvm-project//llvm:OrcJIT",
# "@llvm-project//llvm:Support",
# "@llvm-project//llvm:Target",
# "@llvm-project//mlir:AllPassesAndDialects",
# "@llvm-project//mlir:IR",
# "@llvm-project//mlir:MlirOptLib",
# "@llvm-project//mlir:Support",
# "@llvm-project//mlir:MlirJitRunner",
# "@llvm-project//mlir:Analysis",
# "@llvm-project//mlir:ControlFlowInterfaces",
# "@llvm-project//mlir:InferTypeOpInterface",
# "@llvm-project//mlir:MemRefDialect",
# "@llvm-project//mlir:Shape",
# "@llvm-project//mlir:SideEffects",
# "@llvm-project//mlir:StandardOps",
# "@llvm-project//mlir:TensorDialect",
# "@llvm-project//mlir:TransformUtils",
# "@llvm-project//mlir:Transforms",
],
)

View File

@ -3,19 +3,20 @@
int main() {
builder::Builder builder;
builder::Shape shape({100, 100});
builder::Shape shape({10, 10});
auto pType = builder::PrimitiveType::F32();
builder::Type type(shape, pType);
builder::Tensor tensor(std::vector<int>(100));
std::vector<float> data(100);
builder::Tensor tensor(shape, data);
auto in1 = builder.CreateInput(type);
auto op1 = builder::mhlo::ConstOp::build(builder, type, tensor);
auto op2 = builder::mhlo::ConstOp::build(builder, type, tensor);
auto op3 = builder::mhlo::AddOp::build(builder, type, op1, op2);
auto op4 = builder::mhlo::MulOp::build(builder, type, op3, in1);
builder.SetOutput(std::vector<builder::Op>({op4}));
builder.DumpModule();
return 0;
}
// static ::builder::Op build(::builder::Builder &builder, ::builder::Type
// output, ::builder::Tensor value);

View File

@ -50,6 +50,7 @@ PrimitiveType PrimitiveType::S64() {
Shape::Shape(std::vector<int64_t> dims)
: impl_(std::make_shared<Shape::Impl>(dims)) {}
std::vector<int64_t> Shape::GetDims() { return impl_->GetDims(); }
Type::Type(Shape& shape, PrimitiveType& primitiveType)
: impl_(std::make_shared<Type::Impl>(shape, primitiveType)) {}
@ -70,16 +71,16 @@ Array::Array(std::vector<int64_t> value)
Array::Array(std::vector<std::string> value)
: impl_(std::make_shared<Array::Impl>(value)) {}
Tensor::Tensor(Shape& shape, PrimitiveType& primitiveType)
: impl_(std::make_shared<Impl>(shape, primitiveType)) {}
// Tensor::Tensor(Shape& shape, PrimitiveType& primitiveType, const void* value)
// : impl_(std::make_shared<Impl>(shape, primitiveType, value)) {}
Shape Tensor::GetShape() { return impl_->GetShape(); }
PrimitiveType Tensor::GetType() { return impl_->GetType(); }
Tensor::Tensor(std::vector<int> value)
: impl_(std::make_shared<Tensor::Impl>(value)) {}
Tensor::Tensor(std::vector<int64_t> value)
: impl_(std::make_shared<Tensor::Impl>(value)) {}
Tensor::Tensor(std::vector<float> value)
: impl_(std::make_shared<Tensor::Impl>(value)) {}
// PrimitiveType Tensor::GetType() { return impl_->GetType(); }
Tensor::Tensor(Shape& shape, std::vector<int> value)
: impl_(std::make_shared<Tensor::Impl>(shape, value)) {}
Tensor::Tensor(Shape& shape, std::vector<int64_t> value)
: impl_(std::make_shared<Tensor::Impl>(shape, value)) {}
Tensor::Tensor(Shape& shape, std::vector<float> value)
: impl_(std::make_shared<Tensor::Impl>(shape, value)) {}
TensorInt::TensorInt(Shape& shape, PrimitiveType& primitiveType)
: impl_(std::make_shared<Impl>(shape, primitiveType)) {}

View File

@ -17,7 +17,7 @@ class PrimitiveType {
inline bool operator==(const PrimitiveType& pt);
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
std::shared_ptr<Impl> GetImpl() const { return impl_; }
PrimitiveType(std::shared_ptr<Impl>);
private:
@ -29,7 +29,8 @@ class Shape {
public:
Shape(std::vector<int64_t> dims);
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
std::shared_ptr<Impl> GetImpl() const { return impl_; }
std::vector<int64_t> GetDims();
private:
std::shared_ptr<Impl> impl_;
@ -75,10 +76,9 @@ class Array {
class Tensor {
public:
Tensor(Shape& shape, PrimitiveType& primitiveType);
Tensor(std::vector<int> value);
Tensor(std::vector<int64_t> value);
Tensor(std::vector<float> value);
Tensor(Shape& shape, std::vector<int> value);
Tensor(Shape& shape, std::vector<int64_t> value);
Tensor(Shape& shape, std::vector<float> value);
Shape GetShape();
PrimitiveType GetType();
@ -107,7 +107,7 @@ class Type {
public:
Type(Shape& shape, PrimitiveType& primitiveType);
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
const std::shared_ptr<Impl> GetImpl() const { return impl_; }
private:
std::shared_ptr<Impl> impl_;

View File

@ -48,46 +48,54 @@ class PrimitiveType::Impl {
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::NoneType::get(context);
};
unitBits_ = 0;
break;
case PRED:
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(
context, 1, mlir::IntegerType::SignednessSemantics::Signed);
};
unitBits_ = 1;
break;
case S8:
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(
context, 8, mlir::IntegerType::SignednessSemantics::Signed);
};
unitBits_ = 8;
break;
case S16:
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(
context, 16, mlir::IntegerType::SignednessSemantics::Signed);
};
unitBits_ = 16;
break;
case F32:
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::Float32Type::get(context);
};
unitBits_ = 32;
break;
case S32:
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(
context, 32, mlir::IntegerType::SignednessSemantics::Signed);
};
unitBits_ = 32;
break;
case S64:
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(
context, 64, mlir::IntegerType::SignednessSemantics::Signed);
};
unitBits_ = 64;
break;
default:
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::NoneType::get(context);
};
unitBits_ = 0;
break;
}
}
@ -116,13 +124,25 @@ class PrimitiveType::Impl {
}
std::function<mlir::Type(mlir::MLIRContext *context)> GetMlirType;
uint64_t GetUnitBytes() const { return unitBits_ / 8; }
uint64_t GetUnitBits() const { return unitBits_; }
private:
pType t_;
uint64_t unitBits_;
};
class Shape::Impl {
public:
Impl(std::vector<int64_t> dims) : dims_(dims) {}
const std::vector<int64_t> GetDims() const { return dims_; }
int64_t GetSize() const {
int64_t size;
for (auto &d : dims_) {
size *= d;
}
return size;
}
private:
std::vector<int64_t> dims_;
@ -132,10 +152,12 @@ class Type::Impl {
public:
Impl(Shape &shape, PrimitiveType &primitiveType)
: shape_(shape), primitiveType_(primitiveType) {}
Shape GetShape() { return shape_; }
Shape GetShape() const { return shape_; }
PrimitiveType GetType() { return primitiveType_; }
mlir::Type GetMlirType(mlir::MLIRContext *context) {
return mlir::NoneType::get(context);
mlir::Type GetMlirType(mlir::MLIRContext *context) const {
return mlir::RankedTensorType::get(
llvm::ArrayRef<int64_t>(shape_.GetImpl()->GetDims()),
primitiveType_.GetImpl()->GetMlirType(context));
}
private:
@ -223,44 +245,53 @@ class Array::Impl {
class Tensor::Impl {
public:
Impl(std::vector<int> value)
: shape_({value.size()}), primitiveType_(PrimitiveType::S32()) {
Impl(Shape &shape, std::vector<int> value) : shape_(shape) {
GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
auto type =
mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
auto type = mlir::RankedTensorType::get(
llvm::ArrayRef<int64_t>(shape_.GetImpl()->GetDims()),
mlir::IntegerType::get(context, 32));
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
};
}
Impl(std::vector<int64_t> value)
: shape_({value.size()}), primitiveType_(PrimitiveType::S64()) {
Impl(Shape &shape, std::vector<int64_t> value) : shape_(shape) {
GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
auto type =
mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
auto type = mlir::RankedTensorType::get(
llvm::ArrayRef<int64_t>(shape_.GetImpl()->GetDims()),
mlir::IntegerType::get(context, 64));
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
};
}
Impl(std::vector<float> value)
: shape_({value.size()}), primitiveType_(PrimitiveType::F32()) {
Impl(Shape &shape, std::vector<float> value) : shape_(shape) {
GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
auto type =
mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
auto type = mlir::RankedTensorType::get(
llvm::ArrayRef<int64_t>(shape_.GetImpl()->GetDims()),
mlir::FloatType::getF32(context));
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
};
}
Impl(Shape &shape, PrimitiveType &primitiveType)
: shape_(shape), primitiveType_(primitiveType) {}
// Impl(Shape &shape, PrimitiveType &primitiveType, const void *value)
// : shape_(shape), primitiveType_(primitiveType) {
// GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
// auto type = mlir::RankedTensorType::get(
// llvm::ArrayRef<int64_t>(shape_.GetImpl()->GetDims()),
// primitiveType_.GetImpl()->GetMlirType(context));
// return mlir::DenseElementsAttr::get<char>(
// type,
// llvm::ArrayRef<char>(reinterpret_cast<const char *>(value),
// shape_.GetImpl()->GetSize() *
// primitiveType_.GetImpl()->GetUnitBytes()));
// };
// }
Shape GetShape() { return shape_; }
PrimitiveType GetType() { return primitiveType_; }
// PrimitiveType GetType() { return primitiveType_; }
std::function<mlir::DenseElementsAttr(mlir::MLIRContext *context)> GetAttr;
private:
Shape shape_;
PrimitiveType primitiveType_;
// PrimitiveType primitiveType_;
};
class TensorInt::Impl {

View File

@ -19,6 +19,15 @@
namespace builder {
Builder::Builder() : impl_(std::make_shared<Impl>()) {}
builder::Op Builder::CreateInput(const builder::Type& type) {
return impl_->CreateInput(type);
}
void Builder::SetOutput(const std::vector<builder::Op>& outputs) {
impl_->SetOutput(outputs);
}
void Builder::DumpModule() { impl_->DumpModule(); }
} // namespace builder

View File

@ -11,10 +11,13 @@ namespace builder {
class Builder {
public:
Builder();
void SetInput(const std::vector<builder::Op>& inputs);
builder::Op CreateInput(const builder::Type& type);
void SetOutput(const std::vector<builder::Op>& outputs);
void DumpModule();
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
std::shared_ptr<Impl> GetImpl() const { return impl_; }
private:
std::shared_ptr<Impl> impl_;

View File

@ -1,9 +1,15 @@
#ifndef BUILDER_BUILDERIMPL_
#define BUILDER_BUILDERIMPL_
#include "Attribute.h"
#include "AttributeImpl.h"
#include "Builder.h"
#include "OpImpl.h"
#include "llvm/Support/Casting.h"
// #include "llvm/Support/InitLLVM.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
@ -13,18 +19,14 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
// #include "mlir/InitAllDialects.h"
// #include "mlir/InitAllPasses.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h"
namespace builder {
class Builder::Impl {
public:
Impl() : builder_(&context_) {
// llvm::InitLLVM y(argc, argv);
// llvm::InitializeNativeTarget();
// llvm::InitializeNativeTargetAsmPrinter();
@ -37,14 +39,14 @@ class Builder::Impl {
// registerDefaultTimingManagerCLOptions();
// DebugCounter::registerCLOptions();
// mlir::registerAllPasses();
mlir::registerAllPasses();
mlir::mhlo::registerAllMhloPasses();
mlir::lmhlo::registerAllLmhloPasses();
mlir::disc_ral::registerAllDiscRalPasses();
// mlir::lmhlo::registerAllLmhloPasses();
// mlir::disc_ral::registerAllDiscRalPasses();
mlir::DialectRegistry registry;
// mlir::registerAllToLLVMIRTranslations(registry);
// mlir::registerAllDialects(registry);
mlir::registerAllDialects(registry);
registry.insert<mlir::mhlo::MhloDialect>();
// registry.insert<mlir::chlo::HloClientDialect>();
// registry.insert<mlir::lmhlo::LmhloDialect>();
@ -53,9 +55,6 @@ class Builder::Impl {
context_.appendDialectRegistry(registry);
context_.loadAllAvailableDialects();
module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context_));
llvm::SmallVector<mlir::Type, 4> arg_types;
@ -74,6 +73,35 @@ class Builder::Impl {
mlir::MLIRContext* GetContext() { return &context_; }
void DumpModule() { module_.dump(); }
builder::Op CreateInput(const builder::Type& type) {
mlir::BlockArgument arg =
entry_block_->addArgument(type.GetImpl()->GetMlirType(&context_));
builder::Op op;
op.GetImpl()->SetValue(arg);
return op;
}
void SetOutput(const std::vector<builder::Op>& outputs) {
llvm::SmallVector<mlir::Type, 4> arg_types;
int arg_num = entry_block_->getNumArguments();
for (int i = 0; i < arg_num; ++i) {
arg_types.push_back(entry_block_->getArgument(i).getType());
}
llvm::SmallVector<mlir::Type, 4> ret_types;
llvm::SmallVector<mlir::Value, 4> ret_vals;
for (auto& out : outputs) {
mlir::Value v = out.GetImpl()->GetResult();
ret_types.push_back(v.getType());
ret_vals.push_back(v);
}
// return all output tensors.
builder_.create<mlir::ReturnOp>(builder_.getUnknownLoc(), ret_vals);
// Update main function input/output type
mlir::FunctionType funcType =
builder_.getFunctionType(arg_types, ret_types);
main_func_.setType(funcType);
}
private:
mlir::MLIRContext context_;
mlir::ModuleOp module_;

View File

@ -9,7 +9,7 @@ class Op {
public:
Op();
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
std::shared_ptr<Impl> GetImpl() const { return impl_; }
private:
std::shared_ptr<Impl> impl_;

View File

@ -14,12 +14,19 @@ namespace builder {
class Op::Impl {
public:
Impl() = default;
Impl() : op_(nullptr), value_(){};
void SetOperation(mlir::Operation *Op) { op_ = Op; }
mlir::Value GetResult() { return op_->getResult(0); }
void SetValue(mlir::Value &value) { value_ = value; }
mlir::Value GetResult() {
if (op_ != nullptr)
return op_->getResult(0);
else
return value_;
}
private:
mlir::Operation *op_;
mlir::Value value_;
};
} // namespace builder