add more types and try to build pass
This commit is contained in:
parent
e88366b851
commit
4a9b201c4e
88
BUILD
88
BUILD
|
@ -140,39 +140,6 @@ cc_binary(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mlir-hlo-builder",
|
||||
srcs = glob([
|
||||
"tools/mlir-tblgen-builder/Builder/*.h",
|
||||
"tools/mlir-tblgen-builder/Builder/*.cpp",
|
||||
]),
|
||||
deps = [
|
||||
"@llvm-project//mlir:MlirTableGenMain",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//llvm:TableGen",
|
||||
"@llvm-project//llvm:config",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "mlir-tblgen-builder-basic",
|
||||
srcs = [
|
||||
"tests/mlir-tblgen-builder/test_basic.cpp",
|
||||
],
|
||||
deps = [
|
||||
":hlo_ops_builder_gen",
|
||||
":mlir-hlo-builder",
|
||||
"@llvm-project//mlir:MlirTableGenMain",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//llvm:TableGen",
|
||||
"@llvm-project//llvm:config",
|
||||
],
|
||||
)
|
||||
|
||||
gentbl_cc_library(
|
||||
name = "hlo_ops_builder_gen",
|
||||
strip_include_prefix = "include",
|
||||
|
@ -195,6 +162,61 @@ gentbl_cc_library(
|
|||
deps = [":hlo_ops_td_files"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mlir-hlo-builder",
|
||||
srcs = glob([
|
||||
"tools/mlir-tblgen-builder/Builder/*.h",
|
||||
"tools/mlir-tblgen-builder/Builder/*.cpp",
|
||||
]),
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
":all_passes",
|
||||
":disc_ral",
|
||||
":hlo",
|
||||
":lhlo",
|
||||
":lhlo_gpu",
|
||||
":hlo_ops_builder_gen",
|
||||
"@llvm-project//mlir:MlirTableGenMain",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//llvm:TableGen",
|
||||
"@llvm-project//llvm:config",
|
||||
# "@llvm-project//llvm:AllTargetsAsmParsers",
|
||||
# "@llvm-project//llvm:AllTargetsCodeGens",
|
||||
# "@llvm-project//llvm:Core",
|
||||
# "@llvm-project//llvm:ExecutionEngine",
|
||||
# "@llvm-project//llvm:Option",
|
||||
# "@llvm-project//llvm:OrcJIT",
|
||||
# "@llvm-project//llvm:Support",
|
||||
# "@llvm-project//llvm:Target",
|
||||
# "@llvm-project//mlir:AllPassesAndDialects",
|
||||
# "@llvm-project//mlir:IR",
|
||||
# "@llvm-project//mlir:MlirOptLib",
|
||||
# "@llvm-project//mlir:Pass",
|
||||
# "@llvm-project//mlir:Support",
|
||||
# "@llvm-project//mlir:MlirJitRunner",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "mlir-tblgen-builder-basic",
|
||||
srcs = [
|
||||
"tests/mlir-tblgen-builder/test_basic.cpp",
|
||||
],
|
||||
deps = [
|
||||
":mlir-hlo-builder",
|
||||
"@llvm-project//mlir:MlirTableGenMain",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//llvm:TableGen",
|
||||
"@llvm-project//llvm:config",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
|
||||
gentbl_cc_library(
|
||||
name = "hlo_ops_base_inc_gen",
|
||||
strip_include_prefix = "include",
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
#include "include/mlir-hlo/Dialect/mhlo/IR/hlo_builder.h.inc"
|
||||
#include "tools/mlir-tblgen-builder/Builder/Attribute.h"
|
||||
#include "tools/mlir-tblgen-builder/Builder/Builder.h"
|
||||
#include "tools/mlir-tblgen-builder/Builder/Op.h"
|
||||
#include "tools/mlir-tblgen-builder/Builder/PrimitiveType.h"
|
||||
#include "tools/mlir-tblgen-builder/Builder/Shape.h"
|
||||
#include "tools/mlir-tblgen-builder/Builder/Tensor.h"
|
||||
|
||||
int main() {
|
||||
builder::Integer i(322);
|
||||
builder::Builder builder;
|
||||
builder::Shape shape({100, 100});
|
||||
auto pType = builder::PrimitiveType::F32();
|
||||
builder::Type type(shape, pType);
|
||||
|
||||
builder::Tensor tensor(std::vector<int>(100));
|
||||
auto op = builder::mhlo::ConstOp::build(builder, type, tensor);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// static ::builder::Op build(::builder::Builder &builder, ::builder::Type
|
||||
// output, ::builder::Tensor value);
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
#include "Attribute.h"
|
||||
|
||||
#include "AttributeImpl.h"
|
||||
#include "Shape.h"
|
||||
#include "Tensor.h"
|
||||
#include "TensorImpl.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "memory.h"
|
||||
// #include "mlir/Dialect/StandardOps/Ops.h"
|
||||
|
@ -15,8 +12,93 @@
|
|||
|
||||
namespace builder {
|
||||
|
||||
inline bool PrimitiveType::operator==(const PrimitiveType& pt) {
|
||||
return impl_ == pt.impl_;
|
||||
}
|
||||
PrimitiveType::PrimitiveType(std::shared_ptr<PrimitiveType::Impl> impl)
|
||||
: impl_(impl) {}
|
||||
PrimitiveType PrimitiveType::PRED() {
|
||||
PrimitiveType p(
|
||||
std::make_shared<PrimitiveType::Impl>(PrimitiveType::Impl::pType::PRED));
|
||||
return p;
|
||||
}
|
||||
PrimitiveType PrimitiveType::S8() {
|
||||
PrimitiveType p(
|
||||
std::make_shared<PrimitiveType::Impl>(PrimitiveType::Impl::pType::S8));
|
||||
return p;
|
||||
}
|
||||
PrimitiveType PrimitiveType::S16() {
|
||||
PrimitiveType p(
|
||||
std::make_shared<PrimitiveType::Impl>(PrimitiveType::Impl::pType::S16));
|
||||
return p;
|
||||
}
|
||||
PrimitiveType PrimitiveType::F32() {
|
||||
PrimitiveType p(
|
||||
std::make_shared<PrimitiveType::Impl>(PrimitiveType::Impl::pType::F32));
|
||||
return p;
|
||||
}
|
||||
PrimitiveType PrimitiveType::S32() {
|
||||
PrimitiveType p(
|
||||
std::make_shared<PrimitiveType::Impl>(PrimitiveType::Impl::pType::S32));
|
||||
return p;
|
||||
}
|
||||
PrimitiveType PrimitiveType::S64() {
|
||||
PrimitiveType p(
|
||||
std::make_shared<PrimitiveType::Impl>(PrimitiveType::Impl::pType::S64));
|
||||
return p;
|
||||
}
|
||||
|
||||
Shape::Shape(std::vector<int64_t> dims)
|
||||
: impl_(std::make_shared<Shape::Impl>(dims)) {}
|
||||
|
||||
Type::Type(Shape& shape, PrimitiveType& primitiveType)
|
||||
: impl_(std::make_shared<Type::Impl>(shape, primitiveType)) {}
|
||||
|
||||
Integer::Integer(int value) : impl_(std::make_shared<Integer::Impl>(value)) {}
|
||||
Integer::Integer(int64_t value)
|
||||
: impl_(std::make_shared<Integer::Impl>(value)) {}
|
||||
|
||||
Float::Float(float value) : impl_(std::make_shared<Float::Impl>(value)) {}
|
||||
Float::Float(double value) : impl_(std::make_shared<Float::Impl>(value)) {}
|
||||
|
||||
int64_t Array::GetSize() { return impl_->GetSize(); }
|
||||
PrimitiveType Array::GetType() { return impl_->GetType(); }
|
||||
Array::Array(std::vector<int> value)
|
||||
: impl_(std::make_shared<Array::Impl>(value)) {}
|
||||
Array::Array(std::vector<int64_t> value)
|
||||
: impl_(std::make_shared<Array::Impl>(value)) {}
|
||||
Array::Array(std::vector<std::string> value)
|
||||
: impl_(std::make_shared<Array::Impl>(value)) {}
|
||||
|
||||
Tensor::Tensor(Shape& shape, PrimitiveType& primitiveType)
|
||||
: impl_(std::make_shared<Impl>(shape, primitiveType)) {}
|
||||
Shape Tensor::GetShape() { return impl_->GetShape(); }
|
||||
PrimitiveType Tensor::GetType() { return impl_->GetType(); }
|
||||
Tensor::Tensor(std::vector<int> value)
|
||||
: impl_(std::make_shared<Tensor::Impl>(value)) {}
|
||||
Tensor::Tensor(std::vector<int64_t> value)
|
||||
: impl_(std::make_shared<Tensor::Impl>(value)) {}
|
||||
Tensor::Tensor(std::vector<float> value)
|
||||
: impl_(std::make_shared<Tensor::Impl>(value)) {}
|
||||
|
||||
TensorInt::TensorInt(Shape& shape, PrimitiveType& primitiveType)
|
||||
: impl_(std::make_shared<Impl>(shape, primitiveType)) {}
|
||||
Shape TensorInt::GetShape() { return impl_->GetShape(); }
|
||||
PrimitiveType TensorInt::GetType() { return impl_->GetType(); }
|
||||
|
||||
ChannelHandle::ChannelHandle()
|
||||
: impl_(std::make_shared<ChannelHandle::Impl>()) {}
|
||||
|
||||
ConvDimensionNumbers::ConvDimensionNumbers()
|
||||
: impl_(std::make_shared<ConvDimensionNumbers::Impl>()) {}
|
||||
|
||||
DotDimensionNumbers::DotDimensionNumbers()
|
||||
: impl_(std::make_shared<DotDimensionNumbers::Impl>()) {}
|
||||
|
||||
GatherDimensionNumbers::GatherDimensionNumbers()
|
||||
: impl_(std::make_shared<GatherDimensionNumbers::Impl>()) {}
|
||||
|
||||
ScatterDimensionNumbers::ScatterDimensionNumbers()
|
||||
: impl_(std::make_shared<ScatterDimensionNumbers::Impl>()) {}
|
||||
|
||||
} // namespace builder
|
|
@ -1,16 +1,57 @@
|
|||
#ifndef BUILDER_ATTRIBUTE_H_
|
||||
#define BUILDER_ATTRIBUTE_H_
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
namespace builder {
|
||||
|
||||
class PrimitiveType {
|
||||
public:
|
||||
static PrimitiveType PRED();
|
||||
static PrimitiveType S8();
|
||||
static PrimitiveType S16();
|
||||
static PrimitiveType F32();
|
||||
static PrimitiveType S32();
|
||||
static PrimitiveType S64();
|
||||
inline bool operator==(const PrimitiveType& pt);
|
||||
|
||||
class Impl;
|
||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||
PrimitiveType(std::shared_ptr<Impl>);
|
||||
|
||||
private:
|
||||
std::shared_ptr<Impl> impl_;
|
||||
// PrimitiveType(const PrimitiveType&) = default;
|
||||
};
|
||||
|
||||
class Shape {
|
||||
public:
|
||||
Shape(std::vector<int64_t> dims);
|
||||
class Impl;
|
||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
class Integer {
|
||||
public:
|
||||
Integer(int value);
|
||||
Integer(int64_t value);
|
||||
class Impl;
|
||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
class Float {
|
||||
public:
|
||||
Float(float value);
|
||||
Float(double value);
|
||||
class Impl;
|
||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<Impl> impl_;
|
||||
|
@ -18,6 +59,43 @@ class Integer {
|
|||
|
||||
class Array {
|
||||
public:
|
||||
Array(std::vector<int> value);
|
||||
Array(std::vector<int64_t> value);
|
||||
Array(std::vector<std::string> value);
|
||||
|
||||
int64_t GetSize();
|
||||
PrimitiveType GetType();
|
||||
|
||||
class Impl;
|
||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
class Tensor {
|
||||
public:
|
||||
Tensor(Shape& shape, PrimitiveType& primitiveType);
|
||||
Tensor(std::vector<int> value);
|
||||
Tensor(std::vector<int64_t> value);
|
||||
Tensor(std::vector<float> value);
|
||||
|
||||
Shape GetShape();
|
||||
PrimitiveType GetType();
|
||||
|
||||
class Impl;
|
||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
class TensorInt {
|
||||
public:
|
||||
TensorInt(Shape& shape, PrimitiveType& primitiveType);
|
||||
Shape GetShape();
|
||||
PrimitiveType GetType();
|
||||
|
||||
class Impl;
|
||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||
|
||||
|
@ -27,6 +105,7 @@ class Array {
|
|||
|
||||
class Type {
|
||||
public:
|
||||
Type(Shape& shape, PrimitiveType& primitiveType);
|
||||
class Impl;
|
||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||
|
||||
|
@ -34,16 +113,56 @@ class Type {
|
|||
std::shared_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
// template <typename T>
|
||||
// class DenseIntElementsAttr {
|
||||
// public:
|
||||
// explicit DenseIntElementsAttr(Tensor);
|
||||
// class Impl;
|
||||
// std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||
class ChannelHandle {
|
||||
public:
|
||||
ChannelHandle();
|
||||
class Impl;
|
||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
class ConvDimensionNumbers {
|
||||
public:
|
||||
ConvDimensionNumbers();
|
||||
class Impl;
|
||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
class DotDimensionNumbers {
|
||||
public:
|
||||
DotDimensionNumbers();
|
||||
class Impl;
|
||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
class GatherDimensionNumbers {
|
||||
public:
|
||||
GatherDimensionNumbers();
|
||||
class Impl;
|
||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
class ScatterDimensionNumbers {
|
||||
public:
|
||||
ScatterDimensionNumbers();
|
||||
class Impl;
|
||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
// private:
|
||||
// std::shared_ptr<Impl> impl_;
|
||||
// };
|
||||
|
||||
} // namespace builder
|
||||
|
||||
|
|
|
@ -1,8 +1,13 @@
|
|||
#ifndef BUILDER_TYPEIMPL_
|
||||
#define BUILDER_TYPEIMPL_
|
||||
#ifndef BUILDER_ATTRIBUTEIMPL_H_
|
||||
#define BUILDER_ATTRIBUTEIMPL_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "Builder.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
@ -13,45 +18,306 @@
|
|||
|
||||
namespace builder {
|
||||
|
||||
class PrimitiveType::Impl {
|
||||
public:
|
||||
enum pType {
|
||||
PRIMITIVE_TYPE_INVALID = 0,
|
||||
PRED = 1,
|
||||
S8 = 2,
|
||||
S16 = 3,
|
||||
S32 = 4,
|
||||
S64 = 5,
|
||||
U8 = 6,
|
||||
U16 = 7,
|
||||
U32 = 8,
|
||||
U64 = 9,
|
||||
F16 = 10,
|
||||
F32 = 11,
|
||||
BF16 = 16,
|
||||
F64 = 12,
|
||||
C64 = 15,
|
||||
C128 = 18,
|
||||
TUPLE = 13,
|
||||
OPAQUE_TYPE = 14,
|
||||
TOKEN = 17
|
||||
};
|
||||
|
||||
Impl(pType t) : t_(t) {
|
||||
switch (t) {
|
||||
case PRIMITIVE_TYPE_INVALID:
|
||||
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
||||
return mlir::NoneType::get(context);
|
||||
};
|
||||
break;
|
||||
case PRED:
|
||||
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
||||
return mlir::IntegerType::get(
|
||||
context, 1, mlir::IntegerType::SignednessSemantics::Signed);
|
||||
};
|
||||
break;
|
||||
case S8:
|
||||
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
||||
return mlir::IntegerType::get(
|
||||
context, 8, mlir::IntegerType::SignednessSemantics::Signed);
|
||||
};
|
||||
break;
|
||||
case S16:
|
||||
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
||||
return mlir::IntegerType::get(
|
||||
context, 16, mlir::IntegerType::SignednessSemantics::Signed);
|
||||
};
|
||||
break;
|
||||
case F32:
|
||||
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
||||
return mlir::Float32Type::get(context);
|
||||
};
|
||||
break;
|
||||
case S32:
|
||||
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
||||
return mlir::IntegerType::get(
|
||||
context, 32, mlir::IntegerType::SignednessSemantics::Signed);
|
||||
};
|
||||
break;
|
||||
case S64:
|
||||
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
||||
return mlir::IntegerType::get(
|
||||
context, 64, mlir::IntegerType::SignednessSemantics::Signed);
|
||||
};
|
||||
break;
|
||||
default:
|
||||
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
||||
return mlir::NoneType::get(context);
|
||||
};
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// class BFloat16Type;
|
||||
// class ComplexType;
|
||||
// class Float128Type;
|
||||
// class Float16Type;
|
||||
// class Float32Type;
|
||||
// class Float64Type;
|
||||
// class Float80Type;
|
||||
// class FunctionType;
|
||||
// class IndexType;
|
||||
// class IntegerType;
|
||||
// class MemRefType;
|
||||
// class NoneType;
|
||||
// class OpaqueType;
|
||||
// class RankedTensorType;
|
||||
// class TupleType;
|
||||
// class UnrankedMemRefType;
|
||||
// class UnrankedTensorType;
|
||||
// class VectorType;
|
||||
|
||||
inline bool operator==(const PrimitiveType::Impl &impl) {
|
||||
return t_ == impl.t_;
|
||||
}
|
||||
std::function<mlir::Type(mlir::MLIRContext *context)> GetMlirType;
|
||||
|
||||
private:
|
||||
pType t_;
|
||||
};
|
||||
|
||||
class Shape::Impl {
|
||||
public:
|
||||
Impl(std::vector<int64_t> dims) : dims_(dims) {}
|
||||
|
||||
private:
|
||||
std::vector<int64_t> dims_;
|
||||
};
|
||||
|
||||
class Type::Impl {
|
||||
public:
|
||||
Impl(Shape &shape, PrimitiveType &primitiveType)
|
||||
: shape_(shape), primitiveType_(primitiveType) {}
|
||||
Shape GetShape() { return shape_; }
|
||||
PrimitiveType GetType() { return primitiveType_; }
|
||||
|
||||
private:
|
||||
Shape shape_;
|
||||
PrimitiveType primitiveType_;
|
||||
};
|
||||
|
||||
class Integer::Impl {
|
||||
public:
|
||||
Impl(){};
|
||||
|
||||
Impl(int value) {
|
||||
GetAttr = [=](mlir::MLIRContext *context) -> mlir::IntegerAttr {
|
||||
auto int32_type = mlir::IntegerType::get(context, 32);
|
||||
return mlir::IntegerAttr::get(int32_type, value);
|
||||
};
|
||||
};
|
||||
Impl(int64_t value)
|
||||
: GetAttr([=](mlir::MLIRContext *context) -> mlir::IntegerAttr {
|
||||
auto int32_type = mlir::IntegerType::get(context, 64);
|
||||
return mlir::IntegerAttr::get(int32_type, value);
|
||||
}){};
|
||||
Impl(int64_t value) {
|
||||
GetAttr = [=](mlir::MLIRContext *context) -> mlir::IntegerAttr {
|
||||
auto int64_type = mlir::IntegerType::get(context, 64);
|
||||
return mlir::IntegerAttr::get(int64_type, value);
|
||||
};
|
||||
};
|
||||
std::function<mlir::IntegerAttr(mlir::MLIRContext *context)> GetAttr;
|
||||
|
||||
private:
|
||||
};
|
||||
|
||||
class Float::Impl {
|
||||
public:
|
||||
Impl(float value) {
|
||||
GetAttr = [=](mlir::MLIRContext *context) -> mlir::FloatAttr {
|
||||
auto float32_type = mlir::Float32Type::get(context);
|
||||
return mlir::FloatAttr::get(float32_type, value);
|
||||
};
|
||||
};
|
||||
Impl(double value) {
|
||||
GetAttr = [=](mlir::MLIRContext *context) -> mlir::FloatAttr {
|
||||
auto float64_type = mlir::Float64Type::get(context);
|
||||
return mlir::FloatAttr::get(float64_type, value);
|
||||
};
|
||||
};
|
||||
std::function<mlir::FloatAttr(mlir::MLIRContext *context)> GetAttr;
|
||||
|
||||
private:
|
||||
};
|
||||
|
||||
class Array::Impl {
|
||||
public:
|
||||
Impl() = default;
|
||||
Impl(std::vector<int> value)
|
||||
: size_({value.size()}), primitiveType_(PrimitiveType::S32()) {
|
||||
// GetAttr = [=](mlir::MLIRContext *context) -> mlir::ArrayAttr {
|
||||
// auto type =
|
||||
// mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
|
||||
// mlir::IntegerType::get(context, 32));
|
||||
// return mlir::ArrayAttr::get(type, llvm::makeArrayRef(value));
|
||||
// };
|
||||
}
|
||||
Impl(std::vector<int64_t> value)
|
||||
: size_({value.size()}), primitiveType_(PrimitiveType::S64()) {
|
||||
// GetAttr = [=](mlir::MLIRContext *context) -> mlir::ArrayAttr {
|
||||
// auto type =
|
||||
// mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
|
||||
// mlir::IntegerType::get(context, 64));
|
||||
// return mlir::ArrayAttr::get(type, llvm::makeArrayRef(value));
|
||||
// };
|
||||
}
|
||||
Impl(std::vector<std::string> value)
|
||||
: size_({value.size()}), primitiveType_(PrimitiveType::F32()) {
|
||||
// GetAttr = [=](mlir::MLIRContext *context) -> mlir::ArrayAttr {
|
||||
// auto type =
|
||||
// mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
|
||||
// mlir::FloatType::getF32(context));
|
||||
// return mlir::ArrayAttr::get(type, llvm::makeArrayRef(value));
|
||||
// };
|
||||
}
|
||||
|
||||
int64_t GetSize() { return size_; }
|
||||
PrimitiveType GetType() { return primitiveType_; }
|
||||
|
||||
std::function<mlir::ArrayAttr(mlir::MLIRContext *context)> GetAttr;
|
||||
|
||||
private:
|
||||
int64_t size_;
|
||||
PrimitiveType primitiveType_;
|
||||
};
|
||||
class Type::Impl {
|
||||
|
||||
class Tensor::Impl {
|
||||
public:
|
||||
Impl() = default;
|
||||
Impl(std::vector<int> value)
|
||||
: shape_({value.size()}), primitiveType_(PrimitiveType::S32()) {
|
||||
GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
|
||||
auto type =
|
||||
mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
|
||||
mlir::IntegerType::get(context, 32));
|
||||
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
|
||||
};
|
||||
}
|
||||
Impl(std::vector<int64_t> value)
|
||||
: shape_({value.size()}), primitiveType_(PrimitiveType::S64()) {
|
||||
GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
|
||||
auto type =
|
||||
mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
|
||||
mlir::IntegerType::get(context, 64));
|
||||
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
|
||||
};
|
||||
}
|
||||
Impl(std::vector<float> value)
|
||||
: shape_({value.size()}), primitiveType_(PrimitiveType::F32()) {
|
||||
GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
|
||||
auto type =
|
||||
mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
|
||||
mlir::FloatType::getF32(context));
|
||||
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
|
||||
};
|
||||
}
|
||||
Impl(Shape &shape, PrimitiveType &primitiveType)
|
||||
: shape_(shape), primitiveType_(primitiveType) {}
|
||||
|
||||
Shape GetShape() { return shape_; }
|
||||
PrimitiveType GetType() { return primitiveType_; }
|
||||
|
||||
std::function<mlir::DenseElementsAttr(mlir::MLIRContext *context)> GetAttr;
|
||||
|
||||
private:
|
||||
Shape shape_;
|
||||
PrimitiveType primitiveType_;
|
||||
};
|
||||
|
||||
class TensorInt::Impl {
|
||||
public:
|
||||
Impl(Shape &shape, PrimitiveType &primitiveType)
|
||||
: shape_(shape), primitiveType_(primitiveType) {}
|
||||
|
||||
Shape GetShape() { return shape_; }
|
||||
PrimitiveType GetType() { return primitiveType_; }
|
||||
|
||||
std::function<mlir::DenseIntElementsAttr(mlir::MLIRContext *context)> GetAttr;
|
||||
|
||||
private:
|
||||
Shape shape_;
|
||||
PrimitiveType primitiveType_;
|
||||
};
|
||||
|
||||
class ChannelHandle::Impl {
|
||||
public:
|
||||
Impl() {}
|
||||
std::function<mlir::mhlo::ChannelHandle(mlir::MLIRContext *context)> GetAttr;
|
||||
|
||||
private:
|
||||
};
|
||||
|
||||
// class DenseIntElementsAttr::Impl {
|
||||
// public:
|
||||
// Impl() = default;
|
||||
class ConvDimensionNumbers::Impl {
|
||||
public:
|
||||
Impl() {}
|
||||
std::function<mlir::mhlo::ConvDimensionNumbers(mlir::MLIRContext *context)>
|
||||
GetAttr;
|
||||
|
||||
// private:
|
||||
// };
|
||||
private:
|
||||
};
|
||||
|
||||
class DotDimensionNumbers::Impl {
|
||||
public:
|
||||
Impl() {}
|
||||
std::function<mlir::mhlo::DotDimensionNumbers(mlir::MLIRContext *context)>
|
||||
GetAttr;
|
||||
|
||||
private:
|
||||
};
|
||||
|
||||
class GatherDimensionNumbers::Impl {
|
||||
public:
|
||||
Impl() {}
|
||||
std::function<mlir::mhlo::GatherDimensionNumbers(mlir::MLIRContext *context)>
|
||||
GetAttr;
|
||||
|
||||
private:
|
||||
};
|
||||
|
||||
class ScatterDimensionNumbers::Impl {
|
||||
public:
|
||||
Impl() {}
|
||||
std::function<mlir::mhlo::ScatterDimensionNumbers(mlir::MLIRContext *context)>
|
||||
GetAttr;
|
||||
|
||||
private:
|
||||
};
|
||||
|
||||
} // namespace builder
|
||||
|
||||
|
|
|
@ -8,10 +8,20 @@
|
|||
// #include "mlir/IR/StandardTypes.h"
|
||||
// #include "mlir/IR/Types.h"
|
||||
// #include "mlir/IR/Value.h"
|
||||
#include "Attribute.h"
|
||||
#include "AttributeImpl.h"
|
||||
#include "Op.h"
|
||||
#include "OpImpl.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
|
||||
namespace builder {
|
||||
|
||||
Builder::Builder() : impl_(std::make_shared<Impl>()) {}
|
||||
void Builder::DumpModule() {}
|
||||
|
||||
} // namespace builder
|
||||
} // namespace builder
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "include/mlir-hlo/Dialect/mhlo/IR/hlo_builder.cc.inc"
|
||||
|
|
|
@ -3,14 +3,17 @@
|
|||
|
||||
#include <memory>
|
||||
|
||||
#include "tools/mlir-tblgen-builder/Builder/Attribute.h"
|
||||
#include "tools/mlir-tblgen-builder/Builder/Op.h"
|
||||
|
||||
namespace builder {
|
||||
|
||||
class Builder {
|
||||
public:
|
||||
class Impl;
|
||||
|
||||
Builder();
|
||||
void DumpModule();
|
||||
|
||||
class Impl;
|
||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||
|
||||
private:
|
||||
|
@ -18,4 +21,7 @@ class Builder {
|
|||
};
|
||||
} // namespace builder
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "include/mlir-hlo/Dialect/mhlo/IR/hlo_builder.h.inc"
|
||||
|
||||
#endif
|
|
@ -4,7 +4,10 @@
|
|||
#include "Builder.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
@ -14,15 +17,30 @@ namespace builder {
|
|||
|
||||
class Builder::Impl {
|
||||
public:
|
||||
Impl() {}
|
||||
// mlir::Location GetLoc() { return mlir_loc_; }
|
||||
// mlir::OpBuilder GetBuilder() { return mlir_builder_; }
|
||||
mlir::MLIRContext *GetContext() { return &mlir_context_; }
|
||||
Impl() : builder_(&context_) {
|
||||
module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context_));
|
||||
|
||||
llvm::SmallVector<mlir::Type, 4> arg_types;
|
||||
// Create the main function.
|
||||
mlir::FunctionType funcType = builder_.getFunctionType(arg_types, {});
|
||||
main_func_ = mlir::FuncOp::create(builder_.getUnknownLoc(), "main",
|
||||
funcType, /* attrs = */ {});
|
||||
|
||||
entry_block_ = main_func_.addEntryBlock();
|
||||
builder_.setInsertionPointToStart(entry_block_);
|
||||
module_.push_back(main_func_);
|
||||
}
|
||||
|
||||
mlir::Location GetLoc() { return builder_.getUnknownLoc(); }
|
||||
mlir::OpBuilder GetBuilder() { return builder_; }
|
||||
mlir::MLIRContext* GetContext() { return &context_; }
|
||||
|
||||
private:
|
||||
// mlir::Location mlir_loc_;
|
||||
// mlir::OpBuilder mlir_builder_;
|
||||
mlir::MLIRContext mlir_context_;
|
||||
mlir::MLIRContext context_;
|
||||
mlir::ModuleOp module_;
|
||||
mlir::OpBuilder builder_;
|
||||
mlir::FuncOp main_func_;
|
||||
mlir::Block* entry_block_;
|
||||
};
|
||||
|
||||
} // namespace builder
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
namespace builder {
|
||||
class Op {
|
||||
public:
|
||||
class Impl;
|
||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||
|
||||
|
|
|
@ -15,10 +15,11 @@ namespace builder {
|
|||
class Op::Impl {
|
||||
public:
|
||||
Impl() = default;
|
||||
void SetOperation(Operation *Op) { op_ = Op; }
|
||||
void SetOperation(mlir::Operation *Op) { op_ = Op; }
|
||||
mlir::Value GetResult() { return op_->getResult(0); }
|
||||
|
||||
private:
|
||||
Operation *op_;
|
||||
mlir::Operation *op_;
|
||||
};
|
||||
|
||||
} // namespace builder
|
||||
|
|
|
@ -1,28 +0,0 @@
|
|||
#include "PrimitiveType.h"
|
||||
|
||||
#include "Shape.h"
|
||||
#include "Tensor.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
// #include "mlir/Dialect/StandardOps/Ops.h"
|
||||
// #include "mlir/IR/Attributes.h"
|
||||
// #include "mlir/IR/Operation.h"
|
||||
// #include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
// #include "mlir/IR/Value.h"
|
||||
|
||||
namespace builder {
|
||||
|
||||
// constexpr mlir::Type GetMlirType(const PrimitiveType primitiveType) {
|
||||
// switch (primitiveType)
|
||||
// {
|
||||
// case PrimitiveType::BF16 :
|
||||
// return
|
||||
// break;
|
||||
|
||||
// default:
|
||||
// break;
|
||||
// }
|
||||
|
||||
// }
|
||||
|
||||
} // namespace builder
|
|
@ -1,29 +0,0 @@
|
|||
#ifndef BUILDER_PRIMITIVE_TYPE_
|
||||
#define BUILDER_PRIMITIVE_TYPE_
|
||||
|
||||
namespace builder {
|
||||
|
||||
enum class PrimitiveType {
|
||||
PRIMITIVE_TYPE_INVALID = 0,
|
||||
PRED = 1,
|
||||
S8 = 2,
|
||||
S16 = 3,
|
||||
S32 = 4,
|
||||
S64 = 5,
|
||||
U8 = 6,
|
||||
U16 = 7,
|
||||
U32 = 8,
|
||||
U64 = 9,
|
||||
F16 = 10,
|
||||
F32 = 11,
|
||||
BF16 = 16,
|
||||
F64 = 12,
|
||||
C64 = 15,
|
||||
C128 = 18,
|
||||
TUPLE = 13,
|
||||
OPAQUE_TYPE = 14,
|
||||
TOKEN = 17
|
||||
};
|
||||
|
||||
}
|
||||
#endif
|
|
@ -1,17 +0,0 @@
|
|||
#ifndef BUILDER_SHAPE_
|
||||
#define BUILDER_SHAPE_
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
|
||||
namespace builder {
|
||||
class Shape {
|
||||
class Impl;
|
||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<Impl> impl_;
|
||||
};
|
||||
} // namespace builder
|
||||
|
||||
#endif
|
|
@ -1,24 +0,0 @@
|
|||
#ifndef BUILDER_SHAPEIMPL_
|
||||
#define BUILDER_SHAPEIMPL_
|
||||
|
||||
#include "Builder.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
namespace builder {
|
||||
|
||||
class Shape::Impl {
|
||||
public:
|
||||
Impl() = default;
|
||||
|
||||
private:
|
||||
};
|
||||
|
||||
} // namespace builder
|
||||
|
||||
#endif
|
|
@ -1,22 +0,0 @@
|
|||
#include "Tensor.h"
|
||||
|
||||
#include "Shape.h"
|
||||
#include "TensorImpl.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
// #include "mlir/Dialect/StandardOps/Ops.h"
|
||||
// #include "mlir/IR/Attributes.h"
|
||||
// #include "mlir/IR/Operation.h"
|
||||
// #include "mlir/IR/StandardTypes.h"
|
||||
// #include "mlir/IR/Types.h"
|
||||
// #include "mlir/IR/Value.h"
|
||||
|
||||
namespace builder {
|
||||
|
||||
// Tensor::Tensor() : impl_(std::make_shared<Impl>()) {}
|
||||
Tensor::Tensor(Shape& shape, PrimitiveType& primitiveType)
|
||||
: impl_(std::make_shared<Impl>(shape, primitiveType)) {}
|
||||
|
||||
Shape Tensor::GetShape() { return impl_->GetShape(); }
|
||||
PrimitiveType Tensor::GetType() { return impl_->GetType(); }
|
||||
|
||||
} // namespace builder
|
|
@ -1,25 +0,0 @@
|
|||
#ifndef BUILDER_TENSOR_
|
||||
#define BUILDER_TENSOR_
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
|
||||
#include "PrimitiveType.h"
|
||||
#include "Shape.h"
|
||||
|
||||
namespace builder {
|
||||
|
||||
class Tensor {
|
||||
public:
|
||||
Tensor(Shape& shape, PrimitiveType& primitiveType);
|
||||
class Impl;
|
||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||
Shape GetShape();
|
||||
PrimitiveType GetType();
|
||||
|
||||
private:
|
||||
std::shared_ptr<Impl> impl_;
|
||||
};
|
||||
} // namespace builder
|
||||
|
||||
#endif
|
|
@ -1,32 +0,0 @@
|
|||
#ifndef BUILDER_TENSORIMPL_
|
||||
#define BUILDER_TENSORIMPL_
|
||||
|
||||
#include "Builder.h"
|
||||
#include "PrimitiveType.h"
|
||||
#include "Shape.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
namespace builder {
|
||||
|
||||
class Tensor::Impl {
|
||||
public:
|
||||
Impl() = default;
|
||||
Impl(Shape& shape, PrimitiveType& primitiveType)
|
||||
: shape_(shape), primitiveType_(primitiveType) {}
|
||||
Shape GetShape() { return shape_; }
|
||||
PrimitiveType GetType() { return primitiveType_; }
|
||||
|
||||
private:
|
||||
Shape shape_;
|
||||
PrimitiveType primitiveType_;
|
||||
};
|
||||
|
||||
} // namespace builder
|
||||
|
||||
#endif
|
|
@ -52,81 +52,93 @@ static const std::map<std::string, mlirTypeWrap> typeMapMLIR = {
|
|||
{"::mlir::StringAttr",
|
||||
{"std::string",
|
||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||
body << " mlir::StringAttr " << var << "_mlir = mlir::StringAttr::get("
|
||||
<< var << ", ctx);\n";
|
||||
body << " mlir::StringAttr " << var
|
||||
<< "_mlir = mlir::StringAttr::get(ctx, mlir::Twine(" << var
|
||||
<< "));\n";
|
||||
return var + "_mlir";
|
||||
}}},
|
||||
{"::mlir::IntegerAttr",
|
||||
{"builder::Integer",
|
||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||
// body << " mlir::IntegerAttr " << var << "_mlir = mlir::IntegerAttr::get("
|
||||
// << var << ", ctx);\n";
|
||||
body << " mlir::IntegerAttr " << var << "_mlir = " << var
|
||||
<< ".GetAttr(ctx);\n";
|
||||
return var + "_mlir";
|
||||
}}},
|
||||
{"::mlir::DenseIntElementsAttr",
|
||||
{"std::vector<int>",
|
||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||
body << " mlir::DenseIntElementsAttr " << var
|
||||
<< "_mlir = mlir::DenseIntElementsAttr::get("
|
||||
<< "mlir::VectorType::get(" << var
|
||||
<< ".size(), opBuilder->getIntegerType(32))," << var << ");\n";
|
||||
return var + "_mlir";
|
||||
}}},
|
||||
{"::mlir::mhlo::ChannelHandle",
|
||||
{"ChannelHandle",
|
||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||
<< ".GetImpl()->GetAttr(ctx);\n";
|
||||
return var + "_mlir";
|
||||
}}},
|
||||
{"::mlir::FloatAttr",
|
||||
{"float",
|
||||
{"builder::Float",
|
||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||
body << " mlir::FloatAttr " << var << "_mlir = mlir::FloatAttr::get("
|
||||
<< var << ", ctx);\n";
|
||||
body << " mlir::FloatAttr " << var << "_mlir = " << var
|
||||
<< ".GetImpl()->GetAttr(ctx);\n";
|
||||
return var + "_mlir";
|
||||
}}},
|
||||
{"::mlir::DenseIntElementsAttr",
|
||||
{"builder::TensorInt",
|
||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||
body << " mlir::DenseIntElementsAttr " << var << "_mlir = " << var
|
||||
<< ".GetImpl()->GetAttr(ctx);\n";
|
||||
return var + "_mlir";
|
||||
}}},
|
||||
{"::mlir::mhlo::ChannelHandle",
|
||||
{"builder::ChannelHandle",
|
||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||
body << " mlir::mhlo::ChannelHandle " << var << "_mlir = " << var
|
||||
<< ".GetImpl()->GetAttr(ctx);\n";
|
||||
return var + "_mlir";
|
||||
}}},
|
||||
{"::mlir::BoolAttr",
|
||||
{"bool",
|
||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||
body << " mlir::BoolAttr " << var << "_mlir = mlir::BoolAttr::get("
|
||||
<< var << ", ctx);\n";
|
||||
body << " mlir::BoolAttr " << var
|
||||
<< "_mlir = mlir::BoolAttr::get(ctx, " << var << ");\n";
|
||||
return var + "_mlir";
|
||||
}}},
|
||||
{"::mlir::ElementsAttr",
|
||||
{"::builder::Array",
|
||||
{"builder::Tensor",
|
||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||
body << " mlir::DenseElementsAttr " << var << "_mlir = " << var
|
||||
<< ".GetImpl()->GetAttr(ctx);\n";
|
||||
return var + "_mlir";
|
||||
}}},
|
||||
{"::mlir::DenseElementsAttr",
|
||||
{"::builder::Tensor",
|
||||
{"builder::Tensor",
|
||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||
body << " mlir::DenseElementsAttr " << var << "_mlir = " << var
|
||||
<< ".GetImpl()->GetAttr(ctx);\n";
|
||||
return var + "_mlir";
|
||||
}}},
|
||||
// current only support string array.
|
||||
{"::mlir::ArrayAttr",
|
||||
{"std::vector<std::string>",
|
||||
{"builder::Array",
|
||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||
body << " mlir::ArrayAttr " << var << "_mlir = " << var
|
||||
<< ".GetImpl()->GetAttr(ctx);\n";
|
||||
return var + "_mlir";
|
||||
}}},
|
||||
{"::mlir::mhlo::ConvDimensionNumbers",
|
||||
{"::builder::ConvDimensionNumbers",
|
||||
{"builder::ConvDimensionNumbers",
|
||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||
body << " mlir::mhlo::ConvDimensionNumbers " << var
|
||||
<< "_mlir = " << var << ".GetImpl()->GetAttr(ctx);\n";
|
||||
return var + "_mlir";
|
||||
}}},
|
||||
{"::mlir::mhlo::DotDimensionNumbers",
|
||||
{"::builder::DotDimensionNumbers",
|
||||
{"builder::DotDimensionNumbers",
|
||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||
body << " mlir::mhlo::DotDimensionNumbers " << var << "_mlir = " << var
|
||||
<< ".GetImpl()->GetAttr(ctx);\n";
|
||||
return var + "_mlir";
|
||||
}}},
|
||||
{"::mlir::mhlo::GatherDimensionNumbers",
|
||||
{"::builder::GatherDimensionNumbers",
|
||||
{"builder::GatherDimensionNumbers",
|
||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||
body << " mlir::mhlo::GatherDimensionNumbers " << var
|
||||
<< "_mlir = " << var << ".GetImpl()->GetAttr(ctx);\n";
|
||||
return var + "_mlir";
|
||||
}}},
|
||||
{"::mlir::mhlo::ScatterDimensionNumbers",
|
||||
{"::builder::ScatterDimensionNumbers",
|
||||
{"builder::ScatterDimensionNumbers",
|
||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||
body << " mlir::mhlo::ScatterDimensionNumbers " << var
|
||||
<< "_mlir = " << var << ".GetImpl()->GetAttr(ctx);\n";
|
||||
return var + "_mlir";
|
||||
}}},
|
||||
};
|
||||
|
@ -138,16 +150,16 @@ static const std::map<std::string, mlirTypeWrap> typeMapMLIR = {
|
|||
// {"::mlir::mhlo::ChannelHandle", "ChannelHandle"},
|
||||
// {"::mlir::FloatAttr", "float"},
|
||||
// {"::mlir::BoolAttr", "bool"},
|
||||
// {"::mlir::ElementsAttr", "::builder::Array"},
|
||||
// {"::mlir::DenseElementsAttr", "::builder::Tensor"},
|
||||
// {"::mlir::ElementsAttr", "builder::Array"},
|
||||
// {"::mlir::DenseElementsAttr", "builder::Tensor"},
|
||||
// // current only support string array.
|
||||
// {"::mlir::ArrayAttr", "std::vector<std::string>"},
|
||||
// {"::mlir::mhlo::ConvDimensionNumbers", "::builder::ConvDimensionNumbers"},
|
||||
// {"::mlir::mhlo::DotDimensionNumbers", "::builder::DotDimensionNumbers"},
|
||||
// {"::mlir::mhlo::ConvDimensionNumbers", "builder::ConvDimensionNumbers"},
|
||||
// {"::mlir::mhlo::DotDimensionNumbers", "builder::DotDimensionNumbers"},
|
||||
// {"::mlir::mhlo::GatherDimensionNumbers",
|
||||
// "::builder::GatherDimensionNumbers"},
|
||||
// "builder::GatherDimensionNumbers"},
|
||||
// {"::mlir::mhlo::ScatterDimensionNumbers",
|
||||
// "::builder::ScatterDimensionNumbers"},
|
||||
// "builder::ScatterDimensionNumbers"},
|
||||
// };
|
||||
|
||||
StringRef typeConvertFromMLIR(StringRef type) {
|
||||
|
@ -707,14 +719,14 @@ OpEmitter::OpEmitter(const Operator &op,
|
|||
// Generate C++ code for various op methods. The order here determines the
|
||||
// methods in the generated file.
|
||||
// genOpAsmInterface();
|
||||
genOpNameGetter();
|
||||
//// genOpNameGetter();
|
||||
// genNamedOperandGetters();
|
||||
// genNamedOperandSetters();
|
||||
// genNamedResultGetters();
|
||||
// genNamedRegionGetters();
|
||||
genNamedSuccessorGetters();
|
||||
genAttrGetters();
|
||||
genAttrSetters();
|
||||
//// genNamedSuccessorGetters();
|
||||
//// genAttrGetters();
|
||||
//// genAttrSetters();
|
||||
// genOptionalAttrRemovers();
|
||||
genBuilder();
|
||||
// genParser();
|
||||
|
@ -1179,7 +1191,7 @@ void OpEmitter::genSeparateArgParamBuilder() {
|
|||
buildParamList(paramList, resultNames, paramKind, attrType);
|
||||
|
||||
auto *m = opClass.addMethodAndPrune(
|
||||
"::builder::Op", "build", OpMethod::MP_Static, std::move(paramList));
|
||||
"builder::Op", "build", OpMethod::MP_Static, std::move(paramList));
|
||||
// If the builder is redundant, skip generating the method.
|
||||
if (!m)
|
||||
return;
|
||||
|
@ -1329,7 +1341,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
|
|||
resultTypeNames.reserve(numResults);
|
||||
|
||||
// paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
|
||||
paramList.emplace_back("::builder::Builder &", "builder");
|
||||
paramList.emplace_back("builder::Builder &", "builder");
|
||||
// paramList.emplace_back("::mlir::OperationState &", builderOpState);
|
||||
|
||||
switch (typeParamKind) {
|
||||
|
@ -1344,7 +1356,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
|
|||
resultName = std::string(formatv("resultType{0}", i));
|
||||
|
||||
StringRef type =
|
||||
result.isVariadic() ? "std::vector<::builder::Type>" : "::builder::Type";
|
||||
result.isVariadic() ? "std::vector<builder::Type>" : "builder::Type";
|
||||
OpMethodParameter::Property properties = OpMethodParameter::PP_None;
|
||||
if (result.isOptional())
|
||||
properties = OpMethodParameter::PP_Optional;
|
||||
|
@ -1371,7 +1383,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
|
|||
if (argument.is<tblgen::NamedTypeConstraint *>()) {
|
||||
const auto &operand = op.getOperand(numOperands);
|
||||
StringRef type =
|
||||
operand.isVariadic() ? "std::vector<::builder::Op>" : "::builder::Op";
|
||||
operand.isVariadic() ? "std::vector<builder::Op>" : "builder::Op";
|
||||
OpMethodParameter::Property properties = OpMethodParameter::PP_None;
|
||||
if (operand.isOptional())
|
||||
properties = OpMethodParameter::PP_Optional;
|
||||
|
@ -1437,24 +1449,11 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
|
|||
auto attrs = op.getAttributes();
|
||||
SmallVector<std::string, 4> newAttrs;
|
||||
|
||||
// for(auto o : operands){
|
||||
// body << " operands name: "<<o.name<<" getCPPClassName:"<<o.constraint.getCPPClassName()<<"\n";
|
||||
// }
|
||||
// for(auto a : attrs){
|
||||
// body << " attrs name: "<<a.name<<" type: "<< a.attr.getStorageType().str()<<"\n";
|
||||
// std::string attrType = a.attr.getStorageType().str();
|
||||
// auto ff = typeMapMLIR.find(attrType);
|
||||
// if(ff != typeMapMLIR.end()){
|
||||
// body << "// BBBBBB \n";
|
||||
// std::string attrName = a.name.str();
|
||||
// ff->second.ConvertToMlir(attrName,body);
|
||||
// }
|
||||
// }
|
||||
|
||||
|
||||
// body << "// AAAAAA \n";
|
||||
// for(auto p : paramList){
|
||||
// body << " AAA type"<<p.getType()<<" name"<<p.getName()<<"\n";
|
||||
// if (attrType == "::mlir::DenseIntElementsAttr") {
|
||||
// body << " // BBBBBBBB getStorageType:" << a.attr.getStorageType().str()
|
||||
// << "\n";
|
||||
// body << " // BBBBBBBB getReturnType:" << a.attr.getReturnType().str()
|
||||
// << "\n";
|
||||
// }
|
||||
|
||||
body << " auto b = builder.GetImpl();\n";
|
||||
|
@ -1463,16 +1462,6 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
|
|||
body << " auto ctx = b->GetContext();\n";
|
||||
for (auto a : attrs) {
|
||||
std::string attrType = a.attr.getStorageType().str();
|
||||
|
||||
if(attrType == "::mlir::DenseIntElementsAttr")
|
||||
{
|
||||
body << " // BBBBBBBB getStorageType:"<< a.attr.getStorageType().str() <<"\n";
|
||||
body << " // BBBBBBBB getReturnType:"<< a.attr.getReturnType().str() <<"\n";
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
auto typePair = typeMapMLIR.find(attrType);
|
||||
if (typePair != typeMapMLIR.end()) {
|
||||
std::string attrName = a.name.str();
|
||||
|
@ -1480,20 +1469,42 @@ if(attrType == "::mlir::DenseIntElementsAttr")
|
|||
newAttrs.emplace_back(mlirName);
|
||||
}
|
||||
}
|
||||
int index = 0;
|
||||
for (auto v : operands) {
|
||||
std::string name =
|
||||
v.name.empty() ? "odsArg" + std::to_string(index) : v.name.str();
|
||||
index++;
|
||||
if (v.isVariadic()) {
|
||||
body << " std::vector<mlir::Value> " << name << "_v;\n";
|
||||
body << " for(auto v : " << name << "){\n " << name
|
||||
<< "_v.push_back(v.GetImpl()->GetResult());\n }"
|
||||
<< "\n";
|
||||
}
|
||||
}
|
||||
|
||||
body << " mlir::" << op.getDialectName() << "::" << op.getCppClassName()
|
||||
<< " currentOp =\n";
|
||||
body << " opBuilder.create<mlir::" << op.getDialectName()
|
||||
<< "::" << op.getDialectName() << ">(\n";
|
||||
<< "::" << op.getCppClassName() << ">(\n";
|
||||
body << " loc";
|
||||
|
||||
index = 0;
|
||||
std::for_each(operands.begin(), operands.end(), [&](NamedTypeConstraint &v) {
|
||||
body << ",\n " << v.name << ".getResult()";
|
||||
std::string name =
|
||||
v.name.empty() ? "odsArg_" + std::to_string(index) : v.name.str();
|
||||
index++;
|
||||
if (v.isVariadic()) {
|
||||
body << ",\n " << name << "_v";
|
||||
} else {
|
||||
body << ",\n " << name << ".GetImpl()->GetResult()";
|
||||
}
|
||||
});
|
||||
std::for_each(newAttrs.begin(), newAttrs.end(),
|
||||
[&](std::string &n) { body << ",\n " << n; });
|
||||
body << "\n );\n";
|
||||
body << " builder::" << op.getCppClassName() << " builderOp;\n";
|
||||
body << " builder::mhlo::" << op.getCppClassName() << " builderOp;\n";
|
||||
body << " auto opImpl = builderOp.GetImpl();\n";
|
||||
body << " opImpl.SetOperation(currentOp.getOperation());\n";
|
||||
body << " opImpl->SetOperation(currentOp.getOperation());\n";
|
||||
body << " return builderOp;\n";
|
||||
|
||||
// // Push all operands to the result.
|
||||
|
@ -2109,7 +2120,8 @@ void OpEmitter::genTraits() {
|
|||
void OpEmitter::genOpNameGetter() {
|
||||
auto *method = opClass.addMethodAndPrune(
|
||||
"std::string", "getOperationName",
|
||||
OpMethod::Property(OpMethod::MP_Static | OpMethod::MP_Constexpr));
|
||||
OpMethod::Property(OpMethod::MP_Static));
|
||||
// OpMethod::Property(OpMethod::MP_Static | OpMethod::MP_Constexpr));
|
||||
method->body() << " return std::string(\"" << op.getOperationName()
|
||||
<< "\");";
|
||||
}
|
||||
|
@ -2196,7 +2208,7 @@ static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
|
|||
std::string className = Operator(def).getQualCppClassName();
|
||||
llvm::SplitString(StringRef(className), namespaces, StringRef("::"));
|
||||
if (namespaces.begin() != namespaces.end())
|
||||
os << "::builder::mhlo::" << namespaces.back().str();
|
||||
os << "builder::mhlo::" << namespaces.back().str();
|
||||
},
|
||||
[&os]() { os << ",\n"; });
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue