Rustで書くN体シミュレーション
N体シミュレーション
N体シミュレーションは、複数の粒子が相互に引力や斥力を及ぼし合う動きを数値的に解く手法です。天体の軌道計算や粒子の相互作用をモデル化する際に使用され、物理学や天文学などで広く利用されます。
この記事では、重力相互作用を考慮し、各ステップで粒子の加速度・速度・位置を更新しながら粒子に見立てたパーティクルの運動をシミュレートするシンプルなN体シミュレーションをRustで実装します。
実装原理は下記記事の直接計算法を参考にします。直接計算法では、計算量は
3次元ベクトル
まずはベースとなるVector3クラスを実装します。nalgebraという線形代数ライブラリを使うこともできますが、依存関係をクリーンにしておきたいので手元で簡単に実装します。
#[derive(Debug, Clone, Copy, Default)]
struct Vector3<T> {
x: T,
y: T,
z: T,
}
impl<T> Vector3<T> {
#[must_use]
pub fn new(x: T, y: T, z: T) -> Self {
Self { x, y, z }
}
}
ベクトルのノルム(二乗の大きさ)[1]を計算する関数を用意します。ノルムの平方は、ベクトルの各成分の二乗を足し合わせた値です。ノルムそのものではなく、計算を高速化するために平方した値を返します。例えば、ベクトルv.norm_squared()
は
さらに、このノルムの平方根を取ることで、最終的なノルム(ベクトルの大きさ)を返します。これにより、ベクトルの長さを取得できます。
impl Vector3<f64> {
fn norm_squared(&self) -> f64 {
self.x * self.x + self.y * self.y + self.z * self.z
}
fn norm(&self) -> f64 {
self.norm_squared().sqrt()
}
}
シミュレーション内の粒子の初期状態として-0.5
から0.5
の範囲でランダムな
impl Vector3<f64> {
fn random() -> Self {
let mut rng = rand::rng();
Self::new(
rng.random_range(-0.5..0.5),
rng.random_range(-0.5..0.5),
rng.random_range(-0.5..0.5),
)
}
}
ゼロトレイト
ゼロ状態のVector3を明示的に初期化できるようにトレイトと実装を用意します。
pub trait Zero {
type Output;
fn zeros() -> Self::Output;
}
impl Zero for Vector3<f64> {
type Output = Vector3<f64>;
fn zeros() -> Self::Output {
Self::Output {
x: 0f64,
y: 0f64,
z: 0f64,
}
}
}
オペレーター
さらに、Vector3間における*
や+
、+=
などのオペレーターを必要の限り実装します。
Add
impl<T: std::ops::Add<Output = T>> std::ops::Add for Vector3<T> {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
Self {
x: self.x + rhs.x,
y: self.y + rhs.y,
z: self.z + rhs.z,
}
}
}
Sub
impl<T: std::ops::Sub<Output = T> + Copy> std::ops::Sub for Vector3<T> {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
Self {
x: self.x - rhs.x,
y: self.y - rhs.y,
z: self.z - rhs.z,
}
}
}
Mul
impl<T: std::ops::Mul<Output = T> + Copy> std::ops::Mul<T> for Vector3<T> {
type Output = Self;
fn mul(self, rhs: T) -> Self {
Self {
x: self.x * rhs,
y: self.y * rhs,
z: self.z * rhs,
}
}
}
AddAssign
impl<T: std::ops::AddAssign + Copy> std::ops::AddAssign for Vector3<T> {
fn add_assign(&mut self, rhs: Self) {
self.x += rhs.x;
self.y += rhs.y;
self.z += rhs.z;
}
}
粒子の状態
各粒子の位置、加速度及び質量を保持するステートを定義します。
#[derive(Debug, Clone, Default)]
struct State {
pos: Vec<Vector3<f64>>,
vel: Vec<Vector3<f64>>,
mass: Vec<f64>,
}
ここで
impl State {
#[must_use]
fn new(n: usize) -> Self {
let mut pos = Vec::with_capacity(n);
let mut vel = Vec::with_capacity(n);
let mass = vec![20.0 / n as f64; n];
for _ in 0..n {
pos.push(Vector3::random());
vel.push(Vector3::zeros());
}
Self { pos, vel, mass }
}
}
シミュレータ
シミュレータを実装します。g
は万有引力定数、softening
は粒子が非常に近くにあるときに数値的な問題を回避するために追加される小さな値[2]、state
は各粒子の状態、dt
は時間間隔、t
はシミュレータ内の現在時間を示します。
N体シミュレーションにおけるdt
は、シミュレーションの進行を表す時間の間隔、つまり1ステップあたりの時間の変化量を指します。簡単に言うと、シミュレーションで粒子の位置や速度を更新するための時間的な刻みです。dt
が小さいほどシミュレーションはより精密になりますが、計算量が増加します。逆に大きすぎるとシミュレーションが不安定になる可能性があり、物理法則に忠実でなくなります。
後に万有引力定数を含めこれら全てのパラメータをコマンドラインから自由に指定できるようにします。
#[derive(Debug, Clone, Default)]
struct Simulator {
g: f64,
softening: f64,
state: State,
dt: f64,
t: f64,
}
impl Simulator {
#[must_use]
fn new(g: f64, n: usize, dt: f64, softening: f64) -> Self {
let state = State::new(n);
Self {
g,
softening,
state,
dt,
t: 0.0,
}
}
}
重力加速度
この関数では、各粒子が他の粒子から受ける重力加速度を計算します。
各粒子
impl Simulator {
fn get_acceleration(&self) -> Vec<Vector3<f64>> {
let n = self.state.pos.len();
let mut acc = vec![Vector3::zeros(); n];
for i in 0..n {
for j in 0..n {
if i != j {
let r = self.state.pos[j] - self.state.pos[i];
let r2 = r.norm_squared() + self.softening * self.softening;
let inv_r3 = 1.0 / (r2.sqrt() * r2);
acc[i] += r * (self.g * self.state.mass[j] * inv_r3);
}
}
}
acc
}
}
粒子の状態更新
この関数では、ヴェルレ積分[3]を用いて粒子の位置と速度を更新します。
ヴェルレ積分法は、物理シミュレーションで位置と速度を時間発展させるために使用される数値積分法です。特に、ニュートン力学に基づく運動方程式の解法において高い精度を持ちながらも安定性があります。ヴェルレ積分法は3つのステップに分かれているので、これをRustで表現します。
まず速度を半ステップ分更新します。
次に位置を更新します。
ここでget_acceleration
を用いて新しい加速度
impl Simulator {
fn step(&mut self) {
let acc = self.get_acceleration();
for i in 0..self.state.pos.len() {
self.state.vel[i] += acc[i] * (self.dt / 2.0);
self.state.pos[i] += self.state.vel[i] * self.dt;
}
let acc_new = self.get_acceleration();
for i in 0..self.state.pos.len() {
self.state.vel[i] += acc_new[i] * (self.dt / 2.0);
}
self.t += self.dt;
}
}
運動エネルギーとポテンシャルエネルギー
この関数では、運動エネルギーとポテンシャルエネルギーを計算します。
運動エネルギーは、各粒子の速度から計算します。
万有引力によるポテンシャルエネルギーは、すべての粒子ペア
ポテンシャルエネルギーについては、数式上0.5
を掛けます。
impl Simulator {
fn get_energy(&self) -> (f64, f64) {
let mut ke = 0.0;
let mut pe = 0.0;
for i in 0..self.state.pos.len() {
ke += 0.5 * self.state.mass[i] * self.state.vel[i].norm_squared();
for j in 0..self.state.pos.len() {
if i != j {
let r = (self.state.pos[j] - self.state.pos[i]).norm();
pe -= self.g * self.state.mass[i] * self.state.mass[j] / r;
}
}
}
(ke, pe * 0.5)
}
}
コマンドライン
最後にコマンドラインを実装します。コマンドラインから自由に万有引力定数clap
の導出を使います。softening
もパラメータとして追加してもよいかもしれません。
cargo add clap --features "derive"
#[derive(Parser, Debug)]
struct Args {
/// The gravitational constant (G).
#[clap(short, long, default_value_t = 6.67430e-11)]
gravity: f64,
/// The time step (dt) used in the simulation to advance each step.
#[clap(short, long, default_value_t = 1.0)]
delta_time: f64,
/// The number of steps to run the simulation.
#[clap(short, long, default_value_t = 1000)]
steps: usize,
/// The number of planets to simulate.
#[clap(short, long, default_value_t = 100)]
num_planets: usize,
}
エントリは下記のようになります。
fn main() {
let args = Args::parse();
let softening = 0.1;
let mut simulator = Simulator::new(args.gravity, args.num_planets, args.delta_time, softening);
for _ in 0..args.steps {
simulator.step();
let (ke, pe) = simulator.get_energy();
println!(
"t: {:.2}, KE: {:.5}, PE: {:.5}, ={:.5}",
simulator.t,
ke,
pe,
ke + pe
);
}
}
可視化
このままでは味気ないので、3D空間に可視化できるようにします。今回は、生のHTML5+tailwindとthreejsというグラフィックライブラリを使ってステップごとの各粒子の描画を行います。
将来的にはWASM用のラッパーを書いてReactで完全にブラウザ上でパラメータを操作してシミュレーションを行えるようにしてもよいかもしれません。
テンプレートHTML
まずはテンプレート用のHTMLを用意して、これをRustプログラムに埋め込んで動的にHTMLを生成できるようにします。%
で囲んだトークンをRustで適切なデータに再配置します。
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>N-Body Visualizer</title>
<script src="https://cdn.tailwindcss.com"></script>
<style>
:root {
font-family: monospace;
font-size: 16px;
color: #fff;
}
body {
margin: 0;
}
canvas {
display: block;
}
</style>
</head>
<body>
<script type="importmap">
{
"imports": {
"three": "https://cdn.jsdelivr.net/npm/three@0.167.0/build/three.module.js",
"three/addons/": "https://cdn.jsdelivr.net/npm/three@0.167.0/examples/jsm/"
}
}
</script>
<script type="module">
import * as THREE from 'three';
import { OrbitControls } from 'three/addons/controls/OrbitControls.js';
const DATA = [%REPLACEME%];
let currentData = DATA[0];
const STEPS = %STEPS%;
const DELTA_TIME = %DELTA_TIME%;
let deltaTime = 1.0;
let scene, camera, renderer, controls;
let particles = [];
function initScene() {
scene = new THREE.Scene();
scene.background = new THREE.Color(0x111111);
camera = new THREE.PerspectiveCamera(
75,
window.innerWidth / window.innerHeight,
0.1,
1000000000000
);
camera.position.z = 200;
renderer = new THREE.WebGLRenderer();
renderer.setSize(window.innerWidth, window.innerHeight);
document.body.appendChild(renderer.domElement);
controls = new OrbitControls(camera, renderer.domElement);
controls.enableDamping = true;
controls.dampingFactor = 0.05;
const geometry = new THREE.SphereGeometry(1.5, 32, 32);
const material = new THREE.MeshBasicMaterial({ color: 0x00ff00 });
currentData.forEach(() => {
const particle = new THREE.Mesh(geometry, material);
scene.add(particle);
particles.push(particle);
});
updateParticlePositions(currentData);
}
function updateParticlePositions(data) {
const positions = data.map((p) => [p.x, p.y, p.z]);
const maxVal = Math.max(...positions.flat().map(Math.abs));
const scale = 100 / maxVal;
particles.forEach((particle, i) => {
particle.position.set(
data[i].x * scale,
data[i].y * scale,
data[i].z * scale
);
});
}
function animate() {
requestAnimationFrame(animate);
controls.update();
renderer.render(scene, camera);
}
function handleResize() {
camera.aspect = window.innerWidth / window.innerHeight;
camera.updateProjectionMatrix();
renderer.setSize(window.innerWidth, window.innerHeight);
}
initScene();
animate();
window.addEventListener('resize', handleResize);
window.onload = () => {
const deltaTimeSlider = document.getElementById('trange');
const deltaTimeValueSpan = document.getElementById('tvalue');
deltaTimeSlider.min = 1;
deltaTimeSlider.max = STEPS - 1;
deltaTimeSlider.value = deltaTime;
deltaTimeSlider.addEventListener('input', (event) => {
deltaTime = parseFloat(event.target.value);
deltaTimeValueSpan.textContent = `T=${(deltaTime * DELTA_TIME).toFixed(2)}`;
const data = DATA[Math.floor(deltaTime.toFixed(2))];
updateParticlePositions(data);
});
deltaTimeValueSpan.textContent = `T=${(deltaTime * DELTA_TIME).toFixed(2)}`;
};
</script>
<div class="absolute w-[400px] bottom-5 left-0 right-0 mx-auto z-[999]">
<div class="flex flex-row w-[100%]">
<input
type="range"
min="1"
max="100"
value="50"
class="w-[100%]"
id="trange" />
<span id="tvalue" class="px-2"></span>
</div>
</div>
</body>
</html>
Rustパート
テンプレートHTMLをRustプログラムに埋め込みます。
const TEMPLATE_HTML: &str = include_str!("../template.html");
シミュレータに各ステップごとに呼び出すことのできるデータ化関数を用意します。理想的にはserde
・serde_json
で綺麗にシリアライズして、i/oエラー周りもthiserror
でまとめたいところですが、とりあえずこれで動きます。
impl Simulator {
fn to_visualizer_data<W: Write>(&self, mut writer: W) {
write!(writer, "[").unwrap();
for b in &self.state.pos {
write!(writer, "{{x:{},y:{},z:{}}},", b.x, b.y, b.z).unwrap();
}
write!(writer, "],").unwrap();
}
}
最後にこれをout.html
として出力するようにします。
fn main() {
let args = Args::parse();
let softening = 0.1;
let mut simulator = Simulator::new(args.gravity, args.num_planets, args.delta_time, softening);
let mut buf = Vec::new();
{
let mut writer = BufWriter::new(&mut buf);
for _ in 0..args.steps {
simulator.step();
simulator.to_visualizer_data(&mut writer);
}
writer.flush().unwrap();
}
let output_html = TEMPLATE_HTML
.replace("%STEPS%", &format!("{}", &args.steps))
.replace("%DELTA_TIME%", &format!("{}", &args.delta_time))
.replace("%REPLACEME%", &String::from_utf8_lossy(&buf));
let mut file = OpenOptions::new()
.create(true)
.truncate(true)
.write(true)
.read(true)
.open("out.html")
.unwrap();
file.write_all(output_html.as_bytes()).unwrap();
file.flush().unwrap();
}
これを
cargo run -- -g 1.0 -s 1000 -n 50 -d 0.01
Discussion