✨🧪⬆➕ workspace, varo: Implement RNG with ChaCha8 and add optimization function
- Updated `Cargo.toml` and `Cargo.lock` to downgrade `rand_core` to 0.6.4 and add `rand_chacha` dependency with `serde1` feature. - Implemented `Rng` struct wrapping `ChaCha8Rng` with seed initialization, stream selection, and number generation methods. - Added `rng_new`, `rng_from_seed`, `rng_set_stream`, `rng_next_u32`, `rng_next_u64`, `rng_fill_bytes`, `rng_gen_f32`, and `rng_gen_gaussian` functions. - Made `value` field in `Score` and `values` field in `OptimizationResult` public. - Implemented `Distribution::sample` method that generates Gaussian numbers when moments are available. - Added `From<f32>` implementation for `Score`. - Implemented `optimize` function that evaluates candidates using cloned RNG streams and returns sorted results. - Added integration test `test_optimize` that verifies optimization sorting behavior.
This commit is contained in:
parent
5614cfe95f
commit
db531c8c73
5 changed files with 103 additions and 27 deletions
18
Cargo.lock
generated
18
Cargo.lock
generated
|
@ -1373,7 +1373,7 @@ checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
"rand_chacha",
|
"rand_chacha",
|
||||||
"rand_core 0.6.4",
|
"rand_core",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -1383,7 +1383,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
|
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"ppv-lite86",
|
"ppv-lite86",
|
||||||
"rand_core 0.6.4",
|
"rand_core",
|
||||||
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -1395,12 +1396,6 @@ dependencies = [
|
||||||
"getrandom 0.2.16",
|
"getrandom 0.2.16",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "rand_core"
|
|
||||||
version = "0.9.3"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "redox_syscall"
|
name = "redox_syscall"
|
||||||
version = "0.5.13"
|
version = "0.5.13"
|
||||||
|
@ -1423,7 +1418,7 @@ dependencies = [
|
||||||
"num-traits",
|
"num-traits",
|
||||||
"pkcs1",
|
"pkcs1",
|
||||||
"pkcs8",
|
"pkcs8",
|
||||||
"rand_core 0.6.4",
|
"rand_core",
|
||||||
"signature",
|
"signature",
|
||||||
"spki",
|
"spki",
|
||||||
"subtle",
|
"subtle",
|
||||||
|
@ -1554,7 +1549,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de"
|
checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"digest",
|
"digest",
|
||||||
"rand_core 0.6.4",
|
"rand_core",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -2083,7 +2078,8 @@ version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"blake3",
|
"blake3",
|
||||||
"digest",
|
"digest",
|
||||||
"rand_core 0.9.3",
|
"rand_chacha",
|
||||||
|
"rand_core",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"tempfile",
|
"tempfile",
|
||||||
|
|
|
@ -21,4 +21,5 @@ proc-macro2 = "1.0"
|
||||||
sqlx = "0.8.6"
|
sqlx = "0.8.6"
|
||||||
rusqlite = { version = "0.32.1", features = ["bundled"] }
|
rusqlite = { version = "0.32.1", features = ["bundled"] }
|
||||||
tempfile = "3.12.0"
|
tempfile = "3.12.0"
|
||||||
rand_core = "0.9.3"
|
rand_core = "0.6.4"
|
||||||
|
rand_chacha = { version = "0.3.1", features = ["serde1"] }
|
||||||
|
|
|
@ -14,6 +14,7 @@ blake3.workspace = true
|
||||||
vanth = { path = "../vanth" }
|
vanth = { path = "../vanth" }
|
||||||
vanth_derive = { path = "../vanth_derive" }
|
vanth_derive = { path = "../vanth_derive" }
|
||||||
rand_core.workspace = true
|
rand_core.workspace = true
|
||||||
|
rand_chacha.workspace = true
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tempfile = { workspace = true }
|
tempfile = { workspace = true }
|
||||||
|
|
|
@ -1,18 +1,52 @@
|
||||||
use rand_core::SeedableRng;
|
use rand_core::{RngCore, SeedableRng};
|
||||||
|
use rand_chacha::ChaCha8Rng;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::f32::consts::PI;
|
||||||
use vanth_derive::Vanth;
|
use vanth_derive::Vanth;
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize, Vanth)]
|
#[derive(Clone, Debug, Deserialize, Serialize, Vanth)]
|
||||||
pub struct Rng {
|
pub struct Rng {
|
||||||
// TODO: RNG
|
inner: ChaCha8Rng,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Rng {
|
pub fn rng_new() -> Rng {
|
||||||
// TODO
|
Rng { inner: ChaCha8Rng::from_seed([0u8; 32]) }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rng_from_seed(seed: [u8; 32]) -> Rng {
|
||||||
|
Rng { inner: ChaCha8Rng::from_seed(seed) }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rng_set_stream(rng: &mut Rng, stream: u64) {
|
||||||
|
rng.inner.set_stream(stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rng_next_u32(rng: &mut Rng) -> u32 {
|
||||||
|
rng.inner.next_u32()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rng_next_u64(rng: &mut Rng) -> u64 {
|
||||||
|
rng.inner.next_u64()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rng_fill_bytes(rng: &mut Rng, dest: &mut [u8]) {
|
||||||
|
rng.inner.fill_bytes(dest)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rng_gen_f32(rng: &mut Rng) -> f32 {
|
||||||
|
rng_next_u32(rng) as f32 / u32::MAX as f32
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rng_gen_gaussian(rng: &mut Rng, mean: f32, std_dev: f32) -> f32 {
|
||||||
|
let u = rng_gen_f32(rng);
|
||||||
|
let v = rng_gen_f32(rng);
|
||||||
|
let s = (-2.0 * (1.0 - u).ln()).sqrt();
|
||||||
|
let angle = 2.0 * PI * v;
|
||||||
|
mean + std_dev * s * angle.cos()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait Varo {
|
pub trait Varo {
|
||||||
/// Produce a random
|
/// Produce a random instance of `Self` using the provided RNG.
|
||||||
fn next(digest: &mut Rng) -> Self;
|
fn next(digest: &mut Rng) -> Self;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -23,27 +57,41 @@ pub struct Distribution {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Distribution {
|
impl Distribution {
|
||||||
pub fn sample(digest: &mut Rng) -> f32 {
|
pub fn sample(&self, digest: &mut Rng) -> f32 {
|
||||||
todo!()
|
if self.moments.len() >= 2 {
|
||||||
|
rng_gen_gaussian(digest, self.moments[0], self.moments[1].sqrt())
|
||||||
|
} else {
|
||||||
|
rng_gen_f32(digest)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize, Vanth)]
|
#[derive(Clone, Debug, Deserialize, Serialize, Vanth)]
|
||||||
pub struct Score {
|
pub struct Score {
|
||||||
value: f32,
|
pub value: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<f32> for Score {
|
||||||
|
fn from(value: f32) -> Self {
|
||||||
|
Score { value }
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize, Vanth)]
|
#[derive(Clone, Debug, Deserialize, Serialize, Vanth)]
|
||||||
pub struct OptimizationResult {
|
pub struct OptimizationResult {
|
||||||
/// List of pairs of evaluation score and Rng used to generate the value.
|
/// List of pairs of evaluation score and Rng used to generate the value.
|
||||||
values: Vec<(Rng, f32)>
|
pub values: Vec<(Rng, f32)>
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn optimize<T: Varo>(evaluator: impl Fn(T) -> Score, rng: &mut Rng, rounds: u32) -> OptimizationResult {
|
pub fn optimize<T: Varo>(evaluator: impl Fn(T) -> Score, rng: &mut Rng, rounds: u32) -> OptimizationResult {
|
||||||
// TODO:
|
let mut values: Vec<(Rng, f32)> = Vec::with_capacity(rounds as usize);
|
||||||
// `for i in 0..rounds`: create a clone of `rng` and feed it `i`.
|
for i in 0..rounds {
|
||||||
// Call T::next() and pass it to the evaluator.
|
let mut child = rng.clone();
|
||||||
// Return a sorted list, highest scores first.
|
rng_set_stream(&mut child, i as u64);
|
||||||
|
let t = T::next(&mut child);
|
||||||
todo!()
|
let score = evaluator(t).value;
|
||||||
|
values.push((child, score));
|
||||||
|
}
|
||||||
|
values.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||||
|
OptimizationResult { values }
|
||||||
}
|
}
|
||||||
|
|
30
crates/varo/tests/integration/main.rs
Normal file
30
crates/varo/tests/integration/main.rs
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
use varo::{optimize, Rng, Score, Varo};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_optimize() {
|
||||||
|
struct Foo {
|
||||||
|
x: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Varo for Foo {
|
||||||
|
fn next(digest: &mut varo::Rng) -> Self {
|
||||||
|
let x = varo::rng_gen_f32(digest) * 10.0;
|
||||||
|
Foo { x }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn evaluate(foo: Foo) -> Score {
|
||||||
|
let x = foo.x;
|
||||||
|
let score = -0.9 * x.powi(3) + 2.6 * x.powi(2) - 4.0 * x;
|
||||||
|
score.into()
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut rng = varo::rng_new();
|
||||||
|
let optimization_result = optimize(evaluate, &mut rng, 10);
|
||||||
|
assert_eq!(optimization_result.values.len(), 10);
|
||||||
|
let scores: Vec<f32> = optimization_result.values.iter().map(|pair| pair.1).collect();
|
||||||
|
for i in 0..scores.len() - 1 {
|
||||||
|
assert!(scores[i] > scores[i + 1]);
|
||||||
|
}
|
||||||
|
println!();
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue