#ifndef BUILDER_ATTRIBUTEIMPL_H_ #define BUILDER_ATTRIBUTEIMPL_H_ #include #include "Builder.h" #include "llvm/Support/Casting.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" namespace builder { class PrimitiveType::Impl { public: enum pType { PRIMITIVE_TYPE_INVALID = 0, PRED = 1, S8 = 2, S16 = 3, S32 = 4, S64 = 5, U8 = 6, U16 = 7, U32 = 8, U64 = 9, F16 = 10, F32 = 11, BF16 = 16, F64 = 12, C64 = 15, C128 = 18, TUPLE = 13, OPAQUE_TYPE = 14, TOKEN = 17 }; Impl(pType t) : t_(t) { switch (t) { case PRIMITIVE_TYPE_INVALID: GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type { return mlir::NoneType::get(context); }; unitBits_ = 0; break; case PRED: GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type { return mlir::IntegerType::get( context, 1, mlir::IntegerType::SignednessSemantics::Signed); }; unitBits_ = 1; break; case S8: GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type { return mlir::IntegerType::get( context, 8, mlir::IntegerType::SignednessSemantics::Signed); }; unitBits_ = 8; break; case S16: GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type { return mlir::IntegerType::get( context, 16, mlir::IntegerType::SignednessSemantics::Signed); }; unitBits_ = 16; break; case F32: GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type { return mlir::Float32Type::get(context); }; unitBits_ = 32; break; case S32: GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type { return mlir::IntegerType::get( context, 32, mlir::IntegerType::SignednessSemantics::Signed); }; unitBits_ = 32; break; case S64: GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type { return mlir::IntegerType::get( context, 64, mlir::IntegerType::SignednessSemantics::Signed); }; unitBits_ = 64; break; default: GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type { return mlir::NoneType::get(context); }; unitBits_ = 0; break; } } // class BFloat16Type; // class ComplexType; // class Float128Type; // class Float16Type; // class Float32Type; // class Float64Type; // class Float80Type; // class FunctionType; // class IndexType; // class IntegerType; // class MemRefType; // class NoneType; // class OpaqueType; // class RankedTensorType; // class TupleType; // class UnrankedMemRefType; // class UnrankedTensorType; // class VectorType; inline bool operator==(const PrimitiveType::Impl &impl) { return t_ == impl.t_; } std::function GetMlirType; uint64_t GetUnitBytes() const { return unitBits_ / 8; } uint64_t GetUnitBits() const { return unitBits_; } private: pType t_; uint64_t unitBits_; }; class Shape::Impl { public: Impl(std::vector dims) : dims_(dims) {} const std::vector GetDims() const { return dims_; } int64_t GetSize() const { int64_t size; for (auto &d : dims_) { size *= d; } return size; } private: std::vector dims_; }; class Type::Impl { public: Impl(Shape &shape, PrimitiveType &primitiveType) : shape_(shape), primitiveType_(primitiveType) {} Shape GetShape() const { return shape_; } PrimitiveType GetType() { return primitiveType_; } mlir::Type GetMlirType(mlir::MLIRContext *context) const { return mlir::RankedTensorType::get( llvm::ArrayRef(shape_.GetImpl()->GetDims()), primitiveType_.GetImpl()->GetMlirType(context)); } private: Shape shape_; PrimitiveType primitiveType_; }; class Integer::Impl { public: Impl(int value) { GetAttr = [=](mlir::MLIRContext *context) -> mlir::IntegerAttr { auto int32_type = mlir::IntegerType::get(context, 32); return mlir::IntegerAttr::get(int32_type, value); }; }; Impl(int64_t value) { GetAttr = [=](mlir::MLIRContext *context) -> mlir::IntegerAttr { auto int64_type = mlir::IntegerType::get(context, 64); return mlir::IntegerAttr::get(int64_type, value); }; }; std::function GetAttr; private: }; class Float::Impl { public: Impl(float value) { GetAttr = [=](mlir::MLIRContext *context) -> mlir::FloatAttr { auto float32_type = mlir::Float32Type::get(context); return mlir::FloatAttr::get(float32_type, value); }; }; Impl(double value) { GetAttr = [=](mlir::MLIRContext *context) -> mlir::FloatAttr { auto float64_type = mlir::Float64Type::get(context); return mlir::FloatAttr::get(float64_type, value); }; }; std::function GetAttr; private: }; class Array::Impl { public: Impl(std::vector value) : size_({value.size()}), primitiveType_(PrimitiveType::S32()) { // GetAttr = [=](mlir::MLIRContext *context) -> mlir::ArrayAttr { // auto type = // mlir::RankedTensorType::get(llvm::ArrayRef({value.size()}), // mlir::IntegerType::get(context, 32)); // return mlir::ArrayAttr::get(type, llvm::makeArrayRef(value)); // }; } Impl(std::vector value) : size_({value.size()}), primitiveType_(PrimitiveType::S64()) { // GetAttr = [=](mlir::MLIRContext *context) -> mlir::ArrayAttr { // auto type = // mlir::RankedTensorType::get(llvm::ArrayRef({value.size()}), // mlir::IntegerType::get(context, 64)); // return mlir::ArrayAttr::get(type, llvm::makeArrayRef(value)); // }; } Impl(std::vector value) : size_({value.size()}), primitiveType_(PrimitiveType::F32()) { // GetAttr = [=](mlir::MLIRContext *context) -> mlir::ArrayAttr { // auto type = // mlir::RankedTensorType::get(llvm::ArrayRef({value.size()}), // mlir::FloatType::getF32(context)); // return mlir::ArrayAttr::get(type, llvm::makeArrayRef(value)); // }; } int64_t GetSize() { return size_; } PrimitiveType GetType() { return primitiveType_; } std::function GetAttr; private: int64_t size_; PrimitiveType primitiveType_; }; class Tensor::Impl { public: Impl(Shape &shape, std::vector value) : shape_(shape) { GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr { auto type = mlir::RankedTensorType::get( llvm::ArrayRef(shape_.GetImpl()->GetDims()), mlir::IntegerType::get(context, 32)); return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value)); }; } Impl(Shape &shape, std::vector value) : shape_(shape) { GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr { auto type = mlir::RankedTensorType::get( llvm::ArrayRef(shape_.GetImpl()->GetDims()), mlir::IntegerType::get(context, 64)); return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value)); }; } Impl(Shape &shape, std::vector value) : shape_(shape) { GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr { auto type = mlir::RankedTensorType::get( llvm::ArrayRef(shape_.GetImpl()->GetDims()), mlir::FloatType::getF32(context)); return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value)); }; } // Impl(Shape &shape, PrimitiveType &primitiveType, const void *value) // : shape_(shape), primitiveType_(primitiveType) { // GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr { // auto type = mlir::RankedTensorType::get( // llvm::ArrayRef(shape_.GetImpl()->GetDims()), // primitiveType_.GetImpl()->GetMlirType(context)); // return mlir::DenseElementsAttr::get( // type, // llvm::ArrayRef(reinterpret_cast(value), // shape_.GetImpl()->GetSize() * // primitiveType_.GetImpl()->GetUnitBytes())); // }; // } Shape GetShape() { return shape_; } // PrimitiveType GetType() { return primitiveType_; } std::function GetAttr; private: Shape shape_; // PrimitiveType primitiveType_; }; class TensorInt::Impl { public: Impl(Shape &shape, PrimitiveType &primitiveType) : shape_(shape), primitiveType_(primitiveType) {} Shape GetShape() { return shape_; } PrimitiveType GetType() { return primitiveType_; } std::function GetAttr; private: Shape shape_; PrimitiveType primitiveType_; }; class ChannelHandle::Impl { public: Impl() {} std::function GetAttr; private: }; class ConvDimensionNumbers::Impl { public: Impl() {} std::function GetAttr; private: }; class DotDimensionNumbers::Impl { public: Impl() {} std::function GetAttr; private: }; class GatherDimensionNumbers::Impl { public: Impl() {} std::function GetAttr; private: }; class ScatterDimensionNumbers::Impl { public: Impl() {} std::function GetAttr; private: }; } // namespace builder #endif