diff --git a/.cargo/audit.toml b/.cargo/audit.toml index 84c6551d7..b3a4539d3 100644 --- a/.cargo/audit.toml +++ b/.cargo/audit.toml @@ -11,6 +11,7 @@ ignore = [ "RUSTSEC-2024-0436", # Paste used to generate macro, should be removed at some point. "RUSTSEC-2025-0119", # `number_prefix` used by `tokenizers`, only in the examples. "RUSTSEC-2025-0141", # `bincode` is no longer maintained. + "RUSTSEC-2024-0388", # `derivative` dependancy in the DQN example is unmaintained. ] # advisory IDs to ignore e.g. ["RUSTSEC-2019-0001", ...] informational_warnings = [ "unmaintained", diff --git a/Cargo.lock b/Cargo.lock index 0aa85d65f..c29f2235b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -146,6 +146,15 @@ version = "1.0.100" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" +[[package]] +name = "approx" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" +dependencies = [ + "num-traits", +] + [[package]] name = "ar_archive_writer" version = "0.5.1" @@ -661,6 +670,7 @@ dependencies = [ "burn-nn", "burn-optim", "burn-remote", + "burn-rl", "burn-rocm", "burn-router", "burn-store", @@ -1067,6 +1077,18 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "burn-rl" +version = "0.21.0" +dependencies = [ + "burn-core", + "burn-ndarray", + "burn-optim", + "derive-new", + "log", + "rand 0.9.2", +] + [[package]] name = "burn-rocm" version = "0.21.0" @@ -1183,9 +1205,11 @@ dependencies = [ "burn-core", "burn-ndarray", "burn-optim", + "burn-rl", "derive-new", "log", "nvml-wrapper", + "rand 0.9.2", "ratatui", "rstest", "serde", @@ -1331,6 +1355,12 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "c_vec" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdd7a427adc0135366d99db65b36dae9237130997e560ed61118041fb72be6e8" + [[package]] name = "candle-core" version = "0.9.2" @@ -2709,6 +2739,17 @@ dependencies = [ "powerfmt", ] +[[package]] +name = "derivative" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "derive-new" version = "0.7.0" @@ -2911,6 +2952,16 @@ version = "0.15.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" +[[package]] +name = "dqn-agent" +version = "0.21.0" +dependencies = [ + "burn", + "derive-new", + "gym-rs", + "rand 0.9.2", +] + [[package]] name = "drawille" version = "0.3.0" @@ -4045,6 +4096,23 @@ dependencies = [ "serde", ] +[[package]] +name = "gym-rs" +version = "0.3.1" +source = "git+https://github.com/MathisWellmann/gym-rs?branch=main#5283afaa86a3a7c45c46c882cfad459f02539b62" +dependencies = [ + "derivative", + "derive-new", + "log", + "nalgebra", + "num-traits", + "ordered-float 4.6.0", + "rand 0.8.5", + "rand_pcg", + "sdl2", + "serde", +] + [[package]] name = "h2" version = "0.4.13" @@ -5252,6 +5320,33 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "nalgebra" +version = "0.33.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26aecdf64b707efd1310e3544d709c5c0ac61c13756046aaaba41be5c4f66a3b" +dependencies = [ + "approx", + "matrixmultiply", + "nalgebra-macros", + "num-complex", + "num-rational", + "num-traits", + "simba", + "typenum", +] + +[[package]] +name = "nalgebra-macros" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "254a5372af8fc138e36684761d3c0cdb758a4410e938babcff1c860ce14ddbfc" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.114", +] + [[package]] name = "named-tensor" version = "0.21.0" @@ -5907,6 +6002,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7bb71e1b3fa6ca1c61f383464aaf2bb0e2f8e772a1f01d486832464de363b951" dependencies = [ "num-traits", + "rand 0.8.5", + "serde", ] [[package]] @@ -7014,6 +7111,7 @@ dependencies = [ "libc", "rand_chacha 0.3.1", "rand_core 0.6.4", + "serde", ] [[package]] @@ -7053,6 +7151,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ "getrandom 0.2.17", + "serde", ] [[package]] @@ -7074,6 +7173,15 @@ dependencies = [ "rand 0.9.2", ] +[[package]] +name = "rand_pcg" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59cad018caf63deb318e5a4586d99a24424a364f40f1e5778c29aca23f4fc73e" +dependencies = [ + "rand_core 0.6.4", +] + [[package]] name = "range-alloc" version = "0.1.4" @@ -7604,6 +7712,15 @@ version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a50f4cf475b65d88e057964e0e9bb1f0aa9bbb2036dc65c64596b42932536984" +[[package]] +name = "safe_arch" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96b02de82ddbe1b636e6170c21be622223aea188ef2e139be0a5b219ec215323" +dependencies = [ + "bytemuck", +] + [[package]] name = "safetensors" version = "0.3.3" @@ -7704,6 +7821,31 @@ version = "3.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "490dcfcbfef26be6800d11870ff2df8774fa6e86d047e3e8c8a76b25655e41ca" +[[package]] +name = "sdl2" +version = "0.37.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b498da7d14d1ad6c839729bd4ad6fc11d90a57583605f3b4df2cd709a9cd380" +dependencies = [ + "bitflags 1.3.2", + "c_vec", + "lazy_static", + "libc", + "sdl2-sys", +] + +[[package]] +name = "sdl2-sys" +version = "0.37.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "951deab27af08ed9c6068b7b0d05a93c91f0a8eb16b6b816a5e73452a43521d3" +dependencies = [ + "cfg-if", + "cmake", + "libc", + "version-compare", +] + [[package]] name = "security-framework" version = "2.11.1" @@ -7971,6 +8113,19 @@ dependencies = [ "libc", ] +[[package]] +name = "simba" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c99284beb21666094ba2b75bbceda012e610f5479dfcc2d6e2426f53197ffd95" +dependencies = [ + "approx", + "num-complex", + "num-traits", + "paste", + "wide", +] + [[package]] name = "simd-adler32" version = "0.3.8" @@ -9440,6 +9595,12 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" +[[package]] +name = "version-compare" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "579a42fc0b8e0c63b76519a339be31bed574929511fa53c1a3acae26eb258f29" + [[package]] name = "version_check" version = "0.9.5" @@ -9856,6 +10017,16 @@ dependencies = [ "web-sys", ] +[[package]] +name = "wide" +version = "0.7.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce5da8ecb62bcd8ec8b7ea19f69a51275e91299be594ea5cc6ef7819e16cd03" +dependencies = [ + "bytemuck", + "safe_arch", +] + [[package]] name = "widestring" version = "1.2.1" diff --git a/crates/burn-core/src/record/base.rs b/crates/burn-core/src/record/base.rs index 844a80414..a213384e5 100644 --- a/crates/burn-core/src/record/base.rs +++ b/crates/burn-core/src/record/base.rs @@ -7,7 +7,7 @@ use serde::{Serialize, de::DeserializeOwned}; /// Trait to define a family of types which can be recorded using any [settings](PrecisionSettings). pub trait Record: Send { /// Type of the item that can be serialized and deserialized. - type Item: Serialize + DeserializeOwned; + type Item: Serialize + DeserializeOwned + Clone; /// Convert the current record into the corresponding item that follows the given [settings](PrecisionSettings). fn into_item(self) -> Self::Item; diff --git a/crates/burn-core/src/record/recorder.rs b/crates/burn-core/src/record/recorder.rs index 74ca7f26d..66dd49fb2 100644 --- a/crates/burn-core/src/record/recorder.rs +++ b/crates/burn-core/src/record/recorder.rs @@ -285,7 +285,7 @@ mod tests { #[test] #[should_panic] fn err_when_invalid_item() { - #[derive(new, Serialize, Deserialize)] + #[derive(new, Serialize, Deserialize, Clone)] struct Item { value: S::FloatElem, } diff --git a/crates/burn-optim/src/lr_scheduler/composed.rs b/crates/burn-optim/src/lr_scheduler/composed.rs index 8d15a2afa..b6b686604 100644 --- a/crates/burn-optim/src/lr_scheduler/composed.rs +++ b/crates/burn-optim/src/lr_scheduler/composed.rs @@ -102,7 +102,7 @@ enum LrSchedulerItem { } #[derive(Record)] -/// Record item for the [componsed learning rate scheduler](ComposedLrScheduler). +/// Record item for the [composed learning rate scheduler](ComposedLrScheduler). pub enum LrSchedulerRecord { /// The linear variant. Linear(::Record), @@ -115,7 +115,7 @@ pub enum LrSchedulerRecord { } #[derive(Record)] -/// Records for the [componsed learning rate scheduler](ComposedLrScheduler). +/// Records for the [composed learning rate scheduler](ComposedLrScheduler). pub struct ComposedLrSchedulerRecord { schedulers: Vec>, } diff --git a/crates/burn-optim/src/optim/simple/record/base.rs b/crates/burn-optim/src/optim/simple/record/base.rs index 1d9c8cded..cd58b764e 100644 --- a/crates/burn-optim/src/optim/simple/record/base.rs +++ b/crates/burn-optim/src/optim/simple/record/base.rs @@ -19,7 +19,7 @@ where } /// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item. -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone)] #[serde(bound = "")] pub enum AdaptorRecordItem< O: SimpleOptimizer, diff --git a/crates/burn-optim/src/optim/simple/record/v1.rs b/crates/burn-optim/src/optim/simple/record/v1.rs index 1e5ca84a2..913120fa3 100644 --- a/crates/burn-optim/src/optim/simple/record/v1.rs +++ b/crates/burn-optim/src/optim/simple/record/v1.rs @@ -56,7 +56,7 @@ impl, B: Backend> Clone for AdaptorRecordV1 { } /// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item. -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone)] #[serde(bound = "")] pub enum AdaptorRecordItemV1, B: Backend, S: PrecisionSettings> { /// Rank 0. diff --git a/crates/burn-rl/Cargo.toml b/crates/burn-rl/Cargo.toml new file mode 100644 index 000000000..72fb696d1 --- /dev/null +++ b/crates/burn-rl/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "burn-rl" +edition.workspace = true +license.workspace = true +readme.workspace = true +version.workspace = true + +[dependencies] +burn-core = { path = "../burn-core", version = "=0.21.0", features = [ + "dataset", + "std", +], default-features = false } +burn-optim = { path = "../burn-optim", version = "=0.21.0", features = [ + "std", +], default-features = false } +burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0" } +derive-new.workspace = true +rand.workspace = true + +log = { workspace = true } + +[lints] +workspace = true diff --git a/crates/burn-rl/LICENSE-APACHE b/crates/burn-rl/LICENSE-APACHE new file mode 120000 index 000000000..1cd601d0a --- /dev/null +++ b/crates/burn-rl/LICENSE-APACHE @@ -0,0 +1 @@ +../../LICENSE-APACHE \ No newline at end of file diff --git a/crates/burn-rl/LICENSE-MIT b/crates/burn-rl/LICENSE-MIT new file mode 120000 index 000000000..b2cfbdc7b --- /dev/null +++ b/crates/burn-rl/LICENSE-MIT @@ -0,0 +1 @@ +../../LICENSE-MIT \ No newline at end of file diff --git a/crates/burn-rl/README.md b/crates/burn-rl/README.md new file mode 100644 index 000000000..637cdda77 --- /dev/null +++ b/crates/burn-rl/README.md @@ -0,0 +1,6 @@ +# Burn RL + + + + diff --git a/crates/burn-rl/src/environment/base.rs b/crates/burn-rl/src/environment/base.rs new file mode 100644 index 000000000..16af86975 --- /dev/null +++ b/crates/burn-rl/src/environment/base.rs @@ -0,0 +1,46 @@ +/// The result of taking a step in an environment. +pub struct StepResult { + /// The updated state. + pub next_state: S, + /// The reward. + pub reward: f64, + /// If the environment reached a terminal state. + pub done: bool, + /// If the environment reached its max length. + pub truncated: bool, +} + +/// Trait to be implemented for a RL environment. +pub trait Environment { + /// The type of the state. + type State; + /// The type of actions. + type Action; + + /// The maximum number of step for one episode. + const MAX_STEPS: usize; + + /// Returns the current state. + fn state(&self) -> Self::State; + /// Take a step in the environment given an action. + fn step(&mut self, action: Self::Action) -> StepResult; + /// Reset the environment to an initial state. + fn reset(&mut self); +} + +/// Trait to define how to initialize an environment. +/// By default, any function returning an environment implements it. +pub trait EnvironmentInit: Clone { + /// Initialize the environment. + fn init(&self) -> E; +} + +impl EnvironmentInit for F +where + F: Fn() -> E + Clone, + E: Environment, +{ + fn init(&self) -> E { + (self)() + } +} diff --git a/crates/burn-rl/src/environment/mod.rs b/crates/burn-rl/src/environment/mod.rs new file mode 100644 index 000000000..cbcb6ac7e --- /dev/null +++ b/crates/burn-rl/src/environment/mod.rs @@ -0,0 +1,3 @@ +mod base; + +pub use base::*; diff --git a/crates/burn-rl/src/lib.rs b/crates/burn-rl/src/lib.rs new file mode 100644 index 000000000..dd5736b87 --- /dev/null +++ b/crates/burn-rl/src/lib.rs @@ -0,0 +1,21 @@ +#![warn(missing_docs)] +#![cfg_attr(docsrs, feature(doc_cfg))] + +//! A library for training reinforcement learning agents. + +/// Module for implementing an environment. +pub mod environment; +/// Module for implementing a policy. +pub mod policy; +/// Transition buffer. +pub mod transition_buffer; + +pub use environment::*; +pub use policy::*; +pub use transition_buffer::*; + +#[cfg(test)] +pub(crate) type TestBackend = burn_ndarray::NdArray; + +#[cfg(test)] +pub(crate) mod tests {} diff --git a/crates/burn-rl/src/policy/async_policy.rs b/crates/burn-rl/src/policy/async_policy.rs new file mode 100644 index 000000000..b1946fd83 --- /dev/null +++ b/crates/burn-rl/src/policy/async_policy.rs @@ -0,0 +1,314 @@ +use std::{ + sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, + mpsc::{self, Sender}, + }, + thread::spawn, +}; + +use burn_core::prelude::Backend; + +use crate::{ActionContext, Batchable, Policy, PolicyState}; + +#[derive(Clone)] +struct PolicyInferenceServer> { + // `num_agents` used to make sure autobatching doesn't block the agents if they are less than the autobatch size. + num_agents: Arc, + max_autobatch_size: usize, + inner_policy: P, + batch_action: Vec>, + batch_logits: Vec>, +} + +impl PolicyInferenceServer +where + B: Backend, + P: Policy, + P::Observation: Clone + Batchable, + P::ActionDistribution: Clone + Batchable, + P::Action: Clone + Batchable, + P::ActionContext: Clone, +{ + pub fn new(max_autobatch_size: usize, inner_policy: P) -> Self { + Self { + num_agents: Arc::new(AtomicUsize::new(0)), + max_autobatch_size, + inner_policy, + batch_action: vec![], + batch_logits: vec![], + } + } + + pub fn push_action(&mut self, item: ActionItem) { + self.batch_action.push(item); + if self.len_actions() + >= self + .num_agents + .load(Ordering::Relaxed) + .min(self.max_autobatch_size) + { + self.flush_actions(); + } + } + + pub fn push_logits(&mut self, item: ForwardItem) { + self.batch_logits.push(item); + if self.len_logits() + >= self + .num_agents + .load(Ordering::Relaxed) + .min(self.max_autobatch_size) + { + self.flush_logits(); + } + } + + pub fn len_actions(&self) -> usize { + self.batch_action.len() + } + + pub fn len_logits(&self) -> usize { + self.batch_action.len() + } + + pub fn flush_actions(&mut self) { + if self.len_actions() == 0 { + return; + } + let input: Vec<_> = self + .batch_action + .iter() + .map(|m| m.inference_state.clone()) + .collect(); + // Only deterministic if all actions are requested as deterministic. + let deterministic = self.batch_action.iter().all(|item| item.deterministic); + let (actions, context) = self + .inner_policy + .action(P::Observation::batch(input), deterministic); + let actions: Vec<_> = actions.unbatch(); + + for (i, item) in self.batch_action.iter().enumerate() { + item.sender + .send(ActionContext { + context: vec![context[i].clone()], + action: actions[i].clone(), + }) + .expect("Autobatcher should be able to send resulting actions."); + } + self.batch_action.clear(); + } + + pub fn flush_logits(&mut self) { + if self.len_logits() == 0 { + return; + } + let input: Vec<_> = self + .batch_logits + .iter() + .map(|m| m.inference_state.clone()) + .collect(); + let output = self.inner_policy.forward(P::Observation::batch(input)); + let logits: Vec<_> = output.unbatch(); + for (i, item) in self.batch_logits.iter().enumerate() { + item.sender + .send(logits[i].clone()) + .expect("Autobatcher should be able to send resulting probabilities."); + } + self.batch_action.clear(); + } + + pub fn update_policy(&mut self, policy_update: P::PolicyState) { + if self.len_actions() > 0 { + self.flush_actions(); + } + if self.len_logits() > 0 { + self.flush_logits(); + } + self.inner_policy.update(policy_update); + } + + pub fn state(&self) -> P::PolicyState { + self.inner_policy.state() + } + + pub fn increment_agents(&mut self, num: usize) { + self.num_agents.fetch_add(num, Ordering::Relaxed); + } + + pub fn decrement_agents(&mut self, num: usize) { + self.num_agents.fetch_sub(num, Ordering::Relaxed); + if self.len_actions() + >= self + .num_agents + .load(Ordering::Relaxed) + .min(self.max_autobatch_size) + { + self.flush_actions(); + } + if self.len_logits() + >= self + .num_agents + .load(Ordering::Relaxed) + .min(self.max_autobatch_size) + { + self.flush_logits(); + } + } +} + +enum InferenceMessage> { + ActionMessage(ActionItem), + ForwardMessage(ForwardItem), + PolicyUpdate(P::PolicyState), + PolicyRequest(Sender), + IncrementAgents(usize), + DecrementAgents(usize), +} + +#[derive(Clone)] +struct ActionItem { + sender: Sender>>, + inference_state: S, + deterministic: bool, +} + +#[derive(Clone)] +struct ForwardItem { + sender: Sender, + inference_state: S, +} + +/// An asynchronous policy using an inference server with autobatching. +#[derive(Clone)] +pub struct AsyncPolicy> { + inference_state_sender: Sender>, +} + +impl AsyncPolicy +where + B: Backend, + P: Policy + Clone + Send + 'static, + P::ActionContext: Clone + Send, + P::PolicyState: Send, + P::Observation: Clone + Send + Batchable, + P::ActionDistribution: Clone + Send + Batchable, + P::Action: Clone + Send + Batchable, +{ + /// Create the policy. + /// + /// # Arguments + /// + /// * `autobatch_size` - Number of observations to accumulate before running a pass of inference. + /// * `inner_policy` - The policy used to take actions. + pub fn new(autobatch_size: usize, inner_policy: P) -> Self { + let (sender, receiver) = std::sync::mpsc::channel(); + let mut autobatcher = PolicyInferenceServer::new(autobatch_size, inner_policy.clone()); + spawn(move || { + loop { + match receiver.recv() { + Ok(msg) => match msg { + InferenceMessage::ActionMessage(item) => autobatcher.push_action(item), + InferenceMessage::ForwardMessage(item) => autobatcher.push_logits(item), + InferenceMessage::PolicyUpdate(update) => autobatcher.update_policy(update), + InferenceMessage::PolicyRequest(sender) => sender + .send(autobatcher.state()) + .expect("Autobatcher should be able to send current policy state."), + InferenceMessage::IncrementAgents(num) => autobatcher.increment_agents(num), + InferenceMessage::DecrementAgents(num) => autobatcher.decrement_agents(num), + }, + Err(err) => { + log::error!("Error in AsyncPolicy : {}", err); + break; + } + } + } + }); + + Self { + inference_state_sender: sender, + } + } + + /// Increment the number of agents using the inference server. + pub fn increment_agents(&self, num: usize) { + self.inference_state_sender + .send(InferenceMessage::IncrementAgents(num)) + .expect("Can send message to autobatcher.") + } + + /// Decrement the number of agents using the inference server. + pub fn decrement_agents(&self, num: usize) { + self.inference_state_sender + .send(InferenceMessage::DecrementAgents(num)) + .expect("Can send message to autobatcher.") + } +} + +impl Policy for AsyncPolicy +where + B: Backend, + P: Policy + Send + 'static, +{ + type ActionContext = P::ActionContext; + type PolicyState = P::PolicyState; + + type Observation = P::Observation; + type ActionDistribution = P::ActionDistribution; + type Action = P::Action; + + fn forward(&mut self, states: Self::Observation) -> Self::ActionDistribution { + let (action_sender, action_receiver) = std::sync::mpsc::channel(); + let item = ForwardItem { + sender: action_sender, + inference_state: states, + }; + self.inference_state_sender + .send(InferenceMessage::ForwardMessage(item)) + .expect("Should be able to send message to inference_server"); + action_receiver + .recv() + .expect("AsyncPolicy should receive queued probabilities.") + } + + fn action( + &mut self, + states: Self::Observation, + deterministic: bool, + ) -> (Self::Action, Vec) { + let (action_sender, action_receiver) = std::sync::mpsc::channel(); + let item = ActionItem { + sender: action_sender, + inference_state: states, + deterministic, + }; + self.inference_state_sender + .send(InferenceMessage::ActionMessage(item)) + .expect("should be able to send message to inference_server."); + let action = action_receiver + .recv() + .expect("AsyncPolicy should receive queued actions."); + (action.action, action.context) + } + + fn update(&mut self, update: Self::PolicyState) { + self.inference_state_sender + .send(InferenceMessage::PolicyUpdate(update)) + .expect("AsyncPolicy should be able to send policy state.") + } + + fn state(&self) -> Self::PolicyState { + let (sender, receiver) = mpsc::channel(); + self.inference_state_sender + .send(InferenceMessage::PolicyRequest(sender)) + .expect("should be able to send message to inference_server."); + receiver + .recv() + .expect("AsyncPolicy should be able to receive policy state.") + } + + fn load_record(self, _record: >::Record) -> Self { + // Not needed for now + todo!() + } +} diff --git a/crates/burn-rl/src/policy/base.rs b/crates/burn-rl/src/policy/base.rs new file mode 100644 index 000000000..26873f267 --- /dev/null +++ b/crates/burn-rl/src/policy/base.rs @@ -0,0 +1,108 @@ +use derive_new::new; + +use burn_core::{prelude::*, record::Record, tensor::backend::AutodiffBackend}; + +use crate::TransitionBatch; + +/// An action along with additional context about the decision. +#[derive(Clone, new)] +pub struct ActionContext { + /// The context. + pub context: C, + /// The action. + pub action: A, +} + +/// The state of a policy. +pub trait PolicyState { + /// The type of the record. + type Record: Record; + + /// Convert the state to a record. + fn into_record(self) -> Self::Record; + /// Load the state from a record. + fn load_record(&self, record: Self::Record) -> Self; +} + +/// Trait for a RL policy. +pub trait Policy: Clone { + /// The observation given as input to the policy. + type Observation; + /// The action distribution parameters defining how the action will be sampled. + type ActionDistribution; + /// The action. + type Action; + + /// Additional context on the policy's decision. + type ActionContext; + /// The current parameterization of the policy. + type PolicyState: PolicyState; + + /// Produces the action distribution from a batch of observations. + fn forward(&mut self, obs: Self::Observation) -> Self::ActionDistribution; + /// Gives the action from a batch of observations. + fn action( + &mut self, + obs: Self::Observation, + deterministic: bool, + ) -> (Self::Action, Vec); + + /// Update the policy's parameters. + fn update(&mut self, update: Self::PolicyState); + /// Returns the current parameterization. + fn state(&self) -> Self::PolicyState; + + /// Loads the policy parameters from a record. + fn load_record(self, record: >::Record) -> Self; +} + +/// Trait for a type that can be batched and unbatched (split). +pub trait Batchable: Sized { + /// Create a batch from a list of items. + fn batch(value: Vec) -> Self; + /// Create a list from batched items. + fn unbatch(self) -> Vec; +} + +/// A training output. +pub struct RLTrainOutput { + /// The policy. + pub policy: P, + /// The item. + pub item: TO, +} + +/// Batched transitions for a PolicyLearner. +pub type LearnerTransitionBatch = + TransitionBatch>::Observation,

>::Action>; + +/// Learner for a policy. +pub trait PolicyLearner +where + B: AutodiffBackend, + >::Observation: Clone + Batchable, + >::ActionDistribution: Clone + Batchable, + >::Action: Clone + Batchable, +{ + /// Additional context of a training step. + type TrainContext; + /// The policy to train. + type InnerPolicy: Policy; + /// The record of the learner. + type Record: Record; + + /// Execute a training step on the policy. + fn train( + &mut self, + input: LearnerTransitionBatch, + ) -> RLTrainOutput>::PolicyState>; + /// Returns the learner's current policy. + fn policy(&self) -> Self::InnerPolicy; + /// Update the learner's policy. + fn update_policy(&mut self, update: Self::InnerPolicy); + + /// Convert the learner's state into a record. + fn record(&self) -> Self::Record; + /// Load the learner's state from a record. + fn load_record(self, record: Self::Record) -> Self; +} diff --git a/crates/burn-rl/src/policy/mod.rs b/crates/burn-rl/src/policy/mod.rs new file mode 100644 index 000000000..fa7e96755 --- /dev/null +++ b/crates/burn-rl/src/policy/mod.rs @@ -0,0 +1,5 @@ +mod async_policy; +mod base; + +pub use async_policy::*; +pub use base::*; diff --git a/crates/burn-rl/src/transition_buffer/base.rs b/crates/burn-rl/src/transition_buffer/base.rs new file mode 100644 index 000000000..a937c10fb --- /dev/null +++ b/crates/burn-rl/src/transition_buffer/base.rs @@ -0,0 +1,234 @@ +use burn_core::{Tensor, prelude::Backend, tensor::backend::AutodiffBackend}; +use derive_new::new; +use rand::{rng, seq::index::sample}; + +use crate::Batchable; + +/// A state transition in an environment. +#[derive(Clone, new)] +pub struct Transition { + /// The initial state. + pub state: S, + /// The state after the step was taken. + pub next_state: S, + /// The action taken in the step. + pub action: A, + /// The reward. + pub reward: Tensor, + /// If the environment has reached a terminal state. + pub done: Tensor, +} + +/// A batch of transitions. +pub struct TransitionBatch { + /// Batched initial states. + pub states: SB, + /// Batched resulting states. + pub next_states: SB, + /// Batched actions. + pub actions: AB, + /// Batched rewards. + pub rewards: Tensor, + /// Batched flags for terminal states. + pub dones: Tensor, +} + +impl From>> for TransitionBatch +where + BT: Backend, + B: AutodiffBackend, + S: Into + Clone, + A: Into + Clone, + SB: Batchable, + AB: Batchable, +{ + fn from(value: Vec<&Transition>) -> Self { + let states: Vec<_> = value.iter().map(|t| t.state.clone().into()).collect(); + let next_states: Vec<_> = value.iter().map(|t| t.next_state.clone().into()).collect(); + let actions: Vec<_> = value.iter().map(|t| t.action.clone().into()).collect(); + let rewards: Vec<_> = value.iter().map(|t| t.reward.clone()).collect(); + let dones: Vec<_> = value.iter().map(|t| t.done.clone()).collect(); + + let rewards = Tensor::stack::<2>(rewards, 0); + let dones = Tensor::stack::<2>(dones, 0); + + Self { + states: SB::batch(states), + next_states: SB::batch(next_states), + actions: AB::batch(actions), + rewards: Tensor::from_data(rewards.to_data(), &Default::default()), + dones: Tensor::from_data(dones.to_data(), &Default::default()), + } + } +} + +/// A circular buffer for transitions. +pub struct TransitionBuffer { + buffer: Vec, + capacity: usize, + cursor: usize, +} + +impl TransitionBuffer { + /// Creates a new circular buffer with a fixed capacity. + pub fn new(capacity: usize) -> Self { + Self { + buffer: Vec::with_capacity(capacity), + capacity, + cursor: 0, + } + } + + /// Add an item, overwriting the oldest if full. + pub fn push(&mut self, item: T) { + if self.buffer.len() < self.capacity { + self.buffer.push(item); + } else { + self.buffer[self.cursor] = item; + self.cursor = (self.cursor + 1) % self.capacity; + } + } + + /// Append a list of items to the current buffer. + pub fn append(&mut self, items: &mut Vec) { + let n = items.len(); + let mut is_overflow = false; + if n > self.capacity { + self.cursor = self.capacity - (n % self.capacity); + items.drain(0..n - self.capacity); + is_overflow = true; + } + let n = items.len(); + + let first_part = n.min(self.capacity - self.cursor); + let second_part = n - first_part; + + if is_overflow { + if self.capacity > self.len() { + self.buffer + .extend(items.drain(first_part..second_part + first_part)); + } else { + self.buffer[..second_part] + .iter_mut() + .zip(items.drain(first_part..second_part + first_part)) + .for_each(|(slot, item)| *slot = item); + } + } + + if self.capacity > self.len() { + self.buffer.extend(items.drain(..first_part)); + } else { + self.buffer[self.cursor..self.cursor + first_part] + .iter_mut() + .zip(items.drain(..first_part)) + .for_each(|(slot, item)| *slot = item); + } + + if !is_overflow { + self.buffer[..second_part] + .iter_mut() + .zip(items.drain(..second_part)) + .for_each(|(slot, item)| *slot = item); + } + + self.cursor = (self.cursor + n) % self.capacity + } + + /// Returns the current number of items stored. + pub fn len(&self) -> usize { + self.buffer.len() + } + + /// Returns the current number of items stored. + pub fn is_empty(&self) -> bool { + self.buffer.is_empty() + } + + /// Clear the buffer. + pub fn clear(&mut self) { + self.buffer.clear(); + } + + /// Sample the buffer at the given indices. + pub fn sample(&self, indices: Vec) -> Vec<&T> { + let mut items = Vec::with_capacity(indices.len()); + + for &idx in indices.iter() { + if let Some(item) = self.buffer.get(idx) { + items.push(item); + } + } + items + } + + /// Sample `batch_size` transitions at random. + pub fn random_sample(&self, batch_size: usize) -> Vec<&T> { + assert!(batch_size <= self.len()); + let mut rng = rng(); + let indices = sample(&mut rng, self.len(), batch_size).into_vec(); + self.sample(indices) + } +} + +#[cfg(test)] +mod tests { + use burn_core::tensor::TensorData; + + use crate::TestBackend; + + use super::*; + + fn transition() -> Transition, Tensor> { + Transition::new( + Tensor::from_data(TensorData::from([1.0, 2.0]), &Default::default()), + Tensor::from_data(TensorData::from([1.0, 2.0]), &Default::default()), + Tensor::from_data(TensorData::from([1.0]), &Default::default()), + Tensor::from_data(TensorData::from([1.0]), &Default::default()), + Tensor::from_data(TensorData::from([1.0]), &Default::default()), + ) + } + + #[test] + fn len_returns_number_of_elements() { + let mut buffer: TransitionBuffer< + Transition, Tensor>, + > = TransitionBuffer::new(2); + assert_eq!(buffer.len(), 0); + + buffer.push(transition()); + assert_eq!(buffer.len(), 1); + + buffer.push(transition()); + assert_eq!(buffer.len(), 2); + + buffer.push(transition()); + assert_eq!(buffer.len(), 2) + } + + #[test] + fn append_works() { + let mut buffer = TransitionBuffer::new(4); + assert_eq!(buffer.len(), 0); + + buffer.append(&mut vec![0, 1]); + assert_eq!(buffer.len(), 2); + assert_eq!(buffer.buffer, vec![0, 1]); + + buffer.append(&mut vec![2, 3, 4, 5]); + assert_eq!(buffer.len(), 4); + assert_eq!(buffer.buffer, vec![4, 5, 2, 3]); + + let mut buffer = TransitionBuffer::new(4); + buffer.append(&mut vec![0, 1, 2, 3, 4, 5]); + assert_eq!(buffer.len(), 4); + assert_eq!(buffer.buffer, vec![4, 5, 2, 3]); + + buffer.append(&mut vec![10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]); + assert_eq!(buffer.len(), 4); + assert_eq!(buffer.buffer, vec![20, 17, 18, 19]); + + buffer.append(&mut vec![21, 22]); + assert_eq!(buffer.len(), 4); + assert_eq!(buffer.buffer, vec![20, 21, 22, 19]); + } +} diff --git a/crates/burn-rl/src/transition_buffer/mod.rs b/crates/burn-rl/src/transition_buffer/mod.rs new file mode 100644 index 000000000..cbcb6ac7e --- /dev/null +++ b/crates/burn-rl/src/transition_buffer/mod.rs @@ -0,0 +1,3 @@ +mod base; + +pub use base::*; diff --git a/crates/burn-train/Cargo.toml b/crates/burn-train/Cargo.toml index 5d52c5bc9..343e8c6b5 100644 --- a/crates/burn-train/Cargo.toml +++ b/crates/burn-train/Cargo.toml @@ -37,6 +37,7 @@ burn-core = { path = "../burn-core", version = "=0.21.0", features = [ burn-optim = { path = "../burn-optim", version = "=0.21.0", features = [ "std", ], default-features = false } +burn-rl = { path = "../burn-rl", version = "=0.21.0", default-features = false } burn-collective = { path = "../burn-collective", version = "=0.21.0", optional = true } log = { workspace = true } @@ -62,6 +63,7 @@ async-channel = { workspace = true } burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0" } rstest.workspace = true thiserror.workspace = true +rand.workspace = true [dev-dependencies] burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0" } diff --git a/crates/burn-train/src/evaluator/base.rs b/crates/burn-train/src/evaluator/base.rs index 32d5894c9..435dcb74c 100644 --- a/crates/burn-train/src/evaluator/base.rs +++ b/crates/burn-train/src/evaluator/base.rs @@ -1,7 +1,8 @@ use crate::{ - AsyncProcessorEvaluation, FullEventProcessorEvaluation, InferenceStep, Interrupter, + AsyncProcessorEvaluation, EvaluationItem, FullEventProcessorEvaluation, InferenceStep, + Interrupter, evaluator::components::EvaluatorComponentTypes, - metric::processor::{EvaluatorEvent, EventProcessorEvaluation, LearnerItem}, + metric::processor::{EvaluatorEvent, EventProcessorEvaluation}, renderer::{EvaluationName, MetricsRenderer}, }; use burn_core::{data::dataloader::DataLoader, module::Module}; @@ -43,7 +44,7 @@ impl Evaluator { iteration += 1; let item = self.model.step(item); - let item = LearnerItem::new(item, progress, 0, 1, iteration, None); + let item = EvaluationItem::new(item, progress, Some(iteration)); self.event_processor .process_test(EvaluatorEvent::ProcessedItem(name.clone(), item)); diff --git a/crates/burn-train/src/learner/mod.rs b/crates/burn-train/src/learner/mod.rs index c2af276cc..876320812 100644 --- a/crates/burn-train/src/learner/mod.rs +++ b/crates/burn-train/src/learner/mod.rs @@ -3,6 +3,7 @@ mod base; mod classification; mod early_stopping; mod regression; +mod rl; mod summary; mod supervised; mod train_val; @@ -12,6 +13,7 @@ pub use base::*; pub use classification::*; pub use early_stopping::*; pub use regression::*; +pub use rl::*; pub use summary::*; pub use supervised::*; pub use train_val::*; diff --git a/crates/burn-train/src/learner/rl/checkpointer.rs b/crates/burn-train/src/learner/rl/checkpointer.rs new file mode 100644 index 000000000..0c0bbe890 --- /dev/null +++ b/crates/burn-train/src/learner/rl/checkpointer.rs @@ -0,0 +1,75 @@ +use burn_core::tensor::Device; +use burn_rl::{Policy, PolicyLearner, PolicyState}; + +use crate::RLAgentRecord; +use crate::{ + RLComponentsTypes, RLPolicyRecord, + checkpoint::Checkpointer, + checkpoint::{AsyncCheckpointer, CheckpointingAction, CheckpointingStrategy}, + metric::store::EventStoreClient, +}; + +#[derive(new)] +/// Used to create, delete, or load checkpoints of the training process. +pub struct RLCheckpointer { + policy: AsyncCheckpointer, RLC::Backend>, + learning_agent: AsyncCheckpointer, RLC::Backend>, + strategy: Box, +} + +impl RLCheckpointer { + /// Create checkpoint for the training process. + pub fn checkpoint( + &mut self, + policy: &RLC::PolicyState, + learning_agent: &RLC::LearningAgent, + epoch: usize, + store: &EventStoreClient, + ) { + let actions = self.strategy.checkpointing(epoch, store); + + for action in actions { + match action { + CheckpointingAction::Delete(epoch) => { + self.policy + .delete(epoch) + .expect("Can delete policy checkpoint."); + self.learning_agent + .delete(epoch) + .expect("Can delete learning agent checkpoint.") + } + CheckpointingAction::Save => { + self.policy + .save(epoch, policy.clone().into_record()) + .expect("Can save policy checkpoint."); + self.learning_agent + .save(epoch, learning_agent.record()) + .expect("Can save learning agent checkpoint."); + } + } + } + } + + /// Load a training checkpoint. + pub fn load_checkpoint( + &self, + learning_agent: RLC::LearningAgent, + device: &Device, + epoch: usize, + ) -> RLC::LearningAgent { + let record = self + .policy + .restore(epoch, device) + .expect("Can load model checkpoint."); + let policy = learning_agent.policy().load_record(record); + + let record = self + .learning_agent + .restore(epoch, device) + .expect("Can load learning agent checkpoint."); + let mut learning_agent = learning_agent.load_record(record); + learning_agent.update_policy(policy); + + learning_agent + } +} diff --git a/crates/burn-train/src/learner/rl/components.rs b/crates/burn-train/src/learner/rl/components.rs new file mode 100644 index 000000000..3377dedab --- /dev/null +++ b/crates/burn-train/src/learner/rl/components.rs @@ -0,0 +1,115 @@ +use std::marker::PhantomData; + +use burn_core::tensor::backend::AutodiffBackend; +use burn_rl::{Batchable, Environment, EnvironmentInit, Policy, PolicyLearner, PolicyState}; + +use crate::{AgentEvaluationEvent, AsyncProcessorTraining, ItemLazy, RLEvent}; + +/// All components used by the reinforcement learning paradigm, grouped in one trait. +pub trait RLComponentsTypes { + /// The backend used for training. + type Backend: AutodiffBackend; + /// The learning environment. + type Env: Environment + 'static; + /// Specifies how to initialize the environment. + type EnvInit: EnvironmentInit + Send + 'static; + /// The type of the environment state. + type State: Into<>::Observation> + Clone + Send + 'static; + /// The type of the environment action. + type Action: From<>::Action> + + Into<>::Action> + + Clone + + Send + + 'static; + + /// The policy used to take actions in the environment. + type Policy: Policy< + Self::Backend, + Observation = Self::PolicyObs, + ActionDistribution = Self::PolicyAD, + Action = Self::PolicyAction, + ActionContext = Self::ActionContext, + PolicyState = Self::PolicyState, + > + Send + + 'static; + /// The policy's observation type. + type PolicyObs: Clone + Send + Batchable + 'static; + /// The policy's action distribution type. + type PolicyAD: Clone + Send + Batchable; + /// The policy's action type. + type PolicyAction: Clone + Send + Batchable; + /// Additional data as context for an agent's action. + type ActionContext: ItemLazy + Clone + Send + 'static; + /// The state of the parameterized policy. + type PolicyState: Clone + Send + PolicyState + 'static; + + /// The learning agent. + type LearningAgent: PolicyLearner< + Self::Backend, + TrainContext = Self::TrainingOutput, + InnerPolicy = Self::Policy, + > + Send + + 'static; + /// The output data of a training step. + type TrainingOutput: ItemLazy + Clone + Send; +} + +/// Concrete type that implements the [RLComponentsTypes](RLComponentsTypes) trait. +pub struct RLComponentsMarker { + _backend: PhantomData, + _env: PhantomData, + _env_init: PhantomData, + _agent: PhantomData, +} + +impl RLComponentsTypes for RLComponentsMarker +where + B: AutodiffBackend, + E: Environment + 'static, + EI: EnvironmentInit + Send + 'static, + A: PolicyLearner + Send + 'static, + A::TrainContext: ItemLazy + Clone + Send, + A::InnerPolicy: Policy + Send, + >::Observation: Batchable + Clone + Send, + >::ActionDistribution: Batchable + Clone + Send, + >::Action: Batchable + Clone + Send, + >::ActionContext: ItemLazy + Clone + Send + 'static, + >::PolicyState: Clone + Send, + E::State: Into<>::Observation> + Clone + Send + 'static, + E::Action: From<>::Action> + + Into<>::Action> + + Clone + + Send + + 'static, +{ + type Backend = B; + type Env = E; + type EnvInit = EI; + type LearningAgent = A; + type Policy = A::InnerPolicy; + type PolicyObs = >::Observation; + type PolicyAD = >::ActionDistribution; + type PolicyAction = >::Action; + type ActionContext = >::ActionContext; + type PolicyState = >::PolicyState; + type TrainingOutput = A::TrainContext; + type State = E::State; + type Action = E::Action; +} + +pub(crate) type RlPolicy = <::LearningAgent as PolicyLearner< + ::Backend, +>>::InnerPolicy; +/// The event processor type for reinforcement learning. +pub type RLEventProcessorType = AsyncProcessorTraining< + RLEvent<::TrainingOutput, ::ActionContext>, + AgentEvaluationEvent<::ActionContext>, +>; +/// The record of the policy. +pub type RLPolicyRecord = <<::Policy as Policy< + ::Backend, +>>::PolicyState as PolicyState<::Backend>>::Record; +/// The record of the learning agent. +pub type RLAgentRecord = <::LearningAgent as PolicyLearner< + ::Backend, +>>::Record; diff --git a/crates/burn-train/src/learner/rl/env_runner/async_runner.rs b/crates/burn-train/src/learner/rl/env_runner/async_runner.rs new file mode 100644 index 000000000..257417ea4 --- /dev/null +++ b/crates/burn-train/src/learner/rl/env_runner/async_runner.rs @@ -0,0 +1,504 @@ +use rand::prelude::SliceRandom; +use std::{ + sync::mpsc::{Receiver, Sender}, + thread::spawn, +}; + +use burn_core::{Tensor, data::dataloader::Progress, prelude::Backend, tensor::Device}; +use burn_rl::EnvironmentInit; +use burn_rl::Policy; +use burn_rl::Transition; +use burn_rl::{AsyncPolicy, Environment}; + +use crate::{ + AgentEnvLoop, AgentEvaluationEvent, EpisodeSummary, EvaluationItem, EventProcessorTraining, + Interrupter, RLComponentsTypes, RLEvent, RLEventProcessorType, RLTimeStep, RLTrajectory, + RlPolicy, TimeStep, Trajectory, +}; + +enum RequestMessage { + Step(), + Episode(), +} + +/// An asynchronous agent/environement interface. +pub struct AgentEnvAsyncLoop { + env_init: RLC::EnvInit, + id: usize, + eval: bool, + agent: AsyncPolicy>, + deterministic: bool, + transition_device: Device, + transition_receiver: Receiver>, + transition_sender: Sender>, + trajectory_receiver: Receiver>, + trajectory_sender: Sender>, + request_sender: Option>, +} + +impl AgentEnvAsyncLoop { + /// Create a new asynchronous runner. + pub fn new( + env_init: RLC::EnvInit, + id: usize, + eval: bool, + agent: AsyncPolicy>, + deterministic: bool, + transition_device: &Device, + ) -> Self { + let (transition_sender, transition_receiver) = std::sync::mpsc::channel(); + let (trajectory_sender, trajectory_receiver) = std::sync::mpsc::channel(); + Self { + env_init, + id, + eval, + agent: agent.clone(), + deterministic, + transition_device: transition_device.clone(), + transition_receiver, + transition_sender, + trajectory_receiver, + trajectory_sender, + request_sender: None, + } + } +} + +impl AgentEnvLoop for AgentEnvAsyncLoop +where + BT: Backend, + RLC: RLComponentsTypes, +{ + fn start(&mut self) { + let id = self.id; + let mut agent = self.agent.clone(); + let deterministic = self.deterministic; + let transition_sender = self.transition_sender.clone(); + let trajectory_sender = self.trajectory_sender.clone(); + let device = self.transition_device.clone(); + let env_init = self.env_init.clone(); + + let (request_sender, request_receiver) = std::sync::mpsc::channel(); + self.request_sender = Some(request_sender); + + let mut current_steps = vec![]; + let mut current_reward = 0.0; + let mut step_num = 0; + + spawn(move || { + let mut env = env_init.init(); + env.reset(); + + let mut request_episode = false; + loop { + let state = env.state(); + let (action, context) = agent.action(state.clone().into(), deterministic); + + let env_action = RLC::Action::from(action); + let step_result = env.step(env_action.clone()); + + current_reward += step_result.reward; + step_num += 1; + + let transition = Transition::new( + state.clone(), + step_result.next_state, + env_action, + Tensor::from_data([step_result.reward], &device), + Tensor::from_data( + [(step_result.done || step_result.truncated) as i32 as f64], + &device, + ), + ); + + if !request_episode { + agent.decrement_agents(1); + let request = match request_receiver.recv() { + Ok(req) => req, + Err(err) => { + log::error!("Error in env runner : {}", err); + break; + } + }; + agent.increment_agents(1); + + match request { + RequestMessage::Step() => (), + RequestMessage::Episode() => request_episode = true, + } + } + + let time_step = TimeStep { + env_id: id, + transition, + done: step_result.done, + ep_len: step_num, + cum_reward: current_reward, + action_context: context[0].clone(), + }; + current_steps.push(time_step.clone()); + + if !request_episode && let Err(err) = transition_sender.send(time_step) { + log::error!("Error in env runner : {}", err); + break; + } + + if step_result.done || step_result.truncated { + if request_episode { + request_episode = false; + trajectory_sender + .send(Trajectory { + timesteps: current_steps.clone(), + }) + .expect("Can send trajectory to main thread."); + } + current_steps.clear(); + + env.reset(); + current_reward = 0.; + step_num = 0; + } + } + }); + } + + fn run_steps( + &mut self, + num_steps: usize, + processor: &mut RLEventProcessorType, + interrupter: &Interrupter, + progress: &mut Progress, + ) -> Vec> { + let mut items = vec![]; + for _ in 0..num_steps { + self.request_sender + .as_ref() + .expect("Call start before running steps.") + .send(RequestMessage::Step()) + .expect("Can request transitions."); + let transition = self + .transition_receiver + .recv() + .expect("Can receive transitions."); + items.push(transition.clone()); + + if !self.eval { + progress.items_processed += 1; + processor.process_train(RLEvent::TimeStep(EvaluationItem::new( + transition.action_context, + progress.clone(), + None, + ))); + + if transition.done { + processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new( + EpisodeSummary { + episode_length: transition.ep_len, + cum_reward: transition.cum_reward, + }, + progress.clone(), + None, + ))); + } + } + + if interrupter.should_stop() { + break; + } + } + items + } + + fn run_episodes( + &mut self, + num_episodes: usize, + processor: &mut RLEventProcessorType, + interrupter: &Interrupter, + _progress: &mut Progress, + ) -> Vec> { + let mut items = vec![]; + self.agent.increment_agents(1); + for episode_num in 0..num_episodes { + self.request_sender + .as_ref() + .expect("Call start before running episodes.") + .send(RequestMessage::Episode()) + .expect("Can request episodes."); + let trajectory = self + .trajectory_receiver + .recv() + .expect("Main thread can receive trajectory."); + + for (i, step) in trajectory.timesteps.iter().enumerate() { + // TODO : clean this. + if self.eval { + processor.process_valid(AgentEvaluationEvent::TimeStep(EvaluationItem::new( + step.action_context.clone(), + Progress::new(i, i), + None, + ))); + + if step.done { + processor.process_valid(AgentEvaluationEvent::EpisodeEnd( + EvaluationItem::new( + EpisodeSummary { + episode_length: step.ep_len, + cum_reward: step.cum_reward, + }, + Progress::new(episode_num + 1, num_episodes), + None, + ), + )); + } + } else { + processor.process_train(RLEvent::TimeStep(EvaluationItem::new( + step.action_context.clone(), + Progress::new(i, i), + None, + ))); + + if step.done { + processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new( + EpisodeSummary { + episode_length: step.ep_len, + cum_reward: step.cum_reward, + }, + Progress::new(episode_num + 1, num_episodes), + None, + ))); + } + } + } + + items.push(trajectory); + if interrupter.should_stop() { + break; + } + } + self.agent.decrement_agents(1); + items + } + + fn update_policy(&mut self, update: RLC::PolicyState) { + self.agent.update(update); + } + + fn policy(&self) -> RLC::PolicyState { + self.agent.state() + } +} + +/// An asynchronous runner for multiple agent/environement interfaces. +pub struct MultiAgentEnvLoop { + env_init: RLC::EnvInit, + num_envs: usize, + eval: bool, + agent: AsyncPolicy, + deterministic: bool, + device: Device, + transition_receiver: Receiver>, + transition_sender: Sender>, + trajectory_receiver: Receiver>, + trajectory_sender: Sender>, + request_senders: Vec>, +} + +impl MultiAgentEnvLoop { + /// Create a new asynchronous runner for multiple agent/environement interfaces. + pub fn new( + env_init: RLC::EnvInit, + num_envs: usize, + eval: bool, + agent: AsyncPolicy, + deterministic: bool, + device: &Device, + ) -> Self { + let (transition_sender, transition_receiver) = std::sync::mpsc::channel(); + let (trajectory_sender, trajectory_receiver) = std::sync::mpsc::channel(); + Self { + env_init, + num_envs, + eval, + agent: agent.clone(), + deterministic, + device: device.clone(), + transition_receiver, + transition_sender, + trajectory_receiver, + trajectory_sender, + request_senders: Vec::with_capacity(num_envs), + } + } +} + +impl AgentEnvLoop for MultiAgentEnvLoop +where + BT: Backend, + RLC: RLComponentsTypes, +{ + // TODO: start() shouldn't exist. + fn start(&mut self) { + // Double batching : The environments are always one step ahead of requests. This allows inference for the first batch of steps. + self.agent.increment_agents(self.num_envs); + + for i in 0..self.num_envs { + let mut runner = AgentEnvAsyncLoop::::new( + self.env_init.clone(), + i, + self.eval, + self.agent.clone(), + self.deterministic, + &self.device, + ); + runner.transition_sender = self.transition_sender.clone(); + runner.trajectory_sender = self.trajectory_sender.clone(); + runner.start(); + self.request_senders + .push(runner.request_sender.clone().unwrap()); + } + + // Double batching : The environments are always one step ahead. + self.request_senders.iter().for_each(|s| { + s.send(RequestMessage::Step()) + .expect("Main thread can send step requests.") + }); + } + + fn run_steps( + &mut self, + num_steps: usize, + processor: &mut RLEventProcessorType, + interrupter: &Interrupter, + progress: &mut Progress, + ) -> Vec> { + let mut items = vec![]; + for _ in 0..num_steps { + let transition = self + .transition_receiver + .recv() + .expect("Can receive transitions."); + items.push(transition.clone()); + + self.request_senders[transition.env_id] + .send(RequestMessage::Step()) + .expect("Main thread can request steps."); + + if !self.eval { + progress.items_processed += 1; + processor.process_train(RLEvent::TimeStep(EvaluationItem::new( + transition.action_context, + progress.clone(), + None, + ))); + + if transition.done { + processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new( + EpisodeSummary { + episode_length: transition.ep_len, + cum_reward: transition.cum_reward, + }, + progress.clone(), + None, + ))); + } + } + + if interrupter.should_stop() { + break; + } + } + items + } + + fn update_policy(&mut self, update: RLC::PolicyState) { + self.agent.update(update); + } + + fn run_episodes( + &mut self, + num_episodes: usize, + processor: &mut RLEventProcessorType, + interrupter: &Interrupter, + _progress: &mut Progress, + ) -> Vec> { + // Send `num_episodes` initial requests. + let mut idx = vec![]; + if num_episodes < self.num_envs { + let mut rng = rand::rng(); + let mut vec: Vec = (0..self.num_envs).collect(); + vec.shuffle(&mut rng); + idx = vec.into_iter().take(num_episodes).collect(); + } else { + idx = (0..self.num_envs).collect(); + } + let num_requests = self.num_envs.min(num_episodes); + idx.into_iter().for_each(|i| { + self.request_senders[i] + .send(RequestMessage::Episode()) + .expect("Main thread can request steps."); + }); + + let mut items = vec![]; + for episode_num in 0..num_episodes { + let trajectory = self + .trajectory_receiver + .recv() + .expect("Can receive trajectory."); + items.push(trajectory.clone()); + if items.len() + num_requests <= num_episodes { + self.request_senders[trajectory.timesteps[0].env_id] + .send(RequestMessage::Episode()) + .expect("Main thread can request steps."); + } + for (i, step) in trajectory.timesteps.iter().enumerate() { + if self.eval { + processor.process_valid(AgentEvaluationEvent::TimeStep(EvaluationItem::new( + step.action_context.clone(), + Progress::new(i, i), + None, + ))); + + if step.done { + processor.process_valid(AgentEvaluationEvent::EpisodeEnd( + EvaluationItem::new( + EpisodeSummary { + episode_length: step.ep_len, + cum_reward: step.cum_reward, + }, + Progress::new(episode_num + 1, num_episodes), + None, + ), + )); + } + } else { + processor.process_train(RLEvent::TimeStep(EvaluationItem::new( + step.action_context.clone(), + Progress::new(i, i), + None, + ))); + + if step.done { + processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new( + EpisodeSummary { + episode_length: step.ep_len, + cum_reward: step.cum_reward, + }, + Progress::new(episode_num + 1, num_episodes), + None, + ))); + } + } + } + + if interrupter.should_stop() { + break; + } + } + + items + } + + fn policy(&self) -> RLC::PolicyState { + self.agent.state() + } +} diff --git a/crates/burn-train/src/learner/rl/env_runner/base.rs b/crates/burn-train/src/learner/rl/env_runner/base.rs new file mode 100644 index 000000000..1ba2ed64e --- /dev/null +++ b/crates/burn-train/src/learner/rl/env_runner/base.rs @@ -0,0 +1,273 @@ +use std::marker::PhantomData; + +use burn_core::data::dataloader::Progress; +use burn_core::{Tensor, prelude::Backend}; +use burn_rl::Policy; +use burn_rl::Transition; +use burn_rl::{Environment, EnvironmentInit}; + +use crate::RLEvent; +use crate::{ + AgentEvaluationEvent, EpisodeSummary, EvaluationItem, EventProcessorTraining, + RLEventProcessorType, +}; +use crate::{Interrupter, RLComponentsTypes}; + +/// A trajectory, i.e. a list of ordered [TimeStep](TimeStep). +#[derive(Clone, new)] +pub struct Trajectory { + /// A list of ordered [TimeStep](TimeStep)s. + pub timesteps: Vec>, +} + +/// A timestep debscribing an iteration of the state/decision process. +#[derive(Clone)] +pub struct TimeStep { + /// The environment id. + pub env_id: usize, + /// The [burn_rl::Transition](burn_rl::Transition). + pub transition: Transition, + /// True if the environment reaches a terminal state. + pub done: bool, + /// The running length of the current episode. + pub ep_len: usize, + /// The running cumulative reward. + pub cum_reward: f64, + /// The action's context for this timestep. + pub action_context: C, +} + +pub(crate) type RLTimeStep = TimeStep< + B, + ::State, + ::Action, + ::ActionContext, +>; + +pub(crate) type RLTrajectory = Trajectory< + B, + ::State, + ::Action, + ::ActionContext, +>; + +/// Trait for a structure that implements an agent/environement interface. +pub trait AgentEnvLoop { + /// Start the loop. + fn start(&mut self); + /// Run a certain number of timesteps. + /// + /// # Arguments + /// + /// * `num_steps` - The number of time_steps to run. + /// * `processor` - An [crate::EventProcessorTraining](crate::EventProcessorTraining). + /// * `interrupter` - An [crate::Interrupter](crate::Interrupter). + /// * `num_steps` - The number of time_steps to run. + /// * `progress` - A mutable reference to the learning progress. + /// + /// # Returns + /// + /// A list of ordered timesteps. + fn run_steps( + &mut self, + num_steps: usize, + processor: &mut RLEventProcessorType, + interrupter: &Interrupter, + progress: &mut Progress, + ) -> Vec>; + /// Run a certain number of episodes. + /// + /// # Arguments + /// + /// * `num_episodes` - The number of episodes to run. + /// * `processor` - An [crate::EventProcessorTraining](crate::EventProcessorTraining). + /// * `interrupter` - An [crate::Interrupter](crate::Interrupter). + /// * `progress` - A mutable reference to the learning progress. + /// + /// # Returns + /// + /// A list of ordered timesteps. + fn run_episodes( + &mut self, + num_episodes: usize, + processor: &mut RLEventProcessorType, + interrupter: &Interrupter, + progress: &mut Progress, + ) -> Vec>; + /// Update the runner's agent. + fn update_policy(&mut self, update: RLC::PolicyState); + /// Get the state of the runner's agent. + fn policy(&self) -> RLC::PolicyState; +} + +/// A simple, synchronized agent/environement interface. +pub struct AgentEnvBaseLoop { + env: RLC::Env, + eval: bool, + agent: RLC::Policy, + deterministic: bool, + current_reward: f64, + run_num: usize, + step_num: usize, + _backend: PhantomData, +} + +impl AgentEnvBaseLoop { + /// Create a new base runner. + pub fn new( + env_init: RLC::EnvInit, + agent: RLC::Policy, + eval: bool, + deterministic: bool, + ) -> Self { + Self { + env: env_init.init(), + eval, + agent: agent.clone(), + deterministic, + current_reward: 0.0, + run_num: 0, + step_num: 0, + _backend: PhantomData, + } + } +} + +impl AgentEnvLoop for AgentEnvBaseLoop +where + BT: Backend, + RLC: RLComponentsTypes, +{ + fn start(&mut self) { + self.env.reset(); + } + + fn run_steps( + &mut self, + num_steps: usize, + processor: &mut RLEventProcessorType, + interrupter: &Interrupter, + progress: &mut Progress, + ) -> Vec> { + let mut items = vec![]; + let device = Default::default(); + for _ in 0..num_steps { + let state = self.env.state(); + let (action, context) = self.agent.action(state.clone().into(), self.deterministic); + + let step_result = self.env.step(RLC::Action::from(action.clone())); + + self.current_reward += step_result.reward; + self.step_num += 1; + + let transition = Transition::new( + state.clone(), + step_result.next_state, + RLC::Action::from(action), + Tensor::from_data([step_result.reward], &device), + Tensor::from_data( + [(step_result.done || step_result.truncated) as i32 as f64], + &device, + ), + ); + items.push(TimeStep { + env_id: 0, + transition, + done: step_result.done, + ep_len: self.step_num, + cum_reward: self.current_reward, + action_context: context[0].clone(), + }); + + if !self.eval { + progress.items_processed += 1; + processor.process_train(RLEvent::TimeStep(EvaluationItem::new( + context[0].clone(), + progress.clone(), + None, + ))); + + if step_result.done { + processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new( + EpisodeSummary { + episode_length: self.step_num, + cum_reward: self.current_reward, + }, + progress.clone(), + None, + ))); + } + } + + if interrupter.should_stop() { + break; + } + + if step_result.done || step_result.truncated { + self.env.reset(); + self.current_reward = 0.; + self.step_num = 0; + self.run_num += 1; + } + } + items + } + + fn update_policy(&mut self, update: RLC::PolicyState) { + self.agent.update(update); + } + + fn run_episodes( + &mut self, + num_episodes: usize, + processor: &mut RLEventProcessorType, + interrupter: &Interrupter, + progress: &mut Progress, + ) -> Vec> { + self.env.reset(); + + let mut items = vec![]; + for ep in 0..num_episodes { + let mut steps = vec![]; + loop { + let step = self.run_steps(1, processor, interrupter, progress)[0].clone(); + steps.push(step.clone()); + + if self.eval { + processor.process_valid(AgentEvaluationEvent::TimeStep(EvaluationItem::new( + step.action_context.clone(), + Progress::new(steps.len() + 1, steps.len() + 1), + None, + ))); + + if step.done { + processor.process_valid(AgentEvaluationEvent::EpisodeEnd( + EvaluationItem::new( + EpisodeSummary { + episode_length: step.ep_len, + cum_reward: step.cum_reward, + }, + Progress::new(ep + 1, num_episodes), + None, + ), + )); + } + } + + if interrupter.should_stop() || step.done { + break; + } + } + items.push(Trajectory::new(steps)); + + if interrupter.should_stop() { + break; + } + } + items + } + + fn policy(&self) -> RLC::PolicyState { + self.agent.state() + } +} diff --git a/crates/burn-train/src/learner/rl/env_runner/mod.rs b/crates/burn-train/src/learner/rl/env_runner/mod.rs new file mode 100644 index 000000000..ec150c7d0 --- /dev/null +++ b/crates/burn-train/src/learner/rl/env_runner/mod.rs @@ -0,0 +1,5 @@ +mod async_runner; +mod base; + +pub use async_runner::*; +pub use base::*; diff --git a/crates/burn-train/src/learner/rl/mod.rs b/crates/burn-train/src/learner/rl/mod.rs new file mode 100644 index 000000000..34874dc7c --- /dev/null +++ b/crates/burn-train/src/learner/rl/mod.rs @@ -0,0 +1,15 @@ +mod checkpointer; +mod components; +mod env_runner; +mod off_policy; +mod output; +mod paradigm; +mod strategy; + +pub use checkpointer::*; +pub use components::*; +pub use env_runner::*; +pub use off_policy::*; +pub use output::*; +pub use paradigm::*; +pub use strategy::*; diff --git a/crates/burn-train/src/learner/rl/off_policy.rs b/crates/burn-train/src/learner/rl/off_policy.rs new file mode 100644 index 000000000..7266e4274 --- /dev/null +++ b/crates/burn-train/src/learner/rl/off_policy.rs @@ -0,0 +1,172 @@ +use std::marker::PhantomData; + +use crate::{ + AgentEnvAsyncLoop, AgentEnvLoop, EvaluationItem, EventProcessorTraining, MultiAgentEnvLoop, + RLComponents, RLComponentsTypes, RLEvent, RLEventProcessorType, RLStrategy, +}; +use burn_core::{self as burn}; +use burn_core::{config::Config, data::dataloader::Progress}; +use burn_ndarray::NdArray; +use burn_rl::{AsyncPolicy, Policy, PolicyLearner, Transition, TransitionBuffer}; + +/// Parameters of an on policy training with multi environments and double-batching. +#[derive(Config, Debug)] +pub struct OffPolicyConfig { + /// The number of environments to run simultaneously for experience collection. + #[config(default = 1)] + pub num_envs: usize, + /// Number of environment state to accumulate before running one step of inference with the policy. + /// Must be equal or less than the number of simultaneous environments. + #[config(default = 1)] + pub autobatch_size: usize, + /// Max number of transitions stored in the replay buffer. + #[config(default = 1024)] + pub replay_buffer_size: usize, + /// The number of steps to collect between each step of training. + #[config(default = 1)] + pub train_interval: usize, + /// Number of optimization steps done each `train_interval`. + #[config(default = 1)] + pub train_steps: usize, + /// The number of steps to collect between each evaluation. + #[config(default = 10_000)] + pub eval_interval: usize, + /// The number of episodes to run for each evaluation. + #[config(default = 1)] + pub eval_episodes: usize, + /// The number of transition to train on. + #[config(default = 32)] + pub train_batch_size: usize, + /// Number of steps to collect before starting to train. + #[config(default = 0)] + pub warmup_steps: usize, +} + +/// Off-policy reinforcement learning strategy with multi-env experience collection and double-batching. +pub struct OffPolicyStrategy { + config: OffPolicyConfig, + _components: PhantomData, +} +impl OffPolicyStrategy { + /// Create a new off-policy base strategy. + pub fn new(config: OffPolicyConfig) -> Self { + Self { + config, + _components: PhantomData, + } + } +} + +impl RLStrategy for OffPolicyStrategy +where + RLC: RLComponentsTypes, +{ + fn learn( + &self, + training_components: RLComponents, + learner_agent: &mut RLC::LearningAgent, + starting_epoch: usize, + env_init: RLC::EnvInit, + ) -> (RLC::Policy, RLEventProcessorType) { + let mut event_processor = training_components.event_processor; + let mut checkpointer = training_components.checkpointer; + let num_steps_total = training_components.num_steps; + + let mut env_runner = MultiAgentEnvLoop::::new( + env_init.clone(), + self.config.num_envs, + false, + AsyncPolicy::new( + self.config.num_envs.min(self.config.autobatch_size), + learner_agent.policy(), + ), + false, + &Default::default(), + ); + env_runner.start(); + let mut env_runner_valid = AgentEnvAsyncLoop::::new( + env_init, + 0, + true, + AsyncPolicy::new(1, learner_agent.policy()), + true, + &Default::default(), + ); + env_runner_valid.start(); + let mut transition_buffer = + TransitionBuffer::>::new( + self.config.replay_buffer_size, + ); + + let mut valid_next = self.config.eval_interval + starting_epoch - 1; + let mut progress = Progress { + items_processed: starting_epoch, + items_total: num_steps_total, + }; + + let mut intermediary_update: Option<>::PolicyState> = + None; + while progress.items_processed < num_steps_total { + if training_components.interrupter.should_stop() { + let reason = training_components + .interrupter + .get_message() + .unwrap_or(String::from("Reason unknown")); + log::info!("Training interrupted: {reason}"); + break; + } + + let previous_steps = progress.items_processed; + let items = env_runner.run_steps( + self.config.train_interval, + &mut event_processor, + &training_components.interrupter, + &mut progress, + ); + + transition_buffer.append(&mut items.iter().map(|i| i.transition.clone()).collect()); + + if transition_buffer.len() >= self.config.train_batch_size + && progress.items_processed >= self.config.warmup_steps + { + if let Some(ref u) = intermediary_update { + env_runner.update_policy(u.clone()); + } + for _ in 0..self.config.train_steps { + let transitions = transition_buffer.random_sample(self.config.train_batch_size); + let train_item = learner_agent.train(transitions.into()); + intermediary_update = Some(train_item.policy); + + event_processor.process_train(RLEvent::TrainStep(EvaluationItem::new( + train_item.item, + progress.clone(), + None, + ))); + } + } + + if valid_next > previous_steps && valid_next <= progress.items_processed { + env_runner_valid.update_policy(learner_agent.policy().state()); + env_runner_valid.run_episodes( + self.config.eval_episodes, + &mut event_processor, + &training_components.interrupter, + &mut progress, + ); + + if let Some(checkpointer) = &mut checkpointer { + checkpointer.checkpoint( + &env_runner.policy(), + learner_agent, + valid_next, + &training_components.event_store, + ); + } + + valid_next += self.config.eval_interval; + } + } + + (learner_agent.policy(), event_processor) + } +} diff --git a/crates/burn-train/src/learner/rl/output.rs b/crates/burn-train/src/learner/rl/output.rs new file mode 100644 index 000000000..01177e04c --- /dev/null +++ b/crates/burn-train/src/learner/rl/output.rs @@ -0,0 +1,32 @@ +use crate::{ + ItemLazy, + metric::{Adaptor, CumulativeRewardInput, EpisodeLengthInput}, +}; + +/// Summary of an episode. +pub struct EpisodeSummary { + /// The total length of the episode. + pub episode_length: usize, + /// The final cumulative reward. + pub cum_reward: f64, +} + +impl ItemLazy for EpisodeSummary { + type ItemSync = EpisodeSummary; + + fn sync(self) -> Self::ItemSync { + self + } +} + +impl Adaptor for EpisodeSummary { + fn adapt(&self) -> EpisodeLengthInput { + EpisodeLengthInput::new(self.episode_length as f64) + } +} + +impl Adaptor for EpisodeSummary { + fn adapt(&self) -> CumulativeRewardInput { + CumulativeRewardInput::new(self.cum_reward) + } +} diff --git a/crates/burn-train/src/learner/rl/paradigm.rs b/crates/burn-train/src/learner/rl/paradigm.rs new file mode 100644 index 000000000..672941576 --- /dev/null +++ b/crates/burn-train/src/learner/rl/paradigm.rs @@ -0,0 +1,521 @@ +use crate::checkpoint::{ + AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer, + KeepLastNCheckpoints, MetricCheckpointingStrategy, +}; +use crate::learner::base::Interrupter; +use crate::logger::{FileMetricLogger, MetricLogger}; +use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split}; +use crate::metric::{Adaptor, EpisodeLengthMetric, Metric, Numeric}; +use crate::renderer::{MetricsRenderer, default_renderer}; +use crate::{ + ApplicationLoggerInstaller, AsyncProcessorTraining, FileApplicationLoggerInstaller, ItemLazy, + LearnerSummaryConfig, OffPolicyConfig, OffPolicyStrategy, RLAgentRecord, RLCheckpointer, + RLComponents, RLComponentsMarker, RLComponentsTypes, RLEventProcessor, RLMetrics, + RLPolicyRecord, RLStrategy, +}; +use crate::{EpisodeSummary, RLStrategies}; +use burn_core::record::FileRecorder; +use burn_core::tensor::backend::AutodiffBackend; +use burn_rl::{Batchable, Environment, EnvironmentInit, Policy, PolicyLearner}; +use std::collections::BTreeSet; +use std::path::{Path, PathBuf}; +use std::sync::Arc; + +/// Structure to configure and launch reinforcement learning trainings. +pub struct RLTraining { + // Not that complex. Extracting into yet another type would only make it more confusing. + #[allow(clippy::type_complexity)] + checkpointers: Option<( + AsyncCheckpointer, RLC::Backend>, + AsyncCheckpointer, RLC::Backend>, + )>, + num_steps: usize, + checkpoint: Option, + directory: PathBuf, + grad_accumulation: Option, + renderer: Option>, + metrics: RLMetrics, + event_store: LogEventStore, + interrupter: Interrupter, + tracing_logger: Option>, + checkpointer_strategy: Box, + learning_strategy: RLStrategies, + // Use BTreeSet instead of HashSet for consistent (alphabetical) iteration order + summary_metrics: BTreeSet, + summary: bool, + env_initializer: RLC::EnvInit, +} + +impl RLTraining> +where + B: AutodiffBackend, + E: Environment + 'static, + EI: EnvironmentInit + Send + 'static, + A: PolicyLearner + Send + 'static, + A::TrainContext: ItemLazy + Clone + Send, + A::InnerPolicy: Policy + Send, + >::Observation: Batchable + Clone + Send, + >::ActionDistribution: Batchable + Clone + Send, + >::Action: Batchable + Clone + Send, + >::ActionContext: ItemLazy + Clone + Send + 'static, + >::PolicyState: Clone + Send, + E::State: Into<>::Observation> + Clone + Send + 'static, + E::Action: From<>::Action> + + Into<>::Action> + + Clone + + Send + + 'static, +{ + /// Creates a new runner for reinforcement learning. + /// + /// # Arguments + /// + /// * `directory` - The directory to save the checkpoints. + /// * `env_init` - Specifies how to initialize the environment. + pub fn new(directory: impl AsRef, env_initializer: EI) -> Self { + let directory = directory.as_ref().to_path_buf(); + let experiment_log_file = directory.join("experiment.log"); + Self { + num_steps: 1, + checkpoint: None, + checkpointers: None, + directory, + grad_accumulation: None, + metrics: RLMetrics::default(), + event_store: LogEventStore::default(), + renderer: None, + interrupter: Interrupter::new(), + tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new( + experiment_log_file, + ))), + checkpointer_strategy: Box::new( + ComposedCheckpointingStrategy::builder() + .add(KeepLastNCheckpoints::new(2)) + .add(MetricCheckpointingStrategy::new( + &EpisodeLengthMetric::new(), // default to evaluations' cumulative reward. + Aggregate::Mean, + Direction::Lowest, + Split::Valid, + )) + .build(), + ), + learning_strategy: RLStrategies::OffPolicyStrategy(OffPolicyConfig::new()), + summary_metrics: BTreeSet::new(), + summary: false, + env_initializer, + } + } +} + +impl RLTraining { + /// Replace the default learning strategy (Off Policy learning) with the provided one. + /// + /// # Arguments + /// + /// * `training_strategy` - The training strategy. + pub fn with_learning_strategy(mut self, learning_strategy: RLStrategies) -> Self { + self.learning_strategy = learning_strategy; + self + } + + /// Replace the default metric loggers with the provided ones. + /// + /// # Arguments + /// + /// * `logger` - The training logger. + pub fn with_metric_logger(mut self, logger: ML) -> Self + where + ML: MetricLogger + 'static, + { + self.event_store.register_logger(logger); + self + } + + /// Update the checkpointing_strategy. + pub fn with_checkpointing_strategy( + mut self, + strategy: CS, + ) -> Self { + self.checkpointer_strategy = Box::new(strategy); + self + } + + /// Replace the default CLI renderer with a custom one. + /// + /// # Arguments + /// + /// * `renderer` - The custom renderer. + pub fn renderer(mut self, renderer: MR) -> Self + where + MR: MetricsRenderer + 'static, + { + self.renderer = Some(Box::new(renderer)); + self + } + + /// Register numerical metrics for a training step of the agent. + pub fn metrics_train>(self, metrics: Me) -> Self { + metrics.register(self) + } + + /// Register textual metrics for a training step of the agent. + pub fn text_metrics_train>(self, metrics: Me) -> Self { + metrics.register(self) + } + + /// Register numerical metrics for each action of the agent. + pub fn metrics_agent>(self, metrics: Me) -> Self { + metrics.register(self) + } + + /// Register textual metrics for each action of the agent. + pub fn text_metrics_agent>(self, metrics: Me) -> Self { + metrics.register(self) + } + + /// Register numerical metrics for a completed episode. + pub fn metrics_episode>(self, metrics: Me) -> Self { + metrics.register(self) + } + + /// Register textual metrics for a completed episode. + pub fn text_metrics_episode>(self, metrics: Me) -> Self { + metrics.register(self) + } + + /// Register a textual metric for a training step. + pub fn text_metric_train(mut self, metric: Me) -> Self + where + ::ItemSync: Adaptor, + { + self.metrics.register_text_metric_train(metric); + self + } + + /// Register a [numeric](crate::metric::Numeric) [metric](Metric) for a training step. + pub fn metric_train(mut self, metric: Me) -> Self + where + Me: Metric + Numeric + 'static, + ::ItemSync: Adaptor, + { + self.summary_metrics.insert(metric.name().to_string()); + self.metrics.register_metric_train(metric); + self + } + + /// Register a textual metric for each action taken by the agent. + pub fn text_metric_agent(mut self, metric: Me) -> Self + where + ::ItemSync: Adaptor, + { + self.metrics.register_text_metric_agent(metric.clone()); + self.metrics.register_text_metric_agent_valid(metric); + self + } + + /// Register a [numeric](crate::metric::Numeric) [metric](Metric) for each action taken by the agent. + pub fn metric_agent(mut self, metric: Me) -> Self + where + Me: Metric + Numeric + 'static, + ::ItemSync: Adaptor, + { + self.summary_metrics.insert(metric.name().to_string()); + self.metrics.register_agent_metric(metric.clone()); + self.metrics.register_agent_metric_valid(metric); + self + } + + /// Register a textual metric for a completed episode. + pub fn text_metric_episode(mut self, metric: Me) -> Self + where + EpisodeSummary: Adaptor + 'static, + { + self.metrics.register_text_metric_episode(metric.clone()); + self.metrics.register_text_metric_episode_valid(metric); + self + } + + /// Register a [numeric](crate::metric::Numeric) [metric](Metric) for a completed episode. + pub fn metric_episode(mut self, metric: Me) -> Self + where + Me: Metric + Numeric + 'static, + EpisodeSummary: Adaptor + 'static, + { + self.summary_metrics.insert(metric.name().to_string()); + self.metrics.register_episode_metric(metric.clone()); + self.metrics.register_episode_metric_valid(metric); + self + } + + /// The number of environment steps to train for. + pub fn num_steps(mut self, num_steps: usize) -> Self { + self.num_steps = num_steps; + self + } + + /// The step from which the training must resume. + pub fn checkpoint(mut self, checkpoint: usize) -> Self { + self.checkpoint = Some(checkpoint); + self + } + + /// Provides a handle that can be used to interrupt training. + pub fn interrupter(&self) -> Interrupter { + self.interrupter.clone() + } + + /// Override the handle for stopping training with an externally provided handle + pub fn with_interrupter(mut self, interrupter: Interrupter) -> Self { + self.interrupter = interrupter; + self + } + + /// By default, Rust logs are captured and written into + /// `experiment.log`. If disabled, standard Rust log handling + /// will apply. + pub fn with_application_logger( + mut self, + logger: Option>, + ) -> Self { + self.tracing_logger = logger; + self + } + + /// Register a checkpointer that will save the environment runner's [policy](Policy) + /// and the [PolicyLearner](PolicyLearner) state to different files. + pub fn with_file_checkpointer(mut self, recorder: FR) -> Self + where + FR: FileRecorder + 'static, + FR: FileRecorder<::InnerBackend> + 'static, + { + let checkpoint_dir = self.directory.join("checkpoint"); + let checkpointer_policy = + FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "policy"); + let checkpointer_learning = + FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "learning-agent"); + + self.checkpointers = Some(( + AsyncCheckpointer::new(checkpointer_policy), + AsyncCheckpointer::new(checkpointer_learning), + )); + + self + } + + /// Enable the training summary report. + /// + /// The summary will be displayed after `.launch()`, when the renderer is dropped. + pub fn summary(mut self) -> Self { + self.summary = true; + self + } + + /// Launch the training with the specified [PolicyLearner](PolicyLearner) on the specified environment. + pub fn launch(mut self, learner_agent: RLC::LearningAgent) -> RLResult { + if self.tracing_logger.is_some() + && let Err(e) = self.tracing_logger.as_ref().unwrap().install() + { + log::warn!("Failed to install the experiment logger: {e}"); + } + let renderer = self + .renderer + .unwrap_or_else(|| default_renderer(self.interrupter.clone(), self.checkpoint)); + + if !self.event_store.has_loggers() { + self.event_store + .register_logger(FileMetricLogger::new(self.directory.clone())); + } + + let event_store = Arc::new(EventStoreClient::new(self.event_store)); + let event_processor = AsyncProcessorTraining::new(RLEventProcessor::new( + self.metrics, + renderer, + event_store.clone(), + )); + + let checkpointer = self.checkpointers.map(|(policy, learning_agent)| { + RLCheckpointer::new(policy, learning_agent, self.checkpointer_strategy) + }); + + let summary = if self.summary { + Some(LearnerSummaryConfig { + directory: self.directory, + metrics: self.summary_metrics.into_iter().collect::>(), + }) + } else { + None + }; + + let components = RLComponents:: { + checkpoint: self.checkpoint, + checkpointer, + interrupter: self.interrupter, + event_processor, + event_store, + num_steps: self.num_steps, + grad_accumulation: self.grad_accumulation, + summary, + }; + + match self.learning_strategy { + RLStrategies::OffPolicyStrategy(config) => { + let strategy = OffPolicyStrategy::new(config); + strategy.train(learner_agent, components, self.env_initializer) + } + RLStrategies::Custom(strategy) => { + strategy.train(learner_agent, components, self.env_initializer) + } + } + } +} + +/// The result of reinforcement learning, containing the final policy along with the [renderer](MetricsRenderer). +pub struct RLResult

{ + /// The learned policy. + pub policy: P, + /// The renderer that can be used for follow up training and evaluation. + pub renderer: Box, +} + +/// Trait to fake variadic generics for train step metrics. +pub trait AgentMetricRegistration: Sized { + /// Register the metrics. + fn register(self, builder: RLTraining) -> RLTraining; +} + +/// Trait to fake variadic generics for train step text metrics. +pub trait AgentTextMetricRegistration: Sized { + /// Register the metrics. + fn register(self, builder: RLTraining) -> RLTraining; +} + +/// Trait to fake variadic generics for env step metrics. +pub trait TrainMetricRegistration: Sized { + /// Register the metrics. + fn register(self, builder: RLTraining) -> RLTraining; +} + +/// Trait to fake variadic generics for env step text metrics. +pub trait TrainTextMetricRegistration: Sized { + /// Register the metrics. + fn register(self, builder: RLTraining) -> RLTraining; +} + +/// Trait to fake variadic generics for episode metrics. +pub trait EpisodeMetricRegistration: Sized { + /// Register the metrics. + fn register(self, builder: RLTraining) -> RLTraining; +} + +/// Trait to fake variadic generics for episode text metrics. +pub trait EpisodeTextMetricRegistration: Sized { + /// Register the metrics. + fn register(self, builder: RLTraining) -> RLTraining; +} + +macro_rules! gen_tuple { + ($($M:ident),*) => { + impl<$($M,)* RLC: RLComponentsTypes + 'static> TrainTextMetricRegistration for ($($M,)*) + where + $(::ItemSync: Adaptor<$M::Input>,)* + $($M: Metric + 'static,)* + { + #[allow(non_snake_case)] + fn register( + self, + builder: RLTraining, + ) -> RLTraining { + let ($($M,)*) = self; + $(let builder = builder.text_metric_train($M.clone());)* + builder + } + } + + impl<$($M,)* RLC: RLComponentsTypes + 'static> TrainMetricRegistration for ($($M,)*) + where + $(::ItemSync: Adaptor<$M::Input>,)* + $($M: Metric + Numeric + 'static,)* + { + #[allow(non_snake_case)] + fn register( + self, + builder: RLTraining, + ) -> RLTraining { + let ($($M,)*) = self; + $(let builder = builder.metric_train($M.clone());)* + builder + } + } + + impl<$($M,)* RLC: RLComponentsTypes + 'static> AgentTextMetricRegistration for ($($M,)*) + where + $(::ItemSync: Adaptor<$M::Input>,)* + $($M: Metric + 'static,)* + { + #[allow(non_snake_case)] + fn register( + self, + builder: RLTraining, + ) -> RLTraining { + let ($($M,)*) = self; + $(let builder = builder.text_metric_agent($M.clone());)* + builder + } + } + + impl<$($M,)* RLC: RLComponentsTypes + 'static> AgentMetricRegistration for ($($M,)*) + where + $(::ItemSync: Adaptor<$M::Input>,)* + $($M: Metric + Numeric + 'static,)* + { + #[allow(non_snake_case)] + fn register( + self, + builder: RLTraining, + ) -> RLTraining { + let ($($M,)*) = self; + $(let builder = builder.metric_agent($M.clone());)* + builder + } + } + + impl<$($M,)* RLC: RLComponentsTypes + 'static> EpisodeTextMetricRegistration for ($($M,)*) + where + $(EpisodeSummary: Adaptor<$M::Input> + 'static,)* + $($M: Metric + 'static,)* + { + #[allow(non_snake_case)] + fn register( + self, + builder: RLTraining, + ) -> RLTraining { + let ($($M,)*) = self; + $(let builder = builder.text_metric_episode($M.clone());)* + builder + } + } + + impl<$($M,)* RLC: RLComponentsTypes + 'static> EpisodeMetricRegistration for ($($M,)*) + where + $(EpisodeSummary: Adaptor<$M::Input> + 'static,)* + $($M: Metric + Numeric + 'static,)* + { + #[allow(non_snake_case)] + fn register( + self, + builder: RLTraining, + ) -> RLTraining { + let ($($M,)*) = self; + $(let builder = builder.metric_episode($M.clone());)* + builder + } + } + }; +} + +gen_tuple!(M1); +gen_tuple!(M1, M2); +gen_tuple!(M1, M2, M3); +gen_tuple!(M1, M2, M3, M4); +gen_tuple!(M1, M2, M3, M4, M5); +gen_tuple!(M1, M2, M3, M4, M5, M6); diff --git a/crates/burn-train/src/learner/rl/strategy.rs b/crates/burn-train/src/learner/rl/strategy.rs new file mode 100644 index 000000000..adb791f6f --- /dev/null +++ b/crates/burn-train/src/learner/rl/strategy.rs @@ -0,0 +1,99 @@ +use std::sync::Arc; + +use crate::{ + Interrupter, LearnerSummaryConfig, OffPolicyConfig, RLCheckpointer, RLComponentsTypes, RLEvent, + RLEventProcessorType, RLResult, + metric::{processor::EventProcessorTraining, store::EventStoreClient}, +}; + +/// Struct to minimise parameters passed to [RLStrategy::train]. +pub struct RLComponents { + /// The total number of environment steps. + pub num_steps: usize, + /// The step number from which to continue the training. + pub checkpoint: Option, + /// A checkpointer used to load and save learning checkpoints. + pub checkpointer: Option>, + /// Enables gradients accumulation. + pub grad_accumulation: Option, + /// An [Interupter](Interrupter) that allows aborting the training/evaluation process early. + pub interrupter: Interrupter, + /// An [EventProcessor](crate::EventProcessorTraining) that processes events happening during training and evaluation. + pub event_processor: RLEventProcessorType, + /// A reference to an [EventStoreClient](EventStoreClient). + pub event_store: Arc, + /// Config for creating a summary of the learning + pub summary: Option, +} + +/// The strategy for reinforcement learning. +#[derive(Clone)] +pub enum RLStrategies { + /// Training on one device + OffPolicyStrategy(OffPolicyConfig), + /// Training using a custom learning strategy + Custom(CustomRLStrategy), +} + +/// A reference to an implementation of [RLStrategy]. +pub type CustomRLStrategy = Arc>; + +/// Provides the `fit` function for any learning strategy +pub trait RLStrategy { + /// Train the learner agent with this strategy. + fn train( + &self, + mut learner_agent: RLC::LearningAgent, + mut training_components: RLComponents, + env_init: RLC::EnvInit, + ) -> RLResult { + let starting_epoch = match training_components.checkpoint { + Some(checkpoint) => { + if let Some(checkpointer) = &mut training_components.checkpointer { + learner_agent = checkpointer.load_checkpoint( + learner_agent, + &Default::default(), + checkpoint, + ); + } + checkpoint + 1 + } + None => 1, + }; + + let summary_config = training_components.summary.clone(); + + // Event processor start training + training_components + .event_processor + .process_train(RLEvent::Start); + + // Training loop + let (policy, mut event_processor) = self.learn( + training_components, + &mut learner_agent, + starting_epoch, + env_init, + ); + + let summary = summary_config.and_then(|summary| summary.init().ok()); + + // Signal training end. For the TUI renderer, this handles the exit & return to main screen. + // TODO: summary makes sense for RL? + event_processor.process_train(RLEvent::End(summary)); + + // let model = model.valid(); + let renderer = event_processor.renderer(); + + RLResult { policy, renderer } + } + + /// Training loop for this strategy + fn learn( + &self, + training_components: RLComponents, + learner_agent: &mut RLC::LearningAgent, + starting_epoch: usize, + env_init: RLC::EnvInit, + ) -> (RLC::Policy, RLEventProcessorType); +} diff --git a/crates/burn-train/src/learner/supervised/paradigm.rs b/crates/burn-train/src/learner/supervised/paradigm.rs index 86f7fcb2b..0919e91d0 100644 --- a/crates/burn-train/src/learner/supervised/paradigm.rs +++ b/crates/burn-train/src/learner/supervised/paradigm.rs @@ -16,10 +16,10 @@ use crate::renderer::{MetricsRenderer, default_renderer}; use crate::single::SingleDevicetrainingStrategy; use crate::{ ApplicationLoggerInstaller, EarlyStoppingStrategyRef, FileApplicationLoggerInstaller, - InferenceBackend, InferenceModel, InferenceModelInput, InferenceStep, LearnerModelRecord, - LearnerOptimizerRecord, LearnerSchedulerRecord, LearnerSummaryConfig, LearningCheckpointer, - LearningComponentsMarker, LearningComponentsTypes, LearningResult, TrainStep, TrainingBackend, - TrainingComponents, TrainingModelInput, TrainingStrategy, + InferenceBackend, InferenceModel, InferenceModelInput, InferenceStep, LearnerEvent, + LearnerModelRecord, LearnerOptimizerRecord, LearnerSchedulerRecord, LearnerSummaryConfig, + LearningCheckpointer, LearningComponentsMarker, LearningComponentsTypes, LearningResult, + TrainStep, TrainingBackend, TrainingComponents, TrainingModelInput, TrainingStrategy, }; use crate::{Learner, SupervisedLearningStrategy}; use burn_core::data::dataloader::DataLoader; @@ -38,7 +38,8 @@ pub type TrainLoader = Arc, TrainingModel pub type ValidLoader = Arc, InferenceModelInput>>; /// The event processor type for supervised learning. pub type SupervisedTrainingEventProcessor = AsyncProcessorTraining< - FullEventProcessorTraining, InferenceModelOutput>, + LearnerEvent>, + LearnerEvent>, >; /// Structure to configure and launch supervised learning trainings. @@ -181,7 +182,7 @@ impl SupervisedTraining { metrics.register(self) } - /// Register all metrics as numeric for the training and validation set. + /// Register all metrics as text for the training and validation set. pub fn metrics_text>(self, metrics: Me) -> Self { metrics.register(self) } diff --git a/crates/burn-train/src/learner/supervised/strategies/ddp/epoch.rs b/crates/burn-train/src/learner/supervised/strategies/ddp/epoch.rs index e46a40cbb..588dde51c 100644 --- a/crates/burn-train/src/learner/supervised/strategies/ddp/epoch.rs +++ b/crates/burn-train/src/learner/supervised/strategies/ddp/epoch.rs @@ -1,4 +1,5 @@ use burn_collective::{PeerId, ReduceOperation}; +use burn_core::data::dataloader::Progress; use burn_core::module::AutodiffModule; use burn_core::tensor::backend::AutodiffBackend; use burn_optim::GradientsAccumulator; @@ -9,7 +10,7 @@ use std::sync::{Arc, Mutex}; use crate::SupervisedTrainingEventProcessor; use crate::learner::base::Interrupter; -use crate::metric::processor::{EventProcessorTraining, LearnerEvent, LearnerItem}; +use crate::metric::processor::{EventProcessorTraining, LearnerEvent, TrainingItem}; use crate::{ InferenceStep, Learner, LearningComponentsTypes, TrainLoader, TrainingBackend, ValidLoader, }; @@ -18,14 +19,12 @@ use crate::{ #[derive(new)] pub struct DdpValidEpoch { dataloader: ValidLoader, - epoch_total: usize, } /// A training epoch. #[derive(new)] pub struct DdpTrainEpoch { dataloader: TrainLoader, - epoch_total: usize, grad_accumulation: Option, } @@ -39,10 +38,11 @@ impl DdpValidEpoch { pub fn run( &self, model: &::TrainingModel, - epoch: usize, + global_progress: &Progress, processor: &mut SupervisedTrainingEventProcessor, interrupter: &Interrupter, ) { + let epoch = global_progress.items_processed; log::info!("Executing validation step for epoch {}", epoch); let model = model.valid(); @@ -54,7 +54,13 @@ impl DdpValidEpoch { iteration += 1; let item = model.step(item); - let item = LearnerItem::new(item, progress, epoch, self.epoch_total, iteration, None); + let item = TrainingItem::new( + item, + progress, + global_progress.clone(), + Some(iteration), + None, + ); processor.process_valid(LearnerEvent::ProcessedItem(item)); @@ -84,13 +90,14 @@ impl DdpTrainEpoch { pub fn run( &self, learner: &mut Learner, - epoch: usize, + global_progress: &Progress, processor: Arc>>, interrupter: &Interrupter, peer_id: PeerId, peer_count: usize, is_main: bool, ) { + let epoch = global_progress.items_processed; log::info!("Executing training step for epoch {}", epoch,); let mut iterator = self.dataloader.iter(); @@ -143,12 +150,11 @@ impl DdpTrainEpoch { } } - let item = LearnerItem::new( + let item = TrainingItem::new( item.item, progress, - epoch, - self.epoch_total, - iteration, + global_progress.clone(), + Some(iteration), Some(learner.lr_current()), ); diff --git a/crates/burn-train/src/learner/supervised/strategies/ddp/worker.rs b/crates/burn-train/src/learner/supervised/strategies/ddp/worker.rs index ef11c2e15..1b9ee9636 100644 --- a/crates/burn-train/src/learner/supervised/strategies/ddp/worker.rs +++ b/crates/burn-train/src/learner/supervised/strategies/ddp/worker.rs @@ -1,5 +1,6 @@ use crate::ddp::epoch::{DdpTrainEpoch, DdpValidEpoch}; use crate::ddp::strategy::WorkerComponents; +use crate::single::TrainingLoop; use crate::{ Learner, LearningCheckpointer, LearningComponentsTypes, SupervisedTrainingEventProcessor, TrainLoader, TrainingBackend, ValidLoader, @@ -83,18 +84,19 @@ where // Changed the train epoch to keep the dataloaders let epoch_train = DdpTrainEpoch::::new( self.dataloader_train.clone(), - num_epochs, self.components.grad_accumulation, ); let epoch_valid = self .dataloader_valid - .map(|dataloader| DdpValidEpoch::::new(dataloader, num_epochs)); + .map(|dataloader| DdpValidEpoch::::new(dataloader)); self.learner.fork(&self.device); - for epoch in self.starting_epoch..num_epochs + 1 { + for training_progress in TrainingLoop::new(self.starting_epoch, num_epochs) { + let epoch = training_progress.items_processed; + epoch_train.run( &mut self.learner, - epoch, + &training_progress, self.event_processor.clone(), &interrupter, self.peer_id, @@ -111,7 +113,7 @@ where let mut event_processor = self.event_processor.lock().unwrap(); runner.run( &self.learner.model(), - epoch, + &training_progress, &mut event_processor, &interrupter, ); diff --git a/crates/burn-train/src/learner/supervised/strategies/multi/epoch.rs b/crates/burn-train/src/learner/supervised/strategies/multi/epoch.rs index 9abbfdb71..18d758e1a 100644 --- a/crates/burn-train/src/learner/supervised/strategies/multi/epoch.rs +++ b/crates/burn-train/src/learner/supervised/strategies/multi/epoch.rs @@ -1,10 +1,11 @@ use crate::learner::base::Interrupter; -use crate::metric::processor::{EventProcessorTraining, LearnerEvent, LearnerItem}; +use crate::metric::processor::{EventProcessorTraining, LearnerEvent, TrainingItem}; use crate::train::MultiDevicesTrainStep; use crate::{ Learner, LearningComponentsTypes, MultiDeviceOptim, SupervisedTrainingEventProcessor, TrainLoader, TrainingBackend, }; +use burn_core::data::dataloader::Progress; use burn_core::prelude::DeviceOps; use burn_core::tensor::Device; use burn_core::tensor::backend::DeviceId; @@ -16,7 +17,6 @@ use std::collections::HashMap; #[derive(new)] pub struct MultiDeviceTrainEpoch { dataloaders: Vec>, - epoch_total: usize, grad_accumulation: Option, } @@ -38,30 +38,39 @@ impl MultiDeviceTrainEpoch { pub fn run( &self, learner: &mut Learner, - epoch: usize, + global_progress: &Progress, event_processor: &mut SupervisedTrainingEventProcessor, interrupter: &Interrupter, devices: Vec>>, strategy: MultiDeviceOptim, ) { match strategy { - MultiDeviceOptim::OptimMainDevice => { - self.run_optim_main(learner, epoch, event_processor, interrupter, devices) - } - MultiDeviceOptim::OptimSharded => { - self.run_optim_distr(learner, epoch, event_processor, interrupter, devices) - } + MultiDeviceOptim::OptimMainDevice => self.run_optim_main( + learner, + global_progress, + event_processor, + interrupter, + devices, + ), + MultiDeviceOptim::OptimSharded => self.run_optim_distr( + learner, + global_progress, + event_processor, + interrupter, + devices, + ), } } fn run_optim_main( &self, learner: &mut Learner, - epoch: usize, + global_progress: &Progress, event_processor: &mut SupervisedTrainingEventProcessor, interrupter: &Interrupter, devices: Vec>>, ) { + let epoch = global_progress.items_processed; log::info!( "Executing training step for epoch {} on devices {:?}", epoch, @@ -108,12 +117,11 @@ impl MultiDeviceTrainEpoch { for item in progress_items { iteration += 1; - let item = LearnerItem::new( + let item = TrainingItem::new( item, progress.clone(), - epoch, - self.epoch_total, - iteration, + global_progress.clone(), + Some(iteration), Some(learner.lr_current()), ); @@ -131,11 +139,12 @@ impl MultiDeviceTrainEpoch { fn run_optim_distr( &self, learner: &mut Learner, - epoch: usize, + global_progress: &Progress, event_processor: &mut SupervisedTrainingEventProcessor, interrupter: &Interrupter, devices: Vec>>, ) { + let epoch = global_progress.items_processed; log::info!( "Executing training step for epoch {} on devices {:?}", epoch, @@ -189,12 +198,11 @@ impl MultiDeviceTrainEpoch { for item in progress_items { iteration += 1; - let item = LearnerItem::new( + let item = TrainingItem::new( item, progress.clone(), - epoch, - self.epoch_total, - iteration, + global_progress.clone(), + Some(iteration), Some(learner.lr_current()), ); diff --git a/crates/burn-train/src/learner/supervised/strategies/multi/strategy.rs b/crates/burn-train/src/learner/supervised/strategies/multi/strategy.rs index 2dc67795b..dd627d8df 100644 --- a/crates/burn-train/src/learner/supervised/strategies/multi/strategy.rs +++ b/crates/burn-train/src/learner/supervised/strategies/multi/strategy.rs @@ -1,8 +1,9 @@ use crate::{ Learner, LearningComponentsTypes, MultiDeviceOptim, SupervisedLearningStrategy, SupervisedTrainingEventProcessor, TrainLoader, TrainingBackend, TrainingComponents, - TrainingModel, ValidLoader, multi::epoch::MultiDeviceTrainEpoch, - single::epoch::SingleDeviceValidEpoch, + TrainingModel, ValidLoader, + multi::epoch::MultiDeviceTrainEpoch, + single::{TrainingLoop, epoch::SingleDeviceValidEpoch}, }; use burn_core::{data::dataloader::split::split_dataloader, tensor::Device}; @@ -39,20 +40,19 @@ impl SupervisedLearningStrategy let mut event_processor = training_components.event_processor; let mut checkpointer = training_components.checkpointer; let mut early_stopping = training_components.early_stopping; - let num_epochs = training_components.num_epochs; let epoch_train = MultiDeviceTrainEpoch::::new( dataloader_train.clone(), - num_epochs, training_components.grad_accumulation, ); let epoch_valid: SingleDeviceValidEpoch = - SingleDeviceValidEpoch::new(dataloader_valid.clone(), num_epochs); + SingleDeviceValidEpoch::new(dataloader_valid.clone()); - for epoch in starting_epoch..training_components.num_epochs + 1 { + for training_progress in TrainingLoop::new(starting_epoch, training_components.num_epochs) { + let epoch = training_progress.items_processed; epoch_train.run( &mut learner, - epoch, + &training_progress, &mut event_processor, &training_components.interrupter, self.devices.to_vec(), @@ -70,7 +70,7 @@ impl SupervisedLearningStrategy epoch_valid.run( &learner, - epoch, + &training_progress, &mut event_processor, &training_components.interrupter, ); diff --git a/crates/burn-train/src/learner/supervised/strategies/single/epoch.rs b/crates/burn-train/src/learner/supervised/strategies/single/epoch.rs index f25311880..ad86beca9 100644 --- a/crates/burn-train/src/learner/supervised/strategies/single/epoch.rs +++ b/crates/burn-train/src/learner/supervised/strategies/single/epoch.rs @@ -1,9 +1,10 @@ use crate::learner::base::Interrupter; -use crate::metric::processor::{EventProcessorTraining, LearnerEvent, LearnerItem}; +use crate::metric::processor::{EventProcessorTraining, LearnerEvent, TrainingItem}; use crate::{ InferenceStep, Learner, LearningComponentsTypes, SupervisedTrainingEventProcessor, TrainLoader, ValidLoader, }; +use burn_core::data::dataloader::Progress; use burn_core::module::AutodiffModule; use burn_optim::GradientsAccumulator; @@ -11,14 +12,12 @@ use burn_optim::GradientsAccumulator; #[derive(new)] pub struct SingleDeviceValidEpoch { dataloader: ValidLoader, - epoch_total: usize, } /// A training epoch. #[derive(new)] pub struct SingleDeviceTrainEpoch { dataloader: TrainLoader, - epoch_total: usize, grad_accumulation: Option, } @@ -32,10 +31,11 @@ impl SingleDeviceValidEpoch { pub fn run( &self, learner: &Learner, - epoch: usize, + global_progress: &Progress, processor: &mut SupervisedTrainingEventProcessor, interrupter: &Interrupter, ) { + let epoch = global_progress.items_processed; log::info!("Executing validation step for epoch {}", epoch); let model = learner.model().valid(); @@ -47,7 +47,13 @@ impl SingleDeviceValidEpoch { iteration += 1; let item = model.step(item); - let item = LearnerItem::new(item, progress, epoch, self.epoch_total, iteration, None); + let item = TrainingItem::new( + item, + progress, + global_progress.clone(), + Some(iteration), + None, + ); processor.process_valid(LearnerEvent::ProcessedItem(item)); @@ -75,10 +81,11 @@ impl SingleDeviceTrainEpoch { pub fn run( &self, learner: &mut Learner, - epoch: usize, + global_progress: &Progress, processor: &mut SupervisedTrainingEventProcessor, interrupter: &Interrupter, ) { + let epoch = global_progress.items_processed; log::info!("Executing training step for epoch {}", epoch,); // Single device / dataloader @@ -110,12 +117,11 @@ impl SingleDeviceTrainEpoch { None => learner.optimizer_step(item.grads), } - let item = LearnerItem::new( + let item = TrainingItem::new( item.item, progress, - epoch, - self.epoch_total, - iteration, + global_progress.clone(), + Some(iteration), Some(learner.lr_current()), ); diff --git a/crates/burn-train/src/learner/supervised/strategies/single/strategy.rs b/crates/burn-train/src/learner/supervised/strategies/single/strategy.rs index bdbcd3026..f5e7832ea 100644 --- a/crates/burn-train/src/learner/supervised/strategies/single/strategy.rs +++ b/crates/burn-train/src/learner/supervised/strategies/single/strategy.rs @@ -3,7 +3,7 @@ use crate::{ TrainLoader, TrainingBackend, TrainingComponents, TrainingModel, ValidLoader, single::epoch::{SingleDeviceTrainEpoch, SingleDeviceValidEpoch}, }; -use burn_core::tensor::Device; +use burn_core::{data::dataloader::Progress, tensor::Device}; /// Simplest learning strategy possible, with only a single devices doing both the training and /// validation. @@ -16,6 +16,31 @@ impl SingleDevicetrainingStrategy { } } +#[derive(new)] +pub(crate) struct TrainingLoop { + next_iteration: usize, + total_iteration: usize, +} + +/// An iterator that returns the progress of the training. +impl Iterator for TrainingLoop { + type Item = Progress; + + fn next(&mut self) -> Option { + if self.next_iteration > self.total_iteration { + return None; + } + + let progress = Progress { + items_processed: self.next_iteration, + items_total: self.total_iteration, + }; + + self.next_iteration += 1; + Some(progress) + } +} + impl SupervisedLearningStrategy for SingleDevicetrainingStrategy { @@ -33,20 +58,17 @@ impl SupervisedLearningStrategy let mut event_processor = training_components.event_processor; let mut checkpointer = training_components.checkpointer; let mut early_stopping = training_components.early_stopping; - let num_epochs = training_components.num_epochs; - let epoch_train: SingleDeviceTrainEpoch = SingleDeviceTrainEpoch::new( - dataloader_train, - num_epochs, - training_components.grad_accumulation, - ); + let epoch_train: SingleDeviceTrainEpoch = + SingleDeviceTrainEpoch::new(dataloader_train, training_components.grad_accumulation); let epoch_valid: SingleDeviceValidEpoch = - SingleDeviceValidEpoch::new(dataloader_valid.clone(), num_epochs); + SingleDeviceValidEpoch::new(dataloader_valid.clone()); - for epoch in starting_epoch..training_components.num_epochs + 1 { + for training_progress in TrainingLoop::new(starting_epoch, training_components.num_epochs) { + let epoch = training_progress.items_processed; epoch_train.run( &mut learner, - epoch, + &training_progress, &mut event_processor, &training_components.interrupter, ); @@ -62,7 +84,7 @@ impl SupervisedLearningStrategy epoch_valid.run( &learner, - epoch, + &training_progress, &mut event_processor, &training_components.interrupter, ); diff --git a/crates/burn-train/src/metric/base.rs b/crates/burn-train/src/metric/base.rs index e73c5421f..185598796 100644 --- a/crates/burn-train/src/metric/base.rs +++ b/crates/burn-train/src/metric/base.rs @@ -8,14 +8,11 @@ pub struct MetricMetadata { /// The current progress. pub progress: Progress, - /// The current epoch. - pub epoch: usize, - - /// The total number of epochs. - pub epoch_total: usize, + /// The global progress of the training (e.g. epochs). + pub global_progress: Progress, /// The current iteration. - pub iteration: usize, + pub iteration: Option, /// The current learning rate. pub lr: Option, @@ -30,9 +27,11 @@ impl MetricMetadata { items_processed: 1, items_total: 1, }, - epoch: 0, - epoch_total: 1, - iteration: 0, + global_progress: Progress { + items_processed: 0, + items_total: 1, + }, + iteration: Some(0), lr: None, } } diff --git a/crates/burn-train/src/metric/iteration.rs b/crates/burn-train/src/metric/iteration.rs index 16fb3e7b9..077ddde0b 100644 --- a/crates/burn-train/src/metric/iteration.rs +++ b/crates/burn-train/src/metric/iteration.rs @@ -38,7 +38,14 @@ impl Metric for IterationSpeedMetric { fn update(&mut self, _: &Self::Input, metadata: &MetricMetadata) -> SerializedEntry { let raw = match self.instant { - Some(val) => metadata.iteration as f64 / val.elapsed().as_secs_f64(), + Some(val) => { + // If iteration is not logged, compute the speed over the number of items processed. + // 1 iteration should equal 1 item when iteration is not logged. + metadata + .iteration + .unwrap_or(metadata.progress.items_processed) as f64 + / val.elapsed().as_secs_f64() + } None => { self.instant = Some(std::time::Instant::now()); 0.0 diff --git a/crates/burn-train/src/metric/mod.rs b/crates/burn-train/src/metric/mod.rs index b299bcb6f..a192b67b9 100644 --- a/crates/burn-train/src/metric/mod.rs +++ b/crates/burn-train/src/metric/mod.rs @@ -37,6 +37,7 @@ mod loss; mod perplexity; mod precision; mod recall; +mod rl; mod top_k_acc; mod wer; @@ -53,6 +54,7 @@ pub use loss::*; pub use perplexity::*; pub use precision::*; pub use recall::*; +pub use rl::*; pub use top_k_acc::*; pub use wer::*; diff --git a/crates/burn-train/src/metric/processor/async_wrapper.rs b/crates/burn-train/src/metric/processor/async_wrapper.rs index debd4ce6c..f6ae8f214 100644 --- a/crates/burn-train/src/metric/processor/async_wrapper.rs +++ b/crates/burn-train/src/metric/processor/async_wrapper.rs @@ -1,11 +1,11 @@ use crate::metric::processor::{EvaluatorEvent, EventProcessorEvaluation}; -use super::{EventProcessorTraining, LearnerEvent}; +use super::EventProcessorTraining; use async_channel::{Receiver, Sender}; /// Event processor for the training process. -pub struct AsyncProcessorTraining { - sender: Sender>, +pub struct AsyncProcessorTraining { + sender: Sender>, } /// Event processor for the model evaluation. @@ -13,9 +13,9 @@ pub struct AsyncProcessorEvaluation { sender: Sender>, } -struct WorkerTraining { +struct WorkerTraining> { processor: P, - rec: Receiver>, + rec: Receiver>, } struct WorkerEvaluation { @@ -23,8 +23,10 @@ struct WorkerEvaluation { rec: Receiver>, } -impl WorkerTraining

{ - pub fn start(processor: P, rec: Receiver>) { +impl + 'static> + WorkerTraining +{ + pub fn start(processor: P, rec: Receiver>) { let mut worker = Self { processor, rec }; std::thread::spawn(move || { @@ -59,9 +61,9 @@ impl WorkerEvaluation

{ } } -impl AsyncProcessorTraining

{ +impl AsyncProcessorTraining { /// Create an event processor for training. - pub fn new(processor: P) -> Self { + pub fn new + 'static>(processor: P) -> Self { let (sender, rec) = async_channel::bounded(1); WorkerTraining::start(processor, rec); @@ -81,9 +83,9 @@ impl AsyncProcessorEvaluation

{ } } -enum Message { - Train(LearnerEvent), - Valid(LearnerEvent), +enum Message { + Train(EventTrain), + Valid(EventValid), Renderer(Sender>), } @@ -92,15 +94,12 @@ enum EvalMessage { Renderer(Sender>), } -impl EventProcessorTraining for AsyncProcessorTraining

{ - type ItemTrain = P::ItemTrain; - type ItemValid = P::ItemValid; - - fn process_train(&mut self, event: LearnerEvent) { +impl EventProcessorTraining for AsyncProcessorTraining { + fn process_train(&mut self, event: ET) { self.sender.send_blocking(Message::Train(event)).unwrap(); } - fn process_valid(&mut self, event: LearnerEvent) { + fn process_valid(&mut self, event: EV) { self.sender.send_blocking(Message::Valid(event)).unwrap(); } diff --git a/crates/burn-train/src/metric/processor/base.rs b/crates/burn-train/src/metric/processor/base.rs index a547eccd3..5d2064f4a 100644 --- a/crates/burn-train/src/metric/processor/base.rs +++ b/crates/burn-train/src/metric/processor/base.rs @@ -2,7 +2,7 @@ use burn_core::data::dataloader::Progress; use burn_optim::LearningRate; use crate::{ - LearnerSummary, + EpisodeSummary, LearnerSummary, renderer::{EvaluationName, MetricsRenderer}, }; @@ -11,19 +11,45 @@ pub enum LearnerEvent { /// Signal the start of the process (e.g., training start) Start, /// Signal that an item have been processed. - ProcessedItem(LearnerItem), + ProcessedItem(TrainingItem), /// Signal the end of an epoch. EndEpoch(usize), /// Signal the end of the process (e.g., training end). End(Option), } +/// Event happening during reinforcement learning. +pub enum RLEvent { + /// Signal the start of the process (e.g., learning starts). + Start, + /// Signal an agent's training step. + TrainStep(EvaluationItem), + /// Signal a timestep of the agent-environment interface. + TimeStep(EvaluationItem), + /// Signal an episode end. + EpisodeEnd(EvaluationItem), + /// Signal the end of the process (e.g., learning ends). + End(Option), +} + +/// Event happening during evaluation of a reinforcement learning's agent. +pub enum AgentEvaluationEvent { + /// Signal the start of the process (e.g., training start) + Start, + /// Signal a timestep of the agent-environment interface. + TimeStep(EvaluationItem), + /// Signal an episode end. + EpisodeEnd(EvaluationItem), + /// Signal the end of the process (e.g., training end). + End, +} + /// Event happening during the evaluation process. pub enum EvaluatorEvent { /// Signal the start of the process (e.g., training start) Start, /// Signal that an item have been processed. - ProcessedItem(EvaluationName, LearnerItem), + ProcessedItem(EvaluationName, EvaluationItem), /// Signal the end of the process (e.g., training end). End, } @@ -40,16 +66,11 @@ pub trait ItemLazy: Send { } /// Process events happening during training and validation. -pub trait EventProcessorTraining: Send { - /// The training item. - type ItemTrain: ItemLazy; - /// The validation item. - type ItemValid: ItemLazy; - +pub trait EventProcessorTraining: Send { /// Collect a training event. - fn process_train(&mut self, event: LearnerEvent); + fn process_train(&mut self, event: TrainEvent); /// Collect a validation event. - fn process_valid(&mut self, event: LearnerEvent); + fn process_valid(&mut self, event: ValidEvent); /// Returns the renderer used for training. fn renderer(self) -> Box; } @@ -68,41 +89,62 @@ pub trait EventProcessorEvaluation: Send { /// A learner item. #[derive(new)] -pub struct LearnerItem { +pub struct TrainingItem { /// The item. pub item: T, /// The progress. pub progress: Progress, - /// The epoch. - pub epoch: usize, + /// The global progress of the training (e.g. epochs). + pub global_progress: Progress, - /// The total number of epochs. - pub epoch_total: usize, - - /// The iteration. - pub iteration: usize, + /// The iteration, if it it different from the items processed. + pub iteration: Option, /// The learning rate. pub lr: Option, } -impl ItemLazy for LearnerItem { - type ItemSync = LearnerItem; +impl ItemLazy for TrainingItem { + type ItemSync = TrainingItem; fn sync(self) -> Self::ItemSync { - LearnerItem { + TrainingItem { item: self.item.sync(), progress: self.progress, - epoch: self.epoch, - epoch_total: self.epoch_total, + global_progress: self.global_progress, iteration: self.iteration, lr: self.lr, } } } +/// An evaluation item. +#[derive(new)] +pub struct EvaluationItem { + /// The item. + pub item: T, + + /// The progress. + pub progress: Progress, + + /// The iteration, if it it different from the items processed. + pub iteration: Option, +} + +impl ItemLazy for EvaluationItem { + type ItemSync = EvaluationItem; + + fn sync(self) -> Self::ItemSync { + EvaluationItem { + item: self.item.sync(), + progress: self.progress, + iteration: self.iteration, + } + } +} + impl ItemLazy for () { type ItemSync = (); diff --git a/crates/burn-train/src/metric/processor/full.rs b/crates/burn-train/src/metric/processor/full.rs index 1fae0a2aa..9cd791915 100644 --- a/crates/burn-train/src/metric/processor/full.rs +++ b/crates/burn-train/src/metric/processor/full.rs @@ -1,7 +1,9 @@ use super::{EventProcessorTraining, ItemLazy, LearnerEvent, MetricsTraining}; use crate::metric::processor::{EvaluatorEvent, EventProcessorEvaluation, MetricsEvaluation}; use crate::metric::store::{EpochSummary, EventStoreClient, Split}; -use crate::renderer::{MetricState, MetricsRenderer}; +use crate::renderer::{ + EvaluationProgress, MetricState, MetricsRenderer, ProgressType, TrainingProgress, +}; use std::sync::Arc; /// An [event processor](EventProcessorTraining) that handles: @@ -34,6 +36,30 @@ impl FullEventProcessorTraining { store, } } + + fn progress_indicators(&self, progress: &TrainingProgress) -> Vec { + let mut indicators = vec![]; + indicators.push(ProgressType::Detailed { + tag: String::from("Epoch"), + progress: progress.global_progress.clone(), + }); + + if let Some(iteration) = progress.iteration { + indicators.push(ProgressType::Value { + tag: String::from("Iteration"), + value: iteration, + }); + }; + + if let Some(p) = &progress.progress { + indicators.push(ProgressType::Detailed { + tag: String::from("Items"), + progress: p.clone(), + }); + }; + + indicators + } } impl FullEventProcessorEvaluation { @@ -48,6 +74,23 @@ impl FullEventProcessorEvaluation { store, } } + + fn progress_indicators(&self, progress: &EvaluationProgress) -> Vec { + let mut indicators = vec![]; + if let Some(iteration) = progress.iteration { + indicators.push(ProgressType::Value { + tag: String::from("Iteration"), + value: iteration, + }); + }; + + indicators.push(ProgressType::Detailed { + tag: String::from("Items"), + progress: progress.progress.clone(), + }); + + indicators + } } impl EventProcessorEvaluation for FullEventProcessorEvaluation { @@ -95,7 +138,8 @@ impl EventProcessorEvaluation for FullEventProcessorEvaluation { ) }); - self.renderer.render_test(progress); + let indicators = self.progress_indicators(&progress); + self.renderer.render_test(progress, indicators); } EvaluatorEvent::End => { self.renderer.on_test_end().ok(); @@ -108,11 +152,10 @@ impl EventProcessorEvaluation for FullEventProcessorEvaluation { } } -impl EventProcessorTraining for FullEventProcessorTraining { - type ItemTrain = T; - type ItemValid = V; - - fn process_train(&mut self, event: LearnerEvent) { +impl EventProcessorTraining, LearnerEvent> + for FullEventProcessorTraining +{ + fn process_train(&mut self, event: LearnerEvent) { match event { LearnerEvent::Start => { let definitions = self.metrics.metric_definitions(); @@ -149,7 +192,8 @@ impl EventProcessorTraining for FullEventProcessorTrai )) }); - self.renderer.render_train(progress); + let indicators = self.progress_indicators(&progress); + self.renderer.render_train(progress, indicators); } LearnerEvent::EndEpoch(epoch) => { self.store @@ -165,7 +209,7 @@ impl EventProcessorTraining for FullEventProcessorTrai } } - fn process_valid(&mut self, event: LearnerEvent) { + fn process_valid(&mut self, event: LearnerEvent) { match event { LearnerEvent::Start => {} // no-op for now LearnerEvent::ProcessedItem(item) => { @@ -193,7 +237,8 @@ impl EventProcessorTraining for FullEventProcessorTrai )) }); - self.renderer.render_valid(progress); + let indicators = self.progress_indicators(&progress); + self.renderer.render_valid(progress, indicators); } LearnerEvent::EndEpoch(epoch) => { self.store @@ -206,7 +251,7 @@ impl EventProcessorTraining for FullEventProcessorTrai LearnerEvent::End(_) => {} // no-op for now } } - fn renderer(self) -> Box { + fn renderer(self) -> Box { self.renderer } } diff --git a/crates/burn-train/src/metric/processor/metrics.rs b/crates/burn-train/src/metric/processor/metrics.rs index 78d0211e2..935e0cdee 100644 --- a/crates/burn-train/src/metric/processor/metrics.rs +++ b/crates/burn-train/src/metric/processor/metrics.rs @@ -1,7 +1,8 @@ use std::collections::HashMap; -use super::{ItemLazy, LearnerItem}; +use super::{ItemLazy, TrainingItem}; use crate::{ + EvaluationItem, metric::{ Adaptor, Metric, MetricDefinition, MetricEntry, MetricId, MetricMetadata, Numeric, store::{MetricsUpdate, NumericMetricUpdate}, @@ -83,19 +84,19 @@ impl MetricsEvaluation { /// Update the testing information from the testing item. pub(crate) fn update_test( &mut self, - item: &LearnerItem, + item: &EvaluationItem, metadata: &MetricMetadata, ) -> MetricsUpdate { let mut entries = Vec::with_capacity(self.test.len()); let mut entries_numeric = Vec::with_capacity(self.test_numeric.len()); for metric in self.test.iter_mut() { - let state = metric.update(item, metadata); + let state = metric.update(&item.item, metadata); entries.push(state); } for metric in self.test_numeric.iter_mut() { - let numeric_update = metric.update(item, metadata); + let numeric_update = metric.update(&item.item, metadata); entries_numeric.push(numeric_update); } @@ -162,19 +163,19 @@ impl MetricsTraining { /// Update the training information from the training item. pub(crate) fn update_train( &mut self, - item: &LearnerItem, + item: &TrainingItem, metadata: &MetricMetadata, ) -> MetricsUpdate { let mut entries = Vec::with_capacity(self.train.len()); let mut entries_numeric = Vec::with_capacity(self.train_numeric.len()); for metric in self.train.iter_mut() { - let state = metric.update(item, metadata); + let state = metric.update(&item.item, metadata); entries.push(state); } for metric in self.train_numeric.iter_mut() { - let numeric_update = metric.update(item, metadata); + let numeric_update = metric.update(&item.item, metadata); entries_numeric.push(numeric_update); } @@ -184,19 +185,19 @@ impl MetricsTraining { /// Update the training information from the validation item. pub(crate) fn update_valid( &mut self, - item: &LearnerItem, + item: &TrainingItem, metadata: &MetricMetadata, ) -> MetricsUpdate { let mut entries = Vec::with_capacity(self.valid.len()); let mut entries_numeric = Vec::with_capacity(self.valid_numeric.len()); for metric in self.valid.iter_mut() { - let state = metric.update(item, metadata); + let state = metric.update(&item.item, metadata); entries.push(state); } for metric in self.valid_numeric.iter_mut() { - let numeric_update = metric.update(item, metadata); + let numeric_update = metric.update(&item.item, metadata); entries_numeric.push(numeric_update); } @@ -224,19 +225,28 @@ impl MetricsTraining { } } -impl From<&LearnerItem> for TrainingProgress { - fn from(item: &LearnerItem) -> Self { +impl From<&TrainingItem> for TrainingProgress { + fn from(item: &TrainingItem) -> Self { Self { - progress: item.progress.clone(), - epoch: item.epoch, - epoch_total: item.epoch_total, + progress: Some(item.progress.clone()), + global_progress: item.global_progress.clone(), iteration: item.iteration, } } } -impl From<&LearnerItem> for EvaluationProgress { - fn from(item: &LearnerItem) -> Self { +impl From<&EvaluationItem> for TrainingProgress { + fn from(item: &EvaluationItem) -> Self { + Self { + progress: None, + global_progress: item.progress.clone(), + iteration: item.iteration, + } + } +} + +impl From<&EvaluationItem> for EvaluationProgress { + fn from(item: &EvaluationItem) -> Self { Self { progress: item.progress.clone(), iteration: item.iteration, @@ -244,31 +254,41 @@ impl From<&LearnerItem> for EvaluationProgress { } } -impl From<&LearnerItem> for MetricMetadata { - fn from(item: &LearnerItem) -> Self { +impl From<&TrainingItem> for MetricMetadata { + fn from(item: &TrainingItem) -> Self { Self { progress: item.progress.clone(), - epoch: item.epoch, - epoch_total: item.epoch_total, + global_progress: item.global_progress.clone(), iteration: item.iteration, lr: item.lr, } } } -trait NumericMetricUpdater: Send + Sync { - fn update(&mut self, item: &LearnerItem, metadata: &MetricMetadata) -> NumericMetricUpdate; +impl From<&EvaluationItem> for MetricMetadata { + fn from(item: &EvaluationItem) -> Self { + Self { + progress: item.progress.clone(), + global_progress: item.progress.clone(), + iteration: item.iteration, + lr: None, + } + } +} + +pub(crate) trait NumericMetricUpdater: Send + Sync { + fn update(&mut self, item: &T, metadata: &MetricMetadata) -> NumericMetricUpdate; fn clear(&mut self); } -trait MetricUpdater: Send + Sync { - fn update(&mut self, item: &LearnerItem, metadata: &MetricMetadata) -> MetricEntry; +pub(crate) trait MetricUpdater: Send + Sync { + fn update(&mut self, item: &T, metadata: &MetricMetadata) -> MetricEntry; fn clear(&mut self); } -struct MetricWrapper { - id: MetricId, - metric: M, +pub(crate) struct MetricWrapper { + pub id: MetricId, + pub metric: M, } impl MetricWrapper { @@ -286,8 +306,8 @@ where M: Metric + Numeric + 'static, T: Adaptor, { - fn update(&mut self, item: &LearnerItem, metadata: &MetricMetadata) -> NumericMetricUpdate { - let serialized_entry = self.metric.update(&item.item.adapt(), metadata); + fn update(&mut self, item: &T, metadata: &MetricMetadata) -> NumericMetricUpdate { + let serialized_entry = self.metric.update(&item.adapt(), metadata); let update = MetricEntry::new(self.id.clone(), serialized_entry); let numeric = self.metric.value(); let running = self.metric.running_value(); @@ -310,8 +330,8 @@ where M: Metric + 'static, T: Adaptor, { - fn update(&mut self, item: &LearnerItem, metadata: &MetricMetadata) -> MetricEntry { - let serialized_entry = self.metric.update(&item.item.adapt(), metadata); + fn update(&mut self, item: &T, metadata: &MetricMetadata) -> MetricEntry { + let serialized_entry = self.metric.update(&item.adapt(), metadata); MetricEntry::new(self.id.clone(), serialized_entry) } diff --git a/crates/burn-train/src/metric/processor/minimal.rs b/crates/burn-train/src/metric/processor/minimal.rs index 9a5c657a3..b68d0d556 100644 --- a/crates/burn-train/src/metric/processor/minimal.rs +++ b/crates/burn-train/src/metric/processor/minimal.rs @@ -14,11 +14,10 @@ pub(crate) struct MinimalEventProcessor { store: Arc, } -impl EventProcessorTraining for MinimalEventProcessor { - type ItemTrain = T; - type ItemValid = V; - - fn process_train(&mut self, event: LearnerEvent) { +impl EventProcessorTraining, LearnerEvent> + for MinimalEventProcessor +{ + fn process_train(&mut self, event: LearnerEvent) { match event { LearnerEvent::Start => { let definitions = self.metrics.metric_definitions(); @@ -47,7 +46,7 @@ impl EventProcessorTraining for MinimalEventProcessor< } } - fn process_valid(&mut self, event: LearnerEvent) { + fn process_valid(&mut self, event: LearnerEvent) { match event { LearnerEvent::Start => {} // no-op for now LearnerEvent::ProcessedItem(item) => { diff --git a/crates/burn-train/src/metric/processor/mod.rs b/crates/burn-train/src/metric/processor/mod.rs index 66afcd2c4..4b53d4913 100644 --- a/crates/burn-train/src/metric/processor/mod.rs +++ b/crates/burn-train/src/metric/processor/mod.rs @@ -3,10 +3,14 @@ mod base; mod full; mod metrics; mod minimal; +mod rl_metrics; +mod rl_processor; pub use base::*; pub(crate) use full::*; pub(crate) use metrics::*; +pub(crate) use rl_metrics::*; +pub(crate) use rl_processor::*; #[cfg(test)] pub(crate) use minimal::*; @@ -17,7 +21,7 @@ pub use async_wrapper::{AsyncProcessorEvaluation, AsyncProcessorTraining}; pub(crate) mod test_utils { use crate::metric::{ Adaptor, LossInput, - processor::{EventProcessorTraining, LearnerEvent, LearnerItem, MinimalEventProcessor}, + processor::{EventProcessorTraining, LearnerEvent, MinimalEventProcessor, TrainingItem}, }; use burn_core::tensor::{ElementConversion, Tensor, backend::Backend}; @@ -47,14 +51,16 @@ pub(crate) mod test_utils { items_processed: 1, items_total: 10, }; - let num_epochs = 3; - let dummy_iteration = 1; + let dummy_global_progress = burn_core::data::dataloader::Progress { + items_processed: epoch, + items_total: 3, + }; + let dummy_iteration = Some(1); - processor.process_train(LearnerEvent::ProcessedItem(LearnerItem::new( + processor.process_train(LearnerEvent::ProcessedItem(TrainingItem::new( value, dummy_progress, - epoch, - num_epochs, + dummy_global_progress, dummy_iteration, None, ))); diff --git a/crates/burn-train/src/metric/processor/rl_metrics.rs b/crates/burn-train/src/metric/processor/rl_metrics.rs new file mode 100644 index 000000000..d520acacb --- /dev/null +++ b/crates/burn-train/src/metric/processor/rl_metrics.rs @@ -0,0 +1,268 @@ +use std::collections::HashMap; + +use crate::{ + EpisodeSummary, EvaluationItem, ItemLazy, MetricUpdater, MetricWrapper, NumericMetricUpdater, + metric::{ + Adaptor, Metric, MetricDefinition, MetricId, MetricMetadata, Numeric, store::MetricsUpdate, + }, +}; + +pub(crate) struct RLMetrics { + train_step: Vec>>, + env_step: Vec>>, + env_step_valid: Vec>>, + episode_end: Vec>>, + episode_end_valid: Vec>>, + + train_step_numeric: Vec>>, + env_step_numeric: Vec>>, + env_step_valid_numeric: Vec>>, + episode_end_numeric: Vec>>, + episode_end_valid_numeric: Vec>>, + + metric_definitions: HashMap, +} + +impl Default for RLMetrics { + fn default() -> Self { + Self { + train_step: Vec::default(), + env_step: Vec::default(), + env_step_valid: Vec::default(), + episode_end: Vec::default(), + episode_end_valid: Vec::default(), + train_step_numeric: Vec::default(), + env_step_numeric: Vec::default(), + env_step_valid_numeric: Vec::default(), + episode_end_numeric: Vec::default(), + episode_end_valid_numeric: Vec::default(), + metric_definitions: HashMap::default(), + } + } +} + +impl RLMetrics { + /// Register a training metric. + pub(crate) fn register_text_metric_agent(&mut self, metric: Me) + where + ES::ItemSync: Adaptor + 'static, + { + let metric = MetricWrapper::new(metric); + self.register_definition(&metric); + self.env_step.push(Box::new(metric)) + } + + /// Register a training metric. + pub(crate) fn register_agent_metric(&mut self, metric: Me) + where + ES::ItemSync: Adaptor + 'static, + { + let metric = MetricWrapper::new(metric); + self.register_definition(&metric); + self.env_step_numeric.push(Box::new(metric)) + } + + /// Register a training metric. + pub(crate) fn register_text_metric_train(&mut self, metric: Me) + where + TS::ItemSync: Adaptor + 'static, + { + let metric = MetricWrapper::new(metric); + self.register_definition(&metric); + self.train_step.push(Box::new(metric)) + } + + /// Register a training metric. + pub(crate) fn register_metric_train(&mut self, metric: Me) + where + TS::ItemSync: Adaptor + 'static, + { + let metric = MetricWrapper::new(metric); + self.register_definition(&metric); + self.train_step_numeric.push(Box::new(metric)) + } + + /// Register a validation env-step metric. + pub(crate) fn register_text_metric_agent_valid(&mut self, metric: Me) + where + ES::ItemSync: Adaptor + 'static, + { + let metric = MetricWrapper::new(metric); + self.register_definition(&metric); + self.env_step_valid.push(Box::new(metric)) + } + + /// Register a validation env-step numeric metric. + pub(crate) fn register_agent_metric_valid(&mut self, metric: Me) + where + ES::ItemSync: Adaptor + 'static, + { + let metric = MetricWrapper::new(metric); + self.register_definition(&metric); + self.env_step_valid_numeric.push(Box::new(metric)) + } + + /// Register an episode-end metric. + pub(crate) fn register_text_metric_episode(&mut self, metric: Me) + where + EpisodeSummary: Adaptor + 'static, + { + let metric = MetricWrapper::new(metric); + self.register_definition(&metric); + self.episode_end.push(Box::new(metric)) + } + + /// Register an episode-end numeric metric. + pub(crate) fn register_episode_metric(&mut self, metric: Me) + where + EpisodeSummary: Adaptor + 'static, + { + let metric = MetricWrapper::new(metric); + self.register_definition(&metric); + self.episode_end_numeric.push(Box::new(metric)) + } + + /// Register an episode-end metric for validation. + pub(crate) fn register_text_metric_episode_valid(&mut self, metric: Me) + where + EpisodeSummary: Adaptor + 'static, + { + let metric = MetricWrapper::new(metric); + self.register_definition(&metric); + self.episode_end_valid.push(Box::new(metric)) + } + + /// Register an episode-end numeric metric for validation. + pub(crate) fn register_episode_metric_valid( + &mut self, + metric: Me, + ) where + EpisodeSummary: Adaptor + 'static, + { + let metric = MetricWrapper::new(metric); + self.register_definition(&metric); + self.episode_end_valid_numeric.push(Box::new(metric)) + } + + fn register_definition(&mut self, metric: &MetricWrapper) { + self.metric_definitions.insert( + metric.id.clone(), + MetricDefinition::new(metric.id.clone(), &metric.metric), + ); + } + + /// Get metric definitions for all splits + pub(crate) fn metric_definitions(&mut self) -> Vec { + self.metric_definitions.values().cloned().collect() + } + + /// Update the training information from the training item. + pub(crate) fn update_train_step( + &mut self, + item: &EvaluationItem, + metadata: &MetricMetadata, + ) -> MetricsUpdate { + let mut entries = Vec::with_capacity(self.train_step.len()); + let mut entries_numeric = Vec::with_capacity(self.train_step_numeric.len()); + + for metric in self.train_step.iter_mut() { + let state = metric.update(&item.item, metadata); + entries.push(state); + } + + for metric in self.train_step_numeric.iter_mut() { + let numeric_update = metric.update(&item.item, metadata); + entries_numeric.push(numeric_update); + } + + MetricsUpdate::new(entries, entries_numeric) + } + + /// Update the env-step metrics from an environment step item. + pub(crate) fn update_env_step( + &mut self, + item: &EvaluationItem, + metadata: &MetricMetadata, + ) -> MetricsUpdate { + let mut entries = Vec::with_capacity(self.env_step.len()); + let mut entries_numeric = Vec::with_capacity(self.env_step_numeric.len()); + + for metric in self.env_step.iter_mut() { + let state = metric.update(&item.item, metadata); + entries.push(state); + } + + for metric in self.env_step_numeric.iter_mut() { + let numeric_update = metric.update(&item.item, metadata); + entries_numeric.push(numeric_update); + } + + MetricsUpdate::new(entries, entries_numeric) + } + + /// Update the env-step metrics for validation from an environment step item. + pub(crate) fn update_env_step_valid( + &mut self, + item: &EvaluationItem, + metadata: &MetricMetadata, + ) -> MetricsUpdate { + let mut entries = Vec::with_capacity(self.env_step_valid.len()); + let mut entries_numeric = Vec::with_capacity(self.env_step_valid_numeric.len()); + + for metric in self.env_step_valid.iter_mut() { + let state = metric.update(&item.item, metadata); + entries.push(state); + } + + for metric in self.env_step_valid_numeric.iter_mut() { + let numeric_update = metric.update(&item.item, metadata); + entries_numeric.push(numeric_update); + } + + MetricsUpdate::new(entries, entries_numeric) + } + + /// Update the episode-end metrics from an episode summary. + pub(crate) fn update_episode_end( + &mut self, + item: &EvaluationItem, + metadata: &MetricMetadata, + ) -> MetricsUpdate { + let mut entries = Vec::with_capacity(self.episode_end.len()); + let mut entries_numeric = Vec::with_capacity(self.episode_end_numeric.len()); + + for metric in self.episode_end.iter_mut() { + let state = metric.update(&item.item, metadata); + entries.push(state); + } + + for metric in self.episode_end_numeric.iter_mut() { + let numeric_update = metric.update(&item.item, metadata); + entries_numeric.push(numeric_update); + } + + MetricsUpdate::new(entries, entries_numeric) + } + + /// Update the episode-end metrics for validation from an episode summary. + pub(crate) fn update_episode_end_valid( + &mut self, + item: &EvaluationItem, + metadata: &MetricMetadata, + ) -> MetricsUpdate { + let mut entries = Vec::with_capacity(self.episode_end_valid.len()); + let mut entries_numeric = Vec::with_capacity(self.episode_end_valid_numeric.len()); + + for metric in self.episode_end_valid.iter_mut() { + let state = metric.update(&item.item, metadata); + entries.push(state); + } + + for metric in self.episode_end_valid_numeric.iter_mut() { + let numeric_update = metric.update(&item.item, metadata); + entries_numeric.push(numeric_update); + } + + MetricsUpdate::new(entries, entries_numeric) + } +} diff --git a/crates/burn-train/src/metric/processor/rl_processor.rs b/crates/burn-train/src/metric/processor/rl_processor.rs new file mode 100644 index 000000000..c4a602d86 --- /dev/null +++ b/crates/burn-train/src/metric/processor/rl_processor.rs @@ -0,0 +1,151 @@ +use std::sync::Arc; + +use crate::{ + AgentEvaluationEvent, EventProcessorTraining, ItemLazy, RLEvent, RLMetrics, + metric::store::{Event, EventStoreClient, MetricsUpdate}, + renderer::{MetricState, MetricsRenderer, ProgressType, TrainingProgress}, +}; + +/// An [event processor](EventProcessorTraining) that handles: +/// - Computing and storing metrics in an [event store](crate::metric::store::EventStore). +/// - Render metrics using a [metrics renderer](MetricsRenderer). +#[derive(new)] +pub struct RLEventProcessor { + metrics: RLMetrics, + renderer: Box, + store: Arc, +} + +impl RLEventProcessor { + fn progress_indicators(&self, progress: &TrainingProgress) -> Vec { + let indicators = vec![ProgressType::Detailed { + tag: String::from("Step"), + progress: progress.global_progress.clone(), + }]; + + indicators + } + + fn progress_indicators_eval(&self, progress: &TrainingProgress) -> Vec { + let indicators = vec![ProgressType::Detailed { + tag: String::from("Step"), + progress: progress.global_progress.clone(), + }]; + + indicators + } +} + +impl RLEventProcessor { + fn process_update_train(&mut self, update: MetricsUpdate) { + self.store + .add_event_train(crate::metric::store::Event::MetricsUpdate(update.clone())); + + update + .entries + .into_iter() + .for_each(|entry| self.renderer.update_train(MetricState::Generic(entry))); + + update + .entries_numeric + .into_iter() + .for_each(|numeric_update| { + self.renderer.update_train(MetricState::Numeric( + numeric_update.entry, + numeric_update.numeric_entry, + )) + }); + } + + fn process_update_valid(&mut self, update: MetricsUpdate) { + self.store + .add_event_valid(crate::metric::store::Event::MetricsUpdate(update.clone())); + + update + .entries + .into_iter() + .for_each(|entry| self.renderer.update_valid(MetricState::Generic(entry))); + + update + .entries_numeric + .into_iter() + .for_each(|numeric_update| { + self.renderer.update_valid(MetricState::Numeric( + numeric_update.entry, + numeric_update.numeric_entry, + )) + }); + } +} + +impl EventProcessorTraining, AgentEvaluationEvent> + for RLEventProcessor +{ + fn process_train(&mut self, event: RLEvent) { + match event { + RLEvent::Start => { + let definitions = self.metrics.metric_definitions(); + self.store + .add_event_train(Event::MetricsInit(definitions.clone())); + definitions + .iter() + .for_each(|definition| self.renderer.register_metric(definition.clone())); + } + RLEvent::TrainStep(item) => { + let item = item.sync(); + let metadata = (&item).into(); + + let update = self.metrics.update_train_step(&item, &metadata); + self.process_update_train(update); + } + RLEvent::TimeStep(item) => { + let item = item.sync(); + let progress = (&item).into(); + let metadata = (&item).into(); + + let update = self.metrics.update_env_step(&item, &metadata); + self.process_update_train(update); + let status = self.progress_indicators(&progress); + self.renderer.render_train(progress, status); + } + RLEvent::EpisodeEnd(item) => { + let item = item.sync(); + let metadata = (&item).into(); + + let update = self.metrics.update_episode_end(&item, &metadata); + self.process_update_train(update); + } + RLEvent::End(learner_summary) => { + self.renderer.on_train_end(learner_summary).ok(); + } + } + } + + fn process_valid(&mut self, event: AgentEvaluationEvent) { + match event { + AgentEvaluationEvent::Start => {} // no-op for now + AgentEvaluationEvent::TimeStep(item) => { + let item = item.sync(); + let metadata = (&item).into(); + + let update = self.metrics.update_env_step_valid(&item, &metadata); + self.process_update_valid(update); + } + AgentEvaluationEvent::EpisodeEnd(item) => { + let item = item.sync(); + let progress = (&item).into(); + let metadata = (&item).into(); + + let update = self.metrics.update_episode_end_valid(&item, &metadata); + self.process_update_valid(update); + let status = self.progress_indicators_eval(&progress); + self.renderer.render_valid(progress, status); + } + AgentEvaluationEvent::End => {} // no-op for now + } + } + + fn renderer(self) -> Box { + self.renderer + } +} diff --git a/crates/burn-train/src/metric/rl/cum_reward.rs b/crates/burn-train/src/metric/rl/cum_reward.rs new file mode 100644 index 000000000..28505e3b9 --- /dev/null +++ b/crates/burn-train/src/metric/rl/cum_reward.rs @@ -0,0 +1,78 @@ +use std::sync::Arc; + +use super::super::{ + MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry, + state::{FormatOptions, NumericMetricState}, +}; +use crate::metric::{Metric, MetricName, Numeric, SerializedEntry}; + +/// Metric for the cumulative reward of the last completed episode. +#[derive(Clone)] +pub struct CumulativeRewardMetric { + name: MetricName, + state: NumericMetricState, +} + +impl CumulativeRewardMetric { + /// Creates a new episode length metric. + pub fn new() -> Self { + Self { + name: Arc::new("Cum. Reward".to_string()), + state: NumericMetricState::new(), + } + } +} + +impl Default for CumulativeRewardMetric { + fn default() -> Self { + Self::new() + } +} + +/// The [CumulativeRewardMetric](CumulativeRewardMetric) input type. +#[derive(new)] +pub struct CumulativeRewardInput { + cum_reward: f64, +} + +impl Metric for CumulativeRewardMetric { + type Input = CumulativeRewardInput; + + fn update( + &mut self, + item: &CumulativeRewardInput, + _metadata: &MetricMetadata, + ) -> SerializedEntry { + self.state.update( + item.cum_reward, + 1, + FormatOptions::new(self.name()).precision(2), + ) + } + + fn clear(&mut self) { + self.state.reset() + } + + fn name(&self) -> MetricName { + self.name.clone() + } + + fn attributes(&self) -> MetricAttributes { + NumericAttributes { + unit: None, + higher_is_better: true, + } + .into() + } +} + +impl Numeric for CumulativeRewardMetric { + fn value(&self) -> NumericEntry { + self.state.current_value() + } + + fn running_value(&self) -> NumericEntry { + self.state.running_value() + } +} diff --git a/crates/burn-train/src/metric/rl/ep_len.rs b/crates/burn-train/src/metric/rl/ep_len.rs new file mode 100644 index 000000000..90d90c47f --- /dev/null +++ b/crates/burn-train/src/metric/rl/ep_len.rs @@ -0,0 +1,71 @@ +use std::sync::Arc; + +use super::super::{ + MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry, + state::{FormatOptions, NumericMetricState}, +}; +use crate::metric::{Metric, MetricName, Numeric, SerializedEntry}; + +/// Metric for the length of the last completed episode. +#[derive(Clone)] +pub struct EpisodeLengthMetric { + name: MetricName, + state: NumericMetricState, +} + +impl EpisodeLengthMetric { + /// Creates a new episode length metric. + pub fn new() -> Self { + Self { + name: Arc::new("Episode length".to_string()), + state: NumericMetricState::new(), + } + } +} + +impl Default for EpisodeLengthMetric { + fn default() -> Self { + Self::new() + } +} + +/// The [EpisodeLengthMetric](EpisodeLengthMetric) input type. +#[derive(new)] +pub struct EpisodeLengthInput { + ep_len: f64, +} + +impl Metric for EpisodeLengthMetric { + type Input = EpisodeLengthInput; + + fn update(&mut self, item: &EpisodeLengthInput, _metadata: &MetricMetadata) -> SerializedEntry { + self.state + .update(item.ep_len, 1, FormatOptions::new(self.name()).precision(0)) + } + + fn clear(&mut self) { + self.state.reset() + } + + fn name(&self) -> MetricName { + self.name.clone() + } + + fn attributes(&self) -> MetricAttributes { + NumericAttributes { + unit: Some(String::from("steps")), + higher_is_better: true, + } + .into() + } +} + +impl Numeric for EpisodeLengthMetric { + fn value(&self) -> NumericEntry { + self.state.current_value() + } + + fn running_value(&self) -> NumericEntry { + self.state.running_value() + } +} diff --git a/crates/burn-train/src/metric/rl/exploration_rate.rs b/crates/burn-train/src/metric/rl/exploration_rate.rs new file mode 100644 index 000000000..66a7acf92 --- /dev/null +++ b/crates/burn-train/src/metric/rl/exploration_rate.rs @@ -0,0 +1,78 @@ +use std::sync::Arc; + +use super::super::{ + MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry, + state::{FormatOptions, NumericMetricState}, +}; +use crate::metric::{Metric, MetricName, Numeric, SerializedEntry}; + +/// Metric for the length of the last completed episode. +#[derive(Clone)] +pub struct ExplorationRateMetric { + name: MetricName, + state: NumericMetricState, +} + +impl ExplorationRateMetric { + /// Creates a new episode length metric. + pub fn new() -> Self { + Self { + name: Arc::new("Exploration rate".to_string()), + state: NumericMetricState::new(), + } + } +} + +impl Default for ExplorationRateMetric { + fn default() -> Self { + Self::new() + } +} + +/// The [ExplorationRateMetric](ExplorationRateMetric) input type. +#[derive(new)] +pub struct ExplorationRateInput { + exploration_rate: f64, +} + +impl Metric for ExplorationRateMetric { + type Input = ExplorationRateInput; + + fn update( + &mut self, + item: &ExplorationRateInput, + _metadata: &MetricMetadata, + ) -> SerializedEntry { + self.state.update( + item.exploration_rate, + 1, + FormatOptions::new(self.name()).precision(3), + ) + } + + fn clear(&mut self) { + self.state.reset() + } + + fn name(&self) -> MetricName { + self.name.clone() + } + + fn attributes(&self) -> MetricAttributes { + NumericAttributes { + unit: Some(String::from("%")), + higher_is_better: false, + } + .into() + } +} + +impl Numeric for ExplorationRateMetric { + fn value(&self) -> NumericEntry { + self.state.current_value() + } + + fn running_value(&self) -> NumericEntry { + self.state.running_value() + } +} diff --git a/crates/burn-train/src/metric/rl/mod.rs b/crates/burn-train/src/metric/rl/mod.rs new file mode 100644 index 000000000..f2a2d22b7 --- /dev/null +++ b/crates/burn-train/src/metric/rl/mod.rs @@ -0,0 +1,7 @@ +mod cum_reward; +mod ep_len; +mod exploration_rate; + +pub use cum_reward::*; +pub use ep_len::*; +pub use exploration_rate::*; diff --git a/crates/burn-train/src/metric/state.rs b/crates/burn-train/src/metric/state.rs index e6dfc43f9..8791fe0a0 100644 --- a/crates/burn-train/src/metric/state.rs +++ b/crates/burn-train/src/metric/state.rs @@ -109,6 +109,7 @@ impl NumericMetricState { None => (format!("{value_current}"), format!("{value_running}")), }; + // TODO: naming inconsistent with RL. let formatted = match format.unit { Some(unit) => { format!("epoch {formatted_running} {unit} - batch {formatted_current} {unit}") diff --git a/crates/burn-train/src/renderer/base.rs b/crates/burn-train/src/renderer/base.rs index b82be2861..ec3ad5cd8 100644 --- a/crates/burn-train/src/renderer/base.rs +++ b/crates/burn-train/src/renderer/base.rs @@ -27,14 +27,14 @@ pub trait MetricsRendererTraining: Send + Sync { /// # Arguments /// /// * `item` - The training progress. - fn render_train(&mut self, item: TrainingProgress); + fn render_train(&mut self, item: TrainingProgress, progress_indicators: Vec); /// Renders the validation progress. /// /// # Arguments /// /// * `item` - The validation progress. - fn render_valid(&mut self, item: TrainingProgress); + fn render_valid(&mut self, item: TrainingProgress, progress_indicators: Vec); /// Callback method invoked when training ends, whether it /// completed successfully or was interrupted. @@ -58,7 +58,7 @@ pub trait MetricsRenderer: MetricsRendererEvaluation + MetricsRendererTraining { /// Keep the renderer from automatically closing, requiring manual action to close it. fn manual_close(&mut self); /// Register a new metric. - fn register_metric(&mut self, _definition: MetricDefinition); + fn register_metric(&mut self, definition: MetricDefinition); } #[derive(Clone)] @@ -102,7 +102,7 @@ pub trait MetricsRendererEvaluation: Send + Sync { /// # Arguments /// /// * `item` - The training progress. - fn render_test(&mut self, item: EvaluationProgress); + fn render_test(&mut self, item: EvaluationProgress, progress_indicators: Vec); /// Callback method invoked when testing ends, whether it /// completed successfully or was interrupted. @@ -128,16 +128,13 @@ pub enum MetricState { #[derive(Debug)] pub struct TrainingProgress { /// The progress. - pub progress: Progress, + pub progress: Option, - /// The epoch. - pub epoch: usize, + /// The progress of the whole training. + pub global_progress: Progress, - /// The total number of epochs. - pub epoch_total: usize, - - /// The iteration. - pub iteration: usize, + /// The iteration, if it differs from the items processed. + pub iteration: Option, } /// Evaluation progress. @@ -146,21 +143,48 @@ pub struct EvaluationProgress { /// The progress. pub progress: Progress, - /// The iteration. - pub iteration: usize, + /// The iteration, if it is different from the processed items. + pub iteration: Option, +} + +impl From<&EvaluationProgress> for TrainingProgress { + fn from(value: &EvaluationProgress) -> Self { + TrainingProgress { + progress: None, + global_progress: value.progress.clone(), + iteration: value.iteration, + } + } } impl TrainingProgress { /// Creates a new empty training progress. pub fn none() -> Self { Self { - progress: Progress { + progress: None, + global_progress: Progress { items_processed: 0, items_total: 0, }, - epoch: 0, - epoch_total: 0, - iteration: 0, + iteration: None, } } } + +/// Type of progress indicators. +pub enum ProgressType { + /// Detailed progress. + Detailed { + /// The tag. + tag: String, + /// The progress. + progress: Progress, + }, + /// Simple value. + Value { + /// The tag. + tag: String, + /// The value. + value: usize, + }, +} diff --git a/crates/burn-train/src/renderer/cli.rs b/crates/burn-train/src/renderer/cli.rs index cc598f2ee..9f92bf4d2 100644 --- a/crates/burn-train/src/renderer/cli.rs +++ b/crates/burn-train/src/renderer/cli.rs @@ -1,6 +1,6 @@ use crate::renderer::{ EvaluationProgress, MetricState, MetricsRenderer, MetricsRendererEvaluation, - MetricsRendererTraining, TrainingProgress, + MetricsRendererTraining, ProgressType, TrainingProgress, }; /// A simple renderer for when the cli feature is not enabled. @@ -19,17 +19,17 @@ impl MetricsRendererTraining for CliMetricsRenderer { fn update_valid(&mut self, _state: MetricState) {} - fn render_train(&mut self, item: TrainingProgress) { + fn render_train(&mut self, item: TrainingProgress, _progress_indicators: Vec) { println!("{item:?}"); } - fn render_valid(&mut self, item: TrainingProgress) { + fn render_valid(&mut self, item: TrainingProgress, _progress_indicators: Vec) { println!("{item:?}"); } } impl MetricsRendererEvaluation for CliMetricsRenderer { - fn render_test(&mut self, item: EvaluationProgress) { + fn render_test(&mut self, item: EvaluationProgress, _progress_indicators: Vec) { println!("{item:?}"); } diff --git a/crates/burn-train/src/renderer/mod.rs b/crates/burn-train/src/renderer/mod.rs index ee8bbc4a4..6e441fa88 100644 --- a/crates/burn-train/src/renderer/mod.rs +++ b/crates/burn-train/src/renderer/mod.rs @@ -27,7 +27,7 @@ pub(crate) fn default_renderer( ) -> Box { #[cfg(feature = "tui")] if std::io::stdout().is_terminal() { - return Box::new(tui::TuiMetricsRenderer::new(interuptor, checkpoint)); + return Box::new(tui::TuiMetricsRendererWrapper::new(interuptor, checkpoint)); } Box::new(CliMetricsRenderer::new()) diff --git a/crates/burn-train/src/renderer/tui/metric_numeric.rs b/crates/burn-train/src/renderer/tui/metric_numeric.rs index 5b74bd5f2..f88b0ad46 100644 --- a/crates/burn-train/src/renderer/tui/metric_numeric.rs +++ b/crates/burn-train/src/renderer/tui/metric_numeric.rs @@ -65,13 +65,14 @@ impl NumericMetricsState { /// Update the state with the training progress. pub(crate) fn update_progress_train(&mut self, progress: &TrainingProgress) { - self.epoch = progress.epoch; + self.epoch = progress.global_progress.items_processed; if self.num_samples_train.is_some() { return; } - self.num_samples_train = Some(progress.progress.items_total); + // If the training only has the notion of global progress, num_samples_train remains None. + self.num_samples_train = progress.progress.as_ref().map(|p| p.items_total); } /// Update the state with the validation progress. @@ -80,16 +81,20 @@ impl NumericMetricsState { return; } + // If num_samples_train is None, keep the default max_samples for validation. if let Some(num_sample_train) = self.num_samples_train { for (_, (_recent, full)) in self.data.iter_mut() { - let ratio = progress.progress.items_total as f64 / num_sample_train as f64; + let ratio = match &progress.progress { + Some(p) => p.items_total as f64 / num_sample_train as f64, + None => progress.global_progress.items_total as f64 / num_sample_train as f64, + }; full.update_max_sample(TuiSplit::Valid, ratio); } } - self.epoch = progress.epoch; - self.num_samples_valid = Some(progress.progress.items_total); + self.epoch = progress.global_progress.items_processed; + self.num_samples_valid = progress.progress.as_ref().map(|p| p.items_total); } /// Update the state with the testing progress. diff --git a/crates/burn-train/src/renderer/tui/progress.rs b/crates/burn-train/src/renderer/tui/progress.rs index a615e1704..8e32a2e9e 100644 --- a/crates/burn-train/src/renderer/tui/progress.rs +++ b/crates/burn-train/src/renderer/tui/progress.rs @@ -36,8 +36,12 @@ impl ProgressBarState { /// Update the training progress. pub(crate) fn update_train(&mut self, progress: &TrainingProgress) { self.progress_total = calculate_progress(progress, 0, 0); + let local_progress = progress + .progress + .as_ref() + .unwrap_or(&progress.global_progress); self.progress_task = - progress.progress.items_processed as f64 / progress.progress.items_total as f64; + local_progress.items_processed as f64 / local_progress.items_total as f64; self.estimate.update(progress, self.starting_epoch); self.split = TuiSplit::Train; } @@ -45,8 +49,12 @@ impl ProgressBarState { /// Update the validation progress. pub(crate) fn update_valid(&mut self, progress: &TrainingProgress) { // We don't use the validation for the total progress yet. + let local_progress = progress + .progress + .as_ref() + .unwrap_or(&progress.global_progress); self.progress_task = - progress.progress.items_processed as f64 / progress.progress.items_total as f64; + local_progress.items_processed as f64 / local_progress.items_total as f64; self.split = TuiSplit::Valid; } @@ -187,7 +195,7 @@ impl ProgressEstimate { } // When the training has started since at least 10 seconds and completed 10 iterations. - if progress.iteration >= WARMUP_NUM_ITERATION + if progress.iteration >= Some(WARMUP_NUM_ITERATION) && self.started.elapsed() > Duration::from_secs(10) { self.init(progress, starting_epoch); @@ -195,11 +203,17 @@ impl ProgressEstimate { } fn init(&mut self, progress: &TrainingProgress, starting_epoch: usize) { - let epoch = progress.epoch - starting_epoch; - let epoch_items = (epoch - 1) * progress.progress.items_total; - let iteration_items = progress.progress.items_processed; + let epoch = progress.global_progress.items_processed - starting_epoch; + + self.warmup_num_items = match &progress.progress { + Some(local_progress) => { + let epoch_items = (epoch - 1) * local_progress.items_total; + let iteration_items = local_progress.items_processed; + epoch_items + iteration_items + } + None => epoch, + }; - self.warmup_num_items = epoch_items + iteration_items; self.started_after_warmup = Some(Instant::now()); self.progress = calculate_progress(progress, starting_epoch, self.warmup_num_items); } @@ -210,15 +224,19 @@ fn calculate_progress( starting_epoch: usize, ignore_num_items: usize, ) -> f64 { - let epoch_total = progress.epoch_total - starting_epoch; - let epoch = progress.epoch - starting_epoch; + let epoch_total = progress.global_progress.items_total - starting_epoch; + let epoch = progress.global_progress.items_processed - starting_epoch; + match &progress.progress { + Some(local_progress) => { + let total_items = local_progress.items_total * epoch_total; + let epoch_items = (epoch - 1) * local_progress.items_total; + let iteration_items = local_progress.items_processed; + let num_items = epoch_items + iteration_items - ignore_num_items; - let total_items = progress.progress.items_total * epoch_total; - let epoch_items = (epoch - 1) * progress.progress.items_total; - let iteration_items = progress.progress.items_processed; - let num_items = epoch_items + iteration_items - ignore_num_items; - - num_items as f64 / total_items as f64 + num_items as f64 / total_items as f64 + } + None => epoch as f64 / epoch_total as f64, + } } fn format_eta(eta_secs: u64) -> String { @@ -268,11 +286,14 @@ mod tests { items_processed: 5, items_total: 10, }; + let global_progress = Progress { + items_processed: 9, + items_total: 10, + }; let progress = TrainingProgress { - progress: half, - epoch: 9, - epoch_total: 10, - iteration: 500, + progress: Some(half), + global_progress: global_progress, + iteration: Some(500), }; let starting_epoch = 8; @@ -288,11 +309,14 @@ mod tests { items_processed: 110, items_total: 1000, }; + let global_progress = Progress { + items_processed: 9, + items_total: 10, + }; let progress = TrainingProgress { - progress: half, - epoch: 9, - epoch_total: 10, - iteration: 500, + progress: Some(half), + global_progress: global_progress, + iteration: Some(500), }; let starting_epoch = 8; diff --git a/crates/burn-train/src/renderer/tui/renderer.rs b/crates/burn-train/src/renderer/tui/renderer.rs index de78a38d2..58105a73e 100644 --- a/crates/burn-train/src/renderer/tui/renderer.rs +++ b/crates/burn-train/src/renderer/tui/renderer.rs @@ -2,7 +2,7 @@ use crate::metric::{MetricDefinition, MetricId}; use crate::renderer::tui::TuiSplit; use crate::renderer::{ EvaluationName, EvaluationProgress, MetricState, MetricsRenderer, MetricsRendererEvaluation, - TrainingProgress, + ProgressType, TrainingProgress, }; use crate::renderer::{MetricsRendererTraining, tui::NumericMetricsState}; use crate::{Interrupter, LearnerSummary}; @@ -17,7 +17,9 @@ use ratatui::{ }; use std::collections::HashMap; use std::panic::{set_hook, take_hook}; -use std::sync::Arc; +use std::sync::mpsc::{Receiver, Sender}; +use std::sync::{Arc, Mutex, mpsc}; +use std::thread::{JoinHandle, spawn}; use std::{ error::Error, io::{self, Stdout}, @@ -38,8 +40,86 @@ type PanicHook = Box) + 'static + Sync + S const MAX_REFRESH_RATE_MILLIS: u64 = 100; +enum TuiRendererEvent { + MetricRegistration(MetricDefinition), + MetricsUpdate((TuiSplit, TuiGroup, MetricState)), + StatusUpdateTrain((TuiSplit, TrainingProgress, Vec)), + StatusUpdateTest((EvaluationProgress, Vec)), + TrainEnd(Option), + ManualClose(), + Close(), + Persistent(), +} + /// The terminal UI metrics renderer. -pub struct TuiMetricsRenderer { +pub struct TuiMetricsRendererWrapper { + sender: mpsc::Sender, + interrupter: Interrupter, + handle_join: Option>, + kill_signal: Arc>>, +} + +impl TuiMetricsRendererWrapper { + /// Create a new terminal UI renderer. + pub fn new(interrupter: Interrupter, checkpoint: Option) -> Self { + let (sender, receiver) = mpsc::channel(); + let (kill_signal_sender, kill_signal_receiver) = mpsc::channel(); + + let interrupter_clone = interrupter.clone(); + let handle_join = spawn(move || { + let mut renderer = + TuiMetricsRenderer::new(interrupter_clone, checkpoint, kill_signal_sender); + + let tick_rate = Duration::from_millis(MAX_REFRESH_RATE_MILLIS); + loop { + match receiver.try_recv() { + Ok(event) => renderer.handle_event(event), + Err(mpsc::TryRecvError::Empty) => (), + Err(mpsc::TryRecvError::Disconnected) => { + log::error!("Renderer thread disconnected."); + break; + } + } + + // Render + if renderer.last_update.elapsed() >= tick_rate + && let Err(err) = renderer.render() + { + log::error!("Render error: {err}"); + break; + } + + if (renderer.manual_close && renderer.interrupter.should_stop()) || renderer.close { + break; + } + } + }); + + Self { + sender, + interrupter, + handle_join: Some(handle_join), + kill_signal: Arc::new(Mutex::new(kill_signal_receiver)), + } + } + + fn send_event(&self, event: TuiRendererEvent) { + if self.kill_signal.lock().unwrap().try_recv().is_ok() { + panic!("Killing training from user input.") + } + if let Err(e) = self.sender.send(event) { + log::warn!("Failed to send TUI event: {e}"); + } + } + + /// Set the renderer to persistent mode. + pub fn persistent(self) -> Self { + self.send_event(TuiRendererEvent::Persistent()); + self + } +} + +struct TuiMetricsRenderer { terminal: Terminal, last_update: std::time::Instant, progress: ProgressBarState, @@ -47,75 +127,95 @@ pub struct TuiMetricsRenderer { metrics_numeric: NumericMetricsState, metrics_text: TextMetricsState, status: StatusState, - interuptor: Interrupter, + interrupter: Interrupter, popup: PopupState, previous_panic_hook: Option>, persistent: bool, + manual_close: bool, + close: bool, summary: Option, + kill_signal: Sender<()>, } -impl MetricsRendererEvaluation for TuiMetricsRenderer { +impl MetricsRendererEvaluation for TuiMetricsRendererWrapper { fn update_test(&mut self, name: EvaluationName, state: MetricState) { - self.update_metric(TuiSplit::Test, TuiGroup::Named(name.name), state); + self.send_event(TuiRendererEvent::MetricsUpdate(( + TuiSplit::Test, + TuiGroup::Named(name.name), + state, + ))); } - fn render_test(&mut self, item: EvaluationProgress) { - self.progress.update_test(&item); - self.metrics_numeric.update_progress_test(&item); - self.status.update_test(item); - self.render().unwrap(); + fn render_test(&mut self, item: EvaluationProgress, progress_indicators: Vec) { + self.send_event(TuiRendererEvent::StatusUpdateTest(( + item, + progress_indicators, + ))); } } -impl MetricsRenderer for TuiMetricsRenderer { +impl MetricsRenderer for TuiMetricsRendererWrapper { fn manual_close(&mut self) { - loop { - self.render().unwrap(); - if self.interuptor.should_stop() { - return; - } - std::thread::sleep(Duration::from_millis(100)); - } + self.send_event(TuiRendererEvent::ManualClose()); + let _ = self.handle_join.take().unwrap().join(); } fn register_metric(&mut self, definition: MetricDefinition) { - self.metric_definitions - .insert(definition.metric_id.clone(), definition); + self.send_event(TuiRendererEvent::MetricRegistration(definition)); } } -impl MetricsRendererTraining for TuiMetricsRenderer { +impl MetricsRendererTraining for TuiMetricsRendererWrapper { fn update_train(&mut self, state: MetricState) { - self.update_metric(TuiSplit::Train, TuiGroup::Default, state); + self.send_event(TuiRendererEvent::MetricsUpdate(( + TuiSplit::Train, + TuiGroup::Default, + state, + ))); } fn update_valid(&mut self, state: MetricState) { - self.update_metric(TuiSplit::Valid, TuiGroup::Default, state); + self.send_event(TuiRendererEvent::MetricsUpdate(( + TuiSplit::Valid, + TuiGroup::Default, + state, + ))); } - fn render_train(&mut self, item: TrainingProgress) { - self.progress.update_train(&item); - self.metrics_numeric.update_progress_train(&item); - self.status.update_train(item); - self.render().unwrap(); + fn render_train(&mut self, item: TrainingProgress, progress_indicators: Vec) { + self.send_event(TuiRendererEvent::StatusUpdateTrain(( + TuiSplit::Train, + item, + progress_indicators, + ))); } - fn render_valid(&mut self, item: TrainingProgress) { - self.progress.update_valid(&item); - self.metrics_numeric.update_progress_valid(&item); - self.status.update_valid(item); - self.render().unwrap(); + fn render_valid(&mut self, item: TrainingProgress, progress_indicators: Vec) { + self.send_event(TuiRendererEvent::StatusUpdateTrain(( + TuiSplit::Valid, + item, + progress_indicators, + ))); } fn on_train_end(&mut self, summary: Option) -> Result<(), Box> { // Reset for following steps. - self.interuptor.reset(); + self.interrupter.reset(); // Update the summary - self.summary = summary; + self.send_event(TuiRendererEvent::TrainEnd(summary)); Ok(()) } } +impl Drop for TuiMetricsRendererWrapper { + fn drop(&mut self) { + if !std::thread::panicking() { + self.send_event(TuiRendererEvent::Close()); + let _ = self.handle_join.take().unwrap().join(); + } + } +} + impl TuiMetricsRenderer { fn update_metric(&mut self, split: TuiSplit, group: TuiGroup, state: MetricState) { match state { @@ -144,8 +244,11 @@ impl TuiMetricsRenderer { }; } - /// Create a new terminal UI renderer. - pub fn new(interuptor: Interrupter, checkpoint: Option) -> Self { + pub fn new( + interrupter: Interrupter, + checkpoint: Option, + kill_signal: Sender<()>, + ) -> Self { let mut stdout = io::stdout(); execute!(stdout, EnterAlternateScreen).unwrap(); enable_raw_mode().unwrap(); @@ -171,28 +274,57 @@ impl TuiMetricsRenderer { metrics_numeric: NumericMetricsState::default(), metrics_text: TextMetricsState::default(), status: StatusState::default(), - interuptor, + interrupter, popup: PopupState::Empty, previous_panic_hook: Some(previous_panic_hook), persistent: false, + manual_close: false, + close: false, summary: None, + kill_signal, } } - /// Set the renderer to persistent mode. - pub fn persistent(mut self) -> Self { - self.persistent = true; - self + fn handle_event(&mut self, event: TuiRendererEvent) { + match event { + TuiRendererEvent::MetricRegistration(definition) => { + self.metric_definitions + .insert(definition.metric_id.clone(), definition); + } + TuiRendererEvent::MetricsUpdate((split, group, state)) => { + self.update_metric(split, group, state); + } + TuiRendererEvent::StatusUpdateTrain((split, item, status)) => match split { + TuiSplit::Train => { + self.progress.update_train(&item); + self.metrics_numeric.update_progress_train(&item); + self.status.update_train(status); + } + TuiSplit::Valid => { + self.progress.update_valid(&item); + self.metrics_numeric.update_progress_valid(&item); + self.status.update_valid(status); + } + _ => (), + }, + TuiRendererEvent::StatusUpdateTest((item, status)) => { + self.progress.update_test(&item); + self.metrics_numeric.update_progress_test(&item); + self.status.update_test(status); + } + TuiRendererEvent::TrainEnd(learner_summary) => { + self.interrupter.reset(); + self.summary = learner_summary; + } + TuiRendererEvent::ManualClose() => self.manual_close = true, + TuiRendererEvent::Persistent() => self.persistent = true, + TuiRendererEvent::Close() => self.close = true, + } } fn render(&mut self) -> Result<(), Box> { - let tick_rate = Duration::from_millis(MAX_REFRESH_RATE_MILLIS); - if self.last_update.elapsed() < tick_rate { - return Ok(()); - } - self.draw()?; - self.handle_events()?; + self.handle_user_input()?; self.last_update = Instant::now(); @@ -222,7 +354,7 @@ impl TuiMetricsRenderer { Ok(()) } - fn handle_events(&mut self) -> Result<(), Box> { + fn handle_user_input(&mut self) -> Result<(), Box> { while event::poll(Duration::from_secs(0))? { let event = event::read()?; self.popup.on_event(&event); @@ -242,7 +374,7 @@ impl TuiMetricsRenderer { training loop, but any remaining code after the loop will be \ executed.", 's', - QuitPopupAccept(self.interuptor.clone()), + QuitPopupAccept(self.interrupter.clone()), ), Callback::new( "Stop the training immediately.", @@ -250,7 +382,7 @@ impl TuiMetricsRenderer { the current training fails. Any code following the training \ won't be executed.", 'k', - KillPopupAccept, + KillPopupAccept(self.kill_signal.clone()), ), Callback::new( "Cancel", @@ -337,11 +469,12 @@ impl TuiMetricsRenderer { } struct QuitPopupAccept(Interrupter); -struct KillPopupAccept; +struct KillPopupAccept(Sender<()>); struct PopupCancel; impl CallbackFn for KillPopupAccept { fn call(&self) -> bool { + self.0.send(()).unwrap(); panic!("Killing training from user input."); } } diff --git a/crates/burn-train/src/renderer/tui/status.rs b/crates/burn-train/src/renderer/tui/status.rs index b39a09e27..98d9d61c9 100644 --- a/crates/burn-train/src/renderer/tui/status.rs +++ b/crates/burn-train/src/renderer/tui/status.rs @@ -1,5 +1,6 @@ +use crate::renderer::ProgressType; + use super::TerminalFrame; -use crate::renderer::{EvaluationProgress, TrainingProgress}; use ratatui::{ prelude::{Alignment, Rect}, style::{Color, Style, Stylize}, @@ -9,7 +10,7 @@ use ratatui::{ /// Show the training status with various information. pub(crate) struct StatusState { - progress: TrainingProgress, + progress_indicators: Vec, mode: Mode, } @@ -22,7 +23,7 @@ enum Mode { impl Default for StatusState { fn default() -> Self { Self { - progress: TrainingProgress::none(), + progress_indicators: vec![], mode: Mode::Train, } } @@ -30,24 +31,23 @@ impl Default for StatusState { impl StatusState { /// Update the training information. - pub(crate) fn update_train(&mut self, progress: TrainingProgress) { - self.progress = progress; + pub(crate) fn update_train(&mut self, progress_indicators: Vec) { + self.progress_indicators = progress_indicators; self.mode = Mode::Train; } /// Update the validation information. - pub(crate) fn update_valid(&mut self, progress: TrainingProgress) { - self.progress = progress; + pub(crate) fn update_valid(&mut self, progress_indicators: Vec) { + self.progress_indicators = progress_indicators; self.mode = Mode::Valid; } /// Update the testing information. - pub(crate) fn update_test(&mut self, _progress: EvaluationProgress) { - // TODO: Use the progress here. - // self.progress = progress; + pub(crate) fn update_test(&mut self, progress_indicators: Vec) { + self.progress_indicators = progress_indicators; self.mode = Mode::Evaluation; } /// Create a view. pub(crate) fn view(&self) -> StatusView { - StatusView::new(&self.progress, &self.mode) + StatusView::new(&self.progress_indicators, &self.mode) } } @@ -56,7 +56,7 @@ pub(crate) struct StatusView { } impl StatusView { - fn new(progress: &TrainingProgress, mode: &Mode) -> Self { + fn new(progress_indicators: &[ProgressType], mode: &Mode) -> Self { let title = |title: &str| Span::from(format!(" {title} ")).bold().yellow(); let value = |value: String| Span::from(value).italic(); let mode = match mode { @@ -65,26 +65,38 @@ impl StatusView { Mode::Evaluation => "Evaluation", }; - Self { - lines: vec![ - vec![title("Mode :"), value(mode.to_string())], - vec![ - title("Epoch :"), - value(format!("{}/{}", progress.epoch, progress.epoch_total)), - ], - vec![ - title("Iteration :"), - value(format!("{}", progress.iteration)), - ], - vec![ - title("Items :"), - value(format!( - "{}/{}", - progress.progress.items_processed, progress.progress.items_total - )), - ], - ], - } + let width = progress_indicators + .iter() + .map(|p| match p { + ProgressType::Detailed { tag, .. } => tag.len(), + ProgressType::Value { tag, .. } => tag.len(), + }) + .max() + .unwrap_or(4); + + let mut lines = vec![vec![ + title(&format!("{: lines.push(vec![ + title(&format!("{: lines.push(vec![ + title(&format!("{: , size: Rect) { diff --git a/crates/burn/Cargo.toml b/crates/burn/Cargo.toml index 252ad3bfe..0ed21037d 100644 --- a/crates/burn/Cargo.toml +++ b/crates/burn/Cargo.toml @@ -118,6 +118,7 @@ cubecl = ["dep:cubecl"] audio = ["burn-core/audio"] vision = ["burn-core/vision"] +rl = ["dep:burn-rl"] # Backend ir = ["burn-ir"] @@ -177,6 +178,7 @@ burn-collective = { path = "../burn-collective", version = "=0.21.0", optional = burn-store = { path = "../burn-store", version = "=0.21.0", optional = true, default-features = false } burn-nn = { path = "../burn-nn", version = "=0.21.0", default-features = false } burn-optim = { path = "../burn-optim", version = "=0.21.0", default-features = false } +burn-rl = { path = "../burn-rl", version = "=0.21.0", optional = true, default-features = false } # Backends burn-autodiff = { path = "../burn-autodiff", version = "=0.21.0", optional = true, default-features = false } diff --git a/crates/burn/src/lib.rs b/crates/burn/src/lib.rs index 8373df007..89105aaef 100644 --- a/crates/burn/src/lib.rs +++ b/crates/burn/src/lib.rs @@ -124,6 +124,12 @@ pub mod train { pub use burn_train::*; } +/// Module for reinforcement learning. +#[cfg(feature = "rl")] +pub mod rl { + pub use burn_rl::*; +} + /// Backend module. pub mod backend; diff --git a/examples/custom-learning-strategy/Cargo.toml b/examples/custom-learning-strategy/Cargo.toml index e7f1f5a34..f431f58ab 100644 --- a/examples/custom-learning-strategy/Cargo.toml +++ b/examples/custom-learning-strategy/Cargo.toml @@ -9,7 +9,7 @@ publish = false workspace = true [dependencies] -burn = {path = "../../crates/burn", features=["autodiff", "webgpu", "vision"]} +burn = {path = "../../crates/burn", default-features = false, features=["autodiff", "webgpu", "vision"]} guide = {path = "../guide"} derive-new = { workspace = true } -log = { workspace = true } +log = { workspace = true } \ No newline at end of file diff --git a/examples/custom-learning-strategy/src/training.rs b/examples/custom-learning-strategy/src/training.rs index 2eb68fb9d..ba6325fb2 100644 --- a/examples/custom-learning-strategy/src/training.rs +++ b/examples/custom-learning-strategy/src/training.rs @@ -1,4 +1,5 @@ use crate::model::ModelConfig; +use burn::data::dataloader::Progress; use burn::record::NoStdTrainingRecorder; use burn::train::{ EventProcessorTraining, Learner, LearningComponentsTypes, SupervisedLearningStrategy, @@ -20,7 +21,7 @@ use burn::{ record::CompactRecorder, tensor::{Device, backend::AutodiffBackend}, train::{ - InferenceStep, LearnerEvent, LearnerItem, MetricEarlyStoppingStrategy, StoppingCondition, + InferenceStep, LearnerEvent, MetricEarlyStoppingStrategy, StoppingCondition, TrainingItem, metric::{ AccuracyMetric, LossMetric, store::{Aggregate, Direction, Split}, @@ -167,12 +168,11 @@ impl SupervisedLearningStrategy for MyCustomLea let item = learner.train_step(item); learner.optimizer_step(item.grads); - let item = LearnerItem::new( + let item = TrainingItem::new( item.item, progress, - epoch, - num_epochs, - iteration, + Progress::new(epoch, num_epochs), + Some(iteration), Some(learner.lr_current()), ); @@ -198,7 +198,13 @@ impl SupervisedLearningStrategy for MyCustomLea iteration += 1; let item = model_valid.step(item); - let item = LearnerItem::new(item, progress, epoch, num_epochs, iteration, None); + let item = TrainingItem::new( + item, + progress, + Progress::new(epoch, num_epochs), + Some(iteration), + None, + ); event_processor.process_valid(LearnerEvent::ProcessedItem(item)); } diff --git a/examples/custom-renderer/src/lib.rs b/examples/custom-renderer/src/lib.rs index 956939c78..44f16089a 100644 --- a/examples/custom-renderer/src/lib.rs +++ b/examples/custom-renderer/src/lib.rs @@ -7,7 +7,7 @@ use burn::{ Learner, SupervisedTraining, renderer::{ EvaluationName, EvaluationProgress, MetricState, MetricsRenderer, - MetricsRendererEvaluation, MetricsRendererTraining, TrainingProgress, + MetricsRendererEvaluation, MetricsRendererTraining, ProgressType, TrainingProgress, }, }, }; @@ -36,11 +36,11 @@ impl MetricsRendererTraining for CustomRenderer { fn update_valid(&mut self, _state: MetricState) {} - fn render_train(&mut self, item: TrainingProgress) { + fn render_train(&mut self, item: TrainingProgress, _progress_indicators: Vec) { dbg!(item); } - fn render_valid(&mut self, item: TrainingProgress) { + fn render_valid(&mut self, item: TrainingProgress, _progress_indicators: Vec) { dbg!(item); } } @@ -56,7 +56,7 @@ impl MetricsRenderer for CustomRenderer { impl MetricsRendererEvaluation for CustomRenderer { fn update_test(&mut self, _name: EvaluationName, _state: MetricState) {} - fn render_test(&mut self, item: EvaluationProgress) { + fn render_test(&mut self, item: EvaluationProgress, _progress_indicators: Vec) { dbg!(item); } } diff --git a/examples/dqn-agent/Cargo.toml b/examples/dqn-agent/Cargo.toml new file mode 100644 index 000000000..499d73b2b --- /dev/null +++ b/examples/dqn-agent/Cargo.toml @@ -0,0 +1,41 @@ +[package] +name = "dqn-agent" +edition.workspace = true +license.workspace = true +readme.workspace = true +version.workspace = true + +[features] +default = ["burn/tui"] +ndarray = ["burn/ndarray"] +ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"] +ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"] +ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"] +tch-cpu = ["burn/tch"] +tch-gpu = ["burn/tch"] +remote = ["burn/remote"] +wgpu = ["burn/wgpu", "burn/default"] +metal = ["burn/metal", "burn/default"] +cuda = ["burn/cuda"] +vulkan = ["burn/vulkan", "burn/default"] +rocm = ["burn/rocm", "burn/default"] + +[dependencies] +# Disable autotune default for now (convolutions not optimized) +burn = { path = "../../crates/burn", features = [ + "train", + # "vision", + "metrics", + "std", + "rl", + # "fusion", + "ndarray", + # "autotune", +], default-features = false } +# Just for this example. +gym-rs = { version = "0.3.1", branch = "main", git = "https://github.com/MathisWellmann/gym-rs" } +rand.workspace = true +derive-new = { workspace = true } + +[lints] +workspace = true diff --git a/examples/dqn-agent/examples/dqn-agent.rs b/examples/dqn-agent/examples/dqn-agent.rs new file mode 100644 index 000000000..896be3c48 --- /dev/null +++ b/examples/dqn-agent/examples/dqn-agent.rs @@ -0,0 +1,118 @@ +#[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", +))] +mod ndarray { + use burn::backend::{ + Autodiff, + ndarray::{NdArray, NdArrayDevice}, + }; + use dqn_agent::training; + + pub fn run() { + let device = NdArrayDevice::Cpu; + training::run::>(device); + } +} + +#[cfg(feature = "tch-gpu")] +mod tch_gpu { + use burn::backend::{ + Autodiff, + libtorch::{LibTorch, LibTorchDevice}, + }; + use dqn_agent::training; + + pub fn run() { + #[cfg(not(target_os = "macos"))] + let device = LibTorchDevice::Cuda(0); + #[cfg(target_os = "macos")] + let device = LibTorchDevice::Mps; + + training::run::>(device); + } +} + +#[cfg(any(feature = "wgpu", feature = "metal", feature = "vulkan"))] +mod wgpu { + use burn::backend::{ + Autodiff, + wgpu::{Wgpu, WgpuDevice}, + }; + use dqn_agent::training; + + pub fn run() { + let device = WgpuDevice::default(); + training::run::>(device); + } +} + +#[cfg(feature = "cuda")] +mod cuda { + use burn::backend::{Autodiff, Cuda}; + use dqn_agent::training; + + pub fn run() { + let device = Default::default(); + training::run::>(device); + } +} + +#[cfg(feature = "rocm")] +mod rocm { + use burn::backend::{Autodiff, Rocm}; + use dqn_agent::training; + + pub fn run() { + let device = Default::default(); + training::run::>(device); + } +} + +#[cfg(feature = "tch-cpu")] +mod tch_cpu { + use burn::backend::{ + Autodiff, + libtorch::{LibTorch, LibTorchDevice}, + }; + use dqn_agent::training; + + pub fn run() { + let device = LibTorchDevice::Cpu; + training::run::>(device); + } +} + +#[cfg(feature = "remote")] +mod remote { + use burn::backend::{Autodiff, RemoteBackend}; + use dqn_agent::training; + + pub fn run() { + training::run::>(Default::default()); + } +} + +fn main() { + #[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", + ))] + ndarray::run(); + #[cfg(feature = "tch-gpu")] + tch_gpu::run(); + #[cfg(feature = "tch-cpu")] + tch_cpu::run(); + #[cfg(any(feature = "wgpu", feature = "metal", feature = "vulkan"))] + wgpu::run(); + #[cfg(feature = "cuda")] + cuda::run(); + #[cfg(feature = "rocm")] + rocm::run(); + #[cfg(feature = "remote")] + remote::run(); +} diff --git a/examples/dqn-agent/src/agent.rs b/examples/dqn-agent/src/agent.rs new file mode 100644 index 000000000..602fbbf55 --- /dev/null +++ b/examples/dqn-agent/src/agent.rs @@ -0,0 +1,477 @@ +use std::marker::PhantomData; + +use burn::backend::NdArray; +use burn::module::Module; +use burn::record::Record; +use burn::rl::{ + Batchable, LearnerTransitionBatch, Policy, PolicyLearner, PolicyState, RLTrainOutput, +}; +use burn::tensor::Transaction; +use burn::tensor::activation::softmax; +use burn::train::ItemLazy; +use burn::train::metric::{Adaptor, LossInput}; +use burn::{ + Tensor, + config::Config, + module::AutodiffModule, + nn::{self, loss::MseLoss}, + optim::{GradientsParams, Optimizer}, + prelude::Backend, + tensor::backend::AutodiffBackend, +}; +use rand::distr::Distribution; +use rand::distr::weighted::WeightedIndex; +use rand::rng; + +use crate::utils::{ + EpsilonGreedyPolicy, EpsilonGreedyPolicyState, create_lin_layers, soft_update_linear, +}; + +pub trait DiscreteActionModel: Module { + type Input: Clone + Send + Batchable; + + fn forward(&self, input: Self::Input) -> DiscreteLogitsTensor; +} + +#[derive(Config, Debug)] +pub struct MlpNetConfig { + /// The number of layers. + #[config(default = 3)] + pub num_layers: usize, + /// The dropout rate. + #[config(default = 0.)] + pub dropout: f64, + /// The input dimension. + #[config(default = 4)] + pub d_input: usize, + /// The output dimension. + #[config(default = 2)] + pub d_output: usize, + /// The size of hidden layers. + #[config(default = 256)] + pub d_hidden: usize, +} + +/// Multilayer Perceptron Network. +#[derive(Module, Debug)] +pub struct MlpNet { + pub linears: Vec>, + pub dropout: nn::Dropout, + pub activation: nn::Relu, +} + +impl MlpNet { + /// Create the module from the given configuration. + pub fn new(config: &MlpNetConfig, device: &B::Device) -> Self { + Self { + linears: create_lin_layers( + config.num_layers, + config.d_input, + config.d_hidden, + config.d_output, + device, + ), + dropout: nn::DropoutConfig::new(config.dropout).init(), + activation: nn::Relu::new(), + } + } +} + +#[derive(Clone)] +pub struct ObservationTensor { + pub state: Tensor, +} + +impl Batchable for ObservationTensor { + fn batch(value: Vec) -> Self { + let tensors = value.iter().map(|v| v.state.clone()).collect(); + Self { + state: Tensor::cat(tensors, 0), + } + } + + fn unbatch(self) -> Vec { + self.state + .split(1, 0) + .iter() + .map(|s| ObservationTensor { state: s.clone() }) + .collect() + } +} + +impl DiscreteActionModel for MlpNet { + type Input = ObservationTensor; + + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: `[batch_size, d_input]` + /// - output: `[batch_size, d_output]` + fn forward(&self, input: Self::Input) -> DiscreteLogitsTensor { + let mut x = input.state; + + for (i, linear) in self.linears.iter().enumerate() { + x = linear.forward(x); + x = self.dropout.forward(x); + if i < self.linears.len() - 1 { + x = self.activation.forward(x); + } + } + + DiscreteLogitsTensor { logits: x } + } +} + +#[derive(Config, Debug)] +pub struct DqnAgentConfig { + /// Discount factor (How to value long-term vs short-term rewards) + #[config(default = 0.99)] + pub gamma: f64, + /// The learning rate + #[config(default = 3e-4)] + pub learning_rate: f64, + /// The soft update rate of the target network + #[config(default = 0.005)] + pub tau: f64, + /// Initial value of epsilon (Probability to choose a random action) + #[config(default = 0.9)] + pub epsilon_start: f64, + /// Final value of epsilon (Probability to choose a random action) + #[config(default = 0.01)] + pub epsilon_end: f64, + /// The exponential rate at which the epsilon value decays. Higher = slower decay + #[config(default = 2500.0)] + pub epsilon_decay: f64, +} + +pub trait TargetModel { + fn soft_update(&self, that: &Self, tau: f64) -> Self; +} + +impl TargetModel for MlpNet { + fn soft_update(&self, that: &Self, tau: f64) -> Self { + let mut linears = Vec::with_capacity(self.linears.len()); + for i in 0..self.linears.len() { + let layer = soft_update_linear(self.linears[i].clone(), &that.linears[i].clone(), tau); + linears.insert(i, layer); + } + Self { + linears, + dropout: self.dropout.clone(), + activation: self.activation.clone(), + } + } +} + +#[derive(Clone)] +pub struct DqnState> { + model: M, + _backend: PhantomData, +} + +impl> PolicyState for DqnState { + type Record = M::Record; + + fn into_record(self) -> Self::Record { + self.model.clone().into_record() + } + + fn load_record(&self, record: Self::Record) -> Self { + Self { + model: self.model.clone().load_record(record), + _backend: PhantomData, + } + } +} + +#[derive(Clone)] +pub struct DQN> { + model: M, + _backend: PhantomData, +} + +impl> DQN { + pub fn new(policy: M) -> Self { + Self { + model: policy, + _backend: PhantomData, + } + } +} + +#[derive(Clone)] +pub struct DiscreteLogitsTensor { + pub logits: Tensor, +} + +impl Batchable for DiscreteLogitsTensor { + fn batch(value: Vec) -> Self { + let tensors = value.iter().map(|v| v.logits.clone()).collect(); + Self { + logits: Tensor::cat(tensors, 0), + } + } + + fn unbatch(self) -> Vec { + self.logits + .split(1, 0) + .iter() + .map(|l| DiscreteLogitsTensor { logits: l.clone() }) + .collect() + } +} + +#[derive(Clone)] +pub struct DiscreteActionTensor { + pub actions: Tensor, +} + +impl Batchable for DiscreteActionTensor { + fn batch(value: Vec) -> Self { + let tensors = value.iter().map(|v| v.actions.clone()).collect(); + Self { + actions: Tensor::cat(tensors, 0), + } + } + + fn unbatch(self) -> Vec { + self.actions + .split(1, 0) + .iter() + .map(|a| DiscreteActionTensor { actions: a.clone() }) + .collect() + } +} + +impl> Policy for DQN { + type Observation = M::Input; + type ActionDistribution = DiscreteLogitsTensor; + type Action = DiscreteActionTensor; + + type ActionContext = (); + type PolicyState = DqnState; + + fn forward(&mut self, states: Self::Observation) -> Self::ActionDistribution { + self.model.forward(states) + } + + fn action( + &mut self, + states: Self::Observation, + deterministic: bool, + ) -> (Self::Action, Vec) { + let logits = self.forward(states).logits; + if deterministic { + let output = DiscreteActionTensor { + actions: logits.argmax(1).float(), + }; + return (output, vec![]); + } + + let mut actions = vec![]; + let probs = softmax(logits, 1); + let probs = probs.split(1, 0); + let mut rng = rng(); + for p in probs { + let dist = WeightedIndex::new(p.to_data().to_vec::().unwrap()).unwrap(); + let action = dist.sample(&mut rng); + actions.push(Tensor::::from_floats([action], &p.device())); + } + + let output = DiscreteActionTensor { + actions: Tensor::stack(actions, 1), + }; + (output, vec![]) + } + + fn update(&mut self, update: Self::PolicyState) { + self.model = update.model; + } + + fn state(&self) -> Self::PolicyState { + DqnState { + model: self.model.clone(), + _backend: PhantomData, + } + } + + fn load_record(self, record: >::Record) -> Self { + let state = self.state().load_record(record); + Self { + model: state.model, + _backend: PhantomData, + } + } +} + +#[derive(Record)] +pub struct DqnLearningRecord, O: Optimizer> { + policy_model: M::Record, + target_model: M::Record, + optimizer: O::Record, +} + +#[derive(Clone)] +pub struct DqnLearningAgent +where + B: AutodiffBackend, + M: DiscreteActionModel + AutodiffModule + TargetModel + 'static, + M::InnerModule: DiscreteActionModel + TargetModel, + O: Optimizer + 'static, +{ + policy_model: M, + target_model: M, + agent: EpsilonGreedyPolicy>, + optimizer: O, + config: DqnAgentConfig, +} + +impl DqnLearningAgent +where + B: AutodiffBackend, + M: DiscreteActionModel + AutodiffModule + TargetModel + 'static, + M::InnerModule: DiscreteActionModel + TargetModel, + O: Optimizer + 'static, +{ + pub fn new(model: M, optimizer: O, config: DqnAgentConfig) -> Self { + let agent = EpsilonGreedyPolicy::new( + DQN::new(model.clone()), + config.epsilon_start, + config.epsilon_end, + config.epsilon_decay, + ); + Self { + policy_model: model.clone(), + target_model: model, + agent, + optimizer, + config, + } + } +} + +#[derive(Clone)] +pub struct SimpleTrainOutput { + pub policy_model_loss: Tensor, +} + +impl ItemLazy for SimpleTrainOutput { + type ItemSync = SimpleTrainOutput; + + fn sync(self) -> Self::ItemSync { + let [loss] = Transaction::default() + .register(self.policy_model_loss) + .execute() + .try_into() + .expect("Correct amount of tensor data"); + + let device = &Default::default(); + + SimpleTrainOutput { + policy_model_loss: Tensor::from_data(loss, device), + } + } +} + +impl Adaptor> for SimpleTrainOutput { + fn adapt(&self) -> LossInput { + LossInput::new(self.policy_model_loss.clone()) + } +} + +impl PolicyLearner for DqnLearningAgent +where + B: AutodiffBackend, + M: DiscreteActionModel + AutodiffModule + TargetModel + 'static, + M::Input: Clone, + M::InnerModule: DiscreteActionModel + TargetModel, + O: Optimizer + 'static, +{ + type TrainContext = SimpleTrainOutput; + type InnerPolicy = EpsilonGreedyPolicy>; + type Record = DqnLearningRecord; + + fn train( + &mut self, + input: LearnerTransitionBatch, + ) -> RLTrainOutput>::PolicyState> { + let states_batch = input.states; + let next_states_batch = input.next_states; + let actions_batch = input.actions.actions; + let rewards_batch = input.rewards; + let dones_batch = input.dones; + + // Optimize + let logits = self.policy_model.forward(states_batch).logits; + let state_action_values = logits.gather(1, actions_batch.int()); + + let next_state_values = self.target_model.forward(next_states_batch.clone()); + let next_state_values = next_state_values.logits.max_dim(1).squeeze::<1>(); + + let not_done_batch = Tensor::ones_like(&dones_batch) - dones_batch; + let expected_state_action_values = (next_state_values * not_done_batch.squeeze()) + .mul_scalar(self.config.gamma) + + rewards_batch.squeeze(); + let expected_state_action_values = expected_state_action_values.unsqueeze_dim::<2>(1); + + let loss = MseLoss::new().forward( + state_action_values, + expected_state_action_values, + nn::loss::Reduction::Mean, + ); + let gradients = loss.backward(); + let gradient_params = GradientsParams::from_grads(gradients, &self.policy_model); + self.policy_model = self.optimizer.step( + self.config.learning_rate, + self.policy_model.clone(), + gradient_params, + ); + self.target_model = self + .target_model + .soft_update(&self.policy_model, self.config.tau); + let policy_update = EpsilonGreedyPolicyState::new( + DqnState { + model: self.policy_model.clone(), + _backend: PhantomData, + }, + self.agent.state().step, + ); + self.agent.update(policy_update.clone()); + RLTrainOutput { + policy: policy_update, + item: SimpleTrainOutput { + policy_model_loss: loss, + }, + } + } + + fn policy(&self) -> Self::InnerPolicy { + self.agent.clone() + } + + fn update_policy(&mut self, update: Self::InnerPolicy) { + self.agent = update; + } + + fn record(&self) -> Self::Record { + DqnLearningRecord { + policy_model: self.policy_model.clone().into_record(), + target_model: self.target_model.clone().into_record(), + optimizer: self.optimizer.to_record(), + } + } + + fn load_record(self, record: Self::Record) -> Self { + let policy_model = self.policy_model.load_record(record.policy_model); + let target_model = self.target_model.load_record(record.target_model); + let optimizer = self.optimizer.load_record(record.optimizer); + Self { + policy_model, + target_model, + agent: self.agent, + optimizer, + config: self.config, + } + } +} diff --git a/examples/dqn-agent/src/env.rs b/examples/dqn-agent/src/env.rs new file mode 100644 index 000000000..14ac60174 --- /dev/null +++ b/examples/dqn-agent/src/env.rs @@ -0,0 +1,101 @@ +use burn::rl::{Environment, StepResult}; +use burn::{ + Tensor, + prelude::{Backend, ToElement}, +}; +use gym_rs::{ + core::Env, + envs::classical_control::cartpole::{CartPoleEnv, CartPoleObservation}, +}; + +use crate::agent::{DiscreteActionTensor, ObservationTensor}; + +#[derive(Clone)] +pub struct CartPoleAction { + action: usize, +} + +impl From> for CartPoleAction { + fn from(value: DiscreteActionTensor) -> Self { + Self { + action: value.actions.int().into_scalar().to_usize(), + } + } +} + +impl From for DiscreteActionTensor { + fn from(value: CartPoleAction) -> Self { + DiscreteActionTensor { + actions: Tensor::::from_data([value.action], &Default::default()).unsqueeze(), + } + } +} + +#[derive(Clone)] +pub struct CartPoleState { + pub state: [f64; 4], +} + +impl From for CartPoleState { + fn from(observation: CartPoleObservation) -> Self { + let vec = Vec::::from(observation); + Self { + state: [vec[0], vec[1], vec[2], vec[3]], + } + } +} +impl From for ObservationTensor { + fn from(val: CartPoleState) -> Self { + ObservationTensor { + state: Tensor::::from_floats(val.state, &Default::default()).unsqueeze(), + } + } +} + +#[derive(Clone)] +pub struct CartPoleWrapper { + gym_env: CartPoleEnv, + step_index: usize, +} + +impl Default for CartPoleWrapper { + fn default() -> Self { + Self::new() + } +} + +impl CartPoleWrapper { + pub fn new() -> Self { + Self { + gym_env: CartPoleEnv::new(gym_rs::utils::renderer::RenderMode::None), + step_index: 0, + } + } +} + +impl Environment for CartPoleWrapper { + type State = CartPoleState; + type Action = CartPoleAction; + + const MAX_STEPS: usize = 500; + + fn state(&self) -> Self::State { + CartPoleState::from(self.gym_env.state) + } + + fn step(&mut self, action: Self::Action) -> StepResult { + let action_reward = self.gym_env.step(action.action); + self.step_index += 1; + StepResult { + next_state: CartPoleState::from(action_reward.observation), + reward: action_reward.reward.into_inner(), + done: action_reward.done, + truncated: self.step_index >= Self::MAX_STEPS, + } + } + + fn reset(&mut self) { + self.gym_env.reset(None, false, None); + self.step_index = 0; + } +} diff --git a/examples/dqn-agent/src/lib.rs b/examples/dqn-agent/src/lib.rs new file mode 100644 index 000000000..fb262e004 --- /dev/null +++ b/examples/dqn-agent/src/lib.rs @@ -0,0 +1,4 @@ +pub mod agent; +pub mod env; +pub mod training; +pub mod utils; diff --git a/examples/dqn-agent/src/training.rs b/examples/dqn-agent/src/training.rs new file mode 100644 index 000000000..691530e25 --- /dev/null +++ b/examples/dqn-agent/src/training.rs @@ -0,0 +1,64 @@ +use burn::{ + grad_clipping::GradientClippingConfig, + optim::AdamWConfig, + record::CompactRecorder, + tensor::backend::AutodiffBackend, + train::{ + OffPolicyConfig, RLTraining, + metric::{CumulativeRewardMetric, EpisodeLengthMetric, ExplorationRateMetric, LossMetric}, + }, +}; + +use crate::{ + agent::{DqnAgentConfig, DqnLearningAgent, MlpNet, MlpNetConfig}, + env::CartPoleWrapper, +}; + +static ARTIFACT_DIR: &str = "/tmp/burn-example-dqn-agent"; + +pub fn run(device: B::Device) { + let dqn_config = DqnAgentConfig { + gamma: 0.99, + learning_rate: 3e-4, + tau: 0.005, + epsilon_start: 0.99, + epsilon_end: 0.05, + epsilon_decay: 6000.0, + }; + let model_config = MlpNetConfig { + num_layers: 3, + dropout: 0.0, + d_input: 4, + d_output: 2, + d_hidden: 64, + }; + let learning_config = OffPolicyConfig { + num_envs: 8, + autobatch_size: 8, + replay_buffer_size: 50_000, + train_interval: 8, + eval_interval: 4_000, + eval_episodes: 5, + train_batch_size: 128, + train_steps: 4, + warmup_steps: 0, + }; + + let policy_model = MlpNet::::new(&model_config, &device); + let optimizer = AdamWConfig::new() + .with_grad_clipping(Some(GradientClippingConfig::Value(100.0))) + .init(); + let agent = DqnLearningAgent::new(policy_model, optimizer, dqn_config); + let learner = RLTraining::new(ARTIFACT_DIR, CartPoleWrapper::new) + .metrics_train((LossMetric::new(),)) + .metrics_agent((ExplorationRateMetric::new(),)) + .metrics_episode((EpisodeLengthMetric::new(), CumulativeRewardMetric::new())) + .with_file_checkpointer(CompactRecorder::new()) + .num_steps(40_000) + .with_learning_strategy(burn::train::RLStrategies::OffPolicyStrategy( + learning_config, + )) + .summary(); + + let _result = learner.launch(agent); +} diff --git a/examples/dqn-agent/src/utils.rs b/examples/dqn-agent/src/utils.rs new file mode 100644 index 000000000..7fe477aba --- /dev/null +++ b/examples/dqn-agent/src/utils.rs @@ -0,0 +1,226 @@ +use std::marker::PhantomData; + +use burn::{ + Tensor, + module::{Param, ParamId}, + nn::{self, Linear}, + prelude::Backend, + record::Record, + rl::{Policy, PolicyState}, + tensor::Device, + train::{ + ItemLazy, + metric::{Adaptor, ExplorationRateInput}, + }, +}; +use derive_new::new; +use rand::{random, random_range}; + +use crate::agent::{DiscreteActionTensor, DiscreteLogitsTensor}; + +pub fn create_lin_layers( + num_layers: usize, + d_input: usize, + d_hidden: usize, + d_output: usize, + device: &Device, +) -> Vec> { + let mut linears = Vec::with_capacity(num_layers); + + if num_layers == 1 { + linears.push(nn::LinearConfig::new(d_input, d_output).init(device)); + return linears; + } + for i in 0..num_layers { + if i == 0 { + linears.push(nn::LinearConfig::new(d_input, d_hidden).init(device)); + } else if i == num_layers - 1 { + linears.push(nn::LinearConfig::new(d_hidden, d_output).init(device)); + } else { + linears.push(nn::LinearConfig::new(d_hidden, d_hidden).init(device)); + } + } + linears +} + +pub fn soft_update_linear(this: Linear, that: &Linear, tau: f64) -> Linear { + let weight = soft_update_tensor(&this.weight, &that.weight, tau); + let bias = match (&this.bias, &that.bias) { + (Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)), + _ => None, + }; + + Linear:: { weight, bias } +} + +fn soft_update_tensor( + this: &Param>, + that: &Param>, + tau: f64, +) -> Param> { + let that_weight = that.val(); + let this_weight = this.val(); + let new_weight = this_weight * (1.0 - tau) + that_weight * tau; + + Param::initialized(ParamId::new(), new_weight) +} + +#[derive(Clone)] +pub struct EpsilonGreedyPolicyOutput { + pub epsilon: f64, +} + +impl ItemLazy for EpsilonGreedyPolicyOutput { + type ItemSync = EpsilonGreedyPolicyOutput; + + fn sync(self) -> Self::ItemSync { + self + } +} + +impl Adaptor for EpsilonGreedyPolicyOutput { + fn adapt(&self) -> ExplorationRateInput { + ExplorationRateInput::new(self.epsilon) + } +} + +#[derive(Record)] +pub struct EpsilonGreedyPolicyRecord> { + pub inner_state: >::Record, + pub step: usize, +} + +#[derive(Clone, new)] +pub struct EpsilonGreedyPolicyState> { + pub inner_state: P::PolicyState, + pub step: usize, +} + +impl> PolicyState for EpsilonGreedyPolicyState { + type Record = EpsilonGreedyPolicyRecord; + + fn into_record(self) -> Self::Record { + EpsilonGreedyPolicyRecord { + inner_state: self.inner_state.into_record(), + step: self.step, + } + } + + fn load_record(&self, record: Self::Record) -> Self { + let inner_state = self.inner_state.load_record(record.inner_state); + Self { + inner_state, + step: record.step, + } + } +} + +#[derive(Clone, Debug)] +pub struct EpsilonGreedyPolicy> { + inner_policy: P, + eps_start: f64, + eps_end: f64, + eps_decay: f64, + step: usize, + _backend: PhantomData, +} + +impl> EpsilonGreedyPolicy { + pub fn new(inner_policy: P, eps_start: f64, eps_end: f64, eps_decay: f64) -> Self { + Self { + inner_policy, + eps_start, + eps_end, + eps_decay, + step: 0, + _backend: PhantomData, + } + } + + fn get_threshold(&self) -> f64 { + self.eps_end + + (self.eps_start - self.eps_end) * f64::exp(-(self.step as f64) / self.eps_decay) + } + + fn step(&mut self) -> f64 { + let thresh = self.get_threshold(); + self.step += 1; + thresh + } +} + +impl Policy for EpsilonGreedyPolicy +where + B: Backend, + P: Policy< + B, + ActionDistribution = DiscreteLogitsTensor, + Action = DiscreteActionTensor, + >, +{ + type ActionContext = EpsilonGreedyPolicyOutput; + type PolicyState = EpsilonGreedyPolicyState; + + type Observation = P::Observation; + type ActionDistribution = DiscreteLogitsTensor; + type Action = DiscreteActionTensor; + + fn forward(&mut self, states: Self::Observation) -> Self::ActionDistribution { + self.inner_policy.forward(states) + } + + fn action( + &mut self, + states: Self::Observation, + deterministic: bool, + ) -> (Self::Action, Vec) { + let logits = self.inner_policy.forward(states).logits; + let greedy_actions = logits.argmax(1); + let greedy_actions = greedy_actions.split(1, 0); + + let mut contexts = vec![]; + let mut actions = vec![]; + for a in greedy_actions { + let threshold = self.step(); + let threshold = if deterministic { 0.0 } else { threshold }; + contexts.push(EpsilonGreedyPolicyOutput { epsilon: threshold }); + if random::() > threshold { + actions.push(a.clone().float()); + } else { + actions.push( + Tensor::::from_floats([random_range(0..2)], &a.device()).unsqueeze(), + ); + } + } + + let output = Tensor::cat(actions, 0); + (DiscreteActionTensor { actions: output }, contexts) + } + + fn update(&mut self, update: Self::PolicyState) { + // Note : updating an epsilon greedy policy doesn't change the step. + self.inner_policy.update(update.inner_state); + } + + fn state(&self) -> Self::PolicyState { + EpsilonGreedyPolicyState { + inner_state: self.inner_policy.state(), + step: self.step, + } + } + + fn load_record(self, record: >::Record) -> Self { + let state = self.state().load_record(record); + let inner_policy = self + .inner_policy + .load_record(state.inner_state.into_record()); + EpsilonGreedyPolicy { + inner_policy, + eps_start: self.eps_start, + eps_end: self.eps_end, + eps_decay: self.eps_decay, + step: state.step, + _backend: PhantomData, + } + } +} diff --git a/xtask/src/commands/test.rs b/xtask/src/commands/test.rs index eebe4394d..26b8de217 100644 --- a/xtask/src/commands/test.rs +++ b/xtask/src/commands/test.rs @@ -118,6 +118,9 @@ pub(crate) fn handle_command( "burn-router".to_string(), "burn-tch".to_string(), "burn-wgpu".to_string(), + // dqn-agent example relies on gym-rs dependency which requires SDL2. + // It would be good not to remove the gym-rs dependency in the future. + "dqn-agent".to_string(), ]); // Burn remote tests don't work on windows for now