mlir-hlo/include/mlir-hlo/utils/hlo_utils.h

98 lines
3.9 KiB
C++

/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
namespace mlir {
namespace hlo {
// Computes the broadcast dimensions attr for an elementwise binary operator
// between two ranked tensors.
// If `allow_empty` is true, then null can be returned to mean that the
// broadcast is an "identity".
mlir::DenseIntElementsAttr getBroadcastDimensionsAttr(mlir::Builder* b,
mlir::Value x,
mlir::Value y,
bool allow_empty = true);
// Get a constant splat for the given value of type. Requires value to be of
// type static shaped RankedTensorType.
template <typename T>
static ElementsAttr getSplat(Builder* b, RankedTensorType ty, T constant) {
Type element_ty = getElementTypeOrSelf(ty);
if (element_ty.isSignlessInteger())
return DenseElementsAttr::get(ty, b->getIntegerAttr(element_ty, constant));
if (element_ty.isa<FloatType>())
return DenseElementsAttr::get(ty, b->getFloatAttr(element_ty, constant));
if (auto complex_ty = element_ty.dyn_cast<ComplexType>()) {
auto complex_element_ty = complex_ty.getElementType();
if (complex_element_ty.isF32())
return DenseElementsAttr::get(ty,
static_cast<std::complex<float>>(constant));
if (complex_element_ty.isF64())
return DenseElementsAttr::get(
ty, static_cast<std::complex<double>>(constant));
}
llvm_unreachable("unhandled element type");
}
template <typename T>
static ElementsAttr getSplat(Builder* b, Value val, T constant) {
return getSplat(b, val.getType().cast<RankedTensorType>(), constant);
}
// Returns DenseElementsAttr of rank zero with the given element type and the
// value.
// Requires `ty` to be either FloatType, IntegerType, or ComplexType.
DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value);
// Enum type used to specify scalar argument to GetScalarLimitOfType.
enum ScalarLimit {
kLowest, // The scalar corresponding to numeric_limits<T>::lowest.
kInfinityLowest, // Like kMax, but returns -infinity where available.
kMax, // The scalar corresponding to numeric_limits<T>::max.
kInfinityMax, // Like kMax, but returns infinity where available.
};
// Returns a scalar limit value for the given type.
//
// The argument 'limit' describes which scalar value to return.
//
// Requires `ty` to be either FloatType or IntegerType.
DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit);
// Given `op_name` from LMHLO, returns the corresponding op name in MHLO.
// Returns empty string if no such op exists.
std::string LmhloToMhloOpName(llvm::StringRef op_name,
mlir::MLIRContext* context);
// Return true if Attr has values [0, 1, ...].
bool IsSequenceStartingWith0(DenseIntElementsAttr attr);
} // namespace hlo
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_