Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/infinilm/llm/cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,12 @@ def try_free_blocks(self, num_required: int) -> bool:
def get_num_free_blocks(self) -> int:
return len(self.free_block_ids)

def get_total_usable_blocks(self) -> int:
freeable_used_blocks = sum(
1 for bid in self.used_block_ids if self.blocks[bid].ref_count == 0
)
return len(self.free_block_ids) + freeable_used_blocks

def __repr__(self):
return (
f"BlockManager(blocks={self.num_blocks}, block_size={self.block_size}, "
Expand Down
66 changes: 62 additions & 4 deletions python/infinilm/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,39 @@ def _update_requests(
req.generated_token_ids.append(token_id)
if req.is_prefill:
req.is_prefill = False
# vLLM-style replacement character handling is primarily relevant for streaming.
# For offline generation (no output queue), keep the fast incremental path.
if req._output_queue is None:
token_text = self.detokenize([token_id])
req.generated_text += token_text
else:
# Streaming path: compute delta from a full decode so we can hold back
# trailing '\ufffd' (likely an incomplete UTF-8 sequence).
decoded_text = self.detokenize(req.generated_token_ids)

finished_now = False
if self._check_request_finished(req, token_id):
req.mark_finished(req.finish_reason)
finished_now = True

token_text = self.tokenizer.decode(token_id)
req.generated_text += token_text
# Update generated_text to the latest decode (used for stop-string checks and debugging)
req.generated_text = decoded_text

holds_back_incomplete_utf8 = (
bool(decoded_text) and decoded_text.endswith("\ufffd")
)

if self._check_request_finished(req, token_id):
# vLLM-style: hold back only if we are not on the final chunk.
if holds_back_incomplete_utf8 and not finished_now:
token_text = ""
else:
last_len = getattr(req, "_stream_last_yielded_length", 0)
token_text = decoded_text[last_len:]
if token_text:
req._stream_last_yielded_length = len(decoded_text)

# For non-streaming, finish checks happen here.
if req._output_queue is None and self._check_request_finished(req, token_id):
req.mark_finished(req.finish_reason)

# Put output in queue if it exists (for async streaming)
Expand Down Expand Up @@ -283,12 +311,15 @@ def apply_chat_template(
self,
messages: List[dict],
add_generation_prompt: bool = True,
chat_template_kwargs: Optional[dict] = None,
) -> str:
"""Apply chat template to messages."""
chat_template_kwargs = chat_template_kwargs or {}
return self.tokenizer.apply_chat_template(
conversation=messages,
add_generation_prompt=add_generation_prompt,
tokenize=False,
**chat_template_kwargs,
)


Expand Down Expand Up @@ -486,6 +517,10 @@ def __init__(

self._running = False
self._step_thread: Optional[threading.Thread] = None
self._healthy = True

def is_healthy(self) -> bool:
return bool(self._healthy)

def start(self):
"""Start the background inference loop."""
Expand Down Expand Up @@ -520,6 +555,7 @@ def _step_loop(self):
time.sleep(0.01)
except Exception as e:
logger.error(f"Error in step loop: {e}", exc_info=True)
self._healthy = False
self._running = False
break

Expand Down Expand Up @@ -581,6 +617,8 @@ def add_chat_request(
request_id: Optional[str] = None,
request_data: Optional[dict] = None,
http_request: Optional[any] = None,
add_generation_prompt: bool = True,
chat_template_kwargs: Optional[dict] = None,
) -> InferenceRequest:
"""Add a chat request to the engine.

Expand All @@ -594,7 +632,11 @@ def add_chat_request(
Returns:
The created InferenceRequest object.
"""
prompt = self.engine.apply_chat_template(messages, add_generation_prompt=True)
prompt = self.engine.apply_chat_template(
messages,
add_generation_prompt=add_generation_prompt,
chat_template_kwargs=chat_template_kwargs,
)
return self.add_request(
prompt=prompt,
sampling_params=sampling_params,
Expand All @@ -607,6 +649,7 @@ async def stream_request(
self,
request: InferenceRequest,
timeout: float = 100.0,
request_timeout: Optional[float] = None,
) -> AsyncIterator[TokenOutput]:
"""Stream tokens from a request.

Expand All @@ -619,6 +662,7 @@ async def stream_request(
"""
import asyncio

start = time.time()
while True:
if request.is_finished() and request.output_queue.async_q.empty():
break
Expand All @@ -635,6 +679,20 @@ async def stream_request(
if token_output.finished:
break
except asyncio.TimeoutError:
# Enforce request-level timeout even if no tokens are produced.
if request_timeout is not None:
now = time.time()
if now - start > float(request_timeout):
request.mark_timeout()
yield TokenOutput(
request_id=request.request_id,
token_id=-1,
token_text="",
finished=True,
finish_reason=FinishReason.TIMEOUT,
generated_text=request.generated_text,
)
break
if request.is_finished():
break
continue
Expand Down
4 changes: 4 additions & 0 deletions python/infinilm/llm/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ def __init__(
# Output management (for async streaming)
self._output_queue: Optional[janus.Queue] = None

# Streaming helpers (vLLM-style UTF-8 buffering at the chunking layer)
# Used by the engine to compute "delta" text chunks from a full decode.
self._stream_last_yielded_length: int = 0

@property
def output_queue(self) -> janus.Queue:
"""Lazy initialization of output queue."""
Expand Down
42 changes: 40 additions & 2 deletions python/infinilm/llm/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,21 @@ def schedule(self) -> Optional[SchedulerOutput]:
except queue.Empty:
break

if not self.can_accept_request(req):
self.waiting_queue.sync_q.put(req)
break

# Skip requests that were already finished (e.g., timed out/canceled while waiting)
if req.is_finished():
self.complete_requests([req])
continue

req_tokens = req.get_input_tokens()
num_required_blocks = req.get_num_blocks_required(self.block_size)

if not self.cache_manager.can_allocate(num_required_blocks):
if not self.cache_manager.try_free_blocks(num_required_blocks):
raise RuntimeError("No available cache blocks")
raise RuntimeError("No available cache blocks for new request")

# Allocate blocks with automatic prefix caching support
req.block_table, req.slot_mapping, req.num_cached_tokens = (
Expand All @@ -185,6 +194,10 @@ def schedule(self) -> Optional[SchedulerOutput]:
req = self.running_queue.sync_q.get_nowait()
except queue.Empty:
break
# Skip requests that were already finished (e.g., timed out/canceled while running)
if req.is_finished():
self.complete_requests([req])
continue

# Decode phase: allocate slot for newly generated token
try:
Expand All @@ -197,7 +210,7 @@ def schedule(self) -> Optional[SchedulerOutput]:
scheduled_requests.append(req)

except RuntimeError as e:
raise RuntimeError("No available cache blocks") from e
raise RuntimeError("No available cache blocks for new token") from e

# Return decode batch if any running requests were scheduled
if scheduled_requests:
Expand Down Expand Up @@ -237,6 +250,31 @@ def complete_requests(self, requests: List[InferenceRequest]):
# Still running, put back in running queue
self.running_queue.sync_q.put(req)

def can_accept_request(self, request: InferenceRequest) -> bool:
total_required_blocks = 0

# Calculate blocks needed for running requests
running_queue_size = self.running_queue.sync_q.qsize()
for _ in range(running_queue_size):
req = self.running_queue.sync_q.get()
remaining_tokens = (
req.sampling_params.max_tokens - req.get_num_generated_tokens()
)
num_blocks_needed = (
remaining_tokens + self.block_size - 1
) // self.block_size
total_required_blocks += num_blocks_needed
self.running_queue.sync_q.put(req)

# Calculate blocks needed for the new request
total_length = request.get_prompt_length()
total_length += request.sampling_params.max_tokens
num_blocks_needed = (total_length + self.block_size - 1) // self.block_size
total_required_blocks += num_blocks_needed

# Compare with total usable blocks in cache manager
return total_required_blocks <= self.cache_manager.get_total_usable_blocks()

def get_cache_stats(self) -> dict:
"""Get cache statistics."""
return {
Expand Down
Loading