🧪 workspace, varo: Implement RNG with ChaCha8 and add optimization function

- Updated `Cargo.toml` and `Cargo.lock` to downgrade `rand_core` to 0.6.4 and add `rand_chacha` dependency with `serde1` feature.
- Implemented `Rng` struct wrapping `ChaCha8Rng` with seed initialization, stream selection, and number generation methods.
- Added `rng_new`, `rng_from_seed`, `rng_set_stream`, `rng_next_u32`, `rng_next_u64`, `rng_fill_bytes`, `rng_gen_f32`, and `rng_gen_gaussian` functions.
- Made `value` field in `Score` and `values` field in `OptimizationResult` public.
- Implemented `Distribution::sample` method that generates Gaussian numbers when moments are available.
- Added `From<f32>` implementation for `Score`.
- Implemented `optimize` function that evaluates candidates using cloned RNG streams and returns sorted results.
- Added integration test `test_optimize` that verifies optimization sorting behavior.
This commit is contained in:
Markus Scully 2025-08-06 16:46:58 +03:00
parent 5614cfe95f
commit db531c8c73
Signed by: mascully
GPG key ID: 93CA5814B698101C
5 changed files with 103 additions and 27 deletions

18
Cargo.lock generated
View file

@ -1373,7 +1373,7 @@ checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [
"libc",
"rand_chacha",
"rand_core 0.6.4",
"rand_core",
]
[[package]]
@ -1383,7 +1383,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [
"ppv-lite86",
"rand_core 0.6.4",
"rand_core",
"serde",
]
[[package]]
@ -1395,12 +1396,6 @@ dependencies = [
"getrandom 0.2.16",
]
[[package]]
name = "rand_core"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38"
[[package]]
name = "redox_syscall"
version = "0.5.13"
@ -1423,7 +1418,7 @@ dependencies = [
"num-traits",
"pkcs1",
"pkcs8",
"rand_core 0.6.4",
"rand_core",
"signature",
"spki",
"subtle",
@ -1554,7 +1549,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de"
dependencies = [
"digest",
"rand_core 0.6.4",
"rand_core",
]
[[package]]
@ -2083,7 +2078,8 @@ version = "0.1.0"
dependencies = [
"blake3",
"digest",
"rand_core 0.9.3",
"rand_chacha",
"rand_core",
"serde",
"serde_json",
"tempfile",

View file

@ -21,4 +21,5 @@ proc-macro2 = "1.0"
sqlx = "0.8.6"
rusqlite = { version = "0.32.1", features = ["bundled"] }
tempfile = "3.12.0"
rand_core = "0.9.3"
rand_core = "0.6.4"
rand_chacha = { version = "0.3.1", features = ["serde1"] }

View file

@ -14,6 +14,7 @@ blake3.workspace = true
vanth = { path = "../vanth" }
vanth_derive = { path = "../vanth_derive" }
rand_core.workspace = true
rand_chacha.workspace = true
[dev-dependencies]
tempfile = { workspace = true }

View file

@ -1,18 +1,52 @@
use rand_core::SeedableRng;
use rand_core::{RngCore, SeedableRng};
use rand_chacha::ChaCha8Rng;
use serde::{Deserialize, Serialize};
use std::f32::consts::PI;
use vanth_derive::Vanth;
#[derive(Clone, Debug, Deserialize, Serialize, Vanth)]
pub struct Rng {
// TODO: RNG
inner: ChaCha8Rng,
}
impl Rng {
// TODO
pub fn rng_new() -> Rng {
Rng { inner: ChaCha8Rng::from_seed([0u8; 32]) }
}
pub fn rng_from_seed(seed: [u8; 32]) -> Rng {
Rng { inner: ChaCha8Rng::from_seed(seed) }
}
pub fn rng_set_stream(rng: &mut Rng, stream: u64) {
rng.inner.set_stream(stream);
}
pub fn rng_next_u32(rng: &mut Rng) -> u32 {
rng.inner.next_u32()
}
pub fn rng_next_u64(rng: &mut Rng) -> u64 {
rng.inner.next_u64()
}
pub fn rng_fill_bytes(rng: &mut Rng, dest: &mut [u8]) {
rng.inner.fill_bytes(dest)
}
pub fn rng_gen_f32(rng: &mut Rng) -> f32 {
rng_next_u32(rng) as f32 / u32::MAX as f32
}
pub fn rng_gen_gaussian(rng: &mut Rng, mean: f32, std_dev: f32) -> f32 {
let u = rng_gen_f32(rng);
let v = rng_gen_f32(rng);
let s = (-2.0 * (1.0 - u).ln()).sqrt();
let angle = 2.0 * PI * v;
mean + std_dev * s * angle.cos()
}
pub trait Varo {
/// Produce a random
/// Produce a random instance of `Self` using the provided RNG.
fn next(digest: &mut Rng) -> Self;
}
@ -23,27 +57,41 @@ pub struct Distribution {
}
impl Distribution {
pub fn sample(digest: &mut Rng) -> f32 {
todo!()
pub fn sample(&self, digest: &mut Rng) -> f32 {
if self.moments.len() >= 2 {
rng_gen_gaussian(digest, self.moments[0], self.moments[1].sqrt())
} else {
rng_gen_f32(digest)
}
}
}
#[derive(Clone, Debug, Deserialize, Serialize, Vanth)]
pub struct Score {
value: f32,
pub value: f32,
}
impl From<f32> for Score {
fn from(value: f32) -> Self {
Score { value }
}
}
#[derive(Clone, Debug, Deserialize, Serialize, Vanth)]
pub struct OptimizationResult {
/// List of pairs of evaluation score and Rng used to generate the value.
values: Vec<(Rng, f32)>
pub values: Vec<(Rng, f32)>
}
pub fn optimize<T: Varo>(evaluator: impl Fn(T) -> Score, rng: &mut Rng, rounds: u32) -> OptimizationResult {
// TODO:
// `for i in 0..rounds`: create a clone of `rng` and feed it `i`.
// Call T::next() and pass it to the evaluator.
// Return a sorted list, highest scores first.
todo!()
let mut values: Vec<(Rng, f32)> = Vec::with_capacity(rounds as usize);
for i in 0..rounds {
let mut child = rng.clone();
rng_set_stream(&mut child, i as u64);
let t = T::next(&mut child);
let score = evaluator(t).value;
values.push((child, score));
}
values.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
OptimizationResult { values }
}

View file

@ -0,0 +1,30 @@
use varo::{optimize, Rng, Score, Varo};
#[test]
fn test_optimize() {
struct Foo {
x: f32,
}
impl Varo for Foo {
fn next(digest: &mut varo::Rng) -> Self {
let x = varo::rng_gen_f32(digest) * 10.0;
Foo { x }
}
}
fn evaluate(foo: Foo) -> Score {
let x = foo.x;
let score = -0.9 * x.powi(3) + 2.6 * x.powi(2) - 4.0 * x;
score.into()
}
let mut rng = varo::rng_new();
let optimization_result = optimize(evaluate, &mut rng, 10);
assert_eq!(optimization_result.values.len(), 10);
let scores: Vec<f32> = optimization_result.values.iter().map(|pair| pair.1).collect();
for i in 0..scores.len() - 1 {
assert!(scores[i] > scores[i + 1]);
}
println!();
}