add more types and try to build pass
This commit is contained in:
parent
e88366b851
commit
4a9b201c4e
88
BUILD
88
BUILD
|
@ -140,39 +140,6 @@ cc_binary(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "mlir-hlo-builder",
|
|
||||||
srcs = glob([
|
|
||||||
"tools/mlir-tblgen-builder/Builder/*.h",
|
|
||||||
"tools/mlir-tblgen-builder/Builder/*.cpp",
|
|
||||||
]),
|
|
||||||
deps = [
|
|
||||||
"@llvm-project//mlir:MlirTableGenMain",
|
|
||||||
"@llvm-project//mlir:Support",
|
|
||||||
"@llvm-project//mlir:IR",
|
|
||||||
"@llvm-project//llvm:Support",
|
|
||||||
"@llvm-project//llvm:TableGen",
|
|
||||||
"@llvm-project//llvm:config",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_test(
|
|
||||||
name = "mlir-tblgen-builder-basic",
|
|
||||||
srcs = [
|
|
||||||
"tests/mlir-tblgen-builder/test_basic.cpp",
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
":hlo_ops_builder_gen",
|
|
||||||
":mlir-hlo-builder",
|
|
||||||
"@llvm-project//mlir:MlirTableGenMain",
|
|
||||||
"@llvm-project//mlir:Support",
|
|
||||||
"@llvm-project//mlir:IR",
|
|
||||||
"@llvm-project//llvm:Support",
|
|
||||||
"@llvm-project//llvm:TableGen",
|
|
||||||
"@llvm-project//llvm:config",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
gentbl_cc_library(
|
gentbl_cc_library(
|
||||||
name = "hlo_ops_builder_gen",
|
name = "hlo_ops_builder_gen",
|
||||||
strip_include_prefix = "include",
|
strip_include_prefix = "include",
|
||||||
|
@ -195,6 +162,61 @@ gentbl_cc_library(
|
||||||
deps = [":hlo_ops_td_files"],
|
deps = [":hlo_ops_td_files"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "mlir-hlo-builder",
|
||||||
|
srcs = glob([
|
||||||
|
"tools/mlir-tblgen-builder/Builder/*.h",
|
||||||
|
"tools/mlir-tblgen-builder/Builder/*.cpp",
|
||||||
|
]),
|
||||||
|
includes = ["include"],
|
||||||
|
deps = [
|
||||||
|
":all_passes",
|
||||||
|
":disc_ral",
|
||||||
|
":hlo",
|
||||||
|
":lhlo",
|
||||||
|
":lhlo_gpu",
|
||||||
|
":hlo_ops_builder_gen",
|
||||||
|
"@llvm-project//mlir:MlirTableGenMain",
|
||||||
|
"@llvm-project//mlir:Support",
|
||||||
|
"@llvm-project//mlir:IR",
|
||||||
|
"@llvm-project//llvm:Support",
|
||||||
|
"@llvm-project//llvm:TableGen",
|
||||||
|
"@llvm-project//llvm:config",
|
||||||
|
# "@llvm-project//llvm:AllTargetsAsmParsers",
|
||||||
|
# "@llvm-project//llvm:AllTargetsCodeGens",
|
||||||
|
# "@llvm-project//llvm:Core",
|
||||||
|
# "@llvm-project//llvm:ExecutionEngine",
|
||||||
|
# "@llvm-project//llvm:Option",
|
||||||
|
# "@llvm-project//llvm:OrcJIT",
|
||||||
|
# "@llvm-project//llvm:Support",
|
||||||
|
# "@llvm-project//llvm:Target",
|
||||||
|
# "@llvm-project//mlir:AllPassesAndDialects",
|
||||||
|
# "@llvm-project//mlir:IR",
|
||||||
|
# "@llvm-project//mlir:MlirOptLib",
|
||||||
|
# "@llvm-project//mlir:Pass",
|
||||||
|
# "@llvm-project//mlir:Support",
|
||||||
|
# "@llvm-project//mlir:MlirJitRunner",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "mlir-tblgen-builder-basic",
|
||||||
|
srcs = [
|
||||||
|
"tests/mlir-tblgen-builder/test_basic.cpp",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":mlir-hlo-builder",
|
||||||
|
"@llvm-project//mlir:MlirTableGenMain",
|
||||||
|
"@llvm-project//mlir:Support",
|
||||||
|
"@llvm-project//mlir:IR",
|
||||||
|
"@llvm-project//llvm:Support",
|
||||||
|
"@llvm-project//llvm:TableGen",
|
||||||
|
"@llvm-project//llvm:config",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
gentbl_cc_library(
|
gentbl_cc_library(
|
||||||
name = "hlo_ops_base_inc_gen",
|
name = "hlo_ops_base_inc_gen",
|
||||||
strip_include_prefix = "include",
|
strip_include_prefix = "include",
|
||||||
|
|
|
@ -1,12 +1,15 @@
|
||||||
#include "include/mlir-hlo/Dialect/mhlo/IR/hlo_builder.h.inc"
|
|
||||||
#include "tools/mlir-tblgen-builder/Builder/Attribute.h"
|
|
||||||
#include "tools/mlir-tblgen-builder/Builder/Builder.h"
|
#include "tools/mlir-tblgen-builder/Builder/Builder.h"
|
||||||
#include "tools/mlir-tblgen-builder/Builder/Op.h"
|
|
||||||
#include "tools/mlir-tblgen-builder/Builder/PrimitiveType.h"
|
|
||||||
#include "tools/mlir-tblgen-builder/Builder/Shape.h"
|
|
||||||
#include "tools/mlir-tblgen-builder/Builder/Tensor.h"
|
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
builder::Integer i(322);
|
builder::Builder builder;
|
||||||
|
builder::Shape shape({100, 100});
|
||||||
|
auto pType = builder::PrimitiveType::F32();
|
||||||
|
builder::Type type(shape, pType);
|
||||||
|
|
||||||
|
builder::Tensor tensor(std::vector<int>(100));
|
||||||
|
auto op = builder::mhlo::ConstOp::build(builder, type, tensor);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// static ::builder::Op build(::builder::Builder &builder, ::builder::Type
|
||||||
|
// output, ::builder::Tensor value);
|
||||||
|
|
|
@ -1,9 +1,6 @@
|
||||||
#include "Attribute.h"
|
#include "Attribute.h"
|
||||||
|
|
||||||
#include "AttributeImpl.h"
|
#include "AttributeImpl.h"
|
||||||
#include "Shape.h"
|
|
||||||
#include "Tensor.h"
|
|
||||||
#include "TensorImpl.h"
|
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "memory.h"
|
#include "memory.h"
|
||||||
// #include "mlir/Dialect/StandardOps/Ops.h"
|
// #include "mlir/Dialect/StandardOps/Ops.h"
|
||||||
|
@ -15,8 +12,93 @@
|
||||||
|
|
||||||
namespace builder {
|
namespace builder {
|
||||||
|
|
||||||
|
inline bool PrimitiveType::operator==(const PrimitiveType& pt) {
|
||||||
|
return impl_ == pt.impl_;
|
||||||
|
}
|
||||||
|
PrimitiveType::PrimitiveType(std::shared_ptr<PrimitiveType::Impl> impl)
|
||||||
|
: impl_(impl) {}
|
||||||
|
PrimitiveType PrimitiveType::PRED() {
|
||||||
|
PrimitiveType p(
|
||||||
|
std::make_shared<PrimitiveType::Impl>(PrimitiveType::Impl::pType::PRED));
|
||||||
|
return p;
|
||||||
|
}
|
||||||
|
PrimitiveType PrimitiveType::S8() {
|
||||||
|
PrimitiveType p(
|
||||||
|
std::make_shared<PrimitiveType::Impl>(PrimitiveType::Impl::pType::S8));
|
||||||
|
return p;
|
||||||
|
}
|
||||||
|
PrimitiveType PrimitiveType::S16() {
|
||||||
|
PrimitiveType p(
|
||||||
|
std::make_shared<PrimitiveType::Impl>(PrimitiveType::Impl::pType::S16));
|
||||||
|
return p;
|
||||||
|
}
|
||||||
|
PrimitiveType PrimitiveType::F32() {
|
||||||
|
PrimitiveType p(
|
||||||
|
std::make_shared<PrimitiveType::Impl>(PrimitiveType::Impl::pType::F32));
|
||||||
|
return p;
|
||||||
|
}
|
||||||
|
PrimitiveType PrimitiveType::S32() {
|
||||||
|
PrimitiveType p(
|
||||||
|
std::make_shared<PrimitiveType::Impl>(PrimitiveType::Impl::pType::S32));
|
||||||
|
return p;
|
||||||
|
}
|
||||||
|
PrimitiveType PrimitiveType::S64() {
|
||||||
|
PrimitiveType p(
|
||||||
|
std::make_shared<PrimitiveType::Impl>(PrimitiveType::Impl::pType::S64));
|
||||||
|
return p;
|
||||||
|
}
|
||||||
|
|
||||||
|
Shape::Shape(std::vector<int64_t> dims)
|
||||||
|
: impl_(std::make_shared<Shape::Impl>(dims)) {}
|
||||||
|
|
||||||
|
Type::Type(Shape& shape, PrimitiveType& primitiveType)
|
||||||
|
: impl_(std::make_shared<Type::Impl>(shape, primitiveType)) {}
|
||||||
|
|
||||||
Integer::Integer(int value) : impl_(std::make_shared<Integer::Impl>(value)) {}
|
Integer::Integer(int value) : impl_(std::make_shared<Integer::Impl>(value)) {}
|
||||||
Integer::Integer(int64_t value)
|
Integer::Integer(int64_t value)
|
||||||
: impl_(std::make_shared<Integer::Impl>(value)) {}
|
: impl_(std::make_shared<Integer::Impl>(value)) {}
|
||||||
|
|
||||||
|
Float::Float(float value) : impl_(std::make_shared<Float::Impl>(value)) {}
|
||||||
|
Float::Float(double value) : impl_(std::make_shared<Float::Impl>(value)) {}
|
||||||
|
|
||||||
|
int64_t Array::GetSize() { return impl_->GetSize(); }
|
||||||
|
PrimitiveType Array::GetType() { return impl_->GetType(); }
|
||||||
|
Array::Array(std::vector<int> value)
|
||||||
|
: impl_(std::make_shared<Array::Impl>(value)) {}
|
||||||
|
Array::Array(std::vector<int64_t> value)
|
||||||
|
: impl_(std::make_shared<Array::Impl>(value)) {}
|
||||||
|
Array::Array(std::vector<std::string> value)
|
||||||
|
: impl_(std::make_shared<Array::Impl>(value)) {}
|
||||||
|
|
||||||
|
Tensor::Tensor(Shape& shape, PrimitiveType& primitiveType)
|
||||||
|
: impl_(std::make_shared<Impl>(shape, primitiveType)) {}
|
||||||
|
Shape Tensor::GetShape() { return impl_->GetShape(); }
|
||||||
|
PrimitiveType Tensor::GetType() { return impl_->GetType(); }
|
||||||
|
Tensor::Tensor(std::vector<int> value)
|
||||||
|
: impl_(std::make_shared<Tensor::Impl>(value)) {}
|
||||||
|
Tensor::Tensor(std::vector<int64_t> value)
|
||||||
|
: impl_(std::make_shared<Tensor::Impl>(value)) {}
|
||||||
|
Tensor::Tensor(std::vector<float> value)
|
||||||
|
: impl_(std::make_shared<Tensor::Impl>(value)) {}
|
||||||
|
|
||||||
|
TensorInt::TensorInt(Shape& shape, PrimitiveType& primitiveType)
|
||||||
|
: impl_(std::make_shared<Impl>(shape, primitiveType)) {}
|
||||||
|
Shape TensorInt::GetShape() { return impl_->GetShape(); }
|
||||||
|
PrimitiveType TensorInt::GetType() { return impl_->GetType(); }
|
||||||
|
|
||||||
|
ChannelHandle::ChannelHandle()
|
||||||
|
: impl_(std::make_shared<ChannelHandle::Impl>()) {}
|
||||||
|
|
||||||
|
ConvDimensionNumbers::ConvDimensionNumbers()
|
||||||
|
: impl_(std::make_shared<ConvDimensionNumbers::Impl>()) {}
|
||||||
|
|
||||||
|
DotDimensionNumbers::DotDimensionNumbers()
|
||||||
|
: impl_(std::make_shared<DotDimensionNumbers::Impl>()) {}
|
||||||
|
|
||||||
|
GatherDimensionNumbers::GatherDimensionNumbers()
|
||||||
|
: impl_(std::make_shared<GatherDimensionNumbers::Impl>()) {}
|
||||||
|
|
||||||
|
ScatterDimensionNumbers::ScatterDimensionNumbers()
|
||||||
|
: impl_(std::make_shared<ScatterDimensionNumbers::Impl>()) {}
|
||||||
|
|
||||||
} // namespace builder
|
} // namespace builder
|
|
@ -1,16 +1,57 @@
|
||||||
#ifndef BUILDER_ATTRIBUTE_H_
|
#ifndef BUILDER_ATTRIBUTE_H_
|
||||||
#define BUILDER_ATTRIBUTE_H_
|
#define BUILDER_ATTRIBUTE_H_
|
||||||
|
|
||||||
#include <iostream>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
namespace builder {
|
namespace builder {
|
||||||
|
|
||||||
|
class PrimitiveType {
|
||||||
|
public:
|
||||||
|
static PrimitiveType PRED();
|
||||||
|
static PrimitiveType S8();
|
||||||
|
static PrimitiveType S16();
|
||||||
|
static PrimitiveType F32();
|
||||||
|
static PrimitiveType S32();
|
||||||
|
static PrimitiveType S64();
|
||||||
|
inline bool operator==(const PrimitiveType& pt);
|
||||||
|
|
||||||
|
class Impl;
|
||||||
|
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||||
|
PrimitiveType(std::shared_ptr<Impl>);
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<Impl> impl_;
|
||||||
|
// PrimitiveType(const PrimitiveType&) = default;
|
||||||
|
};
|
||||||
|
|
||||||
|
class Shape {
|
||||||
|
public:
|
||||||
|
Shape(std::vector<int64_t> dims);
|
||||||
|
class Impl;
|
||||||
|
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<Impl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
class Integer {
|
class Integer {
|
||||||
public:
|
public:
|
||||||
Integer(int value);
|
Integer(int value);
|
||||||
Integer(int64_t value);
|
Integer(int64_t value);
|
||||||
class Impl;
|
class Impl;
|
||||||
|
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<Impl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class Float {
|
||||||
|
public:
|
||||||
|
Float(float value);
|
||||||
|
Float(double value);
|
||||||
|
class Impl;
|
||||||
|
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<Impl> impl_;
|
std::shared_ptr<Impl> impl_;
|
||||||
|
@ -18,6 +59,43 @@ class Integer {
|
||||||
|
|
||||||
class Array {
|
class Array {
|
||||||
public:
|
public:
|
||||||
|
Array(std::vector<int> value);
|
||||||
|
Array(std::vector<int64_t> value);
|
||||||
|
Array(std::vector<std::string> value);
|
||||||
|
|
||||||
|
int64_t GetSize();
|
||||||
|
PrimitiveType GetType();
|
||||||
|
|
||||||
|
class Impl;
|
||||||
|
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<Impl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class Tensor {
|
||||||
|
public:
|
||||||
|
Tensor(Shape& shape, PrimitiveType& primitiveType);
|
||||||
|
Tensor(std::vector<int> value);
|
||||||
|
Tensor(std::vector<int64_t> value);
|
||||||
|
Tensor(std::vector<float> value);
|
||||||
|
|
||||||
|
Shape GetShape();
|
||||||
|
PrimitiveType GetType();
|
||||||
|
|
||||||
|
class Impl;
|
||||||
|
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<Impl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class TensorInt {
|
||||||
|
public:
|
||||||
|
TensorInt(Shape& shape, PrimitiveType& primitiveType);
|
||||||
|
Shape GetShape();
|
||||||
|
PrimitiveType GetType();
|
||||||
|
|
||||||
class Impl;
|
class Impl;
|
||||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||||
|
|
||||||
|
@ -27,6 +105,7 @@ class Array {
|
||||||
|
|
||||||
class Type {
|
class Type {
|
||||||
public:
|
public:
|
||||||
|
Type(Shape& shape, PrimitiveType& primitiveType);
|
||||||
class Impl;
|
class Impl;
|
||||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||||
|
|
||||||
|
@ -34,16 +113,56 @@ class Type {
|
||||||
std::shared_ptr<Impl> impl_;
|
std::shared_ptr<Impl> impl_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// template <typename T>
|
class ChannelHandle {
|
||||||
// class DenseIntElementsAttr {
|
public:
|
||||||
// public:
|
ChannelHandle();
|
||||||
// explicit DenseIntElementsAttr(Tensor);
|
class Impl;
|
||||||
// class Impl;
|
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||||
// std::shared_ptr<Impl> GetImpl() { return impl_; }
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<Impl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class ConvDimensionNumbers {
|
||||||
|
public:
|
||||||
|
ConvDimensionNumbers();
|
||||||
|
class Impl;
|
||||||
|
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<Impl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class DotDimensionNumbers {
|
||||||
|
public:
|
||||||
|
DotDimensionNumbers();
|
||||||
|
class Impl;
|
||||||
|
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<Impl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class GatherDimensionNumbers {
|
||||||
|
public:
|
||||||
|
GatherDimensionNumbers();
|
||||||
|
class Impl;
|
||||||
|
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<Impl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class ScatterDimensionNumbers {
|
||||||
|
public:
|
||||||
|
ScatterDimensionNumbers();
|
||||||
|
class Impl;
|
||||||
|
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<Impl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
// private:
|
|
||||||
// std::shared_ptr<Impl> impl_;
|
|
||||||
// };
|
|
||||||
|
|
||||||
} // namespace builder
|
} // namespace builder
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,13 @@
|
||||||
#ifndef BUILDER_TYPEIMPL_
|
#ifndef BUILDER_ATTRIBUTEIMPL_H_
|
||||||
#define BUILDER_TYPEIMPL_
|
#define BUILDER_ATTRIBUTEIMPL_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "Builder.h"
|
#include "Builder.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
@ -13,45 +18,306 @@
|
||||||
|
|
||||||
namespace builder {
|
namespace builder {
|
||||||
|
|
||||||
|
class PrimitiveType::Impl {
|
||||||
|
public:
|
||||||
|
enum pType {
|
||||||
|
PRIMITIVE_TYPE_INVALID = 0,
|
||||||
|
PRED = 1,
|
||||||
|
S8 = 2,
|
||||||
|
S16 = 3,
|
||||||
|
S32 = 4,
|
||||||
|
S64 = 5,
|
||||||
|
U8 = 6,
|
||||||
|
U16 = 7,
|
||||||
|
U32 = 8,
|
||||||
|
U64 = 9,
|
||||||
|
F16 = 10,
|
||||||
|
F32 = 11,
|
||||||
|
BF16 = 16,
|
||||||
|
F64 = 12,
|
||||||
|
C64 = 15,
|
||||||
|
C128 = 18,
|
||||||
|
TUPLE = 13,
|
||||||
|
OPAQUE_TYPE = 14,
|
||||||
|
TOKEN = 17
|
||||||
|
};
|
||||||
|
|
||||||
|
Impl(pType t) : t_(t) {
|
||||||
|
switch (t) {
|
||||||
|
case PRIMITIVE_TYPE_INVALID:
|
||||||
|
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
||||||
|
return mlir::NoneType::get(context);
|
||||||
|
};
|
||||||
|
break;
|
||||||
|
case PRED:
|
||||||
|
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
||||||
|
return mlir::IntegerType::get(
|
||||||
|
context, 1, mlir::IntegerType::SignednessSemantics::Signed);
|
||||||
|
};
|
||||||
|
break;
|
||||||
|
case S8:
|
||||||
|
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
||||||
|
return mlir::IntegerType::get(
|
||||||
|
context, 8, mlir::IntegerType::SignednessSemantics::Signed);
|
||||||
|
};
|
||||||
|
break;
|
||||||
|
case S16:
|
||||||
|
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
||||||
|
return mlir::IntegerType::get(
|
||||||
|
context, 16, mlir::IntegerType::SignednessSemantics::Signed);
|
||||||
|
};
|
||||||
|
break;
|
||||||
|
case F32:
|
||||||
|
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
||||||
|
return mlir::Float32Type::get(context);
|
||||||
|
};
|
||||||
|
break;
|
||||||
|
case S32:
|
||||||
|
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
||||||
|
return mlir::IntegerType::get(
|
||||||
|
context, 32, mlir::IntegerType::SignednessSemantics::Signed);
|
||||||
|
};
|
||||||
|
break;
|
||||||
|
case S64:
|
||||||
|
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
||||||
|
return mlir::IntegerType::get(
|
||||||
|
context, 64, mlir::IntegerType::SignednessSemantics::Signed);
|
||||||
|
};
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
||||||
|
return mlir::NoneType::get(context);
|
||||||
|
};
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// class BFloat16Type;
|
||||||
|
// class ComplexType;
|
||||||
|
// class Float128Type;
|
||||||
|
// class Float16Type;
|
||||||
|
// class Float32Type;
|
||||||
|
// class Float64Type;
|
||||||
|
// class Float80Type;
|
||||||
|
// class FunctionType;
|
||||||
|
// class IndexType;
|
||||||
|
// class IntegerType;
|
||||||
|
// class MemRefType;
|
||||||
|
// class NoneType;
|
||||||
|
// class OpaqueType;
|
||||||
|
// class RankedTensorType;
|
||||||
|
// class TupleType;
|
||||||
|
// class UnrankedMemRefType;
|
||||||
|
// class UnrankedTensorType;
|
||||||
|
// class VectorType;
|
||||||
|
|
||||||
|
inline bool operator==(const PrimitiveType::Impl &impl) {
|
||||||
|
return t_ == impl.t_;
|
||||||
|
}
|
||||||
|
std::function<mlir::Type(mlir::MLIRContext *context)> GetMlirType;
|
||||||
|
|
||||||
|
private:
|
||||||
|
pType t_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class Shape::Impl {
|
||||||
|
public:
|
||||||
|
Impl(std::vector<int64_t> dims) : dims_(dims) {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<int64_t> dims_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class Type::Impl {
|
||||||
|
public:
|
||||||
|
Impl(Shape &shape, PrimitiveType &primitiveType)
|
||||||
|
: shape_(shape), primitiveType_(primitiveType) {}
|
||||||
|
Shape GetShape() { return shape_; }
|
||||||
|
PrimitiveType GetType() { return primitiveType_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
Shape shape_;
|
||||||
|
PrimitiveType primitiveType_;
|
||||||
|
};
|
||||||
|
|
||||||
class Integer::Impl {
|
class Integer::Impl {
|
||||||
public:
|
public:
|
||||||
Impl(){};
|
|
||||||
|
|
||||||
Impl(int value) {
|
Impl(int value) {
|
||||||
GetAttr = [=](mlir::MLIRContext *context) -> mlir::IntegerAttr {
|
GetAttr = [=](mlir::MLIRContext *context) -> mlir::IntegerAttr {
|
||||||
auto int32_type = mlir::IntegerType::get(context, 32);
|
auto int32_type = mlir::IntegerType::get(context, 32);
|
||||||
return mlir::IntegerAttr::get(int32_type, value);
|
return mlir::IntegerAttr::get(int32_type, value);
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
Impl(int64_t value)
|
Impl(int64_t value) {
|
||||||
: GetAttr([=](mlir::MLIRContext *context) -> mlir::IntegerAttr {
|
GetAttr = [=](mlir::MLIRContext *context) -> mlir::IntegerAttr {
|
||||||
auto int32_type = mlir::IntegerType::get(context, 64);
|
auto int64_type = mlir::IntegerType::get(context, 64);
|
||||||
return mlir::IntegerAttr::get(int32_type, value);
|
return mlir::IntegerAttr::get(int64_type, value);
|
||||||
}){};
|
};
|
||||||
|
};
|
||||||
std::function<mlir::IntegerAttr(mlir::MLIRContext *context)> GetAttr;
|
std::function<mlir::IntegerAttr(mlir::MLIRContext *context)> GetAttr;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class Float::Impl {
|
||||||
|
public:
|
||||||
|
Impl(float value) {
|
||||||
|
GetAttr = [=](mlir::MLIRContext *context) -> mlir::FloatAttr {
|
||||||
|
auto float32_type = mlir::Float32Type::get(context);
|
||||||
|
return mlir::FloatAttr::get(float32_type, value);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
Impl(double value) {
|
||||||
|
GetAttr = [=](mlir::MLIRContext *context) -> mlir::FloatAttr {
|
||||||
|
auto float64_type = mlir::Float64Type::get(context);
|
||||||
|
return mlir::FloatAttr::get(float64_type, value);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
std::function<mlir::FloatAttr(mlir::MLIRContext *context)> GetAttr;
|
||||||
|
|
||||||
|
private:
|
||||||
|
};
|
||||||
|
|
||||||
class Array::Impl {
|
class Array::Impl {
|
||||||
public:
|
public:
|
||||||
Impl() = default;
|
Impl(std::vector<int> value)
|
||||||
|
: size_({value.size()}), primitiveType_(PrimitiveType::S32()) {
|
||||||
|
// GetAttr = [=](mlir::MLIRContext *context) -> mlir::ArrayAttr {
|
||||||
|
// auto type =
|
||||||
|
// mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
|
||||||
|
// mlir::IntegerType::get(context, 32));
|
||||||
|
// return mlir::ArrayAttr::get(type, llvm::makeArrayRef(value));
|
||||||
|
// };
|
||||||
|
}
|
||||||
|
Impl(std::vector<int64_t> value)
|
||||||
|
: size_({value.size()}), primitiveType_(PrimitiveType::S64()) {
|
||||||
|
// GetAttr = [=](mlir::MLIRContext *context) -> mlir::ArrayAttr {
|
||||||
|
// auto type =
|
||||||
|
// mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
|
||||||
|
// mlir::IntegerType::get(context, 64));
|
||||||
|
// return mlir::ArrayAttr::get(type, llvm::makeArrayRef(value));
|
||||||
|
// };
|
||||||
|
}
|
||||||
|
Impl(std::vector<std::string> value)
|
||||||
|
: size_({value.size()}), primitiveType_(PrimitiveType::F32()) {
|
||||||
|
// GetAttr = [=](mlir::MLIRContext *context) -> mlir::ArrayAttr {
|
||||||
|
// auto type =
|
||||||
|
// mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
|
||||||
|
// mlir::FloatType::getF32(context));
|
||||||
|
// return mlir::ArrayAttr::get(type, llvm::makeArrayRef(value));
|
||||||
|
// };
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t GetSize() { return size_; }
|
||||||
|
PrimitiveType GetType() { return primitiveType_; }
|
||||||
|
|
||||||
|
std::function<mlir::ArrayAttr(mlir::MLIRContext *context)> GetAttr;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
int64_t size_;
|
||||||
|
PrimitiveType primitiveType_;
|
||||||
};
|
};
|
||||||
class Type::Impl {
|
|
||||||
|
class Tensor::Impl {
|
||||||
public:
|
public:
|
||||||
Impl() = default;
|
Impl(std::vector<int> value)
|
||||||
|
: shape_({value.size()}), primitiveType_(PrimitiveType::S32()) {
|
||||||
|
GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
|
||||||
|
auto type =
|
||||||
|
mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
|
||||||
|
mlir::IntegerType::get(context, 32));
|
||||||
|
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
|
||||||
|
};
|
||||||
|
}
|
||||||
|
Impl(std::vector<int64_t> value)
|
||||||
|
: shape_({value.size()}), primitiveType_(PrimitiveType::S64()) {
|
||||||
|
GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
|
||||||
|
auto type =
|
||||||
|
mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
|
||||||
|
mlir::IntegerType::get(context, 64));
|
||||||
|
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
|
||||||
|
};
|
||||||
|
}
|
||||||
|
Impl(std::vector<float> value)
|
||||||
|
: shape_({value.size()}), primitiveType_(PrimitiveType::F32()) {
|
||||||
|
GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
|
||||||
|
auto type =
|
||||||
|
mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
|
||||||
|
mlir::FloatType::getF32(context));
|
||||||
|
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
|
||||||
|
};
|
||||||
|
}
|
||||||
|
Impl(Shape &shape, PrimitiveType &primitiveType)
|
||||||
|
: shape_(shape), primitiveType_(primitiveType) {}
|
||||||
|
|
||||||
|
Shape GetShape() { return shape_; }
|
||||||
|
PrimitiveType GetType() { return primitiveType_; }
|
||||||
|
|
||||||
|
std::function<mlir::DenseElementsAttr(mlir::MLIRContext *context)> GetAttr;
|
||||||
|
|
||||||
|
private:
|
||||||
|
Shape shape_;
|
||||||
|
PrimitiveType primitiveType_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class TensorInt::Impl {
|
||||||
|
public:
|
||||||
|
Impl(Shape &shape, PrimitiveType &primitiveType)
|
||||||
|
: shape_(shape), primitiveType_(primitiveType) {}
|
||||||
|
|
||||||
|
Shape GetShape() { return shape_; }
|
||||||
|
PrimitiveType GetType() { return primitiveType_; }
|
||||||
|
|
||||||
|
std::function<mlir::DenseIntElementsAttr(mlir::MLIRContext *context)> GetAttr;
|
||||||
|
|
||||||
|
private:
|
||||||
|
Shape shape_;
|
||||||
|
PrimitiveType primitiveType_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class ChannelHandle::Impl {
|
||||||
|
public:
|
||||||
|
Impl() {}
|
||||||
|
std::function<mlir::mhlo::ChannelHandle(mlir::MLIRContext *context)> GetAttr;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
};
|
};
|
||||||
|
|
||||||
// class DenseIntElementsAttr::Impl {
|
class ConvDimensionNumbers::Impl {
|
||||||
// public:
|
public:
|
||||||
// Impl() = default;
|
Impl() {}
|
||||||
|
std::function<mlir::mhlo::ConvDimensionNumbers(mlir::MLIRContext *context)>
|
||||||
|
GetAttr;
|
||||||
|
|
||||||
// private:
|
private:
|
||||||
// };
|
};
|
||||||
|
|
||||||
|
class DotDimensionNumbers::Impl {
|
||||||
|
public:
|
||||||
|
Impl() {}
|
||||||
|
std::function<mlir::mhlo::DotDimensionNumbers(mlir::MLIRContext *context)>
|
||||||
|
GetAttr;
|
||||||
|
|
||||||
|
private:
|
||||||
|
};
|
||||||
|
|
||||||
|
class GatherDimensionNumbers::Impl {
|
||||||
|
public:
|
||||||
|
Impl() {}
|
||||||
|
std::function<mlir::mhlo::GatherDimensionNumbers(mlir::MLIRContext *context)>
|
||||||
|
GetAttr;
|
||||||
|
|
||||||
|
private:
|
||||||
|
};
|
||||||
|
|
||||||
|
class ScatterDimensionNumbers::Impl {
|
||||||
|
public:
|
||||||
|
Impl() {}
|
||||||
|
std::function<mlir::mhlo::ScatterDimensionNumbers(mlir::MLIRContext *context)>
|
||||||
|
GetAttr;
|
||||||
|
|
||||||
|
private:
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace builder
|
} // namespace builder
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,13 @@
|
||||||
// #include "mlir/IR/StandardTypes.h"
|
// #include "mlir/IR/StandardTypes.h"
|
||||||
// #include "mlir/IR/Types.h"
|
// #include "mlir/IR/Types.h"
|
||||||
// #include "mlir/IR/Value.h"
|
// #include "mlir/IR/Value.h"
|
||||||
|
#include "Attribute.h"
|
||||||
|
#include "AttributeImpl.h"
|
||||||
|
#include "Op.h"
|
||||||
|
#include "OpImpl.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||||
|
|
||||||
namespace builder {
|
namespace builder {
|
||||||
|
|
||||||
|
@ -15,3 +22,6 @@ Builder::Builder() : impl_(std::make_shared<Impl>()) {}
|
||||||
void Builder::DumpModule() {}
|
void Builder::DumpModule() {}
|
||||||
|
|
||||||
} // namespace builder
|
} // namespace builder
|
||||||
|
|
||||||
|
#define GET_OP_CLASSES
|
||||||
|
#include "include/mlir-hlo/Dialect/mhlo/IR/hlo_builder.cc.inc"
|
||||||
|
|
|
@ -3,14 +3,17 @@
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tools/mlir-tblgen-builder/Builder/Attribute.h"
|
||||||
|
#include "tools/mlir-tblgen-builder/Builder/Op.h"
|
||||||
|
|
||||||
namespace builder {
|
namespace builder {
|
||||||
|
|
||||||
class Builder {
|
class Builder {
|
||||||
public:
|
public:
|
||||||
class Impl;
|
|
||||||
|
|
||||||
Builder();
|
Builder();
|
||||||
void DumpModule();
|
void DumpModule();
|
||||||
|
|
||||||
|
class Impl;
|
||||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -18,4 +21,7 @@ class Builder {
|
||||||
};
|
};
|
||||||
} // namespace builder
|
} // namespace builder
|
||||||
|
|
||||||
|
#define GET_OP_CLASSES
|
||||||
|
#include "include/mlir-hlo/Dialect/mhlo/IR/hlo_builder.h.inc"
|
||||||
|
|
||||||
#endif
|
#endif
|
|
@ -4,7 +4,10 @@
|
||||||
#include "Builder.h"
|
#include "Builder.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
|
#include "mlir/IR/Block.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/MLIRContext.h"
|
#include "mlir/IR/MLIRContext.h"
|
||||||
#include "mlir/IR/Operation.h"
|
#include "mlir/IR/Operation.h"
|
||||||
#include "mlir/IR/Types.h"
|
#include "mlir/IR/Types.h"
|
||||||
|
@ -14,15 +17,30 @@ namespace builder {
|
||||||
|
|
||||||
class Builder::Impl {
|
class Builder::Impl {
|
||||||
public:
|
public:
|
||||||
Impl() {}
|
Impl() : builder_(&context_) {
|
||||||
// mlir::Location GetLoc() { return mlir_loc_; }
|
module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context_));
|
||||||
// mlir::OpBuilder GetBuilder() { return mlir_builder_; }
|
|
||||||
mlir::MLIRContext *GetContext() { return &mlir_context_; }
|
llvm::SmallVector<mlir::Type, 4> arg_types;
|
||||||
|
// Create the main function.
|
||||||
|
mlir::FunctionType funcType = builder_.getFunctionType(arg_types, {});
|
||||||
|
main_func_ = mlir::FuncOp::create(builder_.getUnknownLoc(), "main",
|
||||||
|
funcType, /* attrs = */ {});
|
||||||
|
|
||||||
|
entry_block_ = main_func_.addEntryBlock();
|
||||||
|
builder_.setInsertionPointToStart(entry_block_);
|
||||||
|
module_.push_back(main_func_);
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::Location GetLoc() { return builder_.getUnknownLoc(); }
|
||||||
|
mlir::OpBuilder GetBuilder() { return builder_; }
|
||||||
|
mlir::MLIRContext* GetContext() { return &context_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// mlir::Location mlir_loc_;
|
mlir::MLIRContext context_;
|
||||||
// mlir::OpBuilder mlir_builder_;
|
mlir::ModuleOp module_;
|
||||||
mlir::MLIRContext mlir_context_;
|
mlir::OpBuilder builder_;
|
||||||
|
mlir::FuncOp main_func_;
|
||||||
|
mlir::Block* entry_block_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace builder
|
} // namespace builder
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
namespace builder {
|
namespace builder {
|
||||||
class Op {
|
class Op {
|
||||||
|
public:
|
||||||
class Impl;
|
class Impl;
|
||||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||||
|
|
||||||
|
|
|
@ -15,10 +15,11 @@ namespace builder {
|
||||||
class Op::Impl {
|
class Op::Impl {
|
||||||
public:
|
public:
|
||||||
Impl() = default;
|
Impl() = default;
|
||||||
void SetOperation(Operation *Op) { op_ = Op; }
|
void SetOperation(mlir::Operation *Op) { op_ = Op; }
|
||||||
|
mlir::Value GetResult() { return op_->getResult(0); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Operation *op_;
|
mlir::Operation *op_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace builder
|
} // namespace builder
|
||||||
|
|
|
@ -1,28 +0,0 @@
|
||||||
#include "PrimitiveType.h"
|
|
||||||
|
|
||||||
#include "Shape.h"
|
|
||||||
#include "Tensor.h"
|
|
||||||
#include "llvm/Support/Casting.h"
|
|
||||||
// #include "mlir/Dialect/StandardOps/Ops.h"
|
|
||||||
// #include "mlir/IR/Attributes.h"
|
|
||||||
// #include "mlir/IR/Operation.h"
|
|
||||||
// #include "mlir/IR/StandardTypes.h"
|
|
||||||
#include "mlir/IR/Types.h"
|
|
||||||
// #include "mlir/IR/Value.h"
|
|
||||||
|
|
||||||
namespace builder {
|
|
||||||
|
|
||||||
// constexpr mlir::Type GetMlirType(const PrimitiveType primitiveType) {
|
|
||||||
// switch (primitiveType)
|
|
||||||
// {
|
|
||||||
// case PrimitiveType::BF16 :
|
|
||||||
// return
|
|
||||||
// break;
|
|
||||||
|
|
||||||
// default:
|
|
||||||
// break;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// }
|
|
||||||
|
|
||||||
} // namespace builder
|
|
|
@ -1,29 +0,0 @@
|
||||||
#ifndef BUILDER_PRIMITIVE_TYPE_
|
|
||||||
#define BUILDER_PRIMITIVE_TYPE_
|
|
||||||
|
|
||||||
namespace builder {
|
|
||||||
|
|
||||||
enum class PrimitiveType {
|
|
||||||
PRIMITIVE_TYPE_INVALID = 0,
|
|
||||||
PRED = 1,
|
|
||||||
S8 = 2,
|
|
||||||
S16 = 3,
|
|
||||||
S32 = 4,
|
|
||||||
S64 = 5,
|
|
||||||
U8 = 6,
|
|
||||||
U16 = 7,
|
|
||||||
U32 = 8,
|
|
||||||
U64 = 9,
|
|
||||||
F16 = 10,
|
|
||||||
F32 = 11,
|
|
||||||
BF16 = 16,
|
|
||||||
F64 = 12,
|
|
||||||
C64 = 15,
|
|
||||||
C128 = 18,
|
|
||||||
TUPLE = 13,
|
|
||||||
OPAQUE_TYPE = 14,
|
|
||||||
TOKEN = 17
|
|
||||||
};
|
|
||||||
|
|
||||||
}
|
|
||||||
#endif
|
|
|
@ -1,17 +0,0 @@
|
||||||
#ifndef BUILDER_SHAPE_
|
|
||||||
#define BUILDER_SHAPE_
|
|
||||||
|
|
||||||
#include <iostream>
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
namespace builder {
|
|
||||||
class Shape {
|
|
||||||
class Impl;
|
|
||||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::shared_ptr<Impl> impl_;
|
|
||||||
};
|
|
||||||
} // namespace builder
|
|
||||||
|
|
||||||
#endif
|
|
|
@ -1,24 +0,0 @@
|
||||||
#ifndef BUILDER_SHAPEIMPL_
|
|
||||||
#define BUILDER_SHAPEIMPL_
|
|
||||||
|
|
||||||
#include "Builder.h"
|
|
||||||
#include "llvm/Support/Casting.h"
|
|
||||||
#include "mlir/IR/Attributes.h"
|
|
||||||
#include "mlir/IR/Builders.h"
|
|
||||||
#include "mlir/IR/MLIRContext.h"
|
|
||||||
#include "mlir/IR/Operation.h"
|
|
||||||
#include "mlir/IR/Types.h"
|
|
||||||
#include "mlir/IR/Value.h"
|
|
||||||
|
|
||||||
namespace builder {
|
|
||||||
|
|
||||||
class Shape::Impl {
|
|
||||||
public:
|
|
||||||
Impl() = default;
|
|
||||||
|
|
||||||
private:
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace builder
|
|
||||||
|
|
||||||
#endif
|
|
|
@ -1,22 +0,0 @@
|
||||||
#include "Tensor.h"
|
|
||||||
|
|
||||||
#include "Shape.h"
|
|
||||||
#include "TensorImpl.h"
|
|
||||||
#include "llvm/Support/Casting.h"
|
|
||||||
// #include "mlir/Dialect/StandardOps/Ops.h"
|
|
||||||
// #include "mlir/IR/Attributes.h"
|
|
||||||
// #include "mlir/IR/Operation.h"
|
|
||||||
// #include "mlir/IR/StandardTypes.h"
|
|
||||||
// #include "mlir/IR/Types.h"
|
|
||||||
// #include "mlir/IR/Value.h"
|
|
||||||
|
|
||||||
namespace builder {
|
|
||||||
|
|
||||||
// Tensor::Tensor() : impl_(std::make_shared<Impl>()) {}
|
|
||||||
Tensor::Tensor(Shape& shape, PrimitiveType& primitiveType)
|
|
||||||
: impl_(std::make_shared<Impl>(shape, primitiveType)) {}
|
|
||||||
|
|
||||||
Shape Tensor::GetShape() { return impl_->GetShape(); }
|
|
||||||
PrimitiveType Tensor::GetType() { return impl_->GetType(); }
|
|
||||||
|
|
||||||
} // namespace builder
|
|
|
@ -1,25 +0,0 @@
|
||||||
#ifndef BUILDER_TENSOR_
|
|
||||||
#define BUILDER_TENSOR_
|
|
||||||
|
|
||||||
#include <iostream>
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
#include "PrimitiveType.h"
|
|
||||||
#include "Shape.h"
|
|
||||||
|
|
||||||
namespace builder {
|
|
||||||
|
|
||||||
class Tensor {
|
|
||||||
public:
|
|
||||||
Tensor(Shape& shape, PrimitiveType& primitiveType);
|
|
||||||
class Impl;
|
|
||||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
|
||||||
Shape GetShape();
|
|
||||||
PrimitiveType GetType();
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::shared_ptr<Impl> impl_;
|
|
||||||
};
|
|
||||||
} // namespace builder
|
|
||||||
|
|
||||||
#endif
|
|
|
@ -1,32 +0,0 @@
|
||||||
#ifndef BUILDER_TENSORIMPL_
|
|
||||||
#define BUILDER_TENSORIMPL_
|
|
||||||
|
|
||||||
#include "Builder.h"
|
|
||||||
#include "PrimitiveType.h"
|
|
||||||
#include "Shape.h"
|
|
||||||
#include "llvm/Support/Casting.h"
|
|
||||||
#include "mlir/IR/Attributes.h"
|
|
||||||
#include "mlir/IR/Builders.h"
|
|
||||||
#include "mlir/IR/MLIRContext.h"
|
|
||||||
#include "mlir/IR/Operation.h"
|
|
||||||
#include "mlir/IR/Types.h"
|
|
||||||
#include "mlir/IR/Value.h"
|
|
||||||
|
|
||||||
namespace builder {
|
|
||||||
|
|
||||||
class Tensor::Impl {
|
|
||||||
public:
|
|
||||||
Impl() = default;
|
|
||||||
Impl(Shape& shape, PrimitiveType& primitiveType)
|
|
||||||
: shape_(shape), primitiveType_(primitiveType) {}
|
|
||||||
Shape GetShape() { return shape_; }
|
|
||||||
PrimitiveType GetType() { return primitiveType_; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
Shape shape_;
|
|
||||||
PrimitiveType primitiveType_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace builder
|
|
||||||
|
|
||||||
#endif
|
|
|
@ -52,81 +52,93 @@ static const std::map<std::string, mlirTypeWrap> typeMapMLIR = {
|
||||||
{"::mlir::StringAttr",
|
{"::mlir::StringAttr",
|
||||||
{"std::string",
|
{"std::string",
|
||||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||||
body << " mlir::StringAttr " << var << "_mlir = mlir::StringAttr::get("
|
body << " mlir::StringAttr " << var
|
||||||
<< var << ", ctx);\n";
|
<< "_mlir = mlir::StringAttr::get(ctx, mlir::Twine(" << var
|
||||||
|
<< "));\n";
|
||||||
return var + "_mlir";
|
return var + "_mlir";
|
||||||
}}},
|
}}},
|
||||||
{"::mlir::IntegerAttr",
|
{"::mlir::IntegerAttr",
|
||||||
{"builder::Integer",
|
{"builder::Integer",
|
||||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||||
// body << " mlir::IntegerAttr " << var << "_mlir = mlir::IntegerAttr::get("
|
|
||||||
// << var << ", ctx);\n";
|
|
||||||
body << " mlir::IntegerAttr " << var << "_mlir = " << var
|
body << " mlir::IntegerAttr " << var << "_mlir = " << var
|
||||||
<< ".GetAttr(ctx);\n";
|
<< ".GetImpl()->GetAttr(ctx);\n";
|
||||||
return var + "_mlir";
|
|
||||||
}}},
|
|
||||||
{"::mlir::DenseIntElementsAttr",
|
|
||||||
{"std::vector<int>",
|
|
||||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
|
||||||
body << " mlir::DenseIntElementsAttr " << var
|
|
||||||
<< "_mlir = mlir::DenseIntElementsAttr::get("
|
|
||||||
<< "mlir::VectorType::get(" << var
|
|
||||||
<< ".size(), opBuilder->getIntegerType(32))," << var << ");\n";
|
|
||||||
return var + "_mlir";
|
|
||||||
}}},
|
|
||||||
{"::mlir::mhlo::ChannelHandle",
|
|
||||||
{"ChannelHandle",
|
|
||||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
|
||||||
return var + "_mlir";
|
return var + "_mlir";
|
||||||
}}},
|
}}},
|
||||||
{"::mlir::FloatAttr",
|
{"::mlir::FloatAttr",
|
||||||
{"float",
|
{"builder::Float",
|
||||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||||
body << " mlir::FloatAttr " << var << "_mlir = mlir::FloatAttr::get("
|
body << " mlir::FloatAttr " << var << "_mlir = " << var
|
||||||
<< var << ", ctx);\n";
|
<< ".GetImpl()->GetAttr(ctx);\n";
|
||||||
|
return var + "_mlir";
|
||||||
|
}}},
|
||||||
|
{"::mlir::DenseIntElementsAttr",
|
||||||
|
{"builder::TensorInt",
|
||||||
|
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||||
|
body << " mlir::DenseIntElementsAttr " << var << "_mlir = " << var
|
||||||
|
<< ".GetImpl()->GetAttr(ctx);\n";
|
||||||
|
return var + "_mlir";
|
||||||
|
}}},
|
||||||
|
{"::mlir::mhlo::ChannelHandle",
|
||||||
|
{"builder::ChannelHandle",
|
||||||
|
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||||
|
body << " mlir::mhlo::ChannelHandle " << var << "_mlir = " << var
|
||||||
|
<< ".GetImpl()->GetAttr(ctx);\n";
|
||||||
return var + "_mlir";
|
return var + "_mlir";
|
||||||
}}},
|
}}},
|
||||||
{"::mlir::BoolAttr",
|
{"::mlir::BoolAttr",
|
||||||
{"bool",
|
{"bool",
|
||||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||||
body << " mlir::BoolAttr " << var << "_mlir = mlir::BoolAttr::get("
|
body << " mlir::BoolAttr " << var
|
||||||
<< var << ", ctx);\n";
|
<< "_mlir = mlir::BoolAttr::get(ctx, " << var << ");\n";
|
||||||
return var + "_mlir";
|
return var + "_mlir";
|
||||||
}}},
|
}}},
|
||||||
{"::mlir::ElementsAttr",
|
{"::mlir::ElementsAttr",
|
||||||
{"::builder::Array",
|
{"builder::Tensor",
|
||||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||||
|
body << " mlir::DenseElementsAttr " << var << "_mlir = " << var
|
||||||
|
<< ".GetImpl()->GetAttr(ctx);\n";
|
||||||
return var + "_mlir";
|
return var + "_mlir";
|
||||||
}}},
|
}}},
|
||||||
{"::mlir::DenseElementsAttr",
|
{"::mlir::DenseElementsAttr",
|
||||||
{"::builder::Tensor",
|
{"builder::Tensor",
|
||||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||||
|
body << " mlir::DenseElementsAttr " << var << "_mlir = " << var
|
||||||
|
<< ".GetImpl()->GetAttr(ctx);\n";
|
||||||
return var + "_mlir";
|
return var + "_mlir";
|
||||||
}}},
|
}}},
|
||||||
// current only support string array.
|
|
||||||
{"::mlir::ArrayAttr",
|
{"::mlir::ArrayAttr",
|
||||||
{"std::vector<std::string>",
|
{"builder::Array",
|
||||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||||
|
body << " mlir::ArrayAttr " << var << "_mlir = " << var
|
||||||
|
<< ".GetImpl()->GetAttr(ctx);\n";
|
||||||
return var + "_mlir";
|
return var + "_mlir";
|
||||||
}}},
|
}}},
|
||||||
{"::mlir::mhlo::ConvDimensionNumbers",
|
{"::mlir::mhlo::ConvDimensionNumbers",
|
||||||
{"::builder::ConvDimensionNumbers",
|
{"builder::ConvDimensionNumbers",
|
||||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||||
|
body << " mlir::mhlo::ConvDimensionNumbers " << var
|
||||||
|
<< "_mlir = " << var << ".GetImpl()->GetAttr(ctx);\n";
|
||||||
return var + "_mlir";
|
return var + "_mlir";
|
||||||
}}},
|
}}},
|
||||||
{"::mlir::mhlo::DotDimensionNumbers",
|
{"::mlir::mhlo::DotDimensionNumbers",
|
||||||
{"::builder::DotDimensionNumbers",
|
{"builder::DotDimensionNumbers",
|
||||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||||
|
body << " mlir::mhlo::DotDimensionNumbers " << var << "_mlir = " << var
|
||||||
|
<< ".GetImpl()->GetAttr(ctx);\n";
|
||||||
return var + "_mlir";
|
return var + "_mlir";
|
||||||
}}},
|
}}},
|
||||||
{"::mlir::mhlo::GatherDimensionNumbers",
|
{"::mlir::mhlo::GatherDimensionNumbers",
|
||||||
{"::builder::GatherDimensionNumbers",
|
{"builder::GatherDimensionNumbers",
|
||||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||||
|
body << " mlir::mhlo::GatherDimensionNumbers " << var
|
||||||
|
<< "_mlir = " << var << ".GetImpl()->GetAttr(ctx);\n";
|
||||||
return var + "_mlir";
|
return var + "_mlir";
|
||||||
}}},
|
}}},
|
||||||
{"::mlir::mhlo::ScatterDimensionNumbers",
|
{"::mlir::mhlo::ScatterDimensionNumbers",
|
||||||
{"::builder::ScatterDimensionNumbers",
|
{"builder::ScatterDimensionNumbers",
|
||||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||||
|
body << " mlir::mhlo::ScatterDimensionNumbers " << var
|
||||||
|
<< "_mlir = " << var << ".GetImpl()->GetAttr(ctx);\n";
|
||||||
return var + "_mlir";
|
return var + "_mlir";
|
||||||
}}},
|
}}},
|
||||||
};
|
};
|
||||||
|
@ -138,16 +150,16 @@ static const std::map<std::string, mlirTypeWrap> typeMapMLIR = {
|
||||||
// {"::mlir::mhlo::ChannelHandle", "ChannelHandle"},
|
// {"::mlir::mhlo::ChannelHandle", "ChannelHandle"},
|
||||||
// {"::mlir::FloatAttr", "float"},
|
// {"::mlir::FloatAttr", "float"},
|
||||||
// {"::mlir::BoolAttr", "bool"},
|
// {"::mlir::BoolAttr", "bool"},
|
||||||
// {"::mlir::ElementsAttr", "::builder::Array"},
|
// {"::mlir::ElementsAttr", "builder::Array"},
|
||||||
// {"::mlir::DenseElementsAttr", "::builder::Tensor"},
|
// {"::mlir::DenseElementsAttr", "builder::Tensor"},
|
||||||
// // current only support string array.
|
// // current only support string array.
|
||||||
// {"::mlir::ArrayAttr", "std::vector<std::string>"},
|
// {"::mlir::ArrayAttr", "std::vector<std::string>"},
|
||||||
// {"::mlir::mhlo::ConvDimensionNumbers", "::builder::ConvDimensionNumbers"},
|
// {"::mlir::mhlo::ConvDimensionNumbers", "builder::ConvDimensionNumbers"},
|
||||||
// {"::mlir::mhlo::DotDimensionNumbers", "::builder::DotDimensionNumbers"},
|
// {"::mlir::mhlo::DotDimensionNumbers", "builder::DotDimensionNumbers"},
|
||||||
// {"::mlir::mhlo::GatherDimensionNumbers",
|
// {"::mlir::mhlo::GatherDimensionNumbers",
|
||||||
// "::builder::GatherDimensionNumbers"},
|
// "builder::GatherDimensionNumbers"},
|
||||||
// {"::mlir::mhlo::ScatterDimensionNumbers",
|
// {"::mlir::mhlo::ScatterDimensionNumbers",
|
||||||
// "::builder::ScatterDimensionNumbers"},
|
// "builder::ScatterDimensionNumbers"},
|
||||||
// };
|
// };
|
||||||
|
|
||||||
StringRef typeConvertFromMLIR(StringRef type) {
|
StringRef typeConvertFromMLIR(StringRef type) {
|
||||||
|
@ -707,14 +719,14 @@ OpEmitter::OpEmitter(const Operator &op,
|
||||||
// Generate C++ code for various op methods. The order here determines the
|
// Generate C++ code for various op methods. The order here determines the
|
||||||
// methods in the generated file.
|
// methods in the generated file.
|
||||||
// genOpAsmInterface();
|
// genOpAsmInterface();
|
||||||
genOpNameGetter();
|
//// genOpNameGetter();
|
||||||
// genNamedOperandGetters();
|
// genNamedOperandGetters();
|
||||||
// genNamedOperandSetters();
|
// genNamedOperandSetters();
|
||||||
// genNamedResultGetters();
|
// genNamedResultGetters();
|
||||||
// genNamedRegionGetters();
|
// genNamedRegionGetters();
|
||||||
genNamedSuccessorGetters();
|
//// genNamedSuccessorGetters();
|
||||||
genAttrGetters();
|
//// genAttrGetters();
|
||||||
genAttrSetters();
|
//// genAttrSetters();
|
||||||
// genOptionalAttrRemovers();
|
// genOptionalAttrRemovers();
|
||||||
genBuilder();
|
genBuilder();
|
||||||
// genParser();
|
// genParser();
|
||||||
|
@ -1179,7 +1191,7 @@ void OpEmitter::genSeparateArgParamBuilder() {
|
||||||
buildParamList(paramList, resultNames, paramKind, attrType);
|
buildParamList(paramList, resultNames, paramKind, attrType);
|
||||||
|
|
||||||
auto *m = opClass.addMethodAndPrune(
|
auto *m = opClass.addMethodAndPrune(
|
||||||
"::builder::Op", "build", OpMethod::MP_Static, std::move(paramList));
|
"builder::Op", "build", OpMethod::MP_Static, std::move(paramList));
|
||||||
// If the builder is redundant, skip generating the method.
|
// If the builder is redundant, skip generating the method.
|
||||||
if (!m)
|
if (!m)
|
||||||
return;
|
return;
|
||||||
|
@ -1329,7 +1341,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
|
||||||
resultTypeNames.reserve(numResults);
|
resultTypeNames.reserve(numResults);
|
||||||
|
|
||||||
// paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
|
// paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
|
||||||
paramList.emplace_back("::builder::Builder &", "builder");
|
paramList.emplace_back("builder::Builder &", "builder");
|
||||||
// paramList.emplace_back("::mlir::OperationState &", builderOpState);
|
// paramList.emplace_back("::mlir::OperationState &", builderOpState);
|
||||||
|
|
||||||
switch (typeParamKind) {
|
switch (typeParamKind) {
|
||||||
|
@ -1344,7 +1356,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
|
||||||
resultName = std::string(formatv("resultType{0}", i));
|
resultName = std::string(formatv("resultType{0}", i));
|
||||||
|
|
||||||
StringRef type =
|
StringRef type =
|
||||||
result.isVariadic() ? "std::vector<::builder::Type>" : "::builder::Type";
|
result.isVariadic() ? "std::vector<builder::Type>" : "builder::Type";
|
||||||
OpMethodParameter::Property properties = OpMethodParameter::PP_None;
|
OpMethodParameter::Property properties = OpMethodParameter::PP_None;
|
||||||
if (result.isOptional())
|
if (result.isOptional())
|
||||||
properties = OpMethodParameter::PP_Optional;
|
properties = OpMethodParameter::PP_Optional;
|
||||||
|
@ -1371,7 +1383,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
|
||||||
if (argument.is<tblgen::NamedTypeConstraint *>()) {
|
if (argument.is<tblgen::NamedTypeConstraint *>()) {
|
||||||
const auto &operand = op.getOperand(numOperands);
|
const auto &operand = op.getOperand(numOperands);
|
||||||
StringRef type =
|
StringRef type =
|
||||||
operand.isVariadic() ? "std::vector<::builder::Op>" : "::builder::Op";
|
operand.isVariadic() ? "std::vector<builder::Op>" : "builder::Op";
|
||||||
OpMethodParameter::Property properties = OpMethodParameter::PP_None;
|
OpMethodParameter::Property properties = OpMethodParameter::PP_None;
|
||||||
if (operand.isOptional())
|
if (operand.isOptional())
|
||||||
properties = OpMethodParameter::PP_Optional;
|
properties = OpMethodParameter::PP_Optional;
|
||||||
|
@ -1437,24 +1449,11 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
|
||||||
auto attrs = op.getAttributes();
|
auto attrs = op.getAttributes();
|
||||||
SmallVector<std::string, 4> newAttrs;
|
SmallVector<std::string, 4> newAttrs;
|
||||||
|
|
||||||
// for(auto o : operands){
|
// if (attrType == "::mlir::DenseIntElementsAttr") {
|
||||||
// body << " operands name: "<<o.name<<" getCPPClassName:"<<o.constraint.getCPPClassName()<<"\n";
|
// body << " // BBBBBBBB getStorageType:" << a.attr.getStorageType().str()
|
||||||
// }
|
// << "\n";
|
||||||
// for(auto a : attrs){
|
// body << " // BBBBBBBB getReturnType:" << a.attr.getReturnType().str()
|
||||||
// body << " attrs name: "<<a.name<<" type: "<< a.attr.getStorageType().str()<<"\n";
|
// << "\n";
|
||||||
// std::string attrType = a.attr.getStorageType().str();
|
|
||||||
// auto ff = typeMapMLIR.find(attrType);
|
|
||||||
// if(ff != typeMapMLIR.end()){
|
|
||||||
// body << "// BBBBBB \n";
|
|
||||||
// std::string attrName = a.name.str();
|
|
||||||
// ff->second.ConvertToMlir(attrName,body);
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
|
|
||||||
// body << "// AAAAAA \n";
|
|
||||||
// for(auto p : paramList){
|
|
||||||
// body << " AAA type"<<p.getType()<<" name"<<p.getName()<<"\n";
|
|
||||||
// }
|
// }
|
||||||
|
|
||||||
body << " auto b = builder.GetImpl();\n";
|
body << " auto b = builder.GetImpl();\n";
|
||||||
|
@ -1463,16 +1462,6 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
|
||||||
body << " auto ctx = b->GetContext();\n";
|
body << " auto ctx = b->GetContext();\n";
|
||||||
for (auto a : attrs) {
|
for (auto a : attrs) {
|
||||||
std::string attrType = a.attr.getStorageType().str();
|
std::string attrType = a.attr.getStorageType().str();
|
||||||
|
|
||||||
if(attrType == "::mlir::DenseIntElementsAttr")
|
|
||||||
{
|
|
||||||
body << " // BBBBBBBB getStorageType:"<< a.attr.getStorageType().str() <<"\n";
|
|
||||||
body << " // BBBBBBBB getReturnType:"<< a.attr.getReturnType().str() <<"\n";
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
auto typePair = typeMapMLIR.find(attrType);
|
auto typePair = typeMapMLIR.find(attrType);
|
||||||
if (typePair != typeMapMLIR.end()) {
|
if (typePair != typeMapMLIR.end()) {
|
||||||
std::string attrName = a.name.str();
|
std::string attrName = a.name.str();
|
||||||
|
@ -1480,20 +1469,42 @@ if(attrType == "::mlir::DenseIntElementsAttr")
|
||||||
newAttrs.emplace_back(mlirName);
|
newAttrs.emplace_back(mlirName);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
int index = 0;
|
||||||
|
for (auto v : operands) {
|
||||||
|
std::string name =
|
||||||
|
v.name.empty() ? "odsArg" + std::to_string(index) : v.name.str();
|
||||||
|
index++;
|
||||||
|
if (v.isVariadic()) {
|
||||||
|
body << " std::vector<mlir::Value> " << name << "_v;\n";
|
||||||
|
body << " for(auto v : " << name << "){\n " << name
|
||||||
|
<< "_v.push_back(v.GetImpl()->GetResult());\n }"
|
||||||
|
<< "\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
body << " mlir::" << op.getDialectName() << "::" << op.getCppClassName()
|
body << " mlir::" << op.getDialectName() << "::" << op.getCppClassName()
|
||||||
<< " currentOp =\n";
|
<< " currentOp =\n";
|
||||||
body << " opBuilder.create<mlir::" << op.getDialectName()
|
body << " opBuilder.create<mlir::" << op.getDialectName()
|
||||||
<< "::" << op.getDialectName() << ">(\n";
|
<< "::" << op.getCppClassName() << ">(\n";
|
||||||
body << " loc";
|
body << " loc";
|
||||||
|
|
||||||
|
index = 0;
|
||||||
std::for_each(operands.begin(), operands.end(), [&](NamedTypeConstraint &v) {
|
std::for_each(operands.begin(), operands.end(), [&](NamedTypeConstraint &v) {
|
||||||
body << ",\n " << v.name << ".getResult()";
|
std::string name =
|
||||||
|
v.name.empty() ? "odsArg_" + std::to_string(index) : v.name.str();
|
||||||
|
index++;
|
||||||
|
if (v.isVariadic()) {
|
||||||
|
body << ",\n " << name << "_v";
|
||||||
|
} else {
|
||||||
|
body << ",\n " << name << ".GetImpl()->GetResult()";
|
||||||
|
}
|
||||||
});
|
});
|
||||||
std::for_each(newAttrs.begin(), newAttrs.end(),
|
std::for_each(newAttrs.begin(), newAttrs.end(),
|
||||||
[&](std::string &n) { body << ",\n " << n; });
|
[&](std::string &n) { body << ",\n " << n; });
|
||||||
body << "\n );\n";
|
body << "\n );\n";
|
||||||
body << " builder::" << op.getCppClassName() << " builderOp;\n";
|
body << " builder::mhlo::" << op.getCppClassName() << " builderOp;\n";
|
||||||
body << " auto opImpl = builderOp.GetImpl();\n";
|
body << " auto opImpl = builderOp.GetImpl();\n";
|
||||||
body << " opImpl.SetOperation(currentOp.getOperation());\n";
|
body << " opImpl->SetOperation(currentOp.getOperation());\n";
|
||||||
body << " return builderOp;\n";
|
body << " return builderOp;\n";
|
||||||
|
|
||||||
// // Push all operands to the result.
|
// // Push all operands to the result.
|
||||||
|
@ -2109,7 +2120,8 @@ void OpEmitter::genTraits() {
|
||||||
void OpEmitter::genOpNameGetter() {
|
void OpEmitter::genOpNameGetter() {
|
||||||
auto *method = opClass.addMethodAndPrune(
|
auto *method = opClass.addMethodAndPrune(
|
||||||
"std::string", "getOperationName",
|
"std::string", "getOperationName",
|
||||||
OpMethod::Property(OpMethod::MP_Static | OpMethod::MP_Constexpr));
|
OpMethod::Property(OpMethod::MP_Static));
|
||||||
|
// OpMethod::Property(OpMethod::MP_Static | OpMethod::MP_Constexpr));
|
||||||
method->body() << " return std::string(\"" << op.getOperationName()
|
method->body() << " return std::string(\"" << op.getOperationName()
|
||||||
<< "\");";
|
<< "\");";
|
||||||
}
|
}
|
||||||
|
@ -2196,7 +2208,7 @@ static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
|
||||||
std::string className = Operator(def).getQualCppClassName();
|
std::string className = Operator(def).getQualCppClassName();
|
||||||
llvm::SplitString(StringRef(className), namespaces, StringRef("::"));
|
llvm::SplitString(StringRef(className), namespaces, StringRef("::"));
|
||||||
if (namespaces.begin() != namespaces.end())
|
if (namespaces.begin() != namespaces.end())
|
||||||
os << "::builder::mhlo::" << namespaces.back().str();
|
os << "builder::mhlo::" << namespaces.back().str();
|
||||||
},
|
},
|
||||||
[&os]() { os << ",\n"; });
|
[&os]() { os << ",\n"; });
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue