refine test and tensor
This commit is contained in:
parent
975a47f7b2
commit
84e7697c6a
32
BUILD
32
BUILD
|
@ -174,39 +174,21 @@ cc_library(
|
|||
":disc_ral",
|
||||
":hlo",
|
||||
":lhlo",
|
||||
":lhlo_gpu",
|
||||
":hlo_ops_builder_gen",
|
||||
"@llvm-project//mlir:MlirTableGenMain",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
# "@llvm-project//mlir:MlirOptLib",
|
||||
# "@llvm-project//mlir:MlirJitRunner",
|
||||
"@llvm-project//llvm:Core",
|
||||
"@llvm-project//llvm:Option",
|
||||
"@llvm-project//llvm:OrcJIT",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//llvm:TableGen",
|
||||
"@llvm-project//llvm:Target",
|
||||
"@llvm-project//llvm:config",
|
||||
# "@llvm-project//llvm:AllTargetsAsmParsers",
|
||||
# "@llvm-project//llvm:AllTargetsCodeGens",
|
||||
# "@llvm-project//llvm:Core",
|
||||
# "@llvm-project//llvm:ExecutionEngine",
|
||||
# "@llvm-project//llvm:Option",
|
||||
# "@llvm-project//llvm:OrcJIT",
|
||||
# "@llvm-project//llvm:Support",
|
||||
# "@llvm-project//llvm:Target",
|
||||
# "@llvm-project//mlir:AllPassesAndDialects",
|
||||
# "@llvm-project//mlir:IR",
|
||||
# "@llvm-project//mlir:MlirOptLib",
|
||||
# "@llvm-project//mlir:Support",
|
||||
# "@llvm-project//mlir:MlirJitRunner",
|
||||
|
||||
# "@llvm-project//mlir:Analysis",
|
||||
# "@llvm-project//mlir:ControlFlowInterfaces",
|
||||
# "@llvm-project//mlir:InferTypeOpInterface",
|
||||
# "@llvm-project//mlir:MemRefDialect",
|
||||
# "@llvm-project//mlir:Shape",
|
||||
# "@llvm-project//mlir:SideEffects",
|
||||
# "@llvm-project//mlir:StandardOps",
|
||||
# "@llvm-project//mlir:TensorDialect",
|
||||
# "@llvm-project//mlir:TransformUtils",
|
||||
# "@llvm-project//mlir:Transforms",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -3,19 +3,20 @@
|
|||
int main() {
|
||||
builder::Builder builder;
|
||||
|
||||
builder::Shape shape({100, 100});
|
||||
builder::Shape shape({10, 10});
|
||||
auto pType = builder::PrimitiveType::F32();
|
||||
builder::Type type(shape, pType);
|
||||
|
||||
builder::Tensor tensor(std::vector<int>(100));
|
||||
std::vector<float> data(100);
|
||||
builder::Tensor tensor(shape, data);
|
||||
|
||||
auto in1 = builder.CreateInput(type);
|
||||
auto op1 = builder::mhlo::ConstOp::build(builder, type, tensor);
|
||||
auto op2 = builder::mhlo::ConstOp::build(builder, type, tensor);
|
||||
auto op3 = builder::mhlo::AddOp::build(builder, type, op1, op2);
|
||||
auto op4 = builder::mhlo::MulOp::build(builder, type, op3, in1);
|
||||
|
||||
builder.SetOutput(std::vector<builder::Op>({op4}));
|
||||
builder.DumpModule();
|
||||
return 0;
|
||||
}
|
||||
|
||||
// static ::builder::Op build(::builder::Builder &builder, ::builder::Type
|
||||
// output, ::builder::Tensor value);
|
||||
|
|
|
@ -50,6 +50,7 @@ PrimitiveType PrimitiveType::S64() {
|
|||
|
||||
Shape::Shape(std::vector<int64_t> dims)
|
||||
: impl_(std::make_shared<Shape::Impl>(dims)) {}
|
||||
std::vector<int64_t> Shape::GetDims() { return impl_->GetDims(); }
|
||||
|
||||
Type::Type(Shape& shape, PrimitiveType& primitiveType)
|
||||
: impl_(std::make_shared<Type::Impl>(shape, primitiveType)) {}
|
||||
|
@ -70,16 +71,16 @@ Array::Array(std::vector<int64_t> value)
|
|||
Array::Array(std::vector<std::string> value)
|
||||
: impl_(std::make_shared<Array::Impl>(value)) {}
|
||||
|
||||
Tensor::Tensor(Shape& shape, PrimitiveType& primitiveType)
|
||||
: impl_(std::make_shared<Impl>(shape, primitiveType)) {}
|
||||
// Tensor::Tensor(Shape& shape, PrimitiveType& primitiveType, const void* value)
|
||||
// : impl_(std::make_shared<Impl>(shape, primitiveType, value)) {}
|
||||
Shape Tensor::GetShape() { return impl_->GetShape(); }
|
||||
PrimitiveType Tensor::GetType() { return impl_->GetType(); }
|
||||
Tensor::Tensor(std::vector<int> value)
|
||||
: impl_(std::make_shared<Tensor::Impl>(value)) {}
|
||||
Tensor::Tensor(std::vector<int64_t> value)
|
||||
: impl_(std::make_shared<Tensor::Impl>(value)) {}
|
||||
Tensor::Tensor(std::vector<float> value)
|
||||
: impl_(std::make_shared<Tensor::Impl>(value)) {}
|
||||
// PrimitiveType Tensor::GetType() { return impl_->GetType(); }
|
||||
Tensor::Tensor(Shape& shape, std::vector<int> value)
|
||||
: impl_(std::make_shared<Tensor::Impl>(shape, value)) {}
|
||||
Tensor::Tensor(Shape& shape, std::vector<int64_t> value)
|
||||
: impl_(std::make_shared<Tensor::Impl>(shape, value)) {}
|
||||
Tensor::Tensor(Shape& shape, std::vector<float> value)
|
||||
: impl_(std::make_shared<Tensor::Impl>(shape, value)) {}
|
||||
|
||||
TensorInt::TensorInt(Shape& shape, PrimitiveType& primitiveType)
|
||||
: impl_(std::make_shared<Impl>(shape, primitiveType)) {}
|
||||
|
|
|
@ -17,7 +17,7 @@ class PrimitiveType {
|
|||
inline bool operator==(const PrimitiveType& pt);
|
||||
|
||||
class Impl;
|
||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||
std::shared_ptr<Impl> GetImpl() const { return impl_; }
|
||||
PrimitiveType(std::shared_ptr<Impl>);
|
||||
|
||||
private:
|
||||
|
@ -29,7 +29,8 @@ class Shape {
|
|||
public:
|
||||
Shape(std::vector<int64_t> dims);
|
||||
class Impl;
|
||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||
std::shared_ptr<Impl> GetImpl() const { return impl_; }
|
||||
std::vector<int64_t> GetDims();
|
||||
|
||||
private:
|
||||
std::shared_ptr<Impl> impl_;
|
||||
|
@ -75,10 +76,9 @@ class Array {
|
|||
|
||||
class Tensor {
|
||||
public:
|
||||
Tensor(Shape& shape, PrimitiveType& primitiveType);
|
||||
Tensor(std::vector<int> value);
|
||||
Tensor(std::vector<int64_t> value);
|
||||
Tensor(std::vector<float> value);
|
||||
Tensor(Shape& shape, std::vector<int> value);
|
||||
Tensor(Shape& shape, std::vector<int64_t> value);
|
||||
Tensor(Shape& shape, std::vector<float> value);
|
||||
|
||||
Shape GetShape();
|
||||
PrimitiveType GetType();
|
||||
|
@ -107,7 +107,7 @@ class Type {
|
|||
public:
|
||||
Type(Shape& shape, PrimitiveType& primitiveType);
|
||||
class Impl;
|
||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||
const std::shared_ptr<Impl> GetImpl() const { return impl_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<Impl> impl_;
|
||||
|
|
|
@ -48,46 +48,54 @@ class PrimitiveType::Impl {
|
|||
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
||||
return mlir::NoneType::get(context);
|
||||
};
|
||||
unitBits_ = 0;
|
||||
break;
|
||||
case PRED:
|
||||
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
||||
return mlir::IntegerType::get(
|
||||
context, 1, mlir::IntegerType::SignednessSemantics::Signed);
|
||||
};
|
||||
unitBits_ = 1;
|
||||
break;
|
||||
case S8:
|
||||
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
||||
return mlir::IntegerType::get(
|
||||
context, 8, mlir::IntegerType::SignednessSemantics::Signed);
|
||||
};
|
||||
unitBits_ = 8;
|
||||
break;
|
||||
case S16:
|
||||
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
||||
return mlir::IntegerType::get(
|
||||
context, 16, mlir::IntegerType::SignednessSemantics::Signed);
|
||||
};
|
||||
unitBits_ = 16;
|
||||
break;
|
||||
case F32:
|
||||
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
||||
return mlir::Float32Type::get(context);
|
||||
};
|
||||
unitBits_ = 32;
|
||||
break;
|
||||
case S32:
|
||||
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
||||
return mlir::IntegerType::get(
|
||||
context, 32, mlir::IntegerType::SignednessSemantics::Signed);
|
||||
};
|
||||
unitBits_ = 32;
|
||||
break;
|
||||
case S64:
|
||||
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
||||
return mlir::IntegerType::get(
|
||||
context, 64, mlir::IntegerType::SignednessSemantics::Signed);
|
||||
};
|
||||
unitBits_ = 64;
|
||||
break;
|
||||
default:
|
||||
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
||||
return mlir::NoneType::get(context);
|
||||
};
|
||||
unitBits_ = 0;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -116,13 +124,25 @@ class PrimitiveType::Impl {
|
|||
}
|
||||
std::function<mlir::Type(mlir::MLIRContext *context)> GetMlirType;
|
||||
|
||||
uint64_t GetUnitBytes() const { return unitBits_ / 8; }
|
||||
uint64_t GetUnitBits() const { return unitBits_; }
|
||||
|
||||
private:
|
||||
pType t_;
|
||||
uint64_t unitBits_;
|
||||
};
|
||||
|
||||
class Shape::Impl {
|
||||
public:
|
||||
Impl(std::vector<int64_t> dims) : dims_(dims) {}
|
||||
const std::vector<int64_t> GetDims() const { return dims_; }
|
||||
int64_t GetSize() const {
|
||||
int64_t size;
|
||||
for (auto &d : dims_) {
|
||||
size *= d;
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int64_t> dims_;
|
||||
|
@ -132,10 +152,12 @@ class Type::Impl {
|
|||
public:
|
||||
Impl(Shape &shape, PrimitiveType &primitiveType)
|
||||
: shape_(shape), primitiveType_(primitiveType) {}
|
||||
Shape GetShape() { return shape_; }
|
||||
Shape GetShape() const { return shape_; }
|
||||
PrimitiveType GetType() { return primitiveType_; }
|
||||
mlir::Type GetMlirType(mlir::MLIRContext *context) {
|
||||
return mlir::NoneType::get(context);
|
||||
mlir::Type GetMlirType(mlir::MLIRContext *context) const {
|
||||
return mlir::RankedTensorType::get(
|
||||
llvm::ArrayRef<int64_t>(shape_.GetImpl()->GetDims()),
|
||||
primitiveType_.GetImpl()->GetMlirType(context));
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -223,44 +245,53 @@ class Array::Impl {
|
|||
|
||||
class Tensor::Impl {
|
||||
public:
|
||||
Impl(std::vector<int> value)
|
||||
: shape_({value.size()}), primitiveType_(PrimitiveType::S32()) {
|
||||
Impl(Shape &shape, std::vector<int> value) : shape_(shape) {
|
||||
GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
|
||||
auto type =
|
||||
mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
|
||||
mlir::IntegerType::get(context, 32));
|
||||
auto type = mlir::RankedTensorType::get(
|
||||
llvm::ArrayRef<int64_t>(shape_.GetImpl()->GetDims()),
|
||||
mlir::IntegerType::get(context, 32));
|
||||
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
|
||||
};
|
||||
}
|
||||
Impl(std::vector<int64_t> value)
|
||||
: shape_({value.size()}), primitiveType_(PrimitiveType::S64()) {
|
||||
Impl(Shape &shape, std::vector<int64_t> value) : shape_(shape) {
|
||||
GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
|
||||
auto type =
|
||||
mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
|
||||
mlir::IntegerType::get(context, 64));
|
||||
auto type = mlir::RankedTensorType::get(
|
||||
llvm::ArrayRef<int64_t>(shape_.GetImpl()->GetDims()),
|
||||
mlir::IntegerType::get(context, 64));
|
||||
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
|
||||
};
|
||||
}
|
||||
Impl(std::vector<float> value)
|
||||
: shape_({value.size()}), primitiveType_(PrimitiveType::F32()) {
|
||||
Impl(Shape &shape, std::vector<float> value) : shape_(shape) {
|
||||
GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
|
||||
auto type =
|
||||
mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
|
||||
mlir::FloatType::getF32(context));
|
||||
auto type = mlir::RankedTensorType::get(
|
||||
llvm::ArrayRef<int64_t>(shape_.GetImpl()->GetDims()),
|
||||
mlir::FloatType::getF32(context));
|
||||
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
|
||||
};
|
||||
}
|
||||
Impl(Shape &shape, PrimitiveType &primitiveType)
|
||||
: shape_(shape), primitiveType_(primitiveType) {}
|
||||
|
||||
// Impl(Shape &shape, PrimitiveType &primitiveType, const void *value)
|
||||
// : shape_(shape), primitiveType_(primitiveType) {
|
||||
// GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
|
||||
// auto type = mlir::RankedTensorType::get(
|
||||
// llvm::ArrayRef<int64_t>(shape_.GetImpl()->GetDims()),
|
||||
// primitiveType_.GetImpl()->GetMlirType(context));
|
||||
// return mlir::DenseElementsAttr::get<char>(
|
||||
// type,
|
||||
// llvm::ArrayRef<char>(reinterpret_cast<const char *>(value),
|
||||
// shape_.GetImpl()->GetSize() *
|
||||
// primitiveType_.GetImpl()->GetUnitBytes()));
|
||||
// };
|
||||
// }
|
||||
|
||||
Shape GetShape() { return shape_; }
|
||||
PrimitiveType GetType() { return primitiveType_; }
|
||||
// PrimitiveType GetType() { return primitiveType_; }
|
||||
|
||||
std::function<mlir::DenseElementsAttr(mlir::MLIRContext *context)> GetAttr;
|
||||
|
||||
private:
|
||||
Shape shape_;
|
||||
PrimitiveType primitiveType_;
|
||||
// PrimitiveType primitiveType_;
|
||||
};
|
||||
|
||||
class TensorInt::Impl {
|
||||
|
|
|
@ -19,6 +19,15 @@
|
|||
namespace builder {
|
||||
|
||||
Builder::Builder() : impl_(std::make_shared<Impl>()) {}
|
||||
|
||||
builder::Op Builder::CreateInput(const builder::Type& type) {
|
||||
return impl_->CreateInput(type);
|
||||
}
|
||||
|
||||
void Builder::SetOutput(const std::vector<builder::Op>& outputs) {
|
||||
impl_->SetOutput(outputs);
|
||||
}
|
||||
|
||||
void Builder::DumpModule() { impl_->DumpModule(); }
|
||||
|
||||
} // namespace builder
|
||||
|
|
|
@ -11,10 +11,13 @@ namespace builder {
|
|||
class Builder {
|
||||
public:
|
||||
Builder();
|
||||
void SetInput(const std::vector<builder::Op>& inputs);
|
||||
builder::Op CreateInput(const builder::Type& type);
|
||||
void SetOutput(const std::vector<builder::Op>& outputs);
|
||||
void DumpModule();
|
||||
|
||||
class Impl;
|
||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||
std::shared_ptr<Impl> GetImpl() const { return impl_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<Impl> impl_;
|
||||
|
|
|
@ -1,9 +1,15 @@
|
|||
#ifndef BUILDER_BUILDERIMPL_
|
||||
#define BUILDER_BUILDERIMPL_
|
||||
|
||||
#include "Attribute.h"
|
||||
#include "AttributeImpl.h"
|
||||
#include "Builder.h"
|
||||
#include "OpImpl.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
// #include "llvm/Support/InitLLVM.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
@ -13,18 +19,14 @@
|
|||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
// #include "mlir/InitAllDialects.h"
|
||||
// #include "mlir/InitAllPasses.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
|
||||
#include "mlir/InitAllDialects.h"
|
||||
#include "mlir/InitAllPasses.h"
|
||||
|
||||
namespace builder {
|
||||
|
||||
class Builder::Impl {
|
||||
public:
|
||||
Impl() : builder_(&context_) {
|
||||
|
||||
// llvm::InitLLVM y(argc, argv);
|
||||
// llvm::InitializeNativeTarget();
|
||||
// llvm::InitializeNativeTargetAsmPrinter();
|
||||
|
@ -37,14 +39,14 @@ class Builder::Impl {
|
|||
// registerDefaultTimingManagerCLOptions();
|
||||
// DebugCounter::registerCLOptions();
|
||||
|
||||
// mlir::registerAllPasses();
|
||||
mlir::registerAllPasses();
|
||||
mlir::mhlo::registerAllMhloPasses();
|
||||
mlir::lmhlo::registerAllLmhloPasses();
|
||||
mlir::disc_ral::registerAllDiscRalPasses();
|
||||
// mlir::lmhlo::registerAllLmhloPasses();
|
||||
// mlir::disc_ral::registerAllDiscRalPasses();
|
||||
|
||||
mlir::DialectRegistry registry;
|
||||
// mlir::registerAllToLLVMIRTranslations(registry);
|
||||
// mlir::registerAllDialects(registry);
|
||||
mlir::registerAllDialects(registry);
|
||||
registry.insert<mlir::mhlo::MhloDialect>();
|
||||
// registry.insert<mlir::chlo::HloClientDialect>();
|
||||
// registry.insert<mlir::lmhlo::LmhloDialect>();
|
||||
|
@ -53,9 +55,6 @@ class Builder::Impl {
|
|||
context_.appendDialectRegistry(registry);
|
||||
context_.loadAllAvailableDialects();
|
||||
|
||||
|
||||
|
||||
|
||||
module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context_));
|
||||
|
||||
llvm::SmallVector<mlir::Type, 4> arg_types;
|
||||
|
@ -74,6 +73,35 @@ class Builder::Impl {
|
|||
mlir::MLIRContext* GetContext() { return &context_; }
|
||||
void DumpModule() { module_.dump(); }
|
||||
|
||||
builder::Op CreateInput(const builder::Type& type) {
|
||||
mlir::BlockArgument arg =
|
||||
entry_block_->addArgument(type.GetImpl()->GetMlirType(&context_));
|
||||
builder::Op op;
|
||||
op.GetImpl()->SetValue(arg);
|
||||
return op;
|
||||
}
|
||||
|
||||
void SetOutput(const std::vector<builder::Op>& outputs) {
|
||||
llvm::SmallVector<mlir::Type, 4> arg_types;
|
||||
int arg_num = entry_block_->getNumArguments();
|
||||
for (int i = 0; i < arg_num; ++i) {
|
||||
arg_types.push_back(entry_block_->getArgument(i).getType());
|
||||
}
|
||||
llvm::SmallVector<mlir::Type, 4> ret_types;
|
||||
llvm::SmallVector<mlir::Value, 4> ret_vals;
|
||||
for (auto& out : outputs) {
|
||||
mlir::Value v = out.GetImpl()->GetResult();
|
||||
ret_types.push_back(v.getType());
|
||||
ret_vals.push_back(v);
|
||||
}
|
||||
// return all output tensors.
|
||||
builder_.create<mlir::ReturnOp>(builder_.getUnknownLoc(), ret_vals);
|
||||
// Update main function input/output type
|
||||
mlir::FunctionType funcType =
|
||||
builder_.getFunctionType(arg_types, ret_types);
|
||||
main_func_.setType(funcType);
|
||||
}
|
||||
|
||||
private:
|
||||
mlir::MLIRContext context_;
|
||||
mlir::ModuleOp module_;
|
||||
|
|
|
@ -9,7 +9,7 @@ class Op {
|
|||
public:
|
||||
Op();
|
||||
class Impl;
|
||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||
std::shared_ptr<Impl> GetImpl() const { return impl_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<Impl> impl_;
|
||||
|
|
|
@ -14,12 +14,19 @@ namespace builder {
|
|||
|
||||
class Op::Impl {
|
||||
public:
|
||||
Impl() = default;
|
||||
Impl() : op_(nullptr), value_(){};
|
||||
void SetOperation(mlir::Operation *Op) { op_ = Op; }
|
||||
mlir::Value GetResult() { return op_->getResult(0); }
|
||||
void SetValue(mlir::Value &value) { value_ = value; }
|
||||
mlir::Value GetResult() {
|
||||
if (op_ != nullptr)
|
||||
return op_->getResult(0);
|
||||
else
|
||||
return value_;
|
||||
}
|
||||
|
||||
private:
|
||||
mlir::Operation *op_;
|
||||
mlir::Value value_;
|
||||
};
|
||||
|
||||
} // namespace builder
|
||||
|
|
Loading…
Reference in New Issue