[MLIR][HLO] Support CHLO unary operations in rank specialization clustering

PiperOrigin-RevId: 373397321
This commit is contained in:
A. Unique TensorFlower 2021-05-12 10:20:03 -07:00 committed by TensorFlow MLIR Team
parent 596918a6f1
commit 420c42a0a1
4 changed files with 36 additions and 3 deletions

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_CHLO_OPS_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_CHLO_OPS_H_
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" #include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"

View File

@ -422,14 +422,28 @@ def HLOClient_BroadcastComplexOp : HLOClient_BroadcastBinaryElementwiseOp<
class HLOClient_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits, class HLOClient_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
Type ArgTensorType, Type ResultTensorType> : HLOClient_Op<mnemonic, Type ArgTensorType, Type ResultTensorType> : HLOClient_Op<mnemonic,
traits # [InferFusibilityOpInterface, NoSideEffect, traits # [InferFusibilityOpInterface, NoSideEffect, Elementwise,
SameOperandsAndResultShape]> { SameOperandsAndResultShape, InferShapedTypeOpInterface]> {
let arguments = (ins ArgTensorType:$operand); let arguments = (ins ArgTensorType:$operand);
let results = (outs ResultTensorType:$result); let results = (outs ResultTensorType:$result);
let assemblyFormat = [{ let assemblyFormat = [{
$operand attr-dict `:` type($operand) `->` type($result) $operand attr-dict `:` type($operand) `->` type($result)
}]; }];
let extraClassDeclaration = [{
static LogicalResult inferReturnTypeComponents(MLIRContext* context,
Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
return failure();
}
LogicalResult reifyReturnTypeShapes(OpBuilder& builder,
SmallVectorImpl<Value>& reifiedReturnShapes) {
return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
&reifiedReturnShapes);
}
}];
} }
def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos",

View File

@ -16,7 +16,6 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "llvm/ADT/APFloat.h" #include "llvm/ADT/APFloat.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/utils/broadcast_utils.h" #include "mlir-hlo/utils/broadcast_utils.h"
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"

View File

@ -66,3 +66,22 @@ func @select_mixed(%pred: tensor<*xi1>, %arg1: tensor<*xf32>,
: (tensor<*xi1>, tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> : (tensor<*xi1>, tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32> return %0 : tensor<*xf32>
} }
// -----
// Unary CHLO operation.
// CHECK-LABEL: @tan
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> tensor<*xf32>
func @tan(%arg : tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG]]) ( {
// CHECK: ^bb0(%[[ARG_:.*]]: tensor<*xf32>)
// CHECK: %[[TMP0:.*]] = chlo.tan %[[ARG_]]
// CHECK: %[[TMP1:.*]] = chlo.tan %[[TMP0]]
// CHECK: %[[TMP2:.*]] = chlo.tan %[[TMP1]]
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP2]])
// CHECK: return %[[RES]]
%0 = chlo.tan %arg : tensor<*xf32> -> tensor<*xf32>
%1 = chlo.tan %0 : tensor<*xf32> -> tensor<*xf32>
%2 = chlo.tan %1 : tensor<*xf32> -> tensor<*xf32>
return %2 : tensor<*xf32>
}