Skip to content

Conversation

@kwen2501
Copy link

@kwen2501 kwen2501 commented Jan 20, 2026

Description

Each rank gathers inputs from (all) peer GPUs, and perform a matrix multiplication with its local weight.

Peer inputs are made visible via PyTorch Symmetric Memory, i.e.

import torch.distributed._symmetric_memor as symm_mem
symm_mem.empty(...)

The fused kernel is equivalent to:

dist.all_gather_into_tensor(ag_out, inp, group)
out = ag_out @ w

The fusion overlaps communication and computation in fine grain.

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.

Argparser

Add test

Signed-off-by: Ke Wen <kwen@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant