[XLA/GPU] Add XLA HLO -> LMHLO conversion to several ops, and implement them in XLA/GPU.
PiperOrigin-RevId: 353158172
This commit is contained in:
parent
ae10640d78
commit
d1c785381d
|
@ -713,7 +713,7 @@ def HLO_SliceOp: HLO_Op<
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLO_DynamicSliceOp: HLO_Op<"dynamic-slice",
|
def HLO_DynamicSliceOp: HLO_Op<"dynamic-slice",
|
||||||
[NoSideEffect, AllElementTypesMatch<["operand", "result"]>]> {
|
[NoSideEffect, AllElementTypesMatch<["operand", "result"]>]>, BASE_HLO_DynamicSliceOp {
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
HLO_Tensor:$operand,
|
HLO_Tensor:$operand,
|
||||||
Variadic<HLO_ScalarIntTensor>:$start_indices,
|
Variadic<HLO_ScalarIntTensor>:$start_indices,
|
||||||
|
@ -726,7 +726,7 @@ def HLO_DynamicSliceOp: HLO_Op<"dynamic-slice",
|
||||||
|
|
||||||
def HLO_DynamicUpdateSliceOp: HLO_Op<"dynamic-update-slice",
|
def HLO_DynamicUpdateSliceOp: HLO_Op<"dynamic-update-slice",
|
||||||
[NoSideEffect, AllElementTypesMatch<["operand", "update", "result"]>,
|
[NoSideEffect, AllElementTypesMatch<["operand", "update", "result"]>,
|
||||||
AllShapesMatch<["operand", "result"]>]> {
|
AllShapesMatch<["operand", "result"]>]>, BASE_HLO_DynamicUpdateSliceOp {
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
HLO_Tensor:$operand,
|
HLO_Tensor:$operand,
|
||||||
HLO_Tensor:$update,
|
HLO_Tensor:$update,
|
||||||
|
|
|
@ -214,9 +214,7 @@ def LHLO_ReduceOp: LHLO_Op<"reduce", [SameVariadicOperandSize]>, BASE_HLO_Reduce
|
||||||
let regions = (region SizedRegion<1>:$body);
|
let regions = (region SizedRegion<1>:$body);
|
||||||
}
|
}
|
||||||
|
|
||||||
def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [
|
def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", []>, BASE_HLO_ReduceWindowOp {
|
||||||
SingleBlockImplicitTerminator<"TerminatorOp">
|
|
||||||
]>, BASE_HLO_ReduceWindowOp {
|
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||||
|
@ -309,12 +307,22 @@ def LHLO_SliceOp: LHLO_Op<
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> {
|
def LHLO_DynamicSliceOp: LHLO_Op<"dynamic_slice",
|
||||||
|
[AllElementTypesMatch<["operand", "output"]>]>, BASE_HLO_DynamicSliceOp {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||||
|
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$start_indices,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||||
|
I64ElementsAttr:$slice_sizes
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []>, BASE_HLO_DynamicUpdateSliceOp {
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$update,
|
Arg<LHLO_Buffer, "", [MemRead]>:$update,
|
||||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$start_indices,
|
||||||
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$start_indices
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -135,6 +135,15 @@ DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit) {
|
||||||
std::string LmhloToMhloOpName(llvm::StringRef op_name,
|
std::string LmhloToMhloOpName(llvm::StringRef op_name,
|
||||||
mlir::MLIRContext *context) {
|
mlir::MLIRContext *context) {
|
||||||
assert(op_name.startswith("lmhlo.") && "Expected an LMHLO op");
|
assert(op_name.startswith("lmhlo.") && "Expected an LMHLO op");
|
||||||
|
|
||||||
|
if (op_name == "lmhlo.dot") {
|
||||||
|
return "mhlo.dot_general";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (op_name == "lmhlo.dynamic_slice") {
|
||||||
|
return "mhlo.dynamic-slice";
|
||||||
|
}
|
||||||
|
|
||||||
std::string mhlo_op_name(op_name.drop_front(1));
|
std::string mhlo_op_name(op_name.drop_front(1));
|
||||||
if (context->isOperationRegistered(mhlo_op_name)) return mhlo_op_name;
|
if (context->isOperationRegistered(mhlo_op_name)) return mhlo_op_name;
|
||||||
return "";
|
return "";
|
||||||
|
|
Loading…
Reference in New Issue