Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions pathwaysutils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Package of Pathways-on-Cloud utilities."""
from collections.abc import Callable
from pathwaysutils import _initialize

initialize = _initialize.initialize
is_pathways_backend_used = _initialize.is_pathways_backend_used
initialize: Callable[[], None] = _initialize.initialize
is_pathways_backend_used: Callable[[], bool] = _initialize.is_pathways_backend_used

del _initialize

# When changing this, also update the CHANGELOG.md.
__version__ = "v0.1.5"
__version__: str = "v0.1.5"
6 changes: 3 additions & 3 deletions pathwaysutils/collect_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
_logger.setLevel(logging.INFO)


_DESCRIPTION = """
_DESCRIPTION: str = """
To profile running JAX programs, you first need to start the profiler server
in the program of interest. You can do this via
`jax.profiler.start_server(<port>)`. Once the program is running and the
Expand All @@ -36,7 +36,7 @@
"""


def _get_parser():
def _get_parser() -> argparse.ArgumentParser:
"""Returns an argument parser for the collect_profile script."""
parser = argparse.ArgumentParser(description=_DESCRIPTION)
parser.add_argument(
Expand All @@ -62,7 +62,7 @@ def _get_parser():
return parser


def main():
def main() -> None:
parser = _get_parser()
args = parser.parse_args()

Expand Down
7 changes: 3 additions & 4 deletions pathwaysutils/experimental/reshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@
"""Experimental resharding API for elastic device sets."""

import base64
import collections
from collections.abc import Mapping
from collections.abc import Callable, Mapping, Sequence
import json
import logging
import math
import operator
from typing import Any, Callable, Dict, Mapping, Sequence
from typing import Any

import jax
from pathwaysutils import jax as pw_jax
Expand Down Expand Up @@ -57,7 +56,7 @@ def __init__(
):
def ifrt_hlo_sharding(
aval: jax.core.ShapedArray, sharding: jax.sharding.Sharding
) -> Dict[str, Any]:
) -> dict[str, Any]:
result = {
"devices": {
"device_ids": [
Expand Down
6 changes: 4 additions & 2 deletions pathwaysutils/experimental/split_by_mesh_axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# limitations under the License.
"""Experimental split by mesh axis API."""

from typing import Any, Sequence
from collections.abc import Sequence
from typing import Any, cast

import jax
from pathwaysutils import jax as pw_jax
Expand Down Expand Up @@ -167,7 +168,8 @@ def split_by_mesh_axis(
mesh_axis_sizes=mesh.axis_sizes,
mesh_axis_idx=mesh_axis_idx,
mesh_axis_sections=mesh_axis_sections,
submesh_shardings=submesh_shardings,
# TODO: b/491156211 - Remove cast once type mismatch is fixed.
submesh_shardings=cast(Any, submesh_shardings),
donate=donate,
)

Expand Down
10 changes: 7 additions & 3 deletions pathwaysutils/lru_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,19 @@
# limitations under the License.
"""An LRU cache that will be cleared when JAX clears its internal cache."""

from collections.abc import Callable
import functools
from typing import Any, Callable
from typing import Any, TypeVar

from jax.extend import backend


_F = TypeVar("_F", bound=Callable[..., Any])


def lru_cache(
maxsize: int = 4096,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
) -> Callable[[_F], _F]:
"""An LRU cache that will be cleared when JAX clears its internal cache.

Args:
Expand All @@ -32,7 +36,7 @@ def lru_cache(
A function that can be used to decorate a function to cache its results.
"""

def wrap(f):
def wrap(f: _F) -> _F:
cached = functools.lru_cache(maxsize=maxsize)(f)
wrapper = functools.wraps(f)(cached)

Expand Down
17 changes: 9 additions & 8 deletions pathwaysutils/persistence/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
"""Helper functions for persistence."""

import base64
from collections.abc import Sequence
import concurrent.futures
import datetime
import json
from typing import Any, Sequence, Tuple, Union
from typing import Any

import jax
from jax import core
Expand Down Expand Up @@ -93,7 +94,7 @@ def get_hlo_sharding_string(
def get_shape_info(
dtype: np.dtype,
dimensions: Sequence[int],
) -> dict[str, Union[Sequence[int], str]]:
) -> dict[str, Sequence[int] | str]:
"""Returns shape info in the format expected by read requests."""
return {
"xla_primitive_type_str": dtype_to_xla_primitive_type_str(dtype),
Expand All @@ -107,7 +108,7 @@ def get_write_request(
jax_array: jax.Array,
timeout: datetime.timedelta,
return_dict: bool = False,
) -> Union[str, dict[str, Any]]:
) -> str | dict[str, Any]:
"""Returns a string representation of the plugin program which writes the given jax_array to the given location."""
sharding = jax_array.sharding
assert isinstance(sharding, jax.sharding.Sharding), sharding
Expand Down Expand Up @@ -171,7 +172,7 @@ def get_read_request(
devices: Sequence[jax.Device],
timeout: datetime.timedelta,
return_dict: bool = False,
) -> Union[str, dict[str, Any]]:
) -> str | dict[str, Any]:
"""Returns a string representation of the plugin program which reads the given array from the given location into the provided sharding."""
if not isinstance(devices, np.ndarray):
devices = np.array(devices)
Expand Down Expand Up @@ -256,9 +257,9 @@ def read_one_array(
dtype: np.dtype,
shape: Sequence[int],
shardings: jax.sharding.Sharding,
devices: Union[Sequence[jax.Device], np.ndarray],
devices: Sequence[jax.Device] | np.ndarray,
timeout: datetime.timedelta,
):
) -> jax.Array:
"""Creates the read array plugin program string, compiles it to an executable, calls it and returns the result."""
read_request = get_read_request(
location,
Expand All @@ -284,9 +285,9 @@ def read_arrays(
dtypes: Sequence[np.dtype],
shapes: Sequence[Sequence[int]],
shardings: Sequence[jax.sharding.Sharding],
devices: Union[Sequence[jax.Device], np.ndarray],
devices: Sequence[jax.Device] | np.ndarray,
timeout: datetime.timedelta,
) -> Tuple[Sequence[jax.Array], concurrent.futures.Future[None]]:
) -> tuple[Sequence[jax.Array], concurrent.futures.Future[None]]:
"""Creates the read array plugin program string, compiles it to an executable, calls it and returns the result."""

bulk_read_request = get_bulk_read_request(
Expand Down
7 changes: 3 additions & 4 deletions pathwaysutils/plugin_executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
# limitations under the License.
"""PluginExecutable is a class for executing plugin programs."""

from collections.abc import Sequence
import concurrent.futures
import threading
from typing import List, Sequence, Tuple, Union

import jax
from jax.extend import ifrt_programs
from jax.interpreters import pxla
Expand All @@ -36,11 +35,11 @@ def __init__(self, prog_str: str):

def call(
self,
in_arr: Sequence[Union[jax.Array, List[jax.Array]]] = (),
in_arr: Sequence[jax.Array | Sequence[jax.Array]] = (),
out_shardings: Sequence[jax.sharding.Sharding] = (),
out_avals: Sequence[jax.core.ShapedArray] = (),
out_committed: bool = True,
) -> Tuple[Sequence[jax.Array], concurrent.futures.Future[None]]:
) -> tuple[Sequence[jax.Array], concurrent.futures.Future[None]]:
"""Runs the compiled IFRT program and returns the result and a future."""
results_with_token = self.compiled.execute_sharded(in_arr, with_tokens=True)

Expand Down
29 changes: 14 additions & 15 deletions pathwaysutils/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import logging
import os
import threading
import time
from typing import Any
import urllib.parse

Expand All @@ -38,11 +37,11 @@ class _ProfileState:
executable: plugin_executable.PluginExecutable | None = None
lock: threading.Lock

def __init__(self):
def __init__(self) -> None:
self.executable = None
self.lock = threading.Lock()

def reset(self):
def reset(self) -> None:
self.executable = None


Expand All @@ -52,7 +51,7 @@ def reset(self):
_original_stop_trace = jax.profiler.stop_trace


def toy_computation():
def toy_computation() -> None:
"""A toy computation to run before the first profile."""
x = jax.jit(lambda x: x + 1)(jnp.array(1))
x.block_until_ready()
Expand Down Expand Up @@ -154,7 +153,7 @@ def start_trace(
)


def stop_trace():
def stop_trace() -> None:
"""Stops the currently-running profiler trace."""
try:
with _profile_state.lock:
Expand All @@ -172,7 +171,7 @@ def stop_trace():
_profiler_thread: threading.Thread | None = None


def start_server(port: int):
def start_server(port: int) -> None:
"""Starts the profiling server on port `port`.

The signature is slightly different from `jax.profiler.start_server`
Expand All @@ -192,7 +191,7 @@ class ProfilingConfig:
repository_path: str

@app.post("/profiling")
async def profiling(pc: ProfilingConfig): # pylint: disable=unused-variable
async def profiling(pc: ProfilingConfig) -> dict[str, str]: # pylint: disable=unused-variable
_logger.debug("Capturing profiling data for %s ms", pc.duration_ms)
_logger.debug("Writing profiling data to %s", pc.repository_path)
await asyncio.to_thread(jax.profiler.start_trace, pc.repository_path)
Expand All @@ -210,7 +209,7 @@ async def profiling(pc: ProfilingConfig): # pylint: disable=unused-variable
_profiler_thread.start()


def stop_server():
def stop_server() -> None:
"""Raises an error if there is no active profiler server.

Pathways profiling servers are not stoppable at this time.
Expand Down Expand Up @@ -257,7 +256,7 @@ def collect_profile(
return True


def monkey_patch_jax():
def monkey_patch_jax() -> None:
"""Monkey patches JAX with Pathways versions of functions.

The signatures in patched functions should match the original.
Expand All @@ -279,7 +278,7 @@ def start_trace_patch(
profiler_options: jax.profiler.ProfileOptions | None = None, # pylint: disable=unused-argument
) -> None:
_logger.debug("jax.profile.start_trace patched with pathways' start_trace")
return start_trace(
start_trace(
log_dir,
create_perfetto_link=create_perfetto_link,
create_perfetto_trace=create_perfetto_trace,
Expand All @@ -291,21 +290,21 @@ def start_trace_patch(

def stop_trace_patch() -> None:
_logger.debug("jax.profile.stop_trace patched with pathways' stop_trace")
return stop_trace()
stop_trace()

jax.profiler.stop_trace = stop_trace_patch
jax._src.profiler.stop_trace = stop_trace_patch # pylint: disable=protected-access

def start_server_patch(port: int):
def start_server_patch(port: int) -> None:
_logger.debug(
"jax.profile.start_server patched with pathways' start_server"
)
return start_server(port)
start_server(port)

jax.profiler.start_server = start_server_patch

def stop_server_patch():
def stop_server_patch() -> None:
_logger.debug("jax.profile.stop_server patched with pathways' stop_server")
return stop_server()
stop_server()

jax.profiler.stop_server = stop_server_patch
2 changes: 1 addition & 1 deletion pathwaysutils/proxy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from pathwaysutils import jax as pw_jax


def register_backend_factory():
def register_backend_factory() -> None:
backend.register_backend_factory(
"proxy",
lambda: pw_jax.ifrt_proxy.get_client(
Expand Down
3 changes: 2 additions & 1 deletion pathwaysutils/reshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
"""Resharding API using the IFRT RemapArray API."""

import collections
from typing import Any, Callable, Mapping, Sequence
from collections.abc import Callable, Mapping, Sequence
from typing import Any

import jax
import pathwaysutils.jax
Expand Down
Loading