[MLIR][KernelGen] Implement InferShapedTypeOpInterface for `mhlo.compare/select`
PiperOrigin-RevId: 332227340
This commit is contained in:
parent
91f16172a4
commit
b1fd4d27cf
|
@ -678,9 +678,10 @@ def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp {
|
|||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def HLO_CompareOp: HLO_Op<"compare",
|
||||
[NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]>,
|
||||
BASE_HLO_CompareOp {
|
||||
def HLO_CompareOp: HLO_Op<"compare", [NoSideEffect, SameTypeOperands,
|
||||
SameOperandsAndResultShape,
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["reifyReturnTypeShapes"]>]>, BASE_HLO_CompareOp {
|
||||
let arguments = (ins
|
||||
HLO_Tensor:$lhs,
|
||||
HLO_Tensor:$rhs,
|
||||
|
@ -1152,7 +1153,10 @@ def HLO_ScatterOp: HLO_Op<"scatter", [RecursiveSideEffects]>,
|
|||
}
|
||||
|
||||
// TODO(jpienaar): Add broadcastable trait.
|
||||
def HLO_SelectOp: HLO_Op<"select", [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]>, BASE_HLO_SelectOp {
|
||||
def HLO_SelectOp: HLO_Op<"select", [NoSideEffect,
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["reifyReturnTypeShapes"]>, DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
]>, BASE_HLO_SelectOp {
|
||||
let arguments = (ins
|
||||
HLO_PredTensor:$pred,
|
||||
HLO_Tensor:$on_true,
|
||||
|
|
|
@ -1678,6 +1678,20 @@ LogicalResult SelectOp::inferReturnTypes(
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SelectOp::inferReturnTypeComponents(
|
||||
mlir::MLIRContext*, llvm::Optional<mlir::Location>, mlir::ValueRange,
|
||||
mlir::DictionaryAttr, mlir::RegionRange,
|
||||
llvm::SmallVectorImpl<mlir::ShapedTypeComponents>&) {
|
||||
// TODO(b/168772852)
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult SelectOp::reifyReturnTypeShapes(
|
||||
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||
return deriveShapeFromFirstOperand(&builder, getOperation(),
|
||||
&reifiedReturnShapes);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PadOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2473,9 +2487,22 @@ void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs,
|
|||
build(builder, result, new_type, lhs, rhs, comparison_direction);
|
||||
}
|
||||
|
||||
LogicalResult CompareOp::inferReturnTypeComponents(
|
||||
mlir::MLIRContext*, llvm::Optional<mlir::Location>, mlir::ValueRange,
|
||||
mlir::DictionaryAttr, mlir::RegionRange,
|
||||
llvm::SmallVectorImpl<mlir::ShapedTypeComponents>&) {
|
||||
// TODO(b/168772852)
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult CompareOp::reifyReturnTypeShapes(
|
||||
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||
return deriveShapeFromFirstOperand(&builder, getOperation(),
|
||||
&reifiedReturnShapes);
|
||||
}
|
||||
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-hlo-opt -mhlo-test-infer-shaped-type-methods -allow-unregistered-dialect -split-input-file -verify-diagnostics %s -o - | FileCheck %s
|
||||
// RUN: mlir-hlo-opt --mhlo-test-infer-shaped-type-methods --allow-unregistered-dialect --split-input-file %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @broadcast_add
|
||||
// Note that all broadcast_ops are expanded from the same template, so
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
// RUN: mlir-hlo-opt --mhlo-test-infer-shaped-type-methods --allow-unregistered-dialect --split-input-file %s | FileCheck %s
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @select
|
||||
// CHECK-SAME: (%[[PRED:.*]]: tensor<2x?xi1>,
|
||||
func @select(%pred : tensor<2x?xi1>, %a : tensor<2x?xf32>, %b : tensor<2x?xf32>)
|
||||
-> tensor<2xi64> {
|
||||
// CHECK: %[[C2:.*]] = constant 2 : i64
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[DIM_AS_INDEX:.*]] = dim %[[PRED]], %[[C1]] : tensor<2x?xi1>
|
||||
// CHECK: %[[DIM:.*]] = index_cast %[[DIM_AS_INDEX]] : index to i64
|
||||
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[C2]], %[[DIM]] : tensor<2xi64>
|
||||
// CHECK: return %[[SHAPE]] : tensor<2xi64>
|
||||
%0 = "mhlo.select"(%pred, %a, %b)
|
||||
: (tensor<2x?xi1>, tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
|
||||
%1 = "mhlo_test.reify_return_type_shapes"(%0)
|
||||
: (tensor<2x?xf32>) -> tensor<2xi64>
|
||||
return %1 : tensor<2xi64>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @compare
|
||||
// CHECK-SAME: (%[[A:.*]]: tensor<2x?xf32>,
|
||||
func @compare(%a : tensor<2x?xf32>, %b : tensor<2x?xf32>) -> tensor<2xi64> {
|
||||
// CHECK: %[[C2:.*]] = constant 2 : i64
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[DIM_AS_INDEX:.*]] = dim %[[A]], %[[C1]] : tensor<2x?xf32>
|
||||
// CHECK: %[[DIM:.*]] = index_cast %[[DIM_AS_INDEX]] : index to i64
|
||||
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[C2]], %[[DIM]] : tensor<2xi64>
|
||||
// CHECK: return %[[SHAPE]] : tensor<2xi64>
|
||||
%0 = "mhlo.compare"(%a, %b) { comparison_direction = "NE" }
|
||||
: (tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xi1>
|
||||
%1 = "mhlo_test.reify_return_type_shapes"(%0)
|
||||
: (tensor<2x?xi1>) -> tensor<2xi64>
|
||||
return %1 : tensor<2xi64>
|
||||
}
|
||||
|
Loading…
Reference in New Issue