[MLIR][HLO] Support CHLO unary operations in rank specialization clustering
PiperOrigin-RevId: 373397321
This commit is contained in:
parent
596918a6f1
commit
420c42a0a1
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_CHLO_OPS_H_
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
|
|
@ -422,14 +422,28 @@ def HLOClient_BroadcastComplexOp : HLOClient_BroadcastBinaryElementwiseOp<
|
|||
|
||||
class HLOClient_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
|
||||
Type ArgTensorType, Type ResultTensorType> : HLOClient_Op<mnemonic,
|
||||
traits # [InferFusibilityOpInterface, NoSideEffect,
|
||||
SameOperandsAndResultShape]> {
|
||||
traits # [InferFusibilityOpInterface, NoSideEffect, Elementwise,
|
||||
SameOperandsAndResultShape, InferShapedTypeOpInterface]> {
|
||||
let arguments = (ins ArgTensorType:$operand);
|
||||
let results = (outs ResultTensorType:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$operand attr-dict `:` type($operand) `->` type($result)
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static LogicalResult inferReturnTypeComponents(MLIRContext* context,
|
||||
Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
|
||||
return failure();
|
||||
}
|
||||
LogicalResult reifyReturnTypeShapes(OpBuilder& builder,
|
||||
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||
return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
|
||||
&reifiedReturnShapes);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos",
|
||||
|
|
|
@ -16,7 +16,6 @@ limitations under the License.
|
|||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/utils/broadcast_utils.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
|
|
@ -66,3 +66,22 @@ func @select_mixed(%pred: tensor<*xi1>, %arg1: tensor<*xf32>,
|
|||
: (tensor<*xi1>, tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Unary CHLO operation.
|
||||
// CHECK-LABEL: @tan
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> tensor<*xf32>
|
||||
func @tan(%arg : tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG]]) ( {
|
||||
// CHECK: ^bb0(%[[ARG_:.*]]: tensor<*xf32>)
|
||||
// CHECK: %[[TMP0:.*]] = chlo.tan %[[ARG_]]
|
||||
// CHECK: %[[TMP1:.*]] = chlo.tan %[[TMP0]]
|
||||
// CHECK: %[[TMP2:.*]] = chlo.tan %[[TMP1]]
|
||||
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP2]])
|
||||
// CHECK: return %[[RES]]
|
||||
%0 = chlo.tan %arg : tensor<*xf32> -> tensor<*xf32>
|
||||
%1 = chlo.tan %0 : tensor<*xf32> -> tensor<*xf32>
|
||||
%2 = chlo.tan %1 : tensor<*xf32> -> tensor<*xf32>
|
||||
return %2 : tensor<*xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue