[MHLO:linalg] Be more aggressive about turning mhlo.const into std.constant
On tensors the only difference between these ops is that mhlo.const supports unsigned types. PiperOrigin-RevId: 377970948
This commit is contained in:
parent
25b93c8d66
commit
d1c60df2fe
|
@ -1077,36 +1077,42 @@ struct ConcatenateConverter : public OpConversionPattern<mhlo::ConcatenateOp> {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
class ConstConverter : public OpConversionPattern<OpTy> {
|
||||
class ConstConverterBuffer : public OpConversionPattern<lmhlo::ConstOp> {
|
||||
public:
|
||||
using OpConversionPattern<OpTy>::OpConversionPattern;
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
OpTy const_op, ArrayRef<Value> /*args*/,
|
||||
lmhlo::ConstOp const_op, ArrayRef<Value> /*args*/,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
Location loc = const_op.getLoc();
|
||||
auto value_attr = const_op.value().template cast<DenseElementsAttr>();
|
||||
auto value_attr = const_op.value().cast<DenseElementsAttr>();
|
||||
if (value_attr.getType().getRank() != 0) return failure();
|
||||
ReplaceConstOp(loc, const_op, value_attr, rewriter);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
void ReplaceConstOp(Location loc, mhlo::ConstOp op,
|
||||
DenseElementsAttr value_attr,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
Value std_tensor_const = rewriter.create<mlir::ConstantOp>(loc, value_attr);
|
||||
rewriter.replaceOp(op, {std_tensor_const});
|
||||
}
|
||||
void ReplaceConstOp(Location loc, lmhlo::ConstOp op,
|
||||
DenseElementsAttr value_attr,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
Value std_scalar_const =
|
||||
rewriter.create<mlir::ConstantOp>(loc, value_attr.getValue({}));
|
||||
rewriter.create<mlir::AffineStoreOp>(loc, std_scalar_const, op.getOperand(),
|
||||
llvm::None);
|
||||
rewriter.eraseOp(op);
|
||||
rewriter.create<mlir::AffineStoreOp>(loc, std_scalar_const,
|
||||
const_op.getOperand(), llvm::None);
|
||||
rewriter.eraseOp(const_op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class ConstConverterTensor : public OpConversionPattern<mhlo::ConstOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
mhlo::ConstOp const_op, ArrayRef<Value> /*args*/,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto value_attr = const_op.value().cast<DenseElementsAttr>();
|
||||
auto type =
|
||||
typeConverter->convertType(const_op.getType()).cast<ShapedType>();
|
||||
if (type != const_op.getType()) {
|
||||
// Signedness conversion.
|
||||
value_attr = value_attr.mapValues(type.getElementType(),
|
||||
[](const APInt& i) { return i; });
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<ConstantOp>(const_op, type, value_attr);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -2291,7 +2297,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
|||
OwningRewritePatternList* patterns) {
|
||||
// clang-format off
|
||||
patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>,
|
||||
ConstConverter<lmhlo::ConstOp>,
|
||||
ConstConverterBuffer,
|
||||
ConvToLinalgConverter,
|
||||
IotaConverter<lmhlo::IotaOp>,
|
||||
LhloBroadcastInDimConverter,
|
||||
|
@ -2490,7 +2496,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
|||
// clang-format off
|
||||
patterns->insert<
|
||||
BroadcastConverter<mhlo::BroadcastOp, false>, ConcatenateConverter,
|
||||
ConstConverter<mhlo::ConstOp>, HloDynamicBroadcastInDimConverter,
|
||||
ConstConverterTensor, HloDynamicBroadcastInDimConverter,
|
||||
HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>,
|
||||
IotaConverter<mhlo::DynamicIotaOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
|
||||
|
|
|
@ -2698,3 +2698,26 @@ func @scatter_update_slice(%arg0: tensor<6x3xi32>, %arg1: tensor<2x1xi32>,
|
|||
// CHECK: linalg.yield %[[SELECT]] : i32
|
||||
// CHECK: } -> tensor<6x3xi32>
|
||||
// CHECK: return %[[RES]] : tensor<6x3xi32>
|
||||
|
||||
// -----
|
||||
|
||||
func @const() -> tensor<3xi32> {
|
||||
// CHECK: = constant dense<[1, 2, 3]> : tensor<3xi32>
|
||||
%cst = mhlo.constant dense<[1, 2, 3]> : tensor<3xi32>
|
||||
return %cst : tensor<3xi32>
|
||||
}
|
||||
// -----
|
||||
|
||||
func @const_unsigned() -> tensor<3xui32> {
|
||||
// CHECK: = constant dense<[1, 2, 3]> : tensor<3xi32>
|
||||
%cst = mhlo.constant dense<[1, 2, 3]> : tensor<3xui32>
|
||||
return %cst : tensor<3xui32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @const_splat() -> tensor<3xi16> {
|
||||
// CHECK: = constant dense<1> : tensor<3xi16>
|
||||
%cst = mhlo.constant dense<1> : tensor<3xi16>
|
||||
return %cst : tensor<3xi16>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue