GPU Accelerating My Electrodynamics Solver Using Claude Code

March 1, 2026
View code

Claude Code with Opus 4.6 wrote a JAX-based GPU backend for a complex multiparticle boundary-condition electrodynamics solver, achieving up to 12X speedup

GPU Accelerating My Electrodynamics Solver Using Claude Code

Revisiting code from my grad school days

When I was in grad school, I wrote computational software to perform electrodynamic simulations. The method is known as the Generalized Multiparticle Mie Theory (GMMT), a boundary-condition method for solving Maxwell’s equations for a multiparticle system. In my software implementation, I utilized C++ and OpenMP parallelization for performance-critical code and exposed it to Python via pybind11 for a flexible interface. I used this software package to study the fascinating dynamical properties of nanoparticle systems illuminated by optical tweezers. This software led to many publications and my PhD thesis.

It was during my 6 years in grad school that I discovered my primary passion was writing software, especially high-performance scientific software with powerful, clean user interfaces. I was spending a lot of time writing Python user interfaces, studying modern C++, and distributing my software to collaborators at UChicago, Argonne, and elsewhere. After briefly working in HPC, I transitioned to industry as a Software Engineer, building cloud-based data and compute platforms for a biotech company. As the years went by, I of course left my PhD software behind to collect dust. Despite past collaborators continuing to use my software, emailing me about it, and strangers opening issues on GitHub, I simply had no time to revisit it. Just getting back up to speed would have been a major chore and time-sink.

In 2025, agentic AI upended how software is developed; I experienced this firsthand at my job and quickly pivoted into this new computational paradigm. By late 2025, with Claude Code, Opus 4.6, and all my new software engineering skills, it became apparent that revisiting my grad school software might not be such a hurdle anymore! In just a few prompts, I could be back to developing my grad school software as if I never missed a beat. I allocated some of my spare time to this, largely out of curiosity about agentic AI capabilities in scientific software development, but also to reignite some of the original passion that drove me to grad school in the first place.

Step one was to modernize the software stack. My original code relied on a mess of manual C++ installs and conda packages. Now I know the value of package managers and lock files, so I worked with Claude to do the following:

  1. Manage C++ dependencies with vcpkg and CMake
  2. Manage Python dependencies with uv and scikit-build backend
  3. Use ruff linting & formatting
  4. Configure a GitHub action release pipeline to build wheels for Linux, macOS (ARM), and Windows, and distribute to PyPI

This streamlined development and distribution of the software. Now anyone can run it with one command:

uvx -p 3.13 miepy

This enabled me to address various install issues former collaborators were having.

Then I got curious. Could I revisit some of the ambitious goals I had in grad school for my software, now with the power of Claude Code? So I threw an ambitious problem at Claude: implement a JAX-based GPU backend for the electrodynamic solver to achieve faster solve times at large particle numbers compared to the C++/OpenMP implementation. To achieve this, upfront planning was a must.

Planning with Claude Cowork

Cowork is a great interface for research that spans a codebase and online literature. In addition to implementing a JAX-based GPU backend, I was asking about general performance improvements we could make. For example, comparing my software to others:

Let’s review the current state of literature relative to how MiePy computes its coefficients to identify potential improvements in performance or algorithm. This could be around (1) how the coefficients are computed, (2) how the matrix system is solved, (3) parallelization or hardware acceleration. We’re looking specifically at this generalized multiparticle mie theory approach

provides a comparison to other solutions available:

Cowork researches existing electrodynamic solvers for various capabilities

Then I prompted Cowork to help me create a series of Linear issues for improving MiePy’s performance, benchmarking its CPU implementation, and implementing a JAX backend.

There is a linear project named “MiePy”

Let’s create individual issue tasks for these items. Provide plenty of details and external references, including references to existing miepy source code.

In addition to these tasks, there is a preparation task (or tasks) to enable measuring performance. This would enable us to benchmark speed improvement under different conditions. There exist some reference benchmarks already at src/miepy/examples/benchmarks/ but perhaps new ones will be useful, especially for measuring performance at different N particles, or time spent in matrix solver vs building the matrix, or multi-core performance.

The goal being, we will start with creating some benchmarks, then start to implement each of these others tasks, measuring benchmarks and running existing test suite for correctness

Cowork creates detailed Linear issues for performance improvements to MiePy

Each Linear issue contains detailed implementation guidance. Additionally, Claude prioritized the issues and established dependencies between them:

Linear issues created by Cowork

The New JAX Backend

There were two parts of the algorithm that Claude ported to JAX:

  1. Assemble the aggregate T-matrix. Iterate over all particle pairs, compute spherical Hankel functions, Legendre polynomials, Gaunt coefficients, and VSH translation coefficients
  2. Solve boundary-condition linear system. Use either a direct solver or iterative BiCGSTAB solver to solve the full electrodynamic interactions of the particle system.

These are the only parts of the algorithm that are O(N2)\mathcal{O}(N^2) or O(N3)\mathcal{O}(N^3) (direct solver), with NN being the number of particles. We might expect a GPU implementation to outperform a CPU one at large NN.

Implementing #2 was trivial because JAX already provides these solvers:

  • jax.scipy.sparse.linalg.bicgstab - iterative solver
  • jax.numpy.linalg.solve - direct solver

But implementing #1 is quite non-trivial. The C++ code is highly imperative, and cannot simply be ported and JIT’d. Claude made several notable transformations to accomplish this:

  1. Split the computation into two phases, where phase 1 runs once to produce XLA constants at compile time.
  2. Flattened C++ iteration over 4 indices (wave function indices) into a single vectorized dimension
  3. Padded variable-length Gaunt sums. C++ relied on conditional branching here.
  4. Converted complex128 into f64 real and imaginary components to avoid crippling GPU performance

Here is a snippet from the JAX implementation:

src/miepy/backends/jax/interactions.py
# Step 1: Pair geometry [P = N*(N-1)/2 pairs]
dji = positions[i_vals] - positions[j_vals]         # [P, 3]
rad = jnp.linalg.norm(dji, axis=1)                  # [P]
cos_theta = dji[:, 2] / rad                          # [P]
phi = jnp.arctan2(dji[:, 1], dji[:, 0])             # [P]

# Step 2: Special functions via vmap (Python loops unrolled at trace time)
zn_all = jax.vmap(lambda r: spherical_hn_recursion(p_max, k * r))(rad)  # [P, p_max+1]
Pnm_all = jax.vmap(lambda ct: associated_legendre_recursion(p_max, ct))(cos_theta)  # [P, pnm_size]

# Step 3: Vectorized translation computation
exp_phi = jnp.exp(1j * mu_sum * phi[:, None])         # [P, M]

# A translations: sum_q A_coeff[q] * Pnm[idx_A[q]] * zn[p_A[q]]
Pnm_A = Pnm_all[:, A_pnm_idx]                        # [P, M, Q]
zn_A = zn_all[:, A_zn_idx]                             # [P, M, Q]
sum_A = jnp.sum(A_coeffs * Pnm_A * zn_A * A_mask, axis=-1)  # [P, M]
A_trans = factor * exp_phi * sum_A                      # [P, M]

# B translations: -factor * exp_phi * sum_q B_coeff[q] * Pnm[idx_B[q]] * zn[p_B[q]]
Pnm_B = Pnm_all[:, B_pnm_idx]                        # [P, M, Q]
zn_B = zn_all[:, B_zn_idx]                             # [P, M, Q]
sum_B = jnp.sum(B_coeffs * Pnm_B * zn_B * B_mask, axis=-1)  # [P, M]
B_trans = -factor * exp_phi * sum_B                     # [P, M]

transfer = jnp.stack([A_trans, B_trans], axis=-1)       # [P, M, 2]

# Step 4: Scatter into flat output matrix
# Determine particle index for row/col/mie of each (pair, entry)
row_particle = jnp.where(s_row_is_i, i_vals[:, None], j_vals[:, None])  # [P, E]
col_particle = jnp.where(s_col_is_i, i_vals[:, None], j_vals[:, None])  # [P, E]
mie_particle = jnp.where(s_mie_is_i, i_vals[:, None], j_vals[:, None])  # [P, E]

rows = row_particle * block_size + s_row_local         # [P, E]
cols = col_particle * block_size + s_col_local         # [P, E]

transfer_vals = transfer[:, s_mode_idx, s_ab_parity]   # [P, E]
mie_vals = mie[mie_particle, s_mie_pol, s_mie_order]  # [P, E]
vals = s_sign * transfer_vals * mie_vals                # [P, E]

Verifying correctness

Arriving at the correct solution required setting up the agent to test its results against the existing, working C++ implementation. Agents excel when they can test against an existing reference implementation. The following unit tests verified correctness of individual functions:

FunctionTestsReference
spherical_jn/yn/hnreal z, complex z, derivativescpp_special.*
spherical_hn_recursionnmax up to 15individual cpp_special.spherical_hn calls
riccati_1/2/3nmax up to 10NumPy/scipy miepy.special_functions
associated_legendre20 (n,m) combos, including negativescpp_special.associated_legendre
associated_legendre_recursionnmax up to 8, all modesindividual C++ calls
pi_func/tau_funcgeneral + pole cases (θ=0, θ=π)cpp_special.*
Emnn up to 8cpp_vsh.Emn
wigner_3jknown values + selection rulescpp_special.wigner_3j
wigner_3j_batch6 parameter combosindividual C++ calls
gaunt_batch10 (m,n,u,v) comboscpp_special.a_func/b_func
Mie coefficientsscattering, interior, conducting, array-kNumPy reference mie_sphere

and full-solver parity:

Test classConfigurationWhat’s compared
TestSingleParticle1 sphere, no interactionscross-sections, p_src/p_inc/p_scat
TestTwoParticlesExact2 spheres, LU solverp_inc, p_scat, cross-sections, force, torque
TestMultiParticleBiCGSTAB5 spheres, iterativep_inc, cross-sections, force
TestHigherLmax2 spheres, lmax=4p_inc, cross-sections
TestDisplacedParticleJAX-only, origin shiftcross-section invariance
TestInteractionsOff2 spheres, no couplingextinction = 2× single Mie
TestEFieldParity2 spheres, exactscattered E-field, total E-field
TestClusterCoefficients2 spheres, exactp_cluster
TestSolverMethodParityexact vs bicgstabcross-backend + cross-solver

GPU vs CPU Performance

The GPU solver is in fact faster than the CPU solver for roughly N>25N > 25. Approaching 1,000 nanoparticles, the CPU solver takes close to 6 seconds vs the GPU solver’s 0.5 seconds, a 12X speedup. This order of magnitude increase is quite significant in any scenario where the electrodynamics needs to be solved repeatedly:

  1. Studying optical properties at various wavelengths
  2. Studying different geometries and particle configurations
  3. Solving particle trajectories under optical tweezer illumination over longer times
GPU acceleration for solving Maxwell's Equations for N interacting particles