2020-07-02 03:18:52 +08:00
|
|
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
limitations under the License.
|
|
|
|
==============================================================================*/
|
|
|
|
|
2020-07-29 07:12:08 +08:00
|
|
|
#include "mlir-hlo/utils/broadcast_utils.h"
|
2020-07-02 03:18:52 +08:00
|
|
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
2020-07-29 07:12:08 +08:00
|
|
|
#include "llvm/ADT/Sequence.h"
|
|
|
|
#include "llvm/ADT/SmallVector.h"
|
|
|
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
|
|
|
#include "mlir/IR/Diagnostics.h"
|
|
|
|
#include "mlir/IR/StandardTypes.h"
|
2020-07-02 03:18:52 +08:00
|
|
|
|
|
|
|
namespace mlir {
|
2020-07-09 11:32:16 +08:00
|
|
|
namespace hlo {
|
2020-07-02 03:18:52 +08:00
|
|
|
|
|
|
|
bool IsLegalNumpyRankedBroadcast(Value lhs, Value rhs,
|
|
|
|
DenseIntElementsAttr broadcast_dims) {
|
|
|
|
RankedTensorType lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
|
|
|
|
RankedTensorType rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
|
|
|
|
if (!lhs_type || !rhs_type) return false;
|
|
|
|
if (lhs_type.getRank() == rhs_type.getRank()) return true;
|
|
|
|
|
|
|
|
// Otherwise, verify that broadcast_dims strictly performs left-padding.
|
|
|
|
auto smaller_rank = std::min(lhs_type.getRank(), rhs_type.getRank());
|
|
|
|
auto larger_rank = std::max(lhs_type.getRank(), rhs_type.getRank());
|
|
|
|
|
|
|
|
if (smaller_rank != broadcast_dims.getNumElements()) {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
auto expected_extents =
|
|
|
|
llvm::seq<int64_t>(larger_rank - smaller_rank, larger_rank);
|
|
|
|
return std::equal(expected_extents.begin(), expected_extents.end(),
|
|
|
|
broadcast_dims.getIntValues().begin());
|
|
|
|
}
|
|
|
|
|
|
|
|
Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs,
|
|
|
|
Value rhs,
|
|
|
|
OpBuilder& builder) {
|
|
|
|
auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
|
|
|
|
auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
|
|
|
|
if (!lhs_type || !rhs_type) {
|
|
|
|
emitError(loc) << "shape computation for broadcasting elementwise ops "
|
|
|
|
<< "is only implemented for ranked tensors";
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank());
|
2020-07-29 22:35:52 +08:00
|
|
|
Value lhs_shape_v = builder.createOrFold<shape::ShapeOfOp>(loc, lhs);
|
|
|
|
Value rhs_shape_v = builder.createOrFold<shape::ShapeOfOp>(loc, rhs);
|
2020-07-02 03:18:52 +08:00
|
|
|
Value result_shape_v = builder.createOrFold<shape::BroadcastOp>(
|
2020-07-31 17:08:18 +08:00
|
|
|
loc, shape::ShapeType::get(builder.getContext()), lhs_shape_v,
|
|
|
|
rhs_shape_v, nullptr /* error */);
|
2020-07-02 03:18:52 +08:00
|
|
|
return builder.createOrFold<shape::ToExtentTensorOp>(
|
|
|
|
loc, RankedTensorType::get({result_rank}, builder.getIndexType()),
|
|
|
|
result_shape_v);
|
|
|
|
}
|
|
|
|
|
2020-07-09 11:32:16 +08:00
|
|
|
} // namespace hlo
|
2020-07-02 03:18:52 +08:00
|
|
|
} // namespace mlir
|