add more types and try to build pass

This commit is contained in:
colin.liang 2021-08-13 15:05:10 +08:00
parent e88366b851
commit 4a9b201c4e
18 changed files with 702 additions and 339 deletions

88
BUILD
View File

@ -140,39 +140,6 @@ cc_binary(
],
)
cc_library(
name = "mlir-hlo-builder",
srcs = glob([
"tools/mlir-tblgen-builder/Builder/*.h",
"tools/mlir-tblgen-builder/Builder/*.cpp",
]),
deps = [
"@llvm-project//mlir:MlirTableGenMain",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:IR",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TableGen",
"@llvm-project//llvm:config",
],
)
cc_test(
name = "mlir-tblgen-builder-basic",
srcs = [
"tests/mlir-tblgen-builder/test_basic.cpp",
],
deps = [
":hlo_ops_builder_gen",
":mlir-hlo-builder",
"@llvm-project//mlir:MlirTableGenMain",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:IR",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TableGen",
"@llvm-project//llvm:config",
],
)
gentbl_cc_library(
name = "hlo_ops_builder_gen",
strip_include_prefix = "include",
@ -195,6 +162,61 @@ gentbl_cc_library(
deps = [":hlo_ops_td_files"],
)
cc_library(
name = "mlir-hlo-builder",
srcs = glob([
"tools/mlir-tblgen-builder/Builder/*.h",
"tools/mlir-tblgen-builder/Builder/*.cpp",
]),
includes = ["include"],
deps = [
":all_passes",
":disc_ral",
":hlo",
":lhlo",
":lhlo_gpu",
":hlo_ops_builder_gen",
"@llvm-project//mlir:MlirTableGenMain",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:IR",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TableGen",
"@llvm-project//llvm:config",
# "@llvm-project//llvm:AllTargetsAsmParsers",
# "@llvm-project//llvm:AllTargetsCodeGens",
# "@llvm-project//llvm:Core",
# "@llvm-project//llvm:ExecutionEngine",
# "@llvm-project//llvm:Option",
# "@llvm-project//llvm:OrcJIT",
# "@llvm-project//llvm:Support",
# "@llvm-project//llvm:Target",
# "@llvm-project//mlir:AllPassesAndDialects",
# "@llvm-project//mlir:IR",
# "@llvm-project//mlir:MlirOptLib",
# "@llvm-project//mlir:Pass",
# "@llvm-project//mlir:Support",
# "@llvm-project//mlir:MlirJitRunner",
],
)
cc_test(
name = "mlir-tblgen-builder-basic",
srcs = [
"tests/mlir-tblgen-builder/test_basic.cpp",
],
deps = [
":mlir-hlo-builder",
"@llvm-project//mlir:MlirTableGenMain",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:IR",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TableGen",
"@llvm-project//llvm:config",
],
)
gentbl_cc_library(
name = "hlo_ops_base_inc_gen",
strip_include_prefix = "include",

View File

@ -1,12 +1,15 @@
#include "include/mlir-hlo/Dialect/mhlo/IR/hlo_builder.h.inc"
#include "tools/mlir-tblgen-builder/Builder/Attribute.h"
#include "tools/mlir-tblgen-builder/Builder/Builder.h"
#include "tools/mlir-tblgen-builder/Builder/Op.h"
#include "tools/mlir-tblgen-builder/Builder/PrimitiveType.h"
#include "tools/mlir-tblgen-builder/Builder/Shape.h"
#include "tools/mlir-tblgen-builder/Builder/Tensor.h"
int main() {
builder::Integer i(322);
builder::Builder builder;
builder::Shape shape({100, 100});
auto pType = builder::PrimitiveType::F32();
builder::Type type(shape, pType);
builder::Tensor tensor(std::vector<int>(100));
auto op = builder::mhlo::ConstOp::build(builder, type, tensor);
return 0;
}
// static ::builder::Op build(::builder::Builder &builder, ::builder::Type
// output, ::builder::Tensor value);

View File

@ -1,9 +1,6 @@
#include "Attribute.h"
#include "AttributeImpl.h"
#include "Shape.h"
#include "Tensor.h"
#include "TensorImpl.h"
#include "llvm/Support/Casting.h"
#include "memory.h"
// #include "mlir/Dialect/StandardOps/Ops.h"
@ -15,8 +12,93 @@
namespace builder {
inline bool PrimitiveType::operator==(const PrimitiveType& pt) {
return impl_ == pt.impl_;
}
PrimitiveType::PrimitiveType(std::shared_ptr<PrimitiveType::Impl> impl)
: impl_(impl) {}
PrimitiveType PrimitiveType::PRED() {
PrimitiveType p(
std::make_shared<PrimitiveType::Impl>(PrimitiveType::Impl::pType::PRED));
return p;
}
PrimitiveType PrimitiveType::S8() {
PrimitiveType p(
std::make_shared<PrimitiveType::Impl>(PrimitiveType::Impl::pType::S8));
return p;
}
PrimitiveType PrimitiveType::S16() {
PrimitiveType p(
std::make_shared<PrimitiveType::Impl>(PrimitiveType::Impl::pType::S16));
return p;
}
PrimitiveType PrimitiveType::F32() {
PrimitiveType p(
std::make_shared<PrimitiveType::Impl>(PrimitiveType::Impl::pType::F32));
return p;
}
PrimitiveType PrimitiveType::S32() {
PrimitiveType p(
std::make_shared<PrimitiveType::Impl>(PrimitiveType::Impl::pType::S32));
return p;
}
PrimitiveType PrimitiveType::S64() {
PrimitiveType p(
std::make_shared<PrimitiveType::Impl>(PrimitiveType::Impl::pType::S64));
return p;
}
Shape::Shape(std::vector<int64_t> dims)
: impl_(std::make_shared<Shape::Impl>(dims)) {}
Type::Type(Shape& shape, PrimitiveType& primitiveType)
: impl_(std::make_shared<Type::Impl>(shape, primitiveType)) {}
Integer::Integer(int value) : impl_(std::make_shared<Integer::Impl>(value)) {}
Integer::Integer(int64_t value)
: impl_(std::make_shared<Integer::Impl>(value)) {}
Float::Float(float value) : impl_(std::make_shared<Float::Impl>(value)) {}
Float::Float(double value) : impl_(std::make_shared<Float::Impl>(value)) {}
int64_t Array::GetSize() { return impl_->GetSize(); }
PrimitiveType Array::GetType() { return impl_->GetType(); }
Array::Array(std::vector<int> value)
: impl_(std::make_shared<Array::Impl>(value)) {}
Array::Array(std::vector<int64_t> value)
: impl_(std::make_shared<Array::Impl>(value)) {}
Array::Array(std::vector<std::string> value)
: impl_(std::make_shared<Array::Impl>(value)) {}
Tensor::Tensor(Shape& shape, PrimitiveType& primitiveType)
: impl_(std::make_shared<Impl>(shape, primitiveType)) {}
Shape Tensor::GetShape() { return impl_->GetShape(); }
PrimitiveType Tensor::GetType() { return impl_->GetType(); }
Tensor::Tensor(std::vector<int> value)
: impl_(std::make_shared<Tensor::Impl>(value)) {}
Tensor::Tensor(std::vector<int64_t> value)
: impl_(std::make_shared<Tensor::Impl>(value)) {}
Tensor::Tensor(std::vector<float> value)
: impl_(std::make_shared<Tensor::Impl>(value)) {}
TensorInt::TensorInt(Shape& shape, PrimitiveType& primitiveType)
: impl_(std::make_shared<Impl>(shape, primitiveType)) {}
Shape TensorInt::GetShape() { return impl_->GetShape(); }
PrimitiveType TensorInt::GetType() { return impl_->GetType(); }
ChannelHandle::ChannelHandle()
: impl_(std::make_shared<ChannelHandle::Impl>()) {}
ConvDimensionNumbers::ConvDimensionNumbers()
: impl_(std::make_shared<ConvDimensionNumbers::Impl>()) {}
DotDimensionNumbers::DotDimensionNumbers()
: impl_(std::make_shared<DotDimensionNumbers::Impl>()) {}
GatherDimensionNumbers::GatherDimensionNumbers()
: impl_(std::make_shared<GatherDimensionNumbers::Impl>()) {}
ScatterDimensionNumbers::ScatterDimensionNumbers()
: impl_(std::make_shared<ScatterDimensionNumbers::Impl>()) {}
} // namespace builder

View File

@ -1,16 +1,57 @@
#ifndef BUILDER_ATTRIBUTE_H_
#define BUILDER_ATTRIBUTE_H_
#include <iostream>
#include <memory>
#include <vector>
namespace builder {
class PrimitiveType {
public:
static PrimitiveType PRED();
static PrimitiveType S8();
static PrimitiveType S16();
static PrimitiveType F32();
static PrimitiveType S32();
static PrimitiveType S64();
inline bool operator==(const PrimitiveType& pt);
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
PrimitiveType(std::shared_ptr<Impl>);
private:
std::shared_ptr<Impl> impl_;
// PrimitiveType(const PrimitiveType&) = default;
};
class Shape {
public:
Shape(std::vector<int64_t> dims);
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
private:
std::shared_ptr<Impl> impl_;
};
class Integer {
public:
Integer(int value);
Integer(int64_t value);
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
private:
std::shared_ptr<Impl> impl_;
};
class Float {
public:
Float(float value);
Float(double value);
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
private:
std::shared_ptr<Impl> impl_;
@ -18,6 +59,43 @@ class Integer {
class Array {
public:
Array(std::vector<int> value);
Array(std::vector<int64_t> value);
Array(std::vector<std::string> value);
int64_t GetSize();
PrimitiveType GetType();
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
private:
std::shared_ptr<Impl> impl_;
};
class Tensor {
public:
Tensor(Shape& shape, PrimitiveType& primitiveType);
Tensor(std::vector<int> value);
Tensor(std::vector<int64_t> value);
Tensor(std::vector<float> value);
Shape GetShape();
PrimitiveType GetType();
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
private:
std::shared_ptr<Impl> impl_;
};
class TensorInt {
public:
TensorInt(Shape& shape, PrimitiveType& primitiveType);
Shape GetShape();
PrimitiveType GetType();
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
@ -27,6 +105,7 @@ class Array {
class Type {
public:
Type(Shape& shape, PrimitiveType& primitiveType);
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
@ -34,16 +113,56 @@ class Type {
std::shared_ptr<Impl> impl_;
};
// template <typename T>
// class DenseIntElementsAttr {
// public:
// explicit DenseIntElementsAttr(Tensor);
// class Impl;
// std::shared_ptr<Impl> GetImpl() { return impl_; }
class ChannelHandle {
public:
ChannelHandle();
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
private:
std::shared_ptr<Impl> impl_;
};
class ConvDimensionNumbers {
public:
ConvDimensionNumbers();
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
private:
std::shared_ptr<Impl> impl_;
};
class DotDimensionNumbers {
public:
DotDimensionNumbers();
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
private:
std::shared_ptr<Impl> impl_;
};
class GatherDimensionNumbers {
public:
GatherDimensionNumbers();
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
private:
std::shared_ptr<Impl> impl_;
};
class ScatterDimensionNumbers {
public:
ScatterDimensionNumbers();
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
private:
std::shared_ptr<Impl> impl_;
};
// private:
// std::shared_ptr<Impl> impl_;
// };
} // namespace builder

View File

@ -1,8 +1,13 @@
#ifndef BUILDER_TYPEIMPL_
#define BUILDER_TYPEIMPL_
#ifndef BUILDER_ATTRIBUTEIMPL_H_
#define BUILDER_ATTRIBUTEIMPL_H_
#include <vector>
#include "Builder.h"
#include "llvm/Support/Casting.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
@ -13,45 +18,306 @@
namespace builder {
class PrimitiveType::Impl {
public:
enum pType {
PRIMITIVE_TYPE_INVALID = 0,
PRED = 1,
S8 = 2,
S16 = 3,
S32 = 4,
S64 = 5,
U8 = 6,
U16 = 7,
U32 = 8,
U64 = 9,
F16 = 10,
F32 = 11,
BF16 = 16,
F64 = 12,
C64 = 15,
C128 = 18,
TUPLE = 13,
OPAQUE_TYPE = 14,
TOKEN = 17
};
Impl(pType t) : t_(t) {
switch (t) {
case PRIMITIVE_TYPE_INVALID:
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::NoneType::get(context);
};
break;
case PRED:
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(
context, 1, mlir::IntegerType::SignednessSemantics::Signed);
};
break;
case S8:
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(
context, 8, mlir::IntegerType::SignednessSemantics::Signed);
};
break;
case S16:
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(
context, 16, mlir::IntegerType::SignednessSemantics::Signed);
};
break;
case F32:
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::Float32Type::get(context);
};
break;
case S32:
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(
context, 32, mlir::IntegerType::SignednessSemantics::Signed);
};
break;
case S64:
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(
context, 64, mlir::IntegerType::SignednessSemantics::Signed);
};
break;
default:
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::NoneType::get(context);
};
break;
}
}
// class BFloat16Type;
// class ComplexType;
// class Float128Type;
// class Float16Type;
// class Float32Type;
// class Float64Type;
// class Float80Type;
// class FunctionType;
// class IndexType;
// class IntegerType;
// class MemRefType;
// class NoneType;
// class OpaqueType;
// class RankedTensorType;
// class TupleType;
// class UnrankedMemRefType;
// class UnrankedTensorType;
// class VectorType;
inline bool operator==(const PrimitiveType::Impl &impl) {
return t_ == impl.t_;
}
std::function<mlir::Type(mlir::MLIRContext *context)> GetMlirType;
private:
pType t_;
};
class Shape::Impl {
public:
Impl(std::vector<int64_t> dims) : dims_(dims) {}
private:
std::vector<int64_t> dims_;
};
class Type::Impl {
public:
Impl(Shape &shape, PrimitiveType &primitiveType)
: shape_(shape), primitiveType_(primitiveType) {}
Shape GetShape() { return shape_; }
PrimitiveType GetType() { return primitiveType_; }
private:
Shape shape_;
PrimitiveType primitiveType_;
};
class Integer::Impl {
public:
Impl(){};
Impl(int value) {
GetAttr = [=](mlir::MLIRContext *context) -> mlir::IntegerAttr {
auto int32_type = mlir::IntegerType::get(context, 32);
return mlir::IntegerAttr::get(int32_type, value);
};
};
Impl(int64_t value)
: GetAttr([=](mlir::MLIRContext *context) -> mlir::IntegerAttr {
auto int32_type = mlir::IntegerType::get(context, 64);
return mlir::IntegerAttr::get(int32_type, value);
}){};
Impl(int64_t value) {
GetAttr = [=](mlir::MLIRContext *context) -> mlir::IntegerAttr {
auto int64_type = mlir::IntegerType::get(context, 64);
return mlir::IntegerAttr::get(int64_type, value);
};
};
std::function<mlir::IntegerAttr(mlir::MLIRContext *context)> GetAttr;
private:
};
class Float::Impl {
public:
Impl(float value) {
GetAttr = [=](mlir::MLIRContext *context) -> mlir::FloatAttr {
auto float32_type = mlir::Float32Type::get(context);
return mlir::FloatAttr::get(float32_type, value);
};
};
Impl(double value) {
GetAttr = [=](mlir::MLIRContext *context) -> mlir::FloatAttr {
auto float64_type = mlir::Float64Type::get(context);
return mlir::FloatAttr::get(float64_type, value);
};
};
std::function<mlir::FloatAttr(mlir::MLIRContext *context)> GetAttr;
private:
};
class Array::Impl {
public:
Impl() = default;
Impl(std::vector<int> value)
: size_({value.size()}), primitiveType_(PrimitiveType::S32()) {
// GetAttr = [=](mlir::MLIRContext *context) -> mlir::ArrayAttr {
// auto type =
// mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
// mlir::IntegerType::get(context, 32));
// return mlir::ArrayAttr::get(type, llvm::makeArrayRef(value));
// };
}
Impl(std::vector<int64_t> value)
: size_({value.size()}), primitiveType_(PrimitiveType::S64()) {
// GetAttr = [=](mlir::MLIRContext *context) -> mlir::ArrayAttr {
// auto type =
// mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
// mlir::IntegerType::get(context, 64));
// return mlir::ArrayAttr::get(type, llvm::makeArrayRef(value));
// };
}
Impl(std::vector<std::string> value)
: size_({value.size()}), primitiveType_(PrimitiveType::F32()) {
// GetAttr = [=](mlir::MLIRContext *context) -> mlir::ArrayAttr {
// auto type =
// mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
// mlir::FloatType::getF32(context));
// return mlir::ArrayAttr::get(type, llvm::makeArrayRef(value));
// };
}
int64_t GetSize() { return size_; }
PrimitiveType GetType() { return primitiveType_; }
std::function<mlir::ArrayAttr(mlir::MLIRContext *context)> GetAttr;
private:
int64_t size_;
PrimitiveType primitiveType_;
};
class Type::Impl {
class Tensor::Impl {
public:
Impl() = default;
Impl(std::vector<int> value)
: shape_({value.size()}), primitiveType_(PrimitiveType::S32()) {
GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
auto type =
mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
mlir::IntegerType::get(context, 32));
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
};
}
Impl(std::vector<int64_t> value)
: shape_({value.size()}), primitiveType_(PrimitiveType::S64()) {
GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
auto type =
mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
mlir::IntegerType::get(context, 64));
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
};
}
Impl(std::vector<float> value)
: shape_({value.size()}), primitiveType_(PrimitiveType::F32()) {
GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
auto type =
mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
mlir::FloatType::getF32(context));
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
};
}
Impl(Shape &shape, PrimitiveType &primitiveType)
: shape_(shape), primitiveType_(primitiveType) {}
Shape GetShape() { return shape_; }
PrimitiveType GetType() { return primitiveType_; }
std::function<mlir::DenseElementsAttr(mlir::MLIRContext *context)> GetAttr;
private:
Shape shape_;
PrimitiveType primitiveType_;
};
class TensorInt::Impl {
public:
Impl(Shape &shape, PrimitiveType &primitiveType)
: shape_(shape), primitiveType_(primitiveType) {}
Shape GetShape() { return shape_; }
PrimitiveType GetType() { return primitiveType_; }
std::function<mlir::DenseIntElementsAttr(mlir::MLIRContext *context)> GetAttr;
private:
Shape shape_;
PrimitiveType primitiveType_;
};
class ChannelHandle::Impl {
public:
Impl() {}
std::function<mlir::mhlo::ChannelHandle(mlir::MLIRContext *context)> GetAttr;
private:
};
// class DenseIntElementsAttr::Impl {
// public:
// Impl() = default;
class ConvDimensionNumbers::Impl {
public:
Impl() {}
std::function<mlir::mhlo::ConvDimensionNumbers(mlir::MLIRContext *context)>
GetAttr;
// private:
// };
private:
};
class DotDimensionNumbers::Impl {
public:
Impl() {}
std::function<mlir::mhlo::DotDimensionNumbers(mlir::MLIRContext *context)>
GetAttr;
private:
};
class GatherDimensionNumbers::Impl {
public:
Impl() {}
std::function<mlir::mhlo::GatherDimensionNumbers(mlir::MLIRContext *context)>
GetAttr;
private:
};
class ScatterDimensionNumbers::Impl {
public:
Impl() {}
std::function<mlir::mhlo::ScatterDimensionNumbers(mlir::MLIRContext *context)>
GetAttr;
private:
};
} // namespace builder

View File

@ -8,10 +8,20 @@
// #include "mlir/IR/StandardTypes.h"
// #include "mlir/IR/Types.h"
// #include "mlir/IR/Value.h"
#include "Attribute.h"
#include "AttributeImpl.h"
#include "Op.h"
#include "OpImpl.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
namespace builder {
Builder::Builder() : impl_(std::make_shared<Impl>()) {}
void Builder::DumpModule() {}
} // namespace builder
} // namespace builder
#define GET_OP_CLASSES
#include "include/mlir-hlo/Dialect/mhlo/IR/hlo_builder.cc.inc"

View File

@ -3,14 +3,17 @@
#include <memory>
#include "tools/mlir-tblgen-builder/Builder/Attribute.h"
#include "tools/mlir-tblgen-builder/Builder/Op.h"
namespace builder {
class Builder {
public:
class Impl;
Builder();
void DumpModule();
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
private:
@ -18,4 +21,7 @@ class Builder {
};
} // namespace builder
#define GET_OP_CLASSES
#include "include/mlir-hlo/Dialect/mhlo/IR/hlo_builder.h.inc"
#endif

View File

@ -4,7 +4,10 @@
#include "Builder.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
@ -14,15 +17,30 @@ namespace builder {
class Builder::Impl {
public:
Impl() {}
// mlir::Location GetLoc() { return mlir_loc_; }
// mlir::OpBuilder GetBuilder() { return mlir_builder_; }
mlir::MLIRContext *GetContext() { return &mlir_context_; }
Impl() : builder_(&context_) {
module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context_));
llvm::SmallVector<mlir::Type, 4> arg_types;
// Create the main function.
mlir::FunctionType funcType = builder_.getFunctionType(arg_types, {});
main_func_ = mlir::FuncOp::create(builder_.getUnknownLoc(), "main",
funcType, /* attrs = */ {});
entry_block_ = main_func_.addEntryBlock();
builder_.setInsertionPointToStart(entry_block_);
module_.push_back(main_func_);
}
mlir::Location GetLoc() { return builder_.getUnknownLoc(); }
mlir::OpBuilder GetBuilder() { return builder_; }
mlir::MLIRContext* GetContext() { return &context_; }
private:
// mlir::Location mlir_loc_;
// mlir::OpBuilder mlir_builder_;
mlir::MLIRContext mlir_context_;
mlir::MLIRContext context_;
mlir::ModuleOp module_;
mlir::OpBuilder builder_;
mlir::FuncOp main_func_;
mlir::Block* entry_block_;
};
} // namespace builder

View File

@ -6,6 +6,7 @@
namespace builder {
class Op {
public:
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }

View File

@ -15,10 +15,11 @@ namespace builder {
class Op::Impl {
public:
Impl() = default;
void SetOperation(Operation *Op) { op_ = Op; }
void SetOperation(mlir::Operation *Op) { op_ = Op; }
mlir::Value GetResult() { return op_->getResult(0); }
private:
Operation *op_;
mlir::Operation *op_;
};
} // namespace builder

View File

@ -1,28 +0,0 @@
#include "PrimitiveType.h"
#include "Shape.h"
#include "Tensor.h"
#include "llvm/Support/Casting.h"
// #include "mlir/Dialect/StandardOps/Ops.h"
// #include "mlir/IR/Attributes.h"
// #include "mlir/IR/Operation.h"
// #include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
// #include "mlir/IR/Value.h"
namespace builder {
// constexpr mlir::Type GetMlirType(const PrimitiveType primitiveType) {
// switch (primitiveType)
// {
// case PrimitiveType::BF16 :
// return
// break;
// default:
// break;
// }
// }
} // namespace builder

View File

@ -1,29 +0,0 @@
#ifndef BUILDER_PRIMITIVE_TYPE_
#define BUILDER_PRIMITIVE_TYPE_
namespace builder {
enum class PrimitiveType {
PRIMITIVE_TYPE_INVALID = 0,
PRED = 1,
S8 = 2,
S16 = 3,
S32 = 4,
S64 = 5,
U8 = 6,
U16 = 7,
U32 = 8,
U64 = 9,
F16 = 10,
F32 = 11,
BF16 = 16,
F64 = 12,
C64 = 15,
C128 = 18,
TUPLE = 13,
OPAQUE_TYPE = 14,
TOKEN = 17
};
}
#endif

View File

@ -1,17 +0,0 @@
#ifndef BUILDER_SHAPE_
#define BUILDER_SHAPE_
#include <iostream>
#include <memory>
namespace builder {
class Shape {
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
private:
std::shared_ptr<Impl> impl_;
};
} // namespace builder
#endif

View File

@ -1,24 +0,0 @@
#ifndef BUILDER_SHAPEIMPL_
#define BUILDER_SHAPEIMPL_
#include "Builder.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
namespace builder {
class Shape::Impl {
public:
Impl() = default;
private:
};
} // namespace builder
#endif

View File

@ -1,22 +0,0 @@
#include "Tensor.h"
#include "Shape.h"
#include "TensorImpl.h"
#include "llvm/Support/Casting.h"
// #include "mlir/Dialect/StandardOps/Ops.h"
// #include "mlir/IR/Attributes.h"
// #include "mlir/IR/Operation.h"
// #include "mlir/IR/StandardTypes.h"
// #include "mlir/IR/Types.h"
// #include "mlir/IR/Value.h"
namespace builder {
// Tensor::Tensor() : impl_(std::make_shared<Impl>()) {}
Tensor::Tensor(Shape& shape, PrimitiveType& primitiveType)
: impl_(std::make_shared<Impl>(shape, primitiveType)) {}
Shape Tensor::GetShape() { return impl_->GetShape(); }
PrimitiveType Tensor::GetType() { return impl_->GetType(); }
} // namespace builder

View File

@ -1,25 +0,0 @@
#ifndef BUILDER_TENSOR_
#define BUILDER_TENSOR_
#include <iostream>
#include <memory>
#include "PrimitiveType.h"
#include "Shape.h"
namespace builder {
class Tensor {
public:
Tensor(Shape& shape, PrimitiveType& primitiveType);
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
Shape GetShape();
PrimitiveType GetType();
private:
std::shared_ptr<Impl> impl_;
};
} // namespace builder
#endif

View File

@ -1,32 +0,0 @@
#ifndef BUILDER_TENSORIMPL_
#define BUILDER_TENSORIMPL_
#include "Builder.h"
#include "PrimitiveType.h"
#include "Shape.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
namespace builder {
class Tensor::Impl {
public:
Impl() = default;
Impl(Shape& shape, PrimitiveType& primitiveType)
: shape_(shape), primitiveType_(primitiveType) {}
Shape GetShape() { return shape_; }
PrimitiveType GetType() { return primitiveType_; }
private:
Shape shape_;
PrimitiveType primitiveType_;
};
} // namespace builder
#endif

View File

@ -52,81 +52,93 @@ static const std::map<std::string, mlirTypeWrap> typeMapMLIR = {
{"::mlir::StringAttr",
{"std::string",
[](std::string &var, OpMethodBody &body) -> std::string {
body << " mlir::StringAttr " << var << "_mlir = mlir::StringAttr::get("
<< var << ", ctx);\n";
body << " mlir::StringAttr " << var
<< "_mlir = mlir::StringAttr::get(ctx, mlir::Twine(" << var
<< "));\n";
return var + "_mlir";
}}},
{"::mlir::IntegerAttr",
{"builder::Integer",
[](std::string &var, OpMethodBody &body) -> std::string {
// body << " mlir::IntegerAttr " << var << "_mlir = mlir::IntegerAttr::get("
// << var << ", ctx);\n";
body << " mlir::IntegerAttr " << var << "_mlir = " << var
<< ".GetAttr(ctx);\n";
return var + "_mlir";
}}},
{"::mlir::DenseIntElementsAttr",
{"std::vector<int>",
[](std::string &var, OpMethodBody &body) -> std::string {
body << " mlir::DenseIntElementsAttr " << var
<< "_mlir = mlir::DenseIntElementsAttr::get("
<< "mlir::VectorType::get(" << var
<< ".size(), opBuilder->getIntegerType(32))," << var << ");\n";
return var + "_mlir";
}}},
{"::mlir::mhlo::ChannelHandle",
{"ChannelHandle",
[](std::string &var, OpMethodBody &body) -> std::string {
<< ".GetImpl()->GetAttr(ctx);\n";
return var + "_mlir";
}}},
{"::mlir::FloatAttr",
{"float",
{"builder::Float",
[](std::string &var, OpMethodBody &body) -> std::string {
body << " mlir::FloatAttr " << var << "_mlir = mlir::FloatAttr::get("
<< var << ", ctx);\n";
body << " mlir::FloatAttr " << var << "_mlir = " << var
<< ".GetImpl()->GetAttr(ctx);\n";
return var + "_mlir";
}}},
{"::mlir::DenseIntElementsAttr",
{"builder::TensorInt",
[](std::string &var, OpMethodBody &body) -> std::string {
body << " mlir::DenseIntElementsAttr " << var << "_mlir = " << var
<< ".GetImpl()->GetAttr(ctx);\n";
return var + "_mlir";
}}},
{"::mlir::mhlo::ChannelHandle",
{"builder::ChannelHandle",
[](std::string &var, OpMethodBody &body) -> std::string {
body << " mlir::mhlo::ChannelHandle " << var << "_mlir = " << var
<< ".GetImpl()->GetAttr(ctx);\n";
return var + "_mlir";
}}},
{"::mlir::BoolAttr",
{"bool",
[](std::string &var, OpMethodBody &body) -> std::string {
body << " mlir::BoolAttr " << var << "_mlir = mlir::BoolAttr::get("
<< var << ", ctx);\n";
body << " mlir::BoolAttr " << var
<< "_mlir = mlir::BoolAttr::get(ctx, " << var << ");\n";
return var + "_mlir";
}}},
{"::mlir::ElementsAttr",
{"::builder::Array",
{"builder::Tensor",
[](std::string &var, OpMethodBody &body) -> std::string {
body << " mlir::DenseElementsAttr " << var << "_mlir = " << var
<< ".GetImpl()->GetAttr(ctx);\n";
return var + "_mlir";
}}},
{"::mlir::DenseElementsAttr",
{"::builder::Tensor",
{"builder::Tensor",
[](std::string &var, OpMethodBody &body) -> std::string {
body << " mlir::DenseElementsAttr " << var << "_mlir = " << var
<< ".GetImpl()->GetAttr(ctx);\n";
return var + "_mlir";
}}},
// current only support string array.
{"::mlir::ArrayAttr",
{"std::vector<std::string>",
{"builder::Array",
[](std::string &var, OpMethodBody &body) -> std::string {
body << " mlir::ArrayAttr " << var << "_mlir = " << var
<< ".GetImpl()->GetAttr(ctx);\n";
return var + "_mlir";
}}},
{"::mlir::mhlo::ConvDimensionNumbers",
{"::builder::ConvDimensionNumbers",
{"builder::ConvDimensionNumbers",
[](std::string &var, OpMethodBody &body) -> std::string {
body << " mlir::mhlo::ConvDimensionNumbers " << var
<< "_mlir = " << var << ".GetImpl()->GetAttr(ctx);\n";
return var + "_mlir";
}}},
{"::mlir::mhlo::DotDimensionNumbers",
{"::builder::DotDimensionNumbers",
{"builder::DotDimensionNumbers",
[](std::string &var, OpMethodBody &body) -> std::string {
body << " mlir::mhlo::DotDimensionNumbers " << var << "_mlir = " << var
<< ".GetImpl()->GetAttr(ctx);\n";
return var + "_mlir";
}}},
{"::mlir::mhlo::GatherDimensionNumbers",
{"::builder::GatherDimensionNumbers",
{"builder::GatherDimensionNumbers",
[](std::string &var, OpMethodBody &body) -> std::string {
body << " mlir::mhlo::GatherDimensionNumbers " << var
<< "_mlir = " << var << ".GetImpl()->GetAttr(ctx);\n";
return var + "_mlir";
}}},
{"::mlir::mhlo::ScatterDimensionNumbers",
{"::builder::ScatterDimensionNumbers",
{"builder::ScatterDimensionNumbers",
[](std::string &var, OpMethodBody &body) -> std::string {
body << " mlir::mhlo::ScatterDimensionNumbers " << var
<< "_mlir = " << var << ".GetImpl()->GetAttr(ctx);\n";
return var + "_mlir";
}}},
};
@ -138,16 +150,16 @@ static const std::map<std::string, mlirTypeWrap> typeMapMLIR = {
// {"::mlir::mhlo::ChannelHandle", "ChannelHandle"},
// {"::mlir::FloatAttr", "float"},
// {"::mlir::BoolAttr", "bool"},
// {"::mlir::ElementsAttr", "::builder::Array"},
// {"::mlir::DenseElementsAttr", "::builder::Tensor"},
// {"::mlir::ElementsAttr", "builder::Array"},
// {"::mlir::DenseElementsAttr", "builder::Tensor"},
// // current only support string array.
// {"::mlir::ArrayAttr", "std::vector<std::string>"},
// {"::mlir::mhlo::ConvDimensionNumbers", "::builder::ConvDimensionNumbers"},
// {"::mlir::mhlo::DotDimensionNumbers", "::builder::DotDimensionNumbers"},
// {"::mlir::mhlo::ConvDimensionNumbers", "builder::ConvDimensionNumbers"},
// {"::mlir::mhlo::DotDimensionNumbers", "builder::DotDimensionNumbers"},
// {"::mlir::mhlo::GatherDimensionNumbers",
// "::builder::GatherDimensionNumbers"},
// "builder::GatherDimensionNumbers"},
// {"::mlir::mhlo::ScatterDimensionNumbers",
// "::builder::ScatterDimensionNumbers"},
// "builder::ScatterDimensionNumbers"},
// };
StringRef typeConvertFromMLIR(StringRef type) {
@ -707,14 +719,14 @@ OpEmitter::OpEmitter(const Operator &op,
// Generate C++ code for various op methods. The order here determines the
// methods in the generated file.
// genOpAsmInterface();
genOpNameGetter();
//// genOpNameGetter();
// genNamedOperandGetters();
// genNamedOperandSetters();
// genNamedResultGetters();
// genNamedRegionGetters();
genNamedSuccessorGetters();
genAttrGetters();
genAttrSetters();
//// genNamedSuccessorGetters();
//// genAttrGetters();
//// genAttrSetters();
// genOptionalAttrRemovers();
genBuilder();
// genParser();
@ -1179,7 +1191,7 @@ void OpEmitter::genSeparateArgParamBuilder() {
buildParamList(paramList, resultNames, paramKind, attrType);
auto *m = opClass.addMethodAndPrune(
"::builder::Op", "build", OpMethod::MP_Static, std::move(paramList));
"builder::Op", "build", OpMethod::MP_Static, std::move(paramList));
// If the builder is redundant, skip generating the method.
if (!m)
return;
@ -1329,7 +1341,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
resultTypeNames.reserve(numResults);
// paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
paramList.emplace_back("::builder::Builder &", "builder");
paramList.emplace_back("builder::Builder &", "builder");
// paramList.emplace_back("::mlir::OperationState &", builderOpState);
switch (typeParamKind) {
@ -1344,7 +1356,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
resultName = std::string(formatv("resultType{0}", i));
StringRef type =
result.isVariadic() ? "std::vector<::builder::Type>" : "::builder::Type";
result.isVariadic() ? "std::vector<builder::Type>" : "builder::Type";
OpMethodParameter::Property properties = OpMethodParameter::PP_None;
if (result.isOptional())
properties = OpMethodParameter::PP_Optional;
@ -1371,7 +1383,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
if (argument.is<tblgen::NamedTypeConstraint *>()) {
const auto &operand = op.getOperand(numOperands);
StringRef type =
operand.isVariadic() ? "std::vector<::builder::Op>" : "::builder::Op";
operand.isVariadic() ? "std::vector<builder::Op>" : "builder::Op";
OpMethodParameter::Property properties = OpMethodParameter::PP_None;
if (operand.isOptional())
properties = OpMethodParameter::PP_Optional;
@ -1437,24 +1449,11 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
auto attrs = op.getAttributes();
SmallVector<std::string, 4> newAttrs;
// for(auto o : operands){
// body << " operands name: "<<o.name<<" getCPPClassName:"<<o.constraint.getCPPClassName()<<"\n";
// }
// for(auto a : attrs){
// body << " attrs name: "<<a.name<<" type: "<< a.attr.getStorageType().str()<<"\n";
// std::string attrType = a.attr.getStorageType().str();
// auto ff = typeMapMLIR.find(attrType);
// if(ff != typeMapMLIR.end()){
// body << "// BBBBBB \n";
// std::string attrName = a.name.str();
// ff->second.ConvertToMlir(attrName,body);
// }
// }
// body << "// AAAAAA \n";
// for(auto p : paramList){
// body << " AAA type"<<p.getType()<<" name"<<p.getName()<<"\n";
// if (attrType == "::mlir::DenseIntElementsAttr") {
// body << " // BBBBBBBB getStorageType:" << a.attr.getStorageType().str()
// << "\n";
// body << " // BBBBBBBB getReturnType:" << a.attr.getReturnType().str()
// << "\n";
// }
body << " auto b = builder.GetImpl();\n";
@ -1463,16 +1462,6 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
body << " auto ctx = b->GetContext();\n";
for (auto a : attrs) {
std::string attrType = a.attr.getStorageType().str();
if(attrType == "::mlir::DenseIntElementsAttr")
{
body << " // BBBBBBBB getStorageType:"<< a.attr.getStorageType().str() <<"\n";
body << " // BBBBBBBB getReturnType:"<< a.attr.getReturnType().str() <<"\n";
}
auto typePair = typeMapMLIR.find(attrType);
if (typePair != typeMapMLIR.end()) {
std::string attrName = a.name.str();
@ -1480,20 +1469,42 @@ if(attrType == "::mlir::DenseIntElementsAttr")
newAttrs.emplace_back(mlirName);
}
}
int index = 0;
for (auto v : operands) {
std::string name =
v.name.empty() ? "odsArg" + std::to_string(index) : v.name.str();
index++;
if (v.isVariadic()) {
body << " std::vector<mlir::Value> " << name << "_v;\n";
body << " for(auto v : " << name << "){\n " << name
<< "_v.push_back(v.GetImpl()->GetResult());\n }"
<< "\n";
}
}
body << " mlir::" << op.getDialectName() << "::" << op.getCppClassName()
<< " currentOp =\n";
body << " opBuilder.create<mlir::" << op.getDialectName()
<< "::" << op.getDialectName() << ">(\n";
<< "::" << op.getCppClassName() << ">(\n";
body << " loc";
index = 0;
std::for_each(operands.begin(), operands.end(), [&](NamedTypeConstraint &v) {
body << ",\n " << v.name << ".getResult()";
std::string name =
v.name.empty() ? "odsArg_" + std::to_string(index) : v.name.str();
index++;
if (v.isVariadic()) {
body << ",\n " << name << "_v";
} else {
body << ",\n " << name << ".GetImpl()->GetResult()";
}
});
std::for_each(newAttrs.begin(), newAttrs.end(),
[&](std::string &n) { body << ",\n " << n; });
body << "\n );\n";
body << " builder::" << op.getCppClassName() << " builderOp;\n";
body << " builder::mhlo::" << op.getCppClassName() << " builderOp;\n";
body << " auto opImpl = builderOp.GetImpl();\n";
body << " opImpl.SetOperation(currentOp.getOperation());\n";
body << " opImpl->SetOperation(currentOp.getOperation());\n";
body << " return builderOp;\n";
// // Push all operands to the result.
@ -2109,7 +2120,8 @@ void OpEmitter::genTraits() {
void OpEmitter::genOpNameGetter() {
auto *method = opClass.addMethodAndPrune(
"std::string", "getOperationName",
OpMethod::Property(OpMethod::MP_Static | OpMethod::MP_Constexpr));
OpMethod::Property(OpMethod::MP_Static));
// OpMethod::Property(OpMethod::MP_Static | OpMethod::MP_Constexpr));
method->body() << " return std::string(\"" << op.getOperationName()
<< "\");";
}
@ -2196,7 +2208,7 @@ static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
std::string className = Operator(def).getQualCppClassName();
llvm::SplitString(StringRef(className), namespaces, StringRef("::"));
if (namespaces.begin() != namespaces.end())
os << "::builder::mhlo::" << namespaces.back().str();
os << "builder::mhlo::" << namespaces.back().str();
},
[&os]() { os << ",\n"; });
}