From 5213b90d48b3d9d2884e0cd61e69e2bc09f2eb68 Mon Sep 17 00:00:00 2001 From: Pathways-on-Cloud Team Date: Tue, 10 Mar 2026 12:24:34 -0700 Subject: [PATCH] Add type hints to pathwaysutils This change adds type annotations to various functions, methods, and variables across the pathwaysutils package, improving code clarity and maintainability. It also includes minor adjustments to existing type hints, such as using `|` for unions and casting where necessary. PiperOrigin-RevId: 881566030 --- pathwaysutils/__init__.py | 7 +++-- pathwaysutils/collect_profile.py | 6 ++-- pathwaysutils/experimental/reshard.py | 7 ++--- .../experimental/split_by_mesh_axis.py | 6 ++-- pathwaysutils/lru_cache.py | 10 +++++-- pathwaysutils/persistence/helper.py | 17 ++++++----- pathwaysutils/plugin_executable.py | 7 ++--- pathwaysutils/profiling.py | 29 +++++++++---------- pathwaysutils/proxy_backend.py | 2 +- pathwaysutils/reshard.py | 3 +- 10 files changed, 50 insertions(+), 44 deletions(-) diff --git a/pathwaysutils/__init__.py b/pathwaysutils/__init__.py index 81d0427..62506f6 100644 --- a/pathwaysutils/__init__.py +++ b/pathwaysutils/__init__.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Package of Pathways-on-Cloud utilities.""" +from collections.abc import Callable from pathwaysutils import _initialize -initialize = _initialize.initialize -is_pathways_backend_used = _initialize.is_pathways_backend_used +initialize: Callable[[], None] = _initialize.initialize +is_pathways_backend_used: Callable[[], bool] = _initialize.is_pathways_backend_used del _initialize # When changing this, also update the CHANGELOG.md. -__version__ = "v0.1.5" +__version__: str = "v0.1.5" diff --git a/pathwaysutils/collect_profile.py b/pathwaysutils/collect_profile.py index 26a01e6..57c971e 100644 --- a/pathwaysutils/collect_profile.py +++ b/pathwaysutils/collect_profile.py @@ -26,7 +26,7 @@ _logger.setLevel(logging.INFO) -_DESCRIPTION = """ +_DESCRIPTION: str = """ To profile running JAX programs, you first need to start the profiler server in the program of interest. You can do this via `jax.profiler.start_server()`. Once the program is running and the @@ -36,7 +36,7 @@ """ -def _get_parser(): +def _get_parser() -> argparse.ArgumentParser: """Returns an argument parser for the collect_profile script.""" parser = argparse.ArgumentParser(description=_DESCRIPTION) parser.add_argument( @@ -62,7 +62,7 @@ def _get_parser(): return parser -def main(): +def main() -> None: parser = _get_parser() args = parser.parse_args() diff --git a/pathwaysutils/experimental/reshard.py b/pathwaysutils/experimental/reshard.py index c2cb906..9ba4f0b 100644 --- a/pathwaysutils/experimental/reshard.py +++ b/pathwaysutils/experimental/reshard.py @@ -14,13 +14,12 @@ """Experimental resharding API for elastic device sets.""" import base64 -import collections -from collections.abc import Mapping +from collections.abc import Callable, Mapping, Sequence import json import logging import math import operator -from typing import Any, Callable, Dict, Mapping, Sequence +from typing import Any import jax from pathwaysutils import jax as pw_jax @@ -57,7 +56,7 @@ def __init__( ): def ifrt_hlo_sharding( aval: jax.core.ShapedArray, sharding: jax.sharding.Sharding - ) -> Dict[str, Any]: + ) -> dict[str, Any]: result = { "devices": { "device_ids": [ diff --git a/pathwaysutils/experimental/split_by_mesh_axis.py b/pathwaysutils/experimental/split_by_mesh_axis.py index 06fa833..88e0e97 100644 --- a/pathwaysutils/experimental/split_by_mesh_axis.py +++ b/pathwaysutils/experimental/split_by_mesh_axis.py @@ -13,7 +13,8 @@ # limitations under the License. """Experimental split by mesh axis API.""" -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any, cast import jax from pathwaysutils import jax as pw_jax @@ -167,7 +168,8 @@ def split_by_mesh_axis( mesh_axis_sizes=mesh.axis_sizes, mesh_axis_idx=mesh_axis_idx, mesh_axis_sections=mesh_axis_sections, - submesh_shardings=submesh_shardings, + # TODO: b/491156211 - Remove cast once type mismatch is fixed. + submesh_shardings=cast(Any, submesh_shardings), donate=donate, ) diff --git a/pathwaysutils/lru_cache.py b/pathwaysutils/lru_cache.py index 6608704..13cfb8a 100644 --- a/pathwaysutils/lru_cache.py +++ b/pathwaysutils/lru_cache.py @@ -13,15 +13,19 @@ # limitations under the License. """An LRU cache that will be cleared when JAX clears its internal cache.""" +from collections.abc import Callable import functools -from typing import Any, Callable +from typing import Any, TypeVar from jax.extend import backend +_F = TypeVar("_F", bound=Callable[..., Any]) + + def lru_cache( maxsize: int = 4096, -) -> Callable[[Callable[..., Any]], Callable[..., Any]]: +) -> Callable[[_F], _F]: """An LRU cache that will be cleared when JAX clears its internal cache. Args: @@ -32,7 +36,7 @@ def lru_cache( A function that can be used to decorate a function to cache its results. """ - def wrap(f): + def wrap(f: _F) -> _F: cached = functools.lru_cache(maxsize=maxsize)(f) wrapper = functools.wraps(f)(cached) diff --git a/pathwaysutils/persistence/helper.py b/pathwaysutils/persistence/helper.py index 48d3021..5d8b535 100644 --- a/pathwaysutils/persistence/helper.py +++ b/pathwaysutils/persistence/helper.py @@ -14,10 +14,11 @@ """Helper functions for persistence.""" import base64 +from collections.abc import Sequence import concurrent.futures import datetime import json -from typing import Any, Sequence, Tuple, Union +from typing import Any import jax from jax import core @@ -93,7 +94,7 @@ def get_hlo_sharding_string( def get_shape_info( dtype: np.dtype, dimensions: Sequence[int], -) -> dict[str, Union[Sequence[int], str]]: +) -> dict[str, Sequence[int] | str]: """Returns shape info in the format expected by read requests.""" return { "xla_primitive_type_str": dtype_to_xla_primitive_type_str(dtype), @@ -107,7 +108,7 @@ def get_write_request( jax_array: jax.Array, timeout: datetime.timedelta, return_dict: bool = False, -) -> Union[str, dict[str, Any]]: +) -> str | dict[str, Any]: """Returns a string representation of the plugin program which writes the given jax_array to the given location.""" sharding = jax_array.sharding assert isinstance(sharding, jax.sharding.Sharding), sharding @@ -171,7 +172,7 @@ def get_read_request( devices: Sequence[jax.Device], timeout: datetime.timedelta, return_dict: bool = False, -) -> Union[str, dict[str, Any]]: +) -> str | dict[str, Any]: """Returns a string representation of the plugin program which reads the given array from the given location into the provided sharding.""" if not isinstance(devices, np.ndarray): devices = np.array(devices) @@ -256,9 +257,9 @@ def read_one_array( dtype: np.dtype, shape: Sequence[int], shardings: jax.sharding.Sharding, - devices: Union[Sequence[jax.Device], np.ndarray], + devices: Sequence[jax.Device] | np.ndarray, timeout: datetime.timedelta, -): +) -> jax.Array: """Creates the read array plugin program string, compiles it to an executable, calls it and returns the result.""" read_request = get_read_request( location, @@ -284,9 +285,9 @@ def read_arrays( dtypes: Sequence[np.dtype], shapes: Sequence[Sequence[int]], shardings: Sequence[jax.sharding.Sharding], - devices: Union[Sequence[jax.Device], np.ndarray], + devices: Sequence[jax.Device] | np.ndarray, timeout: datetime.timedelta, -) -> Tuple[Sequence[jax.Array], concurrent.futures.Future[None]]: +) -> tuple[Sequence[jax.Array], concurrent.futures.Future[None]]: """Creates the read array plugin program string, compiles it to an executable, calls it and returns the result.""" bulk_read_request = get_bulk_read_request( diff --git a/pathwaysutils/plugin_executable.py b/pathwaysutils/plugin_executable.py index 789d4b5..e1c8956 100644 --- a/pathwaysutils/plugin_executable.py +++ b/pathwaysutils/plugin_executable.py @@ -13,10 +13,9 @@ # limitations under the License. """PluginExecutable is a class for executing plugin programs.""" +from collections.abc import Sequence import concurrent.futures import threading -from typing import List, Sequence, Tuple, Union - import jax from jax.extend import ifrt_programs from jax.interpreters import pxla @@ -36,11 +35,11 @@ def __init__(self, prog_str: str): def call( self, - in_arr: Sequence[Union[jax.Array, List[jax.Array]]] = (), + in_arr: Sequence[jax.Array | Sequence[jax.Array]] = (), out_shardings: Sequence[jax.sharding.Sharding] = (), out_avals: Sequence[jax.core.ShapedArray] = (), out_committed: bool = True, - ) -> Tuple[Sequence[jax.Array], concurrent.futures.Future[None]]: + ) -> tuple[Sequence[jax.Array], concurrent.futures.Future[None]]: """Runs the compiled IFRT program and returns the result and a future.""" results_with_token = self.compiled.execute_sharded(in_arr, with_tokens=True) diff --git a/pathwaysutils/profiling.py b/pathwaysutils/profiling.py index e6f2a4c..d0c5e10 100644 --- a/pathwaysutils/profiling.py +++ b/pathwaysutils/profiling.py @@ -19,7 +19,6 @@ import logging import os import threading -import time from typing import Any import urllib.parse @@ -38,11 +37,11 @@ class _ProfileState: executable: plugin_executable.PluginExecutable | None = None lock: threading.Lock - def __init__(self): + def __init__(self) -> None: self.executable = None self.lock = threading.Lock() - def reset(self): + def reset(self) -> None: self.executable = None @@ -52,7 +51,7 @@ def reset(self): _original_stop_trace = jax.profiler.stop_trace -def toy_computation(): +def toy_computation() -> None: """A toy computation to run before the first profile.""" x = jax.jit(lambda x: x + 1)(jnp.array(1)) x.block_until_ready() @@ -154,7 +153,7 @@ def start_trace( ) -def stop_trace(): +def stop_trace() -> None: """Stops the currently-running profiler trace.""" try: with _profile_state.lock: @@ -172,7 +171,7 @@ def stop_trace(): _profiler_thread: threading.Thread | None = None -def start_server(port: int): +def start_server(port: int) -> None: """Starts the profiling server on port `port`. The signature is slightly different from `jax.profiler.start_server` @@ -192,7 +191,7 @@ class ProfilingConfig: repository_path: str @app.post("/profiling") - async def profiling(pc: ProfilingConfig): # pylint: disable=unused-variable + async def profiling(pc: ProfilingConfig) -> dict[str, str]: # pylint: disable=unused-variable _logger.debug("Capturing profiling data for %s ms", pc.duration_ms) _logger.debug("Writing profiling data to %s", pc.repository_path) await asyncio.to_thread(jax.profiler.start_trace, pc.repository_path) @@ -210,7 +209,7 @@ async def profiling(pc: ProfilingConfig): # pylint: disable=unused-variable _profiler_thread.start() -def stop_server(): +def stop_server() -> None: """Raises an error if there is no active profiler server. Pathways profiling servers are not stoppable at this time. @@ -257,7 +256,7 @@ def collect_profile( return True -def monkey_patch_jax(): +def monkey_patch_jax() -> None: """Monkey patches JAX with Pathways versions of functions. The signatures in patched functions should match the original. @@ -279,7 +278,7 @@ def start_trace_patch( profiler_options: jax.profiler.ProfileOptions | None = None, # pylint: disable=unused-argument ) -> None: _logger.debug("jax.profile.start_trace patched with pathways' start_trace") - return start_trace( + start_trace( log_dir, create_perfetto_link=create_perfetto_link, create_perfetto_trace=create_perfetto_trace, @@ -291,21 +290,21 @@ def start_trace_patch( def stop_trace_patch() -> None: _logger.debug("jax.profile.stop_trace patched with pathways' stop_trace") - return stop_trace() + stop_trace() jax.profiler.stop_trace = stop_trace_patch jax._src.profiler.stop_trace = stop_trace_patch # pylint: disable=protected-access - def start_server_patch(port: int): + def start_server_patch(port: int) -> None: _logger.debug( "jax.profile.start_server patched with pathways' start_server" ) - return start_server(port) + start_server(port) jax.profiler.start_server = start_server_patch - def stop_server_patch(): + def stop_server_patch() -> None: _logger.debug("jax.profile.stop_server patched with pathways' stop_server") - return stop_server() + stop_server() jax.profiler.stop_server = stop_server_patch diff --git a/pathwaysutils/proxy_backend.py b/pathwaysutils/proxy_backend.py index cc9ccce..d599f21 100644 --- a/pathwaysutils/proxy_backend.py +++ b/pathwaysutils/proxy_backend.py @@ -18,7 +18,7 @@ from pathwaysutils import jax as pw_jax -def register_backend_factory(): +def register_backend_factory() -> None: backend.register_backend_factory( "proxy", lambda: pw_jax.ifrt_proxy.get_client( diff --git a/pathwaysutils/reshard.py b/pathwaysutils/reshard.py index 496e6be..e112145 100644 --- a/pathwaysutils/reshard.py +++ b/pathwaysutils/reshard.py @@ -14,7 +14,8 @@ """Resharding API using the IFRT RemapArray API.""" import collections -from typing import Any, Callable, Mapping, Sequence +from collections.abc import Callable, Mapping, Sequence +from typing import Any import jax import pathwaysutils.jax