[MLIR] Lower `chlo.constant_like` to MHLO
Lower `chlo.constant_like` to a constant and, if needed, a broadcast. PiperOrigin-RevId: 331964137
This commit is contained in:
parent
da43c8596b
commit
a6fdebdc6c
|
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
|
@ -31,6 +33,39 @@ namespace mlir {
|
||||||
namespace chlo {
|
namespace chlo {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
struct ConvertConstantLikeOp : public OpConversionPattern<ConstantLikeOp> {
|
||||||
|
using OpConversionPattern<ConstantLikeOp>::OpConversionPattern;
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
ConstantLikeOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto result_ty = op.getType().cast<ShapedType>();
|
||||||
|
|
||||||
|
// Unranked uses are not supported. Consider `transform-unranked-hlo`.
|
||||||
|
if (!result_ty.hasRank()) return failure();
|
||||||
|
|
||||||
|
// Lower to MHLO constant if statically shaped.
|
||||||
|
if (result_ty.hasStaticShape()) {
|
||||||
|
rewriter.replaceOpWithNewOp<mhlo::ConstOp>(
|
||||||
|
op, DenseElementsAttr::get(result_ty, op.value()));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lower to broadcasted constant.
|
||||||
|
ConstantLikeOp::Adaptor transformed(operands);
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
Type extent_tensor_type = shape::getExtentTensorType(op.getContext());
|
||||||
|
Value constant = rewriter.create<mhlo::ConstOp>(loc, op.value());
|
||||||
|
Value uncasted_shape = rewriter.create<shape::ShapeOfOp>(
|
||||||
|
loc, extent_tensor_type, transformed.operand());
|
||||||
|
Type shape_ty =
|
||||||
|
RankedTensorType::get({result_ty.getRank()}, rewriter.getIndexType());
|
||||||
|
Value shape = rewriter.create<TensorCastOp>(loc, shape_ty, uncasted_shape);
|
||||||
|
rewriter.replaceOpWithNewOp<mhlo::DynamicBroadcastInDimOp>(
|
||||||
|
op, result_ty, constant, shape, rewriter.getI64TensorAttr({}));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Converts binary ops that statically are determined to not broadcast directly
|
// Converts binary ops that statically are determined to not broadcast directly
|
||||||
// to the corresponding mhlo non-broadcasting op.
|
// to the corresponding mhlo non-broadcasting op.
|
||||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||||
|
@ -505,6 +540,9 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
||||||
context, patterns);
|
context, patterns);
|
||||||
PopulateForBinaryOp<BroadcastCompareOp, mhlo::CompareOp, HloCompareAdaptor>(
|
PopulateForBinaryOp<BroadcastCompareOp, mhlo::CompareOp, HloCompareAdaptor>(
|
||||||
context, patterns);
|
context, patterns);
|
||||||
|
|
||||||
|
// Other patterns.
|
||||||
|
patterns->insert<ConvertConstantLikeOp>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace chlo
|
} // namespace chlo
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
// RUN: mlir-hlo-opt --mhlo-test-chlo-legalize-to-hlo --split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
|
// Lower statically shaped `constant_like` to constant.
|
||||||
|
// CHECK-LABEL: @constant_like_static_shape
|
||||||
|
func @constant_like_static_shape(%arg : tensor<1x2xi64>) -> tensor<1x2xf32> {
|
||||||
|
// CHECK: %[[RESULT:.*]] = mhlo.constant dense<3.200000e+00> : tensor<1x2xf32>
|
||||||
|
// CHECK: return %[[RESULT]]
|
||||||
|
%result = "chlo.constant_like"(%arg) { value = 3.2 : f32 }
|
||||||
|
: (tensor<1x2xi64>) -> tensor<1x2xf32>
|
||||||
|
return %result : tensor<1x2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lower dynamically shaped `constant_like` to broadcasted constant.
|
||||||
|
// CHECK-LABEL: constant_like_dynamic_shape
|
||||||
|
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x?xi64>)
|
||||||
|
func @constant_like_dynamic_shape(%arg : tensor<?x?xi64>) -> tensor<?x?xf32> {
|
||||||
|
// CHECK: %[[CONSTANT:.*]] = mhlo.constant dense<3.200000e+00> : tensor<f32>
|
||||||
|
// CHECK: %[[UNCASTED_SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<?x?xi64> -> tensor<?xindex>
|
||||||
|
// CHECK: %[[SHAPE:.*]] = tensor_cast %[[UNCASTED_SHAPE]] : tensor<?xindex> to tensor<2xindex>
|
||||||
|
// CHECK: %[[BROADCASTED_CONSTANT:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[CONSTANT]], %[[SHAPE]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||||
|
// CHECK: return %[[BROADCASTED_CONSTANT]] : tensor<?x?xf32>
|
||||||
|
%result = "chlo.constant_like"(%arg) { value = 3.2 : f32 }
|
||||||
|
: (tensor<?x?xi64>) -> tensor<?x?xf32>
|
||||||
|
return %result : tensor<?x?xf32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue