OneHotEncoder Shape Inference (#265)
* move scalerop to decompose * change clang format * change clang format * add shape inference for scaler op * fixing generated onnxop * generate onnx.md * redo get onnx.md and onnxop.td.inc using onnx 1.6 * Add shape inference for scaler op * add benefit for scaler decompose and simplify scaler shape inference * add scaler decompose benefit num and simplify shape inference * add cast builder * cast rewrite only for float * add cast op same type rewrite rule * working on cast lowering * cast lowering working * add cast lowering * fix format * Delete OpBuildTable.inc * complete requested changes Co-authored-by: chentong319 <chentong@us.ibm.com>
This commit is contained in:
parent
d3dcee7366
commit
00299910f3
|
@ -2775,6 +2775,45 @@ LogicalResult ONNXDropoutOp::inferShapes() {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OneHotEncoder
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ONNXOneHotEncoderOp::inferShapes() {
|
||||
ShapedType inputType = X().getType().dyn_cast<ShapedType>();
|
||||
if (!inputType)
|
||||
return emitError("Non-shaped input type");
|
||||
auto shape = inputType.getShape();
|
||||
int64_t outDim = 0;
|
||||
|
||||
// If the input is a tensor of float, int32, or double,
|
||||
// the data will be cast to integers and
|
||||
// the cats_int64s category list will be used for the lookups.
|
||||
if (inputType.getElementType().isIntOrFloat()) {
|
||||
if (!cats_int64s())
|
||||
return emitError("input is a tensor of float, int32, or double, but no "
|
||||
"cats_int64s attribute");
|
||||
outDim = ArrayAttrSize(cats_int64s());
|
||||
} else {
|
||||
if (!cats_strings())
|
||||
return emitError("input is not a tensor of float, int32, or double, but "
|
||||
"no cats_strings attribute");
|
||||
outDim = ArrayAttrSize(cats_strings());
|
||||
}
|
||||
|
||||
// Encoded output data, having one more dimension than X
|
||||
// total category count will determine the size of the extra dimension
|
||||
SmallVector<int64_t, 2> dims;
|
||||
for (int i = 0; i != shape.size(); ++i) {
|
||||
dims.emplace_back(shape[i]);
|
||||
}
|
||||
dims.emplace_back(outDim);
|
||||
|
||||
getResult().setType(
|
||||
RankedTensorType::get(dims, FloatType::getF32(getContext())));
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ONNX type related code
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -6019,7 +6019,7 @@ def ONNXNormalizerOp:ONNX_Op<"Normalizer",
|
|||
}
|
||||
|
||||
def ONNXOneHotEncoderOp:ONNX_Op<"OneHotEncoder",
|
||||
[NoSideEffect]> {
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX OneHotEncoder operation";
|
||||
let description = [{
|
||||
"Replace each input element with an array of ones and zeros, where a single"
|
||||
|
|
|
@ -1480,3 +1480,51 @@ func @test_dropout(%arg0: tensor<1x2x3x4xf32>) -> (tensor<*xf32>, tensor<*xi1>)
|
|||
// CHECK: [[RES:%.+]], [[MASK:%.+]] = "onnx.Dropout"(%arg0) {ratio = 1.000000e-01 : f32} : (tensor<1x2x3x4xf32>) -> (tensor<1x2x3x4xf32>, tensor<1x2x3x4xi1>)
|
||||
// CHECK: return [[RES]], [[MASK]] : tensor<1x2x3x4xf32>, tensor<1x2x3x4xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Test shape inference for OneHotEncoder.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
func @test_onehotencoder_string1 (%arg0: tensor<20x1x!onnx.String>) -> tensor<*xf32> {
|
||||
%0 = "onnx.OneHotEncoder"(%arg0) {cats_strings = ["female", "male"], zeros = 1 : i64} : (tensor<20x1x!onnx.String>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_onehotencoder_string1
|
||||
// CHECK: [[RES:%.+]] = "onnx.OneHotEncoder"(%arg0) {cats_strings = ["female", "male"], zeros = 1 : i64} : (tensor<20x1x!onnx.String>) -> tensor<20x1x2xf32>
|
||||
// CHECK: return [[RES]] : tensor<20x1x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_onehotencoder_string2 (%arg0: tensor<20x2x!onnx.String>) -> tensor<*xf32> {
|
||||
%0 = "onnx.OneHotEncoder"(%arg0) {cats_strings = ["female", "male"], zeros = 1 : i64} : (tensor<20x2x!onnx.String>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_onehotencoder_string2
|
||||
// CHECK: [[RES:%.+]] = "onnx.OneHotEncoder"(%arg0) {cats_strings = ["female", "male"], zeros = 1 : i64} : (tensor<20x2x!onnx.String>) -> tensor<20x2x2xf32>
|
||||
// CHECK: return [[RES]] : tensor<20x2x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_onehotencoder_float1(%arg0: tensor<20x1xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.OneHotEncoder"(%arg0) {cats_strings = ["female", "male"], cats_int64s = [1, 2, 4], zeros = 1 : i64} : (tensor<20x1xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_onehotencoder_float1
|
||||
// CHECK: [[RES:%.+]] = "onnx.OneHotEncoder"(%arg0) {cats_int64s = [1, 2, 4], cats_strings = ["female", "male"], zeros = 1 : i64} : (tensor<20x1xf32>) -> tensor<20x1x3xf32>
|
||||
// CHECK: return [[RES]] : tensor<20x1x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_onehotencoder_float2(%arg0: tensor<20x2x3xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.OneHotEncoder"(%arg0) {cats_strings = ["female", "male"], cats_int64s = [1, 2, 4], zeros = 1 : i64} : (tensor<20x2x3xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_onehotencoder_float2
|
||||
// CHECK: [[RES:%.+]] = "onnx.OneHotEncoder"(%arg0) {cats_int64s = [1, 2, 4], cats_strings = ["female", "male"], zeros = 1 : i64} : (tensor<20x2x3xf32>) -> tensor<20x2x3x3xf32>
|
||||
// CHECK: return [[RES]] : tensor<20x2x3x3xf32>
|
||||
}
|
|
@ -280,6 +280,7 @@ OpsWithShapeInference=[
|
|||
'Min',
|
||||
'Mul',
|
||||
'Neg',
|
||||
'OneHotEncoder',
|
||||
'Or',
|
||||
'Pad',
|
||||
'Pow',
|
||||
|
|
Loading…
Reference in New Issue