[MHLO:linalg] Be more aggressive about turning mhlo.const into std.constant

On tensors the only difference between these ops is that mhlo.const supports unsigned types.

PiperOrigin-RevId: 377970948
This commit is contained in:
Benjamin Kramer 2021-06-07 11:56:55 -07:00 committed by TensorFlow MLIR Team
parent 25b93c8d66
commit d1c60df2fe
2 changed files with 53 additions and 24 deletions

View File

@ -1077,36 +1077,42 @@ struct ConcatenateConverter : public OpConversionPattern<mhlo::ConcatenateOp> {
} }
}; };
template <typename OpTy> class ConstConverterBuffer : public OpConversionPattern<lmhlo::ConstOp> {
class ConstConverter : public OpConversionPattern<OpTy> {
public: public:
using OpConversionPattern<OpTy>::OpConversionPattern; using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
OpTy const_op, ArrayRef<Value> /*args*/, lmhlo::ConstOp const_op, ArrayRef<Value> /*args*/,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
Location loc = const_op.getLoc(); Location loc = const_op.getLoc();
auto value_attr = const_op.value().template cast<DenseElementsAttr>(); auto value_attr = const_op.value().cast<DenseElementsAttr>();
if (value_attr.getType().getRank() != 0) return failure(); if (value_attr.getType().getRank() != 0) return failure();
ReplaceConstOp(loc, const_op, value_attr, rewriter);
return success();
}
private:
void ReplaceConstOp(Location loc, mhlo::ConstOp op,
DenseElementsAttr value_attr,
ConversionPatternRewriter& rewriter) const {
Value std_tensor_const = rewriter.create<mlir::ConstantOp>(loc, value_attr);
rewriter.replaceOp(op, {std_tensor_const});
}
void ReplaceConstOp(Location loc, lmhlo::ConstOp op,
DenseElementsAttr value_attr,
ConversionPatternRewriter& rewriter) const {
Value std_scalar_const = Value std_scalar_const =
rewriter.create<mlir::ConstantOp>(loc, value_attr.getValue({})); rewriter.create<mlir::ConstantOp>(loc, value_attr.getValue({}));
rewriter.create<mlir::AffineStoreOp>(loc, std_scalar_const, op.getOperand(), rewriter.create<mlir::AffineStoreOp>(loc, std_scalar_const,
llvm::None); const_op.getOperand(), llvm::None);
rewriter.eraseOp(op); rewriter.eraseOp(const_op);
return success();
}
};
class ConstConverterTensor : public OpConversionPattern<mhlo::ConstOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::ConstOp const_op, ArrayRef<Value> /*args*/,
ConversionPatternRewriter& rewriter) const final {
auto value_attr = const_op.value().cast<DenseElementsAttr>();
auto type =
typeConverter->convertType(const_op.getType()).cast<ShapedType>();
if (type != const_op.getType()) {
// Signedness conversion.
value_attr = value_attr.mapValues(type.getElementType(),
[](const APInt& i) { return i; });
}
rewriter.replaceOpWithNewOp<ConstantOp>(const_op, type, value_attr);
return success();
} }
}; };
@ -2291,7 +2297,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) { OwningRewritePatternList* patterns) {
// clang-format off // clang-format off
patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>, patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>,
ConstConverter<lmhlo::ConstOp>, ConstConverterBuffer,
ConvToLinalgConverter, ConvToLinalgConverter,
IotaConverter<lmhlo::IotaOp>, IotaConverter<lmhlo::IotaOp>,
LhloBroadcastInDimConverter, LhloBroadcastInDimConverter,
@ -2490,7 +2496,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
// clang-format off // clang-format off
patterns->insert< patterns->insert<
BroadcastConverter<mhlo::BroadcastOp, false>, ConcatenateConverter, BroadcastConverter<mhlo::BroadcastOp, false>, ConcatenateConverter,
ConstConverter<mhlo::ConstOp>, HloDynamicBroadcastInDimConverter, ConstConverterTensor, HloDynamicBroadcastInDimConverter,
HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>, HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>,
IotaConverter<mhlo::DynamicIotaOp, false>, IotaConverter<mhlo::DynamicIotaOp, false>,
PointwiseToLinalgConverter<mhlo::AbsOp, false>, PointwiseToLinalgConverter<mhlo::AbsOp, false>,

View File

@ -2698,3 +2698,26 @@ func @scatter_update_slice(%arg0: tensor<6x3xi32>, %arg1: tensor<2x1xi32>,
// CHECK: linalg.yield %[[SELECT]] : i32 // CHECK: linalg.yield %[[SELECT]] : i32
// CHECK: } -> tensor<6x3xi32> // CHECK: } -> tensor<6x3xi32>
// CHECK: return %[[RES]] : tensor<6x3xi32> // CHECK: return %[[RES]] : tensor<6x3xi32>
// -----
func @const() -> tensor<3xi32> {
// CHECK: = constant dense<[1, 2, 3]> : tensor<3xi32>
%cst = mhlo.constant dense<[1, 2, 3]> : tensor<3xi32>
return %cst : tensor<3xi32>
}
// -----
func @const_unsigned() -> tensor<3xui32> {
// CHECK: = constant dense<[1, 2, 3]> : tensor<3xi32>
%cst = mhlo.constant dense<[1, 2, 3]> : tensor<3xui32>
return %cst : tensor<3xui32>
}
// -----
func @const_splat() -> tensor<3xi16> {
// CHECK: = constant dense<1> : tensor<3xi16>
%cst = mhlo.constant dense<1> : tensor<3xi16>
return %cst : tensor<3xi16>
}