1367 lines
47 KiB
TableGen
1367 lines
47 KiB
TableGen
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
// This is the operation definition file for LMHLO, the "late" MHLO variant of
|
|
// the dialect, which operates on buffers instead of tensors.
|
|
//
|
|
// This file largely overlaps with hlo_ops.td at a logical level. It's tempting
|
|
// to merge these two files together, but we need to consider the following
|
|
// obstacles:
|
|
// * We need to have a common representation for arguments. That is to say,
|
|
// HLO_Array<X> translates to HLO_Tensor<X> in HLO dialect, and
|
|
// Arg<LHLO_Buffer<X>, "", [Mem(Read|Write)]> in LHLO. Array types within
|
|
// tuples also need to be transformed.
|
|
// * As of now, TableGen's dag functions are not sufficient to accomplish the
|
|
// one above.
|
|
// * Traits aren't identical, but need to be copied. For example,
|
|
// SameOperandAndResultType in HLO corresponds to SameTypeOperands in LHLO.
|
|
// * Also, currently HLO describes the API in XLA's client side, not service
|
|
// side. LHLO aims for the service side.
|
|
|
|
#ifndef LHLO_OPS
|
|
#define LHLO_OPS
|
|
|
|
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
|
|
include "mlir/IR/OpBase.td"
|
|
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
|
include "mlir/Interfaces/CopyOpInterface.td"
|
|
include "mlir/Interfaces/LoopLikeInterface.td"
|
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
|
include "mlir/Interfaces/ViewLikeInterface.td"
|
|
include "mlir-hlo/Dialect/mhlo/IR/lhlo_dialect.td"
|
|
include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td"
|
|
include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.td"
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// LMHLO nullary op definitions.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
class LHLO_Op<string mnemonic, list<OpTrait> traits> :
|
|
Op<LHLO_Dialect, mnemonic,
|
|
!listconcat([MemoryEffects<[MemRead, MemWrite]>], traits)>;
|
|
|
|
def LHLO_ConstOp : LHLO_Op<"constant", []> {
|
|
let summary = "Constant operator";
|
|
let description = [{
|
|
Represents a constant value.
|
|
}];
|
|
let arguments = (ins
|
|
ElementsAttr:$value,
|
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
|
);
|
|
|
|
let hasCanonicalizer = 1;
|
|
}
|
|
|
|
def LHLO_IotaOp : LHLO_Op<"iota", []> {
|
|
let summary = "Iota operator";
|
|
let description = [{
|
|
Creates a rank 1 array of values starting at zero and incrementing by one.
|
|
}];
|
|
let arguments = (ins I64Attr:$iota_dimension,
|
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// LMHLO unary elementwise op definitions.
|
|
//===----------------------------------------------------------------------===//
|
|
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
|
|
|
|
class LHLO_UnaryElementwiseOp<string mnemonic,
|
|
Type BufferType = LHLO_Buffer,
|
|
list<OpTrait> traits = [SameTypeOperands]>
|
|
: LHLO_Op<mnemonic, traits> {
|
|
let arguments = (ins Arg<BufferType, "", [MemRead]>:$input,
|
|
Arg<BufferType, "", [MemWrite]>:$output);
|
|
}
|
|
|
|
// Abs supports complex to real, so element type is not guaranteed to match.
|
|
def LHLO_AbsOp: LHLO_UnaryElementwiseOp<"abs", LHLO_Buffer, [SameOperandsShape]> {
|
|
let summary = "Absolute value operator";
|
|
let description = [{
|
|
Returns `abs(operand)` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
|
|
}];
|
|
let verifier = [{ return Verify(*this); }];
|
|
}
|
|
|
|
// TODO(timshen): add a custom verifier.
|
|
def LHLO_BitcastConvertOp:
|
|
LHLO_UnaryElementwiseOp<"bitcast_convert", LHLO_Buffer, [SameOperandsShape]> {
|
|
let summary = "BitcastConvert operator";
|
|
let description = [{
|
|
Similar to a 'tf.bitcast' in TensorFlow, performs an element-wise bitcast
|
|
operation from a data shape to a target shape. The dimensions must match,
|
|
and the conversion is an element-wise one. Bitcast is implemented as a
|
|
low-level cast, so machines with different floating-point representations
|
|
will give different results.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#bitcastconverttype.
|
|
}];
|
|
}
|
|
def LHLO_CbrtOp: LHLO_UnaryElementwiseOp<"cbrt", LHLO_FpBuffer> {
|
|
let summary = "Cubic root operator";
|
|
let description = [{
|
|
Returns element-wise cubic root of the operand.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
|
|
}];
|
|
}
|
|
def LHLO_CeilOp: LHLO_UnaryElementwiseOp<"ceil", LHLO_FpBuffer> {
|
|
let summary = "Ceil operator";
|
|
let description = [{
|
|
Returns `Ceil(operand)` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
|
|
}];
|
|
}
|
|
def LHLO_ClzOp: LHLO_UnaryElementwiseOp<"count_leading_zeros", LHLO_IntBuffer> {
|
|
let summary = "Count-leading-zeros (Clz) operator";
|
|
let description = [{
|
|
Returns the number of leading zeros in each operand element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
|
|
}];
|
|
}
|
|
// TODO(timshen): add a custom verifier.
|
|
def LHLO_ConvertOp : LHLO_UnaryElementwiseOp<"convert", LHLO_Buffer, [SameOperandsShape]> {
|
|
let summary = "Convert operator";
|
|
let description = [{
|
|
Performs element-wise conversion of values from one type to another, e.g.
|
|
float to int.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#convertelementtype.
|
|
}];
|
|
}
|
|
def LHLO_CosOp: LHLO_UnaryElementwiseOp<"cosine", LHLO_FpOrComplexBuffer> {
|
|
let summary = "Cos operator";
|
|
let description = [{
|
|
Returns `Cos(operand)` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
|
|
}];
|
|
}
|
|
def LHLO_ExpOp: LHLO_UnaryElementwiseOp<"exponential", LHLO_FpOrComplexBuffer> {
|
|
let summary = "Exponential operator";
|
|
let description = [{
|
|
Returns `e^(operand)` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
|
|
}];
|
|
}
|
|
def LHLO_Expm1Op: LHLO_UnaryElementwiseOp<"exponential_minus_one", LHLO_FpOrComplexBuffer> {
|
|
let summary = "Exponential minus one operator";
|
|
let description = [{
|
|
Returns `e^(operand) - 1` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
|
|
}];
|
|
}
|
|
def LHLO_FloorOp: LHLO_UnaryElementwiseOp<"floor", LHLO_FpBuffer> {
|
|
let summary = "Floor operator";
|
|
let description = [{
|
|
Returns `Floor(operand)` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
|
|
}];
|
|
}
|
|
def LHLO_ImagOp: LHLO_Op<"imag", [SameOperandsShape]> {
|
|
let summary = "Imag operator";
|
|
let description = [{
|
|
Returns `Imag(operand)` element-wise.
|
|
}];
|
|
let arguments = (ins Arg<LHLO_ComplexBuffer, "", [MemRead]>:$input,
|
|
Arg<LHLO_FpBuffer, "", [MemWrite]>:$output);
|
|
}
|
|
|
|
def LHLO_IsFiniteOp: LHLO_Op<"is_finite", [SameOperandsShape]> {
|
|
let summary = "IsFinite operator";
|
|
let description = [{
|
|
Tests whether each element of operand is finite, i.e., is not positive or
|
|
negative infinity, and is not NaN. Returns a tensor of 1-bit integers with
|
|
the same shape as the input, where each element is nonzero (i.e. true) if
|
|
and only if the corresponding input element is finite.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
|
|
}];
|
|
let arguments = (ins Arg<LHLO_FpBuffer, "", [MemRead]>:$input,
|
|
Arg<LHLO_PredBuffer, "", [MemWrite]>:$output);
|
|
}
|
|
|
|
def LHLO_LogOp: LHLO_UnaryElementwiseOp<"log", LHLO_FpOrComplexBuffer> {
|
|
let summary = "Logarithm operator";
|
|
let description = [{
|
|
Returns `log(operand)` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
|
|
}];
|
|
}
|
|
def LHLO_LogisticOp : LHLO_UnaryElementwiseOp<"logistic", LHLO_FpOrComplexBuffer> {
|
|
let summary = "Logistic operator";
|
|
let description = [{
|
|
Returns `logistic(operand)` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
|
|
}];
|
|
}
|
|
def LHLO_Log1pOp: LHLO_UnaryElementwiseOp<"log_plus_one", LHLO_FpOrComplexBuffer> {
|
|
let summary = "Log1p operator";
|
|
let description = [{
|
|
Returns `log(operand+1)` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
|
|
}];
|
|
}
|
|
def LHLO_NegOp: LHLO_UnaryElementwiseOp<"negate"> {
|
|
let summary = "Negation operator";
|
|
let description = [{
|
|
Returns `-operand` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
|
|
}];
|
|
}
|
|
def LHLO_NotOp: LHLO_UnaryElementwiseOp<"not", LHLO_PredOrIntBuffer> {
|
|
let summary = "Not operator";
|
|
let description = [{
|
|
Returns `!operand` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
|
|
}];
|
|
}
|
|
def LHLO_PopulationCountOp: LHLO_UnaryElementwiseOp<"popcnt", LHLO_IntBuffer> {
|
|
let summary = "PopulationCount operator";
|
|
let description = [{
|
|
Returns the number of bits set in each operand element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
|
|
}];
|
|
}
|
|
def LHLO_RealOp: LHLO_Op<"real", [SameOperandsShape]> {
|
|
let summary = "Real operator";
|
|
let description = [{
|
|
Returns `Real(operand)` element-wise.
|
|
}];
|
|
let arguments = (ins Arg<LHLO_ComplexBuffer, "", [MemRead]>:$input,
|
|
Arg<LHLO_FpBuffer, "", [MemWrite]>:$output);
|
|
}
|
|
|
|
def LHLO_RoundOp: LHLO_UnaryElementwiseOp<"round_nearest_afz", LHLO_FpBuffer> {
|
|
let summary = "Round operator";
|
|
let description = [{
|
|
Returns `Round(operand)` element-wise, rounding to nearest integer with
|
|
half-way cases rounding away from zero.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
|
|
}];
|
|
}
|
|
def LHLO_RsqrtOp: LHLO_UnaryElementwiseOp<"rsqrt", LHLO_FpOrComplexBuffer> {
|
|
let summary = "Reciprocal Square-root operator";
|
|
let description = [{
|
|
Returns `1.0 / sqrt(operand)` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
|
|
}];
|
|
}
|
|
def LHLO_SqrtOp: LHLO_UnaryElementwiseOp<"sqrt", LHLO_FpOrComplexBuffer> {
|
|
let summary = "Square-root operator";
|
|
let description = [{
|
|
Returns `sqrt(operand)` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
|
|
}];
|
|
}
|
|
def LHLO_SignOp: LHLO_UnaryElementwiseOp<"sign"> {
|
|
let summary = "Sign operator";
|
|
let description = [{
|
|
Returns `sign(operand)` element-wise, where
|
|
|
|
```
|
|
sign(x) = -1 : x < 0
|
|
= -0 : x = -0
|
|
= NaN : x = NaN
|
|
= +0 : x = +0
|
|
= 1 : x > 0
|
|
```
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
|
|
}];
|
|
}
|
|
def LHLO_SinOp: LHLO_UnaryElementwiseOp<"sine", LHLO_FpOrComplexBuffer> {
|
|
let summary = "Sin operator";
|
|
let description = [{
|
|
Returns `Sin(operand)` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
|
|
}];
|
|
}
|
|
def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh", LHLO_FpOrComplexBuffer> {
|
|
let summary = "Tanh operator";
|
|
let description = [{
|
|
Returns `tanh(operand)` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
|
|
}];
|
|
}
|
|
//===----------------------------------------------------------------------===//
|
|
// LMHLO binary elementwise op definitions.
|
|
//===----------------------------------------------------------------------===//
|
|
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
|
|
|
|
class LHLO_BinaryElementwiseOp<string mnemonic, Type BufferType = LHLO_Buffer,
|
|
list<OpTrait> traits = [SameTypeOperands]> :
|
|
LHLO_Op<mnemonic, traits> {
|
|
let arguments = (ins
|
|
Arg<BufferType, "", [MemRead]>:$lhs,
|
|
Arg<BufferType, "", [MemRead]>:$rhs,
|
|
Arg<BufferType, "", [MemWrite]>:$out,
|
|
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
|
|
);
|
|
}
|
|
|
|
def LHLO_AddOp : LHLO_BinaryElementwiseOp<"add"> {
|
|
let summary = "Addition operator";
|
|
let description = [{
|
|
Returns `lhs + rhs` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
|
}];
|
|
}
|
|
def LHLO_AndOp: LHLO_BinaryElementwiseOp<"and", LHLO_PredOrIntBuffer> {
|
|
let summary = "Logical and";
|
|
let description = [{
|
|
Returns `logical_and(lhs, rhs)` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
|
}];
|
|
}
|
|
def LHLO_Atan2Op : LHLO_BinaryElementwiseOp<"atan2", LHLO_FpOrComplexBuffer> {
|
|
let summary = "Atan2 operator";
|
|
let description = [{
|
|
Returns `atan2(lhs/rhs)` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
|
}];
|
|
}
|
|
def LHLO_ComplexOp: LHLO_Op<"complex", [SameOperandsShape]> {
|
|
let summary = "Complex operator";
|
|
let description = [{
|
|
Performs element-wise conversion of a pair of real and imaginary values to
|
|
a complex value.
|
|
}];
|
|
let arguments = (ins
|
|
Arg<LHLO_FpBuffer, "", [MemRead]>:$lhs,
|
|
Arg<LHLO_FpBuffer, "", [MemRead]>:$rhs,
|
|
Arg<LHLO_ComplexBuffer, "", [MemWrite]>:$output,
|
|
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
|
|
);
|
|
}
|
|
|
|
def LHLO_DivOp : LHLO_BinaryElementwiseOp<"divide"> {
|
|
let summary = "Division operator";
|
|
let description = [{
|
|
Returns `lhs / rhs` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
|
}];
|
|
}
|
|
def LHLO_MaxOp : LHLO_BinaryElementwiseOp<"maximum"> {
|
|
let summary = "Maximum operator";
|
|
let description = [{
|
|
Returns `max(lhs, rhs)` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
|
}];
|
|
}
|
|
def LHLO_MinOp : LHLO_BinaryElementwiseOp<"minimum"> {
|
|
let summary = "Minimum operator";
|
|
let description = [{
|
|
Returns `min(lhs, rhs)` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
|
}];
|
|
}
|
|
def LHLO_MulOp : LHLO_BinaryElementwiseOp<"multiply"> {
|
|
let summary = "Multiplication operator";
|
|
let description = [{
|
|
Returns `lhs * rhs` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
|
}];
|
|
}
|
|
def LHLO_OrOp : LHLO_BinaryElementwiseOp<"or", LHLO_PredOrIntBuffer> {
|
|
let summary = "Logical or";
|
|
let description = [{
|
|
Returns `logical_or(lhs, rhs)` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
|
}];
|
|
}
|
|
def LHLO_PowOp : LHLO_BinaryElementwiseOp<"power"> {
|
|
let summary = "Power operator";
|
|
let description = [{
|
|
Returns `lhs ^ rhs` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
|
}];
|
|
}
|
|
def LHLO_RemOp : LHLO_BinaryElementwiseOp<"remainder", LHLO_IntOrFpBuffer> {
|
|
let summary = "Remainder operator";
|
|
let description = [{
|
|
Returns `lhs % rhs` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
|
}];
|
|
}
|
|
def LHLO_ShiftLeftOp : LHLO_BinaryElementwiseOp<"shift_left", LHLO_IntBuffer> {
|
|
let summary = "Shift Left operator";
|
|
let description = [{
|
|
Returns `lhs << rhs` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
|
}];
|
|
}
|
|
def LHLO_ShiftRightArithmeticOp : LHLO_BinaryElementwiseOp<"shift_right_arithmetic", LHLO_IntBuffer> {
|
|
let summary = "Shift right arithmetic operator";
|
|
let description = [{
|
|
Returns arithmetic `lhs >> rhs` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
|
}];
|
|
}
|
|
def LHLO_ShiftRightLogicalOp : LHLO_BinaryElementwiseOp<"shift_right_logical", LHLO_IntBuffer> {
|
|
let summary = "Shift right logical operator";
|
|
let description = [{
|
|
Returns logical `lhs >> rhs` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
|
}];
|
|
}
|
|
def LHLO_SubOp : LHLO_BinaryElementwiseOp<"subtract"> {
|
|
let summary = "Subtraction operator";
|
|
let description = [{
|
|
Returns `lhs - rhs` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
|
}];
|
|
}
|
|
def LHLO_XorOp : LHLO_BinaryElementwiseOp<"xor", LHLO_PredOrIntBuffer> {
|
|
let summary = "Logical xor";
|
|
let description = [{
|
|
Returns `logical_xor(lhs, rhs)` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
|
}];
|
|
}
|
|
//===----------------------------------------------------------------------===//
|
|
// LMHLO control flow op definitions.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// TODO(b/139813999): specify required function signature in a type-safe way.
|
|
//
|
|
// The region `body` may return lmhlo.TerminatorOp or mhlo.ReturnOp. We are
|
|
// moving towards mhlo.ReturnOp, but some code that needs cleanup still assumes lmhlo.TerminatorOp.
|
|
// TODO(timshen): cleanup lmhlo.TerminatorOp.
|
|
def LHLO_ReduceOp: LHLO_Op<"reduce", [SameVariadicOperandSize]> {
|
|
let summary = "Reduce operator";
|
|
let description = [{
|
|
Returns the result of executing a reduction function on one or more arrays
|
|
in parallel.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#reduce.
|
|
}];
|
|
let arguments = (ins
|
|
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$inputs,
|
|
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$init_values,
|
|
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$out,
|
|
I64ElementsAttr:$dimensions
|
|
);
|
|
|
|
let regions = (region SizedRegion<1>:$body);
|
|
|
|
let hasCanonicalizer = 1;
|
|
}
|
|
|
|
def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [SameVariadicOperandSize]> {
|
|
let summary = "ReduceWindow operator";
|
|
let description = [{
|
|
Returns the result of executing a reduction function over all elements in
|
|
each window of one or more arrays in parallel.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#reducewindow.
|
|
}];
|
|
let arguments = (ins
|
|
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$inputs,
|
|
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$init_values,
|
|
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$out,
|
|
I64ElementsAttr:$window_dimensions,
|
|
// If strides or dilations attributes are missing then the default value is
|
|
// one for each of the input dimensions. Similarly, padding values are zero
|
|
// for both low and high in each of the dimensions, if not specified.
|
|
OptionalAttr<I64ElementsAttr>:$window_strides,
|
|
OptionalAttr<I64ElementsAttr>:$base_dilations,
|
|
OptionalAttr<I64ElementsAttr>:$window_dilations,
|
|
OptionalAttr<I64ElementsAttr>:$padding
|
|
);
|
|
|
|
let regions = (region SizedRegion<1>:$body);
|
|
let verifier = [{ return Verify(*this); }];
|
|
}
|
|
|
|
// TODO(timshen): Add a custom syntax for this.
|
|
def LHLO_CaseOp: LHLO_Op<"case", [
|
|
SingleBlockImplicitTerminator<"TerminatorOp">,
|
|
DeclareOpInterfaceMethods<RegionBranchOpInterface>]> {
|
|
let summary = "Switch-Case operator";
|
|
let description = [{
|
|
Returns the result of executing `branches[index]`. If
|
|
`index` is < 0 or >= N, then `branches[N-1] is executed as
|
|
the default branch.
|
|
|
|
Each branch `branches[b]` must take in a single argument of same type as
|
|
`branch_operands[b]` and will be invoked with `branch_operands[b]`. The type
|
|
of the returned value of each branch must be the same.
|
|
|
|
Note that only one of the branches will be executed depending on the value
|
|
of index.
|
|
See https://www.tensorflow.org/xla/operation_semantics#conditional.
|
|
}];
|
|
|
|
let arguments = (ins Arg<LHLO_PredOrIntBuffer, "", [MemRead]>:$index);
|
|
|
|
let regions = (region VariadicRegion<SizedRegion<1>>:$branches);
|
|
}
|
|
|
|
// TODO(timshen): Add a custom syntax for this.
|
|
def LHLO_WhileOp: LHLO_Op<"while", [
|
|
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
|
|
DeclareOpInterfaceMethods<LoopLikeOpInterface>]> {
|
|
let summary = "While operator";
|
|
let description = [{
|
|
Returns the result of executing a body function until the cond body returns
|
|
true.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#while.
|
|
}];
|
|
let arguments = (ins
|
|
Arg<Variadic<LHLO_PredBuffer>, "", [MemWrite]>:$cond_val,
|
|
OptionalAttr<I64Attr>:$trip_count);
|
|
|
|
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
|
|
}
|
|
|
|
def LHLO_CustomCallOp : LHLO_Op<"custom_call", [AttrSizedOperandSegments]> {
|
|
let summary = "CustomCall operator";
|
|
let description = [{
|
|
A custom call invokes code external to XLA. The `args` are passed to the
|
|
external code, and the external code is expected to produce a result of the
|
|
given type. The exact mechanism is backend-specific. For example, in the CPU
|
|
backend, a call instruction is emitted which targets a symbol with the name
|
|
`call_target_name`.
|
|
|
|
`call_target_name` and `backend_config` can be arbitrary strings, but
|
|
`call_target_name` should be short as it may be used in labels.
|
|
`backend_config` can encode arbitrarily large amounts of information.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#customcall.
|
|
}];
|
|
let arguments = (ins
|
|
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$args,
|
|
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$output,
|
|
StrAttr:$call_target_name,
|
|
DefaultValuedAttr<BoolAttr, "false">:$has_side_effect,
|
|
DefaultValuedAttr<StrAttr, "">:$backend_config,
|
|
OptionalAttr<CustomCallTargetArgMapping>:$target_arg_mapping
|
|
);
|
|
let verifier = [{ return Verify(*this); }];
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// LMHLO tuple op definitions.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
def LHLO_CompareOp: LHLO_Op<"compare", []> {
|
|
let summary = "Comparison operator";
|
|
let description = [{
|
|
Compares `lhs` and `rhs` elementwise according to `comparison_direction`
|
|
and `compare_type`. If unspecified, `compare_type` is FLOAT for float element
|
|
types, SIGNED for signed element types and UNSIGNED for unsigned element
|
|
types.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations.
|
|
}];
|
|
let arguments = (ins
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
|
Arg<LHLO_PredBuffer, "", [MemWrite]>:$out,
|
|
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions,
|
|
HLO_ComparisonDirectionAttr:$comparison_direction,
|
|
OptionalAttr<HLO_ComparisonTypeAttr>:$compare_type
|
|
);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// LMHLO Slice definitions.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
def LHLO_SliceOp: LHLO_Op<
|
|
"slice",
|
|
[AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> {
|
|
let arguments = (ins
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
|
I64ElementsAttr:$start_indices,
|
|
I64ElementsAttr:$limit_indices,
|
|
I64ElementsAttr:$strides
|
|
);
|
|
}
|
|
|
|
def LHLO_DynamicSliceOp: LHLO_Op<"dynamic_slice",
|
|
[AllElementTypesMatch<["operand", "output"]>]> {
|
|
let summary = "Dynamic Slice operator";
|
|
let description = [{
|
|
Extracts a sub-array from the input array at dynamic start_indices.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#dynamicslice.
|
|
}];
|
|
let arguments = (ins
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
|
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$start_indices,
|
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
|
I64ElementsAttr:$slice_sizes
|
|
);
|
|
}
|
|
|
|
def LHLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> {
|
|
let summary = "Dynamic Update Slice operator";
|
|
let description = [{
|
|
DynamicUpdateSlice generates a result which is the value of the input array
|
|
operand, with a slice update overwritten at start_indices.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice.
|
|
}];
|
|
let arguments = (ins
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$update,
|
|
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$start_indices,
|
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
|
);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// LMHLO Other op definitions.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
def LHLO_BatchNormGradOp : LHLO_Op<"batch_norm_grad", []> {
|
|
let summary = "Batch Normalization Gradient";
|
|
let description = [{
|
|
Calculates gradients of batch norm.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#batchnormgrad
|
|
}];
|
|
|
|
let arguments = (ins
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$mean,
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$variance,
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$grad_output,
|
|
Arg<LHLO_Buffer, "", [MemWrite]>:$grad_operand, // gradient of $operand.
|
|
Arg<LHLO_Buffer, "", [MemWrite]>:$grad_scale,
|
|
Arg<LHLO_Buffer, "", [MemWrite]>:$grad_offset,
|
|
F32Attr:$epsilon,
|
|
I64Attr:$feature_index
|
|
);
|
|
|
|
}
|
|
|
|
def LHLO_BatchNormInferenceOp : LHLO_Op<"batch_norm_inference", []> {
|
|
let summary = "Batch Normalization for Inference";
|
|
let description = [{
|
|
Normalizes an array across batch and spatial dimensions.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#batchnorminference
|
|
}];
|
|
|
|
let arguments = (ins
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$offset,
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$mean,
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$variance,
|
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
|
F32Attr:$epsilon,
|
|
I64Attr:$feature_index
|
|
);
|
|
}
|
|
|
|
def LHLO_BatchNormTrainingOp : LHLO_Op<"batch_norm_training", []> {
|
|
let summary = "Batch Normalization for Training";
|
|
let description = [{
|
|
Normalizes an array across batch and spatial dimensions.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#batchnormtraining
|
|
}];
|
|
|
|
let arguments = (ins
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$offset,
|
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
|
Arg<LHLO_Buffer, "", [MemWrite]>:$batch_mean,
|
|
Arg<LHLO_Buffer, "", [MemWrite]>:$batch_var,
|
|
F32Attr:$epsilon,
|
|
I64Attr:$feature_index
|
|
);
|
|
}
|
|
|
|
def LHLO_BroadcastOp : LHLO_Op<"broadcast",
|
|
[]> {
|
|
let summary = "Broadcast a tensor to a higher rank by prepending dimensions";
|
|
let description = [{
|
|
Broadcasts the operand tensor to a higher rank by prepending
|
|
`broadcast_sizes` to the dimensions. The current values of the operand are
|
|
copied into the other dimensions.
|
|
|
|
This is a more limited form of broadcasting, that corresponds to the XLA
|
|
client Broadcast method. For a more general form of broadcasting, see the
|
|
BroadcastInDimOp.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#broadcast.
|
|
}];
|
|
let arguments = (ins
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
|
I64ElementsAttr:$broadcast_sizes
|
|
);
|
|
}
|
|
|
|
def LHLO_BroadcastInDimOp : LHLO_Op<"broadcast_in_dim",
|
|
[]> {
|
|
let summary = "Broadcast a tensor into the given shape by adding dimensions.";
|
|
let description = [{
|
|
Broadcasts the `operand` tensor to a higher rank. This is not the limited
|
|
form of broadcasting exposed as the XLA client broadcast op, but rather the
|
|
more powerful "InDim" broadcasting, which is closer to the HLO broadcast op
|
|
and exposed in the XLA client BroadcastInDim method.
|
|
|
|
`broadcast_dimensions` maps the operand dimension number to the target shape
|
|
dimension number. It must have the same size as the rank of the operand. The
|
|
mapped dimensions must either be the same size or the dimension being
|
|
broadcast from must be size 1 (degenerate broadcasting).
|
|
|
|
For a scalar (0D tensor) operand, `broadcast_dimensions` must be empty. The
|
|
The scalar value will be broadcast to every element in the target shape.
|
|
|
|
See https://www.tensorflow.org/xla/broadcasting.
|
|
}];
|
|
let arguments = (ins
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
|
BroadcastDimAttr:$broadcast_dimensions
|
|
);
|
|
}
|
|
|
|
def LHLO_ClampOp : LHLO_Op<"clamp", []> {
|
|
let summary = "Clamp operator";
|
|
let description = [{
|
|
Clamps an operand to within the range between a minimum and maximum value.
|
|
|
|
Note: All three arrays must be the same shape. Alternatively, as a
|
|
restricted form of broadcasting, min and/or max can be a scalar (0D
|
|
tensor) of the element type of the tensor operand.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#clamp.
|
|
}];
|
|
let arguments = (ins
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$min,
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$max,
|
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
|
);
|
|
}
|
|
|
|
def LHLO_ConcatenateOp : LHLO_Op<"concatenate", []> {
|
|
let summary = "XLA's concatenate op";
|
|
let description = [{
|
|
Concatenates a set of tensors along the specified dimension.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#concatenate.
|
|
}];
|
|
let arguments = (ins
|
|
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$val,
|
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
|
I64Attr:$dimension
|
|
);
|
|
}
|
|
|
|
def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp {
|
|
let arguments = !con(
|
|
(ins
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output),
|
|
ConvolutionAttributes.attributes);
|
|
}
|
|
|
|
def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]> {
|
|
let summary = "Copy operator";
|
|
let description = [{
|
|
Returns a copy of `operand`.
|
|
}];
|
|
let arguments = (ins
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
|
);
|
|
|
|
let extraClassDeclaration = [{
|
|
Value getSource() { return operand();}
|
|
Value getTarget() { return output(); }
|
|
}];
|
|
}
|
|
|
|
def LHLO_DotOp: LHLO_Op<"dot", []> {
|
|
let summary = "Dot operator";
|
|
let description = [{
|
|
Performs dot products between vectors, vector/matrix and matrix/matrix
|
|
multiplication.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#dot.
|
|
}];
|
|
let arguments = (ins
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
|
DotDimensionNumbers:$dot_dimension_numbers,
|
|
HLO_PrecisionConfigAttr:$precision_config,
|
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
|
);
|
|
}
|
|
|
|
def LHLO_GatherOp: LHLO_Op<"gather", []> {
|
|
let arguments = (ins
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
|
Arg<LHLO_IntBuffer, "", [MemRead]>:$start_indices,
|
|
GatherDimensionNumbers:$dimension_numbers,
|
|
I64ElementsAttr:$slice_sizes,
|
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
|
);
|
|
}
|
|
|
|
def LHLO_ReshapeOp: LHLO_Op<"reshape", []> {
|
|
let summary = "Reshape operator";
|
|
let description = [{
|
|
Reshapes the dimensions of `operand` into a new configuration.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#reshape.
|
|
}];
|
|
let arguments = (ins
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
|
);
|
|
}
|
|
|
|
def LHLO_ScatterOp: LHLO_Op<"scatter", []> {
|
|
let summary = "Scatter operator";
|
|
let description = [{
|
|
Generates a result which is the value of the input array `operand`,
|
|
with several slices (at indices specified by `scatter_indices`)
|
|
updated with the values in `updates` using `update_computation`.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#scatter.
|
|
}];
|
|
let arguments = (ins
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$scatter_indices,
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$updates,
|
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
|
ScatterDimensionNumbers:$scatter_dimension_numbers,
|
|
DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted,
|
|
DefaultValuedAttr<BoolAttr, "false">:$unique_indices
|
|
);
|
|
|
|
let regions = (region SizedRegion<1>:$update_computation);
|
|
}
|
|
|
|
def LHLO_SelectOp: LHLO_Op<"select", []> {
|
|
let summary = "Select operator";
|
|
let description = [{
|
|
Constructs an output tensor from the elements of `on_true` and `on_false`
|
|
based on the values of `pred`.
|
|
|
|
`pred`, `on_true` and `on_false` must be broadcast compatible.
|
|
}];
|
|
let arguments = (ins
|
|
Arg<LHLO_PredBuffer, "", [MemRead]>:$pred,
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$on_true,
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$on_false,
|
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
|
);
|
|
}
|
|
|
|
def LHLO_SelectAndScatterOp: LHLO_Op<"select_and_scatter", []> {
|
|
let summary = "SelectAndScatter operator";
|
|
let description = [{
|
|
Runs a windowed selection `select` function over `operand` with shape
|
|
`window_dimensions` and stride `window_strides`. This will produce an amount
|
|
of selected locations whose shape matches `source`. These are then scattered
|
|
to the output which is initialized with `init_value`.
|
|
Multiple scattered elements which land in the same output location are
|
|
combined using the `scatter` function.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#selectandscatter.
|
|
}];
|
|
let arguments = (ins
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$source,
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$init_value,
|
|
Arg<LHLO_Buffer, "", [MemWrite]>:$out,
|
|
OptionalAttr<I64ElementsAttr>:$window_dimensions,
|
|
OptionalAttr<I64ElementsAttr>:$window_strides,
|
|
OptionalAttr<I64ElementsAttr>:$padding
|
|
);
|
|
|
|
let regions = (region SizedRegion<1>:$select, SizedRegion<1>:$scatter);
|
|
}
|
|
|
|
def LHLO_ReverseOp: LHLO_Op<"reverse", []> {
|
|
let summary = "Reverse operator";
|
|
let description = [{
|
|
Reverses the specified dimensions of `operand` according to the given
|
|
`dimensions`.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#rev_reverse.
|
|
}];
|
|
let arguments = (ins
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
|
I64ElementsAttr:$dimensions,
|
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
|
);
|
|
}
|
|
|
|
def LHLO_PadOp: LHLO_Op<"pad", []> {
|
|
let summary = "Pad operator";
|
|
let description = [{
|
|
Pads the edges of `operand` with the `padding_value` and according to
|
|
the passed configuration.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#pad.
|
|
}];
|
|
let arguments = (ins
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$padding_value,
|
|
I64ElementsAttr:$edge_padding_low,
|
|
I64ElementsAttr:$edge_padding_high,
|
|
I64ElementsAttr:$interior_padding,
|
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
|
);
|
|
}
|
|
|
|
def LHLO_TransposeOp: LHLO_Op<"transpose", []> {
|
|
let summary = "Transpose operator";
|
|
let description = [{
|
|
Permutes the dimensions of `operand` according to the given `permutation`.
|
|
|
|
`res_dimensions[i] = operand_dimensions[permutation[i]]`
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#transpose.
|
|
}];
|
|
let arguments = (ins
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
|
I64ElementsAttr:$permutation,
|
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
|
);
|
|
}
|
|
|
|
def LHLO_ReducePrecisionOp: LHLO_Op<"reduce_precision", [SameTypeOperands]> {
|
|
let summary = "Reduce precision operator";
|
|
let description = [{
|
|
Models the effect of converting floating - point values to a lower -
|
|
precision format(such as IEEE - FP16) and back to the original
|
|
format. The number of exponent and mantissa bits in the lower -
|
|
precision format can be specified arbitrarily,
|
|
although all bit sizes may not be supported on all hardware
|
|
implementations.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#reduceprecision.
|
|
}];
|
|
let arguments = (ins
|
|
Arg<LHLO_FpBuffer, "", [MemRead]>:$operand,
|
|
Arg<LHLO_FpBuffer, "", [MemWrite]>:$output,
|
|
I32Attr:$exponent_bits,
|
|
I32Attr:$mantissa_bits
|
|
);
|
|
}
|
|
|
|
// Common base class for AllReduce, AllGather, and AllToAll.
|
|
class LHLO_CollectiveCommunicationOp<string name, list<OpTrait> traits = []> :
|
|
LHLO_Op<name, !listconcat(traits, [SameVariadicOperandSize])> {
|
|
dag arguments_base = (ins
|
|
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
|
|
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$results,
|
|
I64ElementsAttr:$replica_groups,
|
|
DefaultValuedAttr<BoolAttr, "false">:$constrain_layout,
|
|
OptionalAttr<ChannelHandle>:$channel_id,
|
|
DefaultValuedAttr<BoolAttr, "false">:$use_global_device_ids
|
|
);
|
|
let verifier = [{ return Verify(*this); }];
|
|
let extraClassDeclaration = [{
|
|
// AllGather is cross replica if channel_id is not set.
|
|
bool IsCrossReplica() { return !channel_id().hasValue(); }
|
|
}];
|
|
}
|
|
|
|
def LHLO_AllGatherOp : LHLO_CollectiveCommunicationOp<"all_gather"> {
|
|
let summary = "AllGather operator";
|
|
let description = [{
|
|
Performs concatenation across replicas.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#allgather
|
|
}];
|
|
let arguments = !con(
|
|
arguments_base,
|
|
(ins I64Attr:$all_gather_dimension));
|
|
}
|
|
|
|
def LHLO_AllReduceOp : LHLO_CollectiveCommunicationOp<"all_reduce", [SameOperandsElementType]> {
|
|
let summary = "AllReduce operator";
|
|
let description = [{
|
|
Performs a custom reduction across replicas.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#allreduce.
|
|
}];
|
|
let arguments = arguments_base;
|
|
let regions = (region SizedRegion<1>:$computation);
|
|
}
|
|
|
|
def LHLO_AllToAllOp : LHLO_CollectiveCommunicationOp<"all_to_all", [SameOperandsElementType]> {
|
|
let arguments = !con(
|
|
arguments_base,
|
|
(ins OptionalAttr<I64Attr>:$split_dimension));
|
|
}
|
|
|
|
def LHLO_CollectivePermuteOp: LHLO_Op<"collective_permute", [SameTypeOperands]> {
|
|
let summary = "CollectivePermute operator";
|
|
let description = [{
|
|
CollectivePermute is a collective operation that sends and receives data
|
|
cross replicas.
|
|
Note that there are the following restrictions on the source_target_pair:
|
|
- Any two pairs should not have the same target replica id, and they should
|
|
not have the same source replica id.
|
|
- If a replica id is not a target in any pair, then the output on that
|
|
replica is a tensor consists of 0(s) with the same shape as the input.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#collectivepermute.
|
|
|
|
}];
|
|
|
|
let arguments = (ins
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
|
I64ElementsAttr:$source_target_pairs,
|
|
OptionalAttr<ChannelHandle>:$channel_id
|
|
);
|
|
let verifier = [{ return Verify(*this); }];
|
|
}
|
|
|
|
def LHLO_FftOp: LHLO_Op<"fft", []> {
|
|
let summary = "Fast fourier transform operator";
|
|
let description = [{
|
|
Returns the fast-fourier-transform of the input array.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#fft.
|
|
}];
|
|
let arguments = (ins
|
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
|
HLO_FftTypeAttr:$fft_type,
|
|
I64ElementsAttr:$fft_length
|
|
);
|
|
}
|
|
|
|
def LHLO_CholeskyOp: LHLO_Op<"cholesky", [SameOperandsElementType]> {
|
|
let summary = "Cholesky operator";
|
|
let description = [{
|
|
Computes the Cholesky decomposition of a batch of symmetric (Hermitian)
|
|
positive definite matrices.
|
|
|
|
If lower is true, computes lower-triangular matrices l such that
|
|
`a=l.Transpose(l)`. If lower is false, computes upper-triangular matrices u such
|
|
that `a=Transpose(u).u`.
|
|
|
|
Input data is read only from the lower/upper triangle of a, depending on the
|
|
value of lower. Values from the other triangle are ignored. Output data is
|
|
returned in the same triangle; the values in the other triangle are
|
|
implementation-defined and may be anything.
|
|
|
|
If the rank of a is greater than 2, a is treated as a batch of matrices, where
|
|
all except the minor 2 dimensions are batch dimensions.
|
|
|
|
If a is not symmetric (Hermitian) positive definite, the result is
|
|
implementation-defined.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#cholesky.
|
|
}];
|
|
let arguments = (ins
|
|
Arg<LHLO_FpOrComplexBuffer, "", [MemRead]>:$a,
|
|
Arg<LHLO_FpOrComplexBuffer, "", [MemWrite]>:$output,
|
|
DefaultValuedAttr<BoolAttr, "false">:$lower
|
|
);
|
|
}
|
|
|
|
def LHLO_InfeedOp: LHLO_Op<"infeed", []> {
|
|
let summary = "Infeed operator";
|
|
let description = [{
|
|
Reads a single data item from the implicit Infeed streaming interface of
|
|
the device, interpreting the data as the given shape and its layout, and
|
|
returns an LHLO op of the data. Multiple Infeed operations are allowed in a
|
|
computation, but there must be a total order among the Infeed operations.
|
|
For example, two Infeeds in the code below have a total order since there
|
|
is a dependency between the while loops.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#infeed
|
|
}];
|
|
let arguments = (ins
|
|
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$outputs,
|
|
DefaultValuedAttr<StrAttr, "">:$config
|
|
);
|
|
}
|
|
|
|
def LHLO_OutfeedOp: LHLO_Op<"outfeed", []> {
|
|
let arguments = (ins
|
|
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
|
|
DefaultValuedAttr<StrAttr, "">:$config
|
|
);
|
|
}
|
|
|
|
def LHLO_ReplicaIdOp : LHLO_Op<"replica_id", []> {
|
|
let summary = "ReplicaId operator";
|
|
let description = [{
|
|
Returns the unique ID (int32 scalar) of the replica.
|
|
|
|
The unique ID of each replica is an unsigned integer in the interval [0, N),
|
|
where N is the number of replicas. Since all the replicas are running the
|
|
same program, a ReplicaId() call in the program will return a different
|
|
value on each replica.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#replicaid.
|
|
}];
|
|
let arguments = (ins Arg<MemRefOf<[UI32]>, "", [MemWrite]>);
|
|
}
|
|
|
|
def LHLO_PartitionIdOp : LHLO_Op<"partition_id", []> {
|
|
let summary = "PartitionId operator";
|
|
let description = [{
|
|
Returns the unique ID (int32 scalar) of the partition.
|
|
}];
|
|
let arguments = (ins Arg<MemRefOf<[UI32]>, "", [MemWrite]>);
|
|
}
|
|
|
|
def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType]> {
|
|
let summary = "TriangularSolve operator";
|
|
let description = [{
|
|
Solves systems of linear equations with lower or upper triangular
|
|
coefficient matrices by forward- or back-substitution. Broadcasting along
|
|
leading dimensions, this routine solves one of the matrix systems
|
|
op(a) * x = b, or x * op(a) = b, for the variable x, given a and b, where
|
|
op(a) is either op(a) = a, or op(a) = Transpose(a), or
|
|
op(a) = Conj(Transpose(a)).
|
|
|
|
Input data is read only from the lower/upper triangle of a, depending on the
|
|
value of lower. Values from the other triangle are ignored. Output data is
|
|
returned in the same triangle; the values in the other triangle are
|
|
implementation-defined and may be anything.
|
|
|
|
If the rank of a and b are greater than 2, they are treated as batches of
|
|
matrices, where all except the minor 2 dimensions are batch dimensions. a
|
|
and b must have equal batch dimensions.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#triangularsolve.
|
|
}];
|
|
let arguments = (ins
|
|
Arg<LHLO_FpOrComplexBuffer, "", [MemRead]>:$a,
|
|
Arg<LHLO_FpOrComplexBuffer, "", [MemRead]>:$b,
|
|
Arg<LHLO_FpOrComplexBuffer, "", [MemWrite]>:$output,
|
|
BoolAttr:$left_side,
|
|
BoolAttr:$lower,
|
|
BoolAttr:$unit_diagonal,
|
|
HLO_TransposeAttr:$transpose_a,
|
|
HLO_LayoutAttr:$layout_a,
|
|
HLO_LayoutAttr:$layout_b,
|
|
HLO_LayoutAttr:$layout_output
|
|
);
|
|
}
|
|
|
|
// TODO(timshen): add a custom verifier.
|
|
def LHLO_MapOp: LHLO_Op<"map", [SameOperandsShape]> {
|
|
let summary = "Map operator";
|
|
let description = [{
|
|
Applies a scalar function over the given operands arrays, producing an array
|
|
of the same dimensions where each element is the result of the mapped function
|
|
applied to the corresponding elements in the input arrays.
|
|
|
|
The mapped function is an arbitrary computation with the restriction that it
|
|
has N inputs of scalar type T and a single output with type S. The output has
|
|
the same dimensions as the operands except that the element type T is replaced
|
|
with S.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#map.
|
|
}];
|
|
let arguments = (ins
|
|
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
|
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
|
I64ElementsAttr:$dimensions
|
|
);
|
|
let regions = (region SizedRegion<1>:$computation);
|
|
}
|
|
|
|
def LHLO_RngGetAndUpdateStateOp: LHLO_Op<"rng_get_and_update_state", []> {
|
|
let arguments = (ins
|
|
Arg<MemRefOf<[UI64]>, "", [MemRead, MemWrite]>:$state,
|
|
I64Attr:$delta
|
|
);
|
|
}
|
|
|
|
// TODO(timshen): add a custom verifier.
|
|
def LHLO_SortOp: LHLO_Op<"sort", [SameVariadicOperandSize, SameOperandsShape]> {
|
|
let summary = "Sort operator";
|
|
let description = [{
|
|
Sorts the given `operands` at the given `dimension` with the given
|
|
`comparator`.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#sort.
|
|
}];
|
|
let arguments = (ins
|
|
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
|
|
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$output,
|
|
DefaultValuedAttr<I64Attr, "-1">:$dimension,
|
|
DefaultValuedAttr<BoolAttr, "false">:$is_stable
|
|
);
|
|
|
|
let regions = (region SizedRegion<1>:$comparator);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Late operations
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
def FusionOp : LHLO_Op<"fusion", [
|
|
SingleBlockImplicitTerminator<"TerminatorOp">,
|
|
DeclareOpInterfaceMethods<RegionBranchOpInterface>
|
|
]> {
|
|
let summary = "Fusion operator";
|
|
let description = [{
|
|
Models the fusion instruction generated by the XLA compiler's fusion pass.
|
|
|
|
Fusion instructions are generated by the fusion pass of the XLA compiler.
|
|
They serve as a hint to the backend that it is beneficial to emit the
|
|
contained instructions into a single loop nest or kernel. The XLA fusion
|
|
pass is designed such that it only generates fusion nodes that can be
|
|
handled by the XLA compilers backends.
|
|
The XLA runtime expects this hint to be followed, as it expects a single
|
|
kernel per HLO instruction. This restriction might be lifted in the future.
|
|
}];
|
|
let regions = (region SizedRegion<1>:$region);
|
|
|
|
let skipDefaultBuilders = 1;
|
|
let builders = [
|
|
OpBuilder<(ins CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
|
|
];
|
|
|
|
let extraClassDeclaration = [{
|
|
SmallVector<Value> getInputBuffers() {
|
|
SmallVector<Value> buffers;
|
|
for (auto load : region().front().getOps<memref::TensorLoadOp>()) {
|
|
buffers.push_back(load.memref());
|
|
}
|
|
return buffers;
|
|
}
|
|
|
|
SmallVector<Value> getOutputBuffers() {
|
|
SmallVector<Value> buffers;
|
|
for (auto store : region().front().getOps<memref::TensorStoreOp>()) {
|
|
buffers.push_back(store.memref());
|
|
}
|
|
return buffers;
|
|
}
|
|
|
|
SmallVector<Value> getFusionParameters() {
|
|
SmallVector<Value> buffers;
|
|
for (auto load : region().front().getOps<memref::TensorLoadOp>()) {
|
|
buffers.push_back(load);
|
|
}
|
|
return buffers;
|
|
}
|
|
|
|
SmallVector<Value> getFusionResults() {
|
|
SmallVector<Value> buffers;
|
|
for (auto store : region().front().getOps<memref::TensorStoreOp>()) {
|
|
buffers.push_back(store.tensor());
|
|
}
|
|
return buffers;
|
|
}
|
|
}];
|
|
}
|
|
|
|
def TerminatorOp :
|
|
LHLO_Op<"terminator", [ReturnLike, Terminator]> {
|
|
let summary = "LHLO termination operation";
|
|
let description = [{
|
|
Terminator operation for the LHLO dialect.
|
|
}];
|
|
let builders = [
|
|
OpBuilder<(ins "ValueRange":$operands),
|
|
[{ build($_builder, $_state, llvm::None, operands, llvm::None); }]>];
|
|
}
|
|
|
|
#endif // LHLO_OPS
|