[HLO] Add a pattern for HLO ConstOp to HLO -> Linalg conversion.

PiperOrigin-RevId: 346718273
This commit is contained in:
Alexander Belyaev 2020-12-09 23:24:23 -08:00 committed by TensorFlow MLIR Team
parent cfcf741932
commit c36afd275e
2 changed files with 38 additions and 12 deletions

View File

@ -733,23 +733,37 @@ class IotaConverter : public OpConversionPattern<OpTy> {
} }
}; };
class ConstConverter : public OpConversionPattern<lmhlo::ConstOp> { template <typename OpTy>
class ConstConverter : public OpConversionPattern<OpTy> {
public: public:
using OpConversionPattern<lmhlo::ConstOp>::OpConversionPattern; using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
lmhlo::ConstOp const_op, ArrayRef<Value> args, OpTy const_op, ArrayRef<Value> /*args*/,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
auto loc = const_op.getLoc(); Location loc = const_op.getLoc();
auto value_attr = const_op.value().cast<DenseElementsAttr>(); auto value_attr = const_op.value().template cast<DenseElementsAttr>();
if (value_attr.getType().getRank() != 0) return failure(); if (value_attr.getType().getRank() != 0) return failure();
auto std_const_op = ReplaceConstOp(loc, const_op, value_attr, rewriter);
rewriter.create<mlir::ConstantOp>(loc, value_attr.getValue({}));
rewriter.create<mlir::AffineStoreOp>(loc, std_const_op,
const_op.getOperand(), ValueRange());
rewriter.eraseOp(const_op);
return success(); return success();
} }
private:
void ReplaceConstOp(Location loc, mhlo::ConstOp op,
DenseElementsAttr value_attr,
ConversionPatternRewriter& rewriter) const {
Value std_tensor_const = rewriter.create<mlir::ConstantOp>(loc, value_attr);
rewriter.replaceOp(op, {std_tensor_const});
}
void ReplaceConstOp(Location loc, lmhlo::ConstOp op,
DenseElementsAttr value_attr,
ConversionPatternRewriter& rewriter) const {
Value std_scalar_const =
rewriter.create<mlir::ConstantOp>(loc, value_attr.getValue({}));
rewriter.create<mlir::AffineStoreOp>(loc, std_scalar_const, op.getOperand(),
llvm::None);
rewriter.eraseOp(op);
}
}; };
class ReduceConverter : public OpConversionPattern<lmhlo::ReduceOp> { class ReduceConverter : public OpConversionPattern<lmhlo::ReduceOp> {
@ -908,7 +922,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) { OwningRewritePatternList* patterns) {
// clang-format off // clang-format off
patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>, patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>,
ConstConverter, ConstConverter<lmhlo::ConstOp>,
ConvToLinalgConverter, ConvToLinalgConverter,
IotaConverter<lmhlo::IotaOp>, IotaConverter<lmhlo::IotaOp>,
LhloBroadcastInDimConverter, LhloBroadcastInDimConverter,
@ -1029,7 +1043,8 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) { OwningRewritePatternList* patterns) {
patterns patterns
->insert<BroadcastConverter<mhlo::BroadcastOp, false>, ->insert<BroadcastConverter<mhlo::BroadcastOp, false>,
HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>, ConstConverter<mhlo::ConstOp>, HloBroadcastInDimConverter,
IotaConverter<mhlo::IotaOp, false>,
PointwiseToLinalgConverter<mhlo::AbsOp, false>, PointwiseToLinalgConverter<mhlo::AbsOp, false>,
PointwiseToLinalgConverter<mhlo::AddOp, false>, PointwiseToLinalgConverter<mhlo::AddOp, false>,
PointwiseToLinalgConverter<mhlo::AndOp, false>, PointwiseToLinalgConverter<mhlo::AndOp, false>,

View File

@ -696,3 +696,14 @@ func @shift_right_logical(%lhs: tensor<2x2xi32>,
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32): // CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32):
// CHECK-NEXT: %[[RESULT:.*]] = shift_right_unsigned %[[LHS]], %[[RHS]] : i32 // CHECK-NEXT: %[[RESULT:.*]] = shift_right_unsigned %[[LHS]], %[[RHS]] : i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 // CHECK-NEXT: linalg.yield %[[RESULT]] : i32
// -----
// CHECK-LABEL: func @constant
func @constant() {
%result = "mhlo.constant"() {
value = dense<10> : tensor<i32>
} : () -> (tensor<i32>)
return
}
// CHECK: %[[CONSTANT:.*]] = constant dense<10> : tensor<i32>