Commit Graph

438 Commits

Author SHA1 Message Date
Geoffrey Martin-Noble f9f7a63870 Add missing dep on RAL pass generation
Without this I see errors about being unable to find the generated header in our project's build.

PiperOrigin-RevId: 379377718
2021-06-14 17:02:26 -07:00
Wenyi Zhao 7f94bd923b PR #50236: [MLIR][DISC] Bufferize TransposeOp and ConcatenateOp
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/50236

support hlo-to-lhlo conversion for TransposeOp and ConcatenateOp
Copybara import of the project:

--
62860e717f2a14fbd3ddfb634aa6ff132d245a72 by Wenyi Zhao <reyizero@gmail.com>:

[MLIR][DISC] Bufferize TransposeOp and ConcatenateOp

--
ce2ff57c1edee1172cd2f36346cc0b34ec1c7467 by Wenyi Zhao <reyizero@gmail.com>:

fix

PiperOrigin-RevId: 379330954
2021-06-14 12:37:45 -07:00
Wenyi Zhao 23ebbb28d1 PR #50191: [MLIR][DISC] Add RAL (Runtime abstraction layer) Dialect
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/50191

DISC is a e2e flow, including both compiler side and runtime side. For
runtime side, we have different targeting environments (e.g. tensorflow,
pytorch, or sometimes even a standalone binary). In order to simplify
the design of the compiler side, we design a Runtime Abstraction Layer
(RAL) to sperate the compiler side and runtime side. Thus the compiler
side only need to target RAL itself and it is the responsibility of RAL
to handle the differences between different targeting environments.

One of the most important functions of RAL is to manage stateful
resources. To this end, it provides a context object, and hides all
stateful operations behind this context, thus the compiler side itself
doesn't need to care about the resource initialization. For example, a
kernel must be loaded before it can be launched on GPU. However, the
loading operation should only be taken once during the whole lifetime of
the context in order to achieve the best performance. Based on the
initialization-free interfaces provided by RAL, compiler side can focus
on its core optimization logic and lets the RAL to manage the resource
status.

The context mentioned above is passed as a parameter to the entry
function and all RAL APIs should always use the context as their first
argument. This CR also provides a pass to help to ensure this property.
The pass rewrites the entry function to make sure their first argument
is the context. For entry function, the pass also rewrites its inputs
and outputs. To be concrete, all the original inputs and outputs of the
entry function are received from and sent to RAL through a sequence of
RAL API calls correspondingly. The motivation behind this is to hide the
implementation details of I/Os. This design may also potentially enable
partial execution of the compiled module when some of the inputs are
ready.
Copybara import of the project:

--
c4f20a89aed71181e75bcc5265723b88bde23240 by Wenyi Zhao <reyizero@gmail.com>:

[MLIR][DISC] Add RAL (Runtime abstraction layer) Dialect

DISC is a e2e flow, including both compiler side and runtime side. For
runtime side, we have different targeting environments (e.g. tensorflow,
pytorch, or sometimes even a standalone binary). In order to simplify
the design of the compiler side, we design a Runtime Abstraction Layer
(RAL) to sperate the compiler side and runtime side. Thus the compiler
side only need to target RAL itself and it is the responsibility of RAL
to handle the differences between different targeting environments.

One of the most important functions of RAL is to manage stateful
resources. To this end, it provides a context object, and hides all
stateful operations behind this context, thus the compiler side itself
doesn't need to care about the resource initialization. For example, a
kernel must be loaded before it can be launched on GPU. However, the
loading operation should only be taken once during the whole lifetime of
the context in order to achieve the best performance. Based on the
initialization-free interfaces provided by RAL, compiler side can focus
on its core optimization logic and lets the RAL to manage the resource
status.

The context mentioned above is passed as a parameter to the entry
function and all RAL APIs should always use the context as their first
argument. This CR also provides a pass to help to ensure this property.
The pass rewrites the entry function to make sure their first argument
is the context. For entry function, the pass also rewrites its inputs
and outputs. To be concrete, all the original inputs and outputs of the
entry function are received from and sent to RAL through a sequence of
RAL API calls correspondingly. The motivation behind this is to hide the
implementation details of I/Os. This design may also potentially enable
partial execution of the compiled module when some of the inputs are
ready.

--
1991d4f80ab6087943956e1c0fec4940a22ab08d by Wenyi Zhao <reyizero@gmail.com>:

fix

PiperOrigin-RevId: 379317586
2021-06-14 11:27:43 -07:00
Rahul Joshi a6011d0279 [HLO] Add AllReduceScatter to MHLO and LMHLO dialects.
PiperOrigin-RevId: 379296198
2021-06-14 09:37:07 -07:00
Wenyi Zhao 8388303fd2 PR #50211: [MLIR][DISC] Bufferize RealDynamicSliceOp and ReduceOp
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/50211

support hlo-to-lhlo conversion for RealDynamicSliceOp and ReduceOp
Copybara import of the project:

--
c417b336670a1fc256f7026dfe8080e46d13d79a by Wenyi Zhao <reyizero@gmail.com>:

[MLIR][DISC] Bufferize RealDynamicSliceOp and ReduceOp

PiperOrigin-RevId: 378972113
2021-06-11 16:33:15 -07:00
Jacques Pienaar 95ba03534f Allow variadic operands/result in MHLO while
This just adds support for it in the op, but keeps the production/uses as is (e.g., single tensor or tuple) matching what XLA export requires. In follow up here, would be to add pass for export to retuple and then the canonical form could be changed. Tuple'ing given control flow via regions & multi-result operations does not add representational power and all the get_tuple_element ops obscure the computation.

The old form allowed single tensor or tuple. The new variadic number of tensor or tuples as tuples may be nested, so the input could have (Tensor<..>, Tuple<Tensor<...>, Tuple<...>, ...>, Tensor<...>) and HLO_Tensor doesn't allow Tuples.

PiperOrigin-RevId: 378934388
2021-06-11 13:08:28 -07:00
A. Unique TensorFlower bd5752f0bf [MLIR][HLO] Find shape equivalences and use them for better rank specialization
Find shape equivalence classes among the operands and use them for better rank
specialization. If all operands are known to be of the same shape, we can
flatten them to rank one. If there are two shape equivalence classes, we can
generalize the scalar rank specialization cases.

PiperOrigin-RevId: 378844575
2021-06-11 04:00:26 -07:00
Wenyi Zhao 6660234d80 PR #50100: [MLIR][DISC] Bufferize DynamicIotaOp and DynamicPadOp
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/50100

support hlo-to-lhlo conversion for DynamicIotaOp and DynamicPadOp
Copybara import of the project:

--
c3aae94954e35d3f8ad265f619ef9765665a5115 by Wenyi Zhao <reyizero@gmail.com>:

[MLIR][DISC] Bufferize DynamicIotaOp and DynamicPadOp

--
adc6996d70b804d61310d56a33fac975d70c8636 by Wenyi Zhao <reyizero@gmail.com>:

minor

PiperOrigin-RevId: 378733284
2021-06-10 14:20:45 -07:00
A. Unique TensorFlower 14093b7906 [XLA:GPU] Add AllReduce{Start,Done} to MLIR LHLO dialect.
PiperOrigin-RevId: 378681070
2021-06-10 10:27:22 -07:00
Chris Jones 968226b9d7 [XLA:GPU] Add AllReduce{Start,Done} to MLIR LHLO dialect.
PiperOrigin-RevId: 378640706
2021-06-10 06:54:42 -07:00
Adrian Kuegel 6088eb697c Fix Cosh approximation for F16.
We should upcast F16 to F32 to prevent precision loss.
E.g. cosh(-9) would evaluate to 4042 previously instead of 4052.
This allows to enable the MLIR generated kernel for F16 type.
Also move template instantiation for Sinh to inside the #ifdef block.
This was missed in a previous commit.

PiperOrigin-RevId: 378635042
2021-06-10 06:16:44 -07:00
A. Unique TensorFlower 9f67417b41 [MLIR][HLO] Avoid duplicate cluster operands when merging
When merging rank specialization clusters, avoid duplicating operands. A fewer
number of operands usually allows better rank specialization.

PiperOrigin-RevId: 378445946
2021-06-09 10:54:55 -07:00
A. Unique TensorFlower b580722041 [MLIR][KernelGen] Merge rank specialization clusters
Merge adjacent rank specialization clusters. Combine their operands, bodies, and
results.

PiperOrigin-RevId: 378433222
2021-06-09 10:07:47 -07:00
A. Unique TensorFlower b9e45007d5 [MLIR][HLO] Extend broadcast propagation pass to enable more fusion
Move element-wise operations into assuming regions. This enables fusion
opportunities within the region.

PiperOrigin-RevId: 378362725
2021-06-09 03:03:37 -07:00
A. Unique TensorFlower d828b457b3 Handle empty tensors in SimplifyConcatSlice.
If the result of the slice is an empty tensor, do nothing.
This fixes a crash: we can't create a `concat` with an
empty operand range.

PiperOrigin-RevId: 378354956
2021-06-09 02:15:47 -07:00
Adrian Kuegel 9a8c254526 Support complex types for Sinh.
Because mhlo::ConstantLike doesn't support complex types, we need to use
GetScalarOfType and broadcast it to the needed shape.
Disable the tf2xla fallback, now that MLIR fully supports Sinh.

PiperOrigin-RevId: 378123151
2021-06-08 04:23:19 -07:00
A. Unique TensorFlower c47869f931 [MLIR][HLO] Rename `move-up-dynamic-broadcasts-for-fusion` to `broadcast-propagation`
PiperOrigin-RevId: 378102608
2021-06-08 01:51:10 -07:00
Benjamin Kramer d1c60df2fe [MHLO:linalg] Be more aggressive about turning mhlo.const into std.constant
On tensors the only difference between these ops is that mhlo.const supports unsigned types.

PiperOrigin-RevId: 377970948
2021-06-07 11:58:23 -07:00
Hanhan Wang 25b93c8d66 Add support for lowering mhlo.iota/dynamic_iota to Linalg on unsigned types.
PiperOrigin-RevId: 377956338
2021-06-07 10:59:33 -07:00
Adrian Kuegel 5315997402 Fix Sinh approximation for F16.
We should upcast F16 to F32 to prevent precision loss.
E.g. sinh(-9) would evaluate to -4042 previously instead of -4052.
This allows to enable the MLIR generated kernel for F16 type.

PiperOrigin-RevId: 377901896
2021-06-07 06:38:42 -07:00
Tobias Gysi fc723380e6 Update lhlo to use the new structured op interface.
Replace deprecated methods in lhlo_fuse_linalg.cc. The new structured op interface has been introduced in https://reviews.llvm.org/D103394.

PiperOrigin-RevId: 377875452
2021-06-07 03:11:03 -07:00
Wenyi Zhao ade873a5e0 PR #49970: [MLIR][DISC] bufferize DynamicReshape and DynamicBroadcastInDim
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/49970

1, add hlo-to-lhlo support for DynamicReshape and DynamicBroadcastInDim

2, add a flag `convert-to-lmhlo-only` to seperate following two case:
   - hlo-to-lhlo only. Simply lowers all mhlo ops to their lmhlo
     counterparts, do not apply any optimization (e.g. elide any
     buffer copy). Buffer optimization is not easy in dynamic
     shape world especially when involving control flow, thus we
     leave this to another dedicated pass.

   - hlo-to-lhlo-or-memref-directly. Lowers some metadata-only mhlo
     ops (e.g. reshape) to memref dialect directly and Lowers others
     to their lmhlo counterparts.
Copybara import of the project:

--
562bd65a368f6194405c4ae6900e3b4388a5ec03 by Wenyi Zhao <reyizero@gmail.com>:

[MLIR][DISC] bufferize DynamicReshape and DynamicBroadcastInDim

1, add hlo-to-lhlo support for DynamicReshape and DynamicBroadcastInDim

2, add a flag `convert-to-lmhlo-only` to seperate following two case:
   - hlo-to-lhlo only. Simply lowers all mhlo ops to their lmhlo
     counterparts, do not apply any optimization (e.g. elide any
     buffer copy). Buffer optimization is not easy in dynamic
     shape world especially when involving control flow, thus we
     leave this to another dedicated pass.

   - hlo-to-lhlo-or-memref-directly. Lowers some metadata-only mhlo
     ops (e.g. reshape) to memref dialect directly and Lowers others
     to their lmhlo counterparts.

PiperOrigin-RevId: 377603395
2021-06-04 15:36:03 -07:00
A. Unique TensorFlower db05388a3c Integrate LLVM at llvm/llvm-project@da3ed58b97
Updates LLVM usage to match
[da3ed58b97c1](https://github.com/llvm/llvm-project/commit/da3ed58b97c1)

PiperOrigin-RevId: 377432380
2021-06-03 20:45:18 -07:00
A. Unique TensorFlower aba16adfa5 Add `mhlo.all_gather` op to MHLO dialect.
Adds import/export/verifier support as well.
Also makes `channel_handle` uniform across mhlo.all_reduce and mhlo.all-gather.

PiperOrigin-RevId: 377323468
2021-06-03 10:45:29 -07:00
A. Unique TensorFlower 4620410f18 Integrate LLVM at llvm/llvm-project@b25546a4b4
Updates LLVM usage to match
[b25546a4b406](https://github.com/llvm/llvm-project/commit/b25546a4b406)

PiperOrigin-RevId: 377077163
2021-06-02 09:32:59 -07:00
A. Unique TensorFlower 75a1c450ea [MLIR][KernelGen] Fix Windows build failure
Fix usage of default constructor. Instead, always use the parameterized
constructor and make the maximum supported rank explicit.

PiperOrigin-RevId: 377037155
2021-06-02 05:34:44 -07:00
wyzhao 968d4b8709 PR #49598: [MLIR][DISC] legalize tensor_load inserted during hlo-to-lhlo conversion
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/49598

This PR implements logic for lowering memref.tensor_load ops that are
inserted during `mhlo-legalize-to-lmhlo`
Copybara import of the project:

--
80eb377af4e02182e1aecc943a41ca5d7d1c2100 by Wenyi Zhao <reyizero@gmail.com>:

[MLIR][DISC] legalize tensor_load inserted during hlo-to-lhlo conversion

This PR implements logic for lowering memref.tensor_load ops that are
inserted during `mhlo-legalize-to-lmhlo`.

--
ac452fe3dcd591211cd5c59be9189fe2f7153b41 by Wenyi Zhao <reyizero@gmail.com>:

minor fix

--
6b36017f8632a06adbc3e05a62975fa641d0260f by Wenyi Zhao <reyizero@gmail.com>:

minor refine

--
846005cc76d0033112e47825c2e9a97790b6925f by Wenyi Zhao <reyizero@gmail.com>:

minor fix

--
f6a4becaa287d5ca323b2d152a4d0ae053730fd9 by Wenyi Zhao <reyizero@gmail.com>:

fix

--
5555749f60f7fce8f57962860ef65efccf0362ba by Wenyi Zhao <reyizero@gmail.com>:

fix

--
8873b9b6d9315c1199ca9f7c133ecf377ecd2fa6 by Wenyi Zhao <reyizero@gmail.com>:

fix

PiperOrigin-RevId: 376942547
2021-06-01 16:27:56 -07:00
A. Unique TensorFlower d1828625ab [MLIR][KernelGen] Make maximum supported rank in rank specialization configurable
The maximum supported target rank of 5 is sufficient for all operations but
`select`. Make the maximum target rank configurable in the rank specialization.
This reduces the number of generated kernels for operations that don't require
it.

PiperOrigin-RevId: 376822496
2021-06-01 06:54:31 -07:00
A. Unique TensorFlower c7c245eaf1 [MLIR][KernelGen] Add MLIR-generated Xlogy kernel
Add the first MLIR-generated kernel that relies on an in-TF lowering. Fusion for
this kernel relies on the generalized rank specialization for operation groups.

PiperOrigin-RevId: 376805435
2021-06-01 04:48:18 -07:00
A. Unique TensorFlower f16e5a3a67 [MLIR][HLO] Use canonicalization patterns in broadcast propagation pass
Replace local canonicalization patterns with those from upstream.

PiperOrigin-RevId: 376794178
2021-06-01 03:14:31 -07:00
A. Unique TensorFlower 31536431e0 [MLIR][HLO] Eliminate duplicate broadcastable constraints
PiperOrigin-RevId: 376718433
2021-05-31 13:50:23 -07:00
A. Unique TensorFlower 0f341012c6 [MLIR][HLO] Eliminate duplicate broadcastable constraints
PiperOrigin-RevId: 376715240
2021-05-31 13:08:02 -07:00
A. Unique TensorFlower 511a1db4f3 [MLIR][HLO] Use canonicalization patterns in broadcast propagation pass
Replace local canonicalization patterns with those from upstream.

PiperOrigin-RevId: 376708719
2021-05-31 12:01:26 -07:00
A. Unique TensorFlower 5f5db13715 [MLIR][HLO] Use canonicalization patterns in broadcast propagation pass
Replace local canonicalization patterns with those from upstream.

PiperOrigin-RevId: 376707588
2021-05-31 11:44:50 -07:00
A. Unique TensorFlower cc1b22e618 [HLO][Linalg] Support scalar broadcasts in point-wise converter
This is needed for operations that support this limited form of broadcasting,
namely `mhlo.select`.

PiperOrigin-RevId: 376655844
2021-05-31 03:50:23 -07:00
Hanhan Wang 402b74ed7f Fix type bug in mhlo.dynamic-update-slice lowering.
The operand type can be f32. We should not use operand type to do clamp
operations.

PiperOrigin-RevId: 376286524
2021-05-27 17:53:49 -07:00
Adrian Kuegel a4fa6afa07 [mlir][hlo] Avoid dyn_cast_or_null when called with getDefiningOp result (NFC)
PiperOrigin-RevId: 376110457
2021-05-27 00:20:42 -07:00
Hanhan Wang 28c411606f Add support for lowering mhlo.dynamic-update-slice ops to Linalg and std ops.
PiperOrigin-RevId: 376042810
2021-05-26 15:31:05 -07:00
Robert Suderman 26a0053d7d Remove linalg.indexed_generic from mhlo lowerings to linalg
IndexedGeneric is going away. Transition to using linalg.Index instead.

PiperOrigin-RevId: 376002501
2021-05-26 12:24:23 -07:00
A. Unique TensorFlower 4ebcebf31c [MLIR][HLO] Exploit scalar properties in rank specialization lowering
Take advantage of the fact that scalars are already ranked and that they are
neutral elements to broadcasting. Do not reshape scalars, do not consider them
for broadcasting, and materialize ranked operations on scalars accordingly.

PiperOrigin-RevId: 375968371
2021-05-26 09:59:13 -07:00
A. Unique TensorFlower cb46298a07 [MLIR][HLO] Support all smaller ranks in rank specialization cases
Rank specialization cases can be applied to all argument tensors of smaller
ranks than the expected maximum rank. This is crucial if all operands are
effectively scalars and the maximum reduced rank is 0.

PiperOrigin-RevId: 375712020
2021-05-25 08:38:53 -07:00
wyzhao b93e54d8a4 PR #49454: [MLIR][DISC] Upgrade to use the new `reifyReturnTypeShapes` interface.
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/49454

The new interface is more safe to be used during dialect conversion
(e.g. converting from tensor world to buffer world).
Copybara import of the project:

--
a6968072d59bec3c3bbaef0121d297e807c37c91 by Wenyi Zhao <reyizero@gmail.com>:

[MLIR][DISC] Upgrade to use the new `reifyReturnTypeShapes` interface.

The new interface is more safe to be used during dialect conversion
(e.g. converting from tensor world to buffer world).

--
55e7c6b7f2f99b99e226645a57e2433fae3e90ed by Wenyi Zhao <reyizero@gmail.com>:

minor fix

PiperOrigin-RevId: 375500273
2021-05-24 10:11:55 -07:00
Stella Laurenzo 28c4112f35 Downgrade some emitErrors in patterns to notifyMatchFailure.
PiperOrigin-RevId: 375160543
2021-05-21 14:09:23 -07:00
Hanhan Wang 1ba4c714c9 Add support for lowering mhlo.scatter ops to Linalg.
This only works for updating tensors, not add/min/max computations. It requires
the index depth to be 1 because of the limitation in Linalg. We can not compare
multiple indices without packing indices.

PiperOrigin-RevId: 375137721
2021-05-21 12:17:14 -07:00
A. Unique TensorFlower 97e6103933 [MLIR][HLO] Reshape to scalars in rank specialization
Scalars were incorrectly casted to scalar tensors when they have to be reshaped.

PiperOrigin-RevId: 375049088
2021-05-21 03:12:16 -07:00
A. Unique TensorFlower 3daf65578a [MLIR][HLO] Add scalar cases for binary rank specialization
For rank specialization clusters that have only two operands, we can materialize
two extra cases in which either of them is a scalar. This avoids redundant index
computations in these cases.

PiperOrigin-RevId: 375037390
2021-05-21 01:35:44 -07:00
Feiwen a7884196f5 PR #49228: [MLIR][DISC] porting dynamic shape related OPs to mhlo and lmhlo dialect
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/49228

We are porting our MLIR-based dynamic shape compiler to tf community (From OP def, Patttern, to Optimization pass, etc).
This is the first PR, which including some dynamic shape OPs def in mhlo and lmhlo dialect.
For mhlo dialect, we add:
- HLO_RealDynamicSliceOp
- HLO_DynamicPadOp
- HLO_DynamicGatherOp
- HLO_DynamicConvOp

For lmhlo dialect, we add:
- LHLO_RealDynamicSliceOp
- LHLO_DynamicBroadcastInDimOp
- LHLO_DynamicGatherOp
- LHLO_DynamicPadOp
- LHLO_DynamicBitcastOp
- LHLO_DynamicConvOp
- LHLO_DynamicIotaOp
- LHLO_DynamicReshapeOp
- LHLO_DotGeneralOp
- LHLO_BitcastOp

Rest Ops to add:
* We will send a separate PR containing LHLO_DynamicWhileOp and LHLO_DynamicCaseOp for control flow.
* We will add a separate dedicated dialect like mhlo_ral, which including D2HOp/H2DOp/DebugPrintOp/TopKOp, etc.

Previous discussions:[RFC](https://groups.google.com/a/tensorflow.org/g/mlir/c/_X48poNcbDI/m/jCC8BWIICQAJ), [discussion_1](https://llvm.discourse.group/t/updates-on-mlir-based-dynamic-shape-compiler/2384), [Recording of meeting](https://drive.google.com/file/d/1_uEISlV5MUWdG9faKAdKlCWnPtGjRC-D/view?usp=sharing).
Copybara import of the project:

--
e22d9e61106e00a1a1c6f368cc4a03e3bd1f414c by azazhu <azazhu@gmail.com>:

[DISC]fea: porting mhlo and lmhlo OPs

--
9ec3e76290da07cbd53d7da5fa86ff67179441a1 by azazhu <azazhu@gmail.com>:

[DISC][MLIR] 1. add summary and description for dynamic OPs in mhlo and lmhlo; 2. rm InferOutputTypes; 3. add verify for RealDynamicSliceOp and DynamicPadOp

--
0d68cd135555fd935991c12456b21329e628f23f by azazhu <azazhu@gmail.com>:

[DISC][MLIR] 1.remove D2H,H2D and DebugPrint Ops from mhlo/lmhlo dialect; 2. add type constraint to DynamicPadOp and RealDynamicSliceOp; 3.refine lmhlo type constraint; 4.rename RealDynamicSliceOp as name conflict.

--
698762a77d60f6a844cb1ab3f32740d4ef3c5843 by azazhu <azazhu@gmail.com>:

[DISC][MLIR] 1. replace dyn_cast to cast 2. refine code

PiperOrigin-RevId: 375022260
2021-05-20 23:16:47 -07:00
Hanhan Wang cd8f585cf7 [MHLO:Linalg] Add support for lowering torch_index_select of unsigned tensors
Also fixes typos in tests.

PiperOrigin-RevId: 374979460
2021-05-20 17:03:05 -07:00
Rahul Joshi 41f663ce47 [HLO] Adopt custom syntax for convolution dimensions and window attributes (HLO)
PiperOrigin-RevId: 374923250
2021-05-20 12:13:50 -07:00
Rahul Joshi fc88cf1ff4 [HLO] Adopt custom syntax for convolution dims and window attributes for LMHLO_GPU
PiperOrigin-RevId: 374889917
2021-05-20 09:41:48 -07:00