-
Notifications
You must be signed in to change notification settings - Fork 75
Accelerated ts.static + added scaling scripts #427
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
torch_sim/autobatching.py
Outdated
| bbox[i] += 2.0 | ||
| volume = bbox.prod() / 1000 # convert A^3 to nm^3 | ||
| number_density = state.n_atoms / volume.item() | ||
| # Use cell volume (O(1)); SimState always has a cell. Avoids O(N) position scan. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
non-periodic systems don't have a sensible cell, see #412
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I now minimized the differences compared to the initial code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In addition, I added explicit tests for the memory scaler values and verified that the changes in this PR do not affect the test’s success
| self.memory_scalers = calculate_batched_memory_scalers( | ||
| states, self.memory_scales_with | ||
| ) | ||
| self.state_slices = states.split() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
batching makes sense here
| if isinstance(states, SimState): | ||
| self.batched_states = [[states[index_bin]] for index_bin in self.index_bins] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
state.split() is identical to this and faster
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reusing self.state_slices instead of calling states.split() again makes the code 5% faster, so I'd keep it
3138aed to
e91fe92
Compare
examples/scaling/__init__.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this file isn't needed
| ) | ||
| self.state_slices = states.split() | ||
| else: | ||
| self.state_slices = states |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not concat and then called the batched logic?
Summary
Changes:
The figure below shows the speedup achieved for static evaluations, 10-step atomic relaxation, 10-step NVE MD, and 10-step NVT MD. Prior results are shown in blue, while new results are shown in red. The speedup is calculated as
speedup (%) = (baseline_time / current_time − 1) × 100. We observe that:ts.staticachieves a 43.9% speedup for 100,000 structurests.relaxachieves a 2.8% speedup for 1,500 structurests.integrate(NVE) achieves a 0.9% speedup for 10,000 structurests.integrate(NVT) achieves a 1.4% speedup for 10,000 structuresComments:
From the scaling plots, we can see that the timings of
ts.staticandts.integrateare all consistent with each other. Indeed:ts.static→ 85s for 100'000 evaluationsts.integrateNVE → 87s for 10'000 structures (10 MD steps each) → 87s for 100'000 evaluationsts.integrateNVT → 89s for 10'000 structures (10 MD steps each) → 89s for 100'000 evaluationsHowever, when looking at the relaxation:
ts.relax→ 63s for 1'000 structures (10 relax steps each) → 63s for 10'000 evaluations → ~630s for 100'000 evaluationsSo
ts.relaxis about 7x slower thants.staticorts.integrate. The unbatched FrechetCellFilter clearly contributes to that. I'm wondering if there are additional bottlenecks in the code that we might optimize to reduce that massive 7x cost.