refine test and tensor
This commit is contained in:
parent
975a47f7b2
commit
84e7697c6a
32
BUILD
32
BUILD
|
@ -174,39 +174,21 @@ cc_library(
|
||||||
":disc_ral",
|
":disc_ral",
|
||||||
":hlo",
|
":hlo",
|
||||||
":lhlo",
|
":lhlo",
|
||||||
":lhlo_gpu",
|
|
||||||
":hlo_ops_builder_gen",
|
":hlo_ops_builder_gen",
|
||||||
"@llvm-project//mlir:MlirTableGenMain",
|
"@llvm-project//mlir:MlirTableGenMain",
|
||||||
"@llvm-project//mlir:Support",
|
"@llvm-project//mlir:Support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
|
"@llvm-project//mlir:AllPassesAndDialects",
|
||||||
|
# "@llvm-project//mlir:MlirOptLib",
|
||||||
|
# "@llvm-project//mlir:MlirJitRunner",
|
||||||
|
"@llvm-project//llvm:Core",
|
||||||
|
"@llvm-project//llvm:Option",
|
||||||
|
"@llvm-project//llvm:OrcJIT",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//llvm:TableGen",
|
"@llvm-project//llvm:TableGen",
|
||||||
|
"@llvm-project//llvm:Target",
|
||||||
"@llvm-project//llvm:config",
|
"@llvm-project//llvm:config",
|
||||||
# "@llvm-project//llvm:AllTargetsAsmParsers",
|
|
||||||
# "@llvm-project//llvm:AllTargetsCodeGens",
|
|
||||||
# "@llvm-project//llvm:Core",
|
|
||||||
# "@llvm-project//llvm:ExecutionEngine",
|
|
||||||
# "@llvm-project//llvm:Option",
|
|
||||||
# "@llvm-project//llvm:OrcJIT",
|
|
||||||
# "@llvm-project//llvm:Support",
|
|
||||||
# "@llvm-project//llvm:Target",
|
|
||||||
# "@llvm-project//mlir:AllPassesAndDialects",
|
|
||||||
# "@llvm-project//mlir:IR",
|
|
||||||
# "@llvm-project//mlir:MlirOptLib",
|
|
||||||
# "@llvm-project//mlir:Support",
|
|
||||||
# "@llvm-project//mlir:MlirJitRunner",
|
|
||||||
|
|
||||||
# "@llvm-project//mlir:Analysis",
|
|
||||||
# "@llvm-project//mlir:ControlFlowInterfaces",
|
|
||||||
# "@llvm-project//mlir:InferTypeOpInterface",
|
|
||||||
# "@llvm-project//mlir:MemRefDialect",
|
|
||||||
# "@llvm-project//mlir:Shape",
|
|
||||||
# "@llvm-project//mlir:SideEffects",
|
|
||||||
# "@llvm-project//mlir:StandardOps",
|
|
||||||
# "@llvm-project//mlir:TensorDialect",
|
|
||||||
# "@llvm-project//mlir:TransformUtils",
|
|
||||||
# "@llvm-project//mlir:Transforms",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -3,19 +3,20 @@
|
||||||
int main() {
|
int main() {
|
||||||
builder::Builder builder;
|
builder::Builder builder;
|
||||||
|
|
||||||
builder::Shape shape({100, 100});
|
builder::Shape shape({10, 10});
|
||||||
auto pType = builder::PrimitiveType::F32();
|
auto pType = builder::PrimitiveType::F32();
|
||||||
builder::Type type(shape, pType);
|
builder::Type type(shape, pType);
|
||||||
|
|
||||||
builder::Tensor tensor(std::vector<int>(100));
|
std::vector<float> data(100);
|
||||||
|
builder::Tensor tensor(shape, data);
|
||||||
|
|
||||||
|
auto in1 = builder.CreateInput(type);
|
||||||
auto op1 = builder::mhlo::ConstOp::build(builder, type, tensor);
|
auto op1 = builder::mhlo::ConstOp::build(builder, type, tensor);
|
||||||
auto op2 = builder::mhlo::ConstOp::build(builder, type, tensor);
|
auto op2 = builder::mhlo::ConstOp::build(builder, type, tensor);
|
||||||
auto op3 = builder::mhlo::AddOp::build(builder, type, op1, op2);
|
auto op3 = builder::mhlo::AddOp::build(builder, type, op1, op2);
|
||||||
|
auto op4 = builder::mhlo::MulOp::build(builder, type, op3, in1);
|
||||||
|
|
||||||
|
builder.SetOutput(std::vector<builder::Op>({op4}));
|
||||||
builder.DumpModule();
|
builder.DumpModule();
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// static ::builder::Op build(::builder::Builder &builder, ::builder::Type
|
|
||||||
// output, ::builder::Tensor value);
|
|
||||||
|
|
|
@ -50,6 +50,7 @@ PrimitiveType PrimitiveType::S64() {
|
||||||
|
|
||||||
Shape::Shape(std::vector<int64_t> dims)
|
Shape::Shape(std::vector<int64_t> dims)
|
||||||
: impl_(std::make_shared<Shape::Impl>(dims)) {}
|
: impl_(std::make_shared<Shape::Impl>(dims)) {}
|
||||||
|
std::vector<int64_t> Shape::GetDims() { return impl_->GetDims(); }
|
||||||
|
|
||||||
Type::Type(Shape& shape, PrimitiveType& primitiveType)
|
Type::Type(Shape& shape, PrimitiveType& primitiveType)
|
||||||
: impl_(std::make_shared<Type::Impl>(shape, primitiveType)) {}
|
: impl_(std::make_shared<Type::Impl>(shape, primitiveType)) {}
|
||||||
|
@ -70,16 +71,16 @@ Array::Array(std::vector<int64_t> value)
|
||||||
Array::Array(std::vector<std::string> value)
|
Array::Array(std::vector<std::string> value)
|
||||||
: impl_(std::make_shared<Array::Impl>(value)) {}
|
: impl_(std::make_shared<Array::Impl>(value)) {}
|
||||||
|
|
||||||
Tensor::Tensor(Shape& shape, PrimitiveType& primitiveType)
|
// Tensor::Tensor(Shape& shape, PrimitiveType& primitiveType, const void* value)
|
||||||
: impl_(std::make_shared<Impl>(shape, primitiveType)) {}
|
// : impl_(std::make_shared<Impl>(shape, primitiveType, value)) {}
|
||||||
Shape Tensor::GetShape() { return impl_->GetShape(); }
|
Shape Tensor::GetShape() { return impl_->GetShape(); }
|
||||||
PrimitiveType Tensor::GetType() { return impl_->GetType(); }
|
// PrimitiveType Tensor::GetType() { return impl_->GetType(); }
|
||||||
Tensor::Tensor(std::vector<int> value)
|
Tensor::Tensor(Shape& shape, std::vector<int> value)
|
||||||
: impl_(std::make_shared<Tensor::Impl>(value)) {}
|
: impl_(std::make_shared<Tensor::Impl>(shape, value)) {}
|
||||||
Tensor::Tensor(std::vector<int64_t> value)
|
Tensor::Tensor(Shape& shape, std::vector<int64_t> value)
|
||||||
: impl_(std::make_shared<Tensor::Impl>(value)) {}
|
: impl_(std::make_shared<Tensor::Impl>(shape, value)) {}
|
||||||
Tensor::Tensor(std::vector<float> value)
|
Tensor::Tensor(Shape& shape, std::vector<float> value)
|
||||||
: impl_(std::make_shared<Tensor::Impl>(value)) {}
|
: impl_(std::make_shared<Tensor::Impl>(shape, value)) {}
|
||||||
|
|
||||||
TensorInt::TensorInt(Shape& shape, PrimitiveType& primitiveType)
|
TensorInt::TensorInt(Shape& shape, PrimitiveType& primitiveType)
|
||||||
: impl_(std::make_shared<Impl>(shape, primitiveType)) {}
|
: impl_(std::make_shared<Impl>(shape, primitiveType)) {}
|
||||||
|
|
|
@ -17,7 +17,7 @@ class PrimitiveType {
|
||||||
inline bool operator==(const PrimitiveType& pt);
|
inline bool operator==(const PrimitiveType& pt);
|
||||||
|
|
||||||
class Impl;
|
class Impl;
|
||||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
std::shared_ptr<Impl> GetImpl() const { return impl_; }
|
||||||
PrimitiveType(std::shared_ptr<Impl>);
|
PrimitiveType(std::shared_ptr<Impl>);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -29,7 +29,8 @@ class Shape {
|
||||||
public:
|
public:
|
||||||
Shape(std::vector<int64_t> dims);
|
Shape(std::vector<int64_t> dims);
|
||||||
class Impl;
|
class Impl;
|
||||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
std::shared_ptr<Impl> GetImpl() const { return impl_; }
|
||||||
|
std::vector<int64_t> GetDims();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<Impl> impl_;
|
std::shared_ptr<Impl> impl_;
|
||||||
|
@ -75,10 +76,9 @@ class Array {
|
||||||
|
|
||||||
class Tensor {
|
class Tensor {
|
||||||
public:
|
public:
|
||||||
Tensor(Shape& shape, PrimitiveType& primitiveType);
|
Tensor(Shape& shape, std::vector<int> value);
|
||||||
Tensor(std::vector<int> value);
|
Tensor(Shape& shape, std::vector<int64_t> value);
|
||||||
Tensor(std::vector<int64_t> value);
|
Tensor(Shape& shape, std::vector<float> value);
|
||||||
Tensor(std::vector<float> value);
|
|
||||||
|
|
||||||
Shape GetShape();
|
Shape GetShape();
|
||||||
PrimitiveType GetType();
|
PrimitiveType GetType();
|
||||||
|
@ -107,7 +107,7 @@ class Type {
|
||||||
public:
|
public:
|
||||||
Type(Shape& shape, PrimitiveType& primitiveType);
|
Type(Shape& shape, PrimitiveType& primitiveType);
|
||||||
class Impl;
|
class Impl;
|
||||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
const std::shared_ptr<Impl> GetImpl() const { return impl_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<Impl> impl_;
|
std::shared_ptr<Impl> impl_;
|
||||||
|
|
|
@ -48,46 +48,54 @@ class PrimitiveType::Impl {
|
||||||
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
||||||
return mlir::NoneType::get(context);
|
return mlir::NoneType::get(context);
|
||||||
};
|
};
|
||||||
|
unitBits_ = 0;
|
||||||
break;
|
break;
|
||||||
case PRED:
|
case PRED:
|
||||||
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
||||||
return mlir::IntegerType::get(
|
return mlir::IntegerType::get(
|
||||||
context, 1, mlir::IntegerType::SignednessSemantics::Signed);
|
context, 1, mlir::IntegerType::SignednessSemantics::Signed);
|
||||||
};
|
};
|
||||||
|
unitBits_ = 1;
|
||||||
break;
|
break;
|
||||||
case S8:
|
case S8:
|
||||||
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
||||||
return mlir::IntegerType::get(
|
return mlir::IntegerType::get(
|
||||||
context, 8, mlir::IntegerType::SignednessSemantics::Signed);
|
context, 8, mlir::IntegerType::SignednessSemantics::Signed);
|
||||||
};
|
};
|
||||||
|
unitBits_ = 8;
|
||||||
break;
|
break;
|
||||||
case S16:
|
case S16:
|
||||||
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
||||||
return mlir::IntegerType::get(
|
return mlir::IntegerType::get(
|
||||||
context, 16, mlir::IntegerType::SignednessSemantics::Signed);
|
context, 16, mlir::IntegerType::SignednessSemantics::Signed);
|
||||||
};
|
};
|
||||||
|
unitBits_ = 16;
|
||||||
break;
|
break;
|
||||||
case F32:
|
case F32:
|
||||||
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
||||||
return mlir::Float32Type::get(context);
|
return mlir::Float32Type::get(context);
|
||||||
};
|
};
|
||||||
|
unitBits_ = 32;
|
||||||
break;
|
break;
|
||||||
case S32:
|
case S32:
|
||||||
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
||||||
return mlir::IntegerType::get(
|
return mlir::IntegerType::get(
|
||||||
context, 32, mlir::IntegerType::SignednessSemantics::Signed);
|
context, 32, mlir::IntegerType::SignednessSemantics::Signed);
|
||||||
};
|
};
|
||||||
|
unitBits_ = 32;
|
||||||
break;
|
break;
|
||||||
case S64:
|
case S64:
|
||||||
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
||||||
return mlir::IntegerType::get(
|
return mlir::IntegerType::get(
|
||||||
context, 64, mlir::IntegerType::SignednessSemantics::Signed);
|
context, 64, mlir::IntegerType::SignednessSemantics::Signed);
|
||||||
};
|
};
|
||||||
|
unitBits_ = 64;
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
||||||
return mlir::NoneType::get(context);
|
return mlir::NoneType::get(context);
|
||||||
};
|
};
|
||||||
|
unitBits_ = 0;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -116,13 +124,25 @@ class PrimitiveType::Impl {
|
||||||
}
|
}
|
||||||
std::function<mlir::Type(mlir::MLIRContext *context)> GetMlirType;
|
std::function<mlir::Type(mlir::MLIRContext *context)> GetMlirType;
|
||||||
|
|
||||||
|
uint64_t GetUnitBytes() const { return unitBits_ / 8; }
|
||||||
|
uint64_t GetUnitBits() const { return unitBits_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
pType t_;
|
pType t_;
|
||||||
|
uint64_t unitBits_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class Shape::Impl {
|
class Shape::Impl {
|
||||||
public:
|
public:
|
||||||
Impl(std::vector<int64_t> dims) : dims_(dims) {}
|
Impl(std::vector<int64_t> dims) : dims_(dims) {}
|
||||||
|
const std::vector<int64_t> GetDims() const { return dims_; }
|
||||||
|
int64_t GetSize() const {
|
||||||
|
int64_t size;
|
||||||
|
for (auto &d : dims_) {
|
||||||
|
size *= d;
|
||||||
|
}
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<int64_t> dims_;
|
std::vector<int64_t> dims_;
|
||||||
|
@ -132,10 +152,12 @@ class Type::Impl {
|
||||||
public:
|
public:
|
||||||
Impl(Shape &shape, PrimitiveType &primitiveType)
|
Impl(Shape &shape, PrimitiveType &primitiveType)
|
||||||
: shape_(shape), primitiveType_(primitiveType) {}
|
: shape_(shape), primitiveType_(primitiveType) {}
|
||||||
Shape GetShape() { return shape_; }
|
Shape GetShape() const { return shape_; }
|
||||||
PrimitiveType GetType() { return primitiveType_; }
|
PrimitiveType GetType() { return primitiveType_; }
|
||||||
mlir::Type GetMlirType(mlir::MLIRContext *context) {
|
mlir::Type GetMlirType(mlir::MLIRContext *context) const {
|
||||||
return mlir::NoneType::get(context);
|
return mlir::RankedTensorType::get(
|
||||||
|
llvm::ArrayRef<int64_t>(shape_.GetImpl()->GetDims()),
|
||||||
|
primitiveType_.GetImpl()->GetMlirType(context));
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -223,44 +245,53 @@ class Array::Impl {
|
||||||
|
|
||||||
class Tensor::Impl {
|
class Tensor::Impl {
|
||||||
public:
|
public:
|
||||||
Impl(std::vector<int> value)
|
Impl(Shape &shape, std::vector<int> value) : shape_(shape) {
|
||||||
: shape_({value.size()}), primitiveType_(PrimitiveType::S32()) {
|
|
||||||
GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
|
GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
|
||||||
auto type =
|
auto type = mlir::RankedTensorType::get(
|
||||||
mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
|
llvm::ArrayRef<int64_t>(shape_.GetImpl()->GetDims()),
|
||||||
mlir::IntegerType::get(context, 32));
|
mlir::IntegerType::get(context, 32));
|
||||||
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
|
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
Impl(std::vector<int64_t> value)
|
Impl(Shape &shape, std::vector<int64_t> value) : shape_(shape) {
|
||||||
: shape_({value.size()}), primitiveType_(PrimitiveType::S64()) {
|
|
||||||
GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
|
GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
|
||||||
auto type =
|
auto type = mlir::RankedTensorType::get(
|
||||||
mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
|
llvm::ArrayRef<int64_t>(shape_.GetImpl()->GetDims()),
|
||||||
mlir::IntegerType::get(context, 64));
|
mlir::IntegerType::get(context, 64));
|
||||||
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
|
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
Impl(std::vector<float> value)
|
Impl(Shape &shape, std::vector<float> value) : shape_(shape) {
|
||||||
: shape_({value.size()}), primitiveType_(PrimitiveType::F32()) {
|
|
||||||
GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
|
GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
|
||||||
auto type =
|
auto type = mlir::RankedTensorType::get(
|
||||||
mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
|
llvm::ArrayRef<int64_t>(shape_.GetImpl()->GetDims()),
|
||||||
mlir::FloatType::getF32(context));
|
mlir::FloatType::getF32(context));
|
||||||
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
|
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
Impl(Shape &shape, PrimitiveType &primitiveType)
|
|
||||||
: shape_(shape), primitiveType_(primitiveType) {}
|
// Impl(Shape &shape, PrimitiveType &primitiveType, const void *value)
|
||||||
|
// : shape_(shape), primitiveType_(primitiveType) {
|
||||||
|
// GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
|
||||||
|
// auto type = mlir::RankedTensorType::get(
|
||||||
|
// llvm::ArrayRef<int64_t>(shape_.GetImpl()->GetDims()),
|
||||||
|
// primitiveType_.GetImpl()->GetMlirType(context));
|
||||||
|
// return mlir::DenseElementsAttr::get<char>(
|
||||||
|
// type,
|
||||||
|
// llvm::ArrayRef<char>(reinterpret_cast<const char *>(value),
|
||||||
|
// shape_.GetImpl()->GetSize() *
|
||||||
|
// primitiveType_.GetImpl()->GetUnitBytes()));
|
||||||
|
// };
|
||||||
|
// }
|
||||||
|
|
||||||
Shape GetShape() { return shape_; }
|
Shape GetShape() { return shape_; }
|
||||||
PrimitiveType GetType() { return primitiveType_; }
|
// PrimitiveType GetType() { return primitiveType_; }
|
||||||
|
|
||||||
std::function<mlir::DenseElementsAttr(mlir::MLIRContext *context)> GetAttr;
|
std::function<mlir::DenseElementsAttr(mlir::MLIRContext *context)> GetAttr;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Shape shape_;
|
Shape shape_;
|
||||||
PrimitiveType primitiveType_;
|
// PrimitiveType primitiveType_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class TensorInt::Impl {
|
class TensorInt::Impl {
|
||||||
|
|
|
@ -19,6 +19,15 @@
|
||||||
namespace builder {
|
namespace builder {
|
||||||
|
|
||||||
Builder::Builder() : impl_(std::make_shared<Impl>()) {}
|
Builder::Builder() : impl_(std::make_shared<Impl>()) {}
|
||||||
|
|
||||||
|
builder::Op Builder::CreateInput(const builder::Type& type) {
|
||||||
|
return impl_->CreateInput(type);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Builder::SetOutput(const std::vector<builder::Op>& outputs) {
|
||||||
|
impl_->SetOutput(outputs);
|
||||||
|
}
|
||||||
|
|
||||||
void Builder::DumpModule() { impl_->DumpModule(); }
|
void Builder::DumpModule() { impl_->DumpModule(); }
|
||||||
|
|
||||||
} // namespace builder
|
} // namespace builder
|
||||||
|
|
|
@ -11,10 +11,13 @@ namespace builder {
|
||||||
class Builder {
|
class Builder {
|
||||||
public:
|
public:
|
||||||
Builder();
|
Builder();
|
||||||
|
void SetInput(const std::vector<builder::Op>& inputs);
|
||||||
|
builder::Op CreateInput(const builder::Type& type);
|
||||||
|
void SetOutput(const std::vector<builder::Op>& outputs);
|
||||||
void DumpModule();
|
void DumpModule();
|
||||||
|
|
||||||
class Impl;
|
class Impl;
|
||||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
std::shared_ptr<Impl> GetImpl() const { return impl_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<Impl> impl_;
|
std::shared_ptr<Impl> impl_;
|
||||||
|
|
|
@ -1,9 +1,15 @@
|
||||||
#ifndef BUILDER_BUILDERIMPL_
|
#ifndef BUILDER_BUILDERIMPL_
|
||||||
#define BUILDER_BUILDERIMPL_
|
#define BUILDER_BUILDERIMPL_
|
||||||
|
|
||||||
|
#include "Attribute.h"
|
||||||
|
#include "AttributeImpl.h"
|
||||||
#include "Builder.h"
|
#include "Builder.h"
|
||||||
|
#include "OpImpl.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
// #include "llvm/Support/InitLLVM.h"
|
// #include "llvm/Support/InitLLVM.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
#include "mlir/IR/Block.h"
|
#include "mlir/IR/Block.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
|
@ -13,18 +19,14 @@
|
||||||
#include "mlir/IR/Operation.h"
|
#include "mlir/IR/Operation.h"
|
||||||
#include "mlir/IR/Types.h"
|
#include "mlir/IR/Types.h"
|
||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
// #include "mlir/InitAllDialects.h"
|
#include "mlir/InitAllDialects.h"
|
||||||
// #include "mlir/InitAllPasses.h"
|
#include "mlir/InitAllPasses.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
|
|
||||||
|
|
||||||
namespace builder {
|
namespace builder {
|
||||||
|
|
||||||
class Builder::Impl {
|
class Builder::Impl {
|
||||||
public:
|
public:
|
||||||
Impl() : builder_(&context_) {
|
Impl() : builder_(&context_) {
|
||||||
|
|
||||||
// llvm::InitLLVM y(argc, argv);
|
// llvm::InitLLVM y(argc, argv);
|
||||||
// llvm::InitializeNativeTarget();
|
// llvm::InitializeNativeTarget();
|
||||||
// llvm::InitializeNativeTargetAsmPrinter();
|
// llvm::InitializeNativeTargetAsmPrinter();
|
||||||
|
@ -37,14 +39,14 @@ class Builder::Impl {
|
||||||
// registerDefaultTimingManagerCLOptions();
|
// registerDefaultTimingManagerCLOptions();
|
||||||
// DebugCounter::registerCLOptions();
|
// DebugCounter::registerCLOptions();
|
||||||
|
|
||||||
// mlir::registerAllPasses();
|
mlir::registerAllPasses();
|
||||||
mlir::mhlo::registerAllMhloPasses();
|
mlir::mhlo::registerAllMhloPasses();
|
||||||
mlir::lmhlo::registerAllLmhloPasses();
|
// mlir::lmhlo::registerAllLmhloPasses();
|
||||||
mlir::disc_ral::registerAllDiscRalPasses();
|
// mlir::disc_ral::registerAllDiscRalPasses();
|
||||||
|
|
||||||
mlir::DialectRegistry registry;
|
mlir::DialectRegistry registry;
|
||||||
// mlir::registerAllToLLVMIRTranslations(registry);
|
// mlir::registerAllToLLVMIRTranslations(registry);
|
||||||
// mlir::registerAllDialects(registry);
|
mlir::registerAllDialects(registry);
|
||||||
registry.insert<mlir::mhlo::MhloDialect>();
|
registry.insert<mlir::mhlo::MhloDialect>();
|
||||||
// registry.insert<mlir::chlo::HloClientDialect>();
|
// registry.insert<mlir::chlo::HloClientDialect>();
|
||||||
// registry.insert<mlir::lmhlo::LmhloDialect>();
|
// registry.insert<mlir::lmhlo::LmhloDialect>();
|
||||||
|
@ -53,9 +55,6 @@ class Builder::Impl {
|
||||||
context_.appendDialectRegistry(registry);
|
context_.appendDialectRegistry(registry);
|
||||||
context_.loadAllAvailableDialects();
|
context_.loadAllAvailableDialects();
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context_));
|
module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context_));
|
||||||
|
|
||||||
llvm::SmallVector<mlir::Type, 4> arg_types;
|
llvm::SmallVector<mlir::Type, 4> arg_types;
|
||||||
|
@ -74,6 +73,35 @@ class Builder::Impl {
|
||||||
mlir::MLIRContext* GetContext() { return &context_; }
|
mlir::MLIRContext* GetContext() { return &context_; }
|
||||||
void DumpModule() { module_.dump(); }
|
void DumpModule() { module_.dump(); }
|
||||||
|
|
||||||
|
builder::Op CreateInput(const builder::Type& type) {
|
||||||
|
mlir::BlockArgument arg =
|
||||||
|
entry_block_->addArgument(type.GetImpl()->GetMlirType(&context_));
|
||||||
|
builder::Op op;
|
||||||
|
op.GetImpl()->SetValue(arg);
|
||||||
|
return op;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetOutput(const std::vector<builder::Op>& outputs) {
|
||||||
|
llvm::SmallVector<mlir::Type, 4> arg_types;
|
||||||
|
int arg_num = entry_block_->getNumArguments();
|
||||||
|
for (int i = 0; i < arg_num; ++i) {
|
||||||
|
arg_types.push_back(entry_block_->getArgument(i).getType());
|
||||||
|
}
|
||||||
|
llvm::SmallVector<mlir::Type, 4> ret_types;
|
||||||
|
llvm::SmallVector<mlir::Value, 4> ret_vals;
|
||||||
|
for (auto& out : outputs) {
|
||||||
|
mlir::Value v = out.GetImpl()->GetResult();
|
||||||
|
ret_types.push_back(v.getType());
|
||||||
|
ret_vals.push_back(v);
|
||||||
|
}
|
||||||
|
// return all output tensors.
|
||||||
|
builder_.create<mlir::ReturnOp>(builder_.getUnknownLoc(), ret_vals);
|
||||||
|
// Update main function input/output type
|
||||||
|
mlir::FunctionType funcType =
|
||||||
|
builder_.getFunctionType(arg_types, ret_types);
|
||||||
|
main_func_.setType(funcType);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
mlir::MLIRContext context_;
|
mlir::MLIRContext context_;
|
||||||
mlir::ModuleOp module_;
|
mlir::ModuleOp module_;
|
||||||
|
|
|
@ -9,7 +9,7 @@ class Op {
|
||||||
public:
|
public:
|
||||||
Op();
|
Op();
|
||||||
class Impl;
|
class Impl;
|
||||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
std::shared_ptr<Impl> GetImpl() const { return impl_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<Impl> impl_;
|
std::shared_ptr<Impl> impl_;
|
||||||
|
|
|
@ -14,12 +14,19 @@ namespace builder {
|
||||||
|
|
||||||
class Op::Impl {
|
class Op::Impl {
|
||||||
public:
|
public:
|
||||||
Impl() = default;
|
Impl() : op_(nullptr), value_(){};
|
||||||
void SetOperation(mlir::Operation *Op) { op_ = Op; }
|
void SetOperation(mlir::Operation *Op) { op_ = Op; }
|
||||||
mlir::Value GetResult() { return op_->getResult(0); }
|
void SetValue(mlir::Value &value) { value_ = value; }
|
||||||
|
mlir::Value GetResult() {
|
||||||
|
if (op_ != nullptr)
|
||||||
|
return op_->getResult(0);
|
||||||
|
else
|
||||||
|
return value_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
mlir::Operation *op_;
|
mlir::Operation *op_;
|
||||||
|
mlir::Value value_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace builder
|
} // namespace builder
|
||||||
|
|
Loading…
Reference in New Issue