Add a pass to decompose ONNX operations (#9)

This commit is contained in:
Tung D. Le 2020-03-05 00:53:59 +09:00 committed by GitHub
parent 7c1dd0279b
commit e97df0b343
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 191 additions and 115 deletions

View File

@ -51,8 +51,7 @@ OpsWithShapeInference = [
# Operations supporting canonicalization.
OpsWithCanonicalizer = [
'Add', 'Identity', 'ReduceL1', 'ReduceL2', 'ReduceLogSum',
'ReduceLogSumExp', 'ReduceSumSquare', 'Gemm'
'Add', 'Identity', 'Gemm'
]
# Add an Op in this list if the Op needs result type deduction which is required

View File

@ -11,6 +11,7 @@ add_library(compiler
dialect/onnx/onnxop.inc
pass/onnx_combine.cpp
pass/onnx_rewrite.cpp
pass/onnx_decompose.cpp
pass/passes.hpp)
# Include root src directory.
@ -25,6 +26,11 @@ target_link_libraries(compiler
${MLIRLibs}
curses)
set(LLVM_TARGET_DEFINITIONS pass/onnx_decompose.td)
onnf_tablegen(onnx_decompose.inc -gen-rewriters)
add_public_tablegen_target(gen_onnx_decompose)
add_dependencies(compiler gen_onnx_decompose)
set(LLVM_TARGET_DEFINITIONS pass/shape_inference_interface.td)
onnf_tablegen(shape_inference.hpp.inc -gen-op-interface-decls)
onnf_tablegen(shape_inference.cpp.inc -gen-op-interface-defs)
@ -55,6 +61,13 @@ onnf_tablegen(krnl.cpp.inc -gen-op-defs)
add_public_tablegen_target(gen_krnl_ops)
add_dependencies(compiler gen_krnl_ops)
add_library(onnf_onnx_decompose pass/onnx_decompose.cpp)
target_include_directories(onnf_onnx_decompose
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
${ONNF_SRC_ROOT})
target_link_libraries(onnf_onnx_decompose ${MLIRLibs})
add_dependencies(onnf_onnx_decompose gen_krnl_ops)
add_library(onnf_shape_inference pass/shape_inference_pass.cpp)
target_include_directories(onnf_shape_inference
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
@ -90,7 +103,7 @@ add_subdirectory(runtime)
add_executable(onnf main.cpp)
target_link_libraries(onnf builder ${MLIRLibs} onnf_transform onnf_shape_inference onnf_lower_frontend)
target_link_libraries(onnf builder ${MLIRLibs} onnf_transform onnf_onnx_decompose onnf_shape_inference onnf_lower_frontend)
whole_archive_link_mlir(onnf ${MLIRWholeArchiveLibs})
find_package(ZLIB REQUIRED)
target_link_libraries(onnf ${ZLIB_LIBRARIES})

View File

@ -2296,7 +2296,6 @@ def ONNXReciprocalOp:ONNX_Op<"Reciprocal",
def ONNXReduceL1Op:ONNX_Op<"ReduceL1",
[NoSideEffect]> {
let hasCanonicalizer = 1;
let summary = "ONNX ReduceL1 operation";
let description = [{
"Computes the L1 norm of the input tensor's element along the provided axes. The resulted"
@ -2314,7 +2313,6 @@ def ONNXReduceL1Op:ONNX_Op<"ReduceL1",
def ONNXReduceL2Op:ONNX_Op<"ReduceL2",
[NoSideEffect]> {
let hasCanonicalizer = 1;
let summary = "ONNX ReduceL2 operation";
let description = [{
"Computes the L2 norm of the input tensor's element along the provided axes. The resulted"
@ -2332,7 +2330,6 @@ def ONNXReduceL2Op:ONNX_Op<"ReduceL2",
def ONNXReduceLogSumOp:ONNX_Op<"ReduceLogSum",
[NoSideEffect]> {
let hasCanonicalizer = 1;
let summary = "ONNX ReduceLogSum operation";
let description = [{
"Computes the log sum of the input tensor's element along the provided axes. The resulted"
@ -2350,7 +2347,6 @@ def ONNXReduceLogSumOp:ONNX_Op<"ReduceLogSum",
def ONNXReduceLogSumExpOp:ONNX_Op<"ReduceLogSumExp",
[NoSideEffect]> {
let hasCanonicalizer = 1;
let summary = "ONNX ReduceLogSumExp operation";
let description = [{
"Computes the log sum exponent of the input tensor's element along the provided axes. The resulted"
@ -2465,7 +2461,6 @@ def ONNXReduceSumOp:ONNX_Op<"ReduceSum",
def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare",
[NoSideEffect]> {
let hasCanonicalizer = 1;
let summary = "ONNX ReduceSumSquare operation";
let description = [{
"Computes the sum square of the input tensor's element along the provided axes. The resulted"

View File

@ -122,8 +122,9 @@ int main(int argc, char *argv[]) {
}
mlir::PassManager pm(&context);
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createDecomposeONNXToONNXPass());
pm.addPass(mlir::createShapeInferencePass());
pm.addPass(mlir::createCanonicalizerPass());
if (emissionTarget >= EmitMLIR) {
pm.addPass(mlir::createLowerToKrnlPass());

View File

@ -0,0 +1,65 @@
//===- onnx_decompose.cpp - ONNX High Level Rewriting ---------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file implements a set of rewriters to decompose an ONNX operation into
// composition of other ONNX operations.
//
// This pass is applied before any other pass so that there is no need to
// implement shape inference for the decomposed operation. Hence, it is expected
// that there is no knowledge about tensor shape at this point
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "src/dialect/onnx/onnx_ops.hpp"
#include "src/pass/passes.hpp"
using namespace mlir;
namespace {
/// Include the patterns defined in the Declarative Rewrite framework.
#include "src/onnx_decompose.inc"
struct DecomposeONNXToONNXPass : public FunctionPass<DecomposeONNXToONNXPass> {
void runOnFunction() final;
};
} // end anonymous namespace.
void DecomposeONNXToONNXPass::runOnFunction() {
auto function = getFunction();
MLIRContext *context = &getContext();
ConversionTarget target(getContext());
target.addLegalDialect<ONNXOpsDialect>();
// These ops will be decomposed into other ONNX ops. Hence, they will not be
// available after this pass.
target.addIllegalOp<ONNXReduceL1Op>();
target.addIllegalOp<ONNXReduceL2Op>();
target.addIllegalOp<ONNXReduceLogSumOp>();
target.addIllegalOp<ONNXReduceLogSumExpOp>();
target.addIllegalOp<ONNXReduceSumSquareOp>();
OwningRewritePatternList patterns;
populateWithGenerated(context, &patterns);
if (failed(applyPartialConversion(function, target, patterns)))
signalPassFailure();
} // end anonymous namespace
/*!
* Create a DecomposeONNX pass.
*/
std::unique_ptr<mlir::Pass> mlir::createDecomposeONNXToONNXPass() {
return std::make_unique<DecomposeONNXToONNXPass>();
}
static PassRegistration<DecomposeONNXToONNXPass> pass("decompose-onnx",
"Decompose ONNX operations into composition of other ONNX operations.");

View File

@ -0,0 +1,57 @@
//===----------------------------------------------------------------------===//
//=- onnx_decompose.td - Rewriting for decomposing ONNX Ops -*- tablegen -*===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// Defines language-specific pattern match rewritings for ONNX using
// Declarative Rewrite Rules (DRR) specified using TableGen records.
//
#ifndef ONNX_DECOMPOSE
#define ONNX_DECOMPOSE
#ifndef OP_BASE
include "dialect/onnx/onnx.td"
#endif // OP_BASE
/// Note: The DRR definition used for defining patterns is shown below:
///
/// class Pattern<
/// dag sourcePattern, list<dag> resultPatterns,
/// list<dag> additionalConstraints = [],
/// dag benefitsAdded = (addBenefit 0)
/// >;
//===----------------------------------------------------------------------===//
// ONNXReduceL1Op %X = ONNXReduceSumOp (ONNXAbsOp %X)
//===----------------------------------------------------------------------===//
def ReduceL1OpPattern: Pat<(ONNXReduceL1Op $oprd, $axes, $keepdims),
(ONNXReduceSumOp (ONNXAbsOp $oprd), $axes, $keepdims)>;
//===----------------------------------------------------------------------===//
// ONNXReduceL2Op %X = ONNXSqrtOp (ONNXReduceSumSquareOp (%X))
//===----------------------------------------------------------------------===//
def ReduceL2OpPattern: Pat<(ONNXReduceL2Op $oprd, $axes, $keepdims),
(ONNXSqrtOp (ONNXReduceSumSquareOp $oprd, $axes, $keepdims))>;
//===----------------------------------------------------------------------===//
// ONNXReduceLogSumOp %X = ONNXLogOp (ONNXReduceSumOp (%X))
//===----------------------------------------------------------------------===//
def ReduceLogSumOpPattern: Pat<(ONNXReduceLogSumOp $oprd, $axes, $keepdims),
(ONNXLogOp (ONNXReduceSumOp $oprd, $axes, $keepdims))>;
//===----------------------------------------------------------------------===//
// ONNXReduceLogSumExpOp %X = ONNXReduceLogSumOp (ONNXExpOp %X)
//===----------------------------------------------------------------------===//
def ReduceLogSumExpOpPattern: Pat<(ONNXReduceLogSumExpOp $oprd, $axes, $keepdims),
(ONNXReduceLogSumOp (ONNXExpOp $oprd), $axes, $keepdims)>;
//===----------------------------------------------------------------------===//
// ONNXReduceSumSquareOp %X = ONNXReduceSumOp (ONNXMulOp %X, %X)
//===----------------------------------------------------------------------===//
def ReduceSumSquareOpPattern: Pat<(ONNXReduceSumSquareOp $oprd, $axes, $keepdims),
(ONNXReduceSumOp (ONNXMulOp $oprd, $oprd), $axes, $keepdims)>;
#endif // ONNX_DECOMPOSE

View File

@ -118,35 +118,6 @@ struct SplitConvOpPattern : public RewritePattern {
};
} // end anonymous namespace
/// on the ONNXReduceL1Op.
void ONNXReduceL1Op::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ReduceL1OpPattern>(context);
}
/// on the ONNXReduceL2Op.
void ONNXReduceL2Op::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ReduceL2OpPattern>(context);
}
/// on the ONNXReduceLogSumOp.
void ONNXReduceLogSumOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ReduceLogSumOpPattern>(context);
}
/// on the ONNXReduceLogSumExpOp.
void ONNXReduceLogSumExpOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ReduceLogSumExpOpPattern>(context);
}
/// on the ONNXReduceSumSquareOp.
void ONNXReduceSumSquareOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ReduceSumSquareOpPattern>(context);
}
/// on the ONNXReduceSumSquareOp.
void ONNXConvNoBiasOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {

View File

@ -24,34 +24,4 @@ include "dialect/onnx/onnx.td"
/// dag benefitsAdded = (addBenefit 0)
/// >;
//===----------------------------------------------------------------------===//
// ONNXReduceL1Op %X = ONNXReduceSumOp (ONNXAbsOp %X)
//===----------------------------------------------------------------------===//
def ReduceL1OpPattern: Pat<(ONNXReduceL1Op $oprd, $axes, $keepdims),
(ONNXReduceSumOp (ONNXAbsOp $oprd), $axes, $keepdims)>;
//===----------------------------------------------------------------------===//
// ONNXReduceL2Op %X = ONNXSqrtOp (ONNXReduceSumSquareOp (%X))
//===----------------------------------------------------------------------===//
def ReduceL2OpPattern: Pat<(ONNXReduceL2Op $oprd, $axes, $keepdims),
(ONNXSqrtOp (ONNXReduceSumSquareOp $oprd, $axes, $keepdims))>;
//===----------------------------------------------------------------------===//
// ONNXReduceLogSumOp %X = ONNXLogOp (ONNXReduceSumOp (%X))
//===----------------------------------------------------------------------===//
def ReduceLogSumOpPattern: Pat<(ONNXReduceLogSumOp $oprd, $axes, $keepdims),
(ONNXLogOp (ONNXReduceSumOp $oprd, $axes, $keepdims))>;
//===----------------------------------------------------------------------===//
// ONNXReduceLogSumExpOp %X = ONNXReduceLogSumOp (ONNXExpOp %X)
//===----------------------------------------------------------------------===//
def ReduceLogSumExpOpPattern: Pat<(ONNXReduceLogSumExpOp $oprd, $axes, $keepdims),
(ONNXReduceLogSumOp (ONNXExpOp $oprd), $axes, $keepdims)>;
//===----------------------------------------------------------------------===//
// ONNXReduceSumSquareOp %X = ONNXReduceSumOp (ONNXMulOp %X, %X)
//===----------------------------------------------------------------------===//
def ReduceSumSquareOpPattern: Pat<(ONNXReduceSumSquareOp $oprd, $axes, $keepdims),
(ONNXReduceSumOp (ONNXMulOp $oprd, $oprd), $axes, $keepdims)>;
#endif // ONNX_REWRITE

View File

@ -15,6 +15,9 @@
namespace mlir {
class Pass;
/// Pass for rewriting inside frontend dialect.
std::unique_ptr<Pass> createDecomposeONNXToONNXPass();
std::unique_ptr<Pass> createShapeInferencePass();
/// Add pass for lowering to Krnl IR.

View File

@ -38,53 +38,6 @@ func @test_identity_identity(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>) ->
"std.return"(%2) : (tensor<10x10xf32>) -> ()
}
// CHECK-LABEL: @test_reducel1(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
func @test_reducel1(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceL1"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-NEXT: [[ABS:%.+]] = "onnx.Abs"(%arg0) : (tensor<?x?x?xf32>) -> tensor<*xf32>
// CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[ABS]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
}
// CHECK-LABEL: @test_reducel2(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
func @test_reducel2(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceL2"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-NEXT: [[MUL:%.+]] = "onnx.Mul"(%arg0, %arg0) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<*xf32>
// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"([[MUL]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: [[SQRT:%.+]] = "onnx.Sqrt"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
}
// CHECK-LABEL: @test_reducelogsum(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
func @test_reducelogsum(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceLogSum"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"(%arg0) {axes = [1], keepdims = 0 : i64} : (tensor<?x?x?xf32>) -> tensor<*xf32>
// CHECK-NEXT: [[LOG:%.+]] = "onnx.Log"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
}
// CHECK-LABEL: @test_reducelogsumexp(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
func @test_reducelogsumexp(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceLogSumExp"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-NEXT: [[EXP:%.+]] = "onnx.Exp"(%arg0) : (tensor<?x?x?xf32>) -> tensor<*xf32>
// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"([[EXP]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: [[LOG:%.+]] = "onnx.Log"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
}
// CHECK-LABEL: @test_reducesumsquare(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
func @test_reducesumsquare(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceSumSquare"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-NEXT: [[SQUARE:%.+]] = "onnx.Mul"(%arg0, %arg0) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<*xf32>
// CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[SQUARE]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
}
// CHECK-LABEL: @test_constant_pad(%{{.*}}: tensor<?x?xf32>) -> tensor<*xf32> {
func @test_constant_pad(%arg0 : tensor<?x?xf32>) -> tensor<*xf32> {
// CHECK-NEXT: [[SQUARE:%.+]] = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 2, 0, 0]} : (tensor<?x?xf32>) -> tensor<*xf32>

View File

@ -0,0 +1,49 @@
// RUN: onnf-opt --decompose-onnx %s -split-input-file | FileCheck %s
// CHECK-LABEL: @test_reducel1(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
func @test_reducel1(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceL1"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-NEXT: [[ABS:%.+]] = "onnx.Abs"(%arg0) : (tensor<?x?x?xf32>) -> tensor<*xf32>
// CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[ABS]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
}
// CHECK-LABEL: @test_reducel2(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
func @test_reducel2(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceL2"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-NEXT: [[MUL:%.+]] = "onnx.Mul"(%arg0, %arg0) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<*xf32>
// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"([[MUL]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: [[SQRT:%.+]] = "onnx.Sqrt"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
}
// CHECK-LABEL: @test_reducelogsum(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
func @test_reducelogsum(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceLogSum"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"(%arg0) {axes = [1], keepdims = 0 : i64} : (tensor<?x?x?xf32>) -> tensor<*xf32>
// CHECK-NEXT: [[LOG:%.+]] = "onnx.Log"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
}
// CHECK-LABEL: @test_reducelogsumexp(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
func @test_reducelogsumexp(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceLogSumExp"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-NEXT: [[EXP:%.+]] = "onnx.Exp"(%arg0) : (tensor<?x?x?xf32>) -> tensor<*xf32>
// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"([[EXP]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: [[LOG:%.+]] = "onnx.Log"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
}
// CHECK-LABEL: @test_reducesumsquare(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
func @test_reducesumsquare(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceSumSquare"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-NEXT: [[SQUARE:%.+]] = "onnx.Mul"(%arg0, %arg0) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<*xf32>
// CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[SQUARE]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
}