Allow variadic operands/result in MHLO while

This just adds support for it in the op, but keeps the production/uses as is (e.g., single tensor or tuple) matching what XLA export requires. In follow up here, would be to add pass for export to retuple and then the canonical form could be changed. Tuple'ing given control flow via regions & multi-result operations does not add representational power and all the get_tuple_element ops obscure the computation.

The old form allowed single tensor or tuple. The new variadic number of tensor or tuples as tuples may be nested, so the input could have (Tensor<..>, Tuple<Tensor<...>, Tuple<...>, ...>, Tensor<...>) and HLO_Tensor doesn't allow Tuples.

PiperOrigin-RevId: 378934388
This commit is contained in:
Jacques Pienaar 2021-06-11 13:07:19 -07:00 committed by TensorFlow MLIR Team
parent 33f95eecc7
commit 95ba03534f
5 changed files with 43 additions and 8 deletions

View File

@ -205,11 +205,15 @@ Exit:
The MHLO dialect has no direct export format, it is only meant as an The MHLO dialect has no direct export format, it is only meant as an
intermediate optimization dialect/format. It is also where we can experiment intermediate optimization dialect/format. It is also where we can experiment
cheaply with new ops. This format will be where the representation would differ cheaply with new ops. This format will be where the representation would differ
from existing end points. from existing endpoints.
Status: Exists but need to be cleaned up and evolved, in particular with respect Status: Exists but need to be cleaned up and evolved, in particular with respect
to supporting dynamic shapes. to supporting dynamic shapes.
MHLO differs from XLA HLO op set in multiple ways, including:
1. MHLO While accepts multiple operands and may produce multiple results
instead;
### LMHLO ### LMHLO
LMHLO corresponds to late `mhlo` and operates on buffer domain (e.g., memref) LMHLO corresponds to late `mhlo` and operates on buffer domain (e.g., memref)

View File

@ -904,11 +904,11 @@ def HLO_WhileOp: HLO_Op<"while", [
See https://www.tensorflow.org/xla/operation_semantics#while. See https://www.tensorflow.org/xla/operation_semantics#while.
}]; }];
let arguments = (ins HLO_TensorOrTuple:$val); let arguments = (ins Variadic<HLO_TensorOrTuple>:$arg);
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body); let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
let results = (outs HLO_TensorOrTuple); let results = (outs Variadic<HLO_TensorOrTuple>);
// TODO(b/129422361): WhileOp has special conversion logic to HLO. // TODO(b/129422361): WhileOp has special conversion logic to HLO.
let hasCustomHLOConverter = 1; let hasCustomHLOConverter = 1;

View File

@ -106,6 +106,9 @@ LogicalResult LowerIfOp(mlir::mhlo::IfOp if_op) {
} }
LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) { LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) {
// TODO(jpienaar): Support multi-operand while op.
if (while_op.arg().size() != 1) return failure();
// Converts a MHLO while loop into control flow. This generates a set of MLIR // Converts a MHLO while loop into control flow. This generates a set of MLIR
// blocks and branches, along with inlining the regions provided by the MHLO // blocks and branches, along with inlining the regions provided by the MHLO
// while loop. The structure should be similar to below: // while loop. The structure should be similar to below:
@ -140,7 +143,8 @@ LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) {
// <prior operations> // <prior operations>
// br ^cond(%arg0) // Jumps to the condition statement. // br ^cond(%arg0) // Jumps to the condition statement.
builder.setInsertionPointToEnd(orig_block); builder.setInsertionPointToEnd(orig_block);
builder.create<mlir::BranchOp>(loc, cond_block, while_op.getOperand()); // TODO(jpienaar): Support multi-operand while op.
builder.create<mlir::BranchOp>(loc, cond_block, while_op.arg()[0]);
// Updates the inlined condition blocks by replacing the return op with an // Updates the inlined condition blocks by replacing the return op with an
// tensor.extract and conditional branch. This changes the block below: // tensor.extract and conditional branch. This changes the block below:
@ -199,8 +203,9 @@ LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) {
} }
// Erase the original while loop. // Erase the original while loop.
tail_block->addArgument(while_op.getType()); // TODO(jpienaar): Support multi-operand while op.
while_op.getResult().replaceAllUsesWith(tail_block->getArgument(0)); tail_block->addArgument(while_op.arg().getType()[0]);
while_op.getResult(0).replaceAllUsesWith(tail_block->getArgument(0));
op_inst->erase(); op_inst->erase();
return success(); return success();

View File

@ -50,6 +50,9 @@ class ControlFlowToScfPass
// TODO(jpienaar): Look into reformulating as a pattern. // TODO(jpienaar): Look into reformulating as a pattern.
void MatchAndRewrite(WhileOp whileOp) { void MatchAndRewrite(WhileOp whileOp) {
// TODO(jpienaar): Supports multi-operand while op.
if (whileOp.arg().size() != 1) return;
// Handle pattern: // Handle pattern:
// x = start // x = start
// step = ... // step = ...
@ -57,7 +60,8 @@ void MatchAndRewrite(WhileOp whileOp) {
// while (x < limit) { ... x += step; } // while (x < limit) { ... x += step; }
// Only handling multi value while loops at the moment. // Only handling multi value while loops at the moment.
auto tupleOp = whileOp.getOperand().getDefiningOp<TupleOp>(); // TODO(jpienaar): Support multi-operand while op.
auto tupleOp = whileOp.getOperand(0).getDefiningOp<TupleOp>();
if (!tupleOp) return; if (!tupleOp) return;
auto bodyReturn = whileOp.body() auto bodyReturn = whileOp.body()
.front() .front()

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | mlir-hlo-opt | FileCheck %s // RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | FileCheck %s
// Tests for types, ops with custom constraints, verifiers, printer or parser // Tests for types, ops with custom constraints, verifiers, printer or parser
// methods. // methods.
@ -1502,3 +1502,25 @@ func @conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> ten
(tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
return %0 : tensor<1x8x8x16xf32> return %0 : tensor<1x8x8x16xf32>
} }
// -----
// CHECK: func @lt_loop
func @lt_loop(%arg0: tensor<4xf32>, %arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<4xf32>, %arg4: tensor<f32>, %arg5: tensor<f32>, %arg6: tensor<f32>, %arg7: tensor<f32>, %arg8: tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>) {
%cst = constant dense<-1> : tensor<i32>
%cst_0 = constant dense<1> : tensor<i32>
%cst_1 = constant dense<0> : tensor<i32>
%cst_2 = constant dense<1000> : tensor<i32>
%0 = "mhlo.tuple"(%cst_1, %cst, %cst_2) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>>
%1:3 = "mhlo.while"(%cst_1, %cst, %cst_2) ( {
^bb0(%arg9: tensor<i32>, %arg10: tensor<i32>, %arg11: tensor<i32>): // no predecessors
%4 = "mhlo.compare"(%arg9, %arg11) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"mhlo.return"(%4) : (tensor<i1>) -> ()
}, {
^bb0(%arg9: tensor<i32>, %arg10: tensor<i32>, %arg11: tensor<i32>): // no predecessors
%3 = mhlo.add %arg9, %cst_0 : tensor<i32>
%6 = "mhlo.tuple"(%3, %arg10, %arg11) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>>
"mhlo.return"(%3, %arg10, %arg11) : (tensor<i32>, tensor<i32>, tensor<i32>) -> ()
}) : (tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>)
return %1#0, %1#2, %1#2: tensor<i32>, tensor<i32>, tensor<i32>
}