/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "llvm/ADT/APFloat.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/utils/broadcast_utils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" namespace mlir { namespace chlo { template static LogicalResult Verify(T op) { return success(); } Value getConstantLikeMaxFiniteValue(OpBuilder& b, Location loc, Value val) { auto ty = getElementTypeOrSelf(val.getType()).cast(); return getConstantLike( b, loc, llvm::APFloat::getLargest(ty.getFloatSemantics()), val); } Value getConstantLikeInfValue(OpBuilder& b, Location loc, Value val, bool negative) { auto ty = getElementTypeOrSelf(val.getType()).cast(); return getConstantLike( b, loc, llvm::APFloat::getInf(ty.getFloatSemantics(), negative), val); } Value getConstantLike(OpBuilder& b, Location loc, const APFloat& constant, Value val) { Type ty = getElementTypeOrSelf(val.getType()); return b.create(loc, b.getFloatAttr(ty, constant), val); } //===----------------------------------------------------------------------===// // BinaryOps //===----------------------------------------------------------------------===// namespace { // Gets the resulting type from a broadcast between two types. static Type GetBroadcastType(Type x, Type y, Type element_type, DenseIntElementsAttr broadcast_dimensions_attr) { auto x_ranked = x.dyn_cast(); auto y_ranked = y.dyn_cast(); if (!x_ranked || !y_ranked) { return UnrankedTensorType::get(element_type); } auto shape_x = x_ranked.getShape(); auto shape_y = y_ranked.getShape(); if (shape_x.size() == shape_y.size()) { llvm::SmallVector out_shape(shape_x.size()); for (int i = 0, e = shape_x.size(); i < e; i++) { auto x_val = shape_x[i]; auto y_val = shape_y[i]; if (x_val == -1 || y_val == -1) { out_shape[i] = -1; } else { out_shape[i] = std::max(x_val, y_val); } } return RankedTensorType::get(out_shape, element_type); } auto shape_large = shape_x.size() > shape_y.size() ? shape_x : shape_y; auto shape_small = shape_x.size() <= shape_y.size() ? shape_x : shape_y; llvm::SmallVector broadcast_dimensions; if (broadcast_dimensions_attr) { // Explicit broadcast dimensions. for (const APInt& int_value : broadcast_dimensions_attr.getIntValues()) { broadcast_dimensions.push_back(int_value.getSExtValue()); } if (broadcast_dimensions.size() != shape_small.size()) { // Signal illegal broadcast_dimensions as unranked. return UnrankedTensorType::get(element_type); } } else { // If no broadcast dimensions, assume "numpy" broadcasting. broadcast_dimensions = llvm::to_vector<4>(llvm::seq( shape_large.size() - shape_small.size(), shape_large.size())); } llvm::SmallVector out_shape(shape_large.begin(), shape_large.end()); // Update according to the broadcast dimensions. for (auto index_pair : llvm::enumerate(broadcast_dimensions)) { auto old_value = out_shape[index_pair.value()]; auto new_value = shape_small[index_pair.index()]; if (old_value != -1 && (new_value == -1 || new_value > old_value)) { out_shape[index_pair.value()] = new_value; } } return RankedTensorType::get(out_shape, element_type); } LogicalResult InferBroadcastBinaryOpReturnTypeComponents( MLIRContext* context, Optional location, ValueRange operands, DictionaryAttr attributes, Type element_type, SmallVectorImpl& inferedReturnShapes) { // Find broadcast_dimensions. DenseIntElementsAttr broadcast_dimensions = attributes.get("broadcast_dimensions") .dyn_cast_or_null(); ShapedType lhs_type = operands[0].getType().dyn_cast(); ShapedType rhs_type = operands[1].getType().dyn_cast(); if (!lhs_type || !rhs_type || lhs_type.getElementType() != rhs_type.getElementType()) { return emitOptionalError(location, "mismatched operand types"); } if (!element_type) element_type = lhs_type.getElementType(); Type result_type = GetBroadcastType(lhs_type, rhs_type, element_type, broadcast_dimensions); if (auto ranked_result_type = result_type.dyn_cast()) { inferedReturnShapes.emplace_back(ranked_result_type.getShape(), element_type); return success(); } // TODO(laurenzo): This should be constructing with `element_type` but that // constructor variant needs to be added upstream. inferedReturnShapes.emplace_back(/* element_type */); return success(); } LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes( OpBuilder& builder, Operation* op, SmallVectorImpl& reifiedReturnShapes) { auto loc = op->getLoc(); auto lhs = op->getOperand(0); auto rhs = op->getOperand(1); // Check for "numpy"-style rank broadcast. auto broadcast_dimensions = op->getAttr("broadcast_dimensions") .dyn_cast_or_null(); if (broadcast_dimensions && !hlo::IsLegalNumpyRankedBroadcast(lhs, rhs, broadcast_dimensions)) { // Note: It is unclear whether the general specification of explicit // broadcast_dimensions on binary ops is a feature we want to carry // forward. While it can technically be implemented for ranked-dynamic, // it is incompatible with unranked inputs. If this warning is emitted // in real programs, it is an indication that the feature should be // implemented versus just falling back on the more standard definition // of numpy-like prefix-padding. return op->emitWarning() << "unsupported non prefix-padded dynamic rank " << "broadcast_dimensions = " << broadcast_dimensions; } Value computed_shape = hlo::ComputeBinaryElementwiseBroadcastingResultExtents( loc, lhs, rhs, builder, /*unsafe_as_extent_tensor=*/false); if (!computed_shape) return failure(); reifiedReturnShapes.push_back(computed_shape); return success(); } } // namespace //===----------------------------------------------------------------------===// // BroadcastComplexOp (has custom type inference due to different result type). //===----------------------------------------------------------------------===// LogicalResult BroadcastComplexOp::inferReturnTypeComponents( MLIRContext* context, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferedReturnShapes) { ShapedType lhs_type = operands[0].getType().dyn_cast(); if (!lhs_type) { return emitOptionalError(location, "expected ShapedType"); } Type element_type = ComplexType::get(lhs_type.getElementType()); return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands, attributes, element_type, inferedReturnShapes); } LogicalResult BroadcastComplexOp::reifyReturnTypeShapes( OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes) { return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(), reifiedReturnShapes); } //===----------------------------------------------------------------------===// // BroadcastCompareOp (has custom type inference due to different result type). //===----------------------------------------------------------------------===// void BroadcastCompareOp::build(OpBuilder& builder, OperationState& result, Value lhs, Value rhs, DenseIntElementsAttr broadcast_dimensions, StringAttr comparison_direction, StringAttr compare_type) { auto new_type = GetBroadcastType(lhs.getType(), rhs.getType(), builder.getI1Type(), broadcast_dimensions); build(builder, result, new_type, lhs, rhs, broadcast_dimensions, comparison_direction, compare_type); } LogicalResult BroadcastCompareOp::inferReturnTypeComponents( MLIRContext* context, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferedReturnShapes) { Type element_type = IntegerType::get(context, 1); return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands, attributes, element_type, inferedReturnShapes); } LogicalResult BroadcastCompareOp::reifyReturnTypeShapes( OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes) { return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(), reifiedReturnShapes); } //===----------------------------------------------------------------------===// // Macros for method definitions that are common to most broadcasting ops. //===----------------------------------------------------------------------===// #define BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(Op) \ LogicalResult Op::inferReturnTypeComponents( \ MLIRContext* context, Optional location, ValueRange operands, \ DictionaryAttr attributes, RegionRange regions, \ SmallVectorImpl& inferedReturnShapes) { \ return InferBroadcastBinaryOpReturnTypeComponents( \ context, location, operands, attributes, /*element_type=*/nullptr, \ inferedReturnShapes); \ } \ LogicalResult Op::reifyReturnTypeShapes( \ OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes) { \ return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(), \ reifiedReturnShapes); \ } #define BROADCAST_BINARY_OP_DEFS(Op) \ void Op::build(OpBuilder& builder, OperationState& result, Value left, \ Value right, DenseIntElementsAttr broadcast_dimensions) { \ auto type = GetBroadcastType( \ left.getType().cast(), right.getType().cast(), \ getElementTypeOrSelf(right.getType()), broadcast_dimensions); \ return Op::build(builder, result, type, left, right, \ broadcast_dimensions); \ } \ BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(Op) BROADCAST_BINARY_OP_DEFS(BroadcastAddOp); BROADCAST_BINARY_OP_DEFS(BroadcastAndOp); BROADCAST_BINARY_OP_DEFS(BroadcastAtan2Op); BROADCAST_BINARY_OP_DEFS(BroadcastDivOp); BROADCAST_BINARY_OP_DEFS(BroadcastMaxOp); BROADCAST_BINARY_OP_DEFS(BroadcastMinOp); BROADCAST_BINARY_OP_DEFS(BroadcastMulOp); BROADCAST_BINARY_OP_DEFS(BroadcastOrOp); BROADCAST_BINARY_OP_DEFS(BroadcastPowOp); BROADCAST_BINARY_OP_DEFS(BroadcastRemOp); BROADCAST_BINARY_OP_DEFS(BroadcastShiftLeftOp); BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightArithmeticOp); BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightLogicalOp); BROADCAST_BINARY_OP_DEFS(BroadcastSubOp); BROADCAST_BINARY_OP_DEFS(BroadcastXorOp); #undef BROADCAST_INFER_SHAPE_TYPE_OP_DEFS #undef BROADCAST_BINARY_OP_DEFS static LogicalResult Verify(ConstantLikeOp op) { if (op.value().getType() != op.getType().cast().getElementType()) return op.emitOpError() << "value's type doesn't match element return type"; return success(); } LogicalResult ConstantLikeOp::inferReturnTypeComponents( MLIRContext* context, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferedReturnShapes) { ConstantLikeOp::Adaptor op(operands, attributes); if (failed(op.verify(location.getValue()))) return failure(); Type element_type = op.value().getType(); Type operand_type = op.operand().getType(); if (operand_type.isa()) { inferedReturnShapes.emplace_back(element_type); } else { const auto& shape = operand_type.cast().getShape(); inferedReturnShapes.emplace_back(shape, element_type); } return success(); } struct ConstantLikeToConstant : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ConstantLikeOp op, PatternRewriter& rewriter) const override { auto op_type = op.operand().getType().cast(); if (!op_type.hasStaticShape()) return failure(); auto type = RankedTensorType::get(op_type.getShape(), op.value().getType()); ElementsAttr attr = DenseElementsAttr::get(type, op.value()); rewriter.replaceOpWithNewOp(op.getOperation(), attr); return success(); } }; void ConstantLikeOp::getCanonicalizationPatterns( OwningRewritePatternList& results, MLIRContext* context) { results.insert(context); } } // namespace chlo } // namespace mlir #define GET_OP_CLASSES #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc" namespace mlir { namespace chlo { //===----------------------------------------------------------------------===// // chlo Dialect Constructor //===----------------------------------------------------------------------===// void HloClientDialect::initialize() { addOperations< #define GET_OP_LIST #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc" >(); } } // namespace chlo } // namespace mlir