add more types and try to build pass

This commit is contained in:
colin.liang 2021-08-13 15:05:10 +08:00
parent e88366b851
commit 4a9b201c4e
18 changed files with 702 additions and 339 deletions

88
BUILD
View File

@ -140,39 +140,6 @@ cc_binary(
], ],
) )
cc_library(
name = "mlir-hlo-builder",
srcs = glob([
"tools/mlir-tblgen-builder/Builder/*.h",
"tools/mlir-tblgen-builder/Builder/*.cpp",
]),
deps = [
"@llvm-project//mlir:MlirTableGenMain",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:IR",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TableGen",
"@llvm-project//llvm:config",
],
)
cc_test(
name = "mlir-tblgen-builder-basic",
srcs = [
"tests/mlir-tblgen-builder/test_basic.cpp",
],
deps = [
":hlo_ops_builder_gen",
":mlir-hlo-builder",
"@llvm-project//mlir:MlirTableGenMain",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:IR",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TableGen",
"@llvm-project//llvm:config",
],
)
gentbl_cc_library( gentbl_cc_library(
name = "hlo_ops_builder_gen", name = "hlo_ops_builder_gen",
strip_include_prefix = "include", strip_include_prefix = "include",
@ -195,6 +162,61 @@ gentbl_cc_library(
deps = [":hlo_ops_td_files"], deps = [":hlo_ops_td_files"],
) )
cc_library(
name = "mlir-hlo-builder",
srcs = glob([
"tools/mlir-tblgen-builder/Builder/*.h",
"tools/mlir-tblgen-builder/Builder/*.cpp",
]),
includes = ["include"],
deps = [
":all_passes",
":disc_ral",
":hlo",
":lhlo",
":lhlo_gpu",
":hlo_ops_builder_gen",
"@llvm-project//mlir:MlirTableGenMain",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:IR",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TableGen",
"@llvm-project//llvm:config",
# "@llvm-project//llvm:AllTargetsAsmParsers",
# "@llvm-project//llvm:AllTargetsCodeGens",
# "@llvm-project//llvm:Core",
# "@llvm-project//llvm:ExecutionEngine",
# "@llvm-project//llvm:Option",
# "@llvm-project//llvm:OrcJIT",
# "@llvm-project//llvm:Support",
# "@llvm-project//llvm:Target",
# "@llvm-project//mlir:AllPassesAndDialects",
# "@llvm-project//mlir:IR",
# "@llvm-project//mlir:MlirOptLib",
# "@llvm-project//mlir:Pass",
# "@llvm-project//mlir:Support",
# "@llvm-project//mlir:MlirJitRunner",
],
)
cc_test(
name = "mlir-tblgen-builder-basic",
srcs = [
"tests/mlir-tblgen-builder/test_basic.cpp",
],
deps = [
":mlir-hlo-builder",
"@llvm-project//mlir:MlirTableGenMain",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:IR",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TableGen",
"@llvm-project//llvm:config",
],
)
gentbl_cc_library( gentbl_cc_library(
name = "hlo_ops_base_inc_gen", name = "hlo_ops_base_inc_gen",
strip_include_prefix = "include", strip_include_prefix = "include",

View File

@ -1,12 +1,15 @@
#include "include/mlir-hlo/Dialect/mhlo/IR/hlo_builder.h.inc"
#include "tools/mlir-tblgen-builder/Builder/Attribute.h"
#include "tools/mlir-tblgen-builder/Builder/Builder.h" #include "tools/mlir-tblgen-builder/Builder/Builder.h"
#include "tools/mlir-tblgen-builder/Builder/Op.h"
#include "tools/mlir-tblgen-builder/Builder/PrimitiveType.h"
#include "tools/mlir-tblgen-builder/Builder/Shape.h"
#include "tools/mlir-tblgen-builder/Builder/Tensor.h"
int main() { int main() {
builder::Integer i(322); builder::Builder builder;
builder::Shape shape({100, 100});
auto pType = builder::PrimitiveType::F32();
builder::Type type(shape, pType);
builder::Tensor tensor(std::vector<int>(100));
auto op = builder::mhlo::ConstOp::build(builder, type, tensor);
return 0; return 0;
} }
// static ::builder::Op build(::builder::Builder &builder, ::builder::Type
// output, ::builder::Tensor value);

View File

@ -1,9 +1,6 @@
#include "Attribute.h" #include "Attribute.h"
#include "AttributeImpl.h" #include "AttributeImpl.h"
#include "Shape.h"
#include "Tensor.h"
#include "TensorImpl.h"
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include "memory.h" #include "memory.h"
// #include "mlir/Dialect/StandardOps/Ops.h" // #include "mlir/Dialect/StandardOps/Ops.h"
@ -15,8 +12,93 @@
namespace builder { namespace builder {
inline bool PrimitiveType::operator==(const PrimitiveType& pt) {
return impl_ == pt.impl_;
}
PrimitiveType::PrimitiveType(std::shared_ptr<PrimitiveType::Impl> impl)
: impl_(impl) {}
PrimitiveType PrimitiveType::PRED() {
PrimitiveType p(
std::make_shared<PrimitiveType::Impl>(PrimitiveType::Impl::pType::PRED));
return p;
}
PrimitiveType PrimitiveType::S8() {
PrimitiveType p(
std::make_shared<PrimitiveType::Impl>(PrimitiveType::Impl::pType::S8));
return p;
}
PrimitiveType PrimitiveType::S16() {
PrimitiveType p(
std::make_shared<PrimitiveType::Impl>(PrimitiveType::Impl::pType::S16));
return p;
}
PrimitiveType PrimitiveType::F32() {
PrimitiveType p(
std::make_shared<PrimitiveType::Impl>(PrimitiveType::Impl::pType::F32));
return p;
}
PrimitiveType PrimitiveType::S32() {
PrimitiveType p(
std::make_shared<PrimitiveType::Impl>(PrimitiveType::Impl::pType::S32));
return p;
}
PrimitiveType PrimitiveType::S64() {
PrimitiveType p(
std::make_shared<PrimitiveType::Impl>(PrimitiveType::Impl::pType::S64));
return p;
}
Shape::Shape(std::vector<int64_t> dims)
: impl_(std::make_shared<Shape::Impl>(dims)) {}
Type::Type(Shape& shape, PrimitiveType& primitiveType)
: impl_(std::make_shared<Type::Impl>(shape, primitiveType)) {}
Integer::Integer(int value) : impl_(std::make_shared<Integer::Impl>(value)) {} Integer::Integer(int value) : impl_(std::make_shared<Integer::Impl>(value)) {}
Integer::Integer(int64_t value) Integer::Integer(int64_t value)
: impl_(std::make_shared<Integer::Impl>(value)) {} : impl_(std::make_shared<Integer::Impl>(value)) {}
Float::Float(float value) : impl_(std::make_shared<Float::Impl>(value)) {}
Float::Float(double value) : impl_(std::make_shared<Float::Impl>(value)) {}
int64_t Array::GetSize() { return impl_->GetSize(); }
PrimitiveType Array::GetType() { return impl_->GetType(); }
Array::Array(std::vector<int> value)
: impl_(std::make_shared<Array::Impl>(value)) {}
Array::Array(std::vector<int64_t> value)
: impl_(std::make_shared<Array::Impl>(value)) {}
Array::Array(std::vector<std::string> value)
: impl_(std::make_shared<Array::Impl>(value)) {}
Tensor::Tensor(Shape& shape, PrimitiveType& primitiveType)
: impl_(std::make_shared<Impl>(shape, primitiveType)) {}
Shape Tensor::GetShape() { return impl_->GetShape(); }
PrimitiveType Tensor::GetType() { return impl_->GetType(); }
Tensor::Tensor(std::vector<int> value)
: impl_(std::make_shared<Tensor::Impl>(value)) {}
Tensor::Tensor(std::vector<int64_t> value)
: impl_(std::make_shared<Tensor::Impl>(value)) {}
Tensor::Tensor(std::vector<float> value)
: impl_(std::make_shared<Tensor::Impl>(value)) {}
TensorInt::TensorInt(Shape& shape, PrimitiveType& primitiveType)
: impl_(std::make_shared<Impl>(shape, primitiveType)) {}
Shape TensorInt::GetShape() { return impl_->GetShape(); }
PrimitiveType TensorInt::GetType() { return impl_->GetType(); }
ChannelHandle::ChannelHandle()
: impl_(std::make_shared<ChannelHandle::Impl>()) {}
ConvDimensionNumbers::ConvDimensionNumbers()
: impl_(std::make_shared<ConvDimensionNumbers::Impl>()) {}
DotDimensionNumbers::DotDimensionNumbers()
: impl_(std::make_shared<DotDimensionNumbers::Impl>()) {}
GatherDimensionNumbers::GatherDimensionNumbers()
: impl_(std::make_shared<GatherDimensionNumbers::Impl>()) {}
ScatterDimensionNumbers::ScatterDimensionNumbers()
: impl_(std::make_shared<ScatterDimensionNumbers::Impl>()) {}
} // namespace builder } // namespace builder

View File

@ -1,16 +1,57 @@
#ifndef BUILDER_ATTRIBUTE_H_ #ifndef BUILDER_ATTRIBUTE_H_
#define BUILDER_ATTRIBUTE_H_ #define BUILDER_ATTRIBUTE_H_
#include <iostream>
#include <memory> #include <memory>
#include <vector>
namespace builder { namespace builder {
class PrimitiveType {
public:
static PrimitiveType PRED();
static PrimitiveType S8();
static PrimitiveType S16();
static PrimitiveType F32();
static PrimitiveType S32();
static PrimitiveType S64();
inline bool operator==(const PrimitiveType& pt);
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
PrimitiveType(std::shared_ptr<Impl>);
private:
std::shared_ptr<Impl> impl_;
// PrimitiveType(const PrimitiveType&) = default;
};
class Shape {
public:
Shape(std::vector<int64_t> dims);
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
private:
std::shared_ptr<Impl> impl_;
};
class Integer { class Integer {
public: public:
Integer(int value); Integer(int value);
Integer(int64_t value); Integer(int64_t value);
class Impl; class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
private:
std::shared_ptr<Impl> impl_;
};
class Float {
public:
Float(float value);
Float(double value);
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
private: private:
std::shared_ptr<Impl> impl_; std::shared_ptr<Impl> impl_;
@ -18,6 +59,43 @@ class Integer {
class Array { class Array {
public: public:
Array(std::vector<int> value);
Array(std::vector<int64_t> value);
Array(std::vector<std::string> value);
int64_t GetSize();
PrimitiveType GetType();
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
private:
std::shared_ptr<Impl> impl_;
};
class Tensor {
public:
Tensor(Shape& shape, PrimitiveType& primitiveType);
Tensor(std::vector<int> value);
Tensor(std::vector<int64_t> value);
Tensor(std::vector<float> value);
Shape GetShape();
PrimitiveType GetType();
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
private:
std::shared_ptr<Impl> impl_;
};
class TensorInt {
public:
TensorInt(Shape& shape, PrimitiveType& primitiveType);
Shape GetShape();
PrimitiveType GetType();
class Impl; class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; } std::shared_ptr<Impl> GetImpl() { return impl_; }
@ -27,6 +105,7 @@ class Array {
class Type { class Type {
public: public:
Type(Shape& shape, PrimitiveType& primitiveType);
class Impl; class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; } std::shared_ptr<Impl> GetImpl() { return impl_; }
@ -34,16 +113,56 @@ class Type {
std::shared_ptr<Impl> impl_; std::shared_ptr<Impl> impl_;
}; };
// template <typename T> class ChannelHandle {
// class DenseIntElementsAttr { public:
// public: ChannelHandle();
// explicit DenseIntElementsAttr(Tensor); class Impl;
// class Impl; std::shared_ptr<Impl> GetImpl() { return impl_; }
// std::shared_ptr<Impl> GetImpl() { return impl_; }
private:
std::shared_ptr<Impl> impl_;
};
class ConvDimensionNumbers {
public:
ConvDimensionNumbers();
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
private:
std::shared_ptr<Impl> impl_;
};
class DotDimensionNumbers {
public:
DotDimensionNumbers();
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
private:
std::shared_ptr<Impl> impl_;
};
class GatherDimensionNumbers {
public:
GatherDimensionNumbers();
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
private:
std::shared_ptr<Impl> impl_;
};
class ScatterDimensionNumbers {
public:
ScatterDimensionNumbers();
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
private:
std::shared_ptr<Impl> impl_;
};
// private:
// std::shared_ptr<Impl> impl_;
// };
} // namespace builder } // namespace builder

View File

@ -1,8 +1,13 @@
#ifndef BUILDER_TYPEIMPL_ #ifndef BUILDER_ATTRIBUTEIMPL_H_
#define BUILDER_TYPEIMPL_ #define BUILDER_ATTRIBUTEIMPL_H_
#include <vector>
#include "Builder.h" #include "Builder.h"
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
@ -13,45 +18,306 @@
namespace builder { namespace builder {
class PrimitiveType::Impl {
public:
enum pType {
PRIMITIVE_TYPE_INVALID = 0,
PRED = 1,
S8 = 2,
S16 = 3,
S32 = 4,
S64 = 5,
U8 = 6,
U16 = 7,
U32 = 8,
U64 = 9,
F16 = 10,
F32 = 11,
BF16 = 16,
F64 = 12,
C64 = 15,
C128 = 18,
TUPLE = 13,
OPAQUE_TYPE = 14,
TOKEN = 17
};
Impl(pType t) : t_(t) {
switch (t) {
case PRIMITIVE_TYPE_INVALID:
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::NoneType::get(context);
};
break;
case PRED:
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(
context, 1, mlir::IntegerType::SignednessSemantics::Signed);
};
break;
case S8:
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(
context, 8, mlir::IntegerType::SignednessSemantics::Signed);
};
break;
case S16:
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(
context, 16, mlir::IntegerType::SignednessSemantics::Signed);
};
break;
case F32:
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::Float32Type::get(context);
};
break;
case S32:
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(
context, 32, mlir::IntegerType::SignednessSemantics::Signed);
};
break;
case S64:
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(
context, 64, mlir::IntegerType::SignednessSemantics::Signed);
};
break;
default:
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::NoneType::get(context);
};
break;
}
}
// class BFloat16Type;
// class ComplexType;
// class Float128Type;
// class Float16Type;
// class Float32Type;
// class Float64Type;
// class Float80Type;
// class FunctionType;
// class IndexType;
// class IntegerType;
// class MemRefType;
// class NoneType;
// class OpaqueType;
// class RankedTensorType;
// class TupleType;
// class UnrankedMemRefType;
// class UnrankedTensorType;
// class VectorType;
inline bool operator==(const PrimitiveType::Impl &impl) {
return t_ == impl.t_;
}
std::function<mlir::Type(mlir::MLIRContext *context)> GetMlirType;
private:
pType t_;
};
class Shape::Impl {
public:
Impl(std::vector<int64_t> dims) : dims_(dims) {}
private:
std::vector<int64_t> dims_;
};
class Type::Impl {
public:
Impl(Shape &shape, PrimitiveType &primitiveType)
: shape_(shape), primitiveType_(primitiveType) {}
Shape GetShape() { return shape_; }
PrimitiveType GetType() { return primitiveType_; }
private:
Shape shape_;
PrimitiveType primitiveType_;
};
class Integer::Impl { class Integer::Impl {
public: public:
Impl(){};
Impl(int value) { Impl(int value) {
GetAttr = [=](mlir::MLIRContext *context) -> mlir::IntegerAttr { GetAttr = [=](mlir::MLIRContext *context) -> mlir::IntegerAttr {
auto int32_type = mlir::IntegerType::get(context, 32); auto int32_type = mlir::IntegerType::get(context, 32);
return mlir::IntegerAttr::get(int32_type, value); return mlir::IntegerAttr::get(int32_type, value);
}; };
}; };
Impl(int64_t value) Impl(int64_t value) {
: GetAttr([=](mlir::MLIRContext *context) -> mlir::IntegerAttr { GetAttr = [=](mlir::MLIRContext *context) -> mlir::IntegerAttr {
auto int32_type = mlir::IntegerType::get(context, 64); auto int64_type = mlir::IntegerType::get(context, 64);
return mlir::IntegerAttr::get(int32_type, value); return mlir::IntegerAttr::get(int64_type, value);
}){}; };
};
std::function<mlir::IntegerAttr(mlir::MLIRContext *context)> GetAttr; std::function<mlir::IntegerAttr(mlir::MLIRContext *context)> GetAttr;
private: private:
}; };
class Float::Impl {
public:
Impl(float value) {
GetAttr = [=](mlir::MLIRContext *context) -> mlir::FloatAttr {
auto float32_type = mlir::Float32Type::get(context);
return mlir::FloatAttr::get(float32_type, value);
};
};
Impl(double value) {
GetAttr = [=](mlir::MLIRContext *context) -> mlir::FloatAttr {
auto float64_type = mlir::Float64Type::get(context);
return mlir::FloatAttr::get(float64_type, value);
};
};
std::function<mlir::FloatAttr(mlir::MLIRContext *context)> GetAttr;
private:
};
class Array::Impl { class Array::Impl {
public: public:
Impl() = default; Impl(std::vector<int> value)
: size_({value.size()}), primitiveType_(PrimitiveType::S32()) {
private: // GetAttr = [=](mlir::MLIRContext *context) -> mlir::ArrayAttr {
}; // auto type =
class Type::Impl { // mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
public: // mlir::IntegerType::get(context, 32));
Impl() = default; // return mlir::ArrayAttr::get(type, llvm::makeArrayRef(value));
private:
};
// class DenseIntElementsAttr::Impl {
// public:
// Impl() = default;
// private:
// }; // };
}
Impl(std::vector<int64_t> value)
: size_({value.size()}), primitiveType_(PrimitiveType::S64()) {
// GetAttr = [=](mlir::MLIRContext *context) -> mlir::ArrayAttr {
// auto type =
// mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
// mlir::IntegerType::get(context, 64));
// return mlir::ArrayAttr::get(type, llvm::makeArrayRef(value));
// };
}
Impl(std::vector<std::string> value)
: size_({value.size()}), primitiveType_(PrimitiveType::F32()) {
// GetAttr = [=](mlir::MLIRContext *context) -> mlir::ArrayAttr {
// auto type =
// mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
// mlir::FloatType::getF32(context));
// return mlir::ArrayAttr::get(type, llvm::makeArrayRef(value));
// };
}
int64_t GetSize() { return size_; }
PrimitiveType GetType() { return primitiveType_; }
std::function<mlir::ArrayAttr(mlir::MLIRContext *context)> GetAttr;
private:
int64_t size_;
PrimitiveType primitiveType_;
};
class Tensor::Impl {
public:
Impl(std::vector<int> value)
: shape_({value.size()}), primitiveType_(PrimitiveType::S32()) {
GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
auto type =
mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
mlir::IntegerType::get(context, 32));
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
};
}
Impl(std::vector<int64_t> value)
: shape_({value.size()}), primitiveType_(PrimitiveType::S64()) {
GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
auto type =
mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
mlir::IntegerType::get(context, 64));
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
};
}
Impl(std::vector<float> value)
: shape_({value.size()}), primitiveType_(PrimitiveType::F32()) {
GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
auto type =
mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
mlir::FloatType::getF32(context));
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
};
}
Impl(Shape &shape, PrimitiveType &primitiveType)
: shape_(shape), primitiveType_(primitiveType) {}
Shape GetShape() { return shape_; }
PrimitiveType GetType() { return primitiveType_; }
std::function<mlir::DenseElementsAttr(mlir::MLIRContext *context)> GetAttr;
private:
Shape shape_;
PrimitiveType primitiveType_;
};
class TensorInt::Impl {
public:
Impl(Shape &shape, PrimitiveType &primitiveType)
: shape_(shape), primitiveType_(primitiveType) {}
Shape GetShape() { return shape_; }
PrimitiveType GetType() { return primitiveType_; }
std::function<mlir::DenseIntElementsAttr(mlir::MLIRContext *context)> GetAttr;
private:
Shape shape_;
PrimitiveType primitiveType_;
};
class ChannelHandle::Impl {
public:
Impl() {}
std::function<mlir::mhlo::ChannelHandle(mlir::MLIRContext *context)> GetAttr;
private:
};
class ConvDimensionNumbers::Impl {
public:
Impl() {}
std::function<mlir::mhlo::ConvDimensionNumbers(mlir::MLIRContext *context)>
GetAttr;
private:
};
class DotDimensionNumbers::Impl {
public:
Impl() {}
std::function<mlir::mhlo::DotDimensionNumbers(mlir::MLIRContext *context)>
GetAttr;
private:
};
class GatherDimensionNumbers::Impl {
public:
Impl() {}
std::function<mlir::mhlo::GatherDimensionNumbers(mlir::MLIRContext *context)>
GetAttr;
private:
};
class ScatterDimensionNumbers::Impl {
public:
Impl() {}
std::function<mlir::mhlo::ScatterDimensionNumbers(mlir::MLIRContext *context)>
GetAttr;
private:
};
} // namespace builder } // namespace builder

View File

@ -8,6 +8,13 @@
// #include "mlir/IR/StandardTypes.h" // #include "mlir/IR/StandardTypes.h"
// #include "mlir/IR/Types.h" // #include "mlir/IR/Types.h"
// #include "mlir/IR/Value.h" // #include "mlir/IR/Value.h"
#include "Attribute.h"
#include "AttributeImpl.h"
#include "Op.h"
#include "OpImpl.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
namespace builder { namespace builder {
@ -15,3 +22,6 @@ Builder::Builder() : impl_(std::make_shared<Impl>()) {}
void Builder::DumpModule() {} void Builder::DumpModule() {}
} // namespace builder } // namespace builder
#define GET_OP_CLASSES
#include "include/mlir-hlo/Dialect/mhlo/IR/hlo_builder.cc.inc"

View File

@ -3,14 +3,17 @@
#include <memory> #include <memory>
#include "tools/mlir-tblgen-builder/Builder/Attribute.h"
#include "tools/mlir-tblgen-builder/Builder/Op.h"
namespace builder { namespace builder {
class Builder { class Builder {
public: public:
class Impl;
Builder(); Builder();
void DumpModule(); void DumpModule();
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; } std::shared_ptr<Impl> GetImpl() { return impl_; }
private: private:
@ -18,4 +21,7 @@ class Builder {
}; };
} // namespace builder } // namespace builder
#define GET_OP_CLASSES
#include "include/mlir-hlo/Dialect/mhlo/IR/hlo_builder.h.inc"
#endif #endif

View File

@ -4,7 +4,10 @@
#include "Builder.h" #include "Builder.h"
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h" #include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h" #include "mlir/IR/Types.h"
@ -14,15 +17,30 @@ namespace builder {
class Builder::Impl { class Builder::Impl {
public: public:
Impl() {} Impl() : builder_(&context_) {
// mlir::Location GetLoc() { return mlir_loc_; } module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context_));
// mlir::OpBuilder GetBuilder() { return mlir_builder_; }
mlir::MLIRContext *GetContext() { return &mlir_context_; } llvm::SmallVector<mlir::Type, 4> arg_types;
// Create the main function.
mlir::FunctionType funcType = builder_.getFunctionType(arg_types, {});
main_func_ = mlir::FuncOp::create(builder_.getUnknownLoc(), "main",
funcType, /* attrs = */ {});
entry_block_ = main_func_.addEntryBlock();
builder_.setInsertionPointToStart(entry_block_);
module_.push_back(main_func_);
}
mlir::Location GetLoc() { return builder_.getUnknownLoc(); }
mlir::OpBuilder GetBuilder() { return builder_; }
mlir::MLIRContext* GetContext() { return &context_; }
private: private:
// mlir::Location mlir_loc_; mlir::MLIRContext context_;
// mlir::OpBuilder mlir_builder_; mlir::ModuleOp module_;
mlir::MLIRContext mlir_context_; mlir::OpBuilder builder_;
mlir::FuncOp main_func_;
mlir::Block* entry_block_;
}; };
} // namespace builder } // namespace builder

View File

@ -6,6 +6,7 @@
namespace builder { namespace builder {
class Op { class Op {
public:
class Impl; class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; } std::shared_ptr<Impl> GetImpl() { return impl_; }

View File

@ -15,10 +15,11 @@ namespace builder {
class Op::Impl { class Op::Impl {
public: public:
Impl() = default; Impl() = default;
void SetOperation(Operation *Op) { op_ = Op; } void SetOperation(mlir::Operation *Op) { op_ = Op; }
mlir::Value GetResult() { return op_->getResult(0); }
private: private:
Operation *op_; mlir::Operation *op_;
}; };
} // namespace builder } // namespace builder

View File

@ -1,28 +0,0 @@
#include "PrimitiveType.h"
#include "Shape.h"
#include "Tensor.h"
#include "llvm/Support/Casting.h"
// #include "mlir/Dialect/StandardOps/Ops.h"
// #include "mlir/IR/Attributes.h"
// #include "mlir/IR/Operation.h"
// #include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
// #include "mlir/IR/Value.h"
namespace builder {
// constexpr mlir::Type GetMlirType(const PrimitiveType primitiveType) {
// switch (primitiveType)
// {
// case PrimitiveType::BF16 :
// return
// break;
// default:
// break;
// }
// }
} // namespace builder

View File

@ -1,29 +0,0 @@
#ifndef BUILDER_PRIMITIVE_TYPE_
#define BUILDER_PRIMITIVE_TYPE_
namespace builder {
enum class PrimitiveType {
PRIMITIVE_TYPE_INVALID = 0,
PRED = 1,
S8 = 2,
S16 = 3,
S32 = 4,
S64 = 5,
U8 = 6,
U16 = 7,
U32 = 8,
U64 = 9,
F16 = 10,
F32 = 11,
BF16 = 16,
F64 = 12,
C64 = 15,
C128 = 18,
TUPLE = 13,
OPAQUE_TYPE = 14,
TOKEN = 17
};
}
#endif

View File

@ -1,17 +0,0 @@
#ifndef BUILDER_SHAPE_
#define BUILDER_SHAPE_
#include <iostream>
#include <memory>
namespace builder {
class Shape {
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
private:
std::shared_ptr<Impl> impl_;
};
} // namespace builder
#endif

View File

@ -1,24 +0,0 @@
#ifndef BUILDER_SHAPEIMPL_
#define BUILDER_SHAPEIMPL_
#include "Builder.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
namespace builder {
class Shape::Impl {
public:
Impl() = default;
private:
};
} // namespace builder
#endif

View File

@ -1,22 +0,0 @@
#include "Tensor.h"
#include "Shape.h"
#include "TensorImpl.h"
#include "llvm/Support/Casting.h"
// #include "mlir/Dialect/StandardOps/Ops.h"
// #include "mlir/IR/Attributes.h"
// #include "mlir/IR/Operation.h"
// #include "mlir/IR/StandardTypes.h"
// #include "mlir/IR/Types.h"
// #include "mlir/IR/Value.h"
namespace builder {
// Tensor::Tensor() : impl_(std::make_shared<Impl>()) {}
Tensor::Tensor(Shape& shape, PrimitiveType& primitiveType)
: impl_(std::make_shared<Impl>(shape, primitiveType)) {}
Shape Tensor::GetShape() { return impl_->GetShape(); }
PrimitiveType Tensor::GetType() { return impl_->GetType(); }
} // namespace builder

View File

@ -1,25 +0,0 @@
#ifndef BUILDER_TENSOR_
#define BUILDER_TENSOR_
#include <iostream>
#include <memory>
#include "PrimitiveType.h"
#include "Shape.h"
namespace builder {
class Tensor {
public:
Tensor(Shape& shape, PrimitiveType& primitiveType);
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
Shape GetShape();
PrimitiveType GetType();
private:
std::shared_ptr<Impl> impl_;
};
} // namespace builder
#endif

View File

@ -1,32 +0,0 @@
#ifndef BUILDER_TENSORIMPL_
#define BUILDER_TENSORIMPL_
#include "Builder.h"
#include "PrimitiveType.h"
#include "Shape.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
namespace builder {
class Tensor::Impl {
public:
Impl() = default;
Impl(Shape& shape, PrimitiveType& primitiveType)
: shape_(shape), primitiveType_(primitiveType) {}
Shape GetShape() { return shape_; }
PrimitiveType GetType() { return primitiveType_; }
private:
Shape shape_;
PrimitiveType primitiveType_;
};
} // namespace builder
#endif

View File

@ -52,81 +52,93 @@ static const std::map<std::string, mlirTypeWrap> typeMapMLIR = {
{"::mlir::StringAttr", {"::mlir::StringAttr",
{"std::string", {"std::string",
[](std::string &var, OpMethodBody &body) -> std::string { [](std::string &var, OpMethodBody &body) -> std::string {
body << " mlir::StringAttr " << var << "_mlir = mlir::StringAttr::get(" body << " mlir::StringAttr " << var
<< var << ", ctx);\n"; << "_mlir = mlir::StringAttr::get(ctx, mlir::Twine(" << var
<< "));\n";
return var + "_mlir"; return var + "_mlir";
}}}, }}},
{"::mlir::IntegerAttr", {"::mlir::IntegerAttr",
{"builder::Integer", {"builder::Integer",
[](std::string &var, OpMethodBody &body) -> std::string { [](std::string &var, OpMethodBody &body) -> std::string {
// body << " mlir::IntegerAttr " << var << "_mlir = mlir::IntegerAttr::get("
// << var << ", ctx);\n";
body << " mlir::IntegerAttr " << var << "_mlir = " << var body << " mlir::IntegerAttr " << var << "_mlir = " << var
<< ".GetAttr(ctx);\n"; << ".GetImpl()->GetAttr(ctx);\n";
return var + "_mlir";
}}},
{"::mlir::DenseIntElementsAttr",
{"std::vector<int>",
[](std::string &var, OpMethodBody &body) -> std::string {
body << " mlir::DenseIntElementsAttr " << var
<< "_mlir = mlir::DenseIntElementsAttr::get("
<< "mlir::VectorType::get(" << var
<< ".size(), opBuilder->getIntegerType(32))," << var << ");\n";
return var + "_mlir";
}}},
{"::mlir::mhlo::ChannelHandle",
{"ChannelHandle",
[](std::string &var, OpMethodBody &body) -> std::string {
return var + "_mlir"; return var + "_mlir";
}}}, }}},
{"::mlir::FloatAttr", {"::mlir::FloatAttr",
{"float", {"builder::Float",
[](std::string &var, OpMethodBody &body) -> std::string { [](std::string &var, OpMethodBody &body) -> std::string {
body << " mlir::FloatAttr " << var << "_mlir = mlir::FloatAttr::get(" body << " mlir::FloatAttr " << var << "_mlir = " << var
<< var << ", ctx);\n"; << ".GetImpl()->GetAttr(ctx);\n";
return var + "_mlir";
}}},
{"::mlir::DenseIntElementsAttr",
{"builder::TensorInt",
[](std::string &var, OpMethodBody &body) -> std::string {
body << " mlir::DenseIntElementsAttr " << var << "_mlir = " << var
<< ".GetImpl()->GetAttr(ctx);\n";
return var + "_mlir";
}}},
{"::mlir::mhlo::ChannelHandle",
{"builder::ChannelHandle",
[](std::string &var, OpMethodBody &body) -> std::string {
body << " mlir::mhlo::ChannelHandle " << var << "_mlir = " << var
<< ".GetImpl()->GetAttr(ctx);\n";
return var + "_mlir"; return var + "_mlir";
}}}, }}},
{"::mlir::BoolAttr", {"::mlir::BoolAttr",
{"bool", {"bool",
[](std::string &var, OpMethodBody &body) -> std::string { [](std::string &var, OpMethodBody &body) -> std::string {
body << " mlir::BoolAttr " << var << "_mlir = mlir::BoolAttr::get(" body << " mlir::BoolAttr " << var
<< var << ", ctx);\n"; << "_mlir = mlir::BoolAttr::get(ctx, " << var << ");\n";
return var + "_mlir"; return var + "_mlir";
}}}, }}},
{"::mlir::ElementsAttr", {"::mlir::ElementsAttr",
{"::builder::Array", {"builder::Tensor",
[](std::string &var, OpMethodBody &body) -> std::string { [](std::string &var, OpMethodBody &body) -> std::string {
body << " mlir::DenseElementsAttr " << var << "_mlir = " << var
<< ".GetImpl()->GetAttr(ctx);\n";
return var + "_mlir"; return var + "_mlir";
}}}, }}},
{"::mlir::DenseElementsAttr", {"::mlir::DenseElementsAttr",
{"::builder::Tensor", {"builder::Tensor",
[](std::string &var, OpMethodBody &body) -> std::string { [](std::string &var, OpMethodBody &body) -> std::string {
body << " mlir::DenseElementsAttr " << var << "_mlir = " << var
<< ".GetImpl()->GetAttr(ctx);\n";
return var + "_mlir"; return var + "_mlir";
}}}, }}},
// current only support string array.
{"::mlir::ArrayAttr", {"::mlir::ArrayAttr",
{"std::vector<std::string>", {"builder::Array",
[](std::string &var, OpMethodBody &body) -> std::string { [](std::string &var, OpMethodBody &body) -> std::string {
body << " mlir::ArrayAttr " << var << "_mlir = " << var
<< ".GetImpl()->GetAttr(ctx);\n";
return var + "_mlir"; return var + "_mlir";
}}}, }}},
{"::mlir::mhlo::ConvDimensionNumbers", {"::mlir::mhlo::ConvDimensionNumbers",
{"::builder::ConvDimensionNumbers", {"builder::ConvDimensionNumbers",
[](std::string &var, OpMethodBody &body) -> std::string { [](std::string &var, OpMethodBody &body) -> std::string {
body << " mlir::mhlo::ConvDimensionNumbers " << var
<< "_mlir = " << var << ".GetImpl()->GetAttr(ctx);\n";
return var + "_mlir"; return var + "_mlir";
}}}, }}},
{"::mlir::mhlo::DotDimensionNumbers", {"::mlir::mhlo::DotDimensionNumbers",
{"::builder::DotDimensionNumbers", {"builder::DotDimensionNumbers",
[](std::string &var, OpMethodBody &body) -> std::string { [](std::string &var, OpMethodBody &body) -> std::string {
body << " mlir::mhlo::DotDimensionNumbers " << var << "_mlir = " << var
<< ".GetImpl()->GetAttr(ctx);\n";
return var + "_mlir"; return var + "_mlir";
}}}, }}},
{"::mlir::mhlo::GatherDimensionNumbers", {"::mlir::mhlo::GatherDimensionNumbers",
{"::builder::GatherDimensionNumbers", {"builder::GatherDimensionNumbers",
[](std::string &var, OpMethodBody &body) -> std::string { [](std::string &var, OpMethodBody &body) -> std::string {
body << " mlir::mhlo::GatherDimensionNumbers " << var
<< "_mlir = " << var << ".GetImpl()->GetAttr(ctx);\n";
return var + "_mlir"; return var + "_mlir";
}}}, }}},
{"::mlir::mhlo::ScatterDimensionNumbers", {"::mlir::mhlo::ScatterDimensionNumbers",
{"::builder::ScatterDimensionNumbers", {"builder::ScatterDimensionNumbers",
[](std::string &var, OpMethodBody &body) -> std::string { [](std::string &var, OpMethodBody &body) -> std::string {
body << " mlir::mhlo::ScatterDimensionNumbers " << var
<< "_mlir = " << var << ".GetImpl()->GetAttr(ctx);\n";
return var + "_mlir"; return var + "_mlir";
}}}, }}},
}; };
@ -138,16 +150,16 @@ static const std::map<std::string, mlirTypeWrap> typeMapMLIR = {
// {"::mlir::mhlo::ChannelHandle", "ChannelHandle"}, // {"::mlir::mhlo::ChannelHandle", "ChannelHandle"},
// {"::mlir::FloatAttr", "float"}, // {"::mlir::FloatAttr", "float"},
// {"::mlir::BoolAttr", "bool"}, // {"::mlir::BoolAttr", "bool"},
// {"::mlir::ElementsAttr", "::builder::Array"}, // {"::mlir::ElementsAttr", "builder::Array"},
// {"::mlir::DenseElementsAttr", "::builder::Tensor"}, // {"::mlir::DenseElementsAttr", "builder::Tensor"},
// // current only support string array. // // current only support string array.
// {"::mlir::ArrayAttr", "std::vector<std::string>"}, // {"::mlir::ArrayAttr", "std::vector<std::string>"},
// {"::mlir::mhlo::ConvDimensionNumbers", "::builder::ConvDimensionNumbers"}, // {"::mlir::mhlo::ConvDimensionNumbers", "builder::ConvDimensionNumbers"},
// {"::mlir::mhlo::DotDimensionNumbers", "::builder::DotDimensionNumbers"}, // {"::mlir::mhlo::DotDimensionNumbers", "builder::DotDimensionNumbers"},
// {"::mlir::mhlo::GatherDimensionNumbers", // {"::mlir::mhlo::GatherDimensionNumbers",
// "::builder::GatherDimensionNumbers"}, // "builder::GatherDimensionNumbers"},
// {"::mlir::mhlo::ScatterDimensionNumbers", // {"::mlir::mhlo::ScatterDimensionNumbers",
// "::builder::ScatterDimensionNumbers"}, // "builder::ScatterDimensionNumbers"},
// }; // };
StringRef typeConvertFromMLIR(StringRef type) { StringRef typeConvertFromMLIR(StringRef type) {
@ -707,14 +719,14 @@ OpEmitter::OpEmitter(const Operator &op,
// Generate C++ code for various op methods. The order here determines the // Generate C++ code for various op methods. The order here determines the
// methods in the generated file. // methods in the generated file.
// genOpAsmInterface(); // genOpAsmInterface();
genOpNameGetter(); //// genOpNameGetter();
// genNamedOperandGetters(); // genNamedOperandGetters();
// genNamedOperandSetters(); // genNamedOperandSetters();
// genNamedResultGetters(); // genNamedResultGetters();
// genNamedRegionGetters(); // genNamedRegionGetters();
genNamedSuccessorGetters(); //// genNamedSuccessorGetters();
genAttrGetters(); //// genAttrGetters();
genAttrSetters(); //// genAttrSetters();
// genOptionalAttrRemovers(); // genOptionalAttrRemovers();
genBuilder(); genBuilder();
// genParser(); // genParser();
@ -1179,7 +1191,7 @@ void OpEmitter::genSeparateArgParamBuilder() {
buildParamList(paramList, resultNames, paramKind, attrType); buildParamList(paramList, resultNames, paramKind, attrType);
auto *m = opClass.addMethodAndPrune( auto *m = opClass.addMethodAndPrune(
"::builder::Op", "build", OpMethod::MP_Static, std::move(paramList)); "builder::Op", "build", OpMethod::MP_Static, std::move(paramList));
// If the builder is redundant, skip generating the method. // If the builder is redundant, skip generating the method.
if (!m) if (!m)
return; return;
@ -1329,7 +1341,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
resultTypeNames.reserve(numResults); resultTypeNames.reserve(numResults);
// paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); // paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
paramList.emplace_back("::builder::Builder &", "builder"); paramList.emplace_back("builder::Builder &", "builder");
// paramList.emplace_back("::mlir::OperationState &", builderOpState); // paramList.emplace_back("::mlir::OperationState &", builderOpState);
switch (typeParamKind) { switch (typeParamKind) {
@ -1344,7 +1356,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
resultName = std::string(formatv("resultType{0}", i)); resultName = std::string(formatv("resultType{0}", i));
StringRef type = StringRef type =
result.isVariadic() ? "std::vector<::builder::Type>" : "::builder::Type"; result.isVariadic() ? "std::vector<builder::Type>" : "builder::Type";
OpMethodParameter::Property properties = OpMethodParameter::PP_None; OpMethodParameter::Property properties = OpMethodParameter::PP_None;
if (result.isOptional()) if (result.isOptional())
properties = OpMethodParameter::PP_Optional; properties = OpMethodParameter::PP_Optional;
@ -1371,7 +1383,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
if (argument.is<tblgen::NamedTypeConstraint *>()) { if (argument.is<tblgen::NamedTypeConstraint *>()) {
const auto &operand = op.getOperand(numOperands); const auto &operand = op.getOperand(numOperands);
StringRef type = StringRef type =
operand.isVariadic() ? "std::vector<::builder::Op>" : "::builder::Op"; operand.isVariadic() ? "std::vector<builder::Op>" : "builder::Op";
OpMethodParameter::Property properties = OpMethodParameter::PP_None; OpMethodParameter::Property properties = OpMethodParameter::PP_None;
if (operand.isOptional()) if (operand.isOptional())
properties = OpMethodParameter::PP_Optional; properties = OpMethodParameter::PP_Optional;
@ -1437,24 +1449,11 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
auto attrs = op.getAttributes(); auto attrs = op.getAttributes();
SmallVector<std::string, 4> newAttrs; SmallVector<std::string, 4> newAttrs;
// for(auto o : operands){ // if (attrType == "::mlir::DenseIntElementsAttr") {
// body << " operands name: "<<o.name<<" getCPPClassName:"<<o.constraint.getCPPClassName()<<"\n"; // body << " // BBBBBBBB getStorageType:" << a.attr.getStorageType().str()
// } // << "\n";
// for(auto a : attrs){ // body << " // BBBBBBBB getReturnType:" << a.attr.getReturnType().str()
// body << " attrs name: "<<a.name<<" type: "<< a.attr.getStorageType().str()<<"\n"; // << "\n";
// std::string attrType = a.attr.getStorageType().str();
// auto ff = typeMapMLIR.find(attrType);
// if(ff != typeMapMLIR.end()){
// body << "// BBBBBB \n";
// std::string attrName = a.name.str();
// ff->second.ConvertToMlir(attrName,body);
// }
// }
// body << "// AAAAAA \n";
// for(auto p : paramList){
// body << " AAA type"<<p.getType()<<" name"<<p.getName()<<"\n";
// } // }
body << " auto b = builder.GetImpl();\n"; body << " auto b = builder.GetImpl();\n";
@ -1463,16 +1462,6 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
body << " auto ctx = b->GetContext();\n"; body << " auto ctx = b->GetContext();\n";
for (auto a : attrs) { for (auto a : attrs) {
std::string attrType = a.attr.getStorageType().str(); std::string attrType = a.attr.getStorageType().str();
if(attrType == "::mlir::DenseIntElementsAttr")
{
body << " // BBBBBBBB getStorageType:"<< a.attr.getStorageType().str() <<"\n";
body << " // BBBBBBBB getReturnType:"<< a.attr.getReturnType().str() <<"\n";
}
auto typePair = typeMapMLIR.find(attrType); auto typePair = typeMapMLIR.find(attrType);
if (typePair != typeMapMLIR.end()) { if (typePair != typeMapMLIR.end()) {
std::string attrName = a.name.str(); std::string attrName = a.name.str();
@ -1480,20 +1469,42 @@ if(attrType == "::mlir::DenseIntElementsAttr")
newAttrs.emplace_back(mlirName); newAttrs.emplace_back(mlirName);
} }
} }
int index = 0;
for (auto v : operands) {
std::string name =
v.name.empty() ? "odsArg" + std::to_string(index) : v.name.str();
index++;
if (v.isVariadic()) {
body << " std::vector<mlir::Value> " << name << "_v;\n";
body << " for(auto v : " << name << "){\n " << name
<< "_v.push_back(v.GetImpl()->GetResult());\n }"
<< "\n";
}
}
body << " mlir::" << op.getDialectName() << "::" << op.getCppClassName() body << " mlir::" << op.getDialectName() << "::" << op.getCppClassName()
<< " currentOp =\n"; << " currentOp =\n";
body << " opBuilder.create<mlir::" << op.getDialectName() body << " opBuilder.create<mlir::" << op.getDialectName()
<< "::" << op.getDialectName() << ">(\n"; << "::" << op.getCppClassName() << ">(\n";
body << " loc"; body << " loc";
index = 0;
std::for_each(operands.begin(), operands.end(), [&](NamedTypeConstraint &v) { std::for_each(operands.begin(), operands.end(), [&](NamedTypeConstraint &v) {
body << ",\n " << v.name << ".getResult()"; std::string name =
v.name.empty() ? "odsArg_" + std::to_string(index) : v.name.str();
index++;
if (v.isVariadic()) {
body << ",\n " << name << "_v";
} else {
body << ",\n " << name << ".GetImpl()->GetResult()";
}
}); });
std::for_each(newAttrs.begin(), newAttrs.end(), std::for_each(newAttrs.begin(), newAttrs.end(),
[&](std::string &n) { body << ",\n " << n; }); [&](std::string &n) { body << ",\n " << n; });
body << "\n );\n"; body << "\n );\n";
body << " builder::" << op.getCppClassName() << " builderOp;\n"; body << " builder::mhlo::" << op.getCppClassName() << " builderOp;\n";
body << " auto opImpl = builderOp.GetImpl();\n"; body << " auto opImpl = builderOp.GetImpl();\n";
body << " opImpl.SetOperation(currentOp.getOperation());\n"; body << " opImpl->SetOperation(currentOp.getOperation());\n";
body << " return builderOp;\n"; body << " return builderOp;\n";
// // Push all operands to the result. // // Push all operands to the result.
@ -2109,7 +2120,8 @@ void OpEmitter::genTraits() {
void OpEmitter::genOpNameGetter() { void OpEmitter::genOpNameGetter() {
auto *method = opClass.addMethodAndPrune( auto *method = opClass.addMethodAndPrune(
"std::string", "getOperationName", "std::string", "getOperationName",
OpMethod::Property(OpMethod::MP_Static | OpMethod::MP_Constexpr)); OpMethod::Property(OpMethod::MP_Static));
// OpMethod::Property(OpMethod::MP_Static | OpMethod::MP_Constexpr));
method->body() << " return std::string(\"" << op.getOperationName() method->body() << " return std::string(\"" << op.getOperationName()
<< "\");"; << "\");";
} }
@ -2196,7 +2208,7 @@ static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
std::string className = Operator(def).getQualCppClassName(); std::string className = Operator(def).getQualCppClassName();
llvm::SplitString(StringRef(className), namespaces, StringRef("::")); llvm::SplitString(StringRef(className), namespaces, StringRef("::"));
if (namespaces.begin() != namespaces.end()) if (namespaces.begin() != namespaces.end())
os << "::builder::mhlo::" << namespaces.back().str(); os << "builder::mhlo::" << namespaces.back().str();
}, },
[&os]() { os << ",\n"; }); [&os]() { os << ",\n"; });
} }