diff --git a/BUILD b/BUILD index 4459d8b..accf524 100644 --- a/BUILD +++ b/BUILD @@ -174,39 +174,21 @@ cc_library( ":disc_ral", ":hlo", ":lhlo", - ":lhlo_gpu", ":hlo_ops_builder_gen", "@llvm-project//mlir:MlirTableGenMain", "@llvm-project//mlir:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:AllPassesAndDialects", + # "@llvm-project//mlir:MlirOptLib", + # "@llvm-project//mlir:MlirJitRunner", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:Option", + "@llvm-project//llvm:OrcJIT", "@llvm-project//llvm:Support", "@llvm-project//llvm:TableGen", + "@llvm-project//llvm:Target", "@llvm-project//llvm:config", - # "@llvm-project//llvm:AllTargetsAsmParsers", - # "@llvm-project//llvm:AllTargetsCodeGens", - # "@llvm-project//llvm:Core", - # "@llvm-project//llvm:ExecutionEngine", - # "@llvm-project//llvm:Option", - # "@llvm-project//llvm:OrcJIT", - # "@llvm-project//llvm:Support", - # "@llvm-project//llvm:Target", - # "@llvm-project//mlir:AllPassesAndDialects", - # "@llvm-project//mlir:IR", - # "@llvm-project//mlir:MlirOptLib", - # "@llvm-project//mlir:Support", - # "@llvm-project//mlir:MlirJitRunner", - - # "@llvm-project//mlir:Analysis", - # "@llvm-project//mlir:ControlFlowInterfaces", - # "@llvm-project//mlir:InferTypeOpInterface", - # "@llvm-project//mlir:MemRefDialect", - # "@llvm-project//mlir:Shape", - # "@llvm-project//mlir:SideEffects", - # "@llvm-project//mlir:StandardOps", - # "@llvm-project//mlir:TensorDialect", - # "@llvm-project//mlir:TransformUtils", - # "@llvm-project//mlir:Transforms", ], ) diff --git a/tests/mlir-tblgen-builder/test_basic.cpp b/tests/mlir-tblgen-builder/test_basic.cpp index 2604215..6b06b27 100644 --- a/tests/mlir-tblgen-builder/test_basic.cpp +++ b/tests/mlir-tblgen-builder/test_basic.cpp @@ -3,19 +3,20 @@ int main() { builder::Builder builder; - builder::Shape shape({100, 100}); + builder::Shape shape({10, 10}); auto pType = builder::PrimitiveType::F32(); builder::Type type(shape, pType); - builder::Tensor tensor(std::vector(100)); + std::vector data(100); + builder::Tensor tensor(shape, data); + auto in1 = builder.CreateInput(type); auto op1 = builder::mhlo::ConstOp::build(builder, type, tensor); auto op2 = builder::mhlo::ConstOp::build(builder, type, tensor); auto op3 = builder::mhlo::AddOp::build(builder, type, op1, op2); + auto op4 = builder::mhlo::MulOp::build(builder, type, op3, in1); + builder.SetOutput(std::vector({op4})); builder.DumpModule(); return 0; } - -// static ::builder::Op build(::builder::Builder &builder, ::builder::Type -// output, ::builder::Tensor value); diff --git a/tools/mlir-tblgen-builder/Builder/Attribute.cpp b/tools/mlir-tblgen-builder/Builder/Attribute.cpp index f2e32e8..090ed1f 100644 --- a/tools/mlir-tblgen-builder/Builder/Attribute.cpp +++ b/tools/mlir-tblgen-builder/Builder/Attribute.cpp @@ -50,6 +50,7 @@ PrimitiveType PrimitiveType::S64() { Shape::Shape(std::vector dims) : impl_(std::make_shared(dims)) {} +std::vector Shape::GetDims() { return impl_->GetDims(); } Type::Type(Shape& shape, PrimitiveType& primitiveType) : impl_(std::make_shared(shape, primitiveType)) {} @@ -70,16 +71,16 @@ Array::Array(std::vector value) Array::Array(std::vector value) : impl_(std::make_shared(value)) {} -Tensor::Tensor(Shape& shape, PrimitiveType& primitiveType) - : impl_(std::make_shared(shape, primitiveType)) {} +// Tensor::Tensor(Shape& shape, PrimitiveType& primitiveType, const void* value) +// : impl_(std::make_shared(shape, primitiveType, value)) {} Shape Tensor::GetShape() { return impl_->GetShape(); } -PrimitiveType Tensor::GetType() { return impl_->GetType(); } -Tensor::Tensor(std::vector value) - : impl_(std::make_shared(value)) {} -Tensor::Tensor(std::vector value) - : impl_(std::make_shared(value)) {} -Tensor::Tensor(std::vector value) - : impl_(std::make_shared(value)) {} +// PrimitiveType Tensor::GetType() { return impl_->GetType(); } +Tensor::Tensor(Shape& shape, std::vector value) + : impl_(std::make_shared(shape, value)) {} +Tensor::Tensor(Shape& shape, std::vector value) + : impl_(std::make_shared(shape, value)) {} +Tensor::Tensor(Shape& shape, std::vector value) + : impl_(std::make_shared(shape, value)) {} TensorInt::TensorInt(Shape& shape, PrimitiveType& primitiveType) : impl_(std::make_shared(shape, primitiveType)) {} diff --git a/tools/mlir-tblgen-builder/Builder/Attribute.h b/tools/mlir-tblgen-builder/Builder/Attribute.h index b0033cf..9ae524c 100644 --- a/tools/mlir-tblgen-builder/Builder/Attribute.h +++ b/tools/mlir-tblgen-builder/Builder/Attribute.h @@ -17,7 +17,7 @@ class PrimitiveType { inline bool operator==(const PrimitiveType& pt); class Impl; - std::shared_ptr GetImpl() { return impl_; } + std::shared_ptr GetImpl() const { return impl_; } PrimitiveType(std::shared_ptr); private: @@ -29,7 +29,8 @@ class Shape { public: Shape(std::vector dims); class Impl; - std::shared_ptr GetImpl() { return impl_; } + std::shared_ptr GetImpl() const { return impl_; } + std::vector GetDims(); private: std::shared_ptr impl_; @@ -75,10 +76,9 @@ class Array { class Tensor { public: - Tensor(Shape& shape, PrimitiveType& primitiveType); - Tensor(std::vector value); - Tensor(std::vector value); - Tensor(std::vector value); + Tensor(Shape& shape, std::vector value); + Tensor(Shape& shape, std::vector value); + Tensor(Shape& shape, std::vector value); Shape GetShape(); PrimitiveType GetType(); @@ -107,7 +107,7 @@ class Type { public: Type(Shape& shape, PrimitiveType& primitiveType); class Impl; - std::shared_ptr GetImpl() { return impl_; } + const std::shared_ptr GetImpl() const { return impl_; } private: std::shared_ptr impl_; diff --git a/tools/mlir-tblgen-builder/Builder/AttributeImpl.h b/tools/mlir-tblgen-builder/Builder/AttributeImpl.h index 928dd9d..7f2bb06 100644 --- a/tools/mlir-tblgen-builder/Builder/AttributeImpl.h +++ b/tools/mlir-tblgen-builder/Builder/AttributeImpl.h @@ -48,46 +48,54 @@ class PrimitiveType::Impl { GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type { return mlir::NoneType::get(context); }; + unitBits_ = 0; break; case PRED: GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type { return mlir::IntegerType::get( context, 1, mlir::IntegerType::SignednessSemantics::Signed); }; + unitBits_ = 1; break; case S8: GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type { return mlir::IntegerType::get( context, 8, mlir::IntegerType::SignednessSemantics::Signed); }; + unitBits_ = 8; break; case S16: GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type { return mlir::IntegerType::get( context, 16, mlir::IntegerType::SignednessSemantics::Signed); }; + unitBits_ = 16; break; case F32: GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type { return mlir::Float32Type::get(context); }; + unitBits_ = 32; break; case S32: GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type { return mlir::IntegerType::get( context, 32, mlir::IntegerType::SignednessSemantics::Signed); }; + unitBits_ = 32; break; case S64: GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type { return mlir::IntegerType::get( context, 64, mlir::IntegerType::SignednessSemantics::Signed); }; + unitBits_ = 64; break; default: GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type { return mlir::NoneType::get(context); }; + unitBits_ = 0; break; } } @@ -116,13 +124,25 @@ class PrimitiveType::Impl { } std::function GetMlirType; + uint64_t GetUnitBytes() const { return unitBits_ / 8; } + uint64_t GetUnitBits() const { return unitBits_; } + private: pType t_; + uint64_t unitBits_; }; class Shape::Impl { public: Impl(std::vector dims) : dims_(dims) {} + const std::vector GetDims() const { return dims_; } + int64_t GetSize() const { + int64_t size; + for (auto &d : dims_) { + size *= d; + } + return size; + } private: std::vector dims_; @@ -132,10 +152,12 @@ class Type::Impl { public: Impl(Shape &shape, PrimitiveType &primitiveType) : shape_(shape), primitiveType_(primitiveType) {} - Shape GetShape() { return shape_; } + Shape GetShape() const { return shape_; } PrimitiveType GetType() { return primitiveType_; } - mlir::Type GetMlirType(mlir::MLIRContext *context) { - return mlir::NoneType::get(context); + mlir::Type GetMlirType(mlir::MLIRContext *context) const { + return mlir::RankedTensorType::get( + llvm::ArrayRef(shape_.GetImpl()->GetDims()), + primitiveType_.GetImpl()->GetMlirType(context)); } private: @@ -223,44 +245,53 @@ class Array::Impl { class Tensor::Impl { public: - Impl(std::vector value) - : shape_({value.size()}), primitiveType_(PrimitiveType::S32()) { + Impl(Shape &shape, std::vector value) : shape_(shape) { GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr { - auto type = - mlir::RankedTensorType::get(llvm::ArrayRef({value.size()}), - mlir::IntegerType::get(context, 32)); + auto type = mlir::RankedTensorType::get( + llvm::ArrayRef(shape_.GetImpl()->GetDims()), + mlir::IntegerType::get(context, 32)); return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value)); }; } - Impl(std::vector value) - : shape_({value.size()}), primitiveType_(PrimitiveType::S64()) { + Impl(Shape &shape, std::vector value) : shape_(shape) { GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr { - auto type = - mlir::RankedTensorType::get(llvm::ArrayRef({value.size()}), - mlir::IntegerType::get(context, 64)); + auto type = mlir::RankedTensorType::get( + llvm::ArrayRef(shape_.GetImpl()->GetDims()), + mlir::IntegerType::get(context, 64)); return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value)); }; } - Impl(std::vector value) - : shape_({value.size()}), primitiveType_(PrimitiveType::F32()) { + Impl(Shape &shape, std::vector value) : shape_(shape) { GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr { - auto type = - mlir::RankedTensorType::get(llvm::ArrayRef({value.size()}), - mlir::FloatType::getF32(context)); + auto type = mlir::RankedTensorType::get( + llvm::ArrayRef(shape_.GetImpl()->GetDims()), + mlir::FloatType::getF32(context)); return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value)); }; } - Impl(Shape &shape, PrimitiveType &primitiveType) - : shape_(shape), primitiveType_(primitiveType) {} + + // Impl(Shape &shape, PrimitiveType &primitiveType, const void *value) + // : shape_(shape), primitiveType_(primitiveType) { + // GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr { + // auto type = mlir::RankedTensorType::get( + // llvm::ArrayRef(shape_.GetImpl()->GetDims()), + // primitiveType_.GetImpl()->GetMlirType(context)); + // return mlir::DenseElementsAttr::get( + // type, + // llvm::ArrayRef(reinterpret_cast(value), + // shape_.GetImpl()->GetSize() * + // primitiveType_.GetImpl()->GetUnitBytes())); + // }; + // } Shape GetShape() { return shape_; } - PrimitiveType GetType() { return primitiveType_; } + // PrimitiveType GetType() { return primitiveType_; } std::function GetAttr; private: Shape shape_; - PrimitiveType primitiveType_; + // PrimitiveType primitiveType_; }; class TensorInt::Impl { diff --git a/tools/mlir-tblgen-builder/Builder/Builder.cpp b/tools/mlir-tblgen-builder/Builder/Builder.cpp index 3fbff45..7f678d3 100644 --- a/tools/mlir-tblgen-builder/Builder/Builder.cpp +++ b/tools/mlir-tblgen-builder/Builder/Builder.cpp @@ -19,6 +19,15 @@ namespace builder { Builder::Builder() : impl_(std::make_shared()) {} + +builder::Op Builder::CreateInput(const builder::Type& type) { + return impl_->CreateInput(type); +} + +void Builder::SetOutput(const std::vector& outputs) { + impl_->SetOutput(outputs); +} + void Builder::DumpModule() { impl_->DumpModule(); } } // namespace builder diff --git a/tools/mlir-tblgen-builder/Builder/Builder.h b/tools/mlir-tblgen-builder/Builder/Builder.h index 5d83449..72985c3 100644 --- a/tools/mlir-tblgen-builder/Builder/Builder.h +++ b/tools/mlir-tblgen-builder/Builder/Builder.h @@ -11,10 +11,13 @@ namespace builder { class Builder { public: Builder(); + void SetInput(const std::vector& inputs); + builder::Op CreateInput(const builder::Type& type); + void SetOutput(const std::vector& outputs); void DumpModule(); class Impl; - std::shared_ptr GetImpl() { return impl_; } + std::shared_ptr GetImpl() const { return impl_; } private: std::shared_ptr impl_; diff --git a/tools/mlir-tblgen-builder/Builder/BuilderImpl.h b/tools/mlir-tblgen-builder/Builder/BuilderImpl.h index 4c1b8c0..358ab36 100644 --- a/tools/mlir-tblgen-builder/Builder/BuilderImpl.h +++ b/tools/mlir-tblgen-builder/Builder/BuilderImpl.h @@ -1,9 +1,15 @@ #ifndef BUILDER_BUILDERIMPL_ #define BUILDER_BUILDERIMPL_ +#include "Attribute.h" +#include "AttributeImpl.h" #include "Builder.h" +#include "OpImpl.h" #include "llvm/Support/Casting.h" // #include "llvm/Support/InitLLVM.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" @@ -13,18 +19,14 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" -// #include "mlir/InitAllDialects.h" -// #include "mlir/InitAllPasses.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h" +#include "mlir/InitAllDialects.h" +#include "mlir/InitAllPasses.h" namespace builder { class Builder::Impl { public: Impl() : builder_(&context_) { - // llvm::InitLLVM y(argc, argv); // llvm::InitializeNativeTarget(); // llvm::InitializeNativeTargetAsmPrinter(); @@ -37,14 +39,14 @@ class Builder::Impl { // registerDefaultTimingManagerCLOptions(); // DebugCounter::registerCLOptions(); - // mlir::registerAllPasses(); + mlir::registerAllPasses(); mlir::mhlo::registerAllMhloPasses(); - mlir::lmhlo::registerAllLmhloPasses(); - mlir::disc_ral::registerAllDiscRalPasses(); + // mlir::lmhlo::registerAllLmhloPasses(); + // mlir::disc_ral::registerAllDiscRalPasses(); mlir::DialectRegistry registry; // mlir::registerAllToLLVMIRTranslations(registry); - // mlir::registerAllDialects(registry); + mlir::registerAllDialects(registry); registry.insert(); // registry.insert(); // registry.insert(); @@ -53,9 +55,6 @@ class Builder::Impl { context_.appendDialectRegistry(registry); context_.loadAllAvailableDialects(); - - - module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context_)); llvm::SmallVector arg_types; @@ -74,6 +73,35 @@ class Builder::Impl { mlir::MLIRContext* GetContext() { return &context_; } void DumpModule() { module_.dump(); } + builder::Op CreateInput(const builder::Type& type) { + mlir::BlockArgument arg = + entry_block_->addArgument(type.GetImpl()->GetMlirType(&context_)); + builder::Op op; + op.GetImpl()->SetValue(arg); + return op; + } + + void SetOutput(const std::vector& outputs) { + llvm::SmallVector arg_types; + int arg_num = entry_block_->getNumArguments(); + for (int i = 0; i < arg_num; ++i) { + arg_types.push_back(entry_block_->getArgument(i).getType()); + } + llvm::SmallVector ret_types; + llvm::SmallVector ret_vals; + for (auto& out : outputs) { + mlir::Value v = out.GetImpl()->GetResult(); + ret_types.push_back(v.getType()); + ret_vals.push_back(v); + } + // return all output tensors. + builder_.create(builder_.getUnknownLoc(), ret_vals); + // Update main function input/output type + mlir::FunctionType funcType = + builder_.getFunctionType(arg_types, ret_types); + main_func_.setType(funcType); + } + private: mlir::MLIRContext context_; mlir::ModuleOp module_; diff --git a/tools/mlir-tblgen-builder/Builder/Op.h b/tools/mlir-tblgen-builder/Builder/Op.h index 4ce2fdf..fb248d1 100644 --- a/tools/mlir-tblgen-builder/Builder/Op.h +++ b/tools/mlir-tblgen-builder/Builder/Op.h @@ -9,7 +9,7 @@ class Op { public: Op(); class Impl; - std::shared_ptr GetImpl() { return impl_; } + std::shared_ptr GetImpl() const { return impl_; } private: std::shared_ptr impl_; diff --git a/tools/mlir-tblgen-builder/Builder/OpImpl.h b/tools/mlir-tblgen-builder/Builder/OpImpl.h index 6cb4b73..705a77c 100644 --- a/tools/mlir-tblgen-builder/Builder/OpImpl.h +++ b/tools/mlir-tblgen-builder/Builder/OpImpl.h @@ -14,12 +14,19 @@ namespace builder { class Op::Impl { public: - Impl() = default; + Impl() : op_(nullptr), value_(){}; void SetOperation(mlir::Operation *Op) { op_ = Op; } - mlir::Value GetResult() { return op_->getResult(0); } + void SetValue(mlir::Value &value) { value_ = value; } + mlir::Value GetResult() { + if (op_ != nullptr) + return op_->getResult(0); + else + return value_; + } private: mlir::Operation *op_; + mlir::Value value_; }; } // namespace builder