diff --git a/src/murfey/server/ispyb.py b/src/murfey/server/ispyb.py index 47860c89..f6cebb30 100644 --- a/src/murfey/server/ispyb.py +++ b/src/murfey/server/ispyb.py @@ -186,13 +186,21 @@ def do_update_atlas( atlas_image: str, pixel_size: float, slot: int | None, + collection_mode: str | None = None, + color_flags: dict[str, str | int] | None = None, ): + color_flags = color_flags or {} try: with ISPyBSession() as db: atlas = db.query(Atlas).filter(Atlas.atlasId == atlas_id).one() atlas.atlasImage = atlas_image or atlas.atlasImage atlas.pixelSize = pixel_size or atlas.pixelSize atlas.cassetteSlot = slot or atlas.cassetteSlot + atlas.mode = collection_mode or atlas.mode + # Optionally insert colour flags if present + if color_flags: + for col_name, value in color_flags.items(): + setattr(atlas, col_name, value) db.add(atlas) db.commit() return {"success": True, "return_value": atlas.atlasId} @@ -209,7 +217,9 @@ def do_insert_grid_square( atlas_id: int, grid_square_id: int, grid_square_parameters: GridSquareParameters, + color_flags: dict[str, int] | None = None, ): + color_flags = color_flags or {} # most of this is for mypy if ( grid_square_parameters.pixel_size is not None @@ -234,7 +244,12 @@ def do_insert_grid_square( stageLocationX=grid_square_parameters.x_stage_position, stageLocationY=grid_square_parameters.y_stage_position, pixelSize=grid_square_parameters.pixel_size, + mode=grid_square_parameters.collection_mode, ) + # Optionally insert colour flags + if color_flags: + for col_name, value in color_flags.items(): + setattr(record, col_name, value) try: with ISPyBSession() as db: db.add(record) @@ -250,8 +265,12 @@ def do_insert_grid_square( return {"success": False, "return_value": None} def do_update_grid_square( - self, grid_square_id: int, grid_square_parameters: GridSquareParameters + self, + grid_square_id: int, + grid_square_parameters: GridSquareParameters, + color_flags: dict[str, int] | None = None, ): + color_flags = color_flags or {} try: with ISPyBSession() as db: grid_square: GridSquare = ( @@ -290,6 +309,12 @@ def do_update_grid_square( grid_square.stageLocationY = grid_square_parameters.y_stage_position if grid_square_parameters.pixel_size: grid_square.pixelSize = grid_square_parameters.pixel_size + if grid_square_parameters.collection_mode: + grid_square.mode = grid_square_parameters.collection_mode + # Optionally insert colour flags + if color_flags: + for col_name, value in color_flags.items(): + setattr(grid_square, col_name, value) db.add(grid_square) db.commit() return {"success": True, "return_value": grid_square.gridSquareId} diff --git a/src/murfey/util/db.py b/src/murfey/util/db.py index 193e0ff2..a46a1267 100644 --- a/src/murfey/util/db.py +++ b/src/murfey/util/db.py @@ -294,6 +294,14 @@ class CLEMImageSeries(SQLModel, table=True): # type: ignore sa_relationship_kwargs={"cascade": "delete"}, ) # One to many number_of_members: Optional[int] = Field(default=None) + has_grey: Optional[bool] = Field(default=None) + has_red: Optional[bool] = Field(default=None) + has_green: Optional[bool] = Field(default=None) + has_blue: Optional[bool] = Field(default=None) + has_cyan: Optional[bool] = Field(default=None) + has_magenta: Optional[bool] = Field(default=None) + has_yellow: Optional[bool] = Field(default=None) + collection_mode: Optional[str] = Field(default=None) # Shape and resolution information image_pixels_x: Optional[int] = Field(default=None) diff --git a/src/murfey/util/models.py b/src/murfey/util/models.py index 4d079dbd..1d59a6b2 100644 --- a/src/murfey/util/models.py +++ b/src/murfey/util/models.py @@ -157,6 +157,9 @@ class GridSquareParameters(BaseModel): pixel_size: Optional[float] = None angle: Optional[float] = None + # Collection mode + collection_mode: Optional[str] = None + class FoilHoleParameters(BaseModel): tag: str diff --git a/src/murfey/workflows/clem/register_preprocessing_results.py b/src/murfey/workflows/clem/register_preprocessing_results.py index d72859b4..aa8d8dda 100644 --- a/src/murfey/workflows/clem/register_preprocessing_results.py +++ b/src/murfey/workflows/clem/register_preprocessing_results.py @@ -11,6 +11,7 @@ import logging import re import traceback +from collections.abc import Collection from importlib.metadata import entry_points from pathlib import Path from typing import Literal, Optional @@ -66,6 +67,27 @@ def _is_clem_atlas(result: CLEMPreprocessingResult): ) +COLOR_FLAGS_MURFEY = { + "gray": "has_grey", + "red": "has_red", + "green": "has_green", + "blue": "has_blue", + "cyan": "has_cyan", + "magenta": "has_magenta", + "yellow": "has_yellow", +} + + +def _get_color_flags( + colors: Collection[str] | None = None, +): + colors = colors or [] + color_flags = dict.fromkeys(COLOR_FLAGS_MURFEY.values(), False) + for color in colors: + color_flags[COLOR_FLAGS_MURFEY[color]] = True + return color_flags + + def _register_clem_image_series( session_id: int, result: CLEMPreprocessingResult, @@ -159,6 +181,11 @@ def _register_clem_image_series( clem_img_series.image_search_string = str(output_file.parent / "*tiff") clem_img_series.data_type = "atlas" if _is_clem_atlas(result) else "grid_square" clem_img_series.number_of_members = result.number_of_members + for col_name, value in _get_color_flags(result.output_files.keys()).items(): + setattr(clem_img_series, col_name, value) + clem_img_series.collection_mode = _determine_collection_mode( + result.output_files.keys() + ) clem_img_series.image_pixels_x = result.pixels_x clem_img_series.image_pixels_y = result.pixels_y clem_img_series.image_pixel_size = result.pixel_size @@ -186,6 +213,31 @@ def _register_clem_image_series( logger.info(f"CLEM preprocessing results registered for {result.series_name!r} ") +def _determine_collection_mode( + colors: Collection[str] | None = None, +): + if not colors: + logger.warning("No colours were present in returned result") + return None + if "gray" in colors: + if len(colors) == 1: + return "Bright Field" + else: + return "Bright Field and Fluorescent" + else: + return "Fluorescent" + + +def _snake_to_camel_case(string: str): + parts = string.split("_") + return parts[0] + "".join(part.capitalize() for part in parts[1:]) + + +COLOR_FLAGS_MURFEY_TO_ISPYB = { + value: _snake_to_camel_case(value) for value in COLOR_FLAGS_MURFEY.values() +} + + def _register_dcg_and_atlas( session_id: int, instrument_name: str, @@ -225,9 +277,17 @@ def _register_dcg_and_atlas( else: atlas_name = str(output_file.parent / "*.tiff") atlas_pixel_size = result.pixel_size + # Translate colour flags into ISPyB convention + color_flags = { + COLOR_FLAGS_MURFEY_TO_ISPYB[key]: int(value) + for key, value in _get_color_flags(result.output_files.keys()).items() + } + collection_mode = _determine_collection_mode(result.output_files.keys()) else: atlas_name = "" atlas_pixel_size = 0.0 + color_flags = None + collection_mode = None if dcg_search := murfey_db.exec( select(MurfeyDB.DataCollectionGroup) @@ -245,6 +305,8 @@ def _register_dcg_and_atlas( "atlas": atlas_name, "atlas_pixel_size": atlas_pixel_size, "sample": dcg_entry.sample, + "color_flags": color_flags, + "collection_mode": collection_mode, } if entry_point_result := entry_points( group="murfey.workflows", name="atlas_update" @@ -269,6 +331,8 @@ def _register_dcg_and_atlas( "atlas": atlas_name, "atlas_pixel_size": atlas_pixel_size, "sample": None, + "color_flags": color_flags, + "collection_mode": collection_mode, } if entry_point_result := entry_points( group="murfey.workflows", name="data_collection_group" @@ -342,6 +406,8 @@ def _register_grid_square( and atlas_entry.x1 is not None and atlas_entry.y0 is not None and atlas_entry.y1 is not None + and atlas_entry.thumbnail_pixels_x is not None + and atlas_entry.thumbnail_pixels_y is not None ): atlas_width_real = atlas_entry.x1 - atlas_entry.x0 atlas_height_real = atlas_entry.y1 - atlas_entry.y0 @@ -356,34 +422,31 @@ def _register_grid_square( and clem_img_series.x1 is not None and clem_img_series.y0 is not None and clem_img_series.y1 is not None - and clem_img_series.thumbnail_pixels_x is not None - and clem_img_series.thumbnail_pixels_y is not None - and clem_img_series.thumbnail_pixel_size is not None ): # Find pixel corresponding to image midpoint on atlas x_mid_real = ( 0.5 * (clem_img_series.x0 + clem_img_series.x1) - atlas_entry.x0 ) x_mid_px = int( - x_mid_real / atlas_width_real * clem_img_series.thumbnail_pixels_x + x_mid_real / atlas_width_real * atlas_entry.thumbnail_pixels_x ) y_mid_real = ( 0.5 * (clem_img_series.y0 + clem_img_series.y1) - atlas_entry.y0 ) y_mid_px = int( - y_mid_real / atlas_height_real * clem_img_series.thumbnail_pixels_y + y_mid_real / atlas_height_real * atlas_entry.thumbnail_pixels_y ) - # Find the size of the image, in pixels, when overlaid the atlas + # Find the size of the image, in pixels, when overlaid on the atlas width_scaled = int( (clem_img_series.x1 - clem_img_series.x0) / atlas_width_real - * clem_img_series.thumbnail_pixels_x + * atlas_entry.thumbnail_pixels_x ) height_scaled = int( (clem_img_series.y1 - clem_img_series.y0) / atlas_height_real - * clem_img_series.thumbnail_pixels_y + * atlas_entry.thumbnail_pixels_y ) else: logger.warning( @@ -410,7 +473,13 @@ def _register_grid_square( y_stage_position=0.5 * (clem_img_series.y0 + clem_img_series.y1), pixel_size=clem_img_series.image_pixel_size, image=clem_img_series.thumbnail_search_string, + collection_mode=clem_img_series.collection_mode, ) + # Construct colour flags for ISPyB + color_flags = { + ispyb_color_flags: int(getattr(clem_img_series, murfey_color_flags, 0)) + for murfey_color_flags, ispyb_color_flags in COLOR_FLAGS_MURFEY_TO_ISPYB.items() + } # Register or update the grid square entry as required if grid_square_result := murfey_db.exec( select(MurfeyDB.GridSquare) @@ -435,6 +504,7 @@ def _register_grid_square( _transport_object.do_update_grid_square( grid_square_id=grid_square_entry.id, grid_square_parameters=grid_square_params, + color_flags=color_flags, ) else: # Look up data collection group for current series @@ -448,6 +518,7 @@ def _register_grid_square( atlas_id=dcg_entry.atlas_id, grid_square_id=clem_img_series.id, grid_square_parameters=grid_square_params, + color_flags=color_flags, ) # Register to Murfey grid_square_entry = MurfeyDB.GridSquare( diff --git a/src/murfey/workflows/register_atlas_update.py b/src/murfey/workflows/register_atlas_update.py index 6ff68cc4..28bd77c9 100644 --- a/src/murfey/workflows/register_atlas_update.py +++ b/src/murfey/workflows/register_atlas_update.py @@ -19,10 +19,13 @@ def run( logger.info(f"Registering updated atlas: \n{message}") _transport_object.do_update_atlas( - message["atlas_id"], - message["atlas"], - message["atlas_pixel_size"], - message["sample"], + atlas_id=message["atlas_id"], + atlas_image=message["atlas"], + pixel_size=message["atlas_pixel_size"], + slot=message["sample"], + # Extract optional parameters + collection_mode=message.get("collection_mode"), + color_flags=message.get("color_flags", {}), ) if dcg_hooks := entry_points(group="murfey.hooks", name="data_collection_group"): try: diff --git a/src/murfey/workflows/register_data_collection_group.py b/src/murfey/workflows/register_data_collection_group.py index a225936f..13557550 100644 --- a/src/murfey/workflows/register_data_collection_group.py +++ b/src/murfey/workflows/register_data_collection_group.py @@ -67,6 +67,12 @@ def run(message: dict, murfey_db: SQLModelSession) -> dict[str, bool]: pixelSize=message.get("atlas_pixel_size", 0), cassetteSlot=message.get("sample"), ) + # Optionally set the collection mode and color flags + if collection_mode := message.get("collection_mode"): + atlas_record.mode = collection_mode + if color_flags := message.get("color_flags", {}): + for col_name, value in color_flags.items(): + setattr(atlas_record, col_name, value) atlas_id = _transport_object.do_insert_atlas(atlas_record).get( "return_value", None ) diff --git a/tests/workflows/clem/test_register_preprocessing_results.py b/tests/workflows/clem/test_register_preprocessing_results.py index 3353be98..b701a6b6 100644 --- a/tests/workflows/clem/test_register_preprocessing_results.py +++ b/tests/workflows/clem/test_register_preprocessing_results.py @@ -12,9 +12,13 @@ import murfey.util.db as MurfeyDB from murfey.workflows.clem.register_preprocessing_results import ( + COLOR_FLAGS_MURFEY_TO_ISPYB, + _determine_collection_mode, + _get_color_flags, _register_clem_image_series, _register_dcg_and_atlas, _register_grid_square, + _snake_to_camel_case, run, ) from tests.conftest import ExampleVisit, get_or_create_db_entry @@ -22,7 +26,6 @@ visit_name = f"{ExampleVisit.proposal_code}{ExampleVisit.proposal_number}-{ExampleVisit.visit_number}" processed_dir_name = "processed" grid_name = "Grid_1" -colors = ("gray", "green", "red") @pytest.fixture @@ -33,6 +36,7 @@ def rsync_basepath(tmp_path: Path): def generate_preprocessing_messages( rsync_basepath: Path, session_id: int, + colors: list[str], ): # Make directory to where data for current grid is stored visit_dir = rsync_basepath / "2020" / visit_name @@ -116,25 +120,120 @@ def generate_preprocessing_messages( return messages -@pytest.mark.skip +@pytest.mark.parametrize( + "test_params", + ( + ( + ["gray"], + { + "has_grey": True, + }, + ), + ( + ["gray", "red"], + { + "has_grey": True, + "has_red": True, + }, + ), + ( + ["red", "green", "blue"], + { + "has_red": True, + "has_green": True, + "has_blue": True, + }, + ), + ( + ["cyan", "magenta", "yellow"], + { + "has_cyan": True, + "has_magenta": True, + "has_yellow": True, + }, + ), + ), +) +def test_get_color_flags(test_params: tuple[list[str], dict[str, bool]]): + colors, positive_flags = test_params + expected_result = dict.fromkeys( + ( + "has_grey", + "has_red", + "has_green", + "has_blue", + "has_cyan", + "has_magenta", + "has_yellow", + ), + False, + ) + for flag, value in positive_flags.items(): + expected_result[flag] = value + assert _get_color_flags(colors) == expected_result + + def test_register_clem_image_series(): - assert _register_clem_image_series + _register_clem_image_series + + +@pytest.mark.parametrize( + "test_params", + ( + (["gray"], "Bright Field"), + (["gray", "blue"], "Bright Field and Fluorescent"), + (["red", "green", "blue"], "Fluorescent"), + ), +) +def test_determine_collection_mode(test_params: tuple[list[str], str]): + colors, expected_result = test_params + assert _determine_collection_mode(colors) == expected_result + + +@pytest.mark.parametrize( + "test_params", + ( + ("has_grey", "hasGrey"), + ("has_red", "hasRed"), + ("has_green", "hasGreen"), + ("has_blue", "hasBlue"), + ("has_cyan", "hasCyan"), + ("has_magenta", "hasMagenta"), + ("has_yellow", "hasYellow"), + ), +) +def test_snake_to_camel_case( + test_params: tuple[str, str], +): + string, expected_result = test_params + assert _snake_to_camel_case(string) == expected_result -@pytest.mark.skip def test_register_dcg_and_atlas(): - assert _register_dcg_and_atlas + _register_dcg_and_atlas -@pytest.mark.skip def test_register_grid_square(): - assert _register_grid_square + _register_grid_square +@pytest.mark.parametrize( + "test_params", + ( # Colors + (["gray"],), + (["gray", "green"],), + (["red", "green", "blue"],), + (["cyan", "magenta", "blue"],), + ), +) def test_run( mocker: MockerFixture, rsync_basepath: Path, + test_params: tuple[list[str]], ): + # Unpack test params + (colors,) = test_params + # Mock the MurfeyDB connection mock_murfey_session_entry = MagicMock() mock_murfey_session_entry.instrument_name = ExampleVisit.instrument_name @@ -161,6 +260,7 @@ def test_run( preprocessing_messages = generate_preprocessing_messages( rsync_basepath=rsync_basepath, session_id=ExampleVisit.murfey_session_id, + colors=colors, ) for message in preprocessing_messages: result = run( @@ -171,29 +271,34 @@ def test_run( assert mock_register_clem_series.call_count == len(preprocessing_messages) assert mock_register_dcg_and_atlas.call_count == len(preprocessing_messages) assert mock_register_grid_square.call_count == len(preprocessing_messages) - assert mock_align_and_merge_call.call_count == len(preprocessing_messages) * len( - colors - ) - - -test_matrix = ( - # Reverse order of list - (False,), - (True,), + if ("gray" not in colors) or ("gray" in colors and len(colors) == 1): + assert mock_align_and_merge_call.call_count == len(preprocessing_messages) + else: + assert mock_align_and_merge_call.call_count == len(preprocessing_messages) * 3 + + +@pytest.mark.parametrize( + "test_params", + ( + # Reverse list order? | Colors + (False, ["gray"]), + (True, ["gray"]), + (False, ["red", "green", "blue"]), + (True, ["cyan", "magenta", "yellow"]), + (False, ["gray", "red", "green", "blue"]), + (True, ["gray", "cyan", "magenta", "yellow"]), + ), ) - - -@pytest.mark.parametrize("test_params", test_matrix) def test_run_with_db( mocker: MockerFixture, rsync_basepath: Path, mock_ispyb_credentials, murfey_db_session: SQLModelSession, ispyb_db_session: SQLAlchemySession, - test_params: tuple[bool], + test_params: tuple[bool, list[str]], ): # Unpack test params - (shuffle_message,) = test_params + (shuffle_message, colors) = test_params # Create a session to insert for this test murfey_session: MurfeyDB.Session = get_or_create_db_entry( @@ -258,6 +363,7 @@ def test_run_with_db( preprocessing_messages = generate_preprocessing_messages( rsync_basepath=rsync_basepath, session_id=murfey_session.id, + colors=colors, ) if shuffle_message: preprocessing_messages.reverse() @@ -270,9 +376,10 @@ def test_run_with_db( # Each message should call the align-and-merge workflow thrice # if gray and colour channels are both present - assert mock_align_and_merge_call.call_count == len(preprocessing_messages) * len( - colors - ) + if ("gray" not in colors) or ("gray" in colors and len(colors) == 1): + assert mock_align_and_merge_call.call_count == len(preprocessing_messages) + else: + assert mock_align_and_merge_call.call_count == len(preprocessing_messages) * 3 # Both databases should have entries for data collection group, and grid squares # ISPyB database should additionally have an atlas entry @@ -313,7 +420,19 @@ def test_run_with_db( ) assert len(ispyb_atlas_search) == 1 + # Determine the color flags and collection mode + color_flags = { + COLOR_FLAGS_MURFEY_TO_ISPYB[flag]: int(value) + for flag, value in _get_color_flags(colors).items() + } + collection_mode = _determine_collection_mode(colors) + ispyb_atlas = ispyb_atlas_search[0] + # Check that the Atlas color flags and collection mode are set correctly + for flag, value in color_flags.items(): + assert getattr(ispyb_atlas, flag) == value + assert ispyb_atlas.mode == collection_mode + ispyb_gs_search = ( ispyb_db_session.execute( sa_select(ISPyBDB.GridSquare).where( @@ -324,6 +443,11 @@ def test_run_with_db( .all() ) assert len(ispyb_gs_search) == len(preprocessing_messages) - 1 + for gs in ispyb_gs_search: + # Check that the Atlas color flags and collection mode are set correctly + for flag, value in color_flags.items(): + assert getattr(gs, flag) == value + assert gs.mode == collection_mode murfey_db_session.close() ispyb_db_session.close() diff --git a/tests/workflows/test_register_atlas_update.py b/tests/workflows/test_register_atlas_update.py index 983c739b..adfdd628 100644 --- a/tests/workflows/test_register_atlas_update.py +++ b/tests/workflows/test_register_atlas_update.py @@ -25,9 +25,11 @@ def test_run( # Run the function and check the results and calls made result = run(message, mock_murfey_db) mock_transport_object.do_update_atlas.assert_called_once_with( - message["atlas_id"], - message["atlas"], - message["atlas_pixel_size"], - message["sample"], + atlas_id=message["atlas_id"], + atlas_image=message["atlas"], + pixel_size=message["atlas_pixel_size"], + slot=message["sample"], + collection_mode=message.get("collection_mode"), + color_flags=message.get("color_flags", {}), ) assert result == {"success": True} diff --git a/tests/workflows/test_register_data_collection_group.py b/tests/workflows/test_register_data_collection_group.py index 3324dec9..9d2783f7 100644 --- a/tests/workflows/test_register_data_collection_group.py +++ b/tests/workflows/test_register_data_collection_group.py @@ -6,28 +6,29 @@ from murfey.workflows.register_data_collection_group import run from tests.conftest import ExampleVisit -register_data_collection_group_params_matrix = ( - # ISPyB session ID | # DCG search result | # DCG insert result | # Atlas insert result - (0, 0, 0, 0), - (0, 0, 0, None), - (0, 0, None, 0), - (0, 0, None, None), - (0, None, 0, 0), - (0, None, 0, None), - (0, None, None, 0), - (0, None, None, None), - (None, 0, 0, 0), - (None, 0, 0, None), - (None, 0, None, 0), - (None, 0, None, None), - (None, None, 0, 0), - (None, None, 0, None), - (None, None, None, 0), - (None, None, None, None), -) - -@pytest.mark.parametrize("test_params", register_data_collection_group_params_matrix) +@pytest.mark.parametrize( + "test_params", + ( + # ISPyB session ID | # DCG search result | # DCG insert result | # Atlas insert result + (0, 0, 0, 0), + (0, 0, 0, None), + (0, 0, None, 0), + (0, 0, None, None), + (0, None, 0, 0), + (0, None, 0, None), + (0, None, None, 0), + (0, None, None, None), + (None, 0, 0, 0), + (None, 0, 0, None), + (None, 0, None, 0), + (None, 0, None, None), + (None, None, 0, 0), + (None, None, 0, None), + (None, None, None, 0), + (None, None, None, None), + ), +) def test_run( mocker: MockerFixture, test_params: tuple[int | None, int | None, int | None, int | None],