diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index ce5dfa3..c9050df 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -2775,6 +2775,45 @@ LogicalResult ONNXDropoutOp::inferShapes() { return success(); } +//===----------------------------------------------------------------------===// +// OneHotEncoder +//===----------------------------------------------------------------------===// + +LogicalResult ONNXOneHotEncoderOp::inferShapes() { + ShapedType inputType = X().getType().dyn_cast(); + if (!inputType) + return emitError("Non-shaped input type"); + auto shape = inputType.getShape(); + int64_t outDim = 0; + + // If the input is a tensor of float, int32, or double, + // the data will be cast to integers and + // the cats_int64s category list will be used for the lookups. + if (inputType.getElementType().isIntOrFloat()) { + if (!cats_int64s()) + return emitError("input is a tensor of float, int32, or double, but no " + "cats_int64s attribute"); + outDim = ArrayAttrSize(cats_int64s()); + } else { + if (!cats_strings()) + return emitError("input is not a tensor of float, int32, or double, but " + "no cats_strings attribute"); + outDim = ArrayAttrSize(cats_strings()); + } + + // Encoded output data, having one more dimension than X + // total category count will determine the size of the extra dimension + SmallVector dims; + for (int i = 0; i != shape.size(); ++i) { + dims.emplace_back(shape[i]); + } + dims.emplace_back(outDim); + + getResult().setType( + RankedTensorType::get(dims, FloatType::getF32(getContext()))); + return success(); +} + //===----------------------------------------------------------------------===// // ONNX type related code //===----------------------------------------------------------------------===// diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index 6a3a16c..250c69c 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -6019,7 +6019,7 @@ def ONNXNormalizerOp:ONNX_Op<"Normalizer", } def ONNXOneHotEncoderOp:ONNX_Op<"OneHotEncoder", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX OneHotEncoder operation"; let description = [{ "Replace each input element with an array of ones and zeros, where a single" diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index a51f4e3..02ff661 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -1480,3 +1480,51 @@ func @test_dropout(%arg0: tensor<1x2x3x4xf32>) -> (tensor<*xf32>, tensor<*xi1>) // CHECK: [[RES:%.+]], [[MASK:%.+]] = "onnx.Dropout"(%arg0) {ratio = 1.000000e-01 : f32} : (tensor<1x2x3x4xf32>) -> (tensor<1x2x3x4xf32>, tensor<1x2x3x4xi1>) // CHECK: return [[RES]], [[MASK]] : tensor<1x2x3x4xf32>, tensor<1x2x3x4xi1> } + +// ----- + +//===----------------------------------------------------------------------===// +/// Test shape inference for OneHotEncoder. +//===----------------------------------------------------------------------===// + +func @test_onehotencoder_string1 (%arg0: tensor<20x1x!onnx.String>) -> tensor<*xf32> { + %0 = "onnx.OneHotEncoder"(%arg0) {cats_strings = ["female", "male"], zeros = 1 : i64} : (tensor<20x1x!onnx.String>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_onehotencoder_string1 + // CHECK: [[RES:%.+]] = "onnx.OneHotEncoder"(%arg0) {cats_strings = ["female", "male"], zeros = 1 : i64} : (tensor<20x1x!onnx.String>) -> tensor<20x1x2xf32> + // CHECK: return [[RES]] : tensor<20x1x2xf32> +} + +// ----- + +func @test_onehotencoder_string2 (%arg0: tensor<20x2x!onnx.String>) -> tensor<*xf32> { + %0 = "onnx.OneHotEncoder"(%arg0) {cats_strings = ["female", "male"], zeros = 1 : i64} : (tensor<20x2x!onnx.String>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_onehotencoder_string2 + // CHECK: [[RES:%.+]] = "onnx.OneHotEncoder"(%arg0) {cats_strings = ["female", "male"], zeros = 1 : i64} : (tensor<20x2x!onnx.String>) -> tensor<20x2x2xf32> + // CHECK: return [[RES]] : tensor<20x2x2xf32> +} + +// ----- + +func @test_onehotencoder_float1(%arg0: tensor<20x1xf32>) -> tensor<*xf32> { + %0 = "onnx.OneHotEncoder"(%arg0) {cats_strings = ["female", "male"], cats_int64s = [1, 2, 4], zeros = 1 : i64} : (tensor<20x1xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_onehotencoder_float1 + // CHECK: [[RES:%.+]] = "onnx.OneHotEncoder"(%arg0) {cats_int64s = [1, 2, 4], cats_strings = ["female", "male"], zeros = 1 : i64} : (tensor<20x1xf32>) -> tensor<20x1x3xf32> + // CHECK: return [[RES]] : tensor<20x1x3xf32> +} + +// ----- + +func @test_onehotencoder_float2(%arg0: tensor<20x2x3xf32>) -> tensor<*xf32> { + %0 = "onnx.OneHotEncoder"(%arg0) {cats_strings = ["female", "male"], cats_int64s = [1, 2, 4], zeros = 1 : i64} : (tensor<20x2x3xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_onehotencoder_float2 + // CHECK: [[RES:%.+]] = "onnx.OneHotEncoder"(%arg0) {cats_int64s = [1, 2, 4], cats_strings = ["female", "male"], zeros = 1 : i64} : (tensor<20x2x3xf32>) -> tensor<20x2x3x3xf32> + // CHECK: return [[RES]] : tensor<20x2x3x3xf32> +} \ No newline at end of file diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index 3c33374..1459ac1 100644 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -280,6 +280,7 @@ OpsWithShapeInference=[ 'Min', 'Mul', 'Neg', + 'OneHotEncoder', 'Or', 'Pad', 'Pow',