Try to avoid a segfault if we don't support a lowering.
It can happen that a lowering for a certain type is not implemented yet. We should not segfault in such a case, but instead return a failure(). PiperOrigin-RevId: 347801106
This commit is contained in:
parent
e6e8920921
commit
61244b136c
|
@ -131,6 +131,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
|
||||||
SmallVector<AffineMap, 2> indexing_maps(args.size() + (isLHLO ? 0 : 1),
|
SmallVector<AffineMap, 2> indexing_maps(args.size() + (isLHLO ? 0 : 1),
|
||||||
common_indexing_map);
|
common_indexing_map);
|
||||||
|
|
||||||
|
bool failed = false;
|
||||||
auto linalg_op = rewriter.create<linalg::GenericOp>(
|
auto linalg_op = rewriter.create<linalg::GenericOp>(
|
||||||
loc, op_result_types, inputs, output_buffers,
|
loc, op_result_types, inputs, output_buffers,
|
||||||
/*initTensors=*/ValueRange{}, indexing_maps,
|
/*initTensors=*/ValueRange{}, indexing_maps,
|
||||||
|
@ -141,8 +142,13 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
|
||||||
Value op_result = lmhlo::HloOpToStdScalarOp::map<OpTy>(
|
Value op_result = lmhlo::HloOpToStdScalarOp::map<OpTy>(
|
||||||
op, body_result_types,
|
op, body_result_types,
|
||||||
llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter);
|
llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter);
|
||||||
|
if (op_result == nullptr) {
|
||||||
|
failed = true;
|
||||||
|
} else {
|
||||||
nested_builder.create<linalg::YieldOp>(loc, op_result);
|
nested_builder.create<linalg::YieldOp>(loc, op_result);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
if (failed) return failure();
|
||||||
rewriter.replaceOp(op, linalg_op.getOperation()->getResults());
|
rewriter.replaceOp(op, linalg_op.getOperation()->getResults());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue