1402 lines
44 KiB
TableGen
1402 lines
44 KiB
TableGen
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
// This is the operation definition file for MHLO ops.
|
|
|
|
#ifndef HLO_OPS
|
|
#define HLO_OPS
|
|
|
|
include "mlir/IR/OpBase.td"
|
|
include "mlir/Interfaces/InferTypeOpInterface.td"
|
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
|
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
|
include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td"
|
|
include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td"
|
|
|
|
def HLO_Dialect : Dialect {
|
|
let name = "mhlo";
|
|
let cppNamespace = "mhlo";
|
|
}
|
|
|
|
class HLO_Op<string mnemonic, list<OpTrait> traits> :
|
|
Op<HLO_Dialect, mnemonic, traits> {
|
|
// Whether this operation has a custom conversion to HLO or not.
|
|
bit hasCustomHLOConverter = 0b0;
|
|
|
|
// TODO(b/129012527) Much of this custom verification should be expressed as
|
|
// type constraints.
|
|
let verifier = [{ return Verify(*this); }];
|
|
}
|
|
|
|
def HLO_LOOP_FUSION : StrEnumAttrCase<"kLoop">;
|
|
def HLO_INPUT_FUSION : StrEnumAttrCase<"kInput">;
|
|
def HLO_OUTPUT_FUSION : StrEnumAttrCase<"kOutput">;
|
|
def HLO_CUSTOM_FUSION : StrEnumAttrCase<"kCustom">;
|
|
def HLO_FusionKindAttr : StrEnumAttr<"FusionKind", "fusion kind", [
|
|
HLO_LOOP_FUSION, HLO_INPUT_FUSION, HLO_OUTPUT_FUSION, HLO_CUSTOM_FUSION
|
|
]>;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MHLO nullary op definitions.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
def HLO_ConstOp : HLO_Op<"constant",
|
|
[ConstantLike, NoSideEffect, AllTypesMatch<["value", "output"]>]>,
|
|
BASE_HLO_ConstOp {
|
|
let arguments = (ins
|
|
ElementsAttr:$value
|
|
);
|
|
|
|
let results = (outs
|
|
HLO_StaticShapeTensor:$output
|
|
);
|
|
|
|
let builders = [OpBuilder<
|
|
"OpBuilder &builder, OperationState &result, Attribute value"
|
|
>];
|
|
|
|
let assemblyFormat = "attr-dict $value";
|
|
|
|
let hasFolder = 1;
|
|
|
|
// Constant has special conversion logic to HLO.
|
|
let hasCustomHLOConverter = 1;
|
|
}
|
|
|
|
def HLO_IotaOp : HLO_Op<"iota", [NoSideEffect]>, BASE_HLO_IotaOp {
|
|
let arguments = (ins I64Attr:$iota_dimension);
|
|
|
|
let results = (outs HLO_IntFpOrComplexTensor:$output);
|
|
|
|
// TODO(b/130357376): Iota has special conversion logic to HLO.
|
|
let hasCustomHLOConverter = 1;
|
|
let hasCanonicalizer = 1;
|
|
let hasFolder = 1;
|
|
}
|
|
|
|
def HLO_DynamicIotaOp: HLO_Op<"dynamic_iota", [NoSideEffect]> {
|
|
let summary = "Create linear increasing values from 0 to length -1.";
|
|
let description = [{
|
|
Produces an HLO Tensor of the specified shape, with an incremental set of
|
|
values along the specified dimension starting at 0.
|
|
|
|
Requires:
|
|
- The output length of the tensor result.
|
|
}];
|
|
|
|
let arguments = (ins HLO_DimensionTensor:$output_shape, I64Attr:$iota_dimension);
|
|
let results = (outs HLO_Tensor:$result);
|
|
|
|
let hasCanonicalizer = 1;
|
|
// Cannot be exported to legacy formats.
|
|
let hasCustomHLOConverter = 1;
|
|
}
|
|
|
|
|
|
def HLO_CreateTokenOp : HLO_Op<"create_token", [NoSideEffect]> {
|
|
string summary = "Create Token operator";
|
|
|
|
string description = [{
|
|
Produces a HLO token. Tokens are used for ordering side-effecting perations.
|
|
This is exported to HLO as an AfterAll operation with no operands to
|
|
generate a token.
|
|
}];
|
|
|
|
let results = (outs HLO_Token:$output);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MHLO unary elementwise op definitions.
|
|
//===----------------------------------------------------------------------===//
|
|
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
|
|
|
|
class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
|
|
Type TensorType>: HLO_Op<mnemonic,
|
|
!listconcat(traits, [InferShapedTypeOpInterface, InferFusibilityOpInterface])> {
|
|
let arguments = (ins TensorType:$operand);
|
|
let results = (outs TensorType);
|
|
let extraClassDeclaration = [{
|
|
static LogicalResult inferReturnTypeComponents(
|
|
MLIRContext* context, Optional<Location> location,
|
|
ValueRange operands, DictionaryAttr attributes, RegionRange regions,
|
|
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
|
|
return failure();
|
|
}
|
|
LogicalResult reifyReturnTypeShapes(
|
|
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
|
|
return deriveShapeFromFirstOperand(&builder, getOperation(),
|
|
&reifiedReturnShapes);
|
|
}
|
|
bool inferInputOutputShapeEquality(int input, int output) {
|
|
return true;
|
|
}
|
|
llvm::Optional<Value> inferEffectiveWorkloadShape() {
|
|
return getOperation()->getResult(0);
|
|
}
|
|
}];
|
|
}
|
|
|
|
// Abs supports complex to real, so element type is not guaranteed to match.
|
|
def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs",
|
|
[NoSideEffect, SameOperandsAndResultShape],
|
|
TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_AbsOp {
|
|
let builders = [OpBuilder<
|
|
"OpBuilder &builder, OperationState &result, Value operand"
|
|
>];
|
|
}
|
|
|
|
def HLO_CeilOp: HLO_UnaryElementwiseOp<"ceil",
|
|
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CeilOp;
|
|
|
|
def HLO_ConvertOp : HLO_UnaryElementwiseOp<
|
|
"convert", [NoSideEffect, SameOperandsAndResultShape], HLO_Tensor>,
|
|
BASE_HLO_ConvertOp {
|
|
|
|
let builders = [OpBuilder<
|
|
"OpBuilder &, OperationState &tblgen_state, Value operand, "
|
|
"Type result_element_ty"
|
|
>];
|
|
|
|
let hasFolder = 1;
|
|
|
|
let hasCustomHLOConverter = 1;
|
|
}
|
|
|
|
def HLO_ClzOp: HLO_UnaryElementwiseOp<"count_leading_zeros",
|
|
[NoSideEffect, SameOperandsAndResultType], HLO_IntTensor>,
|
|
BASE_HLO_ClzOp;
|
|
|
|
def HLO_CosOp: HLO_UnaryElementwiseOp<"cosine",
|
|
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>,
|
|
BASE_HLO_CosOp;
|
|
|
|
def HLO_ExpOp: HLO_UnaryElementwiseOp<"exponential",
|
|
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>,
|
|
BASE_HLO_ExpOp;
|
|
|
|
def HLO_Expm1Op: HLO_UnaryElementwiseOp<"exponential_minus_one",
|
|
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>,
|
|
BASE_HLO_Expm1Op;
|
|
|
|
def HLO_FloorOp: HLO_UnaryElementwiseOp<"floor",
|
|
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_FloorOp;
|
|
|
|
def HLO_ImagOp: HLO_Op<
|
|
"imag", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_ImagOp {
|
|
let builders = [OpBuilder<
|
|
"OpBuilder &, OperationState &tblgen_state, Value val">];
|
|
|
|
let arguments = (ins HLO_ComplexTensor);
|
|
let results = (outs HLO_FpTensor);
|
|
let hasFolder = 1;
|
|
}
|
|
|
|
def HLO_IsFiniteOp: HLO_UnaryElementwiseOp<"is_finite",
|
|
[NoSideEffect, SameOperandsAndResultShape], HLO_Tensor>,
|
|
BASE_HLO_IsFiniteOp {
|
|
let arguments = (ins HLO_FpTensor:$x);
|
|
let results = (outs HLO_PredTensor:$y);
|
|
}
|
|
|
|
def HLO_LogOp: HLO_UnaryElementwiseOp<"log",
|
|
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>,
|
|
BASE_HLO_LogOp;
|
|
|
|
def HLO_Log1pOp: HLO_UnaryElementwiseOp<"log_plus_one",
|
|
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>,
|
|
BASE_HLO_Log1pOp;
|
|
|
|
def HLO_LogisticOp: HLO_UnaryElementwiseOp<"logistic",
|
|
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>,
|
|
BASE_HLO_LogisticOp;
|
|
|
|
def HLO_NotOp: HLO_UnaryElementwiseOp<"not",
|
|
[NoSideEffect, SameOperandsAndResultType], HLO_PredOrIntTensor>,
|
|
BASE_HLO_NotOp;
|
|
|
|
def HLO_NegOp: HLO_UnaryElementwiseOp<"negate",
|
|
[NoSideEffect, SameOperandsAndResultType], HLO_IntFpOrComplexTensor>,
|
|
BASE_HLO_NegOp;
|
|
|
|
def HLO_PopulationCountOp: HLO_UnaryElementwiseOp<"popcnt",
|
|
[NoSideEffect, SameOperandsAndResultType], HLO_IntTensor>,
|
|
BASE_HLO_PopulationCountOp;
|
|
|
|
def HLO_RealOp: HLO_Op<
|
|
"real", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_RealOp {
|
|
let builders = [OpBuilder<
|
|
"OpBuilder &, OperationState &tblgen_state, Value val">];
|
|
|
|
let arguments = (ins HLO_ComplexTensor);
|
|
let results = (outs HLO_FpTensor);
|
|
let hasFolder = 1;
|
|
}
|
|
|
|
def HLO_RoundOp: HLO_UnaryElementwiseOp<"round_nearest_afz",
|
|
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_RoundOp;
|
|
|
|
def HLO_RsqrtOp: HLO_UnaryElementwiseOp<"rsqrt",
|
|
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>,
|
|
BASE_HLO_RsqrtOp;
|
|
|
|
def HLO_SignOp: HLO_UnaryElementwiseOp<"sign",
|
|
[NoSideEffect, SameOperandsAndResultType],
|
|
TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>,
|
|
BASE_HLO_SignOp;
|
|
|
|
def HLO_SinOp: HLO_UnaryElementwiseOp<"sine",
|
|
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>,
|
|
BASE_HLO_SinOp;
|
|
|
|
def HLO_SqrtOp: HLO_UnaryElementwiseOp<"sqrt",
|
|
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>,
|
|
BASE_HLO_SqrtOp;
|
|
|
|
def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh",
|
|
[NoSideEffect, SameOperandsAndResultType],
|
|
HLO_FpOrComplexTensor>, BASE_HLO_TanhOp;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MHLO binary elementwise op definitions.
|
|
//===----------------------------------------------------------------------===//
|
|
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
|
|
|
|
class HLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
|
|
HLO_Op<mnemonic, !listconcat(traits, [InferShapedTypeOpInterface, InferFusibilityOpInterface])> {
|
|
let arguments = (ins
|
|
HLO_Tensor:$lhs,
|
|
HLO_Tensor:$rhs
|
|
);
|
|
|
|
let extraClassDeclaration = [{
|
|
static LogicalResult inferReturnTypeComponents(
|
|
MLIRContext* context, Optional<Location> location, ValueRange operands,
|
|
DictionaryAttr attributes, RegionRange regions,
|
|
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
|
|
return failure();
|
|
}
|
|
LogicalResult reifyReturnTypeShapes(
|
|
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
|
|
return deriveShapeFromFirstOperand(&builder, getOperation(),
|
|
&reifiedReturnShapes);
|
|
}
|
|
bool inferInputsShapeEquality(int lhs, int rhs) {
|
|
return true;
|
|
}
|
|
bool inferInputOutputShapeEquality(int input, int output) {
|
|
return true;
|
|
}
|
|
llvm::Optional<Value> inferEffectiveWorkloadShape() {
|
|
return getOperation()->getResult(0);
|
|
}
|
|
}];
|
|
|
|
let results = (outs HLO_Tensor);
|
|
let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }];
|
|
let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
|
|
}
|
|
|
|
def HLO_AddOp : HLO_BinaryElementwiseOp<"add",
|
|
[Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_AddOp {
|
|
let hasFolder = 1;
|
|
}
|
|
|
|
def HLO_Atan2Op : HLO_BinaryElementwiseOp<"atan2",
|
|
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_Atan2Op;
|
|
|
|
def HLO_ComplexOp: HLO_Op<"complex",
|
|
[NoSideEffect, SameOperandsAndResultShape]>,
|
|
BASE_HLO_ComplexOp {
|
|
let builders = [OpBuilder<
|
|
"OpBuilder &, OperationState &tblgen_state, Value lhs, Value rhs">];
|
|
|
|
let arguments = (ins HLO_FpTensor:$lhs, HLO_FpTensor:$rhs);
|
|
let results = (outs HLO_ComplexTensor);
|
|
let hasFolder = 1;
|
|
}
|
|
|
|
def HLO_DivOp : HLO_BinaryElementwiseOp<"divide",
|
|
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_DivOp {
|
|
let hasFolder = 1;
|
|
}
|
|
|
|
def HLO_MaxOp : HLO_BinaryElementwiseOp<"maximum",
|
|
[Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MaxOp {
|
|
let hasFolder = 1;
|
|
}
|
|
|
|
def HLO_MinOp : HLO_BinaryElementwiseOp<"minimum",
|
|
[Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MinOp {
|
|
let hasFolder = 1;
|
|
}
|
|
|
|
def HLO_MulOp : HLO_BinaryElementwiseOp<"multiply",
|
|
[Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MulOp {
|
|
let hasFolder = 1;
|
|
}
|
|
|
|
def HLO_PowOp : HLO_BinaryElementwiseOp<"power",
|
|
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_PowOp;
|
|
|
|
def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder",
|
|
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp;
|
|
|
|
def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left",
|
|
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftLeftOp;
|
|
|
|
def HLO_ShiftRightArithmeticOp : HLO_BinaryElementwiseOp<"shift_right_arithmetic",
|
|
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftRightArithmeticOp;
|
|
|
|
def HLO_ShiftRightLogicalOp : HLO_BinaryElementwiseOp<"shift_right_logical",
|
|
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftRightLogicalOp;
|
|
|
|
def HLO_SubOp : HLO_BinaryElementwiseOp<"subtract",
|
|
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_SubOp {
|
|
let hasFolder = 1;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MHLO binary logical elementwise op definitions.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
|
|
class HLO_BinaryLogicalElementwiseOp<string mnemonic> :
|
|
HLO_BinaryElementwiseOp<
|
|
mnemonic, [Commutative, NoSideEffect, SameOperandsAndResultType]> {
|
|
let arguments = (ins
|
|
HLO_PredOrIntTensor:$lhs,
|
|
HLO_PredOrIntTensor:$rhs
|
|
);
|
|
}
|
|
|
|
def HLO_AndOp: HLO_BinaryLogicalElementwiseOp<"and">, BASE_HLO_AndOp;
|
|
def HLO_OrOp: HLO_BinaryLogicalElementwiseOp<"or">, BASE_HLO_OrOp;
|
|
def HLO_XorOp : HLO_BinaryLogicalElementwiseOp<"xor">, BASE_HLO_XorOp;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MHLO communication op definitions.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// InfeedOp corresponds to 'InfeedWithToken' xla client API and not 'Infeed'.
|
|
// InfeedWithToken allows ordering of infeed HLO instructions using tokens.
|
|
def HLO_InfeedOp : HLO_Op<"infeed", []> {
|
|
|
|
string summary = "Infeed operator";
|
|
|
|
string description = [{
|
|
Reads a single data item from the implicit Infeed streaming interface of
|
|
the device, interpreting the data as the given shape, and returns a XlaOp
|
|
of the data. Multiple Infeed operations are allowed in a computation, but
|
|
there must be a total order among the Infeed operations.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#infeed.
|
|
}];
|
|
|
|
let arguments = (ins
|
|
HLO_Token:$token,
|
|
DefaultValuedAttr<StrAttr, "">:$infeed_config
|
|
);
|
|
let results = (outs HLO_Tuple);
|
|
let hasCustomHLOConverter = 1;
|
|
}
|
|
|
|
// OutfeedOp corresponds to 'OutfeedWithToken' xla client API and not 'Outfeed'.
|
|
// OutfeedWithToken allows ordering of outfeed HLO instructions using tokens.
|
|
def HLO_OutfeedOp : HLO_Op<"outfeed", []> {
|
|
|
|
string summary = "Outfeed operator";
|
|
|
|
string description = [{
|
|
Generates outgoing data transfers for the given data. It takes data and a
|
|
token type operand and produces a token type value. Tokens are used for
|
|
ordering side-effecting operations.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#outfeed.
|
|
}];
|
|
|
|
let arguments = (ins
|
|
HLO_TensorOrTuple:$operand,
|
|
HLO_Token:$token,
|
|
DefaultValuedAttr<StrAttr, "">:$outfeed_config
|
|
);
|
|
let results = (outs HLO_Token);
|
|
let hasCustomHLOConverter = 1;
|
|
}
|
|
|
|
def HLO_SendOp : HLO_Op<"send", []> {
|
|
|
|
string summary = "Send operator";
|
|
|
|
string description = [{
|
|
Sends the given operand data to a Recv instruction in another computation
|
|
that shares the same channel handle. Does not return any data. Similar to
|
|
the Recv operation, Send operation represents synchronous communication,
|
|
and is internally decomposed into 2 HLO instructions (Send and SendDone) to
|
|
enable asynchronous data transfers.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#send.
|
|
}];
|
|
|
|
let arguments = (ins
|
|
HLO_TensorOrTuple:$operand,
|
|
HLO_Token:$token,
|
|
ChannelHandle<HLO_Dialect>:$channel_id,
|
|
DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer
|
|
);
|
|
|
|
let results = (outs HLO_Token);
|
|
let hasCustomHLOConverter = 1;
|
|
}
|
|
|
|
def HLO_RecvOp : HLO_Op<"recv", []> {
|
|
|
|
string summary = "Recv operator";
|
|
|
|
string description = [{
|
|
Receives data of the given shape from a Send instruction in another
|
|
computation that shares the same channel handle. Returns a tuple containing
|
|
value for the received data and a token. Recv operation represents
|
|
synchronous communication. However, the instruction is internally decomposed
|
|
into 2 HLO instructions (Recv and RecvDone) to enable asynchronous data
|
|
transfers.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#recv.
|
|
}];
|
|
|
|
let arguments = (ins
|
|
HLO_Token:$token,
|
|
ChannelHandle<HLO_Dialect>:$channel_id,
|
|
DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer
|
|
);
|
|
|
|
let results = (outs HLO_Tuple);
|
|
let hasCustomHLOConverter = 1;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MHLO parallelism related op definitions.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect]>,
|
|
BASE_HLO_ReplicaIdOp {
|
|
// TODO(prakalps): The output should unsigned 32-bit integer but mlir does
|
|
// not differentiate between signed and unsigned int.
|
|
let results = (outs I32Tensor);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MHLO control flow op definitions.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
def HLO_AfterAllOp : HLO_Op<"after_all", [NoSideEffect]> {
|
|
|
|
string summary = "AfterAll operator";
|
|
|
|
string description = [{
|
|
AfterAll takes a variadic number of tokens and produces a single token.
|
|
Tokens are primitive types which can be threaded between side-effecting
|
|
operations to enforce ordering. AfterAll can be used as a join of tokens
|
|
for ordering a operation after a set operations.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#afterall.
|
|
}];
|
|
|
|
let arguments = (ins Variadic<HLO_Token>:$operands);
|
|
let results = (outs HLO_Token);
|
|
}
|
|
|
|
// Xla Client API has two separate calls for indexed and predicated conditional,
|
|
// although both eventually map to kConditional HLO. IfOp maps to predicated
|
|
// conditional use of kConditional HLO.
|
|
def HLO_IfOp: HLO_Op<"if", [RecursiveSideEffects]> {
|
|
string summary = "If operator";
|
|
|
|
string description = [{
|
|
Returns the result of executing either a true or false function depending on
|
|
the result of a condition function.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#conditional.
|
|
}];
|
|
|
|
let arguments = (ins
|
|
HLO_PredTensor:$pred,
|
|
HLO_TensorOrTuple:$true_arg,
|
|
HLO_TensorOrTuple:$false_arg
|
|
);
|
|
|
|
let regions = (region AnyRegion:$true_branch,
|
|
AnyRegion:$false_branch);
|
|
|
|
let results = (outs HLO_TensorOrTuple);
|
|
|
|
// TODO(b/129422361): ConditionalOp has special conversion logic to HLO.
|
|
let hasCustomHLOConverter = 1;
|
|
}
|
|
|
|
// Xla Client API has two separate calls for indexed and predicated conditional,
|
|
// although both eventually map to kConditional HLO. CaseOp maps to indexed
|
|
// conditional use of kConditional HLO.
|
|
def HLO_CaseOp: HLO_Op<"case", [RecursiveSideEffects]>,
|
|
BASE_HLO_CaseOp {
|
|
|
|
let arguments = (ins
|
|
I32Tensor:$index,
|
|
Variadic<HLO_TensorOrTuple>:$branch_operands
|
|
);
|
|
|
|
let regions = (region VariadicRegion<AnyRegion>:$branches);
|
|
|
|
let results = (outs Variadic<HLO_TensorOrTuple>);
|
|
|
|
let hasCustomHLOConverter = 1;
|
|
}
|
|
|
|
|
|
def HLO_WhileOp: HLO_Op<"while", [RecursiveSideEffects,
|
|
SameOperandsAndResultType]>,
|
|
BASE_HLO_WhileOp {
|
|
let arguments = (ins HLO_TensorOrTuple:$val);
|
|
|
|
let regions = (region AnyRegion:$cond, AnyRegion:$body);
|
|
|
|
let results = (outs HLO_TensorOrTuple);
|
|
|
|
// TODO(b/129422361): WhileOp has special conversion logic to HLO.
|
|
let hasCustomHLOConverter = 1;
|
|
}
|
|
|
|
def HLO_AllReduceOp : HLO_Op<"all_reduce",
|
|
[SameOperandsAndResultType]>, BASE_HLO_AllReduceOp {
|
|
|
|
let arguments = (ins
|
|
HLO_Tensor:$operand,
|
|
I64ElementsAttr:$replica_groups,
|
|
OptionalAttr<ChannelHandle<HLO_Dialect>>:$channel_id
|
|
);
|
|
let regions = (region SizedRegion<1>:$computation);
|
|
let results = (outs HLO_Tensor);
|
|
|
|
let hasCustomHLOConverter = 1;
|
|
}
|
|
|
|
def HLO_AllToAllOp : HLO_Op<"all_to_all",
|
|
[NoSideEffect, SameOperandsElementType, SameOperandsShape]>, BASE_HLO_AllToAllOp {
|
|
|
|
let arguments = (ins
|
|
HLO_Tensor:$operand,
|
|
I64Attr:$split_dimension,
|
|
I64Attr:$concat_dimension,
|
|
I64Attr:$split_count,
|
|
I64ElementsAttr:$replica_groups
|
|
);
|
|
let results = (outs HLO_Tensor);
|
|
}
|
|
|
|
def HLO_ReduceOp: HLO_Op<"reduce", [
|
|
RecursiveSideEffects,
|
|
SameVariadicOperandSize,
|
|
SingleBlockImplicitTerminator<"ReturnOp">,
|
|
InferFusibilityOpInterface
|
|
]>, BASE_HLO_ReduceOp {
|
|
let arguments = (ins
|
|
Variadic<HLO_TensorOrTuple>:$operands,
|
|
Variadic<HLO_TensorOrTuple>:$init_values,
|
|
I64ElementsAttr:$dimensions
|
|
);
|
|
|
|
let results = (outs Variadic<HLO_TensorOrTuple>);
|
|
|
|
let builders = [OpBuilder<
|
|
"OpBuilder &, OperationState &state, ValueRange operands, "
|
|
"ValueRange init_values, DenseIntElementsAttr dimensions"
|
|
>];
|
|
|
|
let extraClassDeclaration = [{
|
|
bool isFusibleWithConsumer() {
|
|
return false;
|
|
}
|
|
llvm::Optional<Value> inferEffectiveWorkloadShape() {
|
|
return getOperation()->getOperand(0);
|
|
}
|
|
}];
|
|
|
|
let hasFolder = 1;
|
|
|
|
// TODO(hinsu): Verify that the attached body arguments and results are
|
|
// compatible with reduce op's operands.
|
|
let regions = (region SizedRegion<1>:$body);
|
|
|
|
// TODO(hinsu): Implement custom printer and parser.
|
|
|
|
// TODO(b/129422361): ReduceOp has special conversion logic to HLO.
|
|
let hasCustomHLOConverter = 1;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MHLO tuple op definitions.
|
|
//===----------------------------------------------------------------------===//
|
|
def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect]>, BASE_HLO_GetTupleElementOp {
|
|
let arguments = (ins
|
|
HLO_Tuple,
|
|
I32Attr:$index
|
|
);
|
|
|
|
let results = (outs HLO_TensorOrTokenOrTuple);
|
|
|
|
let hasFolder = 1;
|
|
|
|
let builders = [OpBuilder<
|
|
"OpBuilder &builder, OperationState &results, "
|
|
"Value value, int32_t index">];
|
|
}
|
|
|
|
def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp {
|
|
let arguments = (ins Variadic<HLO_TensorOrTokenOrTuple>:$val);
|
|
let results = (outs HLO_Tuple);
|
|
|
|
let builders = [OpBuilder<
|
|
"OpBuilder &builder, OperationState &results, "
|
|
"ValueRange values">];
|
|
|
|
let hasCanonicalizer = 1;
|
|
}
|
|
|
|
def HLO_CompareOp: HLO_Op<"compare",
|
|
[NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]>,
|
|
BASE_HLO_CompareOp {
|
|
let arguments = (ins
|
|
HLO_Tensor:$lhs,
|
|
HLO_Tensor:$rhs,
|
|
HLO_ComparisonDirectionAttr:$comparison_direction
|
|
);
|
|
let results = (outs HLO_PredTensor);
|
|
|
|
let builders = [OpBuilder<
|
|
"OpBuilder &builder, OperationState &result, Value lhs, Value rhs, "
|
|
"StringAttr comparison_direction"
|
|
>];
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MHLO Slice definitions.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
def HLO_SliceOp: HLO_Op<
|
|
"slice",
|
|
[NoSideEffect, SameOperandsAndResultElementType,
|
|
AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> {
|
|
let arguments = (ins
|
|
HLO_Tensor:$operand,
|
|
I64ElementsAttr:$start_indices,
|
|
I64ElementsAttr:$limit_indices,
|
|
I64ElementsAttr:$strides
|
|
);
|
|
|
|
let results = (outs HLO_Tensor);
|
|
|
|
let hasCanonicalizer = 1;
|
|
let hasFolder = 1;
|
|
|
|
let builders = [OpBuilder<
|
|
"OpBuilder &builder, OperationState &result, Value operand, "
|
|
"DenseIntElementsAttr start_indices, DenseIntElementsAttr limit_indices, "
|
|
"DenseIntElementsAttr strides"
|
|
>];
|
|
|
|
let extraClassDeclaration = [{
|
|
// Infers output type for given operand and attributes. Result type is
|
|
// unranked if any of the attributes is illegal.
|
|
static Type InferOutputTypes(Builder *builder, Value operand,
|
|
DenseIntElementsAttr start_indices,
|
|
DenseIntElementsAttr limit_indices,
|
|
DenseIntElementsAttr strides);
|
|
}];
|
|
}
|
|
|
|
def HLO_DynamicSliceOp: HLO_Op<"dynamic-slice",
|
|
[NoSideEffect, AllElementTypesMatch<["operand", "result"]>]> {
|
|
let arguments = (ins
|
|
HLO_Tensor:$operand,
|
|
Variadic<HLO_ScalarIntTensor>:$start_indices,
|
|
I64ElementsAttr:$slice_sizes
|
|
);
|
|
|
|
let results = (outs HLO_Tensor:$result);
|
|
let hasCanonicalizer = 1;
|
|
}
|
|
|
|
def HLO_DynamicUpdateSliceOp: HLO_Op<"dynamic-update-slice",
|
|
[NoSideEffect, AllElementTypesMatch<["operand", "update", "result"]>,
|
|
AllShapesMatch<["operand", "result"]>]> {
|
|
let arguments = (ins
|
|
HLO_Tensor:$operand,
|
|
HLO_Tensor:$update,
|
|
Variadic<HLO_ScalarIntTensor>:$start_indices
|
|
);
|
|
|
|
let results = (outs HLO_Tensor:$result);
|
|
}
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MHLO Other op definitions.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
def HLO_BatchNormGradOp : HLO_Op<"batch_norm_grad", [NoSideEffect]>,
|
|
BASE_HLO_BatchNormGradOp {
|
|
|
|
let arguments = (ins
|
|
HLO_Tensor:$operand,
|
|
HLO_Tensor:$scale,
|
|
HLO_Tensor:$mean,
|
|
HLO_Tensor:$variance,
|
|
HLO_Tensor:$grad_output,
|
|
F32Attr:$epsilon,
|
|
I64Attr:$feature_index
|
|
);
|
|
|
|
let results = (outs HLO_Tuple);
|
|
}
|
|
|
|
def HLO_BatchNormInferenceOp : HLO_Op<"batch_norm_inference",
|
|
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_BatchNormInferenceOp {
|
|
|
|
let arguments = (ins
|
|
HLO_Tensor:$operand,
|
|
HLO_Tensor:$scale,
|
|
HLO_Tensor:$offset,
|
|
HLO_Tensor:$mean,
|
|
HLO_Tensor:$variance,
|
|
F32Attr:$epsilon,
|
|
I64Attr:$feature_index
|
|
);
|
|
|
|
let results = (outs HLO_Tensor);
|
|
}
|
|
|
|
def HLO_BatchNormTrainingOp : HLO_Op<"batch_norm_training", [NoSideEffect]>,
|
|
BASE_HLO_BatchNormTrainingOp {
|
|
|
|
let arguments = (ins
|
|
HLO_Tensor:$operand,
|
|
HLO_Tensor:$scale,
|
|
HLO_Tensor:$offset,
|
|
F32Attr:$epsilon,
|
|
I64Attr:$feature_index
|
|
);
|
|
|
|
let results = (outs HLO_Tuple);
|
|
}
|
|
|
|
def HLO_BitcastConvertOp : HLO_Op<"bitcast_convert",
|
|
[NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_BitcastConvertOp {
|
|
|
|
let arguments = (ins HLO_Tensor:$operand);
|
|
let results = (outs HLO_Tensor);
|
|
let hasCustomHLOConverter = 1;
|
|
}
|
|
|
|
def HLO_BroadcastOp : HLO_Op<"broadcast",
|
|
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_BroadcastOp {
|
|
let arguments = (ins
|
|
HLO_Tensor:$operand,
|
|
I64ElementsAttr:$broadcast_sizes
|
|
);
|
|
|
|
let results = (outs HLO_Tensor);
|
|
}
|
|
|
|
def HLO_BroadcastInDimOp : HLO_Op<"broadcast_in_dim",
|
|
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_BroadcastInDimOp {
|
|
let arguments = (ins
|
|
HLO_Tensor:$operand,
|
|
BroadcastDimAttr:$broadcast_dimensions
|
|
);
|
|
|
|
let results = (outs HLO_StaticShapeTensor);
|
|
|
|
let hasFolder = 1;
|
|
// Only handles a static subset of the legacy format.
|
|
let hasCustomHLOConverter = 1;
|
|
}
|
|
|
|
def HLO_DynamicBroadcastInDimOp : HLO_Op<"dynamic_broadcast_in_dim",
|
|
[NoSideEffect]> {
|
|
string summary = "Broadcast a tensor into the given dynamic shape by adding dimensions.";
|
|
string description = [{
|
|
This is a generalization of the BroadcastInDimOp which accepts its output
|
|
dimensions as an argument. It should eventually supercede the statically
|
|
shaped original, but is being phased as a separate op in order to support
|
|
compatibility with lowerings and translations that precede dynamic
|
|
shapes.
|
|
}];
|
|
let arguments = (ins
|
|
HLO_Tensor:$operand,
|
|
HLO_DimensionTensor:$output_dimensions,
|
|
BroadcastDimAttr:$broadcast_dimensions
|
|
);
|
|
|
|
let results = (outs HLO_Tensor);
|
|
|
|
let hasCanonicalizer = 1;
|
|
// Cannot be exported to legacy formats.
|
|
let hasCustomHLOConverter = 1;
|
|
}
|
|
|
|
// Note: There is no HLO_CallOp because the standard call operation mlir::CallOp
|
|
// is used instead. A mlir::CallOp is exported to a HLO call instruction
|
|
// directly.
|
|
|
|
def HLO_CholeskyOp : HLO_Op<"cholesky",
|
|
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_CholeskyOp {
|
|
let arguments = (ins
|
|
HLO_FpOrComplexTensor:$a,
|
|
DefaultValuedAttr<BoolAttr, "false">:$lower
|
|
);
|
|
|
|
let results = (outs HLO_FpOrComplexTensor);
|
|
}
|
|
|
|
def HLO_ClampOp : HLO_Op<"clamp",
|
|
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ClampOp {
|
|
let arguments = (ins
|
|
HLO_Tensor:$min,
|
|
HLO_Tensor:$operand,
|
|
HLO_Tensor:$max
|
|
);
|
|
|
|
let results = (outs HLO_Tensor);
|
|
}
|
|
|
|
def HLO_ConcatenateOp : HLO_Op<"concatenate",
|
|
[NoSideEffect, SameOperandsAndResultElementType, DeclareOpInterfaceMethods<InferTypeOpInterface>]>, BASE_HLO_ConcatenateOp {
|
|
|
|
let arguments = (ins
|
|
Variadic<HLO_Tensor>:$val,
|
|
I64Attr: $dimension
|
|
);
|
|
|
|
let results = (outs HLO_Tensor);
|
|
|
|
let hasCanonicalizer = 1;
|
|
let hasFolder = 1;
|
|
|
|
}
|
|
|
|
def HLO_CollectivePermuteOp: HLO_Op<"collective_permute",
|
|
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_CollectivePermuteOp {
|
|
|
|
let arguments = (ins
|
|
HLO_Tensor:$operand,
|
|
I64ElementsAttr:$source_target_pairs
|
|
);
|
|
let results = (outs HLO_Tensor);
|
|
}
|
|
|
|
// TODO(hinsu): Make this struct dialect independent so that it can be shared
|
|
// between HLO and LHLO dialect.
|
|
def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", HLO_Dialect, [
|
|
StructFieldAttr<"input_batch_dimension",I64Attr>,
|
|
StructFieldAttr<"input_feature_dimension", I64Attr>,
|
|
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
|
|
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
|
|
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
|
|
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
|
|
StructFieldAttr<"output_batch_dimension", I64Attr>,
|
|
StructFieldAttr<"output_feature_dimension", I64Attr>,
|
|
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
|
|
|
|
let description = "Structure of dimension information for conv op";
|
|
}
|
|
|
|
def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp {
|
|
let arguments = (ins
|
|
HLO_Tensor:$lhs,
|
|
HLO_Tensor:$rhs,
|
|
// Default value: one for each of the spatial dimension.
|
|
OptionalAttr<I64ElementsAttr>:$window_strides,
|
|
// Default value: zero for each of the spatial dimension.
|
|
OptionalAttr<I64ElementsAttr>:$padding,
|
|
// Default value: one for each of the spatial dimension.
|
|
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
|
|
// Default value: one for each of the spatial dimension.
|
|
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
|
|
ConvDimensionNumbers:$dimension_numbers,
|
|
I64Attr:$feature_group_count,
|
|
I64Attr:$batch_group_count,
|
|
HLO_PrecisionConfigAttr:$precision_config
|
|
);
|
|
|
|
let results = (outs HLO_Tensor);
|
|
}
|
|
|
|
def HLO_CopyOp: HLO_Op<"copy", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_CopyOp {
|
|
let arguments = (ins HLO_Tensor);
|
|
let results = (outs HLO_Tensor);
|
|
let hasFolder = 1;
|
|
}
|
|
|
|
def HLO_CrossReplicaSumOp : HLO_Op<"cross-replica-sum",
|
|
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_CrossReplicaSumOp {
|
|
|
|
let arguments = (ins
|
|
HLO_Tensor:$operand,
|
|
I64ElementsAttr:$replica_groups
|
|
);
|
|
|
|
let results = (outs HLO_Tensor);
|
|
}
|
|
|
|
def HLO_CustomCallOp: HLO_Op<"custom_call", []>, BASE_HLO_CustomCallOp {
|
|
let arguments = (ins
|
|
Variadic<HLO_Tensor>:$args,
|
|
StrAttr:$call_target_name,
|
|
DefaultValuedAttr<BoolAttr, "false">:$has_side_effect,
|
|
DefaultValuedAttr<StrAttr, "">:$backend_config
|
|
);
|
|
let results = (outs HLO_Tensor);
|
|
let hasCustomHLOConverter = 1;
|
|
}
|
|
|
|
def HLO_DotOp: HLO_Op<"dot", [NoSideEffect]>, BASE_HLO_DotOp {
|
|
let arguments = (
|
|
ins HLO_Tensor:$lhs,
|
|
HLO_Tensor:$rhs,
|
|
HLO_PrecisionConfigAttr:$precision_config
|
|
);
|
|
let results = (outs HLO_Tensor);
|
|
}
|
|
|
|
def DotDimensionNumbers : StructAttr<"DotDimensionNumbers", HLO_Dialect, [
|
|
StructFieldAttr<"lhs_batching_dimensions", I64ElementsAttr>,
|
|
StructFieldAttr<"rhs_batching_dimensions", I64ElementsAttr>,
|
|
StructFieldAttr<"lhs_contracting_dimensions", I64ElementsAttr>,
|
|
StructFieldAttr<"rhs_contracting_dimensions", I64ElementsAttr>
|
|
]> {
|
|
let description = "Structure of dimension information for dot product";
|
|
}
|
|
|
|
def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]>, BASE_HLO_DotGeneralOp {
|
|
let arguments = (ins
|
|
HLO_Tensor:$lhs,
|
|
HLO_Tensor:$rhs,
|
|
DotDimensionNumbers:$dot_dimension_numbers,
|
|
HLO_PrecisionConfigAttr:$precision_config
|
|
);
|
|
|
|
let results = (outs HLO_Tensor);
|
|
let verifier = [{ return Verify(*this); }];
|
|
}
|
|
|
|
// Define Base Einsum op within the HLO dialect as these are client ops and
|
|
// therefore this class is not common between HLO and LHLO ops.
|
|
class BASE_EinsumOp {
|
|
string summary = "Einsum operator";
|
|
|
|
string description = [{
|
|
Returns a tensor whose elements are defined by equation, which is written
|
|
in a shorthand form inspired by the Einstein summation convention.
|
|
}];
|
|
}
|
|
|
|
def HLO_EinsumOp: HLO_Op<"einsum", [NoSideEffect]>, BASE_EinsumOp {
|
|
let arguments = (ins
|
|
HLO_Tensor:$lhs,
|
|
HLO_Tensor:$rhs,
|
|
StrAttr:$einsum_config
|
|
);
|
|
|
|
let results = (outs HLO_Tensor);
|
|
|
|
// TODO(hinsu): Canonicalize to lower this client side HLO op to server
|
|
// side HLO ops.
|
|
}
|
|
|
|
def HLO_UnaryEinsumOp: HLO_Op<"unary_einsum", [NoSideEffect]>, BASE_EinsumOp {
|
|
let arguments = (ins
|
|
HLO_Tensor:$operand,
|
|
StrAttr:$einsum_config
|
|
);
|
|
|
|
let results = (outs HLO_Tensor);
|
|
|
|
let hasCanonicalizer = 1;
|
|
|
|
// UnaryEinsumOp is unconditionally canonicalized to the binary EinsumOp so
|
|
// the HLO converter shouldn't be invoked.
|
|
let hasCustomHLOConverter = 1;
|
|
}
|
|
|
|
def HLO_FftOp: HLO_Op<"fft", [NoSideEffect]>, BASE_HLO_FftOp {
|
|
let arguments = (ins
|
|
HLO_Tensor:$operand,
|
|
HLO_FftTypeAttr: $fft_type,
|
|
I64ElementsAttr:$fft_length
|
|
);
|
|
|
|
let results = (outs HLO_Tensor);
|
|
}
|
|
|
|
def GatherDimensionNumbers : StructAttr<"GatherDimensionNumbers", HLO_Dialect,
|
|
[StructFieldAttr<"offset_dims", I64ElementsAttr>,
|
|
StructFieldAttr<"collapsed_slice_dims", I64ElementsAttr>,
|
|
StructFieldAttr<"start_index_map", I64ElementsAttr>,
|
|
StructFieldAttr<"index_vector_dim", I64Attr>]> {
|
|
let description = "Structure of dimension information for gather";
|
|
}
|
|
|
|
def HLO_GatherOp: HLO_Op<"gather", [NoSideEffect]>, BASE_HLO_GatherOp {
|
|
let arguments = (ins
|
|
HLO_Tensor:$operand,
|
|
HLO_IntTensor:$start_indices,
|
|
GatherDimensionNumbers:$dimension_numbers,
|
|
I64ElementsAttr:$slice_sizes,
|
|
DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted
|
|
);
|
|
|
|
let results = (outs HLO_Tensor);
|
|
}
|
|
|
|
def HLO_GetDimensionSizeOp: HLO_Op<"get_dimension_size", [NoSideEffect]>,
|
|
BASE_HLO_GetDimensionSizeOp {
|
|
let arguments = (ins
|
|
HLO_Tensor:$operand,
|
|
I32Attr:$dimension
|
|
);
|
|
// TODO(hinsu): Allow 64-bit result types once XLA HLO dialect based on the
|
|
// XLA semantics is available. This limitation is because of the current XLA
|
|
// implementation.
|
|
let results = (outs I32Tensor);
|
|
}
|
|
|
|
def HLO_MapOp: HLO_Op<"map",
|
|
[RecursiveSideEffects, SameOperandsElementType,
|
|
SameOperandsAndResultShape, SingleBlockImplicitTerminator<"ReturnOp">]>,
|
|
BASE_HLO_MapOp {
|
|
let arguments = (ins
|
|
Variadic<HLO_Tensor>:$operands,
|
|
I64ElementsAttr:$dimensions
|
|
);
|
|
let regions = (region SizedRegion<1>:$computation);
|
|
let results = (outs HLO_Tensor);
|
|
let hasCustomHLOConverter = 1;
|
|
}
|
|
|
|
def HLO_ReshapeOp: HLO_Op<"reshape",
|
|
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ReshapeOp {
|
|
let arguments = (ins HLO_Tensor:$operand);
|
|
|
|
let results = (outs HLO_StaticShapeTensor);
|
|
let hasFolder = 1;
|
|
|
|
let hasCustomHLOConverter = 1;
|
|
}
|
|
|
|
def HLO_DynamicReshapeOp: HLO_Op<"dynamic_reshape", [NoSideEffect]> {
|
|
let summary = "Reshape a tensor to a given, possibly dynamic, shape.";
|
|
let description = [{
|
|
Reshapes `operand` to `output_shape`.
|
|
|
|
Requires:
|
|
- The length of `output_shape` is equal to the rank of `result`.
|
|
- The number of elements in `operand` (that is, the product of extents of
|
|
its shape) is equal to the number of elements in `output_shape` (that is,
|
|
the product of values in `output_shape`).
|
|
}];
|
|
|
|
let arguments = (ins HLO_Tensor:$operand, HLO_DimensionTensor:$output_shape);
|
|
let results = (outs HLO_Tensor:$result);
|
|
|
|
let hasCanonicalizer = 1;
|
|
// Cannot be exported to legacy formats.
|
|
let hasCustomHLOConverter = 1;
|
|
}
|
|
|
|
def HLO_ScatterOp: HLO_Op<"scatter", [RecursiveSideEffects]>,
|
|
BASE_HLO_ScatterOp {
|
|
let arguments = (ins
|
|
HLO_Tensor:$operand,
|
|
HLO_Tensor:$scatter_indices,
|
|
HLO_Tensor:$updates,
|
|
ScatterDimensionNumbers<HLO_Dialect>:$scatter_dimension_numbers,
|
|
DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted,
|
|
DefaultValuedAttr<BoolAttr, "false">:$unique_indices
|
|
);
|
|
|
|
let regions = (region SizedRegion<1>:$update_computation);
|
|
|
|
let results = (outs HLO_Tensor);
|
|
|
|
let hasCustomHLOConverter = 1;
|
|
}
|
|
|
|
// TODO(jpienaar): Add broadcastable trait.
|
|
def HLO_SelectOp: HLO_Op<"select", [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]>, BASE_HLO_SelectOp {
|
|
let arguments = (ins
|
|
HLO_PredTensor:$pred,
|
|
HLO_Tensor:$on_true,
|
|
HLO_Tensor:$on_false
|
|
);
|
|
|
|
let results = (outs HLO_Tensor);
|
|
}
|
|
|
|
def HLO_SelectAndScatterOp: HLO_Op<"select_and_scatter",
|
|
[RecursiveSideEffects]>, BASE_HLO_SelectAndScatterOp {
|
|
let arguments = (ins
|
|
HLO_Tensor:$operand,
|
|
HLO_Tensor:$source,
|
|
HLO_Tensor:$init_value,
|
|
OptionalAttr<I64ElementsAttr>:$window_dimensions,
|
|
OptionalAttr<I64ElementsAttr>:$window_strides,
|
|
OptionalAttr<I64ElementsAttr>:$padding
|
|
);
|
|
|
|
let regions = (region SizedRegion<1>:$select, SizedRegion<1>:$scatter);
|
|
|
|
let results = (outs HLO_Tensor);
|
|
|
|
let hasCustomHLOConverter = 1;
|
|
}
|
|
|
|
def HLO_SetDimensionSizeOp: HLO_Op<"set_dimension_size", [NoSideEffect]>,
|
|
BASE_HLO_SetDimensionSizeOp {
|
|
let arguments = (ins
|
|
HLO_Tensor:$operand,
|
|
I32Tensor:$size,
|
|
I32Attr:$dimension
|
|
);
|
|
let results = (outs HLO_Tensor);
|
|
}
|
|
|
|
def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects]>, BASE_HLO_SortOp {
|
|
let arguments = (ins
|
|
Variadic<HLO_Tensor>:$operands,
|
|
DefaultValuedAttr<I64Attr, "-1">:$dimension,
|
|
DefaultValuedAttr<BoolAttr, "false">:$is_stable
|
|
);
|
|
|
|
let results = (outs HLO_TensorOrTuple);
|
|
|
|
let regions = (region SizedRegion<1>:$comparator);
|
|
|
|
let builders = [OpBuilder<
|
|
"OpBuilder &builder, OperationState &state, ValueRange operands, "
|
|
"int64_t dimension = -1, bool is_stable = false"
|
|
>];
|
|
|
|
// TODO(b/129422361): SortOp has special conversion logic to HLO.
|
|
let hasCustomHLOConverter = 1;
|
|
}
|
|
|
|
def HLO_ReverseOp: HLO_Op<"reverse",
|
|
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ReverseOp {
|
|
let arguments = (ins
|
|
HLO_Tensor:$operand,
|
|
I64ElementsAttr:$dimensions
|
|
);
|
|
|
|
let results = (outs HLO_Tensor);
|
|
|
|
let hasFolder = 1;
|
|
}
|
|
|
|
def HLO_PadOp: HLO_Op<"pad",
|
|
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_PadOp {
|
|
let arguments = (ins
|
|
HLO_Tensor:$operand,
|
|
HLO_Tensor:$padding_value,
|
|
I64ElementsAttr: $edge_padding_low,
|
|
I64ElementsAttr: $edge_padding_high,
|
|
I64ElementsAttr: $interior_padding
|
|
);
|
|
|
|
let results = (outs HLO_Tensor);
|
|
|
|
let description = [{
|
|
Pads the `operand` according to TBD.
|
|
}];
|
|
|
|
// TODO(b/129422361): PadOp has a custom constructor for HLO.
|
|
let hasCustomHLOConverter = 1;
|
|
}
|
|
|
|
def HLO_TraceOp: HLO_Op<"trace", []>, BASE_HLO_TraceOp {
|
|
let arguments = (ins
|
|
HLO_Tensor:$operand,
|
|
StrAttr:$tag
|
|
);
|
|
let hasCustomHLOConverter = 1;
|
|
}
|
|
|
|
def HLO_TransposeOp: HLO_Op<"transpose",
|
|
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_TransposeOp {
|
|
let arguments = (ins
|
|
HLO_Tensor:$operand,
|
|
I64ElementsAttr:$permutation
|
|
);
|
|
let results = (outs HLO_Tensor);
|
|
|
|
let hasFolder = 1;
|
|
}
|
|
|
|
def HLO_TriangularSolveOp: HLO_Op<"triangular_solve",
|
|
[NoSideEffect, SameOperandsAndResultElementType]>,
|
|
BASE_HLO_TriangularSolveOp {
|
|
let arguments = (ins
|
|
HLO_FpOrComplexTensor:$a,
|
|
HLO_FpOrComplexTensor:$b,
|
|
BoolAttr:$left_side,
|
|
BoolAttr:$lower,
|
|
BoolAttr:$unit_diagonal,
|
|
HLO_TransposeAttr:$transpose_a
|
|
);
|
|
let results = (outs HLO_FpOrComplexTensor);
|
|
}
|
|
|
|
def HLO_ReduceWindowOp: HLO_Op<"reduce_window", [
|
|
RecursiveSideEffects,
|
|
SingleBlockImplicitTerminator<"ReturnOp">
|
|
]>, BASE_HLO_ReduceWindowOp {
|
|
|
|
// TODO(hinsu): Verify that padding attribute is 2-d and the remaining
|
|
// attributes are 1-d. Attributes' leading dimension should match rank of the
|
|
// inputs.
|
|
let arguments = (ins
|
|
HLO_Tensor:$operand,
|
|
HLO_Tensor:$init_value,
|
|
I64ElementsAttr:$window_dimensions,
|
|
// If strides or dilations attributes are missing then the default value is
|
|
// one for each of the input dimensions. Similarly, padding values are zero
|
|
// for both low and high in each of the dimensions, if not specified.
|
|
OptionalAttr<I64ElementsAttr>:$window_strides,
|
|
OptionalAttr<I64ElementsAttr>:$base_dilations,
|
|
OptionalAttr<I64ElementsAttr>:$window_dilations,
|
|
OptionalAttr<I64ElementsAttr>:$padding
|
|
);
|
|
|
|
let results = (outs HLO_Tensor);
|
|
|
|
// TODO(hinsu): Verify that the attached body arguments and results are
|
|
// compatible with reduce op's operands.
|
|
let regions = (region SizedRegion<1>:$body);
|
|
|
|
let hasCustomHLOConverter = 1;
|
|
|
|
// TODO(hinsu): Implement custom printer and parser.
|
|
}
|
|
|
|
def HLO_ReturnOp : HLO_Op<"return", [NoSideEffect, Terminator]> {
|
|
let summary = [{
|
|
The `hlo.return` operation terminates a region and returns values.
|
|
}];
|
|
|
|
let arguments = (ins
|
|
Variadic<HLO_TensorOrTuple>:$results
|
|
);
|
|
|
|
// Disable conversion operator for return op as the op is not an actual XLA
|
|
// instruction and is only used as a terminator for regions.
|
|
let hasCustomHLOConverter = 1;
|
|
|
|
// TODO(hinsu): Implement custom printer and parser.
|
|
}
|
|
|
|
def HLO_TorchIndexSelectOp : HLO_Op<"torch_index_select", [NoSideEffect]> {
|
|
let arguments = (ins
|
|
HLO_Tensor:$input,
|
|
HLO_Tensor:$index,
|
|
I64Attr:$dim,
|
|
I64Attr:$batch_dims
|
|
);
|
|
|
|
let results = (outs HLO_Tensor);
|
|
|
|
// TODO(hinsu): Canonicalize to lower this client side HLO op to server
|
|
// side HLO ops.
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MHLO RngUniform Operator.
|
|
//===----------------------------------------------------------------------===//
|
|
def HLO_RngUniformOp : HLO_Op<"rng_uniform", []>, BASE_HLO_RngUniformOp {
|
|
let arguments = (ins
|
|
HLO_PredIntOrFpTensor:$a,
|
|
HLO_PredIntOrFpTensor:$b,
|
|
HLO_DimensionTensor:$shape
|
|
);
|
|
|
|
let results = (outs HLO_PredIntOrFpTensor);
|
|
|
|
let hasCustomHLOConverter = 1;
|
|
}
|
|
|
|
def HLO_RngNormalOp : HLO_Op<"rng_normal", []>, BASE_HLO_RngNormalOp {
|
|
let arguments = (ins
|
|
HLO_FpTensor:$mu,
|
|
HLO_FpTensor:$sigma,
|
|
HLO_DimensionTensor:$shape
|
|
);
|
|
|
|
let results = (outs HLO_FpTensor);
|
|
|
|
let hasCustomHLOConverter = 1;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MHLO Quantize Operator.
|
|
//===----------------------------------------------------------------------===//
|
|
def HLO_DequantizeOp : HLO_Op<"dequantize", [NoSideEffect]>,
|
|
BASE_HLO_DequantizeOp {
|
|
let arguments = (ins
|
|
TensorOf<[I32]>:$input,
|
|
F32Attr:$min_range,
|
|
F32Attr:$max_range,
|
|
HLO_DequantizeModeAttr:$mode,
|
|
BoolAttr:$transpose_output,
|
|
DefaultValuedAttr<BoolAttr, "false">:$is_16bits
|
|
);
|
|
|
|
let results = (outs TensorOf<[BF16]>:$output);
|
|
|
|
let hasCustomHLOConverter = 1;
|
|
}
|
|
|
|
def HLO_FusionOp : HLO_Op<"fusion", []> {
|
|
let summary = "Fusion operator";
|
|
let description = [{
|
|
Models the fusion instruction.
|
|
|
|
A fusion op is consists of a group of basic ops (represented as a region
|
|
attached to it). It serves as a hint to the backend that it is beneficial
|
|
to emit the contained ops into a single loop nest or kernel.
|
|
}];
|
|
let regions = (region SizedRegion<1>:$fused_computation);
|
|
|
|
let arguments = (ins
|
|
Variadic<HLO_TensorOrTuple>:$operands,
|
|
OptionalAttr<HLO_FusionKindAttr>:$fusion_kind
|
|
);
|
|
|
|
let results = (outs
|
|
Variadic<HLO_TensorOrTuple>:$results
|
|
);
|
|
|
|
// FusionOp has special conversion logic to HLO.
|
|
let hasCustomHLOConverter = 1;
|
|
}
|
|
|
|
#endif // HLO_OPS
|