358 lines
10 KiB
C++
358 lines
10 KiB
C++
#ifndef BUILDER_ATTRIBUTEIMPL_H_
|
|
#define BUILDER_ATTRIBUTEIMPL_H_
|
|
|
|
#include <vector>
|
|
|
|
#include "Builder.h"
|
|
#include "llvm/Support/Casting.h"
|
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
|
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
|
|
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
|
#include "mlir/IR/Attributes.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/MLIRContext.h"
|
|
#include "mlir/IR/Operation.h"
|
|
#include "mlir/IR/Types.h"
|
|
#include "mlir/IR/Value.h"
|
|
|
|
namespace builder {
|
|
|
|
class PrimitiveType::Impl {
|
|
public:
|
|
enum pType {
|
|
PRIMITIVE_TYPE_INVALID = 0,
|
|
PRED = 1,
|
|
S8 = 2,
|
|
S16 = 3,
|
|
S32 = 4,
|
|
S64 = 5,
|
|
U8 = 6,
|
|
U16 = 7,
|
|
U32 = 8,
|
|
U64 = 9,
|
|
F16 = 10,
|
|
F32 = 11,
|
|
BF16 = 16,
|
|
F64 = 12,
|
|
C64 = 15,
|
|
C128 = 18,
|
|
TUPLE = 13,
|
|
OPAQUE_TYPE = 14,
|
|
TOKEN = 17
|
|
};
|
|
|
|
Impl(pType t) : t_(t) {
|
|
switch (t) {
|
|
case PRIMITIVE_TYPE_INVALID:
|
|
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
|
return mlir::NoneType::get(context);
|
|
};
|
|
unitBits_ = 0;
|
|
break;
|
|
case PRED:
|
|
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
|
return mlir::IntegerType::get(
|
|
context, 1, mlir::IntegerType::SignednessSemantics::Signed);
|
|
};
|
|
unitBits_ = 1;
|
|
break;
|
|
case S8:
|
|
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
|
return mlir::IntegerType::get(
|
|
context, 8, mlir::IntegerType::SignednessSemantics::Signed);
|
|
};
|
|
unitBits_ = 8;
|
|
break;
|
|
case S16:
|
|
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
|
return mlir::IntegerType::get(
|
|
context, 16, mlir::IntegerType::SignednessSemantics::Signed);
|
|
};
|
|
unitBits_ = 16;
|
|
break;
|
|
case F32:
|
|
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
|
return mlir::Float32Type::get(context);
|
|
};
|
|
unitBits_ = 32;
|
|
break;
|
|
case S32:
|
|
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
|
return mlir::IntegerType::get(
|
|
context, 32, mlir::IntegerType::SignednessSemantics::Signed);
|
|
};
|
|
unitBits_ = 32;
|
|
break;
|
|
case S64:
|
|
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
|
return mlir::IntegerType::get(
|
|
context, 64, mlir::IntegerType::SignednessSemantics::Signed);
|
|
};
|
|
unitBits_ = 64;
|
|
break;
|
|
default:
|
|
GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
|
|
return mlir::NoneType::get(context);
|
|
};
|
|
unitBits_ = 0;
|
|
break;
|
|
}
|
|
}
|
|
|
|
// class BFloat16Type;
|
|
// class ComplexType;
|
|
// class Float128Type;
|
|
// class Float16Type;
|
|
// class Float32Type;
|
|
// class Float64Type;
|
|
// class Float80Type;
|
|
// class FunctionType;
|
|
// class IndexType;
|
|
// class IntegerType;
|
|
// class MemRefType;
|
|
// class NoneType;
|
|
// class OpaqueType;
|
|
// class RankedTensorType;
|
|
// class TupleType;
|
|
// class UnrankedMemRefType;
|
|
// class UnrankedTensorType;
|
|
// class VectorType;
|
|
|
|
inline bool operator==(const PrimitiveType::Impl &impl) {
|
|
return t_ == impl.t_;
|
|
}
|
|
std::function<mlir::Type(mlir::MLIRContext *context)> GetMlirType;
|
|
|
|
uint64_t GetUnitBytes() const { return unitBits_ / 8; }
|
|
uint64_t GetUnitBits() const { return unitBits_; }
|
|
|
|
private:
|
|
pType t_;
|
|
uint64_t unitBits_;
|
|
};
|
|
|
|
class Shape::Impl {
|
|
public:
|
|
Impl(std::vector<int64_t> dims) : dims_(dims) {}
|
|
const std::vector<int64_t> GetDims() const { return dims_; }
|
|
int64_t GetSize() const {
|
|
int64_t size;
|
|
for (auto &d : dims_) {
|
|
size *= d;
|
|
}
|
|
return size;
|
|
}
|
|
|
|
private:
|
|
std::vector<int64_t> dims_;
|
|
};
|
|
|
|
class Type::Impl {
|
|
public:
|
|
Impl(Shape &shape, PrimitiveType &primitiveType)
|
|
: shape_(shape), primitiveType_(primitiveType) {}
|
|
Shape GetShape() const { return shape_; }
|
|
PrimitiveType GetType() { return primitiveType_; }
|
|
mlir::Type GetMlirType(mlir::MLIRContext *context) const {
|
|
return mlir::RankedTensorType::get(
|
|
llvm::ArrayRef<int64_t>(shape_.GetImpl()->GetDims()),
|
|
primitiveType_.GetImpl()->GetMlirType(context));
|
|
}
|
|
|
|
private:
|
|
Shape shape_;
|
|
PrimitiveType primitiveType_;
|
|
};
|
|
|
|
class Integer::Impl {
|
|
public:
|
|
Impl(int value) {
|
|
GetAttr = [=](mlir::MLIRContext *context) -> mlir::IntegerAttr {
|
|
auto int32_type = mlir::IntegerType::get(context, 32);
|
|
return mlir::IntegerAttr::get(int32_type, value);
|
|
};
|
|
};
|
|
Impl(int64_t value) {
|
|
GetAttr = [=](mlir::MLIRContext *context) -> mlir::IntegerAttr {
|
|
auto int64_type = mlir::IntegerType::get(context, 64);
|
|
return mlir::IntegerAttr::get(int64_type, value);
|
|
};
|
|
};
|
|
std::function<mlir::IntegerAttr(mlir::MLIRContext *context)> GetAttr;
|
|
|
|
private:
|
|
};
|
|
|
|
class Float::Impl {
|
|
public:
|
|
Impl(float value) {
|
|
GetAttr = [=](mlir::MLIRContext *context) -> mlir::FloatAttr {
|
|
auto float32_type = mlir::Float32Type::get(context);
|
|
return mlir::FloatAttr::get(float32_type, value);
|
|
};
|
|
};
|
|
Impl(double value) {
|
|
GetAttr = [=](mlir::MLIRContext *context) -> mlir::FloatAttr {
|
|
auto float64_type = mlir::Float64Type::get(context);
|
|
return mlir::FloatAttr::get(float64_type, value);
|
|
};
|
|
};
|
|
std::function<mlir::FloatAttr(mlir::MLIRContext *context)> GetAttr;
|
|
|
|
private:
|
|
};
|
|
|
|
class Array::Impl {
|
|
public:
|
|
Impl(std::vector<int> value)
|
|
: size_({value.size()}), primitiveType_(PrimitiveType::S32()) {
|
|
// GetAttr = [=](mlir::MLIRContext *context) -> mlir::ArrayAttr {
|
|
// auto type =
|
|
// mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
|
|
// mlir::IntegerType::get(context, 32));
|
|
// return mlir::ArrayAttr::get(type, llvm::makeArrayRef(value));
|
|
// };
|
|
}
|
|
Impl(std::vector<int64_t> value)
|
|
: size_({value.size()}), primitiveType_(PrimitiveType::S64()) {
|
|
// GetAttr = [=](mlir::MLIRContext *context) -> mlir::ArrayAttr {
|
|
// auto type =
|
|
// mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
|
|
// mlir::IntegerType::get(context, 64));
|
|
// return mlir::ArrayAttr::get(type, llvm::makeArrayRef(value));
|
|
// };
|
|
}
|
|
Impl(std::vector<std::string> value)
|
|
: size_({value.size()}), primitiveType_(PrimitiveType::F32()) {
|
|
// GetAttr = [=](mlir::MLIRContext *context) -> mlir::ArrayAttr {
|
|
// auto type =
|
|
// mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
|
|
// mlir::FloatType::getF32(context));
|
|
// return mlir::ArrayAttr::get(type, llvm::makeArrayRef(value));
|
|
// };
|
|
}
|
|
|
|
int64_t GetSize() { return size_; }
|
|
PrimitiveType GetType() { return primitiveType_; }
|
|
|
|
std::function<mlir::ArrayAttr(mlir::MLIRContext *context)> GetAttr;
|
|
|
|
private:
|
|
int64_t size_;
|
|
PrimitiveType primitiveType_;
|
|
};
|
|
|
|
class Tensor::Impl {
|
|
public:
|
|
Impl(Shape &shape, std::vector<int> value) : shape_(shape) {
|
|
GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
|
|
auto type = mlir::RankedTensorType::get(
|
|
llvm::ArrayRef<int64_t>(shape_.GetImpl()->GetDims()),
|
|
mlir::IntegerType::get(context, 32));
|
|
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
|
|
};
|
|
}
|
|
Impl(Shape &shape, std::vector<int64_t> value) : shape_(shape) {
|
|
GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
|
|
auto type = mlir::RankedTensorType::get(
|
|
llvm::ArrayRef<int64_t>(shape_.GetImpl()->GetDims()),
|
|
mlir::IntegerType::get(context, 64));
|
|
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
|
|
};
|
|
}
|
|
Impl(Shape &shape, std::vector<float> value) : shape_(shape) {
|
|
GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
|
|
auto type = mlir::RankedTensorType::get(
|
|
llvm::ArrayRef<int64_t>(shape_.GetImpl()->GetDims()),
|
|
mlir::FloatType::getF32(context));
|
|
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
|
|
};
|
|
}
|
|
|
|
// Impl(Shape &shape, PrimitiveType &primitiveType, const void *value)
|
|
// : shape_(shape), primitiveType_(primitiveType) {
|
|
// GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
|
|
// auto type = mlir::RankedTensorType::get(
|
|
// llvm::ArrayRef<int64_t>(shape_.GetImpl()->GetDims()),
|
|
// primitiveType_.GetImpl()->GetMlirType(context));
|
|
// return mlir::DenseElementsAttr::get<char>(
|
|
// type,
|
|
// llvm::ArrayRef<char>(reinterpret_cast<const char *>(value),
|
|
// shape_.GetImpl()->GetSize() *
|
|
// primitiveType_.GetImpl()->GetUnitBytes()));
|
|
// };
|
|
// }
|
|
|
|
Shape GetShape() { return shape_; }
|
|
// PrimitiveType GetType() { return primitiveType_; }
|
|
|
|
std::function<mlir::DenseElementsAttr(mlir::MLIRContext *context)> GetAttr;
|
|
|
|
private:
|
|
Shape shape_;
|
|
// PrimitiveType primitiveType_;
|
|
};
|
|
|
|
class TensorInt::Impl {
|
|
public:
|
|
Impl(Shape &shape, PrimitiveType &primitiveType)
|
|
: shape_(shape), primitiveType_(primitiveType) {}
|
|
|
|
Shape GetShape() { return shape_; }
|
|
PrimitiveType GetType() { return primitiveType_; }
|
|
|
|
std::function<mlir::DenseIntElementsAttr(mlir::MLIRContext *context)> GetAttr;
|
|
|
|
private:
|
|
Shape shape_;
|
|
PrimitiveType primitiveType_;
|
|
};
|
|
|
|
class ChannelHandle::Impl {
|
|
public:
|
|
Impl() {}
|
|
std::function<mlir::mhlo::ChannelHandle(mlir::MLIRContext *context)> GetAttr;
|
|
|
|
private:
|
|
};
|
|
|
|
class ConvDimensionNumbers::Impl {
|
|
public:
|
|
Impl() {}
|
|
std::function<mlir::mhlo::ConvDimensionNumbers(mlir::MLIRContext *context)>
|
|
GetAttr;
|
|
|
|
private:
|
|
};
|
|
|
|
class DotDimensionNumbers::Impl {
|
|
public:
|
|
Impl() {}
|
|
std::function<mlir::mhlo::DotDimensionNumbers(mlir::MLIRContext *context)>
|
|
GetAttr;
|
|
|
|
private:
|
|
};
|
|
|
|
class GatherDimensionNumbers::Impl {
|
|
public:
|
|
Impl() {}
|
|
std::function<mlir::mhlo::GatherDimensionNumbers(mlir::MLIRContext *context)>
|
|
GetAttr;
|
|
|
|
private:
|
|
};
|
|
|
|
class ScatterDimensionNumbers::Impl {
|
|
public:
|
|
Impl() {}
|
|
std::function<mlir::mhlo::ScatterDimensionNumbers(mlir::MLIRContext *context)>
|
|
GetAttr;
|
|
|
|
private:
|
|
};
|
|
|
|
} // namespace builder
|
|
|
|
#endif |