[MHLO:linalg] Be more aggressive about turning mhlo.const into std.constant
On tensors the only difference between these ops is that mhlo.const supports unsigned types. PiperOrigin-RevId: 377970948
This commit is contained in:
parent
25b93c8d66
commit
d1c60df2fe
|
@ -1077,36 +1077,42 @@ struct ConcatenateConverter : public OpConversionPattern<mhlo::ConcatenateOp> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename OpTy>
|
class ConstConverterBuffer : public OpConversionPattern<lmhlo::ConstOp> {
|
||||||
class ConstConverter : public OpConversionPattern<OpTy> {
|
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern<OpTy>::OpConversionPattern;
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(
|
LogicalResult matchAndRewrite(
|
||||||
OpTy const_op, ArrayRef<Value> /*args*/,
|
lmhlo::ConstOp const_op, ArrayRef<Value> /*args*/,
|
||||||
ConversionPatternRewriter& rewriter) const final {
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
Location loc = const_op.getLoc();
|
Location loc = const_op.getLoc();
|
||||||
auto value_attr = const_op.value().template cast<DenseElementsAttr>();
|
auto value_attr = const_op.value().cast<DenseElementsAttr>();
|
||||||
if (value_attr.getType().getRank() != 0) return failure();
|
if (value_attr.getType().getRank() != 0) return failure();
|
||||||
ReplaceConstOp(loc, const_op, value_attr, rewriter);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
void ReplaceConstOp(Location loc, mhlo::ConstOp op,
|
|
||||||
DenseElementsAttr value_attr,
|
|
||||||
ConversionPatternRewriter& rewriter) const {
|
|
||||||
Value std_tensor_const = rewriter.create<mlir::ConstantOp>(loc, value_attr);
|
|
||||||
rewriter.replaceOp(op, {std_tensor_const});
|
|
||||||
}
|
|
||||||
void ReplaceConstOp(Location loc, lmhlo::ConstOp op,
|
|
||||||
DenseElementsAttr value_attr,
|
|
||||||
ConversionPatternRewriter& rewriter) const {
|
|
||||||
Value std_scalar_const =
|
Value std_scalar_const =
|
||||||
rewriter.create<mlir::ConstantOp>(loc, value_attr.getValue({}));
|
rewriter.create<mlir::ConstantOp>(loc, value_attr.getValue({}));
|
||||||
rewriter.create<mlir::AffineStoreOp>(loc, std_scalar_const, op.getOperand(),
|
rewriter.create<mlir::AffineStoreOp>(loc, std_scalar_const,
|
||||||
llvm::None);
|
const_op.getOperand(), llvm::None);
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(const_op);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class ConstConverterTensor : public OpConversionPattern<mhlo::ConstOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
mhlo::ConstOp const_op, ArrayRef<Value> /*args*/,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
auto value_attr = const_op.value().cast<DenseElementsAttr>();
|
||||||
|
auto type =
|
||||||
|
typeConverter->convertType(const_op.getType()).cast<ShapedType>();
|
||||||
|
if (type != const_op.getType()) {
|
||||||
|
// Signedness conversion.
|
||||||
|
value_attr = value_attr.mapValues(type.getElementType(),
|
||||||
|
[](const APInt& i) { return i; });
|
||||||
|
}
|
||||||
|
rewriter.replaceOpWithNewOp<ConstantOp>(const_op, type, value_attr);
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -2291,7 +2297,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
OwningRewritePatternList* patterns) {
|
OwningRewritePatternList* patterns) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>,
|
patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>,
|
||||||
ConstConverter<lmhlo::ConstOp>,
|
ConstConverterBuffer,
|
||||||
ConvToLinalgConverter,
|
ConvToLinalgConverter,
|
||||||
IotaConverter<lmhlo::IotaOp>,
|
IotaConverter<lmhlo::IotaOp>,
|
||||||
LhloBroadcastInDimConverter,
|
LhloBroadcastInDimConverter,
|
||||||
|
@ -2490,7 +2496,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
// clang-format off
|
// clang-format off
|
||||||
patterns->insert<
|
patterns->insert<
|
||||||
BroadcastConverter<mhlo::BroadcastOp, false>, ConcatenateConverter,
|
BroadcastConverter<mhlo::BroadcastOp, false>, ConcatenateConverter,
|
||||||
ConstConverter<mhlo::ConstOp>, HloDynamicBroadcastInDimConverter,
|
ConstConverterTensor, HloDynamicBroadcastInDimConverter,
|
||||||
HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>,
|
HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>,
|
||||||
IotaConverter<mhlo::DynamicIotaOp, false>,
|
IotaConverter<mhlo::DynamicIotaOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
|
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
|
||||||
|
|
|
@ -2698,3 +2698,26 @@ func @scatter_update_slice(%arg0: tensor<6x3xi32>, %arg1: tensor<2x1xi32>,
|
||||||
// CHECK: linalg.yield %[[SELECT]] : i32
|
// CHECK: linalg.yield %[[SELECT]] : i32
|
||||||
// CHECK: } -> tensor<6x3xi32>
|
// CHECK: } -> tensor<6x3xi32>
|
||||||
// CHECK: return %[[RES]] : tensor<6x3xi32>
|
// CHECK: return %[[RES]] : tensor<6x3xi32>
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @const() -> tensor<3xi32> {
|
||||||
|
// CHECK: = constant dense<[1, 2, 3]> : tensor<3xi32>
|
||||||
|
%cst = mhlo.constant dense<[1, 2, 3]> : tensor<3xi32>
|
||||||
|
return %cst : tensor<3xi32>
|
||||||
|
}
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @const_unsigned() -> tensor<3xui32> {
|
||||||
|
// CHECK: = constant dense<[1, 2, 3]> : tensor<3xi32>
|
||||||
|
%cst = mhlo.constant dense<[1, 2, 3]> : tensor<3xui32>
|
||||||
|
return %cst : tensor<3xui32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @const_splat() -> tensor<3xi16> {
|
||||||
|
// CHECK: = constant dense<1> : tensor<3xi16>
|
||||||
|
%cst = mhlo.constant dense<1> : tensor<3xi16>
|
||||||
|
return %cst : tensor<3xi16>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue