71 lines
2.9 KiB
C++
71 lines
2.9 KiB
C++
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||
|
|
||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
you may not use this file except in compliance with the License.
|
||
|
You may obtain a copy of the License at
|
||
|
|
||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
||
|
Unless required by applicable law or agreed to in writing, software
|
||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
See the License for the specific language governing permissions and
|
||
|
limitations under the License.
|
||
|
==============================================================================*/
|
||
|
|
||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
|
||
|
|
||
|
#include <numeric>
|
||
|
|
||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||
|
|
||
|
namespace mlir {
|
||
|
namespace xla {
|
||
|
|
||
|
DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value x, Value y,
|
||
|
bool allow_empty) {
|
||
|
TensorType xType = x.getType().dyn_cast<RankedTensorType>();
|
||
|
TensorType yType = y.getType().dyn_cast<RankedTensorType>();
|
||
|
if (!xType || !yType) return {};
|
||
|
if (allow_empty && xType == yType) return {};
|
||
|
|
||
|
// If the shapes have the same rank, then there is nothing to do.
|
||
|
auto xRank = xType.getRank(), yRank = yType.getRank();
|
||
|
if (allow_empty && xRank == yRank) return {};
|
||
|
|
||
|
// Otherwise if the ranks of the inputs don't match, TensorFlow automatically
|
||
|
// reshapes the smaller by padding with dimensions of size 1 as a prefix. In
|
||
|
// other words to pad a 5-vector to a 3-dimensional tensor it is reshaped to
|
||
|
// have shape [1,1,5]. XLA's automatic broadcast code is able to broadcast
|
||
|
// from lower to higher rank, but doesn't assume you want to pad as a prefix
|
||
|
// of the dimensions, and instead needs to be told which dimensions of the
|
||
|
// higher rank tensor to match to the lower rank tensor.
|
||
|
auto maxRank = std::max(xRank, yRank);
|
||
|
auto minRank = std::min(xRank, yRank);
|
||
|
|
||
|
// Match the lower rank tensor along the larger-numbered dimensions of the
|
||
|
// higher rank tensor.
|
||
|
SmallVector<int64_t, 4> broadcastDimensions(minRank);
|
||
|
std::iota(broadcastDimensions.begin(), broadcastDimensions.end(),
|
||
|
maxRank - minRank);
|
||
|
|
||
|
RankedTensorType type =
|
||
|
RankedTensorType::get({minRank}, b->getIntegerType(64));
|
||
|
return DenseIntElementsAttr::get(type, broadcastDimensions);
|
||
|
}
|
||
|
|
||
|
DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) {
|
||
|
RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
|
||
|
|
||
|
if (auto float_ty = ty.dyn_cast<FloatType>()) {
|
||
|
APFloat value(float_ty.getFloatSemantics(), raw_value);
|
||
|
return DenseElementsAttr::get(scalar_ty, value);
|
||
|
}
|
||
|
auto int_ty = ty.cast<IntegerType>();
|
||
|
APInt value(int_ty.getWidth(), static_cast<int64_t>(raw_value), true);
|
||
|
return DenseElementsAttr::get(scalar_ty, value);
|
||
|
}
|
||
|
|
||
|
} // namespace xla
|
||
|
} // namespace mlir
|