[MLIR][LHLO] Legalize CallOp that call funcs with tensor args/results.

PiperOrigin-RevId: 320172723
This commit is contained in:
Alexander Belyaev 2020-07-08 12:59:45 +00:00 committed by Mehdi Amini
parent e8cfdee592
commit b076e018a8
2 changed files with 12 additions and 5 deletions

View File

@ -422,9 +422,15 @@ struct HloLegalizeToLhlo
return llvm::all_of(inputs, isMemRefType) && return llvm::all_of(inputs, isMemRefType) &&
converter.isLegal(&op.getBody()); converter.isLegal(&op.getBody());
}); });
target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp returnOp) { target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
return std::all_of(returnOp.operand_type_begin(), return std::all_of(op.operand_type_begin(), op.operand_type_end(),
returnOp.operand_type_end(), isMemRefType); isMemRefType) &&
std::all_of(op.result_type_begin(), op.result_type_end(),
isMemRefType);
});
target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp op) {
return std::all_of(op.operand_type_begin(), op.operand_type_end(),
isMemRefType);
}); });
auto module = getOperation(); auto module = getOperation();

View File

@ -169,11 +169,12 @@ func @external_func() -> tensor<3xi64>
func @dyn_broadcast(%operand: memref<?x?xf32>) { func @dyn_broadcast(%operand: memref<?x?xf32>) {
// BOTH-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>) // BOTH-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>)
%tensor_operand = tensor_load %operand : memref<?x?xf32> %tensor_operand = tensor_load %operand : memref<?x?xf32>
%shape = call @external_func() : () -> tensor<3xi64> %c1 = constant 1 : i64
%shape = tensor_from_elements(%c1, %c1, %c1) : tensor<3xi64>
%tensor_result = "mhlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) { %tensor_result = "mhlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
} : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32> } : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
// BOTH: %[[SHAPE:.*]] = call @external_func() // BOTH: %[[SHAPE:.*]] = tensor_from_elements
// BOTH: %[[C0:.*]] = constant 0 : index // BOTH: %[[C0:.*]] = constant 0 : index
// BOTH: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64> // BOTH: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64>
// BOTH: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index // BOTH: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index