Add zeta and broadcasting_zeta to chlo dialect.
PiperOrigin-RevId: 354500879
This commit is contained in:
parent
eb8d5a5e39
commit
e61ef86fdb
|
@ -66,7 +66,7 @@ class HLOClient_Op<string mnemonic, list<OpTrait> traits> :
|
|||
// broadcasting (via the broadcast_dimensions attribute) and implicit degenerate
|
||||
// shape broadcasting.
|
||||
//
|
||||
// These correspond to operations in the mhlo dialect without the
|
||||
// These correspond to operations in the chlo and mhlo dialects without the
|
||||
// "broadcast_" prefix, except that those ops require same-shaped operands and
|
||||
// results.
|
||||
//
|
||||
|
@ -256,8 +256,31 @@ def HLOClient_BroadcastSubOp : HLOClient_BroadcastBinaryElementwiseOp<
|
|||
}];
|
||||
}
|
||||
|
||||
def HLOCLient_BroadcastZetaOp : HLOClient_BroadcastBinaryElementwiseOp<
|
||||
"broadcast_zeta",
|
||||
[NoSideEffect, SameOperandsAndResultElementType]> {
|
||||
let summary = "Hurwitz zeta function";
|
||||
|
||||
let description = [{
|
||||
Returns `Zeta(operand, operand)` element-wise.
|
||||
|
||||
$$
|
||||
\(\zeta(x, q) = \sum_{n=0}^{\infty} (q + n)^{-x}\)
|
||||
$$
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
HLO_FpTensor:$lhs,
|
||||
HLO_FpTensor:$rhs,
|
||||
// Explicit rank-broadcast dimension mappings. Defaults to "numpy" prefix
|
||||
// padded rank-broadcast semantics if omitted.
|
||||
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
|
||||
);
|
||||
let results = (outs HLO_FpTensor);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA binary elementwise op definitions.
|
||||
// XLA binary logical elementwise op definitions.
|
||||
// The same description as the arithmetic binary elementwise ops applies.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
@ -309,6 +332,38 @@ def HLOClient_BroadcastXorOp : HLOClient_BroadcastBinaryLogicalElementwiseOp<
|
|||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA non-broadcasting binary operations.
|
||||
//
|
||||
// These are operations that are supported by the XLA Builder API but that are
|
||||
// not part of the HLO compiler instructions as modelled by the MHLO dialect.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def HLOClient_ZetaOp : HLOClient_Op<"zeta",
|
||||
[NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Hurwitz zeta function";
|
||||
|
||||
let description = [{
|
||||
Returns `Zeta(operand, operand)` element-wise.
|
||||
|
||||
$$
|
||||
\(\zeta(x, q) = \sum_{n=0}^{\infty} (q + n)^{-x}\)
|
||||
$$
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
HLO_FpTensor:$lhs,
|
||||
HLO_FpTensor:$rhs
|
||||
);
|
||||
|
||||
let results = (outs HLO_FpTensor);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$lhs `,` $rhs attr-dict `:`
|
||||
`(` type($lhs) `,` type($rhs) `)` `->` type(results)
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Broadcasting complex op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -53,7 +53,7 @@ struct HloCompareAdaptor {
|
|||
};
|
||||
|
||||
// Populate a pattern for each Broadcasting CHlo op. This requires the pattern
|
||||
// to take a ChloOpTy, MhloOpTy, and an Adaptor as templated values.
|
||||
// to take a ChloOpTy, NonBroadcastingOpTy, and an Adaptor as templated values.
|
||||
template <template <typename, typename, typename> class Pattern,
|
||||
typename... ConstructorArgs>
|
||||
void PopulateForBroadcastingBinaryOp(MLIRContext *context,
|
||||
|
@ -79,6 +79,7 @@ void PopulateForBroadcastingBinaryOp(MLIRContext *context,
|
|||
POPULATE_BCAST(BroadcastShiftRightLogicalOp, mhlo::ShiftRightLogicalOp);
|
||||
POPULATE_BCAST(BroadcastSubOp, mhlo::SubOp);
|
||||
POPULATE_BCAST(BroadcastXorOp, mhlo::XorOp);
|
||||
POPULATE_BCAST(BroadcastZetaOp, ZetaOp);
|
||||
|
||||
// Broadcasting ops requiring special construction.
|
||||
patterns
|
||||
|
|
|
@ -318,6 +318,7 @@ BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightArithmeticOp);
|
|||
BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightLogicalOp);
|
||||
BROADCAST_BINARY_OP_DEFS(BroadcastSubOp);
|
||||
BROADCAST_BINARY_OP_DEFS(BroadcastXorOp);
|
||||
BROADCAST_BINARY_OP_DEFS(BroadcastZetaOp);
|
||||
|
||||
#undef BROADCAST_INFER_SHAPE_TYPE_OP_DEFS
|
||||
#undef BROADCAST_BINARY_OP_DEFS
|
||||
|
|
|
@ -39,14 +39,14 @@ struct ChloLegalizeToHloPass
|
|||
OwningRewritePatternList conversionPatterns;
|
||||
conversionTarget.addIllegalDialect<chlo::HloClientDialect>();
|
||||
|
||||
// Consider the mhlo dialect legal for tests.
|
||||
conversionTarget.addLegalDialect<mhlo::MhloDialect>();
|
||||
// Consider the mhlo dialect legal for tests. Also add helper dialects
|
||||
// that are needed by the patterns.
|
||||
conversionTarget.addLegalDialect<
|
||||
MhloDialect, mlir::StandardOpsDialect, mlir::tensor::TensorDialect,
|
||||
mlir::shape::ShapeDialect, mlir::scf::SCFDialect>();
|
||||
|
||||
// The conversion uses helpers from the standard dialect.
|
||||
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
|
||||
conversionTarget.addLegalDialect<mlir::tensor::TensorDialect>();
|
||||
conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>();
|
||||
conversionTarget.addLegalDialect<mlir::scf::SCFDialect>();
|
||||
// TODO(herhut): This is temporary while Zeta cannot be lowered to hlo.
|
||||
conversionTarget.addLegalOp<chlo::ZetaOp>();
|
||||
|
||||
chlo::PopulateLegalizeChloToHloPatterns(&getContext(), &conversionPatterns);
|
||||
|
||||
|
|
|
@ -237,3 +237,13 @@ func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4x
|
|||
%0 = chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
|
||||
return %0 : tensor<4xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @ZetaWithoutBroadcast
|
||||
func @ZetaWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>)
|
||||
-> tensor<4xf32> {
|
||||
// CHECK: chlo.zeta %arg0, %arg1
|
||||
%0 = chlo.broadcast_zeta %arg0, %arg1
|
||||
: (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue