Skip to content

WIP VLM SFT Training via SkyRL with SkyRL train#1282

Draft
nithinvc wants to merge 14 commits intoNovaSky-AI:mainfrom
nithinvc:nithinc/vlm-tinker-v3
Draft

WIP VLM SFT Training via SkyRL with SkyRL train#1282
nithinvc wants to merge 14 commits intoNovaSky-AI:mainfrom
nithinvc:nithinc/vlm-tinker-v3

Conversation

@nithinvc
Copy link

@nithinvc nithinvc commented Mar 5, 2026

A full work in progress end to end implementation of RFC #1200 . When merging in, this will be broken into 4 PRs (tinker types, plumbing to the backend, inference changes, training changes).

Goal targeted script: VLM classifier training via SFT


Open with Devin

@nithinvc nithinvc marked this pull request as draft March 5, 2026 20:00
@nithinvc nithinvc changed the title VLM SFT Training via SkyRL with SkyRL train WIP VLM SFT Training via SkyRL with SkyRL train Mar 5, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables Vision Language Model (VLM) Supervised Fine-Tuning (SFT) by introducing support for multimodal inputs through ModelInput chunks, a new /sample endpoint, and extending TensorBatch to handle list[Tensor] for variable-sized image features across the Tinker API, backend, and worker implementations. Critically, the changes introduce security vulnerabilities, primarily Server-Side Request Forgery (SSRF) due to the use of a user-controlled Host header in VLLMServerActor and insufficient validation of user-provided image URLs, which could allow unauthorized access to internal network resources. There is also a risk of sensitive information leakage from logging URLs containing credentials. Furthermore, the code review highlights opportunities to improve maintainability, performance, and robustness by addressing code duplication, overly broad exception handling, and enhancing type safety.

sampling_params = data["sampling_params"]
model = data.get("model", self._cli_args.model)

base_url = f"{request.url.scheme}://{request.url.netloc}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The base_url is derived from the Host header of the incoming request (request.url.netloc). This URL is then used to make internal POST requests to other endpoints. An attacker can provide a malicious Host header to redirect these internal requests to an arbitrary server, facilitating a Server-Side Request Forgery (SSRF) attack. Since the server's internal IP and port are already known to the actor, they should be used instead.

Suggested change
base_url = f"{request.url.scheme}://{request.url.netloc}"
base_url = f"http://{self._ip}:{self._port}"

Comment on lines +307 to +312
@field_validator("location")
@classmethod
def validate_location(cls, v: str) -> str:
if not v:
raise ValueError("location must not be empty")
return v
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The location field in ImageAssetPointerChunk is used as a URL for fetching images but lacks proper validation. It only checks if the string is non-empty. This allows users to provide malicious URLs (e.g., pointing to internal services or cloud metadata endpoints like 169.254.169.254), which are then passed to internal components for fetching, facilitating SSRF.

Comment on lines +22 to +30
async def _render_image_url(session: aiohttp.ClientSession, base_url: str, location: str, model: str) -> list:
"""Call /v1/chat/completions/render with an image URL."""
messages = [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": location}}]}]
async with session.post(
f"{base_url}/v1/chat/completions/render",
json={"model": model, "messages": messages},
) as resp:
resp.raise_for_status()
return await resp.json()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The _render_image_url function accepts a user-provided location URL and passes it to an internal endpoint that is expected to fetch the image. This facilitates Server-Side Request Forgery (SSRF), as an attacker can provide URLs pointing to internal resources or metadata services. Validation should be implemented at the API entry point or within this helper to restrict allowed URLs.

raise ValueError(
"_SKYRL_USE_NEW_INFERENCE=1 requires " "generator.inference_engine.external_proxy_url to be set."
)
logger.info(f"Using RemoteInferenceClient: proxy_url={proxy_url}, server_urls={server_urls}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The application logs the proxy_url and server_urls used by the RemoteInferenceClient. These URLs may contain sensitive information such as API keys or basic authentication credentials (e.g., http://user:pass@host). Logging these URLs in plaintext exposes these secrets in the logs.

Comment on lines +81 to +126
elif chunk["type"] == "image":
b64 = chunk["data"] # already a base64 string in JSON
fmt = chunk["format"]
render_resp = await _render_image(session, base_url, b64, fmt, model)
placeholder_tokens, mm_hash = _extract_render_info(render_resp)

expected = chunk.get("expected_tokens")
if expected is not None and expected != len(placeholder_tokens):
raise ValueError(
f"ImageChunk.expected_tokens={expected} but render returned "
f"{len(placeholder_tokens)} placeholder tokens (mm_hash={mm_hash})"
)

features.append(
{
"modality": "image",
"mm_hash": mm_hash,
"offset": len(assembled_tokens),
"length": len(placeholder_tokens),
"kwargs_data": None,
}
)
assembled_tokens.extend(placeholder_tokens)

elif chunk["type"] == "image_asset_pointer":
location = chunk["location"]
render_resp = await _render_image_url(session, base_url, location, model)
placeholder_tokens, mm_hash = _extract_render_info(render_resp)

expected = chunk.get("expected_tokens")
if expected is not None and expected != len(placeholder_tokens):
raise ValueError(
f"ImageAssetPointerChunk.expected_tokens={expected} but render returned "
f"{len(placeholder_tokens)} placeholder tokens (mm_hash={mm_hash})"
)

features.append(
{
"modality": "image",
"mm_hash": mm_hash,
"offset": len(assembled_tokens),
"length": len(placeholder_tokens),
"kwargs_data": None,
}
)
assembled_tokens.extend(placeholder_tokens)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is significant code duplication in how image and image_asset_pointer chunks are processed. The logic for rendering the image, extracting information, validating the number of expected tokens, and constructing the feature dictionary is nearly identical for both chunk types. To improve maintainability and reduce redundancy, this common logic should be refactored into a shared helper function.

Comment on lines +386 to +390
import aiohttp
from fastapi.responses import JSONResponse
from skyrl.backends.skyrl_train.inference_servers._sample_helpers import (
_assemble_tokens_from_chunks,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The imports for aiohttp, JSONResponse, and _assemble_tokens_from_chunks are located inside the _sample async function. This will cause them to be re-imported on every API call, which can introduce a small but unnecessary performance overhead. It is standard practice to place imports at the top of the file for better performance and code organization, unless there is a specific reason for lazy loading (e.g., to avoid circular dependencies). Please move these imports to the top of the file.

self._vision_processor = VisionProcessor.from_pretrained(self.base_model)
self._image_token_id = self._tokenizer.convert_tokens_to_ids(self._vision_processor.image_token)
logger.info("VisionProcessor loaded for %s (image_token_id=%s)", self.base_model, self._image_token_id)
except Exception:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of a broad except Exception: to handle potential errors when loading the VisionProcessor is not ideal. It can suppress unexpected errors, making debugging more difficult. It would be more robust to catch specific exceptions that are anticipated, such as ImportError, OSError (for file-not-found issues), KeyError, or AttributeError. This will make the error handling more precise and the code easier to maintain.

Suggested change
except Exception:
except (ImportError, OSError, KeyError, AttributeError):

await self._inference_engine_client.resume_generation()

def set_inference_engine_client(self, inference_engine_client: InferenceEngineClient) -> None:
def set_inference_engine_client(self, inference_engine_client) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The type hint for the inference_engine_client parameter was removed, which reduces type safety and code clarity. To accommodate both InferenceEngineClient and the new RemoteInferenceClient, you could use a Union type hint. This would restore type safety and make the code easier to understand and maintain. You may need to add a TYPE_CHECKING block for the import of RemoteInferenceClient to avoid circular dependencies.

Suggested change
def set_inference_engine_client(self, inference_engine_client) -> None:
def set_inference_engine_client(self, inference_engine_client: Union[InferenceEngineClient, "RemoteInferenceClient"]) -> None:

Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 3 potential issues.

View 7 additional findings in Devin Review.

Open in Devin Review

prompt_token_ids=prompt,
num_samples=1, # Tinker batches multiple samples separately
prompt_token_ids=prepared_batch.all_prompts[i],
chunks=prepared_batch.all_prompt_inputs[i].model_dump()["chunks"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 model_dump() on types.ImageChunk produces non-JSON-serializable bytes, crashing RemoteInferenceClient

When _SKYRL_USE_NEW_INFERENCE=1 and the prompt contains an ImageChunk, calling model_dump() on the ModelInput returns the ImageChunk.data field as raw bytes (since types.ImageChunk.data is typed as bytes at skyrl/tinker/types.py:102). This dict is then passed as chunks to RemoteInferenceClient.sample() at skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py:334, which attempts to JSON-serialize it via session.post(url, json=payload). Since bytes is not JSON-serializable, this crashes with TypeError: Object of type bytes is not JSON serializable. Additionally, the server-side _render_image at skyrl/backends/skyrl_train/inference_servers/_sample_helpers.py:10 expects b64_data: str, not raw bytes. The fix is to use model_dump(mode='json') which serializes bytes fields as base64 strings.

Suggested change
chunks=prepared_batch.all_prompt_inputs[i].model_dump()["chunks"],
chunks=prepared_batch.all_prompt_inputs[i].model_dump(mode='json')["chunks"],
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Comment on lines +345 to +346
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Missing blank line between sample() and chat_completion() methods makes them appear joined

There is no blank line (or any separator) between the closing of sample() (line 345: }) and the definition of async def chat_completion (line 346). While Python will parse this correctly, PEP 8 requires two blank lines between top-level definitions and one between method definitions in a class. More importantly, this looks like a merge/rebase artifact where a required blank line was dropped, which hurts readability and maintainability.

Suggested change
}
}
async def chat_completion(
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Comment on lines +413 to +414
return JSONResponse(content=result)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Missing blank line between _sample endpoint handler and shutdown method

There is no blank line between the return JSONResponse(content=result) closing the _sample endpoint handler (line 413, inside _add_custom_endpoints) and the async def shutdown(self) method definition (line 414). This is the same kind of missing-separator issue as in remote_inference_client.py — likely an artifact of code generation or rebase. It hurts readability since shutdown is a different class-level method from _add_custom_endpoints.

Suggested change
return JSONResponse(content=result)
return JSONResponse(content=result)
async def shutdown(self) -> None:
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

@nithinvc nithinvc force-pushed the nithinc/vlm-tinker-v3 branch from 6dd4358 to b9643ee Compare March 5, 2026 20:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant