diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h index b314bd4..b778e94 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_CHLO_OPS_H_ #include "llvm/ADT/StringRef.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index c9e46ea..2116556 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -422,14 +422,28 @@ def HLOClient_BroadcastComplexOp : HLOClient_BroadcastBinaryElementwiseOp< class HLOClient_UnaryElementwiseOp traits, Type ArgTensorType, Type ResultTensorType> : HLOClient_Op { + traits # [InferFusibilityOpInterface, NoSideEffect, Elementwise, + SameOperandsAndResultShape, InferShapedTypeOpInterface]> { let arguments = (ins ArgTensorType:$operand); let results = (outs ResultTensorType:$result); let assemblyFormat = [{ $operand attr-dict `:` type($operand) `->` type($result) }]; + + let extraClassDeclaration = [{ + static LogicalResult inferReturnTypeComponents(MLIRContext* context, + Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl& inferedReturnShapes) { + return failure(); + } + LogicalResult reifyReturnTypeShapes(OpBuilder& builder, + SmallVectorImpl& reifiedReturnShapes) { + return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(), + &reifiedReturnShapes); + } + }]; } def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", diff --git a/lib/Dialect/mhlo/IR/chlo_ops.cc b/lib/Dialect/mhlo/IR/chlo_ops.cc index 1fa000e..ad149e2 100644 --- a/lib/Dialect/mhlo/IR/chlo_ops.cc +++ b/lib/Dialect/mhlo/IR/chlo_ops.cc @@ -16,7 +16,6 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "llvm/ADT/APFloat.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/utils/broadcast_utils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" diff --git a/tests/rank-specialization.mlir b/tests/rank-specialization.mlir index 6258ef0..aff12b9 100644 --- a/tests/rank-specialization.mlir +++ b/tests/rank-specialization.mlir @@ -66,3 +66,22 @@ func @select_mixed(%pred: tensor<*xi1>, %arg1: tensor<*xf32>, : (tensor<*xi1>, tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } + +// ----- + +// Unary CHLO operation. +// CHECK-LABEL: @tan +// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> tensor<*xf32> +func @tan(%arg : tensor<*xf32>) -> tensor<*xf32> { + // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG]]) ( { + // CHECK: ^bb0(%[[ARG_:.*]]: tensor<*xf32>) + // CHECK: %[[TMP0:.*]] = chlo.tan %[[ARG_]] + // CHECK: %[[TMP1:.*]] = chlo.tan %[[TMP0]] + // CHECK: %[[TMP2:.*]] = chlo.tan %[[TMP1]] + // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP2]]) + // CHECK: return %[[RES]] + %0 = chlo.tan %arg : tensor<*xf32> -> tensor<*xf32> + %1 = chlo.tan %0 : tensor<*xf32> -> tensor<*xf32> + %2 = chlo.tan %1 : tensor<*xf32> -> tensor<*xf32> + return %2 : tensor<*xf32> +}