Doing evolutionary strategies at the hyperscale.
WARNING: This codebase is a research preview. We strongly recommend reading through and running the eggroll.ipynb notebook in colab or locally (cpu/GPU) to understand the core ideas of the codebase, and also checking the example in tests/end_to_end_test.py to understand the codebase. We recommend also looking at the nano-egg repository, which contains a single-file implementation of int8 pretraining of a language model from scratch.
Install the version of jax available on your system before doing a pip install
conda create -n hyperscalees python=3.13 conda activate hyperscalees pip install "jax[cuda12]" pip install -e .
Also, remember to set the HF_HOME environment variable to a location where a lot of data can be written as cache. This project writes files to this directory when using the RWKV LLMs.
Clone the repo, enter the folder (cd HyperscaleES), and replace USERNAME in the Dockerfile. Then build image & test run:
docker build -t ${USER}_hyperscalees .
docker run -it --rm -v $(pwd):/app --name ${USER}_yourcontainername ${USER}_hyperscalees python tests/end_to_end_test.py
To run the LLM experiments, run the following program. You can check the file for command-line arguments.
python -m llm_experiments.general_do_evolution
The two core components of the codebase are the Noiser (src/hyperscalees/noiser) and the Model (src/hyperscalees/models). We strongly recommend reading and running the eggroll.ipynb notebook to see the full implementation and worked out example.
To initialize the noiser, call Noiser.init_noiser(params, sigma, lr, [additional keyword arguments for noiser]). This returns “frozen_noiser_params” and “noiser_params” where “frozen_noiser_params” are static aspects of the model (such as a solver of an optax optimizer) and “noiser_params” are dynamic aspects of the model (like the optimizer state, sigma, lr, etc.).
The noiser is responsible for perturbing the model with noise. The “get_noisy_standard” function gives the noised versions for parameters that are not applied via matmul (like biases). The “do_mm” function applies matmul (x @ param.T) with noised versions of the parameter. Similarly, “do_Tmm” does transposed matmul (x @ param), and “do_emb” does embedding (param[x]).
The noiser is also responsible for updating the model in the context of evolution. The “convert_fitnesses” function takes raw scores (one per generation thread) and optionally takes num_episodes_list (which gives the number of episodes that are averaged to get this fitness value; this may be helpful in the context of classic RL). The “do_update” function gives updated noiser_params and the updated parameters of the model.
A model class defines how to initialize parameters (and auxiliary components necessary for evolution) and the forward function. To initialize a model, call Model.rand_init(key, [model-specific parameters]), which will returns frozen_params (static aspects of the model), params (standard parameters), scan_map (defining which aspects of the model are scanned over so that different keys are created for sub-parameters), and es_map (defining if a parameter should be treated as a regular PARAM, MM_PARAM, EMB_PARAM, or EXCLUDED from evolution).
After initializing a model, you should also call hyperscalees.models.common.simple_es_tree_key(params, base_key, scan_map) (where base_key is a single jax random prng key), which outputs the “es_tree_key” of a model (referred to as the “base_keys” in the context of the noiser do_updates).
To call the model, you can run Model.forward(noiser, frozen_noiser_params, noiser_params, frozen_params, params, es_tree_key, iterinfo, [additional model-specific args]). The only unspecified parameter is iterinfo, which is either a tuple of (epoch, thread_id) or None (in the case of not noising the parameters).
See tests/end_to_end_test.py for a fully worked out example of getting an MLP to always predict the number 2.