[XLA:GPU] Add conversion from HLO -> MLIR LMHLO for TriangularSolve

- Also add layout attributes for inputs and output for error checking.

PiperOrigin-RevId: 355863625
This commit is contained in:
Rahul Joshi 2021-02-05 09:16:49 -08:00 committed by TensorFlow MLIR Team
parent 1c4521cc42
commit b251712b1d
3 changed files with 21 additions and 2 deletions

View File

@ -99,6 +99,17 @@ def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, AnyFloat, HLO_Complex]>;
// Any pred, int or floating-point tensor types
def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, AnyFloat]>;
// A layout attribute (1D tensor of index type)
def HLO_LayoutAttr : Attr<
And<[IndexElementsAttr.predicate,
CPred<[{$_self.cast<::mlir::DenseIntElementsAttr>().getType().getRank()
== 1}]>]>,
"A 1D tensor of index type (layout)"> {
let storageType = IndexElementsAttr.storageType;
let returnType = IndexElementsAttr.returnType;
let convertFromStorage = IndexElementsAttr.convertFromStorage;
}
//===----------------------------------------------------------------------===//
// MHLO nullary op definitions.
//===----------------------------------------------------------------------===//

View File

@ -641,7 +641,10 @@ def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType
BoolAttr:$left_side,
BoolAttr:$lower,
BoolAttr:$unit_diagonal,
HLO_TransposeAttr:$transpose_a
HLO_TransposeAttr:$transpose_a,
HLO_LayoutAttr:$layout_a,
HLO_LayoutAttr:$layout_b,
HLO_LayoutAttr:$layout_output
);
}

View File

@ -865,7 +865,12 @@ func @replica_id_memrefs(%arg_out: memref<ui32>) -> () {
// CHECK-LABEL: func @triangular_solve_memrefs
func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, %arg_out: memref<3x4xf32>) -> () {
"lmhlo.triangular_solve"(%arg0, %arg1, %arg_out) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true}
"lmhlo.triangular_solve"(%arg0, %arg1, %arg_out)
{layout_a = dense<[1, 0]> : tensor<2xindex>,
layout_b = dense<[1, 0]> : tensor<2xindex>,
layout_output = dense<[1, 0]> : tensor<2xindex>,
left_side = true, lower = true, transpose_a = "NO_TRANSPOSE",
unit_diagonal = true}
: (memref<4x4xf32>, memref<3x4xf32>, memref<3x4xf32>) -> ()
return
}