[MLIR][LHLO] Legalize CallOp that call funcs with tensor args/results.
PiperOrigin-RevId: 320172723
This commit is contained in:
parent
e8cfdee592
commit
b076e018a8
|
@ -422,9 +422,15 @@ struct HloLegalizeToLhlo
|
||||||
return llvm::all_of(inputs, isMemRefType) &&
|
return llvm::all_of(inputs, isMemRefType) &&
|
||||||
converter.isLegal(&op.getBody());
|
converter.isLegal(&op.getBody());
|
||||||
});
|
});
|
||||||
target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp returnOp) {
|
target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
|
||||||
return std::all_of(returnOp.operand_type_begin(),
|
return std::all_of(op.operand_type_begin(), op.operand_type_end(),
|
||||||
returnOp.operand_type_end(), isMemRefType);
|
isMemRefType) &&
|
||||||
|
std::all_of(op.result_type_begin(), op.result_type_end(),
|
||||||
|
isMemRefType);
|
||||||
|
});
|
||||||
|
target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp op) {
|
||||||
|
return std::all_of(op.operand_type_begin(), op.operand_type_end(),
|
||||||
|
isMemRefType);
|
||||||
});
|
});
|
||||||
|
|
||||||
auto module = getOperation();
|
auto module = getOperation();
|
||||||
|
|
|
@ -169,11 +169,12 @@ func @external_func() -> tensor<3xi64>
|
||||||
func @dyn_broadcast(%operand: memref<?x?xf32>) {
|
func @dyn_broadcast(%operand: memref<?x?xf32>) {
|
||||||
// BOTH-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>)
|
// BOTH-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>)
|
||||||
%tensor_operand = tensor_load %operand : memref<?x?xf32>
|
%tensor_operand = tensor_load %operand : memref<?x?xf32>
|
||||||
%shape = call @external_func() : () -> tensor<3xi64>
|
%c1 = constant 1 : i64
|
||||||
|
%shape = tensor_from_elements(%c1, %c1, %c1) : tensor<3xi64>
|
||||||
%tensor_result = "mhlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
|
%tensor_result = "mhlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
|
||||||
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
|
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||||
} : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
|
} : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
|
||||||
// BOTH: %[[SHAPE:.*]] = call @external_func()
|
// BOTH: %[[SHAPE:.*]] = tensor_from_elements
|
||||||
// BOTH: %[[C0:.*]] = constant 0 : index
|
// BOTH: %[[C0:.*]] = constant 0 : index
|
||||||
// BOTH: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64>
|
// BOTH: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64>
|
||||||
// BOTH: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index
|
// BOTH: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index
|
||||||
|
|
Loading…
Reference in New Issue