[HLO] Add a pattern for HLO ConstOp to HLO -> Linalg conversion.
PiperOrigin-RevId: 346718273
This commit is contained in:
parent
cfcf741932
commit
c36afd275e
|
@ -733,23 +733,37 @@ class IotaConverter : public OpConversionPattern<OpTy> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class ConstConverter : public OpConversionPattern<lmhlo::ConstOp> {
|
template <typename OpTy>
|
||||||
|
class ConstConverter : public OpConversionPattern<OpTy> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern<lmhlo::ConstOp>::OpConversionPattern;
|
using OpConversionPattern<OpTy>::OpConversionPattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(
|
LogicalResult matchAndRewrite(
|
||||||
lmhlo::ConstOp const_op, ArrayRef<Value> args,
|
OpTy const_op, ArrayRef<Value> /*args*/,
|
||||||
ConversionPatternRewriter& rewriter) const final {
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
auto loc = const_op.getLoc();
|
Location loc = const_op.getLoc();
|
||||||
auto value_attr = const_op.value().cast<DenseElementsAttr>();
|
auto value_attr = const_op.value().template cast<DenseElementsAttr>();
|
||||||
if (value_attr.getType().getRank() != 0) return failure();
|
if (value_attr.getType().getRank() != 0) return failure();
|
||||||
auto std_const_op =
|
ReplaceConstOp(loc, const_op, value_attr, rewriter);
|
||||||
rewriter.create<mlir::ConstantOp>(loc, value_attr.getValue({}));
|
|
||||||
rewriter.create<mlir::AffineStoreOp>(loc, std_const_op,
|
|
||||||
const_op.getOperand(), ValueRange());
|
|
||||||
rewriter.eraseOp(const_op);
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void ReplaceConstOp(Location loc, mhlo::ConstOp op,
|
||||||
|
DenseElementsAttr value_attr,
|
||||||
|
ConversionPatternRewriter& rewriter) const {
|
||||||
|
Value std_tensor_const = rewriter.create<mlir::ConstantOp>(loc, value_attr);
|
||||||
|
rewriter.replaceOp(op, {std_tensor_const});
|
||||||
|
}
|
||||||
|
void ReplaceConstOp(Location loc, lmhlo::ConstOp op,
|
||||||
|
DenseElementsAttr value_attr,
|
||||||
|
ConversionPatternRewriter& rewriter) const {
|
||||||
|
Value std_scalar_const =
|
||||||
|
rewriter.create<mlir::ConstantOp>(loc, value_attr.getValue({}));
|
||||||
|
rewriter.create<mlir::AffineStoreOp>(loc, std_scalar_const, op.getOperand(),
|
||||||
|
llvm::None);
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class ReduceConverter : public OpConversionPattern<lmhlo::ReduceOp> {
|
class ReduceConverter : public OpConversionPattern<lmhlo::ReduceOp> {
|
||||||
|
@ -908,7 +922,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
OwningRewritePatternList* patterns) {
|
OwningRewritePatternList* patterns) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>,
|
patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>,
|
||||||
ConstConverter,
|
ConstConverter<lmhlo::ConstOp>,
|
||||||
ConvToLinalgConverter,
|
ConvToLinalgConverter,
|
||||||
IotaConverter<lmhlo::IotaOp>,
|
IotaConverter<lmhlo::IotaOp>,
|
||||||
LhloBroadcastInDimConverter,
|
LhloBroadcastInDimConverter,
|
||||||
|
@ -1029,7 +1043,8 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
OwningRewritePatternList* patterns) {
|
OwningRewritePatternList* patterns) {
|
||||||
patterns
|
patterns
|
||||||
->insert<BroadcastConverter<mhlo::BroadcastOp, false>,
|
->insert<BroadcastConverter<mhlo::BroadcastOp, false>,
|
||||||
HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>,
|
ConstConverter<mhlo::ConstOp>, HloBroadcastInDimConverter,
|
||||||
|
IotaConverter<mhlo::IotaOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
|
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::AddOp, false>,
|
PointwiseToLinalgConverter<mhlo::AddOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::AndOp, false>,
|
PointwiseToLinalgConverter<mhlo::AndOp, false>,
|
||||||
|
|
|
@ -696,3 +696,14 @@ func @shift_right_logical(%lhs: tensor<2x2xi32>,
|
||||||
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32):
|
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32):
|
||||||
// CHECK-NEXT: %[[RESULT:.*]] = shift_right_unsigned %[[LHS]], %[[RHS]] : i32
|
// CHECK-NEXT: %[[RESULT:.*]] = shift_right_unsigned %[[LHS]], %[[RHS]] : i32
|
||||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
|
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @constant
|
||||||
|
func @constant() {
|
||||||
|
%result = "mhlo.constant"() {
|
||||||
|
value = dense<10> : tensor<i32>
|
||||||
|
} : () -> (tensor<i32>)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: %[[CONSTANT:.*]] = constant dense<10> : tensor<i32>
|
||||||
|
|
Loading…
Reference in New Issue