Add support for lowering mhlo.iota to Linalg.
PiperOrigin-RevId: 322799853
This commit is contained in:
parent
4251630426
commit
8f262ae8f5
|
@ -640,25 +640,25 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
|
|||
}
|
||||
};
|
||||
|
||||
class IotaConverter : public OpConversionPattern<lmhlo::IotaOp> {
|
||||
template <typename OpTy, bool isLHLO = true>
|
||||
class IotaConverter : public OpConversionPattern<OpTy> {
|
||||
public:
|
||||
using OpConversionPattern<lmhlo::IotaOp>::OpConversionPattern;
|
||||
using OpConversionPattern<OpTy>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
lmhlo::IotaOp iotaOp, ArrayRef<Value> args,
|
||||
OpTy iotaOp, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto resultMemrefType =
|
||||
iotaOp.getOperand().getType().dyn_cast<MemRefType>();
|
||||
if (!resultMemrefType) return failure();
|
||||
ShapedType resultShapedType = getHloOpResultType<isLHLO>(iotaOp);
|
||||
if (!resultShapedType) return failure();
|
||||
|
||||
auto resultElementType = resultMemrefType.getElementType();
|
||||
auto resultElementType = resultShapedType.getElementType();
|
||||
if (!resultElementType.isSignlessIntOrFloat()) return failure();
|
||||
|
||||
// Construct the indexing maps needed for linalg.generic ops.
|
||||
unsigned nloops = resultMemrefType.getRank();
|
||||
unsigned nloops = resultShapedType.getRank();
|
||||
|
||||
rewriter.create<linalg::IndexedGenericOp>(
|
||||
iotaOp.getLoc(), ArrayRef<Type>{}, args,
|
||||
auto linalgOp = rewriter.create<linalg::IndexedGenericOp>(
|
||||
iotaOp.getLoc(), isLHLO ? ArrayRef<Type>{} : resultShapedType, args,
|
||||
0, // args_in
|
||||
1, // args_out
|
||||
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
|
||||
|
@ -669,14 +669,16 @@ class IotaConverter : public OpConversionPattern<lmhlo::IotaOp> {
|
|||
nestedLoc, ivs[iotaOp.iota_dimension().getZExtValue()],
|
||||
nestedBuilder.getIntegerType(
|
||||
resultElementType.getIntOrFloatBitWidth()));
|
||||
if (resultElementType.isa<FloatType>()) {
|
||||
if (resultElementType.template isa<FloatType>()) {
|
||||
castOp = nestedBuilder.create<SIToFPOp>(nestedLoc, castOp,
|
||||
resultElementType);
|
||||
}
|
||||
nestedBuilder.create<linalg::YieldOp>(nestedLoc, castOp);
|
||||
});
|
||||
|
||||
rewriter.replaceOp(iotaOp, llvm::None);
|
||||
if (isLHLO)
|
||||
rewriter.replaceOp(iotaOp, llvm::None);
|
||||
else
|
||||
rewriter.replaceOp(iotaOp, linalgOp.output_tensors());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -768,7 +770,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
|||
patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>,
|
||||
ConstConverter,
|
||||
ConvToLinalgConverter,
|
||||
IotaConverter,
|
||||
IotaConverter<lmhlo::IotaOp>,
|
||||
LhloBroadcastInDimConverter,
|
||||
PointwiseToLinalgConverter<lmhlo::AbsOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::AddOp>,
|
||||
|
@ -870,36 +872,37 @@ namespace mhlo {
|
|||
|
||||
void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
OwningRewritePatternList* patterns) {
|
||||
patterns->insert<BroadcastConverter<mhlo::BroadcastOp, false>,
|
||||
HloBroadcastInDimConverter,
|
||||
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::AddOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::AndOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CeilOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CompareOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ConvertOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CopyOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CosOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::DivOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ExpOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::LogOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::MaxOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::MinOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::MulOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::NegOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RealOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RemOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SelectOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SinOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SubOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::TanhOp, false>,
|
||||
ReshapeOpConverter<mhlo::ReshapeOp, false>,
|
||||
ReverseConverter<mhlo::ReverseOp, false>,
|
||||
TransposeConverter<mhlo::TransposeOp, false>>(context);
|
||||
patterns
|
||||
->insert<BroadcastConverter<mhlo::BroadcastOp, false>,
|
||||
HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::AddOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::AndOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CeilOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CompareOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ConvertOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CopyOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CosOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::DivOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ExpOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::LogOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::MaxOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::MinOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::MulOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::NegOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RealOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RemOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SelectOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SinOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SubOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::TanhOp, false>,
|
||||
ReshapeOpConverter<mhlo::ReshapeOp, false>,
|
||||
ReverseConverter<mhlo::ReverseOp, false>,
|
||||
TransposeConverter<mhlo::TransposeOp, false>>(context);
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
|
||||
|
|
|
@ -557,3 +557,18 @@ func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
|||
}
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-LABEL: func @iota
|
||||
func @iota() -> tensor<7x10xf32> {
|
||||
%result = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<7x10xf32>)
|
||||
return %result : tensor<7x10xf32>
|
||||
}
|
||||
// CHECK: linalg.indexed_generic
|
||||
// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]]
|
||||
// CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index):
|
||||
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32
|
||||
// CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32
|
||||
// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32
|
||||
|
|
Loading…
Reference in New Issue