diff --git a/pathwaysutils/__init__.py b/pathwaysutils/__init__.py index 81d0427..62506f6 100644 --- a/pathwaysutils/__init__.py +++ b/pathwaysutils/__init__.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Package of Pathways-on-Cloud utilities.""" +from collections.abc import Callable from pathwaysutils import _initialize -initialize = _initialize.initialize -is_pathways_backend_used = _initialize.is_pathways_backend_used +initialize: Callable[[], None] = _initialize.initialize +is_pathways_backend_used: Callable[[], bool] = _initialize.is_pathways_backend_used del _initialize # When changing this, also update the CHANGELOG.md. -__version__ = "v0.1.5" +__version__: str = "v0.1.5" diff --git a/pathwaysutils/collect_profile.py b/pathwaysutils/collect_profile.py index 26a01e6..57c971e 100644 --- a/pathwaysutils/collect_profile.py +++ b/pathwaysutils/collect_profile.py @@ -26,7 +26,7 @@ _logger.setLevel(logging.INFO) -_DESCRIPTION = """ +_DESCRIPTION: str = """ To profile running JAX programs, you first need to start the profiler server in the program of interest. You can do this via `jax.profiler.start_server()`. Once the program is running and the @@ -36,7 +36,7 @@ """ -def _get_parser(): +def _get_parser() -> argparse.ArgumentParser: """Returns an argument parser for the collect_profile script.""" parser = argparse.ArgumentParser(description=_DESCRIPTION) parser.add_argument( @@ -62,7 +62,7 @@ def _get_parser(): return parser -def main(): +def main() -> None: parser = _get_parser() args = parser.parse_args() diff --git a/pathwaysutils/experimental/reshard.py b/pathwaysutils/experimental/reshard.py index c2cb906..9ba4f0b 100644 --- a/pathwaysutils/experimental/reshard.py +++ b/pathwaysutils/experimental/reshard.py @@ -14,13 +14,12 @@ """Experimental resharding API for elastic device sets.""" import base64 -import collections -from collections.abc import Mapping +from collections.abc import Callable, Mapping, Sequence import json import logging import math import operator -from typing import Any, Callable, Dict, Mapping, Sequence +from typing import Any import jax from pathwaysutils import jax as pw_jax @@ -57,7 +56,7 @@ def __init__( ): def ifrt_hlo_sharding( aval: jax.core.ShapedArray, sharding: jax.sharding.Sharding - ) -> Dict[str, Any]: + ) -> dict[str, Any]: result = { "devices": { "device_ids": [ diff --git a/pathwaysutils/experimental/split_by_mesh_axis.py b/pathwaysutils/experimental/split_by_mesh_axis.py index 06fa833..88e0e97 100644 --- a/pathwaysutils/experimental/split_by_mesh_axis.py +++ b/pathwaysutils/experimental/split_by_mesh_axis.py @@ -13,7 +13,8 @@ # limitations under the License. """Experimental split by mesh axis API.""" -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any, cast import jax from pathwaysutils import jax as pw_jax @@ -167,7 +168,8 @@ def split_by_mesh_axis( mesh_axis_sizes=mesh.axis_sizes, mesh_axis_idx=mesh_axis_idx, mesh_axis_sections=mesh_axis_sections, - submesh_shardings=submesh_shardings, + # TODO: b/491156211 - Remove cast once type mismatch is fixed. + submesh_shardings=cast(Any, submesh_shardings), donate=donate, ) diff --git a/pathwaysutils/lru_cache.py b/pathwaysutils/lru_cache.py index 6608704..13cfb8a 100644 --- a/pathwaysutils/lru_cache.py +++ b/pathwaysutils/lru_cache.py @@ -13,15 +13,19 @@ # limitations under the License. """An LRU cache that will be cleared when JAX clears its internal cache.""" +from collections.abc import Callable import functools -from typing import Any, Callable +from typing import Any, TypeVar from jax.extend import backend +_F = TypeVar("_F", bound=Callable[..., Any]) + + def lru_cache( maxsize: int = 4096, -) -> Callable[[Callable[..., Any]], Callable[..., Any]]: +) -> Callable[[_F], _F]: """An LRU cache that will be cleared when JAX clears its internal cache. Args: @@ -32,7 +36,7 @@ def lru_cache( A function that can be used to decorate a function to cache its results. """ - def wrap(f): + def wrap(f: _F) -> _F: cached = functools.lru_cache(maxsize=maxsize)(f) wrapper = functools.wraps(f)(cached) diff --git a/pathwaysutils/persistence/helper.py b/pathwaysutils/persistence/helper.py index 48d3021..5d8b535 100644 --- a/pathwaysutils/persistence/helper.py +++ b/pathwaysutils/persistence/helper.py @@ -14,10 +14,11 @@ """Helper functions for persistence.""" import base64 +from collections.abc import Sequence import concurrent.futures import datetime import json -from typing import Any, Sequence, Tuple, Union +from typing import Any import jax from jax import core @@ -93,7 +94,7 @@ def get_hlo_sharding_string( def get_shape_info( dtype: np.dtype, dimensions: Sequence[int], -) -> dict[str, Union[Sequence[int], str]]: +) -> dict[str, Sequence[int] | str]: """Returns shape info in the format expected by read requests.""" return { "xla_primitive_type_str": dtype_to_xla_primitive_type_str(dtype), @@ -107,7 +108,7 @@ def get_write_request( jax_array: jax.Array, timeout: datetime.timedelta, return_dict: bool = False, -) -> Union[str, dict[str, Any]]: +) -> str | dict[str, Any]: """Returns a string representation of the plugin program which writes the given jax_array to the given location.""" sharding = jax_array.sharding assert isinstance(sharding, jax.sharding.Sharding), sharding @@ -171,7 +172,7 @@ def get_read_request( devices: Sequence[jax.Device], timeout: datetime.timedelta, return_dict: bool = False, -) -> Union[str, dict[str, Any]]: +) -> str | dict[str, Any]: """Returns a string representation of the plugin program which reads the given array from the given location into the provided sharding.""" if not isinstance(devices, np.ndarray): devices = np.array(devices) @@ -256,9 +257,9 @@ def read_one_array( dtype: np.dtype, shape: Sequence[int], shardings: jax.sharding.Sharding, - devices: Union[Sequence[jax.Device], np.ndarray], + devices: Sequence[jax.Device] | np.ndarray, timeout: datetime.timedelta, -): +) -> jax.Array: """Creates the read array plugin program string, compiles it to an executable, calls it and returns the result.""" read_request = get_read_request( location, @@ -284,9 +285,9 @@ def read_arrays( dtypes: Sequence[np.dtype], shapes: Sequence[Sequence[int]], shardings: Sequence[jax.sharding.Sharding], - devices: Union[Sequence[jax.Device], np.ndarray], + devices: Sequence[jax.Device] | np.ndarray, timeout: datetime.timedelta, -) -> Tuple[Sequence[jax.Array], concurrent.futures.Future[None]]: +) -> tuple[Sequence[jax.Array], concurrent.futures.Future[None]]: """Creates the read array plugin program string, compiles it to an executable, calls it and returns the result.""" bulk_read_request = get_bulk_read_request( diff --git a/pathwaysutils/plugin_executable.py b/pathwaysutils/plugin_executable.py index 789d4b5..e1c8956 100644 --- a/pathwaysutils/plugin_executable.py +++ b/pathwaysutils/plugin_executable.py @@ -13,10 +13,9 @@ # limitations under the License. """PluginExecutable is a class for executing plugin programs.""" +from collections.abc import Sequence import concurrent.futures import threading -from typing import List, Sequence, Tuple, Union - import jax from jax.extend import ifrt_programs from jax.interpreters import pxla @@ -36,11 +35,11 @@ def __init__(self, prog_str: str): def call( self, - in_arr: Sequence[Union[jax.Array, List[jax.Array]]] = (), + in_arr: Sequence[jax.Array | Sequence[jax.Array]] = (), out_shardings: Sequence[jax.sharding.Sharding] = (), out_avals: Sequence[jax.core.ShapedArray] = (), out_committed: bool = True, - ) -> Tuple[Sequence[jax.Array], concurrent.futures.Future[None]]: + ) -> tuple[Sequence[jax.Array], concurrent.futures.Future[None]]: """Runs the compiled IFRT program and returns the result and a future.""" results_with_token = self.compiled.execute_sharded(in_arr, with_tokens=True) diff --git a/pathwaysutils/profiling.py b/pathwaysutils/profiling.py index e6f2a4c..d0c5e10 100644 --- a/pathwaysutils/profiling.py +++ b/pathwaysutils/profiling.py @@ -19,7 +19,6 @@ import logging import os import threading -import time from typing import Any import urllib.parse @@ -38,11 +37,11 @@ class _ProfileState: executable: plugin_executable.PluginExecutable | None = None lock: threading.Lock - def __init__(self): + def __init__(self) -> None: self.executable = None self.lock = threading.Lock() - def reset(self): + def reset(self) -> None: self.executable = None @@ -52,7 +51,7 @@ def reset(self): _original_stop_trace = jax.profiler.stop_trace -def toy_computation(): +def toy_computation() -> None: """A toy computation to run before the first profile.""" x = jax.jit(lambda x: x + 1)(jnp.array(1)) x.block_until_ready() @@ -154,7 +153,7 @@ def start_trace( ) -def stop_trace(): +def stop_trace() -> None: """Stops the currently-running profiler trace.""" try: with _profile_state.lock: @@ -172,7 +171,7 @@ def stop_trace(): _profiler_thread: threading.Thread | None = None -def start_server(port: int): +def start_server(port: int) -> None: """Starts the profiling server on port `port`. The signature is slightly different from `jax.profiler.start_server` @@ -192,7 +191,7 @@ class ProfilingConfig: repository_path: str @app.post("/profiling") - async def profiling(pc: ProfilingConfig): # pylint: disable=unused-variable + async def profiling(pc: ProfilingConfig) -> dict[str, str]: # pylint: disable=unused-variable _logger.debug("Capturing profiling data for %s ms", pc.duration_ms) _logger.debug("Writing profiling data to %s", pc.repository_path) await asyncio.to_thread(jax.profiler.start_trace, pc.repository_path) @@ -210,7 +209,7 @@ async def profiling(pc: ProfilingConfig): # pylint: disable=unused-variable _profiler_thread.start() -def stop_server(): +def stop_server() -> None: """Raises an error if there is no active profiler server. Pathways profiling servers are not stoppable at this time. @@ -257,7 +256,7 @@ def collect_profile( return True -def monkey_patch_jax(): +def monkey_patch_jax() -> None: """Monkey patches JAX with Pathways versions of functions. The signatures in patched functions should match the original. @@ -279,7 +278,7 @@ def start_trace_patch( profiler_options: jax.profiler.ProfileOptions | None = None, # pylint: disable=unused-argument ) -> None: _logger.debug("jax.profile.start_trace patched with pathways' start_trace") - return start_trace( + start_trace( log_dir, create_perfetto_link=create_perfetto_link, create_perfetto_trace=create_perfetto_trace, @@ -291,21 +290,21 @@ def start_trace_patch( def stop_trace_patch() -> None: _logger.debug("jax.profile.stop_trace patched with pathways' stop_trace") - return stop_trace() + stop_trace() jax.profiler.stop_trace = stop_trace_patch jax._src.profiler.stop_trace = stop_trace_patch # pylint: disable=protected-access - def start_server_patch(port: int): + def start_server_patch(port: int) -> None: _logger.debug( "jax.profile.start_server patched with pathways' start_server" ) - return start_server(port) + start_server(port) jax.profiler.start_server = start_server_patch - def stop_server_patch(): + def stop_server_patch() -> None: _logger.debug("jax.profile.stop_server patched with pathways' stop_server") - return stop_server() + stop_server() jax.profiler.stop_server = stop_server_patch diff --git a/pathwaysutils/proxy_backend.py b/pathwaysutils/proxy_backend.py index cc9ccce..d599f21 100644 --- a/pathwaysutils/proxy_backend.py +++ b/pathwaysutils/proxy_backend.py @@ -18,7 +18,7 @@ from pathwaysutils import jax as pw_jax -def register_backend_factory(): +def register_backend_factory() -> None: backend.register_backend_factory( "proxy", lambda: pw_jax.ifrt_proxy.get_client( diff --git a/pathwaysutils/reshard.py b/pathwaysutils/reshard.py index 496e6be..e112145 100644 --- a/pathwaysutils/reshard.py +++ b/pathwaysutils/reshard.py @@ -14,7 +14,8 @@ """Resharding API using the IFRT RemapArray API.""" import collections -from typing import Any, Callable, Mapping, Sequence +from collections.abc import Callable, Mapping, Sequence +from typing import Any import jax import pathwaysutils.jax