Try to avoid a segfault if we don't support a lowering.

It can happen that a lowering for a certain type is not implemented yet.
We should not segfault in such a case, but instead return a failure().

PiperOrigin-RevId: 347801106
This commit is contained in:
Adrian Kuegel 2020-12-16 04:57:35 -08:00 committed by TensorFlow MLIR Team
parent e6e8920921
commit 61244b136c
1 changed files with 7 additions and 1 deletions

View File

@ -131,6 +131,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
SmallVector<AffineMap, 2> indexing_maps(args.size() + (isLHLO ? 0 : 1),
common_indexing_map);
bool failed = false;
auto linalg_op = rewriter.create<linalg::GenericOp>(
loc, op_result_types, inputs, output_buffers,
/*initTensors=*/ValueRange{}, indexing_maps,
@ -141,8 +142,13 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
Value op_result = lmhlo::HloOpToStdScalarOp::map<OpTy>(
op, body_result_types,
llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter);
if (op_result == nullptr) {
failed = true;
} else {
nested_builder.create<linalg::YieldOp>(loc, op_result);
}
});
if (failed) return failure();
rewriter.replaceOp(op, linalg_op.getOperation()->getResults());
return success();
}