diff --git a/README.md b/README.md index 05aabe3..524d9b8 100644 --- a/README.md +++ b/README.md @@ -205,11 +205,15 @@ Exit: The MHLO dialect has no direct export format, it is only meant as an intermediate optimization dialect/format. It is also where we can experiment cheaply with new ops. This format will be where the representation would differ -from existing end points. +from existing endpoints. Status: Exists but need to be cleaned up and evolved, in particular with respect to supporting dynamic shapes. +MHLO differs from XLA HLO op set in multiple ways, including: +1. MHLO While accepts multiple operands and may produce multiple results + instead; + ### LMHLO LMHLO corresponds to late `mhlo` and operates on buffer domain (e.g., memref) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index c9fc9ae..5ac612a 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -904,11 +904,11 @@ def HLO_WhileOp: HLO_Op<"while", [ See https://www.tensorflow.org/xla/operation_semantics#while. }]; - let arguments = (ins HLO_TensorOrTuple:$val); + let arguments = (ins Variadic:$arg); let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body); - let results = (outs HLO_TensorOrTuple); + let results = (outs Variadic); // TODO(b/129422361): WhileOp has special conversion logic to HLO. let hasCustomHLOConverter = 1; diff --git a/lib/Dialect/mhlo/transforms/legalize_control_flow.cc b/lib/Dialect/mhlo/transforms/legalize_control_flow.cc index 3f876b8..17c017d 100644 --- a/lib/Dialect/mhlo/transforms/legalize_control_flow.cc +++ b/lib/Dialect/mhlo/transforms/legalize_control_flow.cc @@ -106,6 +106,9 @@ LogicalResult LowerIfOp(mlir::mhlo::IfOp if_op) { } LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) { + // TODO(jpienaar): Support multi-operand while op. + if (while_op.arg().size() != 1) return failure(); + // Converts a MHLO while loop into control flow. This generates a set of MLIR // blocks and branches, along with inlining the regions provided by the MHLO // while loop. The structure should be similar to below: @@ -140,7 +143,8 @@ LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) { // // br ^cond(%arg0) // Jumps to the condition statement. builder.setInsertionPointToEnd(orig_block); - builder.create(loc, cond_block, while_op.getOperand()); + // TODO(jpienaar): Support multi-operand while op. + builder.create(loc, cond_block, while_op.arg()[0]); // Updates the inlined condition blocks by replacing the return op with an // tensor.extract and conditional branch. This changes the block below: @@ -199,8 +203,9 @@ LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) { } // Erase the original while loop. - tail_block->addArgument(while_op.getType()); - while_op.getResult().replaceAllUsesWith(tail_block->getArgument(0)); + // TODO(jpienaar): Support multi-operand while op. + tail_block->addArgument(while_op.arg().getType()[0]); + while_op.getResult(0).replaceAllUsesWith(tail_block->getArgument(0)); op_inst->erase(); return success(); diff --git a/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc b/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc index 1b6b125..d7b8537 100644 --- a/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc +++ b/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc @@ -50,6 +50,9 @@ class ControlFlowToScfPass // TODO(jpienaar): Look into reformulating as a pattern. void MatchAndRewrite(WhileOp whileOp) { + // TODO(jpienaar): Supports multi-operand while op. + if (whileOp.arg().size() != 1) return; + // Handle pattern: // x = start // step = ... @@ -57,7 +60,8 @@ void MatchAndRewrite(WhileOp whileOp) { // while (x < limit) { ... x += step; } // Only handling multi value while loops at the moment. - auto tupleOp = whileOp.getOperand().getDefiningOp(); + // TODO(jpienaar): Support multi-operand while op. + auto tupleOp = whileOp.getOperand(0).getDefiningOp(); if (!tupleOp) return; auto bodyReturn = whileOp.body() .front() diff --git a/tests/ops.mlir b/tests/ops.mlir index 48192fe..f5d5002 100644 --- a/tests/ops.mlir +++ b/tests/ops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | mlir-hlo-opt | FileCheck %s +// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | FileCheck %s // Tests for types, ops with custom constraints, verifiers, printer or parser // methods. @@ -1502,3 +1502,25 @@ func @conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> ten (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> return %0 : tensor<1x8x8x16xf32> } + +// ----- + +// CHECK: func @lt_loop +func @lt_loop(%arg0: tensor<4xf32>, %arg1: tensor, %arg2: tensor, %arg3: tensor<4xf32>, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor) -> (tensor, tensor, tensor) { + %cst = constant dense<-1> : tensor + %cst_0 = constant dense<1> : tensor + %cst_1 = constant dense<0> : tensor + %cst_2 = constant dense<1000> : tensor + %0 = "mhlo.tuple"(%cst_1, %cst, %cst_2) : (tensor, tensor, tensor) -> tuple, tensor, tensor> + %1:3 = "mhlo.while"(%cst_1, %cst, %cst_2) ( { + ^bb0(%arg9: tensor, %arg10: tensor, %arg11: tensor): // no predecessors + %4 = "mhlo.compare"(%arg9, %arg11) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + "mhlo.return"(%4) : (tensor) -> () + }, { + ^bb0(%arg9: tensor, %arg10: tensor, %arg11: tensor): // no predecessors + %3 = mhlo.add %arg9, %cst_0 : tensor + %6 = "mhlo.tuple"(%3, %arg10, %arg11) : (tensor, tensor, tensor) -> tuple, tensor, tensor> + "mhlo.return"(%3, %arg10, %arg11) : (tensor, tensor, tensor) -> () + }) : (tensor, tensor, tensor) -> (tensor, tensor, tensor) + return %1#0, %1#2, %1#2: tensor, tensor, tensor +}