/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "mlir-hlo/utils/hlo_utils.h" #include #include "mlir/IR/Attributes.h" namespace mlir { namespace hlo { DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value x, Value y, bool allow_empty) { TensorType xType = x.getType().dyn_cast(); TensorType yType = y.getType().dyn_cast(); if (!xType || !yType) return {}; if (allow_empty && xType == yType) return {}; // If the shapes have the same rank, then there is nothing to do. auto xRank = xType.getRank(), yRank = yType.getRank(); if (allow_empty && xRank == yRank) return {}; // Otherwise if the ranks of the inputs don't match, TensorFlow automatically // reshapes the smaller by padding with dimensions of size 1 as a prefix. In // other words to pad a 5-vector to a 3-dimensional tensor it is reshaped to // have shape [1,1,5]. XLA's automatic broadcast code is able to broadcast // from lower to higher rank, but doesn't assume you want to pad as a prefix // of the dimensions, and instead needs to be told which dimensions of the // higher rank tensor to match to the lower rank tensor. auto maxRank = std::max(xRank, yRank); auto minRank = std::min(xRank, yRank); // Match the lower rank tensor along the larger-numbered dimensions of the // higher rank tensor. SmallVector broadcastDimensions(minRank); std::iota(broadcastDimensions.begin(), broadcastDimensions.end(), maxRank - minRank); RankedTensorType type = RankedTensorType::get({minRank}, b->getIntegerType(64)); return DenseIntElementsAttr::get(type, broadcastDimensions); } DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) { RankedTensorType scalar_ty = RankedTensorType::get({}, ty); if (auto float_ty = ty.dyn_cast()) { APFloat value(float_ty.getFloatSemantics(), raw_value); return DenseElementsAttr::get(scalar_ty, value); } auto int_ty = ty.cast(); APInt value(int_ty.getWidth(), static_cast(raw_value), true); return DenseElementsAttr::get(scalar_ty, value); } } // namespace hlo } // namespace mlir