diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index 3838d3d..187b145 100644 --- a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -422,9 +422,15 @@ struct HloLegalizeToLhlo return llvm::all_of(inputs, isMemRefType) && converter.isLegal(&op.getBody()); }); - target.addDynamicallyLegalOp([&](mlir::ReturnOp returnOp) { - return std::all_of(returnOp.operand_type_begin(), - returnOp.operand_type_end(), isMemRefType); + target.addDynamicallyLegalOp([&](CallOp op) { + return std::all_of(op.operand_type_begin(), op.operand_type_end(), + isMemRefType) && + std::all_of(op.result_type_begin(), op.result_type_end(), + isMemRefType); + }); + target.addDynamicallyLegalOp([&](mlir::ReturnOp op) { + return std::all_of(op.operand_type_begin(), op.operand_type_end(), + isMemRefType); }); auto module = getOperation(); diff --git a/tests/hlo-legalize-to-lhlo.mlir b/tests/hlo-legalize-to-lhlo.mlir index 0db595c..a555935 100644 --- a/tests/hlo-legalize-to-lhlo.mlir +++ b/tests/hlo-legalize-to-lhlo.mlir @@ -169,11 +169,12 @@ func @external_func() -> tensor<3xi64> func @dyn_broadcast(%operand: memref) { // BOTH-SAME: (%[[OPERAND:.*]]: memref) %tensor_operand = tensor_load %operand : memref - %shape = call @external_func() : () -> tensor<3xi64> + %c1 = constant 1 : i64 + %shape = tensor_from_elements(%c1, %c1, %c1) : tensor<3xi64> %tensor_result = "mhlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) { broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> } : (tensor, tensor<3xi64>) -> tensor - // BOTH: %[[SHAPE:.*]] = call @external_func() + // BOTH: %[[SHAPE:.*]] = tensor_from_elements // BOTH: %[[C0:.*]] = constant 0 : index // BOTH: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64> // BOTH: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index