diff --git a/python/infinilm/llm/cache_manager.py b/python/infinilm/llm/cache_manager.py index 1f1e48c8..44ca1376 100644 --- a/python/infinilm/llm/cache_manager.py +++ b/python/infinilm/llm/cache_manager.py @@ -261,6 +261,12 @@ def try_free_blocks(self, num_required: int) -> bool: def get_num_free_blocks(self) -> int: return len(self.free_block_ids) + def get_total_usable_blocks(self) -> int: + freeable_used_blocks = sum( + 1 for bid in self.used_block_ids if self.blocks[bid].ref_count == 0 + ) + return len(self.free_block_ids) + freeable_used_blocks + def __repr__(self): return ( f"BlockManager(blocks={self.num_blocks}, block_size={self.block_size}, " diff --git a/python/infinilm/llm/llm.py b/python/infinilm/llm/llm.py index c152d6e4..53eec8a0 100644 --- a/python/infinilm/llm/llm.py +++ b/python/infinilm/llm/llm.py @@ -228,11 +228,39 @@ def _update_requests( req.generated_token_ids.append(token_id) if req.is_prefill: req.is_prefill = False + # vLLM-style replacement character handling is primarily relevant for streaming. + # For offline generation (no output queue), keep the fast incremental path. + if req._output_queue is None: + token_text = self.detokenize([token_id]) + req.generated_text += token_text + else: + # Streaming path: compute delta from a full decode so we can hold back + # trailing '\ufffd' (likely an incomplete UTF-8 sequence). + decoded_text = self.detokenize(req.generated_token_ids) + + finished_now = False + if self._check_request_finished(req, token_id): + req.mark_finished(req.finish_reason) + finished_now = True - token_text = self.tokenizer.decode(token_id) - req.generated_text += token_text + # Update generated_text to the latest decode (used for stop-string checks and debugging) + req.generated_text = decoded_text + + holds_back_incomplete_utf8 = ( + bool(decoded_text) and decoded_text.endswith("\ufffd") + ) - if self._check_request_finished(req, token_id): + # vLLM-style: hold back only if we are not on the final chunk. + if holds_back_incomplete_utf8 and not finished_now: + token_text = "" + else: + last_len = getattr(req, "_stream_last_yielded_length", 0) + token_text = decoded_text[last_len:] + if token_text: + req._stream_last_yielded_length = len(decoded_text) + + # For non-streaming, finish checks happen here. + if req._output_queue is None and self._check_request_finished(req, token_id): req.mark_finished(req.finish_reason) # Put output in queue if it exists (for async streaming) @@ -283,12 +311,15 @@ def apply_chat_template( self, messages: List[dict], add_generation_prompt: bool = True, + chat_template_kwargs: Optional[dict] = None, ) -> str: """Apply chat template to messages.""" + chat_template_kwargs = chat_template_kwargs or {} return self.tokenizer.apply_chat_template( conversation=messages, add_generation_prompt=add_generation_prompt, tokenize=False, + **chat_template_kwargs, ) @@ -486,6 +517,10 @@ def __init__( self._running = False self._step_thread: Optional[threading.Thread] = None + self._healthy = True + + def is_healthy(self) -> bool: + return bool(self._healthy) def start(self): """Start the background inference loop.""" @@ -520,6 +555,7 @@ def _step_loop(self): time.sleep(0.01) except Exception as e: logger.error(f"Error in step loop: {e}", exc_info=True) + self._healthy = False self._running = False break @@ -581,6 +617,8 @@ def add_chat_request( request_id: Optional[str] = None, request_data: Optional[dict] = None, http_request: Optional[any] = None, + add_generation_prompt: bool = True, + chat_template_kwargs: Optional[dict] = None, ) -> InferenceRequest: """Add a chat request to the engine. @@ -594,7 +632,11 @@ def add_chat_request( Returns: The created InferenceRequest object. """ - prompt = self.engine.apply_chat_template(messages, add_generation_prompt=True) + prompt = self.engine.apply_chat_template( + messages, + add_generation_prompt=add_generation_prompt, + chat_template_kwargs=chat_template_kwargs, + ) return self.add_request( prompt=prompt, sampling_params=sampling_params, @@ -607,6 +649,7 @@ async def stream_request( self, request: InferenceRequest, timeout: float = 100.0, + request_timeout: Optional[float] = None, ) -> AsyncIterator[TokenOutput]: """Stream tokens from a request. @@ -619,6 +662,7 @@ async def stream_request( """ import asyncio + start = time.time() while True: if request.is_finished() and request.output_queue.async_q.empty(): break @@ -635,6 +679,20 @@ async def stream_request( if token_output.finished: break except asyncio.TimeoutError: + # Enforce request-level timeout even if no tokens are produced. + if request_timeout is not None: + now = time.time() + if now - start > float(request_timeout): + request.mark_timeout() + yield TokenOutput( + request_id=request.request_id, + token_id=-1, + token_text="", + finished=True, + finish_reason=FinishReason.TIMEOUT, + generated_text=request.generated_text, + ) + break if request.is_finished(): break continue diff --git a/python/infinilm/llm/request.py b/python/infinilm/llm/request.py index d6e08aef..224828d1 100644 --- a/python/infinilm/llm/request.py +++ b/python/infinilm/llm/request.py @@ -144,6 +144,10 @@ def __init__( # Output management (for async streaming) self._output_queue: Optional[janus.Queue] = None + # Streaming helpers (vLLM-style UTF-8 buffering at the chunking layer) + # Used by the engine to compute "delta" text chunks from a full decode. + self._stream_last_yielded_length: int = 0 + @property def output_queue(self) -> janus.Queue: """Lazy initialization of output queue.""" diff --git a/python/infinilm/llm/scheduler.py b/python/infinilm/llm/scheduler.py index b1853292..b3188c9b 100644 --- a/python/infinilm/llm/scheduler.py +++ b/python/infinilm/llm/scheduler.py @@ -155,12 +155,21 @@ def schedule(self) -> Optional[SchedulerOutput]: except queue.Empty: break + if not self.can_accept_request(req): + self.waiting_queue.sync_q.put(req) + break + + # Skip requests that were already finished (e.g., timed out/canceled while waiting) + if req.is_finished(): + self.complete_requests([req]) + continue + req_tokens = req.get_input_tokens() num_required_blocks = req.get_num_blocks_required(self.block_size) if not self.cache_manager.can_allocate(num_required_blocks): if not self.cache_manager.try_free_blocks(num_required_blocks): - raise RuntimeError("No available cache blocks") + raise RuntimeError("No available cache blocks for new request") # Allocate blocks with automatic prefix caching support req.block_table, req.slot_mapping, req.num_cached_tokens = ( @@ -185,6 +194,10 @@ def schedule(self) -> Optional[SchedulerOutput]: req = self.running_queue.sync_q.get_nowait() except queue.Empty: break + # Skip requests that were already finished (e.g., timed out/canceled while running) + if req.is_finished(): + self.complete_requests([req]) + continue # Decode phase: allocate slot for newly generated token try: @@ -197,7 +210,7 @@ def schedule(self) -> Optional[SchedulerOutput]: scheduled_requests.append(req) except RuntimeError as e: - raise RuntimeError("No available cache blocks") from e + raise RuntimeError("No available cache blocks for new token") from e # Return decode batch if any running requests were scheduled if scheduled_requests: @@ -237,6 +250,31 @@ def complete_requests(self, requests: List[InferenceRequest]): # Still running, put back in running queue self.running_queue.sync_q.put(req) + def can_accept_request(self, request: InferenceRequest) -> bool: + total_required_blocks = 0 + + # Calculate blocks needed for running requests + running_queue_size = self.running_queue.sync_q.qsize() + for _ in range(running_queue_size): + req = self.running_queue.sync_q.get() + remaining_tokens = ( + req.sampling_params.max_tokens - req.get_num_generated_tokens() + ) + num_blocks_needed = ( + remaining_tokens + self.block_size - 1 + ) // self.block_size + total_required_blocks += num_blocks_needed + self.running_queue.sync_q.put(req) + + # Calculate blocks needed for the new request + total_length = request.get_prompt_length() + total_length += request.sampling_params.max_tokens + num_blocks_needed = (total_length + self.block_size - 1) // self.block_size + total_required_blocks += num_blocks_needed + + # Compare with total usable blocks in cache manager + return total_required_blocks <= self.cache_manager.get_total_usable_blocks() + def get_cache_stats(self) -> dict: """Get cache statistics.""" return { diff --git a/python/infinilm/server/inference_server.py b/python/infinilm/server/inference_server.py index 99e1988d..59eecf8a 100644 --- a/python/infinilm/server/inference_server.py +++ b/python/infinilm/server/inference_server.py @@ -10,9 +10,10 @@ import argparse import uvicorn import logging +import os from fastapi import FastAPI, Request -from fastapi.responses import JSONResponse, StreamingResponse +from fastapi.responses import JSONResponse, StreamingResponse, Response from infinilm.llm import AsyncLLMEngine, SamplingParams, FinishReason @@ -22,7 +23,7 @@ DEFAULT_REQUEST_TIMEOUT = 1000.0 -def chunk_json(id_, content=None, role=None, finish_reason=None): +def chunk_json(id_, content=None, role=None, finish_reason=None, model: str = "unknown"): """Generate JSON chunk for streaming response.""" delta = {} if content: @@ -33,12 +34,11 @@ def chunk_json(id_, content=None, role=None, finish_reason=None): "id": id_, "object": "chat.completion.chunk", "created": int(time.time()), - "model": "jiuge", + "model": model, "system_fingerprint": None, "choices": [ { "index": 0, - "text": content, "delta": delta, "logprobs": None, "finish_reason": finish_reason, @@ -47,6 +47,44 @@ def chunk_json(id_, content=None, role=None, finish_reason=None): } +def completion_json( + id_: str, + content: str, + role: str, + finish_reason: str, + model: str, + prompt_tokens: int = 0, + completion_tokens: int = 0, +) -> dict: + """Generate JSON response for non-streaming chat completion (OpenAI-compatible format).""" + response = { + "id": id_, + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "message": { + "role": role, + "content": content, + }, + "finish_reason": finish_reason, + } + ], + } + + # Add usage field if token counts are available + if prompt_tokens > 0 or completion_tokens > 0: + response["usage"] = { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + } + + return response + + class InferenceServer: """HTTP server for LLM inference.""" @@ -65,6 +103,7 @@ def __init__( top_k: int = 1, host: str = "0.0.0.0", port: int = 8000, + max_context_length: int = None, ): """Initialize inference server. @@ -82,8 +121,12 @@ def __init__( top_k: Default top-k sampling parameter. host: Server host address. port: Server port number. + max_context_length: Maximum context length for input prompts. If None, + uses the model's max_position_embeddings. If set, must be <= max_position_embeddings. """ self.model_path = model_path + # vLLM-like served model id: directory name of model_path + self.model_id = os.path.basename(os.path.normpath(model_path)) or "model" self.device = device self.dtype = dtype self.tensor_parallel_size = tensor_parallel_size @@ -96,6 +139,7 @@ def __init__( self.top_k = top_k self.host = host self.port = port + self.max_context_length = max_context_length self.engine: AsyncLLMEngine = None @@ -133,10 +177,214 @@ async def lifespan(app: FastAPI): self._register_routes(app) return app + def _normalize_messages(self, messages: list) -> list: + """Normalize messages to handle multimodal content (list format). + + Converts content from list format [{"type": "text", "text": "..."}] + to string format for chat template compatibility. + """ + normalized = [] + for msg in messages: + if not isinstance(msg, dict): + normalized.append(msg) + continue + + content = msg.get("content") + if isinstance(content, list): + # Extract text from multimodal content list + text_parts = [] + for part in content: + if isinstance(part, dict): + if part.get("type") == "text" and "text" in part: + text_parts.append(part["text"]) + elif isinstance(part, str): + text_parts.append(part) + elif isinstance(part, str): + text_parts.append(part) + # Join all text parts + normalized_msg = msg.copy() + normalized_msg["content"] = "".join(text_parts) if text_parts else "" + normalized.append(normalized_msg) + else: + normalized.append(msg) + + return normalized + + def _truncate_messages_if_needed( + self, messages: list, max_tokens: int = None, reserve_tokens: int = 256 + ) -> list: + """Truncate messages if they exceed the model's max context length. + + Args: + messages: List of message dicts. + max_tokens: Maximum tokens to generate (used to reserve space). + reserve_tokens: Additional tokens to reserve for generation and safety margin. + + Returns: + Truncated messages list if needed, original list otherwise. + """ + if not self.engine or not self.engine.engine: + return messages + + try: + # Get max context length from model config + model_config = self.engine.engine.model_engine.config + model_max_position_embeddings = getattr( + model_config, "max_position_embeddings", None + ) + if model_max_position_embeddings is None: + # Try to get from tokenizer config as fallback + tokenizer = self.engine.engine.tokenizer + if hasattr(tokenizer, "model_max_length") and tokenizer.model_max_length: + model_max_position_embeddings = tokenizer.model_max_length + else: + # Default fallback + logger.warning( + "Could not determine max context length, using default 2048" + ) + model_max_position_embeddings = 2048 + + # Use server-level max_context_length if set, otherwise use model's max_position_embeddings + if self.max_context_length is not None: + if self.max_context_length > model_max_position_embeddings: + logger.warning( + f"Server max_context_length ({self.max_context_length}) exceeds " + f"model's max_position_embeddings ({model_max_position_embeddings}). " + f"Using model's max_position_embeddings instead." + ) + max_context_len = model_max_position_embeddings + else: + max_context_len = self.max_context_length + logger.debug( + f"Using server-level max_context_length: {max_context_len} " + f"(model max_position_embeddings: {model_max_position_embeddings})" + ) + else: + max_context_len = model_max_position_embeddings + + # Calculate available length for prompt + # Reserve space for generation tokens and safety margin + if max_tokens is None: + max_tokens = self.max_tokens + available_len = max_context_len - max_tokens - reserve_tokens + if available_len <= 0: + available_len = max_context_len - reserve_tokens + + # Apply chat template to get prompt + try: + prompt = self.engine.engine.apply_chat_template( + messages, add_generation_prompt=True + ) + except Exception as e: + logger.warning(f"Failed to apply chat template for length check: {e}") + return messages + + # Tokenize to check length + tokenizer = self.engine.engine.tokenizer + encoded = tokenizer.encode(prompt, add_special_tokens=False) + prompt_len = len(encoded) + + if prompt_len <= available_len: + return messages + + # Prompt is too long, need to truncate + logger.warning( + f"Prompt length ({prompt_len}) exceeds available context length " + f"({available_len}). Truncating messages..." + ) + + # Try to truncate by removing oldest non-system messages + # Keep system message if present + system_messages = [msg for msg in messages if msg.get("role") == "system"] + non_system_messages = [msg for msg in messages if msg.get("role") != "system"] + + # Binary search for the right number of messages to keep + # Start with keeping all system messages and last user/assistant pair + truncated = system_messages.copy() + if non_system_messages: + # Keep the last message (usually user's current question) + truncated.append(non_system_messages[-1]) + + # Try adding more messages from the end until we hit the limit + for i in range(len(non_system_messages) - 2, -1, -1): + test_messages = system_messages + non_system_messages[i:] + try: + test_prompt = self.engine.engine.apply_chat_template( + test_messages, add_generation_prompt=True + ) + test_encoded = tokenizer.encode( + test_prompt, add_special_tokens=False + ) + if len(test_encoded) <= available_len: + truncated = test_messages + else: + break + except Exception: + break + + # Final check - if still too long, remove more messages + try: + final_prompt = self.engine.engine.apply_chat_template( + truncated, add_generation_prompt=True + ) + final_encoded = tokenizer.encode(final_prompt, add_special_tokens=False) + final_len = len(final_encoded) + if final_len > available_len: + # Still too long, keep only system message and last user message + if system_messages and non_system_messages: + truncated = system_messages + [non_system_messages[-1]] + elif non_system_messages: + truncated = [non_system_messages[-1]] + else: + truncated = system_messages + + # Check again + final_prompt = self.engine.engine.apply_chat_template( + truncated, add_generation_prompt=True + ) + final_encoded = tokenizer.encode(final_prompt, add_special_tokens=False) + final_len = len(final_encoded) + if final_len > available_len: + logger.warning( + f"Even minimal messages result in prompt length {final_len}, " + f"which exceeds available {available_len}. " + f"Model may crash or produce errors." + ) + except Exception as e: + logger.warning(f"Error in final truncation check: {e}") + + removed_count = len(messages) - len(truncated) + if removed_count > 0: + # Get final length for logging + try: + final_prompt_check = self.engine.engine.apply_chat_template( + truncated, add_generation_prompt=True + ) + final_encoded_check = tokenizer.encode( + final_prompt_check, add_special_tokens=False + ) + final_len_check = len(final_encoded_check) + except Exception: + final_len_check = "unknown" + logger.warning( + f"Removed {removed_count} message(s) to fit within context limit " + f"(original: {prompt_len} tokens, truncated: {final_len_check} tokens)" + ) + + return truncated + + except Exception as e: + logger.error(f"Error during message truncation: {e}", exc_info=True) + # Return original messages on error to avoid breaking requests + return messages + def _register_routes(self, app: FastAPI): """Register API routes.""" + # OpenAI-compatible chat completions endpoint. + # Support both legacy path and OpenAI-style /v1 prefix for proxy/router compatibility. @app.post("/chat/completions") + @app.post("/v1/chat/completions") async def chat_completions(request: Request): try: data = await request.json() @@ -153,8 +401,23 @@ async def chat_completions(request: Request): else: data["messages"] = [{"role": "user", "content": data.get("prompt")}] + # Normalize messages to handle multimodal content (list format) + data["messages"] = self._normalize_messages(data.get("messages", [])) + + # Truncate messages if they exceed model's max context length + max_tokens = data.get("max_tokens") or data.get("max_new_tokens") or self.max_tokens + data["messages"] = self._truncate_messages_if_needed( + data["messages"], max_tokens=max_tokens + ) + stream = data.get("stream", False) request_id = f"cmpl-{uuid.uuid4().hex}" + messages_info = data.get("messages", []) + if messages_info and len(messages_info) > 0: + first_msg_content = messages_info[0].get("content", "") if isinstance(messages_info[0], dict) else str(messages_info[0]) + logger.info(f"Received request {request_id} with {len(messages_info)} message(s), first message length={len(first_msg_content)}") + else: + logger.info(f"Received request {request_id} with empty messages") if stream: return StreamingResponse( @@ -169,15 +432,31 @@ async def chat_completions(request: Request): @app.get("/health") async def health(): + # Expose engine health so babysitter/registry can treat backend as unhealthy. + if ( + self.engine is not None + and hasattr(self.engine, "is_healthy") + and not self.engine.is_healthy() + ): + return JSONResponse(content={"status": "unhealthy"}, status_code=503) return {"status": "healthy"} - @app.get("/v1/models") - async def list_models(): + @app.get("/metrics") + async def metrics(): + """Prometheus-compatible metrics endpoint.""" + # Return empty metrics for now to avoid 404 + # Can be extended with actual metrics later + return Response( + content="# InfiniLM Metrics\n# No metrics collected yet\n", + media_type="text/plain; version=0.0.4; charset=utf-8" + ) + + def _models_payload(): return { "object": "list", "data": [ { - "id": "jiuge", + "id": self.model_id, "object": "model", "created": int(time.time()), "owned_by": "infinilm", @@ -185,14 +464,53 @@ async def list_models(): ], } + # Support both /v1/models (OpenAI) and /models (common legacy) for compatibility. + @app.get("/v1/models") + async def list_models(): + return _models_payload() + + @app.get("/models") + async def list_models_legacy(): + return _models_payload() + def _build_sampling_params(self, data: dict) -> SamplingParams: """Build SamplingParams from request data.""" + # Support both: + # - top-level OpenAI-ish fields: temperature/top_p/top_k/max_tokens/stop + # - nested dict: sampling_params: { ... } + sp = data.get("sampling_params") or {} + if not isinstance(sp, dict): + sp = {} + + def pick(key: str, default): + # Priority: explicit top-level field > nested sampling_params > server default + if key in data and data.get(key) is not None: + return data.get(key) + if key in sp and sp.get(key) is not None: + return sp.get(key) + return default + + # Accept common alias + max_tokens = pick("max_tokens", self.max_tokens) + if max_tokens is None: + # Some clients use max_new_tokens + max_tokens = pick("max_new_tokens", self.max_tokens) + + stop = pick("stop", None) + if isinstance(stop, str): + stop = [stop] + + stop_token_ids = pick("stop_token_ids", None) + if isinstance(stop_token_ids, int): + stop_token_ids = [stop_token_ids] + return SamplingParams( - temperature=data.get("temperature", self.temperature), - top_p=data.get("top_p", self.top_p), - top_k=data.get("top_k", self.top_k), - max_tokens=data.get("max_tokens", self.max_tokens), - stop=data.get("stop"), + temperature=float(pick("temperature", self.temperature)), + top_p=float(pick("top_p", self.top_p)), + top_k=int(pick("top_k", self.top_k)), + max_tokens=int(max_tokens) if max_tokens is not None else None, + stop=stop, + stop_token_ids=stop_token_ids, ) async def _stream_chat(self, request_id: str, data: dict, http_request: Request): @@ -210,22 +528,26 @@ async def _stream_chat(self, request_id: str, data: dict, http_request: Request) request_id=request_id, request_data=data, http_request=http_request, + add_generation_prompt=bool(data.get("add_generation_prompt", True)), + chat_template_kwargs=data.get("chat_template_kwargs") or {}, ) async for token_output in self.engine.stream_request( - req, timeout=DEFAULT_STREAM_TIMEOUT + req, + timeout=DEFAULT_STREAM_TIMEOUT, + request_timeout=DEFAULT_REQUEST_TIMEOUT, ): - # Check timeout - if time.time() - start_time > DEFAULT_REQUEST_TIMEOUT: + # If stream_request enforces timeout, we can just surface the state to the client. + if token_output.finish_reason == FinishReason.TIMEOUT: logger.warning( f"Request {request_id} timed out after {DEFAULT_REQUEST_TIMEOUT}s" ) - req.mark_timeout() error_chunk = json.dumps( chunk_json( request_id, content="[Request timeout]", finish_reason="timeout", + model=self.model_id, ), ensure_ascii=False, ) @@ -240,7 +562,9 @@ async def _stream_chat(self, request_id: str, data: dict, http_request: Request) # Send token chunk = json.dumps( - chunk_json(request_id, content=token_output.token_text), + chunk_json( + request_id, content=token_output.token_text, model=self.model_id + ), ensure_ascii=False, ) yield f"data: {chunk}\n\n" @@ -250,10 +574,20 @@ async def _stream_chat(self, request_id: str, data: dict, http_request: Request) token_output.finish_reason ) chunk = json.dumps( - chunk_json(request_id, finish_reason=finish_reason), + chunk_json( + request_id, finish_reason=finish_reason, model=self.model_id + ), ensure_ascii=False, ) yield f"data: {chunk}\n\n" + elapsed_time = time.time() - start_time + generated_tokens = req.get_num_generated_tokens() if req else 0 + logger.info( + f"Stream completed for {request_id}: " + f"finish_reason={finish_reason}, " + f"tokens={generated_tokens}, " + f"duration={elapsed_time:.2f}s" + ) break except Exception as e: @@ -262,7 +596,10 @@ async def _stream_chat(self, request_id: str, data: dict, http_request: Request) req.mark_failed() error_chunk = json.dumps( chunk_json( - request_id, content=f"[Error: {str(e)}]", finish_reason="error" + request_id, + content=f"[Error: {str(e)}]", + finish_reason="error", + model=self.model_id, ), ensure_ascii=False, ) @@ -273,6 +610,11 @@ async def _stream_chat(self, request_id: str, data: dict, http_request: Request) req.mark_canceled() if req: await req.close() + elapsed_time = time.time() - start_time + logger.info( + f"Stream ended for {request_id}: " + f"total_duration={elapsed_time:.2f}s" + ) yield "data: [DONE]\n\n" async def _chat(self, request_id: str, data: dict, http_request: Request): @@ -290,17 +632,20 @@ async def _chat(self, request_id: str, data: dict, http_request: Request): request_id=request_id, request_data=data, http_request=http_request, + add_generation_prompt=bool(data.get("add_generation_prompt", True)), + chat_template_kwargs=data.get("chat_template_kwargs") or {}, ) # Collect all generated tokens output_text = "" async for token_output in self.engine.stream_request( - req, timeout=DEFAULT_STREAM_TIMEOUT + req, + timeout=DEFAULT_STREAM_TIMEOUT, + request_timeout=DEFAULT_REQUEST_TIMEOUT, ): - # Check timeout - if time.time() - start_time > DEFAULT_REQUEST_TIMEOUT: + # Request-level timeout is handled inside stream_request. + if token_output.finish_reason == FinishReason.TIMEOUT: logger.warning(f"Request {request_id} timed out") - req.mark_timeout() break # Check client disconnect @@ -317,11 +662,27 @@ async def _chat(self, request_id: str, data: dict, http_request: Request): output_text = output_text.strip() finish_reason = self._convert_finish_reason(req.finish_reason) - response = chunk_json( - request_id, + # Get token counts from request + prompt_tokens = req.get_prompt_length() if req else 0 + completion_tokens = req.get_num_generated_tokens() if req else 0 + + elapsed_time = time.time() - start_time + logger.info( + f"Non-streaming request completed for {request_id}: " + f"finish_reason={finish_reason or 'stop'}, " + f"prompt_tokens={prompt_tokens}, " + f"completion_tokens={completion_tokens}, " + f"duration={elapsed_time:.2f}s" + ) + + response = completion_json( + id_=request_id, content=output_text, role="assistant", finish_reason=finish_reason or "stop", + model=self.model_id, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, ) return response @@ -336,6 +697,11 @@ async def _chat(self, request_id: str, data: dict, http_request: Request): req.mark_canceled() if req: await req.close() + elapsed_time = time.time() - start_time + logger.info( + f"Non-streaming request ended for {request_id}: " + f"total_duration={elapsed_time:.2f}s" + ) def _convert_finish_reason(self, reason: FinishReason) -> str: """Convert FinishReason enum to string.""" @@ -401,6 +767,12 @@ def parse_args(): parser.add_argument("--top_k", type=int, default=1, help="Top-k sampling parameter") parser.add_argument("--host", type=str, default="0.0.0.0", help="Server host") parser.add_argument("--port", type=int, default=8000, help="Server port") + parser.add_argument( + "--max_context_length", + type=int, + default=None, + help="Maximum context length for input prompts. If not set, uses model's max_position_embeddings.", + ) parser.add_argument("--cpu", action="store_true", help="Use CPU") parser.add_argument("--nvidia", action="store_true", help="Use NVIDIA GPU") parser.add_argument("--metax", action="store_true", help="Use MetaX device") @@ -459,6 +831,7 @@ def main(): top_k=args.top_k, host=args.host, port=args.port, + max_context_length=args.max_context_length, ) server.start() diff --git a/test/bench/test_benchmark.py b/test/bench/test_benchmark.py index b23241ea..9999f353 100644 --- a/test/bench/test_benchmark.py +++ b/test/bench/test_benchmark.py @@ -4,7 +4,6 @@ import time import re import csv -from datasets import load_dataset, Dataset import numpy as np import infinicore from infinilm.modeling_utils import load_model_state_dict_by_file @@ -12,6 +11,7 @@ from infinilm.cache import StaticKVCacheConfig from infinilm.infer_engine import GenerationConfig, InferEngine from infinilm.cache import StaticKVCacheConfig +from datasets import load_dataset, Dataset from abc import ABC, abstractmethod @@ -67,9 +67,9 @@ def __init__( "nvidia": "cuda", "cambricon": "mlu", "ascend": "ascend", - "metax": "metax", - "moore": "moore", - "iluvatar": "iluvatar", + "metax": "cuda", + "moore": "musa", + "iluvatar": "cuda", "kunlun": "kunlun", "hygon": "hygon", }