Add support for lowering mhlo.iota to Linalg.

PiperOrigin-RevId: 322799853
This commit is contained in:
Hanhan Wang 2020-07-23 16:18:01 +00:00 committed by Mehdi Amini
parent 4251630426
commit 8f262ae8f5
2 changed files with 62 additions and 44 deletions

View File

@ -640,25 +640,25 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
}
};
class IotaConverter : public OpConversionPattern<lmhlo::IotaOp> {
template <typename OpTy, bool isLHLO = true>
class IotaConverter : public OpConversionPattern<OpTy> {
public:
using OpConversionPattern<lmhlo::IotaOp>::OpConversionPattern;
using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
lmhlo::IotaOp iotaOp, ArrayRef<Value> args,
OpTy iotaOp, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
auto resultMemrefType =
iotaOp.getOperand().getType().dyn_cast<MemRefType>();
if (!resultMemrefType) return failure();
ShapedType resultShapedType = getHloOpResultType<isLHLO>(iotaOp);
if (!resultShapedType) return failure();
auto resultElementType = resultMemrefType.getElementType();
auto resultElementType = resultShapedType.getElementType();
if (!resultElementType.isSignlessIntOrFloat()) return failure();
// Construct the indexing maps needed for linalg.generic ops.
unsigned nloops = resultMemrefType.getRank();
unsigned nloops = resultShapedType.getRank();
rewriter.create<linalg::IndexedGenericOp>(
iotaOp.getLoc(), ArrayRef<Type>{}, args,
auto linalgOp = rewriter.create<linalg::IndexedGenericOp>(
iotaOp.getLoc(), isLHLO ? ArrayRef<Type>{} : resultShapedType, args,
0, // args_in
1, // args_out
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
@ -669,14 +669,16 @@ class IotaConverter : public OpConversionPattern<lmhlo::IotaOp> {
nestedLoc, ivs[iotaOp.iota_dimension().getZExtValue()],
nestedBuilder.getIntegerType(
resultElementType.getIntOrFloatBitWidth()));
if (resultElementType.isa<FloatType>()) {
if (resultElementType.template isa<FloatType>()) {
castOp = nestedBuilder.create<SIToFPOp>(nestedLoc, castOp,
resultElementType);
}
nestedBuilder.create<linalg::YieldOp>(nestedLoc, castOp);
});
rewriter.replaceOp(iotaOp, llvm::None);
if (isLHLO)
rewriter.replaceOp(iotaOp, llvm::None);
else
rewriter.replaceOp(iotaOp, linalgOp.output_tensors());
return success();
}
};
@ -768,7 +770,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>,
ConstConverter,
ConvToLinalgConverter,
IotaConverter,
IotaConverter<lmhlo::IotaOp>,
LhloBroadcastInDimConverter,
PointwiseToLinalgConverter<lmhlo::AbsOp>,
PointwiseToLinalgConverter<lmhlo::AddOp>,
@ -870,36 +872,37 @@ namespace mhlo {
void populateHLOToLinalgConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) {
patterns->insert<BroadcastConverter<mhlo::BroadcastOp, false>,
HloBroadcastInDimConverter,
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
PointwiseToLinalgConverter<mhlo::AddOp, false>,
PointwiseToLinalgConverter<mhlo::AndOp, false>,
PointwiseToLinalgConverter<mhlo::CeilOp, false>,
PointwiseToLinalgConverter<mhlo::CompareOp, false>,
PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
PointwiseToLinalgConverter<mhlo::ConvertOp, false>,
PointwiseToLinalgConverter<mhlo::CopyOp, false>,
PointwiseToLinalgConverter<mhlo::CosOp, false>,
PointwiseToLinalgConverter<mhlo::DivOp, false>,
PointwiseToLinalgConverter<mhlo::ExpOp, false>,
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
PointwiseToLinalgConverter<mhlo::LogOp, false>,
PointwiseToLinalgConverter<mhlo::MaxOp, false>,
PointwiseToLinalgConverter<mhlo::MinOp, false>,
PointwiseToLinalgConverter<mhlo::MulOp, false>,
PointwiseToLinalgConverter<mhlo::NegOp, false>,
PointwiseToLinalgConverter<mhlo::RealOp, false>,
PointwiseToLinalgConverter<mhlo::RemOp, false>,
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
PointwiseToLinalgConverter<mhlo::SelectOp, false>,
PointwiseToLinalgConverter<mhlo::SinOp, false>,
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
PointwiseToLinalgConverter<mhlo::SubOp, false>,
PointwiseToLinalgConverter<mhlo::TanhOp, false>,
ReshapeOpConverter<mhlo::ReshapeOp, false>,
ReverseConverter<mhlo::ReverseOp, false>,
TransposeConverter<mhlo::TransposeOp, false>>(context);
patterns
->insert<BroadcastConverter<mhlo::BroadcastOp, false>,
HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>,
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
PointwiseToLinalgConverter<mhlo::AddOp, false>,
PointwiseToLinalgConverter<mhlo::AndOp, false>,
PointwiseToLinalgConverter<mhlo::CeilOp, false>,
PointwiseToLinalgConverter<mhlo::CompareOp, false>,
PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
PointwiseToLinalgConverter<mhlo::ConvertOp, false>,
PointwiseToLinalgConverter<mhlo::CopyOp, false>,
PointwiseToLinalgConverter<mhlo::CosOp, false>,
PointwiseToLinalgConverter<mhlo::DivOp, false>,
PointwiseToLinalgConverter<mhlo::ExpOp, false>,
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
PointwiseToLinalgConverter<mhlo::LogOp, false>,
PointwiseToLinalgConverter<mhlo::MaxOp, false>,
PointwiseToLinalgConverter<mhlo::MinOp, false>,
PointwiseToLinalgConverter<mhlo::MulOp, false>,
PointwiseToLinalgConverter<mhlo::NegOp, false>,
PointwiseToLinalgConverter<mhlo::RealOp, false>,
PointwiseToLinalgConverter<mhlo::RemOp, false>,
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
PointwiseToLinalgConverter<mhlo::SelectOp, false>,
PointwiseToLinalgConverter<mhlo::SinOp, false>,
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
PointwiseToLinalgConverter<mhlo::SubOp, false>,
PointwiseToLinalgConverter<mhlo::TanhOp, false>,
ReshapeOpConverter<mhlo::ReshapeOp, false>,
ReverseConverter<mhlo::ReverseOp, false>,
TransposeConverter<mhlo::TransposeOp, false>>(context);
}
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {

View File

@ -557,3 +557,18 @@ func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> {
}
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// -----
// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @iota
func @iota() -> tensor<7x10xf32> {
%result = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<7x10xf32>)
return %result : tensor<7x10xf32>
}
// CHECK: linalg.indexed_generic
// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index):
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32
// CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32
// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32