[MLIR][LHLO] Legalize CallOp that call funcs with tensor args/results.
PiperOrigin-RevId: 320172723
This commit is contained in:
parent
e8cfdee592
commit
b076e018a8
|
@ -422,9 +422,15 @@ struct HloLegalizeToLhlo
|
|||
return llvm::all_of(inputs, isMemRefType) &&
|
||||
converter.isLegal(&op.getBody());
|
||||
});
|
||||
target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp returnOp) {
|
||||
return std::all_of(returnOp.operand_type_begin(),
|
||||
returnOp.operand_type_end(), isMemRefType);
|
||||
target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
|
||||
return std::all_of(op.operand_type_begin(), op.operand_type_end(),
|
||||
isMemRefType) &&
|
||||
std::all_of(op.result_type_begin(), op.result_type_end(),
|
||||
isMemRefType);
|
||||
});
|
||||
target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp op) {
|
||||
return std::all_of(op.operand_type_begin(), op.operand_type_end(),
|
||||
isMemRefType);
|
||||
});
|
||||
|
||||
auto module = getOperation();
|
||||
|
|
|
@ -169,11 +169,12 @@ func @external_func() -> tensor<3xi64>
|
|||
func @dyn_broadcast(%operand: memref<?x?xf32>) {
|
||||
// BOTH-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>)
|
||||
%tensor_operand = tensor_load %operand : memref<?x?xf32>
|
||||
%shape = call @external_func() : () -> tensor<3xi64>
|
||||
%c1 = constant 1 : i64
|
||||
%shape = tensor_from_elements(%c1, %c1, %c1) : tensor<3xi64>
|
||||
%tensor_result = "mhlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
|
||||
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||
} : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
|
||||
// BOTH: %[[SHAPE:.*]] = call @external_func()
|
||||
// BOTH: %[[SHAPE:.*]] = tensor_from_elements
|
||||
// BOTH: %[[C0:.*]] = constant 0 : index
|
||||
// BOTH: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64>
|
||||
// BOTH: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index
|
||||
|
|
Loading…
Reference in New Issue