refine test and tensor
This commit is contained in:
		
							parent
							
								
									975a47f7b2
								
							
						
					
					
						commit
						84e7697c6a
					
				
							
								
								
									
										32
									
								
								BUILD
								
								
								
								
							
							
						
						
									
										32
									
								
								BUILD
								
								
								
								
							| 
						 | 
					@ -174,39 +174,21 @@ cc_library(
 | 
				
			||||||
        ":disc_ral",
 | 
					        ":disc_ral",
 | 
				
			||||||
        ":hlo",
 | 
					        ":hlo",
 | 
				
			||||||
        ":lhlo",
 | 
					        ":lhlo",
 | 
				
			||||||
        ":lhlo_gpu",
 | 
					 | 
				
			||||||
        ":hlo_ops_builder_gen",
 | 
					        ":hlo_ops_builder_gen",
 | 
				
			||||||
        "@llvm-project//mlir:MlirTableGenMain",
 | 
					        "@llvm-project//mlir:MlirTableGenMain",
 | 
				
			||||||
        "@llvm-project//mlir:Support",
 | 
					        "@llvm-project//mlir:Support",
 | 
				
			||||||
        "@llvm-project//mlir:IR",
 | 
					        "@llvm-project//mlir:IR",
 | 
				
			||||||
        "@llvm-project//mlir:Pass",
 | 
					        "@llvm-project//mlir:Pass",
 | 
				
			||||||
 | 
					        "@llvm-project//mlir:AllPassesAndDialects",
 | 
				
			||||||
 | 
					        # "@llvm-project//mlir:MlirOptLib",
 | 
				
			||||||
 | 
					        # "@llvm-project//mlir:MlirJitRunner",
 | 
				
			||||||
 | 
					        "@llvm-project//llvm:Core",
 | 
				
			||||||
 | 
					        "@llvm-project//llvm:Option",
 | 
				
			||||||
 | 
					        "@llvm-project//llvm:OrcJIT",
 | 
				
			||||||
        "@llvm-project//llvm:Support",
 | 
					        "@llvm-project//llvm:Support",
 | 
				
			||||||
        "@llvm-project//llvm:TableGen",
 | 
					        "@llvm-project//llvm:TableGen",
 | 
				
			||||||
 | 
					        "@llvm-project//llvm:Target",
 | 
				
			||||||
        "@llvm-project//llvm:config",
 | 
					        "@llvm-project//llvm:config",
 | 
				
			||||||
        # "@llvm-project//llvm:AllTargetsAsmParsers",
 | 
					 | 
				
			||||||
        # "@llvm-project//llvm:AllTargetsCodeGens",
 | 
					 | 
				
			||||||
        # "@llvm-project//llvm:Core",
 | 
					 | 
				
			||||||
        # "@llvm-project//llvm:ExecutionEngine",
 | 
					 | 
				
			||||||
        # "@llvm-project//llvm:Option",
 | 
					 | 
				
			||||||
        # "@llvm-project//llvm:OrcJIT",
 | 
					 | 
				
			||||||
        # "@llvm-project//llvm:Support",
 | 
					 | 
				
			||||||
        # "@llvm-project//llvm:Target",
 | 
					 | 
				
			||||||
        # "@llvm-project//mlir:AllPassesAndDialects",
 | 
					 | 
				
			||||||
        # "@llvm-project//mlir:IR",
 | 
					 | 
				
			||||||
        # "@llvm-project//mlir:MlirOptLib",
 | 
					 | 
				
			||||||
        # "@llvm-project//mlir:Support",
 | 
					 | 
				
			||||||
        # "@llvm-project//mlir:MlirJitRunner",
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # "@llvm-project//mlir:Analysis",
 | 
					 | 
				
			||||||
        # "@llvm-project//mlir:ControlFlowInterfaces",
 | 
					 | 
				
			||||||
        # "@llvm-project//mlir:InferTypeOpInterface",
 | 
					 | 
				
			||||||
        # "@llvm-project//mlir:MemRefDialect",
 | 
					 | 
				
			||||||
        # "@llvm-project//mlir:Shape",
 | 
					 | 
				
			||||||
        # "@llvm-project//mlir:SideEffects",
 | 
					 | 
				
			||||||
        # "@llvm-project//mlir:StandardOps",
 | 
					 | 
				
			||||||
        # "@llvm-project//mlir:TensorDialect",
 | 
					 | 
				
			||||||
        # "@llvm-project//mlir:TransformUtils",
 | 
					 | 
				
			||||||
        # "@llvm-project//mlir:Transforms",
 | 
					 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -3,19 +3,20 @@
 | 
				
			||||||
int main() {
 | 
					int main() {
 | 
				
			||||||
  builder::Builder builder;
 | 
					  builder::Builder builder;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  builder::Shape shape({100, 100});
 | 
					  builder::Shape shape({10, 10});
 | 
				
			||||||
  auto pType = builder::PrimitiveType::F32();
 | 
					  auto pType = builder::PrimitiveType::F32();
 | 
				
			||||||
  builder::Type type(shape, pType);
 | 
					  builder::Type type(shape, pType);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  builder::Tensor tensor(std::vector<int>(100));
 | 
					  std::vector<float> data(100);
 | 
				
			||||||
 | 
					  builder::Tensor tensor(shape, data);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto in1 = builder.CreateInput(type);
 | 
				
			||||||
  auto op1 = builder::mhlo::ConstOp::build(builder, type, tensor);
 | 
					  auto op1 = builder::mhlo::ConstOp::build(builder, type, tensor);
 | 
				
			||||||
  auto op2 = builder::mhlo::ConstOp::build(builder, type, tensor);
 | 
					  auto op2 = builder::mhlo::ConstOp::build(builder, type, tensor);
 | 
				
			||||||
  auto op3 = builder::mhlo::AddOp::build(builder, type, op1, op2);
 | 
					  auto op3 = builder::mhlo::AddOp::build(builder, type, op1, op2);
 | 
				
			||||||
 | 
					  auto op4 = builder::mhlo::MulOp::build(builder, type, op3, in1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  builder.SetOutput(std::vector<builder::Op>({op4}));
 | 
				
			||||||
  builder.DumpModule();
 | 
					  builder.DumpModule();
 | 
				
			||||||
  return 0;
 | 
					  return 0;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
// static ::builder::Op build(::builder::Builder &builder, ::builder::Type
 | 
					 | 
				
			||||||
// output, ::builder::Tensor value);
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -50,6 +50,7 @@ PrimitiveType PrimitiveType::S64() {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Shape::Shape(std::vector<int64_t> dims)
 | 
					Shape::Shape(std::vector<int64_t> dims)
 | 
				
			||||||
    : impl_(std::make_shared<Shape::Impl>(dims)) {}
 | 
					    : impl_(std::make_shared<Shape::Impl>(dims)) {}
 | 
				
			||||||
 | 
					std::vector<int64_t> Shape::GetDims() { return impl_->GetDims(); }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Type::Type(Shape& shape, PrimitiveType& primitiveType)
 | 
					Type::Type(Shape& shape, PrimitiveType& primitiveType)
 | 
				
			||||||
    : impl_(std::make_shared<Type::Impl>(shape, primitiveType)) {}
 | 
					    : impl_(std::make_shared<Type::Impl>(shape, primitiveType)) {}
 | 
				
			||||||
| 
						 | 
					@ -70,16 +71,16 @@ Array::Array(std::vector<int64_t> value)
 | 
				
			||||||
Array::Array(std::vector<std::string> value)
 | 
					Array::Array(std::vector<std::string> value)
 | 
				
			||||||
    : impl_(std::make_shared<Array::Impl>(value)) {}
 | 
					    : impl_(std::make_shared<Array::Impl>(value)) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Tensor::Tensor(Shape& shape, PrimitiveType& primitiveType)
 | 
					// Tensor::Tensor(Shape& shape, PrimitiveType& primitiveType, const void* value)
 | 
				
			||||||
    : impl_(std::make_shared<Impl>(shape, primitiveType)) {}
 | 
					//     : impl_(std::make_shared<Impl>(shape, primitiveType, value)) {}
 | 
				
			||||||
Shape Tensor::GetShape() { return impl_->GetShape(); }
 | 
					Shape Tensor::GetShape() { return impl_->GetShape(); }
 | 
				
			||||||
PrimitiveType Tensor::GetType() { return impl_->GetType(); }
 | 
					// PrimitiveType Tensor::GetType() { return impl_->GetType(); }
 | 
				
			||||||
Tensor::Tensor(std::vector<int> value)
 | 
					Tensor::Tensor(Shape& shape, std::vector<int> value)
 | 
				
			||||||
    : impl_(std::make_shared<Tensor::Impl>(value)) {}
 | 
					    : impl_(std::make_shared<Tensor::Impl>(shape, value)) {}
 | 
				
			||||||
Tensor::Tensor(std::vector<int64_t> value)
 | 
					Tensor::Tensor(Shape& shape, std::vector<int64_t> value)
 | 
				
			||||||
    : impl_(std::make_shared<Tensor::Impl>(value)) {}
 | 
					    : impl_(std::make_shared<Tensor::Impl>(shape, value)) {}
 | 
				
			||||||
Tensor::Tensor(std::vector<float> value)
 | 
					Tensor::Tensor(Shape& shape, std::vector<float> value)
 | 
				
			||||||
    : impl_(std::make_shared<Tensor::Impl>(value)) {}
 | 
					    : impl_(std::make_shared<Tensor::Impl>(shape, value)) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TensorInt::TensorInt(Shape& shape, PrimitiveType& primitiveType)
 | 
					TensorInt::TensorInt(Shape& shape, PrimitiveType& primitiveType)
 | 
				
			||||||
    : impl_(std::make_shared<Impl>(shape, primitiveType)) {}
 | 
					    : impl_(std::make_shared<Impl>(shape, primitiveType)) {}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -17,7 +17,7 @@ class PrimitiveType {
 | 
				
			||||||
  inline bool operator==(const PrimitiveType& pt);
 | 
					  inline bool operator==(const PrimitiveType& pt);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  class Impl;
 | 
					  class Impl;
 | 
				
			||||||
  std::shared_ptr<Impl> GetImpl() { return impl_; }
 | 
					  std::shared_ptr<Impl> GetImpl() const { return impl_; }
 | 
				
			||||||
  PrimitiveType(std::shared_ptr<Impl>);
 | 
					  PrimitiveType(std::shared_ptr<Impl>);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 private:
 | 
					 private:
 | 
				
			||||||
| 
						 | 
					@ -29,7 +29,8 @@ class Shape {
 | 
				
			||||||
 public:
 | 
					 public:
 | 
				
			||||||
  Shape(std::vector<int64_t> dims);
 | 
					  Shape(std::vector<int64_t> dims);
 | 
				
			||||||
  class Impl;
 | 
					  class Impl;
 | 
				
			||||||
  std::shared_ptr<Impl> GetImpl() { return impl_; }
 | 
					  std::shared_ptr<Impl> GetImpl() const { return impl_; }
 | 
				
			||||||
 | 
					  std::vector<int64_t> GetDims();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 private:
 | 
					 private:
 | 
				
			||||||
  std::shared_ptr<Impl> impl_;
 | 
					  std::shared_ptr<Impl> impl_;
 | 
				
			||||||
| 
						 | 
					@ -75,10 +76,9 @@ class Array {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Tensor {
 | 
					class Tensor {
 | 
				
			||||||
 public:
 | 
					 public:
 | 
				
			||||||
  Tensor(Shape& shape, PrimitiveType& primitiveType);
 | 
					  Tensor(Shape& shape, std::vector<int> value);
 | 
				
			||||||
  Tensor(std::vector<int> value);
 | 
					  Tensor(Shape& shape, std::vector<int64_t> value);
 | 
				
			||||||
  Tensor(std::vector<int64_t> value);
 | 
					  Tensor(Shape& shape, std::vector<float> value);
 | 
				
			||||||
  Tensor(std::vector<float> value);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  Shape GetShape();
 | 
					  Shape GetShape();
 | 
				
			||||||
  PrimitiveType GetType();
 | 
					  PrimitiveType GetType();
 | 
				
			||||||
| 
						 | 
					@ -107,7 +107,7 @@ class Type {
 | 
				
			||||||
 public:
 | 
					 public:
 | 
				
			||||||
  Type(Shape& shape, PrimitiveType& primitiveType);
 | 
					  Type(Shape& shape, PrimitiveType& primitiveType);
 | 
				
			||||||
  class Impl;
 | 
					  class Impl;
 | 
				
			||||||
  std::shared_ptr<Impl> GetImpl() { return impl_; }
 | 
					  const std::shared_ptr<Impl> GetImpl() const { return impl_; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 private:
 | 
					 private:
 | 
				
			||||||
  std::shared_ptr<Impl> impl_;
 | 
					  std::shared_ptr<Impl> impl_;
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -48,46 +48,54 @@ class PrimitiveType::Impl {
 | 
				
			||||||
        GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
 | 
					        GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
 | 
				
			||||||
          return mlir::NoneType::get(context);
 | 
					          return mlir::NoneType::get(context);
 | 
				
			||||||
        };
 | 
					        };
 | 
				
			||||||
 | 
					        unitBits_ = 0;
 | 
				
			||||||
        break;
 | 
					        break;
 | 
				
			||||||
      case PRED:
 | 
					      case PRED:
 | 
				
			||||||
        GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
 | 
					        GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
 | 
				
			||||||
          return mlir::IntegerType::get(
 | 
					          return mlir::IntegerType::get(
 | 
				
			||||||
              context, 1, mlir::IntegerType::SignednessSemantics::Signed);
 | 
					              context, 1, mlir::IntegerType::SignednessSemantics::Signed);
 | 
				
			||||||
        };
 | 
					        };
 | 
				
			||||||
 | 
					        unitBits_ = 1;
 | 
				
			||||||
        break;
 | 
					        break;
 | 
				
			||||||
      case S8:
 | 
					      case S8:
 | 
				
			||||||
        GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
 | 
					        GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
 | 
				
			||||||
          return mlir::IntegerType::get(
 | 
					          return mlir::IntegerType::get(
 | 
				
			||||||
              context, 8, mlir::IntegerType::SignednessSemantics::Signed);
 | 
					              context, 8, mlir::IntegerType::SignednessSemantics::Signed);
 | 
				
			||||||
        };
 | 
					        };
 | 
				
			||||||
 | 
					        unitBits_ = 8;
 | 
				
			||||||
        break;
 | 
					        break;
 | 
				
			||||||
      case S16:
 | 
					      case S16:
 | 
				
			||||||
        GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
 | 
					        GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
 | 
				
			||||||
          return mlir::IntegerType::get(
 | 
					          return mlir::IntegerType::get(
 | 
				
			||||||
              context, 16, mlir::IntegerType::SignednessSemantics::Signed);
 | 
					              context, 16, mlir::IntegerType::SignednessSemantics::Signed);
 | 
				
			||||||
        };
 | 
					        };
 | 
				
			||||||
 | 
					        unitBits_ = 16;
 | 
				
			||||||
        break;
 | 
					        break;
 | 
				
			||||||
      case F32:
 | 
					      case F32:
 | 
				
			||||||
        GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
 | 
					        GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
 | 
				
			||||||
          return mlir::Float32Type::get(context);
 | 
					          return mlir::Float32Type::get(context);
 | 
				
			||||||
        };
 | 
					        };
 | 
				
			||||||
 | 
					        unitBits_ = 32;
 | 
				
			||||||
        break;
 | 
					        break;
 | 
				
			||||||
      case S32:
 | 
					      case S32:
 | 
				
			||||||
        GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
 | 
					        GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
 | 
				
			||||||
          return mlir::IntegerType::get(
 | 
					          return mlir::IntegerType::get(
 | 
				
			||||||
              context, 32, mlir::IntegerType::SignednessSemantics::Signed);
 | 
					              context, 32, mlir::IntegerType::SignednessSemantics::Signed);
 | 
				
			||||||
        };
 | 
					        };
 | 
				
			||||||
 | 
					        unitBits_ = 32;
 | 
				
			||||||
        break;
 | 
					        break;
 | 
				
			||||||
      case S64:
 | 
					      case S64:
 | 
				
			||||||
        GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
 | 
					        GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
 | 
				
			||||||
          return mlir::IntegerType::get(
 | 
					          return mlir::IntegerType::get(
 | 
				
			||||||
              context, 64, mlir::IntegerType::SignednessSemantics::Signed);
 | 
					              context, 64, mlir::IntegerType::SignednessSemantics::Signed);
 | 
				
			||||||
        };
 | 
					        };
 | 
				
			||||||
 | 
					        unitBits_ = 64;
 | 
				
			||||||
        break;
 | 
					        break;
 | 
				
			||||||
      default:
 | 
					      default:
 | 
				
			||||||
        GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
 | 
					        GetMlirType = [](mlir::MLIRContext *context) -> mlir::Type {
 | 
				
			||||||
          return mlir::NoneType::get(context);
 | 
					          return mlir::NoneType::get(context);
 | 
				
			||||||
        };
 | 
					        };
 | 
				
			||||||
 | 
					        unitBits_ = 0;
 | 
				
			||||||
        break;
 | 
					        break;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
| 
						 | 
					@ -116,13 +124,25 @@ class PrimitiveType::Impl {
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  std::function<mlir::Type(mlir::MLIRContext *context)> GetMlirType;
 | 
					  std::function<mlir::Type(mlir::MLIRContext *context)> GetMlirType;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  uint64_t GetUnitBytes() const { return unitBits_ / 8; }
 | 
				
			||||||
 | 
					  uint64_t GetUnitBits() const { return unitBits_; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 private:
 | 
					 private:
 | 
				
			||||||
  pType t_;
 | 
					  pType t_;
 | 
				
			||||||
 | 
					  uint64_t unitBits_;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Shape::Impl {
 | 
					class Shape::Impl {
 | 
				
			||||||
 public:
 | 
					 public:
 | 
				
			||||||
  Impl(std::vector<int64_t> dims) : dims_(dims) {}
 | 
					  Impl(std::vector<int64_t> dims) : dims_(dims) {}
 | 
				
			||||||
 | 
					  const std::vector<int64_t> GetDims() const { return dims_; }
 | 
				
			||||||
 | 
					  int64_t GetSize() const {
 | 
				
			||||||
 | 
					    int64_t size;
 | 
				
			||||||
 | 
					    for (auto &d : dims_) {
 | 
				
			||||||
 | 
					      size *= d;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    return size;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 private:
 | 
					 private:
 | 
				
			||||||
  std::vector<int64_t> dims_;
 | 
					  std::vector<int64_t> dims_;
 | 
				
			||||||
| 
						 | 
					@ -132,10 +152,12 @@ class Type::Impl {
 | 
				
			||||||
 public:
 | 
					 public:
 | 
				
			||||||
  Impl(Shape &shape, PrimitiveType &primitiveType)
 | 
					  Impl(Shape &shape, PrimitiveType &primitiveType)
 | 
				
			||||||
      : shape_(shape), primitiveType_(primitiveType) {}
 | 
					      : shape_(shape), primitiveType_(primitiveType) {}
 | 
				
			||||||
  Shape GetShape() { return shape_; }
 | 
					  Shape GetShape() const { return shape_; }
 | 
				
			||||||
  PrimitiveType GetType() { return primitiveType_; }
 | 
					  PrimitiveType GetType() { return primitiveType_; }
 | 
				
			||||||
  mlir::Type GetMlirType(mlir::MLIRContext *context) {
 | 
					  mlir::Type GetMlirType(mlir::MLIRContext *context) const {
 | 
				
			||||||
    return mlir::NoneType::get(context);
 | 
					    return mlir::RankedTensorType::get(
 | 
				
			||||||
 | 
					        llvm::ArrayRef<int64_t>(shape_.GetImpl()->GetDims()),
 | 
				
			||||||
 | 
					        primitiveType_.GetImpl()->GetMlirType(context));
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 private:
 | 
					 private:
 | 
				
			||||||
| 
						 | 
					@ -223,44 +245,53 @@ class Array::Impl {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Tensor::Impl {
 | 
					class Tensor::Impl {
 | 
				
			||||||
 public:
 | 
					 public:
 | 
				
			||||||
  Impl(std::vector<int> value)
 | 
					  Impl(Shape &shape, std::vector<int> value) : shape_(shape) {
 | 
				
			||||||
      : shape_({value.size()}), primitiveType_(PrimitiveType::S32()) {
 | 
					 | 
				
			||||||
    GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
 | 
					    GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
 | 
				
			||||||
      auto type =
 | 
					      auto type = mlir::RankedTensorType::get(
 | 
				
			||||||
          mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
 | 
					          llvm::ArrayRef<int64_t>(shape_.GetImpl()->GetDims()),
 | 
				
			||||||
          mlir::IntegerType::get(context, 32));
 | 
					          mlir::IntegerType::get(context, 32));
 | 
				
			||||||
      return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
 | 
					      return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
 | 
				
			||||||
    };
 | 
					    };
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  Impl(std::vector<int64_t> value)
 | 
					  Impl(Shape &shape, std::vector<int64_t> value) : shape_(shape) {
 | 
				
			||||||
      : shape_({value.size()}), primitiveType_(PrimitiveType::S64()) {
 | 
					 | 
				
			||||||
    GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
 | 
					    GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
 | 
				
			||||||
      auto type =
 | 
					      auto type = mlir::RankedTensorType::get(
 | 
				
			||||||
          mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
 | 
					          llvm::ArrayRef<int64_t>(shape_.GetImpl()->GetDims()),
 | 
				
			||||||
          mlir::IntegerType::get(context, 64));
 | 
					          mlir::IntegerType::get(context, 64));
 | 
				
			||||||
      return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
 | 
					      return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
 | 
				
			||||||
    };
 | 
					    };
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  Impl(std::vector<float> value)
 | 
					  Impl(Shape &shape, std::vector<float> value) : shape_(shape) {
 | 
				
			||||||
      : shape_({value.size()}), primitiveType_(PrimitiveType::F32()) {
 | 
					 | 
				
			||||||
    GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
 | 
					    GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
 | 
				
			||||||
      auto type =
 | 
					      auto type = mlir::RankedTensorType::get(
 | 
				
			||||||
          mlir::RankedTensorType::get(llvm::ArrayRef<int64_t>({value.size()}),
 | 
					          llvm::ArrayRef<int64_t>(shape_.GetImpl()->GetDims()),
 | 
				
			||||||
          mlir::FloatType::getF32(context));
 | 
					          mlir::FloatType::getF32(context));
 | 
				
			||||||
      return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
 | 
					      return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(value));
 | 
				
			||||||
    };
 | 
					    };
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  Impl(Shape &shape, PrimitiveType &primitiveType)
 | 
					
 | 
				
			||||||
      : shape_(shape), primitiveType_(primitiveType) {}
 | 
					  // Impl(Shape &shape, PrimitiveType &primitiveType, const void *value)
 | 
				
			||||||
 | 
					  //     : shape_(shape), primitiveType_(primitiveType) {
 | 
				
			||||||
 | 
					  //   GetAttr = [=](mlir::MLIRContext *context) -> mlir::DenseElementsAttr {
 | 
				
			||||||
 | 
					  //     auto type = mlir::RankedTensorType::get(
 | 
				
			||||||
 | 
					  //         llvm::ArrayRef<int64_t>(shape_.GetImpl()->GetDims()),
 | 
				
			||||||
 | 
					  //         primitiveType_.GetImpl()->GetMlirType(context));
 | 
				
			||||||
 | 
					  //     return mlir::DenseElementsAttr::get<char>(
 | 
				
			||||||
 | 
					  //         type,
 | 
				
			||||||
 | 
					  //         llvm::ArrayRef<char>(reinterpret_cast<const char *>(value),
 | 
				
			||||||
 | 
					  //                              shape_.GetImpl()->GetSize() *
 | 
				
			||||||
 | 
					  //                                  primitiveType_.GetImpl()->GetUnitBytes()));
 | 
				
			||||||
 | 
					  //   };
 | 
				
			||||||
 | 
					  // }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  Shape GetShape() { return shape_; }
 | 
					  Shape GetShape() { return shape_; }
 | 
				
			||||||
  PrimitiveType GetType() { return primitiveType_; }
 | 
					  // PrimitiveType GetType() { return primitiveType_; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  std::function<mlir::DenseElementsAttr(mlir::MLIRContext *context)> GetAttr;
 | 
					  std::function<mlir::DenseElementsAttr(mlir::MLIRContext *context)> GetAttr;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 private:
 | 
					 private:
 | 
				
			||||||
  Shape shape_;
 | 
					  Shape shape_;
 | 
				
			||||||
  PrimitiveType primitiveType_;
 | 
					  // PrimitiveType primitiveType_;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TensorInt::Impl {
 | 
					class TensorInt::Impl {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -19,6 +19,15 @@
 | 
				
			||||||
namespace builder {
 | 
					namespace builder {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Builder::Builder() : impl_(std::make_shared<Impl>()) {}
 | 
					Builder::Builder() : impl_(std::make_shared<Impl>()) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					builder::Op Builder::CreateInput(const builder::Type& type) {
 | 
				
			||||||
 | 
					  return impl_->CreateInput(type);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void Builder::SetOutput(const std::vector<builder::Op>& outputs) {
 | 
				
			||||||
 | 
					  impl_->SetOutput(outputs);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void Builder::DumpModule() { impl_->DumpModule(); }
 | 
					void Builder::DumpModule() { impl_->DumpModule(); }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // namespace builder
 | 
					}  // namespace builder
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -11,10 +11,13 @@ namespace builder {
 | 
				
			||||||
class Builder {
 | 
					class Builder {
 | 
				
			||||||
 public:
 | 
					 public:
 | 
				
			||||||
  Builder();
 | 
					  Builder();
 | 
				
			||||||
 | 
					  void SetInput(const std::vector<builder::Op>& inputs);
 | 
				
			||||||
 | 
					  builder::Op CreateInput(const builder::Type& type);
 | 
				
			||||||
 | 
					  void SetOutput(const std::vector<builder::Op>& outputs);
 | 
				
			||||||
  void DumpModule();
 | 
					  void DumpModule();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  class Impl;
 | 
					  class Impl;
 | 
				
			||||||
  std::shared_ptr<Impl> GetImpl() { return impl_; }
 | 
					  std::shared_ptr<Impl> GetImpl() const { return impl_; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 private:
 | 
					 private:
 | 
				
			||||||
  std::shared_ptr<Impl> impl_;
 | 
					  std::shared_ptr<Impl> impl_;
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,9 +1,15 @@
 | 
				
			||||||
#ifndef BUILDER_BUILDERIMPL_
 | 
					#ifndef BUILDER_BUILDERIMPL_
 | 
				
			||||||
#define BUILDER_BUILDERIMPL_
 | 
					#define BUILDER_BUILDERIMPL_
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "Attribute.h"
 | 
				
			||||||
 | 
					#include "AttributeImpl.h"
 | 
				
			||||||
#include "Builder.h"
 | 
					#include "Builder.h"
 | 
				
			||||||
 | 
					#include "OpImpl.h"
 | 
				
			||||||
#include "llvm/Support/Casting.h"
 | 
					#include "llvm/Support/Casting.h"
 | 
				
			||||||
// #include "llvm/Support/InitLLVM.h"
 | 
					// #include "llvm/Support/InitLLVM.h"
 | 
				
			||||||
 | 
					#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 | 
				
			||||||
 | 
					#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 | 
				
			||||||
 | 
					#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
 | 
				
			||||||
#include "mlir/IR/Attributes.h"
 | 
					#include "mlir/IR/Attributes.h"
 | 
				
			||||||
#include "mlir/IR/Block.h"
 | 
					#include "mlir/IR/Block.h"
 | 
				
			||||||
#include "mlir/IR/Builders.h"
 | 
					#include "mlir/IR/Builders.h"
 | 
				
			||||||
| 
						 | 
					@ -13,18 +19,14 @@
 | 
				
			||||||
#include "mlir/IR/Operation.h"
 | 
					#include "mlir/IR/Operation.h"
 | 
				
			||||||
#include "mlir/IR/Types.h"
 | 
					#include "mlir/IR/Types.h"
 | 
				
			||||||
#include "mlir/IR/Value.h"
 | 
					#include "mlir/IR/Value.h"
 | 
				
			||||||
// #include "mlir/InitAllDialects.h"
 | 
					#include "mlir/InitAllDialects.h"
 | 
				
			||||||
// #include "mlir/InitAllPasses.h"
 | 
					#include "mlir/InitAllPasses.h"
 | 
				
			||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 | 
					 | 
				
			||||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 | 
					 | 
				
			||||||
#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace builder {
 | 
					namespace builder {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Builder::Impl {
 | 
					class Builder::Impl {
 | 
				
			||||||
 public:
 | 
					 public:
 | 
				
			||||||
  Impl() : builder_(&context_) {
 | 
					  Impl() : builder_(&context_) {
 | 
				
			||||||
 | 
					 | 
				
			||||||
    // llvm::InitLLVM y(argc, argv);
 | 
					    // llvm::InitLLVM y(argc, argv);
 | 
				
			||||||
    // llvm::InitializeNativeTarget();
 | 
					    // llvm::InitializeNativeTarget();
 | 
				
			||||||
    // llvm::InitializeNativeTargetAsmPrinter();
 | 
					    // llvm::InitializeNativeTargetAsmPrinter();
 | 
				
			||||||
| 
						 | 
					@ -37,14 +39,14 @@ class Builder::Impl {
 | 
				
			||||||
    // registerDefaultTimingManagerCLOptions();
 | 
					    // registerDefaultTimingManagerCLOptions();
 | 
				
			||||||
    // DebugCounter::registerCLOptions();
 | 
					    // DebugCounter::registerCLOptions();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // mlir::registerAllPasses();
 | 
					    mlir::registerAllPasses();
 | 
				
			||||||
    mlir::mhlo::registerAllMhloPasses();
 | 
					    mlir::mhlo::registerAllMhloPasses();
 | 
				
			||||||
    mlir::lmhlo::registerAllLmhloPasses();
 | 
					    // mlir::lmhlo::registerAllLmhloPasses();
 | 
				
			||||||
    mlir::disc_ral::registerAllDiscRalPasses();
 | 
					    // mlir::disc_ral::registerAllDiscRalPasses();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    mlir::DialectRegistry registry;
 | 
					    mlir::DialectRegistry registry;
 | 
				
			||||||
    // mlir::registerAllToLLVMIRTranslations(registry);
 | 
					    // mlir::registerAllToLLVMIRTranslations(registry);
 | 
				
			||||||
    // mlir::registerAllDialects(registry);
 | 
					    mlir::registerAllDialects(registry);
 | 
				
			||||||
    registry.insert<mlir::mhlo::MhloDialect>();
 | 
					    registry.insert<mlir::mhlo::MhloDialect>();
 | 
				
			||||||
    // registry.insert<mlir::chlo::HloClientDialect>();
 | 
					    // registry.insert<mlir::chlo::HloClientDialect>();
 | 
				
			||||||
    // registry.insert<mlir::lmhlo::LmhloDialect>();
 | 
					    // registry.insert<mlir::lmhlo::LmhloDialect>();
 | 
				
			||||||
| 
						 | 
					@ -53,9 +55,6 @@ class Builder::Impl {
 | 
				
			||||||
    context_.appendDialectRegistry(registry);
 | 
					    context_.appendDialectRegistry(registry);
 | 
				
			||||||
    context_.loadAllAvailableDialects();
 | 
					    context_.loadAllAvailableDialects();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context_));
 | 
					    module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context_));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    llvm::SmallVector<mlir::Type, 4> arg_types;
 | 
					    llvm::SmallVector<mlir::Type, 4> arg_types;
 | 
				
			||||||
| 
						 | 
					@ -74,6 +73,35 @@ class Builder::Impl {
 | 
				
			||||||
  mlir::MLIRContext* GetContext() { return &context_; }
 | 
					  mlir::MLIRContext* GetContext() { return &context_; }
 | 
				
			||||||
  void DumpModule() { module_.dump(); }
 | 
					  void DumpModule() { module_.dump(); }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  builder::Op CreateInput(const builder::Type& type) {
 | 
				
			||||||
 | 
					    mlir::BlockArgument arg =
 | 
				
			||||||
 | 
					        entry_block_->addArgument(type.GetImpl()->GetMlirType(&context_));
 | 
				
			||||||
 | 
					    builder::Op op;
 | 
				
			||||||
 | 
					    op.GetImpl()->SetValue(arg);
 | 
				
			||||||
 | 
					    return op;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void SetOutput(const std::vector<builder::Op>& outputs) {
 | 
				
			||||||
 | 
					    llvm::SmallVector<mlir::Type, 4> arg_types;
 | 
				
			||||||
 | 
					    int arg_num = entry_block_->getNumArguments();
 | 
				
			||||||
 | 
					    for (int i = 0; i < arg_num; ++i) {
 | 
				
			||||||
 | 
					      arg_types.push_back(entry_block_->getArgument(i).getType());
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    llvm::SmallVector<mlir::Type, 4> ret_types;
 | 
				
			||||||
 | 
					    llvm::SmallVector<mlir::Value, 4> ret_vals;
 | 
				
			||||||
 | 
					    for (auto& out : outputs) {
 | 
				
			||||||
 | 
					      mlir::Value v = out.GetImpl()->GetResult();
 | 
				
			||||||
 | 
					      ret_types.push_back(v.getType());
 | 
				
			||||||
 | 
					      ret_vals.push_back(v);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    // return all output tensors.
 | 
				
			||||||
 | 
					    builder_.create<mlir::ReturnOp>(builder_.getUnknownLoc(), ret_vals);
 | 
				
			||||||
 | 
					    // Update main function input/output type
 | 
				
			||||||
 | 
					    mlir::FunctionType funcType =
 | 
				
			||||||
 | 
					        builder_.getFunctionType(arg_types, ret_types);
 | 
				
			||||||
 | 
					    main_func_.setType(funcType);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 private:
 | 
					 private:
 | 
				
			||||||
  mlir::MLIRContext context_;
 | 
					  mlir::MLIRContext context_;
 | 
				
			||||||
  mlir::ModuleOp module_;
 | 
					  mlir::ModuleOp module_;
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -9,7 +9,7 @@ class Op {
 | 
				
			||||||
 public:
 | 
					 public:
 | 
				
			||||||
  Op();
 | 
					  Op();
 | 
				
			||||||
  class Impl;
 | 
					  class Impl;
 | 
				
			||||||
  std::shared_ptr<Impl> GetImpl() { return impl_; }
 | 
					  std::shared_ptr<Impl> GetImpl() const { return impl_; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 private:
 | 
					 private:
 | 
				
			||||||
  std::shared_ptr<Impl> impl_;
 | 
					  std::shared_ptr<Impl> impl_;
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -14,12 +14,19 @@ namespace builder {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Op::Impl {
 | 
					class Op::Impl {
 | 
				
			||||||
 public:
 | 
					 public:
 | 
				
			||||||
  Impl() = default;
 | 
					  Impl() : op_(nullptr), value_(){};
 | 
				
			||||||
  void SetOperation(mlir::Operation *Op) { op_ = Op; }
 | 
					  void SetOperation(mlir::Operation *Op) { op_ = Op; }
 | 
				
			||||||
  mlir::Value GetResult() { return op_->getResult(0); }
 | 
					  void SetValue(mlir::Value &value) { value_ = value; }
 | 
				
			||||||
 | 
					  mlir::Value GetResult() {
 | 
				
			||||||
 | 
					    if (op_ != nullptr)
 | 
				
			||||||
 | 
					      return op_->getResult(0);
 | 
				
			||||||
 | 
					    else
 | 
				
			||||||
 | 
					      return value_;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 private:
 | 
					 private:
 | 
				
			||||||
  mlir::Operation *op_;
 | 
					  mlir::Operation *op_;
 | 
				
			||||||
 | 
					  mlir::Value value_;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // namespace builder
 | 
					}  // namespace builder
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue