mlir-hlo/tools/mlir-tblgen-builder/Builder/AttributeImpl.h

358 lines
10 KiB
C++

#ifndef BUILDER_ATTRIBUTEIMPL_H_
#define BUILDER_ATTRIBUTEIMPL_H_
#include <vector>
#include "Builder.h"
#include "llvm/Support/Casting.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
namespace builder {
class PrimitiveType::Impl {
public:
enum pType {
PRIMITIVE_TYPE_INVALID = 0,
PRED = 1,
S8 = 2,
S16 = 3,
S32 = 4,
S64 = 5,
U8 = 6,
U16 = 7,
U32 = 8,
U64 = 9,
F16 = 10,
F32 = 11,
BF16 = 16,
F64 = 12,
C64 = 15,
C128 = 18,
TUPLE = 13,
OPAQUE_TYPE = 14,
TOKEN = 17
};
Impl(pType t) : t_(t) {
switch (t) {
case PRIMITIVE_TYPE_INVALID:
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::NoneType::get(context);
};
unitBits_ = 0;
break;
case PRED:
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(
context, 1, mlir::IntegerType::SignednessSemantics::Signed);
};
unitBits_ = 1;
break;
case S8:
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(
context, 8, mlir::IntegerType::SignednessSemantics::Signed);
};
unitBits_ = 8;
break;
case S16:
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(
context, 16, mlir::IntegerType::SignednessSemantics::Signed);
};
unitBits_ = 16;
break;
case F32:
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::Float32Type::get(context);
};
unitBits_ = 32;
break;
case S32:
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(
context, 32, mlir::IntegerType::SignednessSemantics::Signed);
};
unitBits_ = 32;
break;
case S64:
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(
context, 64, mlir::IntegerType::SignednessSemantics::Signed);
};
unitBits_ = 64;
break;
default:
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::NoneType::get(context);
};
unitBits_ = 0;
break;
}
}
// class BFloat16Type;
// class ComplexType;
// class Float128Type;
// class Float16Type;
// class Float32Type;
// class Float64Type;
// class Float80Type;
// class FunctionType;
// class IndexType;
// class IntegerType;
// class MemRefType;
// class NoneType;
// class OpaqueType;
// class RankedTensorType;
// class TupleType;
// class UnrankedMemRefType;
// class UnrankedTensorType;
// class VectorType;
inline bool operator==(const PrimitiveType::Impl &impl) {
return t_ == impl.t_;
}
std::function<mlir::Type(mlir::MLIRContext *context)> GetMlirType;
uint64_t GetUnitBytes() const { return unitBits_ / 8; }
uint64_t GetUnitBits() const { return unitBits_; }
private:
pType t_;
uint64_t unitBits_;
};
class Shape::Impl {
public:
Impl(std::vector<int64_t> dims) : dims_(dims) {}
const std::vector<int64_t> GetDims() const { return dims_; }
int64_t GetSize() const {
int64_t size;
for (auto &d : dims_) {
size *= d;
}
return size;
}
private:
std::vector<int64_t> dims_;
};
class Type::Impl {
public:
Impl(Shape &shape, PrimitiveType &primitiveType)
: shape_(shape), primitiveType_(primitiveType) {}
Shape GetShape() const { return shape_; }
PrimitiveType GetType() { return primitiveType_; }
mlir::Type GetMlirType(mlir::MLIRContext *context) const {
return mlir::RankedTensorType::get(
llvm::ArrayRef<int64_t>(shape_.GetImpl()->GetDims()),
primitiveType_.GetImpl()->GetMlirType(context));
}
private:
Shape shape_;
PrimitiveType primitiveType_;
};
class Integer::Impl {
public:
Impl(int value) {
GetAttr = [=](mlir::MLIRContext *context) -> mlir::IntegerAttr {
auto int32_type = mlir::IntegerType::get(context, 32);
return mlir::IntegerAttr::get(int32_type, value);
};
};
Impl(int64_t value) {
GetAttr = [=](mlir::MLIRContext *context) -> mlir::IntegerAttr {
auto int64_type = mlir::IntegerType::get(context, 64);
return mlir::IntegerAttr::get(int64_type, value);
};
};
std::function<mlir::IntegerAttr(mlir::MLIRContext *context)> GetAttr;
private:
};
class Float::Impl {
public:
Impl(float value) {
GetAttr = [=](mlir::MLIRContext *context) -> mlir::FloatAttr {
auto float32_type = mlir::Float32Type::get(context);
return mlir::FloatAttr::get(float32_type, value);
};
};
Impl(double value) {
GetAttr = [=](mlir::MLIRContext *context) -> mlir::FloatAttr {
auto float64_type = mlir::Float64Type::get(context);
return mlir::FloatAttr::get(float64_type, value);
};
};
std::function<mlir::FloatAttr(mlir::MLIRContext *context)> GetAttr;
private:
};
class Array::Impl {
public:
Impl(std::vector<int> value)
: size_({value.size()}), primitiveType_(PrimitiveType::S32()) {
// GetAttr = [=](mlir::MLIRContext *context) -> mlir::ArrayAttr {
// auto type =
// mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
// mlir::IntegerType::get(context, 32));
// return mlir::ArrayAttr::get(type, llvm::makeArrayRef(value));
// };
}
Impl(std::vector<int64_t> value)
: size_({value.size()}), primitiveType_(PrimitiveType::S64()) {
// GetAttr = [=](mlir::MLIRContext *context) -> mlir::ArrayAttr {
// auto type =
// mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
// mlir::IntegerType::get(context, 64));
// return mlir::ArrayAttr::get(type, llvm::makeArrayRef(value));
// };
}
Impl(std::vector<std::string> value)
: size_({value.size()}), primitiveType_(PrimitiveType::F32()) {
// GetAttr = [=](mlir::MLIRContext *context) -> mlir::ArrayAttr {
// auto type =
// mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
// mlir::FloatType::getF32(context));
// return mlir::ArrayAttr::get(type, llvm::makeArrayRef(value));
// };
}
int64_t GetSize() { return size_; }
PrimitiveType GetType() { return primitiveType_; }
std::function<mlir::ArrayAttr(mlir::MLIRContext *context)> GetAttr;
private:
int64_t size_;
PrimitiveType primitiveType_;
};
class Tensor::Impl {
public:
Impl(Shape &shape, std::vector<int> value) : shape_(shape) {
GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
auto type = mlir::RankedTensorType::get(
llvm::ArrayRef<int64_t>(shape_.GetImpl()->GetDims()),
mlir::IntegerType::get(context, 32));
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
};
}
Impl(Shape &shape, std::vector<int64_t> value) : shape_(shape) {
GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
auto type = mlir::RankedTensorType::get(
llvm::ArrayRef<int64_t>(shape_.GetImpl()->GetDims()),
mlir::IntegerType::get(context, 64));
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
};
}
Impl(Shape &shape, std::vector<float> value) : shape_(shape) {
GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
auto type = mlir::RankedTensorType::get(
llvm::ArrayRef<int64_t>(shape_.GetImpl()->GetDims()),
mlir::FloatType::getF32(context));
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
};
}
// Impl(Shape &shape, PrimitiveType &primitiveType, const void *value)
// : shape_(shape), primitiveType_(primitiveType) {
// GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
// auto type = mlir::RankedTensorType::get(
// llvm::ArrayRef<int64_t>(shape_.GetImpl()->GetDims()),
// primitiveType_.GetImpl()->GetMlirType(context));
// return mlir::DenseElementsAttr::get<char>(
// type,
// llvm::ArrayRef<char>(reinterpret_cast<const char *>(value),
// shape_.GetImpl()->GetSize() *
// primitiveType_.GetImpl()->GetUnitBytes()));
// };
// }
Shape GetShape() { return shape_; }
// PrimitiveType GetType() { return primitiveType_; }
std::function<mlir::DenseElementsAttr(mlir::MLIRContext *context)> GetAttr;
private:
Shape shape_;
// PrimitiveType primitiveType_;
};
class TensorInt::Impl {
public:
Impl(Shape &shape, PrimitiveType &primitiveType)
: shape_(shape), primitiveType_(primitiveType) {}
Shape GetShape() { return shape_; }
PrimitiveType GetType() { return primitiveType_; }
std::function<mlir::DenseIntElementsAttr(mlir::MLIRContext *context)> GetAttr;
private:
Shape shape_;
PrimitiveType primitiveType_;
};
class ChannelHandle::Impl {
public:
Impl() {}
std::function<mlir::mhlo::ChannelHandle(mlir::MLIRContext *context)> GetAttr;
private:
};
class ConvDimensionNumbers::Impl {
public:
Impl() {}
std::function<mlir::mhlo::ConvDimensionNumbers(mlir::MLIRContext *context)>
GetAttr;
private:
};
class DotDimensionNumbers::Impl {
public:
Impl() {}
std::function<mlir::mhlo::DotDimensionNumbers(mlir::MLIRContext *context)>
GetAttr;
private:
};
class GatherDimensionNumbers::Impl {
public:
Impl() {}
std::function<mlir::mhlo::GatherDimensionNumbers(mlir::MLIRContext *context)>
GetAttr;
private:
};
class ScatterDimensionNumbers::Impl {
public:
Impl() {}
std::function<mlir::mhlo::ScatterDimensionNumbers(mlir::MLIRContext *context)>
GetAttr;
private:
};
} // namespace builder
#endif