-
Notifications
You must be signed in to change notification settings - Fork 296
feat: RL training support for VERL #1196
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Weichao Luo <luoweichao@sensetime.com> Co-authored-by: shihaobai <1798930569@qq.com>
Co-authored-by: sufubao <sufubao@sensetime.com>
Summary of ChangesHello @shihaobai, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly upgrades the LightLLM serving framework by introducing robust support for Reinforcement Learning (RL) training, particularly for VERL. The core changes revolve around advanced GPU memory management, enabling dynamic model weight updates, and refining the server's architecture for better control and communication. These enhancements pave the way for more flexible and efficient online model training and deployment scenarios. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces significant changes to support reinforcement learning (RL) training, including dynamic memory management and online weight updates. The refactoring of I/O structures and communication patterns is a positive step towards a more modular and maintainable codebase. The introduction of TorchMemorySaverWrapper provides a clean way to optionally use memory-saving features.
However, I've identified several issues that need attention, including a critical bug in the distributed handling of aborted requests, a logic error in request scheduling, and a potential race condition with asyncio.Event. There are also opportunities for code simplification and improved error handling. Please see the detailed comments for each issue.
| [req.is_aborted for req in self.running_batch.reqs], dtype=torch.bool, device="cpu" | ||
| ) | ||
| if self.is_multinode_tp: | ||
| dist.all_reduce(aborted_req_mask, op=dist.ReduceOp.MIN, group=self.mulitnode_group) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The all_reduce operation on aborted_req_mask uses dist.ReduceOp.MIN. For a boolean tensor, this is equivalent to a logical AND operation across all ranks. This means a request will only be considered aborted if it is marked as aborted on all ranks. The correct behavior should be to consider a request aborted if it is aborted on any rank. This requires a logical OR, which corresponds to dist.ReduceOp.MAX.
| dist.all_reduce(aborted_req_mask, op=dist.ReduceOp.MIN, group=self.mulitnode_group) | |
| dist.all_reduce(aborted_req_mask, op=dist.ReduceOp.MAX, group=self.mulitnode_group) |
| aborted_req_mask = torch.tensor( | ||
| [req.is_aborted for req in new_batch.reqs], dtype=torch.bool, device="cpu" | ||
| ) | ||
| dist.all_reduce(aborted_req_mask, op=dist.ReduceOp.MIN, group=self.mulitnode_group) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The all_reduce operation on aborted_req_mask uses dist.ReduceOp.MIN. For a boolean tensor, this is equivalent to a logical AND operation across all ranks. This means a request will only be considered aborted if it is marked as aborted on all ranks. The correct behavior should be to consider a request aborted if it is aborted on any rank. This requires a logical OR, which corresponds to dist.ReduceOp.MAX.
| dist.all_reduce(aborted_req_mask, op=dist.ReduceOp.MIN, group=self.mulitnode_group) | |
| dist.all_reduce(aborted_req_mask, op=dist.ReduceOp.MAX, group=self.mulitnode_group) |
| select_req_ids.append(req_id) | ||
|
|
||
| aborted_req_mask = torch.tensor(aborted_req_mask, dtype=torch.bool, device="cpu") | ||
| dist.all_reduce(aborted_req_mask, op=dist.ReduceOp.MIN, group=self.mulitnode_group) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The all_reduce operation on aborted_req_mask uses dist.ReduceOp.MIN. For a boolean tensor, this is equivalent to a logical AND operation across all ranks. This means a request will only be considered aborted if it is marked as aborted on all ranks. The correct behavior should be to consider a request aborted if it is aborted on any rank. This requires a logical OR, which corresponds to dist.ReduceOp.MAX.
| dist.all_reduce(aborted_req_mask, op=dist.ReduceOp.MIN, group=self.mulitnode_group) | |
| dist.all_reduce(aborted_req_mask, op=dist.ReduceOp.MAX, group=self.mulitnode_group) |
| return "unknown" # need fix | ||
| # raise RuntimeError("No GPU available") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function get_current_device_name returns "unknown" when no GPU is available. This could mask potential configuration errors. It would be better to raise a RuntimeError as suggested in the commented-out code to fail fast and provide a clear error message when a GPU is expected but not found.
| return "unknown" # need fix | |
| # raise RuntimeError("No GPU available") | |
| raise RuntimeError("No GPU available") |
| async def http_to_model_special_request( | ||
| self, request: GeneralHttpToModelRpcReq, timeout: int = 300 | ||
| ) -> GeneralModelToHttpRpcRsp: | ||
| event = await self.get_event_for_func(request.func_name) | ||
| await self.transfer_to_next_module(request) | ||
| try: | ||
| await asyncio.wait_for(event.wait(), timeout=timeout) | ||
| ret = event.result | ||
|
|
||
| except asyncio.TimeoutError: | ||
| ret = GeneralModelToHttpRpcRsp(success=False, msg="wait for response timeout", func_name=request.func_name) | ||
| except Exception as e: | ||
| ret = GeneralModelToHttpRpcRsp( | ||
| success=False, msg="wait for response error: %s" % str(e), func_name=request.func_name | ||
| ) | ||
| return ret |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The asyncio.Event object is not cleared after being used. If the same special request function is called again, the old event will still be set, causing event.wait() to return immediately with potentially stale data. This can lead to incorrect behavior. The event should be cleared before it's waited on.
| async def http_to_model_special_request( | |
| self, request: GeneralHttpToModelRpcReq, timeout: int = 300 | |
| ) -> GeneralModelToHttpRpcRsp: | |
| event = await self.get_event_for_func(request.func_name) | |
| await self.transfer_to_next_module(request) | |
| try: | |
| await asyncio.wait_for(event.wait(), timeout=timeout) | |
| ret = event.result | |
| except asyncio.TimeoutError: | |
| ret = GeneralModelToHttpRpcRsp(success=False, msg="wait for response timeout", func_name=request.func_name) | |
| except Exception as e: | |
| ret = GeneralModelToHttpRpcRsp( | |
| success=False, msg="wait for response error: %s" % str(e), func_name=request.func_name | |
| ) | |
| return ret | |
| async def http_to_model_special_request( | |
| self, request: GeneralHttpToModelRpcReq, timeout: int = 300 | |
| ) -> GeneralModelToHttpRpcRsp: | |
| event = await self.get_event_for_func(request.func_name) | |
| event.clear() | |
| await self.transfer_to_next_module(request) | |
| try: | |
| await asyncio.wait_for(event.wait(), timeout=timeout) | |
| ret = event.result | |
| except asyncio.TimeoutError: | |
| ret = GeneralModelToHttpRpcRsp(success=False, msg="wait for response timeout", func_name=request.func_name) | |
| except Exception as e: | |
| ret = GeneralModelToHttpRpcRsp( | |
| success=False, msg="wait for response error: %s" % str(e), func_name=request.func_name | |
| ) | |
| return ret |
| cur_group_reqs, is_busy, new_batch_first_router_need_tokens | ||
| ) | ||
| if ok_insert: | ||
| if ok_insert and False: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition if ok_insert and False: will always evaluate to false, which prevents the last group of requests (cur_group_reqs) from being added to the can_run_list. This appears to be a logic error, likely a leftover from debugging, and will cause requests to be unnecessarily delayed or starved.
| if ok_insert and False: | |
| if ok_insert: |
| def release_all(self): | ||
| self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT) | ||
| self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE) | ||
| self.torch_memory_saver.pause(tag=MemoryTag.GRAPH) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The release_all method can be made more concise and maintainable by iterating over a list of memory tags. This avoids code repetition and makes it easier to add or remove tags in the future.
| def release_all(self): | |
| self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT) | |
| self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE) | |
| self.torch_memory_saver.pause(tag=MemoryTag.GRAPH) | |
| def release_all(self): | |
| for tag in [MemoryTag.WEIGHT, MemoryTag.KV_CACHE, MemoryTag.GRAPH]: | |
| self.torch_memory_saver.pause(tag=tag) |
| def resume_all(self): | ||
| self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT) | ||
| self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE) | ||
| self.torch_memory_saver.resume(tag=MemoryTag.GRAPH) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The resume_all method can be made more concise and maintainable by iterating over a list of memory tags. This avoids code repetition and makes it easier to add or remove tags in the future.
| def resume_all(self): | |
| self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT) | |
| self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE) | |
| self.torch_memory_saver.resume(tag=MemoryTag.GRAPH) | |
| def resume_all(self): | |
| for tag in [MemoryTag.WEIGHT, MemoryTag.KV_CACHE, MemoryTag.GRAPH]: | |
| self.torch_memory_saver.resume(tag=tag) |
| if abort_all: | ||
| for group_req_id in list(self.req_id_to_out_inf.keys()): | ||
| await self.abort(group_req_id) | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No description provided.