From 95ba03534f994d7fb148e17166a947181d74c53c Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Fri, 11 Jun 2021 13:07:19 -0700 Subject: [PATCH] Allow variadic operands/result in MHLO while This just adds support for it in the op, but keeps the production/uses as is (e.g., single tensor or tuple) matching what XLA export requires. In follow up here, would be to add pass for export to retuple and then the canonical form could be changed. Tuple'ing given control flow via regions & multi-result operations does not add representational power and all the get_tuple_element ops obscure the computation. The old form allowed single tensor or tuple. The new variadic number of tensor or tuples as tuples may be nested, so the input could have (Tensor<..>, Tuple, Tuple<...>, ...>, Tensor<...>) and HLO_Tensor doesn't allow Tuples. PiperOrigin-RevId: 378934388 --- README.md | 6 ++++- include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td | 4 ++-- .../mhlo/transforms/legalize_control_flow.cc | 11 ++++++--- .../transforms/mhlo_control_flow_to_scf.cc | 6 ++++- tests/ops.mlir | 24 ++++++++++++++++++- 5 files changed, 43 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 05aabe3..524d9b8 100644 --- a/README.md +++ b/README.md @@ -205,11 +205,15 @@ Exit: The MHLO dialect has no direct export format, it is only meant as an intermediate optimization dialect/format. It is also where we can experiment cheaply with new ops. This format will be where the representation would differ -from existing end points. +from existing endpoints. Status: Exists but need to be cleaned up and evolved, in particular with respect to supporting dynamic shapes. +MHLO differs from XLA HLO op set in multiple ways, including: +1. MHLO While accepts multiple operands and may produce multiple results + instead; + ### LMHLO LMHLO corresponds to late `mhlo` and operates on buffer domain (e.g., memref) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index c9fc9ae..5ac612a 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -904,11 +904,11 @@ def HLO_WhileOp: HLO_Op<"while", [ See https://www.tensorflow.org/xla/operation_semantics#while. }]; - let arguments = (ins HLO_TensorOrTuple:$val); + let arguments = (ins Variadic:$arg); let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body); - let results = (outs HLO_TensorOrTuple); + let results = (outs Variadic); // TODO(b/129422361): WhileOp has special conversion logic to HLO. let hasCustomHLOConverter = 1; diff --git a/lib/Dialect/mhlo/transforms/legalize_control_flow.cc b/lib/Dialect/mhlo/transforms/legalize_control_flow.cc index 3f876b8..17c017d 100644 --- a/lib/Dialect/mhlo/transforms/legalize_control_flow.cc +++ b/lib/Dialect/mhlo/transforms/legalize_control_flow.cc @@ -106,6 +106,9 @@ LogicalResult LowerIfOp(mlir::mhlo::IfOp if_op) { } LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) { + // TODO(jpienaar): Support multi-operand while op. + if (while_op.arg().size() != 1) return failure(); + // Converts a MHLO while loop into control flow. This generates a set of MLIR // blocks and branches, along with inlining the regions provided by the MHLO // while loop. The structure should be similar to below: @@ -140,7 +143,8 @@ LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) { // // br ^cond(%arg0) // Jumps to the condition statement. builder.setInsertionPointToEnd(orig_block); - builder.create(loc, cond_block, while_op.getOperand()); + // TODO(jpienaar): Support multi-operand while op. + builder.create(loc, cond_block, while_op.arg()[0]); // Updates the inlined condition blocks by replacing the return op with an // tensor.extract and conditional branch. This changes the block below: @@ -199,8 +203,9 @@ LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) { } // Erase the original while loop. - tail_block->addArgument(while_op.getType()); - while_op.getResult().replaceAllUsesWith(tail_block->getArgument(0)); + // TODO(jpienaar): Support multi-operand while op. + tail_block->addArgument(while_op.arg().getType()[0]); + while_op.getResult(0).replaceAllUsesWith(tail_block->getArgument(0)); op_inst->erase(); return success(); diff --git a/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc b/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc index 1b6b125..d7b8537 100644 --- a/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc +++ b/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc @@ -50,6 +50,9 @@ class ControlFlowToScfPass // TODO(jpienaar): Look into reformulating as a pattern. void MatchAndRewrite(WhileOp whileOp) { + // TODO(jpienaar): Supports multi-operand while op. + if (whileOp.arg().size() != 1) return; + // Handle pattern: // x = start // step = ... @@ -57,7 +60,8 @@ void MatchAndRewrite(WhileOp whileOp) { // while (x < limit) { ... x += step; } // Only handling multi value while loops at the moment. - auto tupleOp = whileOp.getOperand().getDefiningOp(); + // TODO(jpienaar): Support multi-operand while op. + auto tupleOp = whileOp.getOperand(0).getDefiningOp(); if (!tupleOp) return; auto bodyReturn = whileOp.body() .front() diff --git a/tests/ops.mlir b/tests/ops.mlir index 48192fe..f5d5002 100644 --- a/tests/ops.mlir +++ b/tests/ops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | mlir-hlo-opt | FileCheck %s +// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | FileCheck %s // Tests for types, ops with custom constraints, verifiers, printer or parser // methods. @@ -1502,3 +1502,25 @@ func @conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> ten (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> return %0 : tensor<1x8x8x16xf32> } + +// ----- + +// CHECK: func @lt_loop +func @lt_loop(%arg0: tensor<4xf32>, %arg1: tensor, %arg2: tensor, %arg3: tensor<4xf32>, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor) -> (tensor, tensor, tensor) { + %cst = constant dense<-1> : tensor + %cst_0 = constant dense<1> : tensor + %cst_1 = constant dense<0> : tensor + %cst_2 = constant dense<1000> : tensor + %0 = "mhlo.tuple"(%cst_1, %cst, %cst_2) : (tensor, tensor, tensor) -> tuple, tensor, tensor> + %1:3 = "mhlo.while"(%cst_1, %cst, %cst_2) ( { + ^bb0(%arg9: tensor, %arg10: tensor, %arg11: tensor): // no predecessors + %4 = "mhlo.compare"(%arg9, %arg11) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + "mhlo.return"(%4) : (tensor) -> () + }, { + ^bb0(%arg9: tensor, %arg10: tensor, %arg11: tensor): // no predecessors + %3 = mhlo.add %arg9, %cst_0 : tensor + %6 = "mhlo.tuple"(%3, %arg10, %arg11) : (tensor, tensor, tensor) -> tuple, tensor, tensor> + "mhlo.return"(%3, %arg10, %arg11) : (tensor, tensor, tensor) -> () + }) : (tensor, tensor, tensor) -> (tensor, tensor, tensor) + return %1#0, %1#2, %1#2: tensor, tensor, tensor +}