Add zeta and broadcasting_zeta to chlo dialect.

PiperOrigin-RevId: 354500879
This commit is contained in:
Stephan Herhut 2021-01-29 03:21:59 -08:00 committed by TensorFlow MLIR Team
parent eb8d5a5e39
commit e61ef86fdb
5 changed files with 77 additions and 10 deletions

View File

@ -66,7 +66,7 @@ class HLOClient_Op<string mnemonic, list<OpTrait> traits> :
// broadcasting (via the broadcast_dimensions attribute) and implicit degenerate // broadcasting (via the broadcast_dimensions attribute) and implicit degenerate
// shape broadcasting. // shape broadcasting.
// //
// These correspond to operations in the mhlo dialect without the // These correspond to operations in the chlo and mhlo dialects without the
// "broadcast_" prefix, except that those ops require same-shaped operands and // "broadcast_" prefix, except that those ops require same-shaped operands and
// results. // results.
// //
@ -256,8 +256,31 @@ def HLOClient_BroadcastSubOp : HLOClient_BroadcastBinaryElementwiseOp<
}]; }];
} }
def HLOCLient_BroadcastZetaOp : HLOClient_BroadcastBinaryElementwiseOp<
"broadcast_zeta",
[NoSideEffect, SameOperandsAndResultElementType]> {
let summary = "Hurwitz zeta function";
let description = [{
Returns `Zeta(operand, operand)` element-wise.
$$
\(\zeta(x, q) = \sum_{n=0}^{\infty} (q + n)^{-x}\)
$$
}];
let arguments = (ins
HLO_FpTensor:$lhs,
HLO_FpTensor:$rhs,
// Explicit rank-broadcast dimension mappings. Defaults to "numpy" prefix
// padded rank-broadcast semantics if omitted.
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
);
let results = (outs HLO_FpTensor);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// XLA binary elementwise op definitions. // XLA binary logical elementwise op definitions.
// The same description as the arithmetic binary elementwise ops applies. // The same description as the arithmetic binary elementwise ops applies.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -309,6 +332,38 @@ def HLOClient_BroadcastXorOp : HLOClient_BroadcastBinaryLogicalElementwiseOp<
}]; }];
} }
//===----------------------------------------------------------------------===//
// XLA non-broadcasting binary operations.
//
// These are operations that are supported by the XLA Builder API but that are
// not part of the HLO compiler instructions as modelled by the MHLO dialect.
//===----------------------------------------------------------------------===//
def HLOClient_ZetaOp : HLOClient_Op<"zeta",
[NoSideEffect, SameOperandsAndResultType]> {
let summary = "Hurwitz zeta function";
let description = [{
Returns `Zeta(operand, operand)` element-wise.
$$
\(\zeta(x, q) = \sum_{n=0}^{\infty} (q + n)^{-x}\)
$$
}];
let arguments = (ins
HLO_FpTensor:$lhs,
HLO_FpTensor:$rhs
);
let results = (outs HLO_FpTensor);
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:`
`(` type($lhs) `,` type($rhs) `)` `->` type(results)
}];
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Broadcasting complex op // Broadcasting complex op
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -53,7 +53,7 @@ struct HloCompareAdaptor {
}; };
// Populate a pattern for each Broadcasting CHlo op. This requires the pattern // Populate a pattern for each Broadcasting CHlo op. This requires the pattern
// to take a ChloOpTy, MhloOpTy, and an Adaptor as templated values. // to take a ChloOpTy, NonBroadcastingOpTy, and an Adaptor as templated values.
template <template <typename, typename, typename> class Pattern, template <template <typename, typename, typename> class Pattern,
typename... ConstructorArgs> typename... ConstructorArgs>
void PopulateForBroadcastingBinaryOp(MLIRContext *context, void PopulateForBroadcastingBinaryOp(MLIRContext *context,
@ -79,6 +79,7 @@ void PopulateForBroadcastingBinaryOp(MLIRContext *context,
POPULATE_BCAST(BroadcastShiftRightLogicalOp, mhlo::ShiftRightLogicalOp); POPULATE_BCAST(BroadcastShiftRightLogicalOp, mhlo::ShiftRightLogicalOp);
POPULATE_BCAST(BroadcastSubOp, mhlo::SubOp); POPULATE_BCAST(BroadcastSubOp, mhlo::SubOp);
POPULATE_BCAST(BroadcastXorOp, mhlo::XorOp); POPULATE_BCAST(BroadcastXorOp, mhlo::XorOp);
POPULATE_BCAST(BroadcastZetaOp, ZetaOp);
// Broadcasting ops requiring special construction. // Broadcasting ops requiring special construction.
patterns patterns

View File

@ -318,6 +318,7 @@ BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightArithmeticOp);
BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightLogicalOp); BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightLogicalOp);
BROADCAST_BINARY_OP_DEFS(BroadcastSubOp); BROADCAST_BINARY_OP_DEFS(BroadcastSubOp);
BROADCAST_BINARY_OP_DEFS(BroadcastXorOp); BROADCAST_BINARY_OP_DEFS(BroadcastXorOp);
BROADCAST_BINARY_OP_DEFS(BroadcastZetaOp);
#undef BROADCAST_INFER_SHAPE_TYPE_OP_DEFS #undef BROADCAST_INFER_SHAPE_TYPE_OP_DEFS
#undef BROADCAST_BINARY_OP_DEFS #undef BROADCAST_BINARY_OP_DEFS

View File

@ -39,14 +39,14 @@ struct ChloLegalizeToHloPass
OwningRewritePatternList conversionPatterns; OwningRewritePatternList conversionPatterns;
conversionTarget.addIllegalDialect<chlo::HloClientDialect>(); conversionTarget.addIllegalDialect<chlo::HloClientDialect>();
// Consider the mhlo dialect legal for tests. // Consider the mhlo dialect legal for tests. Also add helper dialects
conversionTarget.addLegalDialect<mhlo::MhloDialect>(); // that are needed by the patterns.
conversionTarget.addLegalDialect<
MhloDialect, mlir::StandardOpsDialect, mlir::tensor::TensorDialect,
mlir::shape::ShapeDialect, mlir::scf::SCFDialect>();
// The conversion uses helpers from the standard dialect. // TODO(herhut): This is temporary while Zeta cannot be lowered to hlo.
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>(); conversionTarget.addLegalOp<chlo::ZetaOp>();
conversionTarget.addLegalDialect<mlir::tensor::TensorDialect>();
conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>();
conversionTarget.addLegalDialect<mlir::scf::SCFDialect>();
chlo::PopulateLegalizeChloToHloPatterns(&getContext(), &conversionPatterns); chlo::PopulateLegalizeChloToHloPatterns(&getContext(), &conversionPatterns);

View File

@ -237,3 +237,13 @@ func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4x
%0 = chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> %0 = chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
return %0 : tensor<4xi1> return %0 : tensor<4xi1>
} }
// -----
// CHECK-LABEL: @ZetaWithoutBroadcast
func @ZetaWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>)
-> tensor<4xf32> {
// CHECK: chlo.zeta %arg0, %arg1
%0 = chlo.broadcast_zeta %arg0, %arg1
: (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}