mlir-hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td

748 lines
26 KiB
TableGen

/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the operation definition file for LMHLO, the "late" MHLO variant of
// the dialect, which operates on buffers instead of tensors.
//
// This file largely overlaps with hlo_ops.td at a logical level. It's tempting
// to merge these two files together, but we need to consider the following
// obstacles:
// * We need to have a common representation for arguments. That is to say,
// HLO_Array<X> translates to HLO_Tensor<X> in HLO dialect, and
// Arg<LHLO_Buffer<X>, "", [Mem(Read|Write)]> in LHLO. Array types within
// tuples also need to be transformed.
// * As of now, TableGen's dag functions are not sufficient to accomplish the
// one above.
// * Traits aren't identical, but need to be copied. For example,
// SameOperandAndResultType in HLO corresponds to SameTypeOperands in LHLO.
// * Also, currently HLO describes the API in XLA's client side, not service
// side. LHLO aims for the service side.
#ifndef LHLO_OPS
#define LHLO_OPS
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/CopyOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir-hlo/Dialect/mhlo/IR/lhlo_dialect.td"
include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td"
include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.td"
//===----------------------------------------------------------------------===//
// LMHLO nullary op definitions.
//===----------------------------------------------------------------------===//
class LHLO_Op<string mnemonic, list<OpTrait> traits> :
Op<LHLO_Dialect, mnemonic,
!listconcat([MemoryEffects<[MemRead, MemWrite]>], traits)>;
def LHLO_ConstOp : LHLO_Op<"constant", []>, BASE_HLO_ConstOp {
let arguments = (ins
ElementsAttr:$value,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
let hasCanonicalizer = 1;
}
def LHLO_IotaOp : LHLO_Op<"iota", []>, BASE_HLO_IotaOp {
let arguments = (ins I64Attr:$iota_dimension,
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
}
//===----------------------------------------------------------------------===//
// LMHLO unary elementwise op definitions.
//===----------------------------------------------------------------------===//
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
class LHLO_UnaryElementwiseOp<string mnemonic,
Type BufferType = LHLO_Buffer,
list<OpTrait> traits = [SameTypeOperands]>
: LHLO_Op<mnemonic, traits> {
let arguments = (ins Arg<BufferType, "", [MemRead]>:$input,
Arg<BufferType, "", [MemWrite]>:$output);
}
def LHLO_AbsOp: LHLO_UnaryElementwiseOp<"abs">, BASE_HLO_AbsOp;
// TODO(timshen): add a custom verifier.
def LHLO_BitcastConvertOp:
LHLO_UnaryElementwiseOp<"bitcast_convert", LHLO_Buffer, [SameOperandsShape]>, BASE_HLO_BitcastConvertOp;
def LHLO_CbrtOp: LHLO_UnaryElementwiseOp<"cbrt", LHLO_FpBuffer>, BASE_HLO_CbrtOp;
def LHLO_CeilOp: LHLO_UnaryElementwiseOp<"ceil", LHLO_FpBuffer>, BASE_HLO_CeilOp;
def LHLO_ClzOp: LHLO_UnaryElementwiseOp<"count_leading_zeros", LHLO_IntBuffer>, BASE_HLO_ClzOp;
// TODO(timshen): add a custom verifier.
def LHLO_ConvertOp : LHLO_UnaryElementwiseOp<"convert", LHLO_Buffer, [SameOperandsShape]>, BASE_HLO_ConvertOp;
def LHLO_CosOp: LHLO_UnaryElementwiseOp<"cosine", LHLO_FpOrComplexBuffer>, BASE_HLO_CosOp;
def LHLO_ExpOp: LHLO_UnaryElementwiseOp<"exponential", LHLO_FpOrComplexBuffer>, BASE_HLO_ExpOp;
def LHLO_Expm1Op: LHLO_UnaryElementwiseOp<"exponential_minus_one", LHLO_FpOrComplexBuffer>, BASE_HLO_Expm1Op;
def LHLO_FloorOp: LHLO_UnaryElementwiseOp<"floor", LHLO_FpBuffer>, BASE_HLO_FloorOp;
def LHLO_ImagOp: LHLO_Op<"imag", [SameOperandsShape]>, BASE_HLO_ImagOp {
let arguments = (ins Arg<LHLO_ComplexBuffer, "", [MemRead]>:$input,
Arg<LHLO_FpBuffer, "", [MemWrite]>:$output);
}
def LHLO_IsFiniteOp: LHLO_Op<"is_finite", [SameOperandsShape]>, BASE_HLO_IsFiniteOp {
let arguments = (ins Arg<LHLO_FpBuffer, "", [MemRead]>:$input,
Arg<LHLO_PredBuffer, "", [MemWrite]>:$output);
}
def LHLO_LogOp: LHLO_UnaryElementwiseOp<"log", LHLO_FpOrComplexBuffer>, BASE_HLO_LogOp;
def LHLO_LogisticOp : LHLO_UnaryElementwiseOp<"logistic", LHLO_FpOrComplexBuffer>, BASE_HLO_LogisticOp;
def LHLO_Log1pOp: LHLO_UnaryElementwiseOp<"log_plus_one", LHLO_FpOrComplexBuffer>, BASE_HLO_Log1pOp;
def LHLO_NegOp: LHLO_UnaryElementwiseOp<"negate">, BASE_HLO_NegOp;
def LHLO_NotOp: LHLO_UnaryElementwiseOp<"not", LHLO_PredOrIntBuffer>, BASE_HLO_NotOp;
def LHLO_PopulationCountOp: LHLO_UnaryElementwiseOp<"popcnt", LHLO_IntBuffer>, BASE_HLO_PopulationCountOp;
def LHLO_RealOp: LHLO_Op<"real", [SameOperandsShape]>, BASE_HLO_RealOp {
let arguments = (ins Arg<LHLO_ComplexBuffer, "", [MemRead]>:$input,
Arg<LHLO_FpBuffer, "", [MemWrite]>:$output);
}
def LHLO_RoundOp: LHLO_UnaryElementwiseOp<"round_nearest_afz", LHLO_FpBuffer>, BASE_HLO_RoundOp;
def LHLO_RsqrtOp: LHLO_UnaryElementwiseOp<"rsqrt", LHLO_FpOrComplexBuffer>, BASE_HLO_RsqrtOp;
def LHLO_SqrtOp: LHLO_UnaryElementwiseOp<"sqrt", LHLO_FpOrComplexBuffer>, BASE_HLO_SqrtOp;
def LHLO_SignOp: LHLO_UnaryElementwiseOp<"sign">, BASE_HLO_SignOp;
def LHLO_SinOp: LHLO_UnaryElementwiseOp<"sine", LHLO_FpOrComplexBuffer>, BASE_HLO_SinOp;
def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh", LHLO_FpOrComplexBuffer>, BASE_HLO_TanhOp;
//===----------------------------------------------------------------------===//
// LMHLO binary elementwise op definitions.
//===----------------------------------------------------------------------===//
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
class LHLO_BinaryElementwiseOp<string mnemonic, Type BufferType = LHLO_Buffer,
list<OpTrait> traits = [SameTypeOperands]> :
LHLO_Op<mnemonic, traits> {
let arguments = (ins
Arg<BufferType, "", [MemRead]>:$lhs,
Arg<BufferType, "", [MemRead]>:$rhs,
Arg<BufferType, "", [MemWrite]>:$out,
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
);
}
def LHLO_AddOp : LHLO_BinaryElementwiseOp<"add">, BASE_HLO_AddOp;
def LHLO_AndOp: LHLO_BinaryElementwiseOp<"and", LHLO_PredOrIntBuffer>, BASE_HLO_AndOp;
def LHLO_Atan2Op : LHLO_BinaryElementwiseOp<"atan2", LHLO_FpOrComplexBuffer>, BASE_HLO_Atan2Op;
def LHLO_ComplexOp: LHLO_Op<"complex", [SameOperandsShape]>, BASE_HLO_ComplexOp {
let arguments = (ins
Arg<LHLO_FpBuffer, "", [MemRead]>:$lhs,
Arg<LHLO_FpBuffer, "", [MemRead]>:$rhs,
Arg<LHLO_ComplexBuffer, "", [MemWrite]>:$output,
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
);
}
def LHLO_DivOp : LHLO_BinaryElementwiseOp<"divide">, BASE_HLO_DivOp;
def LHLO_MaxOp : LHLO_BinaryElementwiseOp<"maximum">, BASE_HLO_MaxOp;
def LHLO_MinOp : LHLO_BinaryElementwiseOp<"minimum">, BASE_HLO_MinOp;
def LHLO_MulOp : LHLO_BinaryElementwiseOp<"multiply">, BASE_HLO_MulOp;
def LHLO_OrOp : LHLO_BinaryElementwiseOp<"or", LHLO_PredOrIntBuffer>, BASE_HLO_OrOp;
def LHLO_PowOp : LHLO_BinaryElementwiseOp<"power">, BASE_HLO_PowOp;
def LHLO_RemOp : LHLO_BinaryElementwiseOp<"remainder", LHLO_IntOrFpBuffer>, BASE_HLO_RemOp;
def LHLO_ShiftLeftOp : LHLO_BinaryElementwiseOp<"shift_left", LHLO_IntBuffer>, BASE_HLO_ShiftLeftOp;
def LHLO_ShiftRightArithmeticOp : LHLO_BinaryElementwiseOp<"shift_right_arithmetic", LHLO_IntBuffer>, BASE_HLO_ShiftRightArithmeticOp;
def LHLO_ShiftRightLogicalOp : LHLO_BinaryElementwiseOp<"shift_right_logical", LHLO_IntBuffer>, BASE_HLO_ShiftRightLogicalOp;
def LHLO_SubOp : LHLO_BinaryElementwiseOp<"subtract">, BASE_HLO_SubOp;
def LHLO_XorOp : LHLO_BinaryElementwiseOp<"xor", LHLO_PredOrIntBuffer>, BASE_HLO_XorOp;
//===----------------------------------------------------------------------===//
// LMHLO control flow op definitions.
//===----------------------------------------------------------------------===//
// TODO(b/139813999): specify required function signature in a type-safe way.
//
// The region `body` may return lmhlo.TerminatorOp or mhlo.ReturnOp. We are
// moving towards mhlo.ReturnOp, but some code that needs cleanup still assumes lmhlo.TerminatorOp.
// TODO(timshen): cleanup lmhlo.TerminatorOp.
def LHLO_ReduceOp: LHLO_Op<"reduce", [SameVariadicOperandSize]>, BASE_HLO_ReduceOp {
let arguments = (ins
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$init_values,
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$out,
I64ElementsAttr:$dimensions
);
let regions = (region SizedRegion<1>:$body);
}
def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", []>, BASE_HLO_ReduceWindowOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$init_value,
Arg<LHLO_Buffer, "", [MemWrite]>:$out,
I64ElementsAttr:$window_dimensions,
// If strides or dilations attributes are missing then the default value is
// one for each of the input dimensions. Similarly, padding values are zero
// for both low and high in each of the dimensions, if not specified.
OptionalAttr<I64ElementsAttr>:$window_strides,
OptionalAttr<I64ElementsAttr>:$base_dilations,
OptionalAttr<I64ElementsAttr>:$window_dilations,
OptionalAttr<I64ElementsAttr>:$padding
);
let regions = (region SizedRegion<1>:$body);
}
// TODO(timshen): Add a custom parser to hide operand_segment_sizes. For example,
// A tuple-like pattern match syntax could work:
// lmhlo.case %index, (%input0, %input1, %input2), (%output0, %output1) {
// ...
// }, {
// ...
// } : (type_input0, type_input1, type_input2, type_output0, type_output1) -> ()
def LHLO_CaseOp: LHLO_Op<"case", [
AttrSizedOperandSegments,
SingleBlockImplicitTerminator<"TerminatorOp">
]>, BASE_HLO_CaseOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$index,
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$branch_operands,
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$out
);
let regions = (region VariadicRegion<SizedRegion<1>>:$branches);
}
// TODO(timshen): Add a custom syntax for this.
def LHLO_WhileOp: LHLO_Op<"while", [SameVariadicOperandSize]>,
BASE_HLO_WhileOp {
let arguments = (ins
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$val,
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$output
);
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
}
def LHLO_CustomCallOp : LHLO_Op<"custom_call", [AttrSizedOperandSegments]>,
BASE_HLO_CustomCallOp {
let arguments = (ins
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$args,
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$output,
StrAttr:$call_target_name,
DefaultValuedAttr<BoolAttr, "false">:$has_side_effect,
DefaultValuedAttr<StrAttr, "">:$backend_config,
OptionalAttr<CustomCallTargetArgMapping>:$target_arg_mapping
);
let verifier = [{ return Verify(*this); }];
}
//===----------------------------------------------------------------------===//
// LMHLO tuple op definitions.
//===----------------------------------------------------------------------===//
def LHLO_CompareOp: LHLO_Op<"compare", []>, BASE_HLO_CompareOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
Arg<LHLO_PredBuffer, "", [MemWrite]>:$out,
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions,
HLO_ComparisonDirectionAttr:$comparison_direction,
OptionalAttr<HLO_ComparisonTypeAttr>:$compare_type
);
}
//===----------------------------------------------------------------------===//
// LMHLO Slice definitions.
//===----------------------------------------------------------------------===//
def LHLO_SliceOp: LHLO_Op<
"slice",
[AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
I64ElementsAttr:$start_indices,
I64ElementsAttr:$limit_indices,
I64ElementsAttr:$strides
);
}
def LHLO_DynamicSliceOp: LHLO_Op<"dynamic_slice",
[AllElementTypesMatch<["operand", "output"]>]>, BASE_HLO_DynamicSliceOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$start_indices,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
I64ElementsAttr:$slice_sizes
);
}
def LHLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []>, BASE_HLO_DynamicUpdateSliceOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$update,
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$start_indices,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}
//===----------------------------------------------------------------------===//
// LMHLO Other op definitions.
//===----------------------------------------------------------------------===//
def LHLO_BatchNormGradOp : LHLO_Op<"batch_norm_grad", []>,
BASE_HLO_BatchNormGradOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
Arg<LHLO_Buffer, "", [MemRead]>:$mean,
Arg<LHLO_Buffer, "", [MemRead]>:$variance,
Arg<LHLO_Buffer, "", [MemRead]>:$grad_output,
Arg<LHLO_Buffer, "", [MemWrite]>:$grad_operand, // gradient of $operand.
Arg<LHLO_Buffer, "", [MemWrite]>:$grad_scale,
Arg<LHLO_Buffer, "", [MemWrite]>:$grad_offset,
F32Attr:$epsilon,
I64Attr:$feature_index
);
}
def LHLO_BatchNormInferenceOp : LHLO_Op<"batch_norm_inference", []>,
BASE_HLO_BatchNormInferenceOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
Arg<LHLO_Buffer, "", [MemRead]>:$offset,
Arg<LHLO_Buffer, "", [MemRead]>:$mean,
Arg<LHLO_Buffer, "", [MemRead]>:$variance,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
F32Attr:$epsilon,
I64Attr:$feature_index
);
}
def LHLO_BatchNormTrainingOp : LHLO_Op<"batch_norm_training", []>,
BASE_HLO_BatchNormTrainingOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
Arg<LHLO_Buffer, "", [MemRead]>:$offset,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
Arg<LHLO_Buffer, "", [MemWrite]>:$batch_mean,
Arg<LHLO_Buffer, "", [MemWrite]>:$batch_var,
F32Attr:$epsilon,
I64Attr:$feature_index
);
}
def LHLO_BroadcastOp : LHLO_Op<"broadcast",
[]>, BASE_HLO_BroadcastOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
I64ElementsAttr:$broadcast_sizes
);
}
def LHLO_BroadcastInDimOp : LHLO_Op<"broadcast_in_dim",
[]>, BASE_HLO_BroadcastInDimOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
BroadcastDimAttr:$broadcast_dimensions
);
}
def LHLO_ClampOp : LHLO_Op<"clamp", []>, BASE_HLO_ClampOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$min,
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$max,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}
def LHLO_ConcatenateOp : LHLO_Op<"concatenate", []>, BASE_HLO_ConcatenateOp {
let arguments = (ins
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$val,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
I64Attr:$dimension
);
}
def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp {
let arguments = !con(
(ins
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
Arg<LHLO_Buffer, "", [MemWrite]>:$output),
ConvolutionAttributes.attributes);
}
def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]>, BASE_HLO_CopyOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
let extraClassDeclaration = [{
Value getSource() { return operand();}
Value getTarget() { return output(); }
}];
}
def LHLO_DotOp: LHLO_Op<"dot", []>, BASE_HLO_DotOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
DotDimensionNumbers:$dot_dimension_numbers,
HLO_PrecisionConfigAttr:$precision_config,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}
def LHLO_GatherOp: LHLO_Op<"gather", []>, BASE_HLO_GatherOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_IntBuffer, "", [MemRead]>:$start_indices,
GatherDimensionNumbers:$dimension_numbers,
I64ElementsAttr:$slice_sizes,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}
def LHLO_ReshapeOp: LHLO_Op<"reshape", []>, BASE_HLO_ReshapeOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}
def LHLO_ScatterOp: LHLO_Op<"scatter", []>, BASE_HLO_ScatterOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$scatter_indices,
Arg<LHLO_Buffer, "", [MemRead]>:$updates,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
ScatterDimensionNumbers:$scatter_dimension_numbers,
DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted,
DefaultValuedAttr<BoolAttr, "false">:$unique_indices
);
let regions = (region SizedRegion<1>:$update_computation);
}
def LHLO_SelectOp: LHLO_Op<"select", []>, BASE_HLO_SelectOp {
let arguments = (ins
Arg<LHLO_PredBuffer, "", [MemRead]>:$pred,
Arg<LHLO_Buffer, "", [MemRead]>:$on_true,
Arg<LHLO_Buffer, "", [MemRead]>:$on_false,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}
def LHLO_SelectAndScatterOp: LHLO_Op<"select_and_scatter", []>,
BASE_HLO_SelectAndScatterOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$source,
Arg<LHLO_Buffer, "", [MemRead]>:$init_value,
Arg<LHLO_Buffer, "", [MemWrite]>:$out,
OptionalAttr<I64ElementsAttr>:$window_dimensions,
OptionalAttr<I64ElementsAttr>:$window_strides,
OptionalAttr<I64ElementsAttr>:$padding
);
let regions = (region SizedRegion<1>:$select, SizedRegion<1>:$scatter);
}
def LHLO_ReverseOp: LHLO_Op<"reverse", []>, BASE_HLO_ReverseOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
I64ElementsAttr:$dimensions,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}
def LHLO_PadOp: LHLO_Op<"pad", []>, BASE_HLO_PadOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$padding_value,
I64ElementsAttr:$edge_padding_low,
I64ElementsAttr:$edge_padding_high,
I64ElementsAttr:$interior_padding,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}
def LHLO_TransposeOp: LHLO_Op<"transpose", []>, BASE_HLO_TransposeOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
I64ElementsAttr:$permutation,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}
def LHLO_ReducePrecisionOp: LHLO_Op<"reduce_precision", [SameTypeOperands]>,
BASE_HLO_ReducePrecisionOp {
let arguments = (ins
Arg<LHLO_FpBuffer, "", [MemRead]>:$operand,
Arg<LHLO_FpBuffer, "", [MemWrite]>:$output,
I32Attr:$exponent_bits,
I32Attr:$mantissa_bits
);
}
// Common base class for AllReduce, AllGather, and AllToAll.
class LHLO_CollectiveCommunicationOp<string name, list<OpTrait> traits = []> :
LHLO_Op<name, !listconcat(traits, [SameVariadicOperandSize, SameOperandsElementType])> {
dag arguments_base = (ins
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$results,
I64ElementsAttr:$replica_groups,
DefaultValuedAttr<BoolAttr, "false">:$constrain_layout,
OptionalAttr<ChannelHandle>:$channel_id,
DefaultValuedAttr<BoolAttr, "false">:$use_global_device_ids
);
let verifier = [{ return Verify(*this); }];
let extraClassDeclaration = [{
// AllGather is cross replica if channel_id is not set.
bool IsCrossReplica() { return !channel_id().hasValue(); }
}];
}
def LHLO_AllGatherOp : LHLO_CollectiveCommunicationOp<"all_gather">,
BASE_HLO_AllGatherOp {
let arguments = !con(
arguments_base,
(ins I64Attr:$all_gather_dimension));
}
def LHLO_AllReduceOp : LHLO_CollectiveCommunicationOp<"all_reduce">,
BASE_HLO_AllReduceOp {
let arguments = arguments_base;
let regions = (region SizedRegion<1>:$computation);
}
def LHLO_AllToAllOp : LHLO_CollectiveCommunicationOp<"all_to_all">,
BASE_HLO_AllToAllOp {
let arguments = !con(
arguments_base,
(ins OptionalAttr<I64Attr>:$split_dimension));
}
def LHLO_CollectivePermuteOp: LHLO_Op<"collective_permute", [SameTypeOperands]>,
BASE_HLO_CollectivePermuteOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
I64ElementsAttr:$source_target_pairs,
OptionalAttr<ChannelHandle>:$channel_id
);
}
def LHLO_FftOp: LHLO_Op<"fft", []>, BASE_HLO_FftOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
HLO_FftTypeAttr:$fft_type,
I64ElementsAttr:$fft_length
);
}
def LHLO_CholeskyOp: LHLO_Op<"cholesky", [SameOperandsElementType]>, BASE_HLO_CholeskyOp {
let arguments = (ins
Arg<LHLO_FpOrComplexBuffer, "", [MemRead]>:$a,
Arg<LHLO_FpOrComplexBuffer, "", [MemWrite]>:$output,
DefaultValuedAttr<BoolAttr, "false">:$lower
);
}
def LHLO_InfeedOp: LHLO_Op<"infeed", []>, BASE_HLO_InfeedOp {
let arguments = (ins
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$outputs,
DefaultValuedAttr<StrAttr, "">:$config
);
}
def LHLO_OutfeedOp: LHLO_Op<"outfeed", []> {
let arguments = (ins
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
DefaultValuedAttr<StrAttr, "">:$config
);
}
def LHLO_ReplicaIdOp : LHLO_Op<"replica_id", []>, BASE_HLO_ReplicaIdOp {
let arguments = (ins Arg<MemRefOf<[UI32]>, "", [MemWrite]>);
}
def LHLO_PartitionIdOp : LHLO_Op<"partition_id", []>, BASE_HLO_PartitionIdOp {
let arguments = (ins Arg<MemRefOf<[UI32]>, "", [MemWrite]>);
}
def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType]>,
BASE_HLO_TriangularSolveOp {
let arguments = (ins
Arg<LHLO_FpOrComplexBuffer, "", [MemRead]>:$a,
Arg<LHLO_FpOrComplexBuffer, "", [MemRead]>:$b,
Arg<LHLO_FpOrComplexBuffer, "", [MemWrite]>:$output,
BoolAttr:$left_side,
BoolAttr:$lower,
BoolAttr:$unit_diagonal,
HLO_TransposeAttr:$transpose_a,
HLO_LayoutAttr:$layout_a,
HLO_LayoutAttr:$layout_b,
HLO_LayoutAttr:$layout_output
);
}
// TODO(timshen): add a custom verifier.
def LHLO_MapOp: LHLO_Op<"map", [SameOperandsShape]>, BASE_HLO_MapOp {
let arguments = (ins
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
I64ElementsAttr:$dimensions
);
let regions = (region SizedRegion<1>:$computation);
}
def LHLO_RngGetAndUpdateStateOp: LHLO_Op<"rng_get_and_update_state", []> {
let arguments = (ins
Arg<MemRefOf<[UI64]>, "", [MemRead, MemWrite]>:$state,
I64Attr:$delta
);
}
// TODO(timshen): add a custom verifier.
def LHLO_SortOp: LHLO_Op<"sort", [SameVariadicOperandSize, SameOperandsShape]>, BASE_HLO_SortOp {
let arguments = (ins
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$output,
DefaultValuedAttr<I64Attr, "-1">:$dimension,
DefaultValuedAttr<BoolAttr, "false">:$is_stable
);
let regions = (region SizedRegion<1>:$comparator);
}
//===----------------------------------------------------------------------===//
// Late operations
//===----------------------------------------------------------------------===//
def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]> {
let summary = "Fusion operator";
let description = [{
Models the fusion instruction generated by the XLA compiler's fusion pass.
Fusion instructions are generated by the fusion pass of the XLA compiler.
They serve as a hint to the backend that it is beneficial to emit the
contained instructions into a single loop nest or kernel. The XLA fusion
pass is designed such that it only generates fusion nodes that can be
handled by the XLA compilers backends.
The XLA runtime expects this hint to be followed, as it expects a single
kernel per HLO instruction. This restriction might be lifted in the future.
}];
let regions = (region SizedRegion<1>:$region);
let skipDefaultBuilders = 1;
let builders = [
OpBuilderDAG<(ins CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
];
let extraClassDeclaration = [{
SmallVector<Value, 4> getInputBuffers() {
SmallVector<Value, 4> buffers;
this->region().walk([&](TensorLoadOp load) {
if (load.memref().getParentRegion()->isProperAncestor(&region()))
buffers.push_back(load.memref());
});
return buffers;
}
SmallVector<Value, 4> getOutputBuffers() {
SmallVector<Value, 4> buffers;
this->region().walk([&](TensorStoreOp store) {
if (store.memref().getParentRegion()->isProperAncestor(&region()))
buffers.push_back(store.memref());
});
return buffers;
}
SmallVector<Value, 4> getFusionParameters() {
SmallVector<Value, 4> buffers;
this->region().walk([&](TensorLoadOp load) {
if (load.memref().getParentRegion()->isProperAncestor(&region()))
buffers.push_back(load);
});
return buffers;
}
SmallVector<Value, 4> getFusionResults() {
SmallVector<Value, 4> buffers;
this->region().walk([&](TensorStoreOp store) {
if (store.memref().getParentRegion()->isProperAncestor(&region()))
buffers.push_back(store.tensor());
});
return buffers;
}
}];
}
def TerminatorOp :
LHLO_Op<"terminator", [Terminator]> {
let summary = "LHLO termination operation";
let description = [{
Terminator operation for the LHLO dialect.
}];
let builders = [
OpBuilderDAG<(ins "ValueRange":$operands),
[{ build($_builder, $_state, llvm::None, operands, llvm::None); }]>];
}
#endif // LHLO_OPS