mirror of
https://github.com/tracel-ai/burn.git
synced 2026-05-31 19:49:48 +09:00
Burn rl (#4447)
* wip burn-rl * clean up types traits, remove paradigm trait, improve learner * code quality * book * doc * error in doc * revert TrainStep rename and remove paradigm components * trainstep refactor with assoc types * fix docs * fix warning * comment * remove evaluator step * add more rl stuff and change Adaptor trait * metrics for rl and revert adaptor change * docs + cleanup event processor for rl * policy vs agent refactor * rework renderer * policy no longer generic over env * early version of checkpointing * versioning * off policy training loop * naming consistency and loading from record * naming * add other backends * strategies and configs * env initializer and cum reward metric * transition backend thing compiles * mostly docs and naming * reorganize files * batchable trait and re-arrange some bounds * start env runner refactor * rework autobatcher and env_runnner episodes * change autobatcher dynamic and tweak dqn params * fix typos and warnings * params tweeking * file naming * fix docs * metrics api and dependency bump * address PR comments and fix lint * kill process when user kills training * handle kill signal renderer * format * remove dqn-agent from ci tests bc of dependency issue
This commit is contained in:
@@ -11,6 +11,7 @@ ignore = [
|
||||
"RUSTSEC-2024-0436", # Paste used to generate macro, should be removed at some point.
|
||||
"RUSTSEC-2025-0119", # `number_prefix` used by `tokenizers`, only in the examples.
|
||||
"RUSTSEC-2025-0141", # `bincode` is no longer maintained.
|
||||
"RUSTSEC-2024-0388", # `derivative` dependancy in the DQN example is unmaintained.
|
||||
] # advisory IDs to ignore e.g. ["RUSTSEC-2019-0001", ...]
|
||||
informational_warnings = [
|
||||
"unmaintained",
|
||||
|
||||
171
Cargo.lock
generated
171
Cargo.lock
generated
@@ -146,6 +146,15 @@ version = "1.0.100"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61"
|
||||
|
||||
[[package]]
|
||||
name = "approx"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ar_archive_writer"
|
||||
version = "0.5.1"
|
||||
@@ -661,6 +670,7 @@ dependencies = [
|
||||
"burn-nn",
|
||||
"burn-optim",
|
||||
"burn-remote",
|
||||
"burn-rl",
|
||||
"burn-rocm",
|
||||
"burn-router",
|
||||
"burn-store",
|
||||
@@ -1067,6 +1077,18 @@ dependencies = [
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "burn-rl"
|
||||
version = "0.21.0"
|
||||
dependencies = [
|
||||
"burn-core",
|
||||
"burn-ndarray",
|
||||
"burn-optim",
|
||||
"derive-new",
|
||||
"log",
|
||||
"rand 0.9.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "burn-rocm"
|
||||
version = "0.21.0"
|
||||
@@ -1183,9 +1205,11 @@ dependencies = [
|
||||
"burn-core",
|
||||
"burn-ndarray",
|
||||
"burn-optim",
|
||||
"burn-rl",
|
||||
"derive-new",
|
||||
"log",
|
||||
"nvml-wrapper",
|
||||
"rand 0.9.2",
|
||||
"ratatui",
|
||||
"rstest",
|
||||
"serde",
|
||||
@@ -1331,6 +1355,12 @@ dependencies = [
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "c_vec"
|
||||
version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fdd7a427adc0135366d99db65b36dae9237130997e560ed61118041fb72be6e8"
|
||||
|
||||
[[package]]
|
||||
name = "candle-core"
|
||||
version = "0.9.2"
|
||||
@@ -2709,6 +2739,17 @@ dependencies = [
|
||||
"powerfmt",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derivative"
|
||||
version = "2.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive-new"
|
||||
version = "0.7.0"
|
||||
@@ -2911,6 +2952,16 @@ version = "0.15.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b"
|
||||
|
||||
[[package]]
|
||||
name = "dqn-agent"
|
||||
version = "0.21.0"
|
||||
dependencies = [
|
||||
"burn",
|
||||
"derive-new",
|
||||
"gym-rs",
|
||||
"rand 0.9.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "drawille"
|
||||
version = "0.3.0"
|
||||
@@ -4045,6 +4096,23 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gym-rs"
|
||||
version = "0.3.1"
|
||||
source = "git+https://github.com/MathisWellmann/gym-rs?branch=main#5283afaa86a3a7c45c46c882cfad459f02539b62"
|
||||
dependencies = [
|
||||
"derivative",
|
||||
"derive-new",
|
||||
"log",
|
||||
"nalgebra",
|
||||
"num-traits",
|
||||
"ordered-float 4.6.0",
|
||||
"rand 0.8.5",
|
||||
"rand_pcg",
|
||||
"sdl2",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "0.4.13"
|
||||
@@ -5252,6 +5320,33 @@ dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nalgebra"
|
||||
version = "0.33.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "26aecdf64b707efd1310e3544d709c5c0ac61c13756046aaaba41be5c4f66a3b"
|
||||
dependencies = [
|
||||
"approx",
|
||||
"matrixmultiply",
|
||||
"nalgebra-macros",
|
||||
"num-complex",
|
||||
"num-rational",
|
||||
"num-traits",
|
||||
"simba",
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nalgebra-macros"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "254a5372af8fc138e36684761d3c0cdb758a4410e938babcff1c860ce14ddbfc"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.114",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "named-tensor"
|
||||
version = "0.21.0"
|
||||
@@ -5907,6 +6002,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7bb71e1b3fa6ca1c61f383464aaf2bb0e2f8e772a1f01d486832464de363b951"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
"rand 0.8.5",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -7014,6 +7111,7 @@ dependencies = [
|
||||
"libc",
|
||||
"rand_chacha 0.3.1",
|
||||
"rand_core 0.6.4",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -7053,6 +7151,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
|
||||
dependencies = [
|
||||
"getrandom 0.2.17",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -7074,6 +7173,15 @@ dependencies = [
|
||||
"rand 0.9.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_pcg"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "59cad018caf63deb318e5a4586d99a24424a364f40f1e5778c29aca23f4fc73e"
|
||||
dependencies = [
|
||||
"rand_core 0.6.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "range-alloc"
|
||||
version = "0.1.4"
|
||||
@@ -7604,6 +7712,15 @@ version = "1.0.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a50f4cf475b65d88e057964e0e9bb1f0aa9bbb2036dc65c64596b42932536984"
|
||||
|
||||
[[package]]
|
||||
name = "safe_arch"
|
||||
version = "0.7.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "96b02de82ddbe1b636e6170c21be622223aea188ef2e139be0a5b219ec215323"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "safetensors"
|
||||
version = "0.3.3"
|
||||
@@ -7704,6 +7821,31 @@ version = "3.0.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "490dcfcbfef26be6800d11870ff2df8774fa6e86d047e3e8c8a76b25655e41ca"
|
||||
|
||||
[[package]]
|
||||
name = "sdl2"
|
||||
version = "0.37.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3b498da7d14d1ad6c839729bd4ad6fc11d90a57583605f3b4df2cd709a9cd380"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"c_vec",
|
||||
"lazy_static",
|
||||
"libc",
|
||||
"sdl2-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sdl2-sys"
|
||||
version = "0.37.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "951deab27af08ed9c6068b7b0d05a93c91f0a8eb16b6b816a5e73452a43521d3"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cmake",
|
||||
"libc",
|
||||
"version-compare",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "security-framework"
|
||||
version = "2.11.1"
|
||||
@@ -7971,6 +8113,19 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "simba"
|
||||
version = "0.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c99284beb21666094ba2b75bbceda012e610f5479dfcc2d6e2426f53197ffd95"
|
||||
dependencies = [
|
||||
"approx",
|
||||
"num-complex",
|
||||
"num-traits",
|
||||
"paste",
|
||||
"wide",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "simd-adler32"
|
||||
version = "0.3.8"
|
||||
@@ -9440,6 +9595,12 @@ version = "0.2.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
|
||||
|
||||
[[package]]
|
||||
name = "version-compare"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "579a42fc0b8e0c63b76519a339be31bed574929511fa53c1a3acae26eb258f29"
|
||||
|
||||
[[package]]
|
||||
name = "version_check"
|
||||
version = "0.9.5"
|
||||
@@ -9856,6 +10017,16 @@ dependencies = [
|
||||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wide"
|
||||
version = "0.7.33"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0ce5da8ecb62bcd8ec8b7ea19f69a51275e91299be594ea5cc6ef7819e16cd03"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"safe_arch",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "widestring"
|
||||
version = "1.2.1"
|
||||
|
||||
@@ -7,7 +7,7 @@ use serde::{Serialize, de::DeserializeOwned};
|
||||
/// Trait to define a family of types which can be recorded using any [settings](PrecisionSettings).
|
||||
pub trait Record<B: Backend>: Send {
|
||||
/// Type of the item that can be serialized and deserialized.
|
||||
type Item<S: PrecisionSettings>: Serialize + DeserializeOwned;
|
||||
type Item<S: PrecisionSettings>: Serialize + DeserializeOwned + Clone;
|
||||
|
||||
/// Convert the current record into the corresponding item that follows the given [settings](PrecisionSettings).
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S>;
|
||||
|
||||
@@ -285,7 +285,7 @@ mod tests {
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn err_when_invalid_item() {
|
||||
#[derive(new, Serialize, Deserialize)]
|
||||
#[derive(new, Serialize, Deserialize, Clone)]
|
||||
struct Item<S: PrecisionSettings> {
|
||||
value: S::FloatElem,
|
||||
}
|
||||
|
||||
@@ -102,7 +102,7 @@ enum LrSchedulerItem {
|
||||
}
|
||||
|
||||
#[derive(Record)]
|
||||
/// Record item for the [componsed learning rate scheduler](ComposedLrScheduler).
|
||||
/// Record item for the [composed learning rate scheduler](ComposedLrScheduler).
|
||||
pub enum LrSchedulerRecord<B: Backend> {
|
||||
/// The linear variant.
|
||||
Linear(<LinearLrScheduler as LrScheduler>::Record<B>),
|
||||
@@ -115,7 +115,7 @@ pub enum LrSchedulerRecord<B: Backend> {
|
||||
}
|
||||
|
||||
#[derive(Record)]
|
||||
/// Records for the [componsed learning rate scheduler](ComposedLrScheduler).
|
||||
/// Records for the [composed learning rate scheduler](ComposedLrScheduler).
|
||||
pub struct ComposedLrSchedulerRecord<B: Backend> {
|
||||
schedulers: Vec<LrSchedulerRecord<B>>,
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ where
|
||||
}
|
||||
|
||||
/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item.
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
#[serde(bound = "")]
|
||||
pub enum AdaptorRecordItem<
|
||||
O: SimpleOptimizer<B::InnerBackend>,
|
||||
|
||||
@@ -56,7 +56,7 @@ impl<O: SimpleOptimizer<B>, B: Backend> Clone for AdaptorRecordV1<O, B> {
|
||||
}
|
||||
|
||||
/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item.
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
#[serde(bound = "")]
|
||||
pub enum AdaptorRecordItemV1<O: SimpleOptimizer<B>, B: Backend, S: PrecisionSettings> {
|
||||
/// Rank 0.
|
||||
|
||||
23
crates/burn-rl/Cargo.toml
Normal file
23
crates/burn-rl/Cargo.toml
Normal file
@@ -0,0 +1,23 @@
|
||||
[package]
|
||||
name = "burn-rl"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
readme.workspace = true
|
||||
version.workspace = true
|
||||
|
||||
[dependencies]
|
||||
burn-core = { path = "../burn-core", version = "=0.21.0", features = [
|
||||
"dataset",
|
||||
"std",
|
||||
], default-features = false }
|
||||
burn-optim = { path = "../burn-optim", version = "=0.21.0", features = [
|
||||
"std",
|
||||
], default-features = false }
|
||||
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0" }
|
||||
derive-new.workspace = true
|
||||
rand.workspace = true
|
||||
|
||||
log = { workspace = true }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
1
crates/burn-rl/LICENSE-APACHE
Symbolic link
1
crates/burn-rl/LICENSE-APACHE
Symbolic link
@@ -0,0 +1 @@
|
||||
../../LICENSE-APACHE
|
||||
1
crates/burn-rl/LICENSE-MIT
Symbolic link
1
crates/burn-rl/LICENSE-MIT
Symbolic link
@@ -0,0 +1 @@
|
||||
../../LICENSE-MIT
|
||||
6
crates/burn-rl/README.md
Normal file
6
crates/burn-rl/README.md
Normal file
@@ -0,0 +1,6 @@
|
||||
# Burn RL
|
||||
|
||||
<!-- This crate should be used with [burn](https://github.com/tracel-ai/burn). -->
|
||||
|
||||
<!-- [](https://crates.io/crates/burn-train)
|
||||
[](https://github.com/tracel-ai/burn-train/blob/master/README.md) -->
|
||||
46
crates/burn-rl/src/environment/base.rs
Normal file
46
crates/burn-rl/src/environment/base.rs
Normal file
@@ -0,0 +1,46 @@
|
||||
/// The result of taking a step in an environment.
|
||||
pub struct StepResult<S> {
|
||||
/// The updated state.
|
||||
pub next_state: S,
|
||||
/// The reward.
|
||||
pub reward: f64,
|
||||
/// If the environment reached a terminal state.
|
||||
pub done: bool,
|
||||
/// If the environment reached its max length.
|
||||
pub truncated: bool,
|
||||
}
|
||||
|
||||
/// Trait to be implemented for a RL environment.
|
||||
pub trait Environment {
|
||||
/// The type of the state.
|
||||
type State;
|
||||
/// The type of actions.
|
||||
type Action;
|
||||
|
||||
/// The maximum number of step for one episode.
|
||||
const MAX_STEPS: usize;
|
||||
|
||||
/// Returns the current state.
|
||||
fn state(&self) -> Self::State;
|
||||
/// Take a step in the environment given an action.
|
||||
fn step(&mut self, action: Self::Action) -> StepResult<Self::State>;
|
||||
/// Reset the environment to an initial state.
|
||||
fn reset(&mut self);
|
||||
}
|
||||
|
||||
/// Trait to define how to initialize an environment.
|
||||
/// By default, any function returning an environment implements it.
|
||||
pub trait EnvironmentInit<E: Environment>: Clone {
|
||||
/// Initialize the environment.
|
||||
fn init(&self) -> E;
|
||||
}
|
||||
|
||||
impl<F, E> EnvironmentInit<E> for F
|
||||
where
|
||||
F: Fn() -> E + Clone,
|
||||
E: Environment,
|
||||
{
|
||||
fn init(&self) -> E {
|
||||
(self)()
|
||||
}
|
||||
}
|
||||
3
crates/burn-rl/src/environment/mod.rs
Normal file
3
crates/burn-rl/src/environment/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
mod base;
|
||||
|
||||
pub use base::*;
|
||||
21
crates/burn-rl/src/lib.rs
Normal file
21
crates/burn-rl/src/lib.rs
Normal file
@@ -0,0 +1,21 @@
|
||||
#![warn(missing_docs)]
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
|
||||
//! A library for training reinforcement learning agents.
|
||||
|
||||
/// Module for implementing an environment.
|
||||
pub mod environment;
|
||||
/// Module for implementing a policy.
|
||||
pub mod policy;
|
||||
/// Transition buffer.
|
||||
pub mod transition_buffer;
|
||||
|
||||
pub use environment::*;
|
||||
pub use policy::*;
|
||||
pub use transition_buffer::*;
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) type TestBackend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {}
|
||||
314
crates/burn-rl/src/policy/async_policy.rs
Normal file
314
crates/burn-rl/src/policy/async_policy.rs
Normal file
@@ -0,0 +1,314 @@
|
||||
use std::{
|
||||
sync::{
|
||||
Arc,
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
mpsc::{self, Sender},
|
||||
},
|
||||
thread::spawn,
|
||||
};
|
||||
|
||||
use burn_core::prelude::Backend;
|
||||
|
||||
use crate::{ActionContext, Batchable, Policy, PolicyState};
|
||||
|
||||
#[derive(Clone)]
|
||||
struct PolicyInferenceServer<B: Backend, P: Policy<B>> {
|
||||
// `num_agents` used to make sure autobatching doesn't block the agents if they are less than the autobatch size.
|
||||
num_agents: Arc<AtomicUsize>,
|
||||
max_autobatch_size: usize,
|
||||
inner_policy: P,
|
||||
batch_action: Vec<ActionItem<P::Observation, P::Action, P::ActionContext>>,
|
||||
batch_logits: Vec<ForwardItem<P::Observation, P::ActionDistribution>>,
|
||||
}
|
||||
|
||||
impl<B, P> PolicyInferenceServer<B, P>
|
||||
where
|
||||
B: Backend,
|
||||
P: Policy<B>,
|
||||
P::Observation: Clone + Batchable,
|
||||
P::ActionDistribution: Clone + Batchable,
|
||||
P::Action: Clone + Batchable,
|
||||
P::ActionContext: Clone,
|
||||
{
|
||||
pub fn new(max_autobatch_size: usize, inner_policy: P) -> Self {
|
||||
Self {
|
||||
num_agents: Arc::new(AtomicUsize::new(0)),
|
||||
max_autobatch_size,
|
||||
inner_policy,
|
||||
batch_action: vec![],
|
||||
batch_logits: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push_action(&mut self, item: ActionItem<P::Observation, P::Action, P::ActionContext>) {
|
||||
self.batch_action.push(item);
|
||||
if self.len_actions()
|
||||
>= self
|
||||
.num_agents
|
||||
.load(Ordering::Relaxed)
|
||||
.min(self.max_autobatch_size)
|
||||
{
|
||||
self.flush_actions();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push_logits(&mut self, item: ForwardItem<P::Observation, P::ActionDistribution>) {
|
||||
self.batch_logits.push(item);
|
||||
if self.len_logits()
|
||||
>= self
|
||||
.num_agents
|
||||
.load(Ordering::Relaxed)
|
||||
.min(self.max_autobatch_size)
|
||||
{
|
||||
self.flush_logits();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn len_actions(&self) -> usize {
|
||||
self.batch_action.len()
|
||||
}
|
||||
|
||||
pub fn len_logits(&self) -> usize {
|
||||
self.batch_action.len()
|
||||
}
|
||||
|
||||
pub fn flush_actions(&mut self) {
|
||||
if self.len_actions() == 0 {
|
||||
return;
|
||||
}
|
||||
let input: Vec<_> = self
|
||||
.batch_action
|
||||
.iter()
|
||||
.map(|m| m.inference_state.clone())
|
||||
.collect();
|
||||
// Only deterministic if all actions are requested as deterministic.
|
||||
let deterministic = self.batch_action.iter().all(|item| item.deterministic);
|
||||
let (actions, context) = self
|
||||
.inner_policy
|
||||
.action(P::Observation::batch(input), deterministic);
|
||||
let actions: Vec<_> = actions.unbatch();
|
||||
|
||||
for (i, item) in self.batch_action.iter().enumerate() {
|
||||
item.sender
|
||||
.send(ActionContext {
|
||||
context: vec![context[i].clone()],
|
||||
action: actions[i].clone(),
|
||||
})
|
||||
.expect("Autobatcher should be able to send resulting actions.");
|
||||
}
|
||||
self.batch_action.clear();
|
||||
}
|
||||
|
||||
pub fn flush_logits(&mut self) {
|
||||
if self.len_logits() == 0 {
|
||||
return;
|
||||
}
|
||||
let input: Vec<_> = self
|
||||
.batch_logits
|
||||
.iter()
|
||||
.map(|m| m.inference_state.clone())
|
||||
.collect();
|
||||
let output = self.inner_policy.forward(P::Observation::batch(input));
|
||||
let logits: Vec<_> = output.unbatch();
|
||||
for (i, item) in self.batch_logits.iter().enumerate() {
|
||||
item.sender
|
||||
.send(logits[i].clone())
|
||||
.expect("Autobatcher should be able to send resulting probabilities.");
|
||||
}
|
||||
self.batch_action.clear();
|
||||
}
|
||||
|
||||
pub fn update_policy(&mut self, policy_update: P::PolicyState) {
|
||||
if self.len_actions() > 0 {
|
||||
self.flush_actions();
|
||||
}
|
||||
if self.len_logits() > 0 {
|
||||
self.flush_logits();
|
||||
}
|
||||
self.inner_policy.update(policy_update);
|
||||
}
|
||||
|
||||
pub fn state(&self) -> P::PolicyState {
|
||||
self.inner_policy.state()
|
||||
}
|
||||
|
||||
pub fn increment_agents(&mut self, num: usize) {
|
||||
self.num_agents.fetch_add(num, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn decrement_agents(&mut self, num: usize) {
|
||||
self.num_agents.fetch_sub(num, Ordering::Relaxed);
|
||||
if self.len_actions()
|
||||
>= self
|
||||
.num_agents
|
||||
.load(Ordering::Relaxed)
|
||||
.min(self.max_autobatch_size)
|
||||
{
|
||||
self.flush_actions();
|
||||
}
|
||||
if self.len_logits()
|
||||
>= self
|
||||
.num_agents
|
||||
.load(Ordering::Relaxed)
|
||||
.min(self.max_autobatch_size)
|
||||
{
|
||||
self.flush_logits();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum InferenceMessage<B: Backend, P: Policy<B>> {
|
||||
ActionMessage(ActionItem<P::Observation, P::Action, P::ActionContext>),
|
||||
ForwardMessage(ForwardItem<P::Observation, P::ActionDistribution>),
|
||||
PolicyUpdate(P::PolicyState),
|
||||
PolicyRequest(Sender<P::PolicyState>),
|
||||
IncrementAgents(usize),
|
||||
DecrementAgents(usize),
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ActionItem<S, A, C> {
|
||||
sender: Sender<ActionContext<A, Vec<C>>>,
|
||||
inference_state: S,
|
||||
deterministic: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ForwardItem<S, O> {
|
||||
sender: Sender<O>,
|
||||
inference_state: S,
|
||||
}
|
||||
|
||||
/// An asynchronous policy using an inference server with autobatching.
|
||||
#[derive(Clone)]
|
||||
pub struct AsyncPolicy<B: Backend, P: Policy<B>> {
|
||||
inference_state_sender: Sender<InferenceMessage<B, P>>,
|
||||
}
|
||||
|
||||
impl<B, P> AsyncPolicy<B, P>
|
||||
where
|
||||
B: Backend,
|
||||
P: Policy<B> + Clone + Send + 'static,
|
||||
P::ActionContext: Clone + Send,
|
||||
P::PolicyState: Send,
|
||||
P::Observation: Clone + Send + Batchable,
|
||||
P::ActionDistribution: Clone + Send + Batchable,
|
||||
P::Action: Clone + Send + Batchable,
|
||||
{
|
||||
/// Create the policy.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `autobatch_size` - Number of observations to accumulate before running a pass of inference.
|
||||
/// * `inner_policy` - The policy used to take actions.
|
||||
pub fn new(autobatch_size: usize, inner_policy: P) -> Self {
|
||||
let (sender, receiver) = std::sync::mpsc::channel();
|
||||
let mut autobatcher = PolicyInferenceServer::new(autobatch_size, inner_policy.clone());
|
||||
spawn(move || {
|
||||
loop {
|
||||
match receiver.recv() {
|
||||
Ok(msg) => match msg {
|
||||
InferenceMessage::ActionMessage(item) => autobatcher.push_action(item),
|
||||
InferenceMessage::ForwardMessage(item) => autobatcher.push_logits(item),
|
||||
InferenceMessage::PolicyUpdate(update) => autobatcher.update_policy(update),
|
||||
InferenceMessage::PolicyRequest(sender) => sender
|
||||
.send(autobatcher.state())
|
||||
.expect("Autobatcher should be able to send current policy state."),
|
||||
InferenceMessage::IncrementAgents(num) => autobatcher.increment_agents(num),
|
||||
InferenceMessage::DecrementAgents(num) => autobatcher.decrement_agents(num),
|
||||
},
|
||||
Err(err) => {
|
||||
log::error!("Error in AsyncPolicy : {}", err);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
inference_state_sender: sender,
|
||||
}
|
||||
}
|
||||
|
||||
/// Increment the number of agents using the inference server.
|
||||
pub fn increment_agents(&self, num: usize) {
|
||||
self.inference_state_sender
|
||||
.send(InferenceMessage::IncrementAgents(num))
|
||||
.expect("Can send message to autobatcher.")
|
||||
}
|
||||
|
||||
/// Decrement the number of agents using the inference server.
|
||||
pub fn decrement_agents(&self, num: usize) {
|
||||
self.inference_state_sender
|
||||
.send(InferenceMessage::DecrementAgents(num))
|
||||
.expect("Can send message to autobatcher.")
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, P> Policy<B> for AsyncPolicy<B, P>
|
||||
where
|
||||
B: Backend,
|
||||
P: Policy<B> + Send + 'static,
|
||||
{
|
||||
type ActionContext = P::ActionContext;
|
||||
type PolicyState = P::PolicyState;
|
||||
|
||||
type Observation = P::Observation;
|
||||
type ActionDistribution = P::ActionDistribution;
|
||||
type Action = P::Action;
|
||||
|
||||
fn forward(&mut self, states: Self::Observation) -> Self::ActionDistribution {
|
||||
let (action_sender, action_receiver) = std::sync::mpsc::channel();
|
||||
let item = ForwardItem {
|
||||
sender: action_sender,
|
||||
inference_state: states,
|
||||
};
|
||||
self.inference_state_sender
|
||||
.send(InferenceMessage::ForwardMessage(item))
|
||||
.expect("Should be able to send message to inference_server");
|
||||
action_receiver
|
||||
.recv()
|
||||
.expect("AsyncPolicy should receive queued probabilities.")
|
||||
}
|
||||
|
||||
fn action(
|
||||
&mut self,
|
||||
states: Self::Observation,
|
||||
deterministic: bool,
|
||||
) -> (Self::Action, Vec<Self::ActionContext>) {
|
||||
let (action_sender, action_receiver) = std::sync::mpsc::channel();
|
||||
let item = ActionItem {
|
||||
sender: action_sender,
|
||||
inference_state: states,
|
||||
deterministic,
|
||||
};
|
||||
self.inference_state_sender
|
||||
.send(InferenceMessage::ActionMessage(item))
|
||||
.expect("should be able to send message to inference_server.");
|
||||
let action = action_receiver
|
||||
.recv()
|
||||
.expect("AsyncPolicy should receive queued actions.");
|
||||
(action.action, action.context)
|
||||
}
|
||||
|
||||
fn update(&mut self, update: Self::PolicyState) {
|
||||
self.inference_state_sender
|
||||
.send(InferenceMessage::PolicyUpdate(update))
|
||||
.expect("AsyncPolicy should be able to send policy state.")
|
||||
}
|
||||
|
||||
fn state(&self) -> Self::PolicyState {
|
||||
let (sender, receiver) = mpsc::channel();
|
||||
self.inference_state_sender
|
||||
.send(InferenceMessage::PolicyRequest(sender))
|
||||
.expect("should be able to send message to inference_server.");
|
||||
receiver
|
||||
.recv()
|
||||
.expect("AsyncPolicy should be able to receive policy state.")
|
||||
}
|
||||
|
||||
fn load_record(self, _record: <Self::PolicyState as PolicyState<B>>::Record) -> Self {
|
||||
// Not needed for now
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
108
crates/burn-rl/src/policy/base.rs
Normal file
108
crates/burn-rl/src/policy/base.rs
Normal file
@@ -0,0 +1,108 @@
|
||||
use derive_new::new;
|
||||
|
||||
use burn_core::{prelude::*, record::Record, tensor::backend::AutodiffBackend};
|
||||
|
||||
use crate::TransitionBatch;
|
||||
|
||||
/// An action along with additional context about the decision.
|
||||
#[derive(Clone, new)]
|
||||
pub struct ActionContext<A, C> {
|
||||
/// The context.
|
||||
pub context: C,
|
||||
/// The action.
|
||||
pub action: A,
|
||||
}
|
||||
|
||||
/// The state of a policy.
|
||||
pub trait PolicyState<B: Backend> {
|
||||
/// The type of the record.
|
||||
type Record: Record<B>;
|
||||
|
||||
/// Convert the state to a record.
|
||||
fn into_record(self) -> Self::Record;
|
||||
/// Load the state from a record.
|
||||
fn load_record(&self, record: Self::Record) -> Self;
|
||||
}
|
||||
|
||||
/// Trait for a RL policy.
|
||||
pub trait Policy<B: Backend>: Clone {
|
||||
/// The observation given as input to the policy.
|
||||
type Observation;
|
||||
/// The action distribution parameters defining how the action will be sampled.
|
||||
type ActionDistribution;
|
||||
/// The action.
|
||||
type Action;
|
||||
|
||||
/// Additional context on the policy's decision.
|
||||
type ActionContext;
|
||||
/// The current parameterization of the policy.
|
||||
type PolicyState: PolicyState<B>;
|
||||
|
||||
/// Produces the action distribution from a batch of observations.
|
||||
fn forward(&mut self, obs: Self::Observation) -> Self::ActionDistribution;
|
||||
/// Gives the action from a batch of observations.
|
||||
fn action(
|
||||
&mut self,
|
||||
obs: Self::Observation,
|
||||
deterministic: bool,
|
||||
) -> (Self::Action, Vec<Self::ActionContext>);
|
||||
|
||||
/// Update the policy's parameters.
|
||||
fn update(&mut self, update: Self::PolicyState);
|
||||
/// Returns the current parameterization.
|
||||
fn state(&self) -> Self::PolicyState;
|
||||
|
||||
/// Loads the policy parameters from a record.
|
||||
fn load_record(self, record: <Self::PolicyState as PolicyState<B>>::Record) -> Self;
|
||||
}
|
||||
|
||||
/// Trait for a type that can be batched and unbatched (split).
|
||||
pub trait Batchable: Sized {
|
||||
/// Create a batch from a list of items.
|
||||
fn batch(value: Vec<Self>) -> Self;
|
||||
/// Create a list from batched items.
|
||||
fn unbatch(self) -> Vec<Self>;
|
||||
}
|
||||
|
||||
/// A training output.
|
||||
pub struct RLTrainOutput<TO, P> {
|
||||
/// The policy.
|
||||
pub policy: P,
|
||||
/// The item.
|
||||
pub item: TO,
|
||||
}
|
||||
|
||||
/// Batched transitions for a PolicyLearner.
|
||||
pub type LearnerTransitionBatch<B, P> =
|
||||
TransitionBatch<B, <P as Policy<B>>::Observation, <P as Policy<B>>::Action>;
|
||||
|
||||
/// Learner for a policy.
|
||||
pub trait PolicyLearner<B>
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
<Self::InnerPolicy as Policy<B>>::Observation: Clone + Batchable,
|
||||
<Self::InnerPolicy as Policy<B>>::ActionDistribution: Clone + Batchable,
|
||||
<Self::InnerPolicy as Policy<B>>::Action: Clone + Batchable,
|
||||
{
|
||||
/// Additional context of a training step.
|
||||
type TrainContext;
|
||||
/// The policy to train.
|
||||
type InnerPolicy: Policy<B>;
|
||||
/// The record of the learner.
|
||||
type Record: Record<B>;
|
||||
|
||||
/// Execute a training step on the policy.
|
||||
fn train(
|
||||
&mut self,
|
||||
input: LearnerTransitionBatch<B, Self::InnerPolicy>,
|
||||
) -> RLTrainOutput<Self::TrainContext, <Self::InnerPolicy as Policy<B>>::PolicyState>;
|
||||
/// Returns the learner's current policy.
|
||||
fn policy(&self) -> Self::InnerPolicy;
|
||||
/// Update the learner's policy.
|
||||
fn update_policy(&mut self, update: Self::InnerPolicy);
|
||||
|
||||
/// Convert the learner's state into a record.
|
||||
fn record(&self) -> Self::Record;
|
||||
/// Load the learner's state from a record.
|
||||
fn load_record(self, record: Self::Record) -> Self;
|
||||
}
|
||||
5
crates/burn-rl/src/policy/mod.rs
Normal file
5
crates/burn-rl/src/policy/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
mod async_policy;
|
||||
mod base;
|
||||
|
||||
pub use async_policy::*;
|
||||
pub use base::*;
|
||||
234
crates/burn-rl/src/transition_buffer/base.rs
Normal file
234
crates/burn-rl/src/transition_buffer/base.rs
Normal file
@@ -0,0 +1,234 @@
|
||||
use burn_core::{Tensor, prelude::Backend, tensor::backend::AutodiffBackend};
|
||||
use derive_new::new;
|
||||
use rand::{rng, seq::index::sample};
|
||||
|
||||
use crate::Batchable;
|
||||
|
||||
/// A state transition in an environment.
|
||||
#[derive(Clone, new)]
|
||||
pub struct Transition<B: Backend, S, A> {
|
||||
/// The initial state.
|
||||
pub state: S,
|
||||
/// The state after the step was taken.
|
||||
pub next_state: S,
|
||||
/// The action taken in the step.
|
||||
pub action: A,
|
||||
/// The reward.
|
||||
pub reward: Tensor<B, 1>,
|
||||
/// If the environment has reached a terminal state.
|
||||
pub done: Tensor<B, 1>,
|
||||
}
|
||||
|
||||
/// A batch of transitions.
|
||||
pub struct TransitionBatch<B: Backend, SB, AB> {
|
||||
/// Batched initial states.
|
||||
pub states: SB,
|
||||
/// Batched resulting states.
|
||||
pub next_states: SB,
|
||||
/// Batched actions.
|
||||
pub actions: AB,
|
||||
/// Batched rewards.
|
||||
pub rewards: Tensor<B, 2>,
|
||||
/// Batched flags for terminal states.
|
||||
pub dones: Tensor<B, 2>,
|
||||
}
|
||||
|
||||
impl<BT, B, S, A, SB, AB> From<Vec<&Transition<BT, S, A>>> for TransitionBatch<B, SB, AB>
|
||||
where
|
||||
BT: Backend,
|
||||
B: AutodiffBackend,
|
||||
S: Into<SB> + Clone,
|
||||
A: Into<AB> + Clone,
|
||||
SB: Batchable,
|
||||
AB: Batchable,
|
||||
{
|
||||
fn from(value: Vec<&Transition<BT, S, A>>) -> Self {
|
||||
let states: Vec<_> = value.iter().map(|t| t.state.clone().into()).collect();
|
||||
let next_states: Vec<_> = value.iter().map(|t| t.next_state.clone().into()).collect();
|
||||
let actions: Vec<_> = value.iter().map(|t| t.action.clone().into()).collect();
|
||||
let rewards: Vec<_> = value.iter().map(|t| t.reward.clone()).collect();
|
||||
let dones: Vec<_> = value.iter().map(|t| t.done.clone()).collect();
|
||||
|
||||
let rewards = Tensor::stack::<2>(rewards, 0);
|
||||
let dones = Tensor::stack::<2>(dones, 0);
|
||||
|
||||
Self {
|
||||
states: SB::batch(states),
|
||||
next_states: SB::batch(next_states),
|
||||
actions: AB::batch(actions),
|
||||
rewards: Tensor::from_data(rewards.to_data(), &Default::default()),
|
||||
dones: Tensor::from_data(dones.to_data(), &Default::default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A circular buffer for transitions.
|
||||
pub struct TransitionBuffer<T> {
|
||||
buffer: Vec<T>,
|
||||
capacity: usize,
|
||||
cursor: usize,
|
||||
}
|
||||
|
||||
impl<T> TransitionBuffer<T> {
|
||||
/// Creates a new circular buffer with a fixed capacity.
|
||||
pub fn new(capacity: usize) -> Self {
|
||||
Self {
|
||||
buffer: Vec::with_capacity(capacity),
|
||||
capacity,
|
||||
cursor: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add an item, overwriting the oldest if full.
|
||||
pub fn push(&mut self, item: T) {
|
||||
if self.buffer.len() < self.capacity {
|
||||
self.buffer.push(item);
|
||||
} else {
|
||||
self.buffer[self.cursor] = item;
|
||||
self.cursor = (self.cursor + 1) % self.capacity;
|
||||
}
|
||||
}
|
||||
|
||||
/// Append a list of items to the current buffer.
|
||||
pub fn append(&mut self, items: &mut Vec<T>) {
|
||||
let n = items.len();
|
||||
let mut is_overflow = false;
|
||||
if n > self.capacity {
|
||||
self.cursor = self.capacity - (n % self.capacity);
|
||||
items.drain(0..n - self.capacity);
|
||||
is_overflow = true;
|
||||
}
|
||||
let n = items.len();
|
||||
|
||||
let first_part = n.min(self.capacity - self.cursor);
|
||||
let second_part = n - first_part;
|
||||
|
||||
if is_overflow {
|
||||
if self.capacity > self.len() {
|
||||
self.buffer
|
||||
.extend(items.drain(first_part..second_part + first_part));
|
||||
} else {
|
||||
self.buffer[..second_part]
|
||||
.iter_mut()
|
||||
.zip(items.drain(first_part..second_part + first_part))
|
||||
.for_each(|(slot, item)| *slot = item);
|
||||
}
|
||||
}
|
||||
|
||||
if self.capacity > self.len() {
|
||||
self.buffer.extend(items.drain(..first_part));
|
||||
} else {
|
||||
self.buffer[self.cursor..self.cursor + first_part]
|
||||
.iter_mut()
|
||||
.zip(items.drain(..first_part))
|
||||
.for_each(|(slot, item)| *slot = item);
|
||||
}
|
||||
|
||||
if !is_overflow {
|
||||
self.buffer[..second_part]
|
||||
.iter_mut()
|
||||
.zip(items.drain(..second_part))
|
||||
.for_each(|(slot, item)| *slot = item);
|
||||
}
|
||||
|
||||
self.cursor = (self.cursor + n) % self.capacity
|
||||
}
|
||||
|
||||
/// Returns the current number of items stored.
|
||||
pub fn len(&self) -> usize {
|
||||
self.buffer.len()
|
||||
}
|
||||
|
||||
/// Returns the current number of items stored.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.buffer.is_empty()
|
||||
}
|
||||
|
||||
/// Clear the buffer.
|
||||
pub fn clear(&mut self) {
|
||||
self.buffer.clear();
|
||||
}
|
||||
|
||||
/// Sample the buffer at the given indices.
|
||||
pub fn sample(&self, indices: Vec<usize>) -> Vec<&T> {
|
||||
let mut items = Vec::with_capacity(indices.len());
|
||||
|
||||
for &idx in indices.iter() {
|
||||
if let Some(item) = self.buffer.get(idx) {
|
||||
items.push(item);
|
||||
}
|
||||
}
|
||||
items
|
||||
}
|
||||
|
||||
/// Sample `batch_size` transitions at random.
|
||||
pub fn random_sample(&self, batch_size: usize) -> Vec<&T> {
|
||||
assert!(batch_size <= self.len());
|
||||
let mut rng = rng();
|
||||
let indices = sample(&mut rng, self.len(), batch_size).into_vec();
|
||||
self.sample(indices)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use burn_core::tensor::TensorData;
|
||||
|
||||
use crate::TestBackend;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn transition() -> Transition<TestBackend, Tensor<TestBackend, 1>, Tensor<TestBackend, 1>> {
|
||||
Transition::new(
|
||||
Tensor::from_data(TensorData::from([1.0, 2.0]), &Default::default()),
|
||||
Tensor::from_data(TensorData::from([1.0, 2.0]), &Default::default()),
|
||||
Tensor::from_data(TensorData::from([1.0]), &Default::default()),
|
||||
Tensor::from_data(TensorData::from([1.0]), &Default::default()),
|
||||
Tensor::from_data(TensorData::from([1.0]), &Default::default()),
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn len_returns_number_of_elements() {
|
||||
let mut buffer: TransitionBuffer<
|
||||
Transition<TestBackend, Tensor<TestBackend, 1>, Tensor<TestBackend, 1>>,
|
||||
> = TransitionBuffer::new(2);
|
||||
assert_eq!(buffer.len(), 0);
|
||||
|
||||
buffer.push(transition());
|
||||
assert_eq!(buffer.len(), 1);
|
||||
|
||||
buffer.push(transition());
|
||||
assert_eq!(buffer.len(), 2);
|
||||
|
||||
buffer.push(transition());
|
||||
assert_eq!(buffer.len(), 2)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn append_works() {
|
||||
let mut buffer = TransitionBuffer::new(4);
|
||||
assert_eq!(buffer.len(), 0);
|
||||
|
||||
buffer.append(&mut vec![0, 1]);
|
||||
assert_eq!(buffer.len(), 2);
|
||||
assert_eq!(buffer.buffer, vec![0, 1]);
|
||||
|
||||
buffer.append(&mut vec![2, 3, 4, 5]);
|
||||
assert_eq!(buffer.len(), 4);
|
||||
assert_eq!(buffer.buffer, vec![4, 5, 2, 3]);
|
||||
|
||||
let mut buffer = TransitionBuffer::new(4);
|
||||
buffer.append(&mut vec![0, 1, 2, 3, 4, 5]);
|
||||
assert_eq!(buffer.len(), 4);
|
||||
assert_eq!(buffer.buffer, vec![4, 5, 2, 3]);
|
||||
|
||||
buffer.append(&mut vec![10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]);
|
||||
assert_eq!(buffer.len(), 4);
|
||||
assert_eq!(buffer.buffer, vec![20, 17, 18, 19]);
|
||||
|
||||
buffer.append(&mut vec![21, 22]);
|
||||
assert_eq!(buffer.len(), 4);
|
||||
assert_eq!(buffer.buffer, vec![20, 21, 22, 19]);
|
||||
}
|
||||
}
|
||||
3
crates/burn-rl/src/transition_buffer/mod.rs
Normal file
3
crates/burn-rl/src/transition_buffer/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
mod base;
|
||||
|
||||
pub use base::*;
|
||||
@@ -37,6 +37,7 @@ burn-core = { path = "../burn-core", version = "=0.21.0", features = [
|
||||
burn-optim = { path = "../burn-optim", version = "=0.21.0", features = [
|
||||
"std",
|
||||
], default-features = false }
|
||||
burn-rl = { path = "../burn-rl", version = "=0.21.0", default-features = false }
|
||||
burn-collective = { path = "../burn-collective", version = "=0.21.0", optional = true }
|
||||
|
||||
log = { workspace = true }
|
||||
@@ -62,6 +63,7 @@ async-channel = { workspace = true }
|
||||
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0" }
|
||||
rstest.workspace = true
|
||||
thiserror.workspace = true
|
||||
rand.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0" }
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use crate::{
|
||||
AsyncProcessorEvaluation, FullEventProcessorEvaluation, InferenceStep, Interrupter,
|
||||
AsyncProcessorEvaluation, EvaluationItem, FullEventProcessorEvaluation, InferenceStep,
|
||||
Interrupter,
|
||||
evaluator::components::EvaluatorComponentTypes,
|
||||
metric::processor::{EvaluatorEvent, EventProcessorEvaluation, LearnerItem},
|
||||
metric::processor::{EvaluatorEvent, EventProcessorEvaluation},
|
||||
renderer::{EvaluationName, MetricsRenderer},
|
||||
};
|
||||
use burn_core::{data::dataloader::DataLoader, module::Module};
|
||||
@@ -43,7 +44,7 @@ impl<EC: EvaluatorComponentTypes> Evaluator<EC> {
|
||||
iteration += 1;
|
||||
|
||||
let item = self.model.step(item);
|
||||
let item = LearnerItem::new(item, progress, 0, 1, iteration, None);
|
||||
let item = EvaluationItem::new(item, progress, Some(iteration));
|
||||
|
||||
self.event_processor
|
||||
.process_test(EvaluatorEvent::ProcessedItem(name.clone(), item));
|
||||
|
||||
@@ -3,6 +3,7 @@ mod base;
|
||||
mod classification;
|
||||
mod early_stopping;
|
||||
mod regression;
|
||||
mod rl;
|
||||
mod summary;
|
||||
mod supervised;
|
||||
mod train_val;
|
||||
@@ -12,6 +13,7 @@ pub use base::*;
|
||||
pub use classification::*;
|
||||
pub use early_stopping::*;
|
||||
pub use regression::*;
|
||||
pub use rl::*;
|
||||
pub use summary::*;
|
||||
pub use supervised::*;
|
||||
pub use train_val::*;
|
||||
|
||||
75
crates/burn-train/src/learner/rl/checkpointer.rs
Normal file
75
crates/burn-train/src/learner/rl/checkpointer.rs
Normal file
@@ -0,0 +1,75 @@
|
||||
use burn_core::tensor::Device;
|
||||
use burn_rl::{Policy, PolicyLearner, PolicyState};
|
||||
|
||||
use crate::RLAgentRecord;
|
||||
use crate::{
|
||||
RLComponentsTypes, RLPolicyRecord,
|
||||
checkpoint::Checkpointer,
|
||||
checkpoint::{AsyncCheckpointer, CheckpointingAction, CheckpointingStrategy},
|
||||
metric::store::EventStoreClient,
|
||||
};
|
||||
|
||||
#[derive(new)]
|
||||
/// Used to create, delete, or load checkpoints of the training process.
|
||||
pub struct RLCheckpointer<RLC: RLComponentsTypes> {
|
||||
policy: AsyncCheckpointer<RLPolicyRecord<RLC>, RLC::Backend>,
|
||||
learning_agent: AsyncCheckpointer<RLAgentRecord<RLC>, RLC::Backend>,
|
||||
strategy: Box<dyn CheckpointingStrategy>,
|
||||
}
|
||||
|
||||
impl<RLC: RLComponentsTypes> RLCheckpointer<RLC> {
|
||||
/// Create checkpoint for the training process.
|
||||
pub fn checkpoint(
|
||||
&mut self,
|
||||
policy: &RLC::PolicyState,
|
||||
learning_agent: &RLC::LearningAgent,
|
||||
epoch: usize,
|
||||
store: &EventStoreClient,
|
||||
) {
|
||||
let actions = self.strategy.checkpointing(epoch, store);
|
||||
|
||||
for action in actions {
|
||||
match action {
|
||||
CheckpointingAction::Delete(epoch) => {
|
||||
self.policy
|
||||
.delete(epoch)
|
||||
.expect("Can delete policy checkpoint.");
|
||||
self.learning_agent
|
||||
.delete(epoch)
|
||||
.expect("Can delete learning agent checkpoint.")
|
||||
}
|
||||
CheckpointingAction::Save => {
|
||||
self.policy
|
||||
.save(epoch, policy.clone().into_record())
|
||||
.expect("Can save policy checkpoint.");
|
||||
self.learning_agent
|
||||
.save(epoch, learning_agent.record())
|
||||
.expect("Can save learning agent checkpoint.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Load a training checkpoint.
|
||||
pub fn load_checkpoint(
|
||||
&self,
|
||||
learning_agent: RLC::LearningAgent,
|
||||
device: &Device<RLC::Backend>,
|
||||
epoch: usize,
|
||||
) -> RLC::LearningAgent {
|
||||
let record = self
|
||||
.policy
|
||||
.restore(epoch, device)
|
||||
.expect("Can load model checkpoint.");
|
||||
let policy = learning_agent.policy().load_record(record);
|
||||
|
||||
let record = self
|
||||
.learning_agent
|
||||
.restore(epoch, device)
|
||||
.expect("Can load learning agent checkpoint.");
|
||||
let mut learning_agent = learning_agent.load_record(record);
|
||||
learning_agent.update_policy(policy);
|
||||
|
||||
learning_agent
|
||||
}
|
||||
}
|
||||
115
crates/burn-train/src/learner/rl/components.rs
Normal file
115
crates/burn-train/src/learner/rl/components.rs
Normal file
@@ -0,0 +1,115 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use burn_core::tensor::backend::AutodiffBackend;
|
||||
use burn_rl::{Batchable, Environment, EnvironmentInit, Policy, PolicyLearner, PolicyState};
|
||||
|
||||
use crate::{AgentEvaluationEvent, AsyncProcessorTraining, ItemLazy, RLEvent};
|
||||
|
||||
/// All components used by the reinforcement learning paradigm, grouped in one trait.
|
||||
pub trait RLComponentsTypes {
|
||||
/// The backend used for training.
|
||||
type Backend: AutodiffBackend;
|
||||
/// The learning environment.
|
||||
type Env: Environment<State = Self::State, Action = Self::Action> + 'static;
|
||||
/// Specifies how to initialize the environment.
|
||||
type EnvInit: EnvironmentInit<Self::Env> + Send + 'static;
|
||||
/// The type of the environment state.
|
||||
type State: Into<<Self::Policy as Policy<Self::Backend>>::Observation> + Clone + Send + 'static;
|
||||
/// The type of the environment action.
|
||||
type Action: From<<Self::Policy as Policy<Self::Backend>>::Action>
|
||||
+ Into<<Self::Policy as Policy<Self::Backend>>::Action>
|
||||
+ Clone
|
||||
+ Send
|
||||
+ 'static;
|
||||
|
||||
/// The policy used to take actions in the environment.
|
||||
type Policy: Policy<
|
||||
Self::Backend,
|
||||
Observation = Self::PolicyObs,
|
||||
ActionDistribution = Self::PolicyAD,
|
||||
Action = Self::PolicyAction,
|
||||
ActionContext = Self::ActionContext,
|
||||
PolicyState = Self::PolicyState,
|
||||
> + Send
|
||||
+ 'static;
|
||||
/// The policy's observation type.
|
||||
type PolicyObs: Clone + Send + Batchable + 'static;
|
||||
/// The policy's action distribution type.
|
||||
type PolicyAD: Clone + Send + Batchable;
|
||||
/// The policy's action type.
|
||||
type PolicyAction: Clone + Send + Batchable;
|
||||
/// Additional data as context for an agent's action.
|
||||
type ActionContext: ItemLazy + Clone + Send + 'static;
|
||||
/// The state of the parameterized policy.
|
||||
type PolicyState: Clone + Send + PolicyState<Self::Backend> + 'static;
|
||||
|
||||
/// The learning agent.
|
||||
type LearningAgent: PolicyLearner<
|
||||
Self::Backend,
|
||||
TrainContext = Self::TrainingOutput,
|
||||
InnerPolicy = Self::Policy,
|
||||
> + Send
|
||||
+ 'static;
|
||||
/// The output data of a training step.
|
||||
type TrainingOutput: ItemLazy + Clone + Send;
|
||||
}
|
||||
|
||||
/// Concrete type that implements the [RLComponentsTypes](RLComponentsTypes) trait.
|
||||
pub struct RLComponentsMarker<B, E, EI, A> {
|
||||
_backend: PhantomData<B>,
|
||||
_env: PhantomData<E>,
|
||||
_env_init: PhantomData<EI>,
|
||||
_agent: PhantomData<A>,
|
||||
}
|
||||
|
||||
impl<B, E, EI, A> RLComponentsTypes for RLComponentsMarker<B, E, EI, A>
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
E: Environment + 'static,
|
||||
EI: EnvironmentInit<E> + Send + 'static,
|
||||
A: PolicyLearner<B> + Send + 'static,
|
||||
A::TrainContext: ItemLazy + Clone + Send,
|
||||
A::InnerPolicy: Policy<B> + Send,
|
||||
<A::InnerPolicy as Policy<B>>::Observation: Batchable + Clone + Send,
|
||||
<A::InnerPolicy as Policy<B>>::ActionDistribution: Batchable + Clone + Send,
|
||||
<A::InnerPolicy as Policy<B>>::Action: Batchable + Clone + Send,
|
||||
<A::InnerPolicy as Policy<B>>::ActionContext: ItemLazy + Clone + Send + 'static,
|
||||
<A::InnerPolicy as Policy<B>>::PolicyState: Clone + Send,
|
||||
E::State: Into<<A::InnerPolicy as Policy<B>>::Observation> + Clone + Send + 'static,
|
||||
E::Action: From<<A::InnerPolicy as Policy<B>>::Action>
|
||||
+ Into<<A::InnerPolicy as Policy<B>>::Action>
|
||||
+ Clone
|
||||
+ Send
|
||||
+ 'static,
|
||||
{
|
||||
type Backend = B;
|
||||
type Env = E;
|
||||
type EnvInit = EI;
|
||||
type LearningAgent = A;
|
||||
type Policy = A::InnerPolicy;
|
||||
type PolicyObs = <A::InnerPolicy as Policy<B>>::Observation;
|
||||
type PolicyAD = <A::InnerPolicy as Policy<B>>::ActionDistribution;
|
||||
type PolicyAction = <A::InnerPolicy as Policy<B>>::Action;
|
||||
type ActionContext = <A::InnerPolicy as Policy<B>>::ActionContext;
|
||||
type PolicyState = <A::InnerPolicy as Policy<B>>::PolicyState;
|
||||
type TrainingOutput = A::TrainContext;
|
||||
type State = E::State;
|
||||
type Action = E::Action;
|
||||
}
|
||||
|
||||
pub(crate) type RlPolicy<RLC> = <<RLC as RLComponentsTypes>::LearningAgent as PolicyLearner<
|
||||
<RLC as RLComponentsTypes>::Backend,
|
||||
>>::InnerPolicy;
|
||||
/// The event processor type for reinforcement learning.
|
||||
pub type RLEventProcessorType<RLC> = AsyncProcessorTraining<
|
||||
RLEvent<<RLC as RLComponentsTypes>::TrainingOutput, <RLC as RLComponentsTypes>::ActionContext>,
|
||||
AgentEvaluationEvent<<RLC as RLComponentsTypes>::ActionContext>,
|
||||
>;
|
||||
/// The record of the policy.
|
||||
pub type RLPolicyRecord<RLC> = <<<RLC as RLComponentsTypes>::Policy as Policy<
|
||||
<RLC as RLComponentsTypes>::Backend,
|
||||
>>::PolicyState as PolicyState<<RLC as RLComponentsTypes>::Backend>>::Record;
|
||||
/// The record of the learning agent.
|
||||
pub type RLAgentRecord<RLC> = <<RLC as RLComponentsTypes>::LearningAgent as PolicyLearner<
|
||||
<RLC as RLComponentsTypes>::Backend,
|
||||
>>::Record;
|
||||
504
crates/burn-train/src/learner/rl/env_runner/async_runner.rs
Normal file
504
crates/burn-train/src/learner/rl/env_runner/async_runner.rs
Normal file
@@ -0,0 +1,504 @@
|
||||
use rand::prelude::SliceRandom;
|
||||
use std::{
|
||||
sync::mpsc::{Receiver, Sender},
|
||||
thread::spawn,
|
||||
};
|
||||
|
||||
use burn_core::{Tensor, data::dataloader::Progress, prelude::Backend, tensor::Device};
|
||||
use burn_rl::EnvironmentInit;
|
||||
use burn_rl::Policy;
|
||||
use burn_rl::Transition;
|
||||
use burn_rl::{AsyncPolicy, Environment};
|
||||
|
||||
use crate::{
|
||||
AgentEnvLoop, AgentEvaluationEvent, EpisodeSummary, EvaluationItem, EventProcessorTraining,
|
||||
Interrupter, RLComponentsTypes, RLEvent, RLEventProcessorType, RLTimeStep, RLTrajectory,
|
||||
RlPolicy, TimeStep, Trajectory,
|
||||
};
|
||||
|
||||
enum RequestMessage {
|
||||
Step(),
|
||||
Episode(),
|
||||
}
|
||||
|
||||
/// An asynchronous agent/environement interface.
|
||||
pub struct AgentEnvAsyncLoop<BT: Backend, RLC: RLComponentsTypes> {
|
||||
env_init: RLC::EnvInit,
|
||||
id: usize,
|
||||
eval: bool,
|
||||
agent: AsyncPolicy<RLC::Backend, RlPolicy<RLC>>,
|
||||
deterministic: bool,
|
||||
transition_device: Device<BT>,
|
||||
transition_receiver: Receiver<RLTimeStep<BT, RLC>>,
|
||||
transition_sender: Sender<RLTimeStep<BT, RLC>>,
|
||||
trajectory_receiver: Receiver<RLTrajectory<BT, RLC>>,
|
||||
trajectory_sender: Sender<RLTrajectory<BT, RLC>>,
|
||||
request_sender: Option<Sender<RequestMessage>>,
|
||||
}
|
||||
|
||||
impl<BT: Backend, RLC: RLComponentsTypes> AgentEnvAsyncLoop<BT, RLC> {
|
||||
/// Create a new asynchronous runner.
|
||||
pub fn new(
|
||||
env_init: RLC::EnvInit,
|
||||
id: usize,
|
||||
eval: bool,
|
||||
agent: AsyncPolicy<RLC::Backend, RlPolicy<RLC>>,
|
||||
deterministic: bool,
|
||||
transition_device: &Device<BT>,
|
||||
) -> Self {
|
||||
let (transition_sender, transition_receiver) = std::sync::mpsc::channel();
|
||||
let (trajectory_sender, trajectory_receiver) = std::sync::mpsc::channel();
|
||||
Self {
|
||||
env_init,
|
||||
id,
|
||||
eval,
|
||||
agent: agent.clone(),
|
||||
deterministic,
|
||||
transition_device: transition_device.clone(),
|
||||
transition_receiver,
|
||||
transition_sender,
|
||||
trajectory_receiver,
|
||||
trajectory_sender,
|
||||
request_sender: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<BT, RLC> AgentEnvLoop<BT, RLC> for AgentEnvAsyncLoop<BT, RLC>
|
||||
where
|
||||
BT: Backend,
|
||||
RLC: RLComponentsTypes,
|
||||
{
|
||||
fn start(&mut self) {
|
||||
let id = self.id;
|
||||
let mut agent = self.agent.clone();
|
||||
let deterministic = self.deterministic;
|
||||
let transition_sender = self.transition_sender.clone();
|
||||
let trajectory_sender = self.trajectory_sender.clone();
|
||||
let device = self.transition_device.clone();
|
||||
let env_init = self.env_init.clone();
|
||||
|
||||
let (request_sender, request_receiver) = std::sync::mpsc::channel();
|
||||
self.request_sender = Some(request_sender);
|
||||
|
||||
let mut current_steps = vec![];
|
||||
let mut current_reward = 0.0;
|
||||
let mut step_num = 0;
|
||||
|
||||
spawn(move || {
|
||||
let mut env = env_init.init();
|
||||
env.reset();
|
||||
|
||||
let mut request_episode = false;
|
||||
loop {
|
||||
let state = env.state();
|
||||
let (action, context) = agent.action(state.clone().into(), deterministic);
|
||||
|
||||
let env_action = RLC::Action::from(action);
|
||||
let step_result = env.step(env_action.clone());
|
||||
|
||||
current_reward += step_result.reward;
|
||||
step_num += 1;
|
||||
|
||||
let transition = Transition::new(
|
||||
state.clone(),
|
||||
step_result.next_state,
|
||||
env_action,
|
||||
Tensor::from_data([step_result.reward], &device),
|
||||
Tensor::from_data(
|
||||
[(step_result.done || step_result.truncated) as i32 as f64],
|
||||
&device,
|
||||
),
|
||||
);
|
||||
|
||||
if !request_episode {
|
||||
agent.decrement_agents(1);
|
||||
let request = match request_receiver.recv() {
|
||||
Ok(req) => req,
|
||||
Err(err) => {
|
||||
log::error!("Error in env runner : {}", err);
|
||||
break;
|
||||
}
|
||||
};
|
||||
agent.increment_agents(1);
|
||||
|
||||
match request {
|
||||
RequestMessage::Step() => (),
|
||||
RequestMessage::Episode() => request_episode = true,
|
||||
}
|
||||
}
|
||||
|
||||
let time_step = TimeStep {
|
||||
env_id: id,
|
||||
transition,
|
||||
done: step_result.done,
|
||||
ep_len: step_num,
|
||||
cum_reward: current_reward,
|
||||
action_context: context[0].clone(),
|
||||
};
|
||||
current_steps.push(time_step.clone());
|
||||
|
||||
if !request_episode && let Err(err) = transition_sender.send(time_step) {
|
||||
log::error!("Error in env runner : {}", err);
|
||||
break;
|
||||
}
|
||||
|
||||
if step_result.done || step_result.truncated {
|
||||
if request_episode {
|
||||
request_episode = false;
|
||||
trajectory_sender
|
||||
.send(Trajectory {
|
||||
timesteps: current_steps.clone(),
|
||||
})
|
||||
.expect("Can send trajectory to main thread.");
|
||||
}
|
||||
current_steps.clear();
|
||||
|
||||
env.reset();
|
||||
current_reward = 0.;
|
||||
step_num = 0;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn run_steps(
|
||||
&mut self,
|
||||
num_steps: usize,
|
||||
processor: &mut RLEventProcessorType<RLC>,
|
||||
interrupter: &Interrupter,
|
||||
progress: &mut Progress,
|
||||
) -> Vec<RLTimeStep<BT, RLC>> {
|
||||
let mut items = vec![];
|
||||
for _ in 0..num_steps {
|
||||
self.request_sender
|
||||
.as_ref()
|
||||
.expect("Call start before running steps.")
|
||||
.send(RequestMessage::Step())
|
||||
.expect("Can request transitions.");
|
||||
let transition = self
|
||||
.transition_receiver
|
||||
.recv()
|
||||
.expect("Can receive transitions.");
|
||||
items.push(transition.clone());
|
||||
|
||||
if !self.eval {
|
||||
progress.items_processed += 1;
|
||||
processor.process_train(RLEvent::TimeStep(EvaluationItem::new(
|
||||
transition.action_context,
|
||||
progress.clone(),
|
||||
None,
|
||||
)));
|
||||
|
||||
if transition.done {
|
||||
processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(
|
||||
EpisodeSummary {
|
||||
episode_length: transition.ep_len,
|
||||
cum_reward: transition.cum_reward,
|
||||
},
|
||||
progress.clone(),
|
||||
None,
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
if interrupter.should_stop() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
items
|
||||
}
|
||||
|
||||
fn run_episodes(
|
||||
&mut self,
|
||||
num_episodes: usize,
|
||||
processor: &mut RLEventProcessorType<RLC>,
|
||||
interrupter: &Interrupter,
|
||||
_progress: &mut Progress,
|
||||
) -> Vec<RLTrajectory<BT, RLC>> {
|
||||
let mut items = vec![];
|
||||
self.agent.increment_agents(1);
|
||||
for episode_num in 0..num_episodes {
|
||||
self.request_sender
|
||||
.as_ref()
|
||||
.expect("Call start before running episodes.")
|
||||
.send(RequestMessage::Episode())
|
||||
.expect("Can request episodes.");
|
||||
let trajectory = self
|
||||
.trajectory_receiver
|
||||
.recv()
|
||||
.expect("Main thread can receive trajectory.");
|
||||
|
||||
for (i, step) in trajectory.timesteps.iter().enumerate() {
|
||||
// TODO : clean this.
|
||||
if self.eval {
|
||||
processor.process_valid(AgentEvaluationEvent::TimeStep(EvaluationItem::new(
|
||||
step.action_context.clone(),
|
||||
Progress::new(i, i),
|
||||
None,
|
||||
)));
|
||||
|
||||
if step.done {
|
||||
processor.process_valid(AgentEvaluationEvent::EpisodeEnd(
|
||||
EvaluationItem::new(
|
||||
EpisodeSummary {
|
||||
episode_length: step.ep_len,
|
||||
cum_reward: step.cum_reward,
|
||||
},
|
||||
Progress::new(episode_num + 1, num_episodes),
|
||||
None,
|
||||
),
|
||||
));
|
||||
}
|
||||
} else {
|
||||
processor.process_train(RLEvent::TimeStep(EvaluationItem::new(
|
||||
step.action_context.clone(),
|
||||
Progress::new(i, i),
|
||||
None,
|
||||
)));
|
||||
|
||||
if step.done {
|
||||
processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(
|
||||
EpisodeSummary {
|
||||
episode_length: step.ep_len,
|
||||
cum_reward: step.cum_reward,
|
||||
},
|
||||
Progress::new(episode_num + 1, num_episodes),
|
||||
None,
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
items.push(trajectory);
|
||||
if interrupter.should_stop() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
self.agent.decrement_agents(1);
|
||||
items
|
||||
}
|
||||
|
||||
fn update_policy(&mut self, update: RLC::PolicyState) {
|
||||
self.agent.update(update);
|
||||
}
|
||||
|
||||
fn policy(&self) -> RLC::PolicyState {
|
||||
self.agent.state()
|
||||
}
|
||||
}
|
||||
|
||||
/// An asynchronous runner for multiple agent/environement interfaces.
|
||||
pub struct MultiAgentEnvLoop<BT: Backend, RLC: RLComponentsTypes> {
|
||||
env_init: RLC::EnvInit,
|
||||
num_envs: usize,
|
||||
eval: bool,
|
||||
agent: AsyncPolicy<RLC::Backend, RLC::Policy>,
|
||||
deterministic: bool,
|
||||
device: Device<BT>,
|
||||
transition_receiver: Receiver<RLTimeStep<BT, RLC>>,
|
||||
transition_sender: Sender<RLTimeStep<BT, RLC>>,
|
||||
trajectory_receiver: Receiver<RLTrajectory<BT, RLC>>,
|
||||
trajectory_sender: Sender<RLTrajectory<BT, RLC>>,
|
||||
request_senders: Vec<Sender<RequestMessage>>,
|
||||
}
|
||||
|
||||
impl<BT: Backend, RLC: RLComponentsTypes> MultiAgentEnvLoop<BT, RLC> {
|
||||
/// Create a new asynchronous runner for multiple agent/environement interfaces.
|
||||
pub fn new(
|
||||
env_init: RLC::EnvInit,
|
||||
num_envs: usize,
|
||||
eval: bool,
|
||||
agent: AsyncPolicy<RLC::Backend, RLC::Policy>,
|
||||
deterministic: bool,
|
||||
device: &Device<BT>,
|
||||
) -> Self {
|
||||
let (transition_sender, transition_receiver) = std::sync::mpsc::channel();
|
||||
let (trajectory_sender, trajectory_receiver) = std::sync::mpsc::channel();
|
||||
Self {
|
||||
env_init,
|
||||
num_envs,
|
||||
eval,
|
||||
agent: agent.clone(),
|
||||
deterministic,
|
||||
device: device.clone(),
|
||||
transition_receiver,
|
||||
transition_sender,
|
||||
trajectory_receiver,
|
||||
trajectory_sender,
|
||||
request_senders: Vec::with_capacity(num_envs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<BT, RLC> AgentEnvLoop<BT, RLC> for MultiAgentEnvLoop<BT, RLC>
|
||||
where
|
||||
BT: Backend,
|
||||
RLC: RLComponentsTypes,
|
||||
{
|
||||
// TODO: start() shouldn't exist.
|
||||
fn start(&mut self) {
|
||||
// Double batching : The environments are always one step ahead of requests. This allows inference for the first batch of steps.
|
||||
self.agent.increment_agents(self.num_envs);
|
||||
|
||||
for i in 0..self.num_envs {
|
||||
let mut runner = AgentEnvAsyncLoop::<BT, RLC>::new(
|
||||
self.env_init.clone(),
|
||||
i,
|
||||
self.eval,
|
||||
self.agent.clone(),
|
||||
self.deterministic,
|
||||
&self.device,
|
||||
);
|
||||
runner.transition_sender = self.transition_sender.clone();
|
||||
runner.trajectory_sender = self.trajectory_sender.clone();
|
||||
runner.start();
|
||||
self.request_senders
|
||||
.push(runner.request_sender.clone().unwrap());
|
||||
}
|
||||
|
||||
// Double batching : The environments are always one step ahead.
|
||||
self.request_senders.iter().for_each(|s| {
|
||||
s.send(RequestMessage::Step())
|
||||
.expect("Main thread can send step requests.")
|
||||
});
|
||||
}
|
||||
|
||||
fn run_steps(
|
||||
&mut self,
|
||||
num_steps: usize,
|
||||
processor: &mut RLEventProcessorType<RLC>,
|
||||
interrupter: &Interrupter,
|
||||
progress: &mut Progress,
|
||||
) -> Vec<RLTimeStep<BT, RLC>> {
|
||||
let mut items = vec![];
|
||||
for _ in 0..num_steps {
|
||||
let transition = self
|
||||
.transition_receiver
|
||||
.recv()
|
||||
.expect("Can receive transitions.");
|
||||
items.push(transition.clone());
|
||||
|
||||
self.request_senders[transition.env_id]
|
||||
.send(RequestMessage::Step())
|
||||
.expect("Main thread can request steps.");
|
||||
|
||||
if !self.eval {
|
||||
progress.items_processed += 1;
|
||||
processor.process_train(RLEvent::TimeStep(EvaluationItem::new(
|
||||
transition.action_context,
|
||||
progress.clone(),
|
||||
None,
|
||||
)));
|
||||
|
||||
if transition.done {
|
||||
processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(
|
||||
EpisodeSummary {
|
||||
episode_length: transition.ep_len,
|
||||
cum_reward: transition.cum_reward,
|
||||
},
|
||||
progress.clone(),
|
||||
None,
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
if interrupter.should_stop() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
items
|
||||
}
|
||||
|
||||
fn update_policy(&mut self, update: RLC::PolicyState) {
|
||||
self.agent.update(update);
|
||||
}
|
||||
|
||||
fn run_episodes(
|
||||
&mut self,
|
||||
num_episodes: usize,
|
||||
processor: &mut RLEventProcessorType<RLC>,
|
||||
interrupter: &Interrupter,
|
||||
_progress: &mut Progress,
|
||||
) -> Vec<RLTrajectory<BT, RLC>> {
|
||||
// Send `num_episodes` initial requests.
|
||||
let mut idx = vec![];
|
||||
if num_episodes < self.num_envs {
|
||||
let mut rng = rand::rng();
|
||||
let mut vec: Vec<usize> = (0..self.num_envs).collect();
|
||||
vec.shuffle(&mut rng);
|
||||
idx = vec.into_iter().take(num_episodes).collect();
|
||||
} else {
|
||||
idx = (0..self.num_envs).collect();
|
||||
}
|
||||
let num_requests = self.num_envs.min(num_episodes);
|
||||
idx.into_iter().for_each(|i| {
|
||||
self.request_senders[i]
|
||||
.send(RequestMessage::Episode())
|
||||
.expect("Main thread can request steps.");
|
||||
});
|
||||
|
||||
let mut items = vec![];
|
||||
for episode_num in 0..num_episodes {
|
||||
let trajectory = self
|
||||
.trajectory_receiver
|
||||
.recv()
|
||||
.expect("Can receive trajectory.");
|
||||
items.push(trajectory.clone());
|
||||
if items.len() + num_requests <= num_episodes {
|
||||
self.request_senders[trajectory.timesteps[0].env_id]
|
||||
.send(RequestMessage::Episode())
|
||||
.expect("Main thread can request steps.");
|
||||
}
|
||||
for (i, step) in trajectory.timesteps.iter().enumerate() {
|
||||
if self.eval {
|
||||
processor.process_valid(AgentEvaluationEvent::TimeStep(EvaluationItem::new(
|
||||
step.action_context.clone(),
|
||||
Progress::new(i, i),
|
||||
None,
|
||||
)));
|
||||
|
||||
if step.done {
|
||||
processor.process_valid(AgentEvaluationEvent::EpisodeEnd(
|
||||
EvaluationItem::new(
|
||||
EpisodeSummary {
|
||||
episode_length: step.ep_len,
|
||||
cum_reward: step.cum_reward,
|
||||
},
|
||||
Progress::new(episode_num + 1, num_episodes),
|
||||
None,
|
||||
),
|
||||
));
|
||||
}
|
||||
} else {
|
||||
processor.process_train(RLEvent::TimeStep(EvaluationItem::new(
|
||||
step.action_context.clone(),
|
||||
Progress::new(i, i),
|
||||
None,
|
||||
)));
|
||||
|
||||
if step.done {
|
||||
processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(
|
||||
EpisodeSummary {
|
||||
episode_length: step.ep_len,
|
||||
cum_reward: step.cum_reward,
|
||||
},
|
||||
Progress::new(episode_num + 1, num_episodes),
|
||||
None,
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if interrupter.should_stop() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
items
|
||||
}
|
||||
|
||||
fn policy(&self) -> RLC::PolicyState {
|
||||
self.agent.state()
|
||||
}
|
||||
}
|
||||
273
crates/burn-train/src/learner/rl/env_runner/base.rs
Normal file
273
crates/burn-train/src/learner/rl/env_runner/base.rs
Normal file
@@ -0,0 +1,273 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use burn_core::data::dataloader::Progress;
|
||||
use burn_core::{Tensor, prelude::Backend};
|
||||
use burn_rl::Policy;
|
||||
use burn_rl::Transition;
|
||||
use burn_rl::{Environment, EnvironmentInit};
|
||||
|
||||
use crate::RLEvent;
|
||||
use crate::{
|
||||
AgentEvaluationEvent, EpisodeSummary, EvaluationItem, EventProcessorTraining,
|
||||
RLEventProcessorType,
|
||||
};
|
||||
use crate::{Interrupter, RLComponentsTypes};
|
||||
|
||||
/// A trajectory, i.e. a list of ordered [TimeStep](TimeStep).
|
||||
#[derive(Clone, new)]
|
||||
pub struct Trajectory<B: Backend, S, A, C> {
|
||||
/// A list of ordered [TimeStep](TimeStep)s.
|
||||
pub timesteps: Vec<TimeStep<B, S, A, C>>,
|
||||
}
|
||||
|
||||
/// A timestep debscribing an iteration of the state/decision process.
|
||||
#[derive(Clone)]
|
||||
pub struct TimeStep<B: Backend, S, A, C> {
|
||||
/// The environment id.
|
||||
pub env_id: usize,
|
||||
/// The [burn_rl::Transition](burn_rl::Transition).
|
||||
pub transition: Transition<B, S, A>,
|
||||
/// True if the environment reaches a terminal state.
|
||||
pub done: bool,
|
||||
/// The running length of the current episode.
|
||||
pub ep_len: usize,
|
||||
/// The running cumulative reward.
|
||||
pub cum_reward: f64,
|
||||
/// The action's context for this timestep.
|
||||
pub action_context: C,
|
||||
}
|
||||
|
||||
pub(crate) type RLTimeStep<B, RLC> = TimeStep<
|
||||
B,
|
||||
<RLC as RLComponentsTypes>::State,
|
||||
<RLC as RLComponentsTypes>::Action,
|
||||
<RLC as RLComponentsTypes>::ActionContext,
|
||||
>;
|
||||
|
||||
pub(crate) type RLTrajectory<B, RLC> = Trajectory<
|
||||
B,
|
||||
<RLC as RLComponentsTypes>::State,
|
||||
<RLC as RLComponentsTypes>::Action,
|
||||
<RLC as RLComponentsTypes>::ActionContext,
|
||||
>;
|
||||
|
||||
/// Trait for a structure that implements an agent/environement interface.
|
||||
pub trait AgentEnvLoop<BT: Backend, RLC: RLComponentsTypes> {
|
||||
/// Start the loop.
|
||||
fn start(&mut self);
|
||||
/// Run a certain number of timesteps.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `num_steps` - The number of time_steps to run.
|
||||
/// * `processor` - An [crate::EventProcessorTraining](crate::EventProcessorTraining).
|
||||
/// * `interrupter` - An [crate::Interrupter](crate::Interrupter).
|
||||
/// * `num_steps` - The number of time_steps to run.
|
||||
/// * `progress` - A mutable reference to the learning progress.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A list of ordered timesteps.
|
||||
fn run_steps(
|
||||
&mut self,
|
||||
num_steps: usize,
|
||||
processor: &mut RLEventProcessorType<RLC>,
|
||||
interrupter: &Interrupter,
|
||||
progress: &mut Progress,
|
||||
) -> Vec<RLTimeStep<BT, RLC>>;
|
||||
/// Run a certain number of episodes.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `num_episodes` - The number of episodes to run.
|
||||
/// * `processor` - An [crate::EventProcessorTraining](crate::EventProcessorTraining).
|
||||
/// * `interrupter` - An [crate::Interrupter](crate::Interrupter).
|
||||
/// * `progress` - A mutable reference to the learning progress.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A list of ordered timesteps.
|
||||
fn run_episodes(
|
||||
&mut self,
|
||||
num_episodes: usize,
|
||||
processor: &mut RLEventProcessorType<RLC>,
|
||||
interrupter: &Interrupter,
|
||||
progress: &mut Progress,
|
||||
) -> Vec<RLTrajectory<BT, RLC>>;
|
||||
/// Update the runner's agent.
|
||||
fn update_policy(&mut self, update: RLC::PolicyState);
|
||||
/// Get the state of the runner's agent.
|
||||
fn policy(&self) -> RLC::PolicyState;
|
||||
}
|
||||
|
||||
/// A simple, synchronized agent/environement interface.
|
||||
pub struct AgentEnvBaseLoop<B: Backend, RLC: RLComponentsTypes> {
|
||||
env: RLC::Env,
|
||||
eval: bool,
|
||||
agent: RLC::Policy,
|
||||
deterministic: bool,
|
||||
current_reward: f64,
|
||||
run_num: usize,
|
||||
step_num: usize,
|
||||
_backend: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend, RLC: RLComponentsTypes> AgentEnvBaseLoop<B, RLC> {
|
||||
/// Create a new base runner.
|
||||
pub fn new(
|
||||
env_init: RLC::EnvInit,
|
||||
agent: RLC::Policy,
|
||||
eval: bool,
|
||||
deterministic: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
env: env_init.init(),
|
||||
eval,
|
||||
agent: agent.clone(),
|
||||
deterministic,
|
||||
current_reward: 0.0,
|
||||
run_num: 0,
|
||||
step_num: 0,
|
||||
_backend: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<BT, RLC> AgentEnvLoop<BT, RLC> for AgentEnvBaseLoop<BT, RLC>
|
||||
where
|
||||
BT: Backend,
|
||||
RLC: RLComponentsTypes,
|
||||
{
|
||||
fn start(&mut self) {
|
||||
self.env.reset();
|
||||
}
|
||||
|
||||
fn run_steps(
|
||||
&mut self,
|
||||
num_steps: usize,
|
||||
processor: &mut RLEventProcessorType<RLC>,
|
||||
interrupter: &Interrupter,
|
||||
progress: &mut Progress,
|
||||
) -> Vec<RLTimeStep<BT, RLC>> {
|
||||
let mut items = vec![];
|
||||
let device = Default::default();
|
||||
for _ in 0..num_steps {
|
||||
let state = self.env.state();
|
||||
let (action, context) = self.agent.action(state.clone().into(), self.deterministic);
|
||||
|
||||
let step_result = self.env.step(RLC::Action::from(action.clone()));
|
||||
|
||||
self.current_reward += step_result.reward;
|
||||
self.step_num += 1;
|
||||
|
||||
let transition = Transition::new(
|
||||
state.clone(),
|
||||
step_result.next_state,
|
||||
RLC::Action::from(action),
|
||||
Tensor::from_data([step_result.reward], &device),
|
||||
Tensor::from_data(
|
||||
[(step_result.done || step_result.truncated) as i32 as f64],
|
||||
&device,
|
||||
),
|
||||
);
|
||||
items.push(TimeStep {
|
||||
env_id: 0,
|
||||
transition,
|
||||
done: step_result.done,
|
||||
ep_len: self.step_num,
|
||||
cum_reward: self.current_reward,
|
||||
action_context: context[0].clone(),
|
||||
});
|
||||
|
||||
if !self.eval {
|
||||
progress.items_processed += 1;
|
||||
processor.process_train(RLEvent::TimeStep(EvaluationItem::new(
|
||||
context[0].clone(),
|
||||
progress.clone(),
|
||||
None,
|
||||
)));
|
||||
|
||||
if step_result.done {
|
||||
processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(
|
||||
EpisodeSummary {
|
||||
episode_length: self.step_num,
|
||||
cum_reward: self.current_reward,
|
||||
},
|
||||
progress.clone(),
|
||||
None,
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
if interrupter.should_stop() {
|
||||
break;
|
||||
}
|
||||
|
||||
if step_result.done || step_result.truncated {
|
||||
self.env.reset();
|
||||
self.current_reward = 0.;
|
||||
self.step_num = 0;
|
||||
self.run_num += 1;
|
||||
}
|
||||
}
|
||||
items
|
||||
}
|
||||
|
||||
fn update_policy(&mut self, update: RLC::PolicyState) {
|
||||
self.agent.update(update);
|
||||
}
|
||||
|
||||
fn run_episodes(
|
||||
&mut self,
|
||||
num_episodes: usize,
|
||||
processor: &mut RLEventProcessorType<RLC>,
|
||||
interrupter: &Interrupter,
|
||||
progress: &mut Progress,
|
||||
) -> Vec<RLTrajectory<BT, RLC>> {
|
||||
self.env.reset();
|
||||
|
||||
let mut items = vec![];
|
||||
for ep in 0..num_episodes {
|
||||
let mut steps = vec![];
|
||||
loop {
|
||||
let step = self.run_steps(1, processor, interrupter, progress)[0].clone();
|
||||
steps.push(step.clone());
|
||||
|
||||
if self.eval {
|
||||
processor.process_valid(AgentEvaluationEvent::TimeStep(EvaluationItem::new(
|
||||
step.action_context.clone(),
|
||||
Progress::new(steps.len() + 1, steps.len() + 1),
|
||||
None,
|
||||
)));
|
||||
|
||||
if step.done {
|
||||
processor.process_valid(AgentEvaluationEvent::EpisodeEnd(
|
||||
EvaluationItem::new(
|
||||
EpisodeSummary {
|
||||
episode_length: step.ep_len,
|
||||
cum_reward: step.cum_reward,
|
||||
},
|
||||
Progress::new(ep + 1, num_episodes),
|
||||
None,
|
||||
),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
if interrupter.should_stop() || step.done {
|
||||
break;
|
||||
}
|
||||
}
|
||||
items.push(Trajectory::new(steps));
|
||||
|
||||
if interrupter.should_stop() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
items
|
||||
}
|
||||
|
||||
fn policy(&self) -> RLC::PolicyState {
|
||||
self.agent.state()
|
||||
}
|
||||
}
|
||||
5
crates/burn-train/src/learner/rl/env_runner/mod.rs
Normal file
5
crates/burn-train/src/learner/rl/env_runner/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
mod async_runner;
|
||||
mod base;
|
||||
|
||||
pub use async_runner::*;
|
||||
pub use base::*;
|
||||
15
crates/burn-train/src/learner/rl/mod.rs
Normal file
15
crates/burn-train/src/learner/rl/mod.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
mod checkpointer;
|
||||
mod components;
|
||||
mod env_runner;
|
||||
mod off_policy;
|
||||
mod output;
|
||||
mod paradigm;
|
||||
mod strategy;
|
||||
|
||||
pub use checkpointer::*;
|
||||
pub use components::*;
|
||||
pub use env_runner::*;
|
||||
pub use off_policy::*;
|
||||
pub use output::*;
|
||||
pub use paradigm::*;
|
||||
pub use strategy::*;
|
||||
172
crates/burn-train/src/learner/rl/off_policy.rs
Normal file
172
crates/burn-train/src/learner/rl/off_policy.rs
Normal file
@@ -0,0 +1,172 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::{
|
||||
AgentEnvAsyncLoop, AgentEnvLoop, EvaluationItem, EventProcessorTraining, MultiAgentEnvLoop,
|
||||
RLComponents, RLComponentsTypes, RLEvent, RLEventProcessorType, RLStrategy,
|
||||
};
|
||||
use burn_core::{self as burn};
|
||||
use burn_core::{config::Config, data::dataloader::Progress};
|
||||
use burn_ndarray::NdArray;
|
||||
use burn_rl::{AsyncPolicy, Policy, PolicyLearner, Transition, TransitionBuffer};
|
||||
|
||||
/// Parameters of an on policy training with multi environments and double-batching.
|
||||
#[derive(Config, Debug)]
|
||||
pub struct OffPolicyConfig {
|
||||
/// The number of environments to run simultaneously for experience collection.
|
||||
#[config(default = 1)]
|
||||
pub num_envs: usize,
|
||||
/// Number of environment state to accumulate before running one step of inference with the policy.
|
||||
/// Must be equal or less than the number of simultaneous environments.
|
||||
#[config(default = 1)]
|
||||
pub autobatch_size: usize,
|
||||
/// Max number of transitions stored in the replay buffer.
|
||||
#[config(default = 1024)]
|
||||
pub replay_buffer_size: usize,
|
||||
/// The number of steps to collect between each step of training.
|
||||
#[config(default = 1)]
|
||||
pub train_interval: usize,
|
||||
/// Number of optimization steps done each `train_interval`.
|
||||
#[config(default = 1)]
|
||||
pub train_steps: usize,
|
||||
/// The number of steps to collect between each evaluation.
|
||||
#[config(default = 10_000)]
|
||||
pub eval_interval: usize,
|
||||
/// The number of episodes to run for each evaluation.
|
||||
#[config(default = 1)]
|
||||
pub eval_episodes: usize,
|
||||
/// The number of transition to train on.
|
||||
#[config(default = 32)]
|
||||
pub train_batch_size: usize,
|
||||
/// Number of steps to collect before starting to train.
|
||||
#[config(default = 0)]
|
||||
pub warmup_steps: usize,
|
||||
}
|
||||
|
||||
/// Off-policy reinforcement learning strategy with multi-env experience collection and double-batching.
|
||||
pub struct OffPolicyStrategy<RLC: RLComponentsTypes> {
|
||||
config: OffPolicyConfig,
|
||||
_components: PhantomData<RLC>,
|
||||
}
|
||||
impl<RLC: RLComponentsTypes> OffPolicyStrategy<RLC> {
|
||||
/// Create a new off-policy base strategy.
|
||||
pub fn new(config: OffPolicyConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
_components: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<RLC> RLStrategy<RLC> for OffPolicyStrategy<RLC>
|
||||
where
|
||||
RLC: RLComponentsTypes,
|
||||
{
|
||||
fn learn(
|
||||
&self,
|
||||
training_components: RLComponents<RLC>,
|
||||
learner_agent: &mut RLC::LearningAgent,
|
||||
starting_epoch: usize,
|
||||
env_init: RLC::EnvInit,
|
||||
) -> (RLC::Policy, RLEventProcessorType<RLC>) {
|
||||
let mut event_processor = training_components.event_processor;
|
||||
let mut checkpointer = training_components.checkpointer;
|
||||
let num_steps_total = training_components.num_steps;
|
||||
|
||||
let mut env_runner = MultiAgentEnvLoop::<NdArray, RLC>::new(
|
||||
env_init.clone(),
|
||||
self.config.num_envs,
|
||||
false,
|
||||
AsyncPolicy::new(
|
||||
self.config.num_envs.min(self.config.autobatch_size),
|
||||
learner_agent.policy(),
|
||||
),
|
||||
false,
|
||||
&Default::default(),
|
||||
);
|
||||
env_runner.start();
|
||||
let mut env_runner_valid = AgentEnvAsyncLoop::<NdArray, RLC>::new(
|
||||
env_init,
|
||||
0,
|
||||
true,
|
||||
AsyncPolicy::new(1, learner_agent.policy()),
|
||||
true,
|
||||
&Default::default(),
|
||||
);
|
||||
env_runner_valid.start();
|
||||
let mut transition_buffer =
|
||||
TransitionBuffer::<Transition<NdArray, RLC::State, RLC::Action>>::new(
|
||||
self.config.replay_buffer_size,
|
||||
);
|
||||
|
||||
let mut valid_next = self.config.eval_interval + starting_epoch - 1;
|
||||
let mut progress = Progress {
|
||||
items_processed: starting_epoch,
|
||||
items_total: num_steps_total,
|
||||
};
|
||||
|
||||
let mut intermediary_update: Option<<RLC::Policy as Policy<RLC::Backend>>::PolicyState> =
|
||||
None;
|
||||
while progress.items_processed < num_steps_total {
|
||||
if training_components.interrupter.should_stop() {
|
||||
let reason = training_components
|
||||
.interrupter
|
||||
.get_message()
|
||||
.unwrap_or(String::from("Reason unknown"));
|
||||
log::info!("Training interrupted: {reason}");
|
||||
break;
|
||||
}
|
||||
|
||||
let previous_steps = progress.items_processed;
|
||||
let items = env_runner.run_steps(
|
||||
self.config.train_interval,
|
||||
&mut event_processor,
|
||||
&training_components.interrupter,
|
||||
&mut progress,
|
||||
);
|
||||
|
||||
transition_buffer.append(&mut items.iter().map(|i| i.transition.clone()).collect());
|
||||
|
||||
if transition_buffer.len() >= self.config.train_batch_size
|
||||
&& progress.items_processed >= self.config.warmup_steps
|
||||
{
|
||||
if let Some(ref u) = intermediary_update {
|
||||
env_runner.update_policy(u.clone());
|
||||
}
|
||||
for _ in 0..self.config.train_steps {
|
||||
let transitions = transition_buffer.random_sample(self.config.train_batch_size);
|
||||
let train_item = learner_agent.train(transitions.into());
|
||||
intermediary_update = Some(train_item.policy);
|
||||
|
||||
event_processor.process_train(RLEvent::TrainStep(EvaluationItem::new(
|
||||
train_item.item,
|
||||
progress.clone(),
|
||||
None,
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
if valid_next > previous_steps && valid_next <= progress.items_processed {
|
||||
env_runner_valid.update_policy(learner_agent.policy().state());
|
||||
env_runner_valid.run_episodes(
|
||||
self.config.eval_episodes,
|
||||
&mut event_processor,
|
||||
&training_components.interrupter,
|
||||
&mut progress,
|
||||
);
|
||||
|
||||
if let Some(checkpointer) = &mut checkpointer {
|
||||
checkpointer.checkpoint(
|
||||
&env_runner.policy(),
|
||||
learner_agent,
|
||||
valid_next,
|
||||
&training_components.event_store,
|
||||
);
|
||||
}
|
||||
|
||||
valid_next += self.config.eval_interval;
|
||||
}
|
||||
}
|
||||
|
||||
(learner_agent.policy(), event_processor)
|
||||
}
|
||||
}
|
||||
32
crates/burn-train/src/learner/rl/output.rs
Normal file
32
crates/burn-train/src/learner/rl/output.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use crate::{
|
||||
ItemLazy,
|
||||
metric::{Adaptor, CumulativeRewardInput, EpisodeLengthInput},
|
||||
};
|
||||
|
||||
/// Summary of an episode.
|
||||
pub struct EpisodeSummary {
|
||||
/// The total length of the episode.
|
||||
pub episode_length: usize,
|
||||
/// The final cumulative reward.
|
||||
pub cum_reward: f64,
|
||||
}
|
||||
|
||||
impl ItemLazy for EpisodeSummary {
|
||||
type ItemSync = EpisodeSummary;
|
||||
|
||||
fn sync(self) -> Self::ItemSync {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Adaptor<EpisodeLengthInput> for EpisodeSummary {
|
||||
fn adapt(&self) -> EpisodeLengthInput {
|
||||
EpisodeLengthInput::new(self.episode_length as f64)
|
||||
}
|
||||
}
|
||||
|
||||
impl Adaptor<CumulativeRewardInput> for EpisodeSummary {
|
||||
fn adapt(&self) -> CumulativeRewardInput {
|
||||
CumulativeRewardInput::new(self.cum_reward)
|
||||
}
|
||||
}
|
||||
521
crates/burn-train/src/learner/rl/paradigm.rs
Normal file
521
crates/burn-train/src/learner/rl/paradigm.rs
Normal file
@@ -0,0 +1,521 @@
|
||||
use crate::checkpoint::{
|
||||
AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer,
|
||||
KeepLastNCheckpoints, MetricCheckpointingStrategy,
|
||||
};
|
||||
use crate::learner::base::Interrupter;
|
||||
use crate::logger::{FileMetricLogger, MetricLogger};
|
||||
use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split};
|
||||
use crate::metric::{Adaptor, EpisodeLengthMetric, Metric, Numeric};
|
||||
use crate::renderer::{MetricsRenderer, default_renderer};
|
||||
use crate::{
|
||||
ApplicationLoggerInstaller, AsyncProcessorTraining, FileApplicationLoggerInstaller, ItemLazy,
|
||||
LearnerSummaryConfig, OffPolicyConfig, OffPolicyStrategy, RLAgentRecord, RLCheckpointer,
|
||||
RLComponents, RLComponentsMarker, RLComponentsTypes, RLEventProcessor, RLMetrics,
|
||||
RLPolicyRecord, RLStrategy,
|
||||
};
|
||||
use crate::{EpisodeSummary, RLStrategies};
|
||||
use burn_core::record::FileRecorder;
|
||||
use burn_core::tensor::backend::AutodiffBackend;
|
||||
use burn_rl::{Batchable, Environment, EnvironmentInit, Policy, PolicyLearner};
|
||||
use std::collections::BTreeSet;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Structure to configure and launch reinforcement learning trainings.
|
||||
pub struct RLTraining<RLC: RLComponentsTypes> {
|
||||
// Not that complex. Extracting into yet another type would only make it more confusing.
|
||||
#[allow(clippy::type_complexity)]
|
||||
checkpointers: Option<(
|
||||
AsyncCheckpointer<RLPolicyRecord<RLC>, RLC::Backend>,
|
||||
AsyncCheckpointer<RLAgentRecord<RLC>, RLC::Backend>,
|
||||
)>,
|
||||
num_steps: usize,
|
||||
checkpoint: Option<usize>,
|
||||
directory: PathBuf,
|
||||
grad_accumulation: Option<usize>,
|
||||
renderer: Option<Box<dyn MetricsRenderer + 'static>>,
|
||||
metrics: RLMetrics<RLC::TrainingOutput, RLC::ActionContext>,
|
||||
event_store: LogEventStore,
|
||||
interrupter: Interrupter,
|
||||
tracing_logger: Option<Box<dyn ApplicationLoggerInstaller>>,
|
||||
checkpointer_strategy: Box<dyn CheckpointingStrategy>,
|
||||
learning_strategy: RLStrategies<RLC>,
|
||||
// Use BTreeSet instead of HashSet for consistent (alphabetical) iteration order
|
||||
summary_metrics: BTreeSet<String>,
|
||||
summary: bool,
|
||||
env_initializer: RLC::EnvInit,
|
||||
}
|
||||
|
||||
impl<B, E, EI, A> RLTraining<RLComponentsMarker<B, E, EI, A>>
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
E: Environment + 'static,
|
||||
EI: EnvironmentInit<E> + Send + 'static,
|
||||
A: PolicyLearner<B> + Send + 'static,
|
||||
A::TrainContext: ItemLazy + Clone + Send,
|
||||
A::InnerPolicy: Policy<B> + Send,
|
||||
<A::InnerPolicy as Policy<B>>::Observation: Batchable + Clone + Send,
|
||||
<A::InnerPolicy as Policy<B>>::ActionDistribution: Batchable + Clone + Send,
|
||||
<A::InnerPolicy as Policy<B>>::Action: Batchable + Clone + Send,
|
||||
<A::InnerPolicy as Policy<B>>::ActionContext: ItemLazy + Clone + Send + 'static,
|
||||
<A::InnerPolicy as Policy<B>>::PolicyState: Clone + Send,
|
||||
E::State: Into<<A::InnerPolicy as Policy<B>>::Observation> + Clone + Send + 'static,
|
||||
E::Action: From<<A::InnerPolicy as Policy<B>>::Action>
|
||||
+ Into<<A::InnerPolicy as Policy<B>>::Action>
|
||||
+ Clone
|
||||
+ Send
|
||||
+ 'static,
|
||||
{
|
||||
/// Creates a new runner for reinforcement learning.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `directory` - The directory to save the checkpoints.
|
||||
/// * `env_init` - Specifies how to initialize the environment.
|
||||
pub fn new(directory: impl AsRef<Path>, env_initializer: EI) -> Self {
|
||||
let directory = directory.as_ref().to_path_buf();
|
||||
let experiment_log_file = directory.join("experiment.log");
|
||||
Self {
|
||||
num_steps: 1,
|
||||
checkpoint: None,
|
||||
checkpointers: None,
|
||||
directory,
|
||||
grad_accumulation: None,
|
||||
metrics: RLMetrics::default(),
|
||||
event_store: LogEventStore::default(),
|
||||
renderer: None,
|
||||
interrupter: Interrupter::new(),
|
||||
tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(
|
||||
experiment_log_file,
|
||||
))),
|
||||
checkpointer_strategy: Box::new(
|
||||
ComposedCheckpointingStrategy::builder()
|
||||
.add(KeepLastNCheckpoints::new(2))
|
||||
.add(MetricCheckpointingStrategy::new(
|
||||
&EpisodeLengthMetric::new(), // default to evaluations' cumulative reward.
|
||||
Aggregate::Mean,
|
||||
Direction::Lowest,
|
||||
Split::Valid,
|
||||
))
|
||||
.build(),
|
||||
),
|
||||
learning_strategy: RLStrategies::OffPolicyStrategy(OffPolicyConfig::new()),
|
||||
summary_metrics: BTreeSet::new(),
|
||||
summary: false,
|
||||
env_initializer,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<RLC: RLComponentsTypes + 'static> RLTraining<RLC> {
|
||||
/// Replace the default learning strategy (Off Policy learning) with the provided one.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `training_strategy` - The training strategy.
|
||||
pub fn with_learning_strategy(mut self, learning_strategy: RLStrategies<RLC>) -> Self {
|
||||
self.learning_strategy = learning_strategy;
|
||||
self
|
||||
}
|
||||
|
||||
/// Replace the default metric loggers with the provided ones.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `logger` - The training logger.
|
||||
pub fn with_metric_logger<ML>(mut self, logger: ML) -> Self
|
||||
where
|
||||
ML: MetricLogger + 'static,
|
||||
{
|
||||
self.event_store.register_logger(logger);
|
||||
self
|
||||
}
|
||||
|
||||
/// Update the checkpointing_strategy.
|
||||
pub fn with_checkpointing_strategy<CS: CheckpointingStrategy + 'static>(
|
||||
mut self,
|
||||
strategy: CS,
|
||||
) -> Self {
|
||||
self.checkpointer_strategy = Box::new(strategy);
|
||||
self
|
||||
}
|
||||
|
||||
/// Replace the default CLI renderer with a custom one.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `renderer` - The custom renderer.
|
||||
pub fn renderer<MR>(mut self, renderer: MR) -> Self
|
||||
where
|
||||
MR: MetricsRenderer + 'static,
|
||||
{
|
||||
self.renderer = Some(Box::new(renderer));
|
||||
self
|
||||
}
|
||||
|
||||
/// Register numerical metrics for a training step of the agent.
|
||||
pub fn metrics_train<Me: TrainMetricRegistration<RLC>>(self, metrics: Me) -> Self {
|
||||
metrics.register(self)
|
||||
}
|
||||
|
||||
/// Register textual metrics for a training step of the agent.
|
||||
pub fn text_metrics_train<Me: TrainTextMetricRegistration<RLC>>(self, metrics: Me) -> Self {
|
||||
metrics.register(self)
|
||||
}
|
||||
|
||||
/// Register numerical metrics for each action of the agent.
|
||||
pub fn metrics_agent<Me: AgentMetricRegistration<RLC>>(self, metrics: Me) -> Self {
|
||||
metrics.register(self)
|
||||
}
|
||||
|
||||
/// Register textual metrics for each action of the agent.
|
||||
pub fn text_metrics_agent<Me: AgentTextMetricRegistration<RLC>>(self, metrics: Me) -> Self {
|
||||
metrics.register(self)
|
||||
}
|
||||
|
||||
/// Register numerical metrics for a completed episode.
|
||||
pub fn metrics_episode<Me: EpisodeMetricRegistration<RLC>>(self, metrics: Me) -> Self {
|
||||
metrics.register(self)
|
||||
}
|
||||
|
||||
/// Register textual metrics for a completed episode.
|
||||
pub fn text_metrics_episode<Me: EpisodeTextMetricRegistration<RLC>>(self, metrics: Me) -> Self {
|
||||
metrics.register(self)
|
||||
}
|
||||
|
||||
/// Register a textual metric for a training step.
|
||||
pub fn text_metric_train<Me: Metric + 'static>(mut self, metric: Me) -> Self
|
||||
where
|
||||
<RLC::TrainingOutput as ItemLazy>::ItemSync: Adaptor<Me::Input>,
|
||||
{
|
||||
self.metrics.register_text_metric_train(metric);
|
||||
self
|
||||
}
|
||||
|
||||
/// Register a [numeric](crate::metric::Numeric) [metric](Metric) for a training step.
|
||||
pub fn metric_train<Me>(mut self, metric: Me) -> Self
|
||||
where
|
||||
Me: Metric + Numeric + 'static,
|
||||
<RLC::TrainingOutput as ItemLazy>::ItemSync: Adaptor<Me::Input>,
|
||||
{
|
||||
self.summary_metrics.insert(metric.name().to_string());
|
||||
self.metrics.register_metric_train(metric);
|
||||
self
|
||||
}
|
||||
|
||||
/// Register a textual metric for each action taken by the agent.
|
||||
pub fn text_metric_agent<Me: Metric + 'static>(mut self, metric: Me) -> Self
|
||||
where
|
||||
<RLC::ActionContext as ItemLazy>::ItemSync: Adaptor<Me::Input>,
|
||||
{
|
||||
self.metrics.register_text_metric_agent(metric.clone());
|
||||
self.metrics.register_text_metric_agent_valid(metric);
|
||||
self
|
||||
}
|
||||
|
||||
/// Register a [numeric](crate::metric::Numeric) [metric](Metric) for each action taken by the agent.
|
||||
pub fn metric_agent<Me>(mut self, metric: Me) -> Self
|
||||
where
|
||||
Me: Metric + Numeric + 'static,
|
||||
<RLC::ActionContext as ItemLazy>::ItemSync: Adaptor<Me::Input>,
|
||||
{
|
||||
self.summary_metrics.insert(metric.name().to_string());
|
||||
self.metrics.register_agent_metric(metric.clone());
|
||||
self.metrics.register_agent_metric_valid(metric);
|
||||
self
|
||||
}
|
||||
|
||||
/// Register a textual metric for a completed episode.
|
||||
pub fn text_metric_episode<Me: Metric + 'static>(mut self, metric: Me) -> Self
|
||||
where
|
||||
EpisodeSummary: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
self.metrics.register_text_metric_episode(metric.clone());
|
||||
self.metrics.register_text_metric_episode_valid(metric);
|
||||
self
|
||||
}
|
||||
|
||||
/// Register a [numeric](crate::metric::Numeric) [metric](Metric) for a completed episode.
|
||||
pub fn metric_episode<Me>(mut self, metric: Me) -> Self
|
||||
where
|
||||
Me: Metric + Numeric + 'static,
|
||||
EpisodeSummary: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
self.summary_metrics.insert(metric.name().to_string());
|
||||
self.metrics.register_episode_metric(metric.clone());
|
||||
self.metrics.register_episode_metric_valid(metric);
|
||||
self
|
||||
}
|
||||
|
||||
/// The number of environment steps to train for.
|
||||
pub fn num_steps(mut self, num_steps: usize) -> Self {
|
||||
self.num_steps = num_steps;
|
||||
self
|
||||
}
|
||||
|
||||
/// The step from which the training must resume.
|
||||
pub fn checkpoint(mut self, checkpoint: usize) -> Self {
|
||||
self.checkpoint = Some(checkpoint);
|
||||
self
|
||||
}
|
||||
|
||||
/// Provides a handle that can be used to interrupt training.
|
||||
pub fn interrupter(&self) -> Interrupter {
|
||||
self.interrupter.clone()
|
||||
}
|
||||
|
||||
/// Override the handle for stopping training with an externally provided handle
|
||||
pub fn with_interrupter(mut self, interrupter: Interrupter) -> Self {
|
||||
self.interrupter = interrupter;
|
||||
self
|
||||
}
|
||||
|
||||
/// By default, Rust logs are captured and written into
|
||||
/// `experiment.log`. If disabled, standard Rust log handling
|
||||
/// will apply.
|
||||
pub fn with_application_logger(
|
||||
mut self,
|
||||
logger: Option<Box<dyn ApplicationLoggerInstaller>>,
|
||||
) -> Self {
|
||||
self.tracing_logger = logger;
|
||||
self
|
||||
}
|
||||
|
||||
/// Register a checkpointer that will save the environment runner's [policy](Policy)
|
||||
/// and the [PolicyLearner](PolicyLearner) state to different files.
|
||||
pub fn with_file_checkpointer<FR>(mut self, recorder: FR) -> Self
|
||||
where
|
||||
FR: FileRecorder<RLC::Backend> + 'static,
|
||||
FR: FileRecorder<<RLC::Backend as AutodiffBackend>::InnerBackend> + 'static,
|
||||
{
|
||||
let checkpoint_dir = self.directory.join("checkpoint");
|
||||
let checkpointer_policy =
|
||||
FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "policy");
|
||||
let checkpointer_learning =
|
||||
FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "learning-agent");
|
||||
|
||||
self.checkpointers = Some((
|
||||
AsyncCheckpointer::new(checkpointer_policy),
|
||||
AsyncCheckpointer::new(checkpointer_learning),
|
||||
));
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
/// Enable the training summary report.
|
||||
///
|
||||
/// The summary will be displayed after `.launch()`, when the renderer is dropped.
|
||||
pub fn summary(mut self) -> Self {
|
||||
self.summary = true;
|
||||
self
|
||||
}
|
||||
|
||||
/// Launch the training with the specified [PolicyLearner](PolicyLearner) on the specified environment.
|
||||
pub fn launch(mut self, learner_agent: RLC::LearningAgent) -> RLResult<RLC::Policy> {
|
||||
if self.tracing_logger.is_some()
|
||||
&& let Err(e) = self.tracing_logger.as_ref().unwrap().install()
|
||||
{
|
||||
log::warn!("Failed to install the experiment logger: {e}");
|
||||
}
|
||||
let renderer = self
|
||||
.renderer
|
||||
.unwrap_or_else(|| default_renderer(self.interrupter.clone(), self.checkpoint));
|
||||
|
||||
if !self.event_store.has_loggers() {
|
||||
self.event_store
|
||||
.register_logger(FileMetricLogger::new(self.directory.clone()));
|
||||
}
|
||||
|
||||
let event_store = Arc::new(EventStoreClient::new(self.event_store));
|
||||
let event_processor = AsyncProcessorTraining::new(RLEventProcessor::new(
|
||||
self.metrics,
|
||||
renderer,
|
||||
event_store.clone(),
|
||||
));
|
||||
|
||||
let checkpointer = self.checkpointers.map(|(policy, learning_agent)| {
|
||||
RLCheckpointer::new(policy, learning_agent, self.checkpointer_strategy)
|
||||
});
|
||||
|
||||
let summary = if self.summary {
|
||||
Some(LearnerSummaryConfig {
|
||||
directory: self.directory,
|
||||
metrics: self.summary_metrics.into_iter().collect::<Vec<_>>(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let components = RLComponents::<RLC> {
|
||||
checkpoint: self.checkpoint,
|
||||
checkpointer,
|
||||
interrupter: self.interrupter,
|
||||
event_processor,
|
||||
event_store,
|
||||
num_steps: self.num_steps,
|
||||
grad_accumulation: self.grad_accumulation,
|
||||
summary,
|
||||
};
|
||||
|
||||
match self.learning_strategy {
|
||||
RLStrategies::OffPolicyStrategy(config) => {
|
||||
let strategy = OffPolicyStrategy::new(config);
|
||||
strategy.train(learner_agent, components, self.env_initializer)
|
||||
}
|
||||
RLStrategies::Custom(strategy) => {
|
||||
strategy.train(learner_agent, components, self.env_initializer)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The result of reinforcement learning, containing the final policy along with the [renderer](MetricsRenderer).
|
||||
pub struct RLResult<P> {
|
||||
/// The learned policy.
|
||||
pub policy: P,
|
||||
/// The renderer that can be used for follow up training and evaluation.
|
||||
pub renderer: Box<dyn MetricsRenderer>,
|
||||
}
|
||||
|
||||
/// Trait to fake variadic generics for train step metrics.
|
||||
pub trait AgentMetricRegistration<RLC: RLComponentsTypes>: Sized {
|
||||
/// Register the metrics.
|
||||
fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
|
||||
}
|
||||
|
||||
/// Trait to fake variadic generics for train step text metrics.
|
||||
pub trait AgentTextMetricRegistration<RLC: RLComponentsTypes>: Sized {
|
||||
/// Register the metrics.
|
||||
fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
|
||||
}
|
||||
|
||||
/// Trait to fake variadic generics for env step metrics.
|
||||
pub trait TrainMetricRegistration<RLC: RLComponentsTypes>: Sized {
|
||||
/// Register the metrics.
|
||||
fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
|
||||
}
|
||||
|
||||
/// Trait to fake variadic generics for env step text metrics.
|
||||
pub trait TrainTextMetricRegistration<RLC: RLComponentsTypes>: Sized {
|
||||
/// Register the metrics.
|
||||
fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
|
||||
}
|
||||
|
||||
/// Trait to fake variadic generics for episode metrics.
|
||||
pub trait EpisodeMetricRegistration<RLC: RLComponentsTypes>: Sized {
|
||||
/// Register the metrics.
|
||||
fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
|
||||
}
|
||||
|
||||
/// Trait to fake variadic generics for episode text metrics.
|
||||
pub trait EpisodeTextMetricRegistration<RLC: RLComponentsTypes>: Sized {
|
||||
/// Register the metrics.
|
||||
fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
|
||||
}
|
||||
|
||||
macro_rules! gen_tuple {
|
||||
($($M:ident),*) => {
|
||||
impl<$($M,)* RLC: RLComponentsTypes + 'static> TrainTextMetricRegistration<RLC> for ($($M,)*)
|
||||
where
|
||||
$(<RLC::TrainingOutput as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
|
||||
$($M: Metric + 'static,)*
|
||||
{
|
||||
#[allow(non_snake_case)]
|
||||
fn register(
|
||||
self,
|
||||
builder: RLTraining<RLC>,
|
||||
) -> RLTraining<RLC> {
|
||||
let ($($M,)*) = self;
|
||||
$(let builder = builder.text_metric_train($M.clone());)*
|
||||
builder
|
||||
}
|
||||
}
|
||||
|
||||
impl<$($M,)* RLC: RLComponentsTypes + 'static> TrainMetricRegistration<RLC> for ($($M,)*)
|
||||
where
|
||||
$(<RLC::TrainingOutput as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
|
||||
$($M: Metric + Numeric + 'static,)*
|
||||
{
|
||||
#[allow(non_snake_case)]
|
||||
fn register(
|
||||
self,
|
||||
builder: RLTraining<RLC>,
|
||||
) -> RLTraining<RLC> {
|
||||
let ($($M,)*) = self;
|
||||
$(let builder = builder.metric_train($M.clone());)*
|
||||
builder
|
||||
}
|
||||
}
|
||||
|
||||
impl<$($M,)* RLC: RLComponentsTypes + 'static> AgentTextMetricRegistration<RLC> for ($($M,)*)
|
||||
where
|
||||
$(<RLC::ActionContext as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
|
||||
$($M: Metric + 'static,)*
|
||||
{
|
||||
#[allow(non_snake_case)]
|
||||
fn register(
|
||||
self,
|
||||
builder: RLTraining<RLC>,
|
||||
) -> RLTraining<RLC> {
|
||||
let ($($M,)*) = self;
|
||||
$(let builder = builder.text_metric_agent($M.clone());)*
|
||||
builder
|
||||
}
|
||||
}
|
||||
|
||||
impl<$($M,)* RLC: RLComponentsTypes + 'static> AgentMetricRegistration<RLC> for ($($M,)*)
|
||||
where
|
||||
$(<RLC::ActionContext as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
|
||||
$($M: Metric + Numeric + 'static,)*
|
||||
{
|
||||
#[allow(non_snake_case)]
|
||||
fn register(
|
||||
self,
|
||||
builder: RLTraining<RLC>,
|
||||
) -> RLTraining<RLC> {
|
||||
let ($($M,)*) = self;
|
||||
$(let builder = builder.metric_agent($M.clone());)*
|
||||
builder
|
||||
}
|
||||
}
|
||||
|
||||
impl<$($M,)* RLC: RLComponentsTypes + 'static> EpisodeTextMetricRegistration<RLC> for ($($M,)*)
|
||||
where
|
||||
$(EpisodeSummary: Adaptor<$M::Input> + 'static,)*
|
||||
$($M: Metric + 'static,)*
|
||||
{
|
||||
#[allow(non_snake_case)]
|
||||
fn register(
|
||||
self,
|
||||
builder: RLTraining<RLC>,
|
||||
) -> RLTraining<RLC> {
|
||||
let ($($M,)*) = self;
|
||||
$(let builder = builder.text_metric_episode($M.clone());)*
|
||||
builder
|
||||
}
|
||||
}
|
||||
|
||||
impl<$($M,)* RLC: RLComponentsTypes + 'static> EpisodeMetricRegistration<RLC> for ($($M,)*)
|
||||
where
|
||||
$(EpisodeSummary: Adaptor<$M::Input> + 'static,)*
|
||||
$($M: Metric + Numeric + 'static,)*
|
||||
{
|
||||
#[allow(non_snake_case)]
|
||||
fn register(
|
||||
self,
|
||||
builder: RLTraining<RLC>,
|
||||
) -> RLTraining<RLC> {
|
||||
let ($($M,)*) = self;
|
||||
$(let builder = builder.metric_episode($M.clone());)*
|
||||
builder
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
gen_tuple!(M1);
|
||||
gen_tuple!(M1, M2);
|
||||
gen_tuple!(M1, M2, M3);
|
||||
gen_tuple!(M1, M2, M3, M4);
|
||||
gen_tuple!(M1, M2, M3, M4, M5);
|
||||
gen_tuple!(M1, M2, M3, M4, M5, M6);
|
||||
99
crates/burn-train/src/learner/rl/strategy.rs
Normal file
99
crates/burn-train/src/learner/rl/strategy.rs
Normal file
@@ -0,0 +1,99 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
Interrupter, LearnerSummaryConfig, OffPolicyConfig, RLCheckpointer, RLComponentsTypes, RLEvent,
|
||||
RLEventProcessorType, RLResult,
|
||||
metric::{processor::EventProcessorTraining, store::EventStoreClient},
|
||||
};
|
||||
|
||||
/// Struct to minimise parameters passed to [RLStrategy::train].
|
||||
pub struct RLComponents<RLC: RLComponentsTypes> {
|
||||
/// The total number of environment steps.
|
||||
pub num_steps: usize,
|
||||
/// The step number from which to continue the training.
|
||||
pub checkpoint: Option<usize>,
|
||||
/// A checkpointer used to load and save learning checkpoints.
|
||||
pub checkpointer: Option<RLCheckpointer<RLC>>,
|
||||
/// Enables gradients accumulation.
|
||||
pub grad_accumulation: Option<usize>,
|
||||
/// An [Interupter](Interrupter) that allows aborting the training/evaluation process early.
|
||||
pub interrupter: Interrupter,
|
||||
/// An [EventProcessor](crate::EventProcessorTraining) that processes events happening during training and evaluation.
|
||||
pub event_processor: RLEventProcessorType<RLC>,
|
||||
/// A reference to an [EventStoreClient](EventStoreClient).
|
||||
pub event_store: Arc<EventStoreClient>,
|
||||
/// Config for creating a summary of the learning
|
||||
pub summary: Option<LearnerSummaryConfig>,
|
||||
}
|
||||
|
||||
/// The strategy for reinforcement learning.
|
||||
#[derive(Clone)]
|
||||
pub enum RLStrategies<RLC: RLComponentsTypes> {
|
||||
/// Training on one device
|
||||
OffPolicyStrategy(OffPolicyConfig),
|
||||
/// Training using a custom learning strategy
|
||||
Custom(CustomRLStrategy<RLC>),
|
||||
}
|
||||
|
||||
/// A reference to an implementation of [RLStrategy].
|
||||
pub type CustomRLStrategy<LC> = Arc<dyn RLStrategy<LC>>;
|
||||
|
||||
/// Provides the `fit` function for any learning strategy
|
||||
pub trait RLStrategy<RLC: RLComponentsTypes> {
|
||||
/// Train the learner agent with this strategy.
|
||||
fn train(
|
||||
&self,
|
||||
mut learner_agent: RLC::LearningAgent,
|
||||
mut training_components: RLComponents<RLC>,
|
||||
env_init: RLC::EnvInit,
|
||||
) -> RLResult<RLC::Policy> {
|
||||
let starting_epoch = match training_components.checkpoint {
|
||||
Some(checkpoint) => {
|
||||
if let Some(checkpointer) = &mut training_components.checkpointer {
|
||||
learner_agent = checkpointer.load_checkpoint(
|
||||
learner_agent,
|
||||
&Default::default(),
|
||||
checkpoint,
|
||||
);
|
||||
}
|
||||
checkpoint + 1
|
||||
}
|
||||
None => 1,
|
||||
};
|
||||
|
||||
let summary_config = training_components.summary.clone();
|
||||
|
||||
// Event processor start training
|
||||
training_components
|
||||
.event_processor
|
||||
.process_train(RLEvent::Start);
|
||||
|
||||
// Training loop
|
||||
let (policy, mut event_processor) = self.learn(
|
||||
training_components,
|
||||
&mut learner_agent,
|
||||
starting_epoch,
|
||||
env_init,
|
||||
);
|
||||
|
||||
let summary = summary_config.and_then(|summary| summary.init().ok());
|
||||
|
||||
// Signal training end. For the TUI renderer, this handles the exit & return to main screen.
|
||||
// TODO: summary makes sense for RL?
|
||||
event_processor.process_train(RLEvent::End(summary));
|
||||
|
||||
// let model = model.valid();
|
||||
let renderer = event_processor.renderer();
|
||||
|
||||
RLResult { policy, renderer }
|
||||
}
|
||||
|
||||
/// Training loop for this strategy
|
||||
fn learn(
|
||||
&self,
|
||||
training_components: RLComponents<RLC>,
|
||||
learner_agent: &mut RLC::LearningAgent,
|
||||
starting_epoch: usize,
|
||||
env_init: RLC::EnvInit,
|
||||
) -> (RLC::Policy, RLEventProcessorType<RLC>);
|
||||
}
|
||||
@@ -16,10 +16,10 @@ use crate::renderer::{MetricsRenderer, default_renderer};
|
||||
use crate::single::SingleDevicetrainingStrategy;
|
||||
use crate::{
|
||||
ApplicationLoggerInstaller, EarlyStoppingStrategyRef, FileApplicationLoggerInstaller,
|
||||
InferenceBackend, InferenceModel, InferenceModelInput, InferenceStep, LearnerModelRecord,
|
||||
LearnerOptimizerRecord, LearnerSchedulerRecord, LearnerSummaryConfig, LearningCheckpointer,
|
||||
LearningComponentsMarker, LearningComponentsTypes, LearningResult, TrainStep, TrainingBackend,
|
||||
TrainingComponents, TrainingModelInput, TrainingStrategy,
|
||||
InferenceBackend, InferenceModel, InferenceModelInput, InferenceStep, LearnerEvent,
|
||||
LearnerModelRecord, LearnerOptimizerRecord, LearnerSchedulerRecord, LearnerSummaryConfig,
|
||||
LearningCheckpointer, LearningComponentsMarker, LearningComponentsTypes, LearningResult,
|
||||
TrainStep, TrainingBackend, TrainingComponents, TrainingModelInput, TrainingStrategy,
|
||||
};
|
||||
use crate::{Learner, SupervisedLearningStrategy};
|
||||
use burn_core::data::dataloader::DataLoader;
|
||||
@@ -38,7 +38,8 @@ pub type TrainLoader<LC> = Arc<dyn DataLoader<TrainingBackend<LC>, TrainingModel
|
||||
pub type ValidLoader<LC> = Arc<dyn DataLoader<InferenceBackend<LC>, InferenceModelInput<LC>>>;
|
||||
/// The event processor type for supervised learning.
|
||||
pub type SupervisedTrainingEventProcessor<LC> = AsyncProcessorTraining<
|
||||
FullEventProcessorTraining<TrainingModelOutput<LC>, InferenceModelOutput<LC>>,
|
||||
LearnerEvent<TrainingModelOutput<LC>>,
|
||||
LearnerEvent<InferenceModelOutput<LC>>,
|
||||
>;
|
||||
|
||||
/// Structure to configure and launch supervised learning trainings.
|
||||
@@ -181,7 +182,7 @@ impl<LC: LearningComponentsTypes> SupervisedTraining<LC> {
|
||||
metrics.register(self)
|
||||
}
|
||||
|
||||
/// Register all metrics as numeric for the training and validation set.
|
||||
/// Register all metrics as text for the training and validation set.
|
||||
pub fn metrics_text<Me: TextMetricRegistration<LC>>(self, metrics: Me) -> Self {
|
||||
metrics.register(self)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use burn_collective::{PeerId, ReduceOperation};
|
||||
use burn_core::data::dataloader::Progress;
|
||||
use burn_core::module::AutodiffModule;
|
||||
use burn_core::tensor::backend::AutodiffBackend;
|
||||
use burn_optim::GradientsAccumulator;
|
||||
@@ -9,7 +10,7 @@ use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::SupervisedTrainingEventProcessor;
|
||||
use crate::learner::base::Interrupter;
|
||||
use crate::metric::processor::{EventProcessorTraining, LearnerEvent, LearnerItem};
|
||||
use crate::metric::processor::{EventProcessorTraining, LearnerEvent, TrainingItem};
|
||||
use crate::{
|
||||
InferenceStep, Learner, LearningComponentsTypes, TrainLoader, TrainingBackend, ValidLoader,
|
||||
};
|
||||
@@ -18,14 +19,12 @@ use crate::{
|
||||
#[derive(new)]
|
||||
pub struct DdpValidEpoch<LC: LearningComponentsTypes> {
|
||||
dataloader: ValidLoader<LC>,
|
||||
epoch_total: usize,
|
||||
}
|
||||
|
||||
/// A training epoch.
|
||||
#[derive(new)]
|
||||
pub struct DdpTrainEpoch<LC: LearningComponentsTypes> {
|
||||
dataloader: TrainLoader<LC>,
|
||||
epoch_total: usize,
|
||||
grad_accumulation: Option<usize>,
|
||||
}
|
||||
|
||||
@@ -39,10 +38,11 @@ impl<LC: LearningComponentsTypes> DdpValidEpoch<LC> {
|
||||
pub fn run(
|
||||
&self,
|
||||
model: &<LC as LearningComponentsTypes>::TrainingModel,
|
||||
epoch: usize,
|
||||
global_progress: &Progress,
|
||||
processor: &mut SupervisedTrainingEventProcessor<LC>,
|
||||
interrupter: &Interrupter,
|
||||
) {
|
||||
let epoch = global_progress.items_processed;
|
||||
log::info!("Executing validation step for epoch {}", epoch);
|
||||
let model = model.valid();
|
||||
|
||||
@@ -54,7 +54,13 @@ impl<LC: LearningComponentsTypes> DdpValidEpoch<LC> {
|
||||
iteration += 1;
|
||||
|
||||
let item = model.step(item);
|
||||
let item = LearnerItem::new(item, progress, epoch, self.epoch_total, iteration, None);
|
||||
let item = TrainingItem::new(
|
||||
item,
|
||||
progress,
|
||||
global_progress.clone(),
|
||||
Some(iteration),
|
||||
None,
|
||||
);
|
||||
|
||||
processor.process_valid(LearnerEvent::ProcessedItem(item));
|
||||
|
||||
@@ -84,13 +90,14 @@ impl<LC: LearningComponentsTypes> DdpTrainEpoch<LC> {
|
||||
pub fn run(
|
||||
&self,
|
||||
learner: &mut Learner<LC>,
|
||||
epoch: usize,
|
||||
global_progress: &Progress,
|
||||
processor: Arc<Mutex<SupervisedTrainingEventProcessor<LC>>>,
|
||||
interrupter: &Interrupter,
|
||||
peer_id: PeerId,
|
||||
peer_count: usize,
|
||||
is_main: bool,
|
||||
) {
|
||||
let epoch = global_progress.items_processed;
|
||||
log::info!("Executing training step for epoch {}", epoch,);
|
||||
|
||||
let mut iterator = self.dataloader.iter();
|
||||
@@ -143,12 +150,11 @@ impl<LC: LearningComponentsTypes> DdpTrainEpoch<LC> {
|
||||
}
|
||||
}
|
||||
|
||||
let item = LearnerItem::new(
|
||||
let item = TrainingItem::new(
|
||||
item.item,
|
||||
progress,
|
||||
epoch,
|
||||
self.epoch_total,
|
||||
iteration,
|
||||
global_progress.clone(),
|
||||
Some(iteration),
|
||||
Some(learner.lr_current()),
|
||||
);
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use crate::ddp::epoch::{DdpTrainEpoch, DdpValidEpoch};
|
||||
use crate::ddp::strategy::WorkerComponents;
|
||||
use crate::single::TrainingLoop;
|
||||
use crate::{
|
||||
Learner, LearningCheckpointer, LearningComponentsTypes, SupervisedTrainingEventProcessor,
|
||||
TrainLoader, TrainingBackend, ValidLoader,
|
||||
@@ -83,18 +84,19 @@ where
|
||||
// Changed the train epoch to keep the dataloaders
|
||||
let epoch_train = DdpTrainEpoch::<LC>::new(
|
||||
self.dataloader_train.clone(),
|
||||
num_epochs,
|
||||
self.components.grad_accumulation,
|
||||
);
|
||||
let epoch_valid = self
|
||||
.dataloader_valid
|
||||
.map(|dataloader| DdpValidEpoch::<LC>::new(dataloader, num_epochs));
|
||||
.map(|dataloader| DdpValidEpoch::<LC>::new(dataloader));
|
||||
self.learner.fork(&self.device);
|
||||
|
||||
for epoch in self.starting_epoch..num_epochs + 1 {
|
||||
for training_progress in TrainingLoop::new(self.starting_epoch, num_epochs) {
|
||||
let epoch = training_progress.items_processed;
|
||||
|
||||
epoch_train.run(
|
||||
&mut self.learner,
|
||||
epoch,
|
||||
&training_progress,
|
||||
self.event_processor.clone(),
|
||||
&interrupter,
|
||||
self.peer_id,
|
||||
@@ -111,7 +113,7 @@ where
|
||||
let mut event_processor = self.event_processor.lock().unwrap();
|
||||
runner.run(
|
||||
&self.learner.model(),
|
||||
epoch,
|
||||
&training_progress,
|
||||
&mut event_processor,
|
||||
&interrupter,
|
||||
);
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
use crate::learner::base::Interrupter;
|
||||
use crate::metric::processor::{EventProcessorTraining, LearnerEvent, LearnerItem};
|
||||
use crate::metric::processor::{EventProcessorTraining, LearnerEvent, TrainingItem};
|
||||
use crate::train::MultiDevicesTrainStep;
|
||||
use crate::{
|
||||
Learner, LearningComponentsTypes, MultiDeviceOptim, SupervisedTrainingEventProcessor,
|
||||
TrainLoader, TrainingBackend,
|
||||
};
|
||||
use burn_core::data::dataloader::Progress;
|
||||
use burn_core::prelude::DeviceOps;
|
||||
use burn_core::tensor::Device;
|
||||
use burn_core::tensor::backend::DeviceId;
|
||||
@@ -16,7 +17,6 @@ use std::collections::HashMap;
|
||||
#[derive(new)]
|
||||
pub struct MultiDeviceTrainEpoch<LC: LearningComponentsTypes> {
|
||||
dataloaders: Vec<TrainLoader<LC>>,
|
||||
epoch_total: usize,
|
||||
grad_accumulation: Option<usize>,
|
||||
}
|
||||
|
||||
@@ -38,30 +38,39 @@ impl<LC: LearningComponentsTypes> MultiDeviceTrainEpoch<LC> {
|
||||
pub fn run(
|
||||
&self,
|
||||
learner: &mut Learner<LC>,
|
||||
epoch: usize,
|
||||
global_progress: &Progress,
|
||||
event_processor: &mut SupervisedTrainingEventProcessor<LC>,
|
||||
interrupter: &Interrupter,
|
||||
devices: Vec<Device<TrainingBackend<LC>>>,
|
||||
strategy: MultiDeviceOptim,
|
||||
) {
|
||||
match strategy {
|
||||
MultiDeviceOptim::OptimMainDevice => {
|
||||
self.run_optim_main(learner, epoch, event_processor, interrupter, devices)
|
||||
}
|
||||
MultiDeviceOptim::OptimSharded => {
|
||||
self.run_optim_distr(learner, epoch, event_processor, interrupter, devices)
|
||||
}
|
||||
MultiDeviceOptim::OptimMainDevice => self.run_optim_main(
|
||||
learner,
|
||||
global_progress,
|
||||
event_processor,
|
||||
interrupter,
|
||||
devices,
|
||||
),
|
||||
MultiDeviceOptim::OptimSharded => self.run_optim_distr(
|
||||
learner,
|
||||
global_progress,
|
||||
event_processor,
|
||||
interrupter,
|
||||
devices,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn run_optim_main(
|
||||
&self,
|
||||
learner: &mut Learner<LC>,
|
||||
epoch: usize,
|
||||
global_progress: &Progress,
|
||||
event_processor: &mut SupervisedTrainingEventProcessor<LC>,
|
||||
interrupter: &Interrupter,
|
||||
devices: Vec<Device<TrainingBackend<LC>>>,
|
||||
) {
|
||||
let epoch = global_progress.items_processed;
|
||||
log::info!(
|
||||
"Executing training step for epoch {} on devices {:?}",
|
||||
epoch,
|
||||
@@ -108,12 +117,11 @@ impl<LC: LearningComponentsTypes> MultiDeviceTrainEpoch<LC> {
|
||||
|
||||
for item in progress_items {
|
||||
iteration += 1;
|
||||
let item = LearnerItem::new(
|
||||
let item = TrainingItem::new(
|
||||
item,
|
||||
progress.clone(),
|
||||
epoch,
|
||||
self.epoch_total,
|
||||
iteration,
|
||||
global_progress.clone(),
|
||||
Some(iteration),
|
||||
Some(learner.lr_current()),
|
||||
);
|
||||
|
||||
@@ -131,11 +139,12 @@ impl<LC: LearningComponentsTypes> MultiDeviceTrainEpoch<LC> {
|
||||
fn run_optim_distr(
|
||||
&self,
|
||||
learner: &mut Learner<LC>,
|
||||
epoch: usize,
|
||||
global_progress: &Progress,
|
||||
event_processor: &mut SupervisedTrainingEventProcessor<LC>,
|
||||
interrupter: &Interrupter,
|
||||
devices: Vec<Device<TrainingBackend<LC>>>,
|
||||
) {
|
||||
let epoch = global_progress.items_processed;
|
||||
log::info!(
|
||||
"Executing training step for epoch {} on devices {:?}",
|
||||
epoch,
|
||||
@@ -189,12 +198,11 @@ impl<LC: LearningComponentsTypes> MultiDeviceTrainEpoch<LC> {
|
||||
|
||||
for item in progress_items {
|
||||
iteration += 1;
|
||||
let item = LearnerItem::new(
|
||||
let item = TrainingItem::new(
|
||||
item,
|
||||
progress.clone(),
|
||||
epoch,
|
||||
self.epoch_total,
|
||||
iteration,
|
||||
global_progress.clone(),
|
||||
Some(iteration),
|
||||
Some(learner.lr_current()),
|
||||
);
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
use crate::{
|
||||
Learner, LearningComponentsTypes, MultiDeviceOptim, SupervisedLearningStrategy,
|
||||
SupervisedTrainingEventProcessor, TrainLoader, TrainingBackend, TrainingComponents,
|
||||
TrainingModel, ValidLoader, multi::epoch::MultiDeviceTrainEpoch,
|
||||
single::epoch::SingleDeviceValidEpoch,
|
||||
TrainingModel, ValidLoader,
|
||||
multi::epoch::MultiDeviceTrainEpoch,
|
||||
single::{TrainingLoop, epoch::SingleDeviceValidEpoch},
|
||||
};
|
||||
use burn_core::{data::dataloader::split::split_dataloader, tensor::Device};
|
||||
|
||||
@@ -39,20 +40,19 @@ impl<LC: LearningComponentsTypes> SupervisedLearningStrategy<LC>
|
||||
let mut event_processor = training_components.event_processor;
|
||||
let mut checkpointer = training_components.checkpointer;
|
||||
let mut early_stopping = training_components.early_stopping;
|
||||
let num_epochs = training_components.num_epochs;
|
||||
|
||||
let epoch_train = MultiDeviceTrainEpoch::<LC>::new(
|
||||
dataloader_train.clone(),
|
||||
num_epochs,
|
||||
training_components.grad_accumulation,
|
||||
);
|
||||
let epoch_valid: SingleDeviceValidEpoch<LC> =
|
||||
SingleDeviceValidEpoch::new(dataloader_valid.clone(), num_epochs);
|
||||
SingleDeviceValidEpoch::new(dataloader_valid.clone());
|
||||
|
||||
for epoch in starting_epoch..training_components.num_epochs + 1 {
|
||||
for training_progress in TrainingLoop::new(starting_epoch, training_components.num_epochs) {
|
||||
let epoch = training_progress.items_processed;
|
||||
epoch_train.run(
|
||||
&mut learner,
|
||||
epoch,
|
||||
&training_progress,
|
||||
&mut event_processor,
|
||||
&training_components.interrupter,
|
||||
self.devices.to_vec(),
|
||||
@@ -70,7 +70,7 @@ impl<LC: LearningComponentsTypes> SupervisedLearningStrategy<LC>
|
||||
|
||||
epoch_valid.run(
|
||||
&learner,
|
||||
epoch,
|
||||
&training_progress,
|
||||
&mut event_processor,
|
||||
&training_components.interrupter,
|
||||
);
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
use crate::learner::base::Interrupter;
|
||||
use crate::metric::processor::{EventProcessorTraining, LearnerEvent, LearnerItem};
|
||||
use crate::metric::processor::{EventProcessorTraining, LearnerEvent, TrainingItem};
|
||||
use crate::{
|
||||
InferenceStep, Learner, LearningComponentsTypes, SupervisedTrainingEventProcessor, TrainLoader,
|
||||
ValidLoader,
|
||||
};
|
||||
use burn_core::data::dataloader::Progress;
|
||||
use burn_core::module::AutodiffModule;
|
||||
use burn_optim::GradientsAccumulator;
|
||||
|
||||
@@ -11,14 +12,12 @@ use burn_optim::GradientsAccumulator;
|
||||
#[derive(new)]
|
||||
pub struct SingleDeviceValidEpoch<LC: LearningComponentsTypes> {
|
||||
dataloader: ValidLoader<LC>,
|
||||
epoch_total: usize,
|
||||
}
|
||||
|
||||
/// A training epoch.
|
||||
#[derive(new)]
|
||||
pub struct SingleDeviceTrainEpoch<LC: LearningComponentsTypes> {
|
||||
dataloader: TrainLoader<LC>,
|
||||
epoch_total: usize,
|
||||
grad_accumulation: Option<usize>,
|
||||
}
|
||||
|
||||
@@ -32,10 +31,11 @@ impl<LC: LearningComponentsTypes> SingleDeviceValidEpoch<LC> {
|
||||
pub fn run(
|
||||
&self,
|
||||
learner: &Learner<LC>,
|
||||
epoch: usize,
|
||||
global_progress: &Progress,
|
||||
processor: &mut SupervisedTrainingEventProcessor<LC>,
|
||||
interrupter: &Interrupter,
|
||||
) {
|
||||
let epoch = global_progress.items_processed;
|
||||
log::info!("Executing validation step for epoch {}", epoch);
|
||||
let model = learner.model().valid();
|
||||
|
||||
@@ -47,7 +47,13 @@ impl<LC: LearningComponentsTypes> SingleDeviceValidEpoch<LC> {
|
||||
iteration += 1;
|
||||
|
||||
let item = model.step(item);
|
||||
let item = LearnerItem::new(item, progress, epoch, self.epoch_total, iteration, None);
|
||||
let item = TrainingItem::new(
|
||||
item,
|
||||
progress,
|
||||
global_progress.clone(),
|
||||
Some(iteration),
|
||||
None,
|
||||
);
|
||||
|
||||
processor.process_valid(LearnerEvent::ProcessedItem(item));
|
||||
|
||||
@@ -75,10 +81,11 @@ impl<LC: LearningComponentsTypes> SingleDeviceTrainEpoch<LC> {
|
||||
pub fn run(
|
||||
&self,
|
||||
learner: &mut Learner<LC>,
|
||||
epoch: usize,
|
||||
global_progress: &Progress,
|
||||
processor: &mut SupervisedTrainingEventProcessor<LC>,
|
||||
interrupter: &Interrupter,
|
||||
) {
|
||||
let epoch = global_progress.items_processed;
|
||||
log::info!("Executing training step for epoch {}", epoch,);
|
||||
|
||||
// Single device / dataloader
|
||||
@@ -110,12 +117,11 @@ impl<LC: LearningComponentsTypes> SingleDeviceTrainEpoch<LC> {
|
||||
None => learner.optimizer_step(item.grads),
|
||||
}
|
||||
|
||||
let item = LearnerItem::new(
|
||||
let item = TrainingItem::new(
|
||||
item.item,
|
||||
progress,
|
||||
epoch,
|
||||
self.epoch_total,
|
||||
iteration,
|
||||
global_progress.clone(),
|
||||
Some(iteration),
|
||||
Some(learner.lr_current()),
|
||||
);
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ use crate::{
|
||||
TrainLoader, TrainingBackend, TrainingComponents, TrainingModel, ValidLoader,
|
||||
single::epoch::{SingleDeviceTrainEpoch, SingleDeviceValidEpoch},
|
||||
};
|
||||
use burn_core::tensor::Device;
|
||||
use burn_core::{data::dataloader::Progress, tensor::Device};
|
||||
|
||||
/// Simplest learning strategy possible, with only a single devices doing both the training and
|
||||
/// validation.
|
||||
@@ -16,6 +16,31 @@ impl<LC: LearningComponentsTypes> SingleDevicetrainingStrategy<LC> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub(crate) struct TrainingLoop {
|
||||
next_iteration: usize,
|
||||
total_iteration: usize,
|
||||
}
|
||||
|
||||
/// An iterator that returns the progress of the training.
|
||||
impl Iterator for TrainingLoop {
|
||||
type Item = Progress;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.next_iteration > self.total_iteration {
|
||||
return None;
|
||||
}
|
||||
|
||||
let progress = Progress {
|
||||
items_processed: self.next_iteration,
|
||||
items_total: self.total_iteration,
|
||||
};
|
||||
|
||||
self.next_iteration += 1;
|
||||
Some(progress)
|
||||
}
|
||||
}
|
||||
|
||||
impl<LC: LearningComponentsTypes> SupervisedLearningStrategy<LC>
|
||||
for SingleDevicetrainingStrategy<LC>
|
||||
{
|
||||
@@ -33,20 +58,17 @@ impl<LC: LearningComponentsTypes> SupervisedLearningStrategy<LC>
|
||||
let mut event_processor = training_components.event_processor;
|
||||
let mut checkpointer = training_components.checkpointer;
|
||||
let mut early_stopping = training_components.early_stopping;
|
||||
let num_epochs = training_components.num_epochs;
|
||||
|
||||
let epoch_train: SingleDeviceTrainEpoch<LC> = SingleDeviceTrainEpoch::new(
|
||||
dataloader_train,
|
||||
num_epochs,
|
||||
training_components.grad_accumulation,
|
||||
);
|
||||
let epoch_train: SingleDeviceTrainEpoch<LC> =
|
||||
SingleDeviceTrainEpoch::new(dataloader_train, training_components.grad_accumulation);
|
||||
let epoch_valid: SingleDeviceValidEpoch<LC> =
|
||||
SingleDeviceValidEpoch::new(dataloader_valid.clone(), num_epochs);
|
||||
SingleDeviceValidEpoch::new(dataloader_valid.clone());
|
||||
|
||||
for epoch in starting_epoch..training_components.num_epochs + 1 {
|
||||
for training_progress in TrainingLoop::new(starting_epoch, training_components.num_epochs) {
|
||||
let epoch = training_progress.items_processed;
|
||||
epoch_train.run(
|
||||
&mut learner,
|
||||
epoch,
|
||||
&training_progress,
|
||||
&mut event_processor,
|
||||
&training_components.interrupter,
|
||||
);
|
||||
@@ -62,7 +84,7 @@ impl<LC: LearningComponentsTypes> SupervisedLearningStrategy<LC>
|
||||
|
||||
epoch_valid.run(
|
||||
&learner,
|
||||
epoch,
|
||||
&training_progress,
|
||||
&mut event_processor,
|
||||
&training_components.interrupter,
|
||||
);
|
||||
|
||||
@@ -8,14 +8,11 @@ pub struct MetricMetadata {
|
||||
/// The current progress.
|
||||
pub progress: Progress,
|
||||
|
||||
/// The current epoch.
|
||||
pub epoch: usize,
|
||||
|
||||
/// The total number of epochs.
|
||||
pub epoch_total: usize,
|
||||
/// The global progress of the training (e.g. epochs).
|
||||
pub global_progress: Progress,
|
||||
|
||||
/// The current iteration.
|
||||
pub iteration: usize,
|
||||
pub iteration: Option<usize>,
|
||||
|
||||
/// The current learning rate.
|
||||
pub lr: Option<LearningRate>,
|
||||
@@ -30,9 +27,11 @@ impl MetricMetadata {
|
||||
items_processed: 1,
|
||||
items_total: 1,
|
||||
},
|
||||
epoch: 0,
|
||||
epoch_total: 1,
|
||||
iteration: 0,
|
||||
global_progress: Progress {
|
||||
items_processed: 0,
|
||||
items_total: 1,
|
||||
},
|
||||
iteration: Some(0),
|
||||
lr: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,7 +38,14 @@ impl Metric for IterationSpeedMetric {
|
||||
|
||||
fn update(&mut self, _: &Self::Input, metadata: &MetricMetadata) -> SerializedEntry {
|
||||
let raw = match self.instant {
|
||||
Some(val) => metadata.iteration as f64 / val.elapsed().as_secs_f64(),
|
||||
Some(val) => {
|
||||
// If iteration is not logged, compute the speed over the number of items processed.
|
||||
// 1 iteration should equal 1 item when iteration is not logged.
|
||||
metadata
|
||||
.iteration
|
||||
.unwrap_or(metadata.progress.items_processed) as f64
|
||||
/ val.elapsed().as_secs_f64()
|
||||
}
|
||||
None => {
|
||||
self.instant = Some(std::time::Instant::now());
|
||||
0.0
|
||||
|
||||
@@ -37,6 +37,7 @@ mod loss;
|
||||
mod perplexity;
|
||||
mod precision;
|
||||
mod recall;
|
||||
mod rl;
|
||||
mod top_k_acc;
|
||||
mod wer;
|
||||
|
||||
@@ -53,6 +54,7 @@ pub use loss::*;
|
||||
pub use perplexity::*;
|
||||
pub use precision::*;
|
||||
pub use recall::*;
|
||||
pub use rl::*;
|
||||
pub use top_k_acc::*;
|
||||
pub use wer::*;
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
use crate::metric::processor::{EvaluatorEvent, EventProcessorEvaluation};
|
||||
|
||||
use super::{EventProcessorTraining, LearnerEvent};
|
||||
use super::EventProcessorTraining;
|
||||
use async_channel::{Receiver, Sender};
|
||||
|
||||
/// Event processor for the training process.
|
||||
pub struct AsyncProcessorTraining<P: EventProcessorTraining> {
|
||||
sender: Sender<Message<P>>,
|
||||
pub struct AsyncProcessorTraining<ET, EV> {
|
||||
sender: Sender<Message<ET, EV>>,
|
||||
}
|
||||
|
||||
/// Event processor for the model evaluation.
|
||||
@@ -13,9 +13,9 @@ pub struct AsyncProcessorEvaluation<P: EventProcessorEvaluation> {
|
||||
sender: Sender<EvalMessage<P>>,
|
||||
}
|
||||
|
||||
struct WorkerTraining<P: EventProcessorTraining> {
|
||||
struct WorkerTraining<ET, EV, P: EventProcessorTraining<ET, EV>> {
|
||||
processor: P,
|
||||
rec: Receiver<Message<P>>,
|
||||
rec: Receiver<Message<ET, EV>>,
|
||||
}
|
||||
|
||||
struct WorkerEvaluation<P: EventProcessorEvaluation> {
|
||||
@@ -23,8 +23,10 @@ struct WorkerEvaluation<P: EventProcessorEvaluation> {
|
||||
rec: Receiver<EvalMessage<P>>,
|
||||
}
|
||||
|
||||
impl<P: EventProcessorTraining + 'static> WorkerTraining<P> {
|
||||
pub fn start(processor: P, rec: Receiver<Message<P>>) {
|
||||
impl<ET: Send + 'static, EV: Send + 'static, P: EventProcessorTraining<ET, EV> + 'static>
|
||||
WorkerTraining<ET, EV, P>
|
||||
{
|
||||
pub fn start(processor: P, rec: Receiver<Message<ET, EV>>) {
|
||||
let mut worker = Self { processor, rec };
|
||||
|
||||
std::thread::spawn(move || {
|
||||
@@ -59,9 +61,9 @@ impl<P: EventProcessorEvaluation + 'static> WorkerEvaluation<P> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<P: EventProcessorTraining + 'static> AsyncProcessorTraining<P> {
|
||||
impl<ET: Send + 'static, EV: Send + 'static> AsyncProcessorTraining<ET, EV> {
|
||||
/// Create an event processor for training.
|
||||
pub fn new(processor: P) -> Self {
|
||||
pub fn new<P: EventProcessorTraining<ET, EV> + 'static>(processor: P) -> Self {
|
||||
let (sender, rec) = async_channel::bounded(1);
|
||||
|
||||
WorkerTraining::start(processor, rec);
|
||||
@@ -81,9 +83,9 @@ impl<P: EventProcessorEvaluation + 'static> AsyncProcessorEvaluation<P> {
|
||||
}
|
||||
}
|
||||
|
||||
enum Message<P: EventProcessorTraining> {
|
||||
Train(LearnerEvent<P::ItemTrain>),
|
||||
Valid(LearnerEvent<P::ItemValid>),
|
||||
enum Message<EventTrain, EventValid> {
|
||||
Train(EventTrain),
|
||||
Valid(EventValid),
|
||||
Renderer(Sender<Box<dyn crate::renderer::MetricsRenderer>>),
|
||||
}
|
||||
|
||||
@@ -92,15 +94,12 @@ enum EvalMessage<P: EventProcessorEvaluation> {
|
||||
Renderer(Sender<Box<dyn crate::renderer::MetricsRenderer>>),
|
||||
}
|
||||
|
||||
impl<P: EventProcessorTraining> EventProcessorTraining for AsyncProcessorTraining<P> {
|
||||
type ItemTrain = P::ItemTrain;
|
||||
type ItemValid = P::ItemValid;
|
||||
|
||||
fn process_train(&mut self, event: LearnerEvent<Self::ItemTrain>) {
|
||||
impl<ET: Send, EV: Send> EventProcessorTraining<ET, EV> for AsyncProcessorTraining<ET, EV> {
|
||||
fn process_train(&mut self, event: ET) {
|
||||
self.sender.send_blocking(Message::Train(event)).unwrap();
|
||||
}
|
||||
|
||||
fn process_valid(&mut self, event: LearnerEvent<Self::ItemValid>) {
|
||||
fn process_valid(&mut self, event: EV) {
|
||||
self.sender.send_blocking(Message::Valid(event)).unwrap();
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ use burn_core::data::dataloader::Progress;
|
||||
use burn_optim::LearningRate;
|
||||
|
||||
use crate::{
|
||||
LearnerSummary,
|
||||
EpisodeSummary, LearnerSummary,
|
||||
renderer::{EvaluationName, MetricsRenderer},
|
||||
};
|
||||
|
||||
@@ -11,19 +11,45 @@ pub enum LearnerEvent<T> {
|
||||
/// Signal the start of the process (e.g., training start)
|
||||
Start,
|
||||
/// Signal that an item have been processed.
|
||||
ProcessedItem(LearnerItem<T>),
|
||||
ProcessedItem(TrainingItem<T>),
|
||||
/// Signal the end of an epoch.
|
||||
EndEpoch(usize),
|
||||
/// Signal the end of the process (e.g., training end).
|
||||
End(Option<LearnerSummary>),
|
||||
}
|
||||
|
||||
/// Event happening during reinforcement learning.
|
||||
pub enum RLEvent<TS, ES> {
|
||||
/// Signal the start of the process (e.g., learning starts).
|
||||
Start,
|
||||
/// Signal an agent's training step.
|
||||
TrainStep(EvaluationItem<TS>),
|
||||
/// Signal a timestep of the agent-environment interface.
|
||||
TimeStep(EvaluationItem<ES>),
|
||||
/// Signal an episode end.
|
||||
EpisodeEnd(EvaluationItem<EpisodeSummary>),
|
||||
/// Signal the end of the process (e.g., learning ends).
|
||||
End(Option<LearnerSummary>),
|
||||
}
|
||||
|
||||
/// Event happening during evaluation of a reinforcement learning's agent.
|
||||
pub enum AgentEvaluationEvent<T> {
|
||||
/// Signal the start of the process (e.g., training start)
|
||||
Start,
|
||||
/// Signal a timestep of the agent-environment interface.
|
||||
TimeStep(EvaluationItem<T>),
|
||||
/// Signal an episode end.
|
||||
EpisodeEnd(EvaluationItem<EpisodeSummary>),
|
||||
/// Signal the end of the process (e.g., training end).
|
||||
End,
|
||||
}
|
||||
|
||||
/// Event happening during the evaluation process.
|
||||
pub enum EvaluatorEvent<T> {
|
||||
/// Signal the start of the process (e.g., training start)
|
||||
Start,
|
||||
/// Signal that an item have been processed.
|
||||
ProcessedItem(EvaluationName, LearnerItem<T>),
|
||||
ProcessedItem(EvaluationName, EvaluationItem<T>),
|
||||
/// Signal the end of the process (e.g., training end).
|
||||
End,
|
||||
}
|
||||
@@ -40,16 +66,11 @@ pub trait ItemLazy: Send {
|
||||
}
|
||||
|
||||
/// Process events happening during training and validation.
|
||||
pub trait EventProcessorTraining: Send {
|
||||
/// The training item.
|
||||
type ItemTrain: ItemLazy;
|
||||
/// The validation item.
|
||||
type ItemValid: ItemLazy;
|
||||
|
||||
pub trait EventProcessorTraining<TrainEvent, ValidEvent>: Send {
|
||||
/// Collect a training event.
|
||||
fn process_train(&mut self, event: LearnerEvent<Self::ItemTrain>);
|
||||
fn process_train(&mut self, event: TrainEvent);
|
||||
/// Collect a validation event.
|
||||
fn process_valid(&mut self, event: LearnerEvent<Self::ItemValid>);
|
||||
fn process_valid(&mut self, event: ValidEvent);
|
||||
/// Returns the renderer used for training.
|
||||
fn renderer(self) -> Box<dyn MetricsRenderer>;
|
||||
}
|
||||
@@ -68,41 +89,62 @@ pub trait EventProcessorEvaluation: Send {
|
||||
|
||||
/// A learner item.
|
||||
#[derive(new)]
|
||||
pub struct LearnerItem<T> {
|
||||
pub struct TrainingItem<T> {
|
||||
/// The item.
|
||||
pub item: T,
|
||||
|
||||
/// The progress.
|
||||
pub progress: Progress,
|
||||
|
||||
/// The epoch.
|
||||
pub epoch: usize,
|
||||
/// The global progress of the training (e.g. epochs).
|
||||
pub global_progress: Progress,
|
||||
|
||||
/// The total number of epochs.
|
||||
pub epoch_total: usize,
|
||||
|
||||
/// The iteration.
|
||||
pub iteration: usize,
|
||||
/// The iteration, if it it different from the items processed.
|
||||
pub iteration: Option<usize>,
|
||||
|
||||
/// The learning rate.
|
||||
pub lr: Option<LearningRate>,
|
||||
}
|
||||
|
||||
impl<T: ItemLazy> ItemLazy for LearnerItem<T> {
|
||||
type ItemSync = LearnerItem<T::ItemSync>;
|
||||
impl<T: ItemLazy> ItemLazy for TrainingItem<T> {
|
||||
type ItemSync = TrainingItem<T::ItemSync>;
|
||||
|
||||
fn sync(self) -> Self::ItemSync {
|
||||
LearnerItem {
|
||||
TrainingItem {
|
||||
item: self.item.sync(),
|
||||
progress: self.progress,
|
||||
epoch: self.epoch,
|
||||
epoch_total: self.epoch_total,
|
||||
global_progress: self.global_progress,
|
||||
iteration: self.iteration,
|
||||
lr: self.lr,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An evaluation item.
|
||||
#[derive(new)]
|
||||
pub struct EvaluationItem<T> {
|
||||
/// The item.
|
||||
pub item: T,
|
||||
|
||||
/// The progress.
|
||||
pub progress: Progress,
|
||||
|
||||
/// The iteration, if it it different from the items processed.
|
||||
pub iteration: Option<usize>,
|
||||
}
|
||||
|
||||
impl<T: ItemLazy> ItemLazy for EvaluationItem<T> {
|
||||
type ItemSync = EvaluationItem<T::ItemSync>;
|
||||
|
||||
fn sync(self) -> Self::ItemSync {
|
||||
EvaluationItem {
|
||||
item: self.item.sync(),
|
||||
progress: self.progress,
|
||||
iteration: self.iteration,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ItemLazy for () {
|
||||
type ItemSync = ();
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
use super::{EventProcessorTraining, ItemLazy, LearnerEvent, MetricsTraining};
|
||||
use crate::metric::processor::{EvaluatorEvent, EventProcessorEvaluation, MetricsEvaluation};
|
||||
use crate::metric::store::{EpochSummary, EventStoreClient, Split};
|
||||
use crate::renderer::{MetricState, MetricsRenderer};
|
||||
use crate::renderer::{
|
||||
EvaluationProgress, MetricState, MetricsRenderer, ProgressType, TrainingProgress,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// An [event processor](EventProcessorTraining) that handles:
|
||||
@@ -34,6 +36,30 @@ impl<T: ItemLazy, V: ItemLazy> FullEventProcessorTraining<T, V> {
|
||||
store,
|
||||
}
|
||||
}
|
||||
|
||||
fn progress_indicators(&self, progress: &TrainingProgress) -> Vec<ProgressType> {
|
||||
let mut indicators = vec![];
|
||||
indicators.push(ProgressType::Detailed {
|
||||
tag: String::from("Epoch"),
|
||||
progress: progress.global_progress.clone(),
|
||||
});
|
||||
|
||||
if let Some(iteration) = progress.iteration {
|
||||
indicators.push(ProgressType::Value {
|
||||
tag: String::from("Iteration"),
|
||||
value: iteration,
|
||||
});
|
||||
};
|
||||
|
||||
if let Some(p) = &progress.progress {
|
||||
indicators.push(ProgressType::Detailed {
|
||||
tag: String::from("Items"),
|
||||
progress: p.clone(),
|
||||
});
|
||||
};
|
||||
|
||||
indicators
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ItemLazy> FullEventProcessorEvaluation<T> {
|
||||
@@ -48,6 +74,23 @@ impl<T: ItemLazy> FullEventProcessorEvaluation<T> {
|
||||
store,
|
||||
}
|
||||
}
|
||||
|
||||
fn progress_indicators(&self, progress: &EvaluationProgress) -> Vec<ProgressType> {
|
||||
let mut indicators = vec![];
|
||||
if let Some(iteration) = progress.iteration {
|
||||
indicators.push(ProgressType::Value {
|
||||
tag: String::from("Iteration"),
|
||||
value: iteration,
|
||||
});
|
||||
};
|
||||
|
||||
indicators.push(ProgressType::Detailed {
|
||||
tag: String::from("Items"),
|
||||
progress: progress.progress.clone(),
|
||||
});
|
||||
|
||||
indicators
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ItemLazy> EventProcessorEvaluation for FullEventProcessorEvaluation<T> {
|
||||
@@ -95,7 +138,8 @@ impl<T: ItemLazy> EventProcessorEvaluation for FullEventProcessorEvaluation<T> {
|
||||
)
|
||||
});
|
||||
|
||||
self.renderer.render_test(progress);
|
||||
let indicators = self.progress_indicators(&progress);
|
||||
self.renderer.render_test(progress, indicators);
|
||||
}
|
||||
EvaluatorEvent::End => {
|
||||
self.renderer.on_test_end().ok();
|
||||
@@ -108,11 +152,10 @@ impl<T: ItemLazy> EventProcessorEvaluation for FullEventProcessorEvaluation<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ItemLazy, V: ItemLazy> EventProcessorTraining for FullEventProcessorTraining<T, V> {
|
||||
type ItemTrain = T;
|
||||
type ItemValid = V;
|
||||
|
||||
fn process_train(&mut self, event: LearnerEvent<Self::ItemTrain>) {
|
||||
impl<T: ItemLazy, V: ItemLazy> EventProcessorTraining<LearnerEvent<T>, LearnerEvent<V>>
|
||||
for FullEventProcessorTraining<T, V>
|
||||
{
|
||||
fn process_train(&mut self, event: LearnerEvent<T>) {
|
||||
match event {
|
||||
LearnerEvent::Start => {
|
||||
let definitions = self.metrics.metric_definitions();
|
||||
@@ -149,7 +192,8 @@ impl<T: ItemLazy, V: ItemLazy> EventProcessorTraining for FullEventProcessorTrai
|
||||
))
|
||||
});
|
||||
|
||||
self.renderer.render_train(progress);
|
||||
let indicators = self.progress_indicators(&progress);
|
||||
self.renderer.render_train(progress, indicators);
|
||||
}
|
||||
LearnerEvent::EndEpoch(epoch) => {
|
||||
self.store
|
||||
@@ -165,7 +209,7 @@ impl<T: ItemLazy, V: ItemLazy> EventProcessorTraining for FullEventProcessorTrai
|
||||
}
|
||||
}
|
||||
|
||||
fn process_valid(&mut self, event: LearnerEvent<Self::ItemValid>) {
|
||||
fn process_valid(&mut self, event: LearnerEvent<V>) {
|
||||
match event {
|
||||
LearnerEvent::Start => {} // no-op for now
|
||||
LearnerEvent::ProcessedItem(item) => {
|
||||
@@ -193,7 +237,8 @@ impl<T: ItemLazy, V: ItemLazy> EventProcessorTraining for FullEventProcessorTrai
|
||||
))
|
||||
});
|
||||
|
||||
self.renderer.render_valid(progress);
|
||||
let indicators = self.progress_indicators(&progress);
|
||||
self.renderer.render_valid(progress, indicators);
|
||||
}
|
||||
LearnerEvent::EndEpoch(epoch) => {
|
||||
self.store
|
||||
@@ -206,7 +251,7 @@ impl<T: ItemLazy, V: ItemLazy> EventProcessorTraining for FullEventProcessorTrai
|
||||
LearnerEvent::End(_) => {} // no-op for now
|
||||
}
|
||||
}
|
||||
fn renderer(self) -> Box<dyn crate::renderer::MetricsRenderer> {
|
||||
fn renderer(self) -> Box<dyn MetricsRenderer> {
|
||||
self.renderer
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::{ItemLazy, LearnerItem};
|
||||
use super::{ItemLazy, TrainingItem};
|
||||
use crate::{
|
||||
EvaluationItem,
|
||||
metric::{
|
||||
Adaptor, Metric, MetricDefinition, MetricEntry, MetricId, MetricMetadata, Numeric,
|
||||
store::{MetricsUpdate, NumericMetricUpdate},
|
||||
@@ -83,19 +84,19 @@ impl<T: ItemLazy> MetricsEvaluation<T> {
|
||||
/// Update the testing information from the testing item.
|
||||
pub(crate) fn update_test(
|
||||
&mut self,
|
||||
item: &LearnerItem<T::ItemSync>,
|
||||
item: &EvaluationItem<T::ItemSync>,
|
||||
metadata: &MetricMetadata,
|
||||
) -> MetricsUpdate {
|
||||
let mut entries = Vec::with_capacity(self.test.len());
|
||||
let mut entries_numeric = Vec::with_capacity(self.test_numeric.len());
|
||||
|
||||
for metric in self.test.iter_mut() {
|
||||
let state = metric.update(item, metadata);
|
||||
let state = metric.update(&item.item, metadata);
|
||||
entries.push(state);
|
||||
}
|
||||
|
||||
for metric in self.test_numeric.iter_mut() {
|
||||
let numeric_update = metric.update(item, metadata);
|
||||
let numeric_update = metric.update(&item.item, metadata);
|
||||
entries_numeric.push(numeric_update);
|
||||
}
|
||||
|
||||
@@ -162,19 +163,19 @@ impl<T: ItemLazy, V: ItemLazy> MetricsTraining<T, V> {
|
||||
/// Update the training information from the training item.
|
||||
pub(crate) fn update_train(
|
||||
&mut self,
|
||||
item: &LearnerItem<T::ItemSync>,
|
||||
item: &TrainingItem<T::ItemSync>,
|
||||
metadata: &MetricMetadata,
|
||||
) -> MetricsUpdate {
|
||||
let mut entries = Vec::with_capacity(self.train.len());
|
||||
let mut entries_numeric = Vec::with_capacity(self.train_numeric.len());
|
||||
|
||||
for metric in self.train.iter_mut() {
|
||||
let state = metric.update(item, metadata);
|
||||
let state = metric.update(&item.item, metadata);
|
||||
entries.push(state);
|
||||
}
|
||||
|
||||
for metric in self.train_numeric.iter_mut() {
|
||||
let numeric_update = metric.update(item, metadata);
|
||||
let numeric_update = metric.update(&item.item, metadata);
|
||||
entries_numeric.push(numeric_update);
|
||||
}
|
||||
|
||||
@@ -184,19 +185,19 @@ impl<T: ItemLazy, V: ItemLazy> MetricsTraining<T, V> {
|
||||
/// Update the training information from the validation item.
|
||||
pub(crate) fn update_valid(
|
||||
&mut self,
|
||||
item: &LearnerItem<V::ItemSync>,
|
||||
item: &TrainingItem<V::ItemSync>,
|
||||
metadata: &MetricMetadata,
|
||||
) -> MetricsUpdate {
|
||||
let mut entries = Vec::with_capacity(self.valid.len());
|
||||
let mut entries_numeric = Vec::with_capacity(self.valid_numeric.len());
|
||||
|
||||
for metric in self.valid.iter_mut() {
|
||||
let state = metric.update(item, metadata);
|
||||
let state = metric.update(&item.item, metadata);
|
||||
entries.push(state);
|
||||
}
|
||||
|
||||
for metric in self.valid_numeric.iter_mut() {
|
||||
let numeric_update = metric.update(item, metadata);
|
||||
let numeric_update = metric.update(&item.item, metadata);
|
||||
entries_numeric.push(numeric_update);
|
||||
}
|
||||
|
||||
@@ -224,19 +225,28 @@ impl<T: ItemLazy, V: ItemLazy> MetricsTraining<T, V> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<&LearnerItem<T>> for TrainingProgress {
|
||||
fn from(item: &LearnerItem<T>) -> Self {
|
||||
impl<T> From<&TrainingItem<T>> for TrainingProgress {
|
||||
fn from(item: &TrainingItem<T>) -> Self {
|
||||
Self {
|
||||
progress: item.progress.clone(),
|
||||
epoch: item.epoch,
|
||||
epoch_total: item.epoch_total,
|
||||
progress: Some(item.progress.clone()),
|
||||
global_progress: item.global_progress.clone(),
|
||||
iteration: item.iteration,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<&LearnerItem<T>> for EvaluationProgress {
|
||||
fn from(item: &LearnerItem<T>) -> Self {
|
||||
impl<T> From<&EvaluationItem<T>> for TrainingProgress {
|
||||
fn from(item: &EvaluationItem<T>) -> Self {
|
||||
Self {
|
||||
progress: None,
|
||||
global_progress: item.progress.clone(),
|
||||
iteration: item.iteration,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<&EvaluationItem<T>> for EvaluationProgress {
|
||||
fn from(item: &EvaluationItem<T>) -> Self {
|
||||
Self {
|
||||
progress: item.progress.clone(),
|
||||
iteration: item.iteration,
|
||||
@@ -244,31 +254,41 @@ impl<T> From<&LearnerItem<T>> for EvaluationProgress {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<&LearnerItem<T>> for MetricMetadata {
|
||||
fn from(item: &LearnerItem<T>) -> Self {
|
||||
impl<T> From<&TrainingItem<T>> for MetricMetadata {
|
||||
fn from(item: &TrainingItem<T>) -> Self {
|
||||
Self {
|
||||
progress: item.progress.clone(),
|
||||
epoch: item.epoch,
|
||||
epoch_total: item.epoch_total,
|
||||
global_progress: item.global_progress.clone(),
|
||||
iteration: item.iteration,
|
||||
lr: item.lr,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trait NumericMetricUpdater<T>: Send + Sync {
|
||||
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> NumericMetricUpdate;
|
||||
impl<T> From<&EvaluationItem<T>> for MetricMetadata {
|
||||
fn from(item: &EvaluationItem<T>) -> Self {
|
||||
Self {
|
||||
progress: item.progress.clone(),
|
||||
global_progress: item.progress.clone(),
|
||||
iteration: item.iteration,
|
||||
lr: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) trait NumericMetricUpdater<T>: Send + Sync {
|
||||
fn update(&mut self, item: &T, metadata: &MetricMetadata) -> NumericMetricUpdate;
|
||||
fn clear(&mut self);
|
||||
}
|
||||
|
||||
trait MetricUpdater<T>: Send + Sync {
|
||||
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> MetricEntry;
|
||||
pub(crate) trait MetricUpdater<T>: Send + Sync {
|
||||
fn update(&mut self, item: &T, metadata: &MetricMetadata) -> MetricEntry;
|
||||
fn clear(&mut self);
|
||||
}
|
||||
|
||||
struct MetricWrapper<M> {
|
||||
id: MetricId,
|
||||
metric: M,
|
||||
pub(crate) struct MetricWrapper<M> {
|
||||
pub id: MetricId,
|
||||
pub metric: M,
|
||||
}
|
||||
|
||||
impl<M: Metric> MetricWrapper<M> {
|
||||
@@ -286,8 +306,8 @@ where
|
||||
M: Metric + Numeric + 'static,
|
||||
T: Adaptor<M::Input>,
|
||||
{
|
||||
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> NumericMetricUpdate {
|
||||
let serialized_entry = self.metric.update(&item.item.adapt(), metadata);
|
||||
fn update(&mut self, item: &T, metadata: &MetricMetadata) -> NumericMetricUpdate {
|
||||
let serialized_entry = self.metric.update(&item.adapt(), metadata);
|
||||
let update = MetricEntry::new(self.id.clone(), serialized_entry);
|
||||
let numeric = self.metric.value();
|
||||
let running = self.metric.running_value();
|
||||
@@ -310,8 +330,8 @@ where
|
||||
M: Metric + 'static,
|
||||
T: Adaptor<M::Input>,
|
||||
{
|
||||
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> MetricEntry {
|
||||
let serialized_entry = self.metric.update(&item.item.adapt(), metadata);
|
||||
fn update(&mut self, item: &T, metadata: &MetricMetadata) -> MetricEntry {
|
||||
let serialized_entry = self.metric.update(&item.adapt(), metadata);
|
||||
MetricEntry::new(self.id.clone(), serialized_entry)
|
||||
}
|
||||
|
||||
|
||||
@@ -14,11 +14,10 @@ pub(crate) struct MinimalEventProcessor<T: ItemLazy, V: ItemLazy> {
|
||||
store: Arc<EventStoreClient>,
|
||||
}
|
||||
|
||||
impl<T: ItemLazy, V: ItemLazy> EventProcessorTraining for MinimalEventProcessor<T, V> {
|
||||
type ItemTrain = T;
|
||||
type ItemValid = V;
|
||||
|
||||
fn process_train(&mut self, event: LearnerEvent<Self::ItemTrain>) {
|
||||
impl<T: ItemLazy, V: ItemLazy> EventProcessorTraining<LearnerEvent<T>, LearnerEvent<V>>
|
||||
for MinimalEventProcessor<T, V>
|
||||
{
|
||||
fn process_train(&mut self, event: LearnerEvent<T>) {
|
||||
match event {
|
||||
LearnerEvent::Start => {
|
||||
let definitions = self.metrics.metric_definitions();
|
||||
@@ -47,7 +46,7 @@ impl<T: ItemLazy, V: ItemLazy> EventProcessorTraining for MinimalEventProcessor<
|
||||
}
|
||||
}
|
||||
|
||||
fn process_valid(&mut self, event: LearnerEvent<Self::ItemValid>) {
|
||||
fn process_valid(&mut self, event: LearnerEvent<V>) {
|
||||
match event {
|
||||
LearnerEvent::Start => {} // no-op for now
|
||||
LearnerEvent::ProcessedItem(item) => {
|
||||
|
||||
@@ -3,10 +3,14 @@ mod base;
|
||||
mod full;
|
||||
mod metrics;
|
||||
mod minimal;
|
||||
mod rl_metrics;
|
||||
mod rl_processor;
|
||||
|
||||
pub use base::*;
|
||||
pub(crate) use full::*;
|
||||
pub(crate) use metrics::*;
|
||||
pub(crate) use rl_metrics::*;
|
||||
pub(crate) use rl_processor::*;
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) use minimal::*;
|
||||
@@ -17,7 +21,7 @@ pub use async_wrapper::{AsyncProcessorEvaluation, AsyncProcessorTraining};
|
||||
pub(crate) mod test_utils {
|
||||
use crate::metric::{
|
||||
Adaptor, LossInput,
|
||||
processor::{EventProcessorTraining, LearnerEvent, LearnerItem, MinimalEventProcessor},
|
||||
processor::{EventProcessorTraining, LearnerEvent, MinimalEventProcessor, TrainingItem},
|
||||
};
|
||||
use burn_core::tensor::{ElementConversion, Tensor, backend::Backend};
|
||||
|
||||
@@ -47,14 +51,16 @@ pub(crate) mod test_utils {
|
||||
items_processed: 1,
|
||||
items_total: 10,
|
||||
};
|
||||
let num_epochs = 3;
|
||||
let dummy_iteration = 1;
|
||||
let dummy_global_progress = burn_core::data::dataloader::Progress {
|
||||
items_processed: epoch,
|
||||
items_total: 3,
|
||||
};
|
||||
let dummy_iteration = Some(1);
|
||||
|
||||
processor.process_train(LearnerEvent::ProcessedItem(LearnerItem::new(
|
||||
processor.process_train(LearnerEvent::ProcessedItem(TrainingItem::new(
|
||||
value,
|
||||
dummy_progress,
|
||||
epoch,
|
||||
num_epochs,
|
||||
dummy_global_progress,
|
||||
dummy_iteration,
|
||||
None,
|
||||
)));
|
||||
|
||||
268
crates/burn-train/src/metric/processor/rl_metrics.rs
Normal file
268
crates/burn-train/src/metric/processor/rl_metrics.rs
Normal file
@@ -0,0 +1,268 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::{
|
||||
EpisodeSummary, EvaluationItem, ItemLazy, MetricUpdater, MetricWrapper, NumericMetricUpdater,
|
||||
metric::{
|
||||
Adaptor, Metric, MetricDefinition, MetricId, MetricMetadata, Numeric, store::MetricsUpdate,
|
||||
},
|
||||
};
|
||||
|
||||
pub(crate) struct RLMetrics<TS: ItemLazy, ES: ItemLazy> {
|
||||
train_step: Vec<Box<dyn MetricUpdater<TS::ItemSync>>>,
|
||||
env_step: Vec<Box<dyn MetricUpdater<ES::ItemSync>>>,
|
||||
env_step_valid: Vec<Box<dyn MetricUpdater<ES::ItemSync>>>,
|
||||
episode_end: Vec<Box<dyn MetricUpdater<EpisodeSummary>>>,
|
||||
episode_end_valid: Vec<Box<dyn MetricUpdater<EpisodeSummary>>>,
|
||||
|
||||
train_step_numeric: Vec<Box<dyn NumericMetricUpdater<TS::ItemSync>>>,
|
||||
env_step_numeric: Vec<Box<dyn NumericMetricUpdater<ES::ItemSync>>>,
|
||||
env_step_valid_numeric: Vec<Box<dyn NumericMetricUpdater<ES::ItemSync>>>,
|
||||
episode_end_numeric: Vec<Box<dyn NumericMetricUpdater<EpisodeSummary>>>,
|
||||
episode_end_valid_numeric: Vec<Box<dyn NumericMetricUpdater<EpisodeSummary>>>,
|
||||
|
||||
metric_definitions: HashMap<MetricId, MetricDefinition>,
|
||||
}
|
||||
|
||||
impl<TS: ItemLazy, ES: ItemLazy> Default for RLMetrics<TS, ES> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
train_step: Vec::default(),
|
||||
env_step: Vec::default(),
|
||||
env_step_valid: Vec::default(),
|
||||
episode_end: Vec::default(),
|
||||
episode_end_valid: Vec::default(),
|
||||
train_step_numeric: Vec::default(),
|
||||
env_step_numeric: Vec::default(),
|
||||
env_step_valid_numeric: Vec::default(),
|
||||
episode_end_numeric: Vec::default(),
|
||||
episode_end_valid_numeric: Vec::default(),
|
||||
metric_definitions: HashMap::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<TS: ItemLazy, ES: ItemLazy> RLMetrics<TS, ES> {
|
||||
/// Register a training metric.
|
||||
pub(crate) fn register_text_metric_agent<Me: Metric + 'static>(&mut self, metric: Me)
|
||||
where
|
||||
ES::ItemSync: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.register_definition(&metric);
|
||||
self.env_step.push(Box::new(metric))
|
||||
}
|
||||
|
||||
/// Register a training metric.
|
||||
pub(crate) fn register_agent_metric<Me: Metric + Numeric + 'static>(&mut self, metric: Me)
|
||||
where
|
||||
ES::ItemSync: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.register_definition(&metric);
|
||||
self.env_step_numeric.push(Box::new(metric))
|
||||
}
|
||||
|
||||
/// Register a training metric.
|
||||
pub(crate) fn register_text_metric_train<Me: Metric + 'static>(&mut self, metric: Me)
|
||||
where
|
||||
TS::ItemSync: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.register_definition(&metric);
|
||||
self.train_step.push(Box::new(metric))
|
||||
}
|
||||
|
||||
/// Register a training metric.
|
||||
pub(crate) fn register_metric_train<Me: Metric + Numeric + 'static>(&mut self, metric: Me)
|
||||
where
|
||||
TS::ItemSync: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.register_definition(&metric);
|
||||
self.train_step_numeric.push(Box::new(metric))
|
||||
}
|
||||
|
||||
/// Register a validation env-step metric.
|
||||
pub(crate) fn register_text_metric_agent_valid<Me: Metric + 'static>(&mut self, metric: Me)
|
||||
where
|
||||
ES::ItemSync: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.register_definition(&metric);
|
||||
self.env_step_valid.push(Box::new(metric))
|
||||
}
|
||||
|
||||
/// Register a validation env-step numeric metric.
|
||||
pub(crate) fn register_agent_metric_valid<Me: Metric + Numeric + 'static>(&mut self, metric: Me)
|
||||
where
|
||||
ES::ItemSync: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.register_definition(&metric);
|
||||
self.env_step_valid_numeric.push(Box::new(metric))
|
||||
}
|
||||
|
||||
/// Register an episode-end metric.
|
||||
pub(crate) fn register_text_metric_episode<Me: Metric + 'static>(&mut self, metric: Me)
|
||||
where
|
||||
EpisodeSummary: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.register_definition(&metric);
|
||||
self.episode_end.push(Box::new(metric))
|
||||
}
|
||||
|
||||
/// Register an episode-end numeric metric.
|
||||
pub(crate) fn register_episode_metric<Me: Metric + Numeric + 'static>(&mut self, metric: Me)
|
||||
where
|
||||
EpisodeSummary: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.register_definition(&metric);
|
||||
self.episode_end_numeric.push(Box::new(metric))
|
||||
}
|
||||
|
||||
/// Register an episode-end metric for validation.
|
||||
pub(crate) fn register_text_metric_episode_valid<Me: Metric + 'static>(&mut self, metric: Me)
|
||||
where
|
||||
EpisodeSummary: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.register_definition(&metric);
|
||||
self.episode_end_valid.push(Box::new(metric))
|
||||
}
|
||||
|
||||
/// Register an episode-end numeric metric for validation.
|
||||
pub(crate) fn register_episode_metric_valid<Me: Metric + Numeric + 'static>(
|
||||
&mut self,
|
||||
metric: Me,
|
||||
) where
|
||||
EpisodeSummary: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.register_definition(&metric);
|
||||
self.episode_end_valid_numeric.push(Box::new(metric))
|
||||
}
|
||||
|
||||
fn register_definition<Me: Metric>(&mut self, metric: &MetricWrapper<Me>) {
|
||||
self.metric_definitions.insert(
|
||||
metric.id.clone(),
|
||||
MetricDefinition::new(metric.id.clone(), &metric.metric),
|
||||
);
|
||||
}
|
||||
|
||||
/// Get metric definitions for all splits
|
||||
pub(crate) fn metric_definitions(&mut self) -> Vec<MetricDefinition> {
|
||||
self.metric_definitions.values().cloned().collect()
|
||||
}
|
||||
|
||||
/// Update the training information from the training item.
|
||||
pub(crate) fn update_train_step(
|
||||
&mut self,
|
||||
item: &EvaluationItem<TS::ItemSync>,
|
||||
metadata: &MetricMetadata,
|
||||
) -> MetricsUpdate {
|
||||
let mut entries = Vec::with_capacity(self.train_step.len());
|
||||
let mut entries_numeric = Vec::with_capacity(self.train_step_numeric.len());
|
||||
|
||||
for metric in self.train_step.iter_mut() {
|
||||
let state = metric.update(&item.item, metadata);
|
||||
entries.push(state);
|
||||
}
|
||||
|
||||
for metric in self.train_step_numeric.iter_mut() {
|
||||
let numeric_update = metric.update(&item.item, metadata);
|
||||
entries_numeric.push(numeric_update);
|
||||
}
|
||||
|
||||
MetricsUpdate::new(entries, entries_numeric)
|
||||
}
|
||||
|
||||
/// Update the env-step metrics from an environment step item.
|
||||
pub(crate) fn update_env_step(
|
||||
&mut self,
|
||||
item: &EvaluationItem<ES::ItemSync>,
|
||||
metadata: &MetricMetadata,
|
||||
) -> MetricsUpdate {
|
||||
let mut entries = Vec::with_capacity(self.env_step.len());
|
||||
let mut entries_numeric = Vec::with_capacity(self.env_step_numeric.len());
|
||||
|
||||
for metric in self.env_step.iter_mut() {
|
||||
let state = metric.update(&item.item, metadata);
|
||||
entries.push(state);
|
||||
}
|
||||
|
||||
for metric in self.env_step_numeric.iter_mut() {
|
||||
let numeric_update = metric.update(&item.item, metadata);
|
||||
entries_numeric.push(numeric_update);
|
||||
}
|
||||
|
||||
MetricsUpdate::new(entries, entries_numeric)
|
||||
}
|
||||
|
||||
/// Update the env-step metrics for validation from an environment step item.
|
||||
pub(crate) fn update_env_step_valid(
|
||||
&mut self,
|
||||
item: &EvaluationItem<ES::ItemSync>,
|
||||
metadata: &MetricMetadata,
|
||||
) -> MetricsUpdate {
|
||||
let mut entries = Vec::with_capacity(self.env_step_valid.len());
|
||||
let mut entries_numeric = Vec::with_capacity(self.env_step_valid_numeric.len());
|
||||
|
||||
for metric in self.env_step_valid.iter_mut() {
|
||||
let state = metric.update(&item.item, metadata);
|
||||
entries.push(state);
|
||||
}
|
||||
|
||||
for metric in self.env_step_valid_numeric.iter_mut() {
|
||||
let numeric_update = metric.update(&item.item, metadata);
|
||||
entries_numeric.push(numeric_update);
|
||||
}
|
||||
|
||||
MetricsUpdate::new(entries, entries_numeric)
|
||||
}
|
||||
|
||||
/// Update the episode-end metrics from an episode summary.
|
||||
pub(crate) fn update_episode_end(
|
||||
&mut self,
|
||||
item: &EvaluationItem<EpisodeSummary>,
|
||||
metadata: &MetricMetadata,
|
||||
) -> MetricsUpdate {
|
||||
let mut entries = Vec::with_capacity(self.episode_end.len());
|
||||
let mut entries_numeric = Vec::with_capacity(self.episode_end_numeric.len());
|
||||
|
||||
for metric in self.episode_end.iter_mut() {
|
||||
let state = metric.update(&item.item, metadata);
|
||||
entries.push(state);
|
||||
}
|
||||
|
||||
for metric in self.episode_end_numeric.iter_mut() {
|
||||
let numeric_update = metric.update(&item.item, metadata);
|
||||
entries_numeric.push(numeric_update);
|
||||
}
|
||||
|
||||
MetricsUpdate::new(entries, entries_numeric)
|
||||
}
|
||||
|
||||
/// Update the episode-end metrics for validation from an episode summary.
|
||||
pub(crate) fn update_episode_end_valid(
|
||||
&mut self,
|
||||
item: &EvaluationItem<EpisodeSummary>,
|
||||
metadata: &MetricMetadata,
|
||||
) -> MetricsUpdate {
|
||||
let mut entries = Vec::with_capacity(self.episode_end_valid.len());
|
||||
let mut entries_numeric = Vec::with_capacity(self.episode_end_valid_numeric.len());
|
||||
|
||||
for metric in self.episode_end_valid.iter_mut() {
|
||||
let state = metric.update(&item.item, metadata);
|
||||
entries.push(state);
|
||||
}
|
||||
|
||||
for metric in self.episode_end_valid_numeric.iter_mut() {
|
||||
let numeric_update = metric.update(&item.item, metadata);
|
||||
entries_numeric.push(numeric_update);
|
||||
}
|
||||
|
||||
MetricsUpdate::new(entries, entries_numeric)
|
||||
}
|
||||
}
|
||||
151
crates/burn-train/src/metric/processor/rl_processor.rs
Normal file
151
crates/burn-train/src/metric/processor/rl_processor.rs
Normal file
@@ -0,0 +1,151 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
AgentEvaluationEvent, EventProcessorTraining, ItemLazy, RLEvent, RLMetrics,
|
||||
metric::store::{Event, EventStoreClient, MetricsUpdate},
|
||||
renderer::{MetricState, MetricsRenderer, ProgressType, TrainingProgress},
|
||||
};
|
||||
|
||||
/// An [event processor](EventProcessorTraining) that handles:
|
||||
/// - Computing and storing metrics in an [event store](crate::metric::store::EventStore).
|
||||
/// - Render metrics using a [metrics renderer](MetricsRenderer).
|
||||
#[derive(new)]
|
||||
pub struct RLEventProcessor<TS: ItemLazy, ES: ItemLazy> {
|
||||
metrics: RLMetrics<TS, ES>,
|
||||
renderer: Box<dyn MetricsRenderer>,
|
||||
store: Arc<EventStoreClient>,
|
||||
}
|
||||
|
||||
impl<TS: ItemLazy, ES: ItemLazy> RLEventProcessor<TS, ES> {
|
||||
fn progress_indicators(&self, progress: &TrainingProgress) -> Vec<ProgressType> {
|
||||
let indicators = vec![ProgressType::Detailed {
|
||||
tag: String::from("Step"),
|
||||
progress: progress.global_progress.clone(),
|
||||
}];
|
||||
|
||||
indicators
|
||||
}
|
||||
|
||||
fn progress_indicators_eval(&self, progress: &TrainingProgress) -> Vec<ProgressType> {
|
||||
let indicators = vec![ProgressType::Detailed {
|
||||
tag: String::from("Step"),
|
||||
progress: progress.global_progress.clone(),
|
||||
}];
|
||||
|
||||
indicators
|
||||
}
|
||||
}
|
||||
|
||||
impl<TS: ItemLazy, ES: ItemLazy> RLEventProcessor<TS, ES> {
|
||||
fn process_update_train(&mut self, update: MetricsUpdate) {
|
||||
self.store
|
||||
.add_event_train(crate::metric::store::Event::MetricsUpdate(update.clone()));
|
||||
|
||||
update
|
||||
.entries
|
||||
.into_iter()
|
||||
.for_each(|entry| self.renderer.update_train(MetricState::Generic(entry)));
|
||||
|
||||
update
|
||||
.entries_numeric
|
||||
.into_iter()
|
||||
.for_each(|numeric_update| {
|
||||
self.renderer.update_train(MetricState::Numeric(
|
||||
numeric_update.entry,
|
||||
numeric_update.numeric_entry,
|
||||
))
|
||||
});
|
||||
}
|
||||
|
||||
fn process_update_valid(&mut self, update: MetricsUpdate) {
|
||||
self.store
|
||||
.add_event_valid(crate::metric::store::Event::MetricsUpdate(update.clone()));
|
||||
|
||||
update
|
||||
.entries
|
||||
.into_iter()
|
||||
.for_each(|entry| self.renderer.update_valid(MetricState::Generic(entry)));
|
||||
|
||||
update
|
||||
.entries_numeric
|
||||
.into_iter()
|
||||
.for_each(|numeric_update| {
|
||||
self.renderer.update_valid(MetricState::Numeric(
|
||||
numeric_update.entry,
|
||||
numeric_update.numeric_entry,
|
||||
))
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl<TS: ItemLazy, ES: ItemLazy> EventProcessorTraining<RLEvent<TS, ES>, AgentEvaluationEvent<ES>>
|
||||
for RLEventProcessor<TS, ES>
|
||||
{
|
||||
fn process_train(&mut self, event: RLEvent<TS, ES>) {
|
||||
match event {
|
||||
RLEvent::Start => {
|
||||
let definitions = self.metrics.metric_definitions();
|
||||
self.store
|
||||
.add_event_train(Event::MetricsInit(definitions.clone()));
|
||||
definitions
|
||||
.iter()
|
||||
.for_each(|definition| self.renderer.register_metric(definition.clone()));
|
||||
}
|
||||
RLEvent::TrainStep(item) => {
|
||||
let item = item.sync();
|
||||
let metadata = (&item).into();
|
||||
|
||||
let update = self.metrics.update_train_step(&item, &metadata);
|
||||
self.process_update_train(update);
|
||||
}
|
||||
RLEvent::TimeStep(item) => {
|
||||
let item = item.sync();
|
||||
let progress = (&item).into();
|
||||
let metadata = (&item).into();
|
||||
|
||||
let update = self.metrics.update_env_step(&item, &metadata);
|
||||
self.process_update_train(update);
|
||||
let status = self.progress_indicators(&progress);
|
||||
self.renderer.render_train(progress, status);
|
||||
}
|
||||
RLEvent::EpisodeEnd(item) => {
|
||||
let item = item.sync();
|
||||
let metadata = (&item).into();
|
||||
|
||||
let update = self.metrics.update_episode_end(&item, &metadata);
|
||||
self.process_update_train(update);
|
||||
}
|
||||
RLEvent::End(learner_summary) => {
|
||||
self.renderer.on_train_end(learner_summary).ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn process_valid(&mut self, event: AgentEvaluationEvent<ES>) {
|
||||
match event {
|
||||
AgentEvaluationEvent::Start => {} // no-op for now
|
||||
AgentEvaluationEvent::TimeStep(item) => {
|
||||
let item = item.sync();
|
||||
let metadata = (&item).into();
|
||||
|
||||
let update = self.metrics.update_env_step_valid(&item, &metadata);
|
||||
self.process_update_valid(update);
|
||||
}
|
||||
AgentEvaluationEvent::EpisodeEnd(item) => {
|
||||
let item = item.sync();
|
||||
let progress = (&item).into();
|
||||
let metadata = (&item).into();
|
||||
|
||||
let update = self.metrics.update_episode_end_valid(&item, &metadata);
|
||||
self.process_update_valid(update);
|
||||
let status = self.progress_indicators_eval(&progress);
|
||||
self.renderer.render_valid(progress, status);
|
||||
}
|
||||
AgentEvaluationEvent::End => {} // no-op for now
|
||||
}
|
||||
}
|
||||
|
||||
fn renderer(self) -> Box<dyn MetricsRenderer> {
|
||||
self.renderer
|
||||
}
|
||||
}
|
||||
78
crates/burn-train/src/metric/rl/cum_reward.rs
Normal file
78
crates/burn-train/src/metric/rl/cum_reward.rs
Normal file
@@ -0,0 +1,78 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::super::{
|
||||
MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry,
|
||||
state::{FormatOptions, NumericMetricState},
|
||||
};
|
||||
use crate::metric::{Metric, MetricName, Numeric, SerializedEntry};
|
||||
|
||||
/// Metric for the cumulative reward of the last completed episode.
|
||||
#[derive(Clone)]
|
||||
pub struct CumulativeRewardMetric {
|
||||
name: MetricName,
|
||||
state: NumericMetricState,
|
||||
}
|
||||
|
||||
impl CumulativeRewardMetric {
|
||||
/// Creates a new episode length metric.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
name: Arc::new("Cum. Reward".to_string()),
|
||||
state: NumericMetricState::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CumulativeRewardMetric {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// The [CumulativeRewardMetric](CumulativeRewardMetric) input type.
|
||||
#[derive(new)]
|
||||
pub struct CumulativeRewardInput {
|
||||
cum_reward: f64,
|
||||
}
|
||||
|
||||
impl Metric for CumulativeRewardMetric {
|
||||
type Input = CumulativeRewardInput;
|
||||
|
||||
fn update(
|
||||
&mut self,
|
||||
item: &CumulativeRewardInput,
|
||||
_metadata: &MetricMetadata,
|
||||
) -> SerializedEntry {
|
||||
self.state.update(
|
||||
item.cum_reward,
|
||||
1,
|
||||
FormatOptions::new(self.name()).precision(2),
|
||||
)
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.state.reset()
|
||||
}
|
||||
|
||||
fn name(&self) -> MetricName {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
fn attributes(&self) -> MetricAttributes {
|
||||
NumericAttributes {
|
||||
unit: None,
|
||||
higher_is_better: true,
|
||||
}
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl Numeric for CumulativeRewardMetric {
|
||||
fn value(&self) -> NumericEntry {
|
||||
self.state.current_value()
|
||||
}
|
||||
|
||||
fn running_value(&self) -> NumericEntry {
|
||||
self.state.running_value()
|
||||
}
|
||||
}
|
||||
71
crates/burn-train/src/metric/rl/ep_len.rs
Normal file
71
crates/burn-train/src/metric/rl/ep_len.rs
Normal file
@@ -0,0 +1,71 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::super::{
|
||||
MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry,
|
||||
state::{FormatOptions, NumericMetricState},
|
||||
};
|
||||
use crate::metric::{Metric, MetricName, Numeric, SerializedEntry};
|
||||
|
||||
/// Metric for the length of the last completed episode.
|
||||
#[derive(Clone)]
|
||||
pub struct EpisodeLengthMetric {
|
||||
name: MetricName,
|
||||
state: NumericMetricState,
|
||||
}
|
||||
|
||||
impl EpisodeLengthMetric {
|
||||
/// Creates a new episode length metric.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
name: Arc::new("Episode length".to_string()),
|
||||
state: NumericMetricState::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for EpisodeLengthMetric {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// The [EpisodeLengthMetric](EpisodeLengthMetric) input type.
|
||||
#[derive(new)]
|
||||
pub struct EpisodeLengthInput {
|
||||
ep_len: f64,
|
||||
}
|
||||
|
||||
impl Metric for EpisodeLengthMetric {
|
||||
type Input = EpisodeLengthInput;
|
||||
|
||||
fn update(&mut self, item: &EpisodeLengthInput, _metadata: &MetricMetadata) -> SerializedEntry {
|
||||
self.state
|
||||
.update(item.ep_len, 1, FormatOptions::new(self.name()).precision(0))
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.state.reset()
|
||||
}
|
||||
|
||||
fn name(&self) -> MetricName {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
fn attributes(&self) -> MetricAttributes {
|
||||
NumericAttributes {
|
||||
unit: Some(String::from("steps")),
|
||||
higher_is_better: true,
|
||||
}
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl Numeric for EpisodeLengthMetric {
|
||||
fn value(&self) -> NumericEntry {
|
||||
self.state.current_value()
|
||||
}
|
||||
|
||||
fn running_value(&self) -> NumericEntry {
|
||||
self.state.running_value()
|
||||
}
|
||||
}
|
||||
78
crates/burn-train/src/metric/rl/exploration_rate.rs
Normal file
78
crates/burn-train/src/metric/rl/exploration_rate.rs
Normal file
@@ -0,0 +1,78 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::super::{
|
||||
MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry,
|
||||
state::{FormatOptions, NumericMetricState},
|
||||
};
|
||||
use crate::metric::{Metric, MetricName, Numeric, SerializedEntry};
|
||||
|
||||
/// Metric for the length of the last completed episode.
|
||||
#[derive(Clone)]
|
||||
pub struct ExplorationRateMetric {
|
||||
name: MetricName,
|
||||
state: NumericMetricState,
|
||||
}
|
||||
|
||||
impl ExplorationRateMetric {
|
||||
/// Creates a new episode length metric.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
name: Arc::new("Exploration rate".to_string()),
|
||||
state: NumericMetricState::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ExplorationRateMetric {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// The [ExplorationRateMetric](ExplorationRateMetric) input type.
|
||||
#[derive(new)]
|
||||
pub struct ExplorationRateInput {
|
||||
exploration_rate: f64,
|
||||
}
|
||||
|
||||
impl Metric for ExplorationRateMetric {
|
||||
type Input = ExplorationRateInput;
|
||||
|
||||
fn update(
|
||||
&mut self,
|
||||
item: &ExplorationRateInput,
|
||||
_metadata: &MetricMetadata,
|
||||
) -> SerializedEntry {
|
||||
self.state.update(
|
||||
item.exploration_rate,
|
||||
1,
|
||||
FormatOptions::new(self.name()).precision(3),
|
||||
)
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.state.reset()
|
||||
}
|
||||
|
||||
fn name(&self) -> MetricName {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
fn attributes(&self) -> MetricAttributes {
|
||||
NumericAttributes {
|
||||
unit: Some(String::from("%")),
|
||||
higher_is_better: false,
|
||||
}
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl Numeric for ExplorationRateMetric {
|
||||
fn value(&self) -> NumericEntry {
|
||||
self.state.current_value()
|
||||
}
|
||||
|
||||
fn running_value(&self) -> NumericEntry {
|
||||
self.state.running_value()
|
||||
}
|
||||
}
|
||||
7
crates/burn-train/src/metric/rl/mod.rs
Normal file
7
crates/burn-train/src/metric/rl/mod.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
mod cum_reward;
|
||||
mod ep_len;
|
||||
mod exploration_rate;
|
||||
|
||||
pub use cum_reward::*;
|
||||
pub use ep_len::*;
|
||||
pub use exploration_rate::*;
|
||||
@@ -109,6 +109,7 @@ impl NumericMetricState {
|
||||
None => (format!("{value_current}"), format!("{value_running}")),
|
||||
};
|
||||
|
||||
// TODO: naming inconsistent with RL.
|
||||
let formatted = match format.unit {
|
||||
Some(unit) => {
|
||||
format!("epoch {formatted_running} {unit} - batch {formatted_current} {unit}")
|
||||
|
||||
@@ -27,14 +27,14 @@ pub trait MetricsRendererTraining: Send + Sync {
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `item` - The training progress.
|
||||
fn render_train(&mut self, item: TrainingProgress);
|
||||
fn render_train(&mut self, item: TrainingProgress, progress_indicators: Vec<ProgressType>);
|
||||
|
||||
/// Renders the validation progress.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `item` - The validation progress.
|
||||
fn render_valid(&mut self, item: TrainingProgress);
|
||||
fn render_valid(&mut self, item: TrainingProgress, progress_indicators: Vec<ProgressType>);
|
||||
|
||||
/// Callback method invoked when training ends, whether it
|
||||
/// completed successfully or was interrupted.
|
||||
@@ -58,7 +58,7 @@ pub trait MetricsRenderer: MetricsRendererEvaluation + MetricsRendererTraining {
|
||||
/// Keep the renderer from automatically closing, requiring manual action to close it.
|
||||
fn manual_close(&mut self);
|
||||
/// Register a new metric.
|
||||
fn register_metric(&mut self, _definition: MetricDefinition);
|
||||
fn register_metric(&mut self, definition: MetricDefinition);
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -102,7 +102,7 @@ pub trait MetricsRendererEvaluation: Send + Sync {
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `item` - The training progress.
|
||||
fn render_test(&mut self, item: EvaluationProgress);
|
||||
fn render_test(&mut self, item: EvaluationProgress, progress_indicators: Vec<ProgressType>);
|
||||
|
||||
/// Callback method invoked when testing ends, whether it
|
||||
/// completed successfully or was interrupted.
|
||||
@@ -128,16 +128,13 @@ pub enum MetricState {
|
||||
#[derive(Debug)]
|
||||
pub struct TrainingProgress {
|
||||
/// The progress.
|
||||
pub progress: Progress,
|
||||
pub progress: Option<Progress>,
|
||||
|
||||
/// The epoch.
|
||||
pub epoch: usize,
|
||||
/// The progress of the whole training.
|
||||
pub global_progress: Progress,
|
||||
|
||||
/// The total number of epochs.
|
||||
pub epoch_total: usize,
|
||||
|
||||
/// The iteration.
|
||||
pub iteration: usize,
|
||||
/// The iteration, if it differs from the items processed.
|
||||
pub iteration: Option<usize>,
|
||||
}
|
||||
|
||||
/// Evaluation progress.
|
||||
@@ -146,21 +143,48 @@ pub struct EvaluationProgress {
|
||||
/// The progress.
|
||||
pub progress: Progress,
|
||||
|
||||
/// The iteration.
|
||||
pub iteration: usize,
|
||||
/// The iteration, if it is different from the processed items.
|
||||
pub iteration: Option<usize>,
|
||||
}
|
||||
|
||||
impl From<&EvaluationProgress> for TrainingProgress {
|
||||
fn from(value: &EvaluationProgress) -> Self {
|
||||
TrainingProgress {
|
||||
progress: None,
|
||||
global_progress: value.progress.clone(),
|
||||
iteration: value.iteration,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TrainingProgress {
|
||||
/// Creates a new empty training progress.
|
||||
pub fn none() -> Self {
|
||||
Self {
|
||||
progress: Progress {
|
||||
progress: None,
|
||||
global_progress: Progress {
|
||||
items_processed: 0,
|
||||
items_total: 0,
|
||||
},
|
||||
epoch: 0,
|
||||
epoch_total: 0,
|
||||
iteration: 0,
|
||||
iteration: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Type of progress indicators.
|
||||
pub enum ProgressType {
|
||||
/// Detailed progress.
|
||||
Detailed {
|
||||
/// The tag.
|
||||
tag: String,
|
||||
/// The progress.
|
||||
progress: Progress,
|
||||
},
|
||||
/// Simple value.
|
||||
Value {
|
||||
/// The tag.
|
||||
tag: String,
|
||||
/// The value.
|
||||
value: usize,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::renderer::{
|
||||
EvaluationProgress, MetricState, MetricsRenderer, MetricsRendererEvaluation,
|
||||
MetricsRendererTraining, TrainingProgress,
|
||||
MetricsRendererTraining, ProgressType, TrainingProgress,
|
||||
};
|
||||
|
||||
/// A simple renderer for when the cli feature is not enabled.
|
||||
@@ -19,17 +19,17 @@ impl MetricsRendererTraining for CliMetricsRenderer {
|
||||
|
||||
fn update_valid(&mut self, _state: MetricState) {}
|
||||
|
||||
fn render_train(&mut self, item: TrainingProgress) {
|
||||
fn render_train(&mut self, item: TrainingProgress, _progress_indicators: Vec<ProgressType>) {
|
||||
println!("{item:?}");
|
||||
}
|
||||
|
||||
fn render_valid(&mut self, item: TrainingProgress) {
|
||||
fn render_valid(&mut self, item: TrainingProgress, _progress_indicators: Vec<ProgressType>) {
|
||||
println!("{item:?}");
|
||||
}
|
||||
}
|
||||
|
||||
impl MetricsRendererEvaluation for CliMetricsRenderer {
|
||||
fn render_test(&mut self, item: EvaluationProgress) {
|
||||
fn render_test(&mut self, item: EvaluationProgress, _progress_indicators: Vec<ProgressType>) {
|
||||
println!("{item:?}");
|
||||
}
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ pub(crate) fn default_renderer(
|
||||
) -> Box<dyn MetricsRenderer> {
|
||||
#[cfg(feature = "tui")]
|
||||
if std::io::stdout().is_terminal() {
|
||||
return Box::new(tui::TuiMetricsRenderer::new(interuptor, checkpoint));
|
||||
return Box::new(tui::TuiMetricsRendererWrapper::new(interuptor, checkpoint));
|
||||
}
|
||||
|
||||
Box::new(CliMetricsRenderer::new())
|
||||
|
||||
@@ -65,13 +65,14 @@ impl NumericMetricsState {
|
||||
|
||||
/// Update the state with the training progress.
|
||||
pub(crate) fn update_progress_train(&mut self, progress: &TrainingProgress) {
|
||||
self.epoch = progress.epoch;
|
||||
self.epoch = progress.global_progress.items_processed;
|
||||
|
||||
if self.num_samples_train.is_some() {
|
||||
return;
|
||||
}
|
||||
|
||||
self.num_samples_train = Some(progress.progress.items_total);
|
||||
// If the training only has the notion of global progress, num_samples_train remains None.
|
||||
self.num_samples_train = progress.progress.as_ref().map(|p| p.items_total);
|
||||
}
|
||||
|
||||
/// Update the state with the validation progress.
|
||||
@@ -80,16 +81,20 @@ impl NumericMetricsState {
|
||||
return;
|
||||
}
|
||||
|
||||
// If num_samples_train is None, keep the default max_samples for validation.
|
||||
if let Some(num_sample_train) = self.num_samples_train {
|
||||
for (_, (_recent, full)) in self.data.iter_mut() {
|
||||
let ratio = progress.progress.items_total as f64 / num_sample_train as f64;
|
||||
let ratio = match &progress.progress {
|
||||
Some(p) => p.items_total as f64 / num_sample_train as f64,
|
||||
None => progress.global_progress.items_total as f64 / num_sample_train as f64,
|
||||
};
|
||||
|
||||
full.update_max_sample(TuiSplit::Valid, ratio);
|
||||
}
|
||||
}
|
||||
|
||||
self.epoch = progress.epoch;
|
||||
self.num_samples_valid = Some(progress.progress.items_total);
|
||||
self.epoch = progress.global_progress.items_processed;
|
||||
self.num_samples_valid = progress.progress.as_ref().map(|p| p.items_total);
|
||||
}
|
||||
|
||||
/// Update the state with the testing progress.
|
||||
|
||||
@@ -36,8 +36,12 @@ impl ProgressBarState {
|
||||
/// Update the training progress.
|
||||
pub(crate) fn update_train(&mut self, progress: &TrainingProgress) {
|
||||
self.progress_total = calculate_progress(progress, 0, 0);
|
||||
let local_progress = progress
|
||||
.progress
|
||||
.as_ref()
|
||||
.unwrap_or(&progress.global_progress);
|
||||
self.progress_task =
|
||||
progress.progress.items_processed as f64 / progress.progress.items_total as f64;
|
||||
local_progress.items_processed as f64 / local_progress.items_total as f64;
|
||||
self.estimate.update(progress, self.starting_epoch);
|
||||
self.split = TuiSplit::Train;
|
||||
}
|
||||
@@ -45,8 +49,12 @@ impl ProgressBarState {
|
||||
/// Update the validation progress.
|
||||
pub(crate) fn update_valid(&mut self, progress: &TrainingProgress) {
|
||||
// We don't use the validation for the total progress yet.
|
||||
let local_progress = progress
|
||||
.progress
|
||||
.as_ref()
|
||||
.unwrap_or(&progress.global_progress);
|
||||
self.progress_task =
|
||||
progress.progress.items_processed as f64 / progress.progress.items_total as f64;
|
||||
local_progress.items_processed as f64 / local_progress.items_total as f64;
|
||||
self.split = TuiSplit::Valid;
|
||||
}
|
||||
|
||||
@@ -187,7 +195,7 @@ impl ProgressEstimate {
|
||||
}
|
||||
|
||||
// When the training has started since at least 10 seconds and completed 10 iterations.
|
||||
if progress.iteration >= WARMUP_NUM_ITERATION
|
||||
if progress.iteration >= Some(WARMUP_NUM_ITERATION)
|
||||
&& self.started.elapsed() > Duration::from_secs(10)
|
||||
{
|
||||
self.init(progress, starting_epoch);
|
||||
@@ -195,11 +203,17 @@ impl ProgressEstimate {
|
||||
}
|
||||
|
||||
fn init(&mut self, progress: &TrainingProgress, starting_epoch: usize) {
|
||||
let epoch = progress.epoch - starting_epoch;
|
||||
let epoch_items = (epoch - 1) * progress.progress.items_total;
|
||||
let iteration_items = progress.progress.items_processed;
|
||||
let epoch = progress.global_progress.items_processed - starting_epoch;
|
||||
|
||||
self.warmup_num_items = match &progress.progress {
|
||||
Some(local_progress) => {
|
||||
let epoch_items = (epoch - 1) * local_progress.items_total;
|
||||
let iteration_items = local_progress.items_processed;
|
||||
epoch_items + iteration_items
|
||||
}
|
||||
None => epoch,
|
||||
};
|
||||
|
||||
self.warmup_num_items = epoch_items + iteration_items;
|
||||
self.started_after_warmup = Some(Instant::now());
|
||||
self.progress = calculate_progress(progress, starting_epoch, self.warmup_num_items);
|
||||
}
|
||||
@@ -210,15 +224,19 @@ fn calculate_progress(
|
||||
starting_epoch: usize,
|
||||
ignore_num_items: usize,
|
||||
) -> f64 {
|
||||
let epoch_total = progress.epoch_total - starting_epoch;
|
||||
let epoch = progress.epoch - starting_epoch;
|
||||
let epoch_total = progress.global_progress.items_total - starting_epoch;
|
||||
let epoch = progress.global_progress.items_processed - starting_epoch;
|
||||
match &progress.progress {
|
||||
Some(local_progress) => {
|
||||
let total_items = local_progress.items_total * epoch_total;
|
||||
let epoch_items = (epoch - 1) * local_progress.items_total;
|
||||
let iteration_items = local_progress.items_processed;
|
||||
let num_items = epoch_items + iteration_items - ignore_num_items;
|
||||
|
||||
let total_items = progress.progress.items_total * epoch_total;
|
||||
let epoch_items = (epoch - 1) * progress.progress.items_total;
|
||||
let iteration_items = progress.progress.items_processed;
|
||||
let num_items = epoch_items + iteration_items - ignore_num_items;
|
||||
|
||||
num_items as f64 / total_items as f64
|
||||
num_items as f64 / total_items as f64
|
||||
}
|
||||
None => epoch as f64 / epoch_total as f64,
|
||||
}
|
||||
}
|
||||
|
||||
fn format_eta(eta_secs: u64) -> String {
|
||||
@@ -268,11 +286,14 @@ mod tests {
|
||||
items_processed: 5,
|
||||
items_total: 10,
|
||||
};
|
||||
let global_progress = Progress {
|
||||
items_processed: 9,
|
||||
items_total: 10,
|
||||
};
|
||||
let progress = TrainingProgress {
|
||||
progress: half,
|
||||
epoch: 9,
|
||||
epoch_total: 10,
|
||||
iteration: 500,
|
||||
progress: Some(half),
|
||||
global_progress: global_progress,
|
||||
iteration: Some(500),
|
||||
};
|
||||
|
||||
let starting_epoch = 8;
|
||||
@@ -288,11 +309,14 @@ mod tests {
|
||||
items_processed: 110,
|
||||
items_total: 1000,
|
||||
};
|
||||
let global_progress = Progress {
|
||||
items_processed: 9,
|
||||
items_total: 10,
|
||||
};
|
||||
let progress = TrainingProgress {
|
||||
progress: half,
|
||||
epoch: 9,
|
||||
epoch_total: 10,
|
||||
iteration: 500,
|
||||
progress: Some(half),
|
||||
global_progress: global_progress,
|
||||
iteration: Some(500),
|
||||
};
|
||||
|
||||
let starting_epoch = 8;
|
||||
|
||||
@@ -2,7 +2,7 @@ use crate::metric::{MetricDefinition, MetricId};
|
||||
use crate::renderer::tui::TuiSplit;
|
||||
use crate::renderer::{
|
||||
EvaluationName, EvaluationProgress, MetricState, MetricsRenderer, MetricsRendererEvaluation,
|
||||
TrainingProgress,
|
||||
ProgressType, TrainingProgress,
|
||||
};
|
||||
use crate::renderer::{MetricsRendererTraining, tui::NumericMetricsState};
|
||||
use crate::{Interrupter, LearnerSummary};
|
||||
@@ -17,7 +17,9 @@ use ratatui::{
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::panic::{set_hook, take_hook};
|
||||
use std::sync::Arc;
|
||||
use std::sync::mpsc::{Receiver, Sender};
|
||||
use std::sync::{Arc, Mutex, mpsc};
|
||||
use std::thread::{JoinHandle, spawn};
|
||||
use std::{
|
||||
error::Error,
|
||||
io::{self, Stdout},
|
||||
@@ -38,8 +40,86 @@ type PanicHook = Box<dyn Fn(&std::panic::PanicHookInfo<'_>) + 'static + Sync + S
|
||||
|
||||
const MAX_REFRESH_RATE_MILLIS: u64 = 100;
|
||||
|
||||
enum TuiRendererEvent {
|
||||
MetricRegistration(MetricDefinition),
|
||||
MetricsUpdate((TuiSplit, TuiGroup, MetricState)),
|
||||
StatusUpdateTrain((TuiSplit, TrainingProgress, Vec<ProgressType>)),
|
||||
StatusUpdateTest((EvaluationProgress, Vec<ProgressType>)),
|
||||
TrainEnd(Option<LearnerSummary>),
|
||||
ManualClose(),
|
||||
Close(),
|
||||
Persistent(),
|
||||
}
|
||||
|
||||
/// The terminal UI metrics renderer.
|
||||
pub struct TuiMetricsRenderer {
|
||||
pub struct TuiMetricsRendererWrapper {
|
||||
sender: mpsc::Sender<TuiRendererEvent>,
|
||||
interrupter: Interrupter,
|
||||
handle_join: Option<JoinHandle<()>>,
|
||||
kill_signal: Arc<Mutex<Receiver<()>>>,
|
||||
}
|
||||
|
||||
impl TuiMetricsRendererWrapper {
|
||||
/// Create a new terminal UI renderer.
|
||||
pub fn new(interrupter: Interrupter, checkpoint: Option<usize>) -> Self {
|
||||
let (sender, receiver) = mpsc::channel();
|
||||
let (kill_signal_sender, kill_signal_receiver) = mpsc::channel();
|
||||
|
||||
let interrupter_clone = interrupter.clone();
|
||||
let handle_join = spawn(move || {
|
||||
let mut renderer =
|
||||
TuiMetricsRenderer::new(interrupter_clone, checkpoint, kill_signal_sender);
|
||||
|
||||
let tick_rate = Duration::from_millis(MAX_REFRESH_RATE_MILLIS);
|
||||
loop {
|
||||
match receiver.try_recv() {
|
||||
Ok(event) => renderer.handle_event(event),
|
||||
Err(mpsc::TryRecvError::Empty) => (),
|
||||
Err(mpsc::TryRecvError::Disconnected) => {
|
||||
log::error!("Renderer thread disconnected.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Render
|
||||
if renderer.last_update.elapsed() >= tick_rate
|
||||
&& let Err(err) = renderer.render()
|
||||
{
|
||||
log::error!("Render error: {err}");
|
||||
break;
|
||||
}
|
||||
|
||||
if (renderer.manual_close && renderer.interrupter.should_stop()) || renderer.close {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
sender,
|
||||
interrupter,
|
||||
handle_join: Some(handle_join),
|
||||
kill_signal: Arc::new(Mutex::new(kill_signal_receiver)),
|
||||
}
|
||||
}
|
||||
|
||||
fn send_event(&self, event: TuiRendererEvent) {
|
||||
if self.kill_signal.lock().unwrap().try_recv().is_ok() {
|
||||
panic!("Killing training from user input.")
|
||||
}
|
||||
if let Err(e) = self.sender.send(event) {
|
||||
log::warn!("Failed to send TUI event: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the renderer to persistent mode.
|
||||
pub fn persistent(self) -> Self {
|
||||
self.send_event(TuiRendererEvent::Persistent());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
struct TuiMetricsRenderer {
|
||||
terminal: Terminal<TerminalBackend>,
|
||||
last_update: std::time::Instant,
|
||||
progress: ProgressBarState,
|
||||
@@ -47,75 +127,95 @@ pub struct TuiMetricsRenderer {
|
||||
metrics_numeric: NumericMetricsState,
|
||||
metrics_text: TextMetricsState,
|
||||
status: StatusState,
|
||||
interuptor: Interrupter,
|
||||
interrupter: Interrupter,
|
||||
popup: PopupState,
|
||||
previous_panic_hook: Option<Arc<PanicHook>>,
|
||||
persistent: bool,
|
||||
manual_close: bool,
|
||||
close: bool,
|
||||
summary: Option<LearnerSummary>,
|
||||
kill_signal: Sender<()>,
|
||||
}
|
||||
|
||||
impl MetricsRendererEvaluation for TuiMetricsRenderer {
|
||||
impl MetricsRendererEvaluation for TuiMetricsRendererWrapper {
|
||||
fn update_test(&mut self, name: EvaluationName, state: MetricState) {
|
||||
self.update_metric(TuiSplit::Test, TuiGroup::Named(name.name), state);
|
||||
self.send_event(TuiRendererEvent::MetricsUpdate((
|
||||
TuiSplit::Test,
|
||||
TuiGroup::Named(name.name),
|
||||
state,
|
||||
)));
|
||||
}
|
||||
|
||||
fn render_test(&mut self, item: EvaluationProgress) {
|
||||
self.progress.update_test(&item);
|
||||
self.metrics_numeric.update_progress_test(&item);
|
||||
self.status.update_test(item);
|
||||
self.render().unwrap();
|
||||
fn render_test(&mut self, item: EvaluationProgress, progress_indicators: Vec<ProgressType>) {
|
||||
self.send_event(TuiRendererEvent::StatusUpdateTest((
|
||||
item,
|
||||
progress_indicators,
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
impl MetricsRenderer for TuiMetricsRenderer {
|
||||
impl MetricsRenderer for TuiMetricsRendererWrapper {
|
||||
fn manual_close(&mut self) {
|
||||
loop {
|
||||
self.render().unwrap();
|
||||
if self.interuptor.should_stop() {
|
||||
return;
|
||||
}
|
||||
std::thread::sleep(Duration::from_millis(100));
|
||||
}
|
||||
self.send_event(TuiRendererEvent::ManualClose());
|
||||
let _ = self.handle_join.take().unwrap().join();
|
||||
}
|
||||
|
||||
fn register_metric(&mut self, definition: MetricDefinition) {
|
||||
self.metric_definitions
|
||||
.insert(definition.metric_id.clone(), definition);
|
||||
self.send_event(TuiRendererEvent::MetricRegistration(definition));
|
||||
}
|
||||
}
|
||||
|
||||
impl MetricsRendererTraining for TuiMetricsRenderer {
|
||||
impl MetricsRendererTraining for TuiMetricsRendererWrapper {
|
||||
fn update_train(&mut self, state: MetricState) {
|
||||
self.update_metric(TuiSplit::Train, TuiGroup::Default, state);
|
||||
self.send_event(TuiRendererEvent::MetricsUpdate((
|
||||
TuiSplit::Train,
|
||||
TuiGroup::Default,
|
||||
state,
|
||||
)));
|
||||
}
|
||||
|
||||
fn update_valid(&mut self, state: MetricState) {
|
||||
self.update_metric(TuiSplit::Valid, TuiGroup::Default, state);
|
||||
self.send_event(TuiRendererEvent::MetricsUpdate((
|
||||
TuiSplit::Valid,
|
||||
TuiGroup::Default,
|
||||
state,
|
||||
)));
|
||||
}
|
||||
|
||||
fn render_train(&mut self, item: TrainingProgress) {
|
||||
self.progress.update_train(&item);
|
||||
self.metrics_numeric.update_progress_train(&item);
|
||||
self.status.update_train(item);
|
||||
self.render().unwrap();
|
||||
fn render_train(&mut self, item: TrainingProgress, progress_indicators: Vec<ProgressType>) {
|
||||
self.send_event(TuiRendererEvent::StatusUpdateTrain((
|
||||
TuiSplit::Train,
|
||||
item,
|
||||
progress_indicators,
|
||||
)));
|
||||
}
|
||||
|
||||
fn render_valid(&mut self, item: TrainingProgress) {
|
||||
self.progress.update_valid(&item);
|
||||
self.metrics_numeric.update_progress_valid(&item);
|
||||
self.status.update_valid(item);
|
||||
self.render().unwrap();
|
||||
fn render_valid(&mut self, item: TrainingProgress, progress_indicators: Vec<ProgressType>) {
|
||||
self.send_event(TuiRendererEvent::StatusUpdateTrain((
|
||||
TuiSplit::Valid,
|
||||
item,
|
||||
progress_indicators,
|
||||
)));
|
||||
}
|
||||
|
||||
fn on_train_end(&mut self, summary: Option<LearnerSummary>) -> Result<(), Box<dyn Error>> {
|
||||
// Reset for following steps.
|
||||
self.interuptor.reset();
|
||||
self.interrupter.reset();
|
||||
// Update the summary
|
||||
self.summary = summary;
|
||||
self.send_event(TuiRendererEvent::TrainEnd(summary));
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TuiMetricsRendererWrapper {
|
||||
fn drop(&mut self) {
|
||||
if !std::thread::panicking() {
|
||||
self.send_event(TuiRendererEvent::Close());
|
||||
let _ = self.handle_join.take().unwrap().join();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TuiMetricsRenderer {
|
||||
fn update_metric(&mut self, split: TuiSplit, group: TuiGroup, state: MetricState) {
|
||||
match state {
|
||||
@@ -144,8 +244,11 @@ impl TuiMetricsRenderer {
|
||||
};
|
||||
}
|
||||
|
||||
/// Create a new terminal UI renderer.
|
||||
pub fn new(interuptor: Interrupter, checkpoint: Option<usize>) -> Self {
|
||||
pub fn new(
|
||||
interrupter: Interrupter,
|
||||
checkpoint: Option<usize>,
|
||||
kill_signal: Sender<()>,
|
||||
) -> Self {
|
||||
let mut stdout = io::stdout();
|
||||
execute!(stdout, EnterAlternateScreen).unwrap();
|
||||
enable_raw_mode().unwrap();
|
||||
@@ -171,28 +274,57 @@ impl TuiMetricsRenderer {
|
||||
metrics_numeric: NumericMetricsState::default(),
|
||||
metrics_text: TextMetricsState::default(),
|
||||
status: StatusState::default(),
|
||||
interuptor,
|
||||
interrupter,
|
||||
popup: PopupState::Empty,
|
||||
previous_panic_hook: Some(previous_panic_hook),
|
||||
persistent: false,
|
||||
manual_close: false,
|
||||
close: false,
|
||||
summary: None,
|
||||
kill_signal,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the renderer to persistent mode.
|
||||
pub fn persistent(mut self) -> Self {
|
||||
self.persistent = true;
|
||||
self
|
||||
fn handle_event(&mut self, event: TuiRendererEvent) {
|
||||
match event {
|
||||
TuiRendererEvent::MetricRegistration(definition) => {
|
||||
self.metric_definitions
|
||||
.insert(definition.metric_id.clone(), definition);
|
||||
}
|
||||
TuiRendererEvent::MetricsUpdate((split, group, state)) => {
|
||||
self.update_metric(split, group, state);
|
||||
}
|
||||
TuiRendererEvent::StatusUpdateTrain((split, item, status)) => match split {
|
||||
TuiSplit::Train => {
|
||||
self.progress.update_train(&item);
|
||||
self.metrics_numeric.update_progress_train(&item);
|
||||
self.status.update_train(status);
|
||||
}
|
||||
TuiSplit::Valid => {
|
||||
self.progress.update_valid(&item);
|
||||
self.metrics_numeric.update_progress_valid(&item);
|
||||
self.status.update_valid(status);
|
||||
}
|
||||
_ => (),
|
||||
},
|
||||
TuiRendererEvent::StatusUpdateTest((item, status)) => {
|
||||
self.progress.update_test(&item);
|
||||
self.metrics_numeric.update_progress_test(&item);
|
||||
self.status.update_test(status);
|
||||
}
|
||||
TuiRendererEvent::TrainEnd(learner_summary) => {
|
||||
self.interrupter.reset();
|
||||
self.summary = learner_summary;
|
||||
}
|
||||
TuiRendererEvent::ManualClose() => self.manual_close = true,
|
||||
TuiRendererEvent::Persistent() => self.persistent = true,
|
||||
TuiRendererEvent::Close() => self.close = true,
|
||||
}
|
||||
}
|
||||
|
||||
fn render(&mut self) -> Result<(), Box<dyn Error>> {
|
||||
let tick_rate = Duration::from_millis(MAX_REFRESH_RATE_MILLIS);
|
||||
if self.last_update.elapsed() < tick_rate {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.draw()?;
|
||||
self.handle_events()?;
|
||||
self.handle_user_input()?;
|
||||
|
||||
self.last_update = Instant::now();
|
||||
|
||||
@@ -222,7 +354,7 @@ impl TuiMetricsRenderer {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_events(&mut self) -> Result<(), Box<dyn Error>> {
|
||||
fn handle_user_input(&mut self) -> Result<(), Box<dyn Error>> {
|
||||
while event::poll(Duration::from_secs(0))? {
|
||||
let event = event::read()?;
|
||||
self.popup.on_event(&event);
|
||||
@@ -242,7 +374,7 @@ impl TuiMetricsRenderer {
|
||||
training loop, but any remaining code after the loop will be \
|
||||
executed.",
|
||||
's',
|
||||
QuitPopupAccept(self.interuptor.clone()),
|
||||
QuitPopupAccept(self.interrupter.clone()),
|
||||
),
|
||||
Callback::new(
|
||||
"Stop the training immediately.",
|
||||
@@ -250,7 +382,7 @@ impl TuiMetricsRenderer {
|
||||
the current training fails. Any code following the training \
|
||||
won't be executed.",
|
||||
'k',
|
||||
KillPopupAccept,
|
||||
KillPopupAccept(self.kill_signal.clone()),
|
||||
),
|
||||
Callback::new(
|
||||
"Cancel",
|
||||
@@ -337,11 +469,12 @@ impl TuiMetricsRenderer {
|
||||
}
|
||||
|
||||
struct QuitPopupAccept(Interrupter);
|
||||
struct KillPopupAccept;
|
||||
struct KillPopupAccept(Sender<()>);
|
||||
struct PopupCancel;
|
||||
|
||||
impl CallbackFn for KillPopupAccept {
|
||||
fn call(&self) -> bool {
|
||||
self.0.send(()).unwrap();
|
||||
panic!("Killing training from user input.");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use crate::renderer::ProgressType;
|
||||
|
||||
use super::TerminalFrame;
|
||||
use crate::renderer::{EvaluationProgress, TrainingProgress};
|
||||
use ratatui::{
|
||||
prelude::{Alignment, Rect},
|
||||
style::{Color, Style, Stylize},
|
||||
@@ -9,7 +10,7 @@ use ratatui::{
|
||||
|
||||
/// Show the training status with various information.
|
||||
pub(crate) struct StatusState {
|
||||
progress: TrainingProgress,
|
||||
progress_indicators: Vec<ProgressType>,
|
||||
mode: Mode,
|
||||
}
|
||||
|
||||
@@ -22,7 +23,7 @@ enum Mode {
|
||||
impl Default for StatusState {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
progress: TrainingProgress::none(),
|
||||
progress_indicators: vec![],
|
||||
mode: Mode::Train,
|
||||
}
|
||||
}
|
||||
@@ -30,24 +31,23 @@ impl Default for StatusState {
|
||||
|
||||
impl StatusState {
|
||||
/// Update the training information.
|
||||
pub(crate) fn update_train(&mut self, progress: TrainingProgress) {
|
||||
self.progress = progress;
|
||||
pub(crate) fn update_train(&mut self, progress_indicators: Vec<ProgressType>) {
|
||||
self.progress_indicators = progress_indicators;
|
||||
self.mode = Mode::Train;
|
||||
}
|
||||
/// Update the validation information.
|
||||
pub(crate) fn update_valid(&mut self, progress: TrainingProgress) {
|
||||
self.progress = progress;
|
||||
pub(crate) fn update_valid(&mut self, progress_indicators: Vec<ProgressType>) {
|
||||
self.progress_indicators = progress_indicators;
|
||||
self.mode = Mode::Valid;
|
||||
}
|
||||
/// Update the testing information.
|
||||
pub(crate) fn update_test(&mut self, _progress: EvaluationProgress) {
|
||||
// TODO: Use the progress here.
|
||||
// self.progress = progress;
|
||||
pub(crate) fn update_test(&mut self, progress_indicators: Vec<ProgressType>) {
|
||||
self.progress_indicators = progress_indicators;
|
||||
self.mode = Mode::Evaluation;
|
||||
}
|
||||
/// Create a view.
|
||||
pub(crate) fn view(&self) -> StatusView {
|
||||
StatusView::new(&self.progress, &self.mode)
|
||||
StatusView::new(&self.progress_indicators, &self.mode)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,7 +56,7 @@ pub(crate) struct StatusView {
|
||||
}
|
||||
|
||||
impl StatusView {
|
||||
fn new(progress: &TrainingProgress, mode: &Mode) -> Self {
|
||||
fn new(progress_indicators: &[ProgressType], mode: &Mode) -> Self {
|
||||
let title = |title: &str| Span::from(format!(" {title} ")).bold().yellow();
|
||||
let value = |value: String| Span::from(value).italic();
|
||||
let mode = match mode {
|
||||
@@ -65,26 +65,38 @@ impl StatusView {
|
||||
Mode::Evaluation => "Evaluation",
|
||||
};
|
||||
|
||||
Self {
|
||||
lines: vec![
|
||||
vec![title("Mode :"), value(mode.to_string())],
|
||||
vec![
|
||||
title("Epoch :"),
|
||||
value(format!("{}/{}", progress.epoch, progress.epoch_total)),
|
||||
],
|
||||
vec![
|
||||
title("Iteration :"),
|
||||
value(format!("{}", progress.iteration)),
|
||||
],
|
||||
vec![
|
||||
title("Items :"),
|
||||
value(format!(
|
||||
"{}/{}",
|
||||
progress.progress.items_processed, progress.progress.items_total
|
||||
)),
|
||||
],
|
||||
],
|
||||
}
|
||||
let width = progress_indicators
|
||||
.iter()
|
||||
.map(|p| match p {
|
||||
ProgressType::Detailed { tag, .. } => tag.len(),
|
||||
ProgressType::Value { tag, .. } => tag.len(),
|
||||
})
|
||||
.max()
|
||||
.unwrap_or(4);
|
||||
|
||||
let mut lines = vec![vec![
|
||||
title(&format!("{: <width$} :", "Mode")),
|
||||
value(mode.to_string()),
|
||||
]];
|
||||
|
||||
progress_indicators.iter().for_each(|p| match p {
|
||||
ProgressType::Detailed { tag, progress } => lines.push(vec![
|
||||
title(&format!("{: <width$} :", tag)),
|
||||
value(format!(
|
||||
"{}/{}",
|
||||
progress.items_processed, progress.items_total
|
||||
)),
|
||||
]),
|
||||
ProgressType::Value {
|
||||
tag,
|
||||
value: num_items,
|
||||
} => lines.push(vec![
|
||||
title(&format!("{: <width$} :", tag)),
|
||||
value(format!("{}", num_items)),
|
||||
]),
|
||||
});
|
||||
|
||||
Self { lines }
|
||||
}
|
||||
|
||||
pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) {
|
||||
|
||||
@@ -118,6 +118,7 @@ cubecl = ["dep:cubecl"]
|
||||
|
||||
audio = ["burn-core/audio"]
|
||||
vision = ["burn-core/vision"]
|
||||
rl = ["dep:burn-rl"]
|
||||
|
||||
# Backend
|
||||
ir = ["burn-ir"]
|
||||
@@ -177,6 +178,7 @@ burn-collective = { path = "../burn-collective", version = "=0.21.0", optional =
|
||||
burn-store = { path = "../burn-store", version = "=0.21.0", optional = true, default-features = false }
|
||||
burn-nn = { path = "../burn-nn", version = "=0.21.0", default-features = false }
|
||||
burn-optim = { path = "../burn-optim", version = "=0.21.0", default-features = false }
|
||||
burn-rl = { path = "../burn-rl", version = "=0.21.0", optional = true, default-features = false }
|
||||
|
||||
# Backends
|
||||
burn-autodiff = { path = "../burn-autodiff", version = "=0.21.0", optional = true, default-features = false }
|
||||
|
||||
@@ -124,6 +124,12 @@ pub mod train {
|
||||
pub use burn_train::*;
|
||||
}
|
||||
|
||||
/// Module for reinforcement learning.
|
||||
#[cfg(feature = "rl")]
|
||||
pub mod rl {
|
||||
pub use burn_rl::*;
|
||||
}
|
||||
|
||||
/// Backend module.
|
||||
pub mod backend;
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ publish = false
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
burn = {path = "../../crates/burn", features=["autodiff", "webgpu", "vision"]}
|
||||
burn = {path = "../../crates/burn", default-features = false, features=["autodiff", "webgpu", "vision"]}
|
||||
guide = {path = "../guide"}
|
||||
derive-new = { workspace = true }
|
||||
log = { workspace = true }
|
||||
log = { workspace = true }
|
||||
@@ -1,4 +1,5 @@
|
||||
use crate::model::ModelConfig;
|
||||
use burn::data::dataloader::Progress;
|
||||
use burn::record::NoStdTrainingRecorder;
|
||||
use burn::train::{
|
||||
EventProcessorTraining, Learner, LearningComponentsTypes, SupervisedLearningStrategy,
|
||||
@@ -20,7 +21,7 @@ use burn::{
|
||||
record::CompactRecorder,
|
||||
tensor::{Device, backend::AutodiffBackend},
|
||||
train::{
|
||||
InferenceStep, LearnerEvent, LearnerItem, MetricEarlyStoppingStrategy, StoppingCondition,
|
||||
InferenceStep, LearnerEvent, MetricEarlyStoppingStrategy, StoppingCondition, TrainingItem,
|
||||
metric::{
|
||||
AccuracyMetric, LossMetric,
|
||||
store::{Aggregate, Direction, Split},
|
||||
@@ -167,12 +168,11 @@ impl<LC: LearningComponentsTypes> SupervisedLearningStrategy<LC> for MyCustomLea
|
||||
let item = learner.train_step(item);
|
||||
learner.optimizer_step(item.grads);
|
||||
|
||||
let item = LearnerItem::new(
|
||||
let item = TrainingItem::new(
|
||||
item.item,
|
||||
progress,
|
||||
epoch,
|
||||
num_epochs,
|
||||
iteration,
|
||||
Progress::new(epoch, num_epochs),
|
||||
Some(iteration),
|
||||
Some(learner.lr_current()),
|
||||
);
|
||||
|
||||
@@ -198,7 +198,13 @@ impl<LC: LearningComponentsTypes> SupervisedLearningStrategy<LC> for MyCustomLea
|
||||
iteration += 1;
|
||||
|
||||
let item = model_valid.step(item);
|
||||
let item = LearnerItem::new(item, progress, epoch, num_epochs, iteration, None);
|
||||
let item = TrainingItem::new(
|
||||
item,
|
||||
progress,
|
||||
Progress::new(epoch, num_epochs),
|
||||
Some(iteration),
|
||||
None,
|
||||
);
|
||||
|
||||
event_processor.process_valid(LearnerEvent::ProcessedItem(item));
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ use burn::{
|
||||
Learner, SupervisedTraining,
|
||||
renderer::{
|
||||
EvaluationName, EvaluationProgress, MetricState, MetricsRenderer,
|
||||
MetricsRendererEvaluation, MetricsRendererTraining, TrainingProgress,
|
||||
MetricsRendererEvaluation, MetricsRendererTraining, ProgressType, TrainingProgress,
|
||||
},
|
||||
},
|
||||
};
|
||||
@@ -36,11 +36,11 @@ impl MetricsRendererTraining for CustomRenderer {
|
||||
|
||||
fn update_valid(&mut self, _state: MetricState) {}
|
||||
|
||||
fn render_train(&mut self, item: TrainingProgress) {
|
||||
fn render_train(&mut self, item: TrainingProgress, _progress_indicators: Vec<ProgressType>) {
|
||||
dbg!(item);
|
||||
}
|
||||
|
||||
fn render_valid(&mut self, item: TrainingProgress) {
|
||||
fn render_valid(&mut self, item: TrainingProgress, _progress_indicators: Vec<ProgressType>) {
|
||||
dbg!(item);
|
||||
}
|
||||
}
|
||||
@@ -56,7 +56,7 @@ impl MetricsRenderer for CustomRenderer {
|
||||
impl MetricsRendererEvaluation for CustomRenderer {
|
||||
fn update_test(&mut self, _name: EvaluationName, _state: MetricState) {}
|
||||
|
||||
fn render_test(&mut self, item: EvaluationProgress) {
|
||||
fn render_test(&mut self, item: EvaluationProgress, _progress_indicators: Vec<ProgressType>) {
|
||||
dbg!(item);
|
||||
}
|
||||
}
|
||||
|
||||
41
examples/dqn-agent/Cargo.toml
Normal file
41
examples/dqn-agent/Cargo.toml
Normal file
@@ -0,0 +1,41 @@
|
||||
[package]
|
||||
name = "dqn-agent"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
readme.workspace = true
|
||||
version.workspace = true
|
||||
|
||||
[features]
|
||||
default = ["burn/tui"]
|
||||
ndarray = ["burn/ndarray"]
|
||||
ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"]
|
||||
ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"]
|
||||
ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"]
|
||||
tch-cpu = ["burn/tch"]
|
||||
tch-gpu = ["burn/tch"]
|
||||
remote = ["burn/remote"]
|
||||
wgpu = ["burn/wgpu", "burn/default"]
|
||||
metal = ["burn/metal", "burn/default"]
|
||||
cuda = ["burn/cuda"]
|
||||
vulkan = ["burn/vulkan", "burn/default"]
|
||||
rocm = ["burn/rocm", "burn/default"]
|
||||
|
||||
[dependencies]
|
||||
# Disable autotune default for now (convolutions not optimized)
|
||||
burn = { path = "../../crates/burn", features = [
|
||||
"train",
|
||||
# "vision",
|
||||
"metrics",
|
||||
"std",
|
||||
"rl",
|
||||
# "fusion",
|
||||
"ndarray",
|
||||
# "autotune",
|
||||
], default-features = false }
|
||||
# Just for this example.
|
||||
gym-rs = { version = "0.3.1", branch = "main", git = "https://github.com/MathisWellmann/gym-rs" }
|
||||
rand.workspace = true
|
||||
derive-new = { workspace = true }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
118
examples/dqn-agent/examples/dqn-agent.rs
Normal file
118
examples/dqn-agent/examples/dqn-agent.rs
Normal file
@@ -0,0 +1,118 @@
|
||||
#[cfg(any(
|
||||
feature = "ndarray",
|
||||
feature = "ndarray-blas-netlib",
|
||||
feature = "ndarray-blas-openblas",
|
||||
feature = "ndarray-blas-accelerate",
|
||||
))]
|
||||
mod ndarray {
|
||||
use burn::backend::{
|
||||
Autodiff,
|
||||
ndarray::{NdArray, NdArrayDevice},
|
||||
};
|
||||
use dqn_agent::training;
|
||||
|
||||
pub fn run() {
|
||||
let device = NdArrayDevice::Cpu;
|
||||
training::run::<Autodiff<NdArray>>(device);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tch-gpu")]
|
||||
mod tch_gpu {
|
||||
use burn::backend::{
|
||||
Autodiff,
|
||||
libtorch::{LibTorch, LibTorchDevice},
|
||||
};
|
||||
use dqn_agent::training;
|
||||
|
||||
pub fn run() {
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
let device = LibTorchDevice::Cuda(0);
|
||||
#[cfg(target_os = "macos")]
|
||||
let device = LibTorchDevice::Mps;
|
||||
|
||||
training::run::<Autodiff<LibTorch>>(device);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "wgpu", feature = "metal", feature = "vulkan"))]
|
||||
mod wgpu {
|
||||
use burn::backend::{
|
||||
Autodiff,
|
||||
wgpu::{Wgpu, WgpuDevice},
|
||||
};
|
||||
use dqn_agent::training;
|
||||
|
||||
pub fn run() {
|
||||
let device = WgpuDevice::default();
|
||||
training::run::<Autodiff<Wgpu>>(device);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
mod cuda {
|
||||
use burn::backend::{Autodiff, Cuda};
|
||||
use dqn_agent::training;
|
||||
|
||||
pub fn run() {
|
||||
let device = Default::default();
|
||||
training::run::<Autodiff<Cuda>>(device);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "rocm")]
|
||||
mod rocm {
|
||||
use burn::backend::{Autodiff, Rocm};
|
||||
use dqn_agent::training;
|
||||
|
||||
pub fn run() {
|
||||
let device = Default::default();
|
||||
training::run::<Autodiff<Rocm>>(device);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tch-cpu")]
|
||||
mod tch_cpu {
|
||||
use burn::backend::{
|
||||
Autodiff,
|
||||
libtorch::{LibTorch, LibTorchDevice},
|
||||
};
|
||||
use dqn_agent::training;
|
||||
|
||||
pub fn run() {
|
||||
let device = LibTorchDevice::Cpu;
|
||||
training::run::<Autodiff<LibTorch>>(device);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
mod remote {
|
||||
use burn::backend::{Autodiff, RemoteBackend};
|
||||
use dqn_agent::training;
|
||||
|
||||
pub fn run() {
|
||||
training::run::<Autodiff<RemoteBackend>>(Default::default());
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
#[cfg(any(
|
||||
feature = "ndarray",
|
||||
feature = "ndarray-blas-netlib",
|
||||
feature = "ndarray-blas-openblas",
|
||||
feature = "ndarray-blas-accelerate",
|
||||
))]
|
||||
ndarray::run();
|
||||
#[cfg(feature = "tch-gpu")]
|
||||
tch_gpu::run();
|
||||
#[cfg(feature = "tch-cpu")]
|
||||
tch_cpu::run();
|
||||
#[cfg(any(feature = "wgpu", feature = "metal", feature = "vulkan"))]
|
||||
wgpu::run();
|
||||
#[cfg(feature = "cuda")]
|
||||
cuda::run();
|
||||
#[cfg(feature = "rocm")]
|
||||
rocm::run();
|
||||
#[cfg(feature = "remote")]
|
||||
remote::run();
|
||||
}
|
||||
477
examples/dqn-agent/src/agent.rs
Normal file
477
examples/dqn-agent/src/agent.rs
Normal file
@@ -0,0 +1,477 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use burn::backend::NdArray;
|
||||
use burn::module::Module;
|
||||
use burn::record::Record;
|
||||
use burn::rl::{
|
||||
Batchable, LearnerTransitionBatch, Policy, PolicyLearner, PolicyState, RLTrainOutput,
|
||||
};
|
||||
use burn::tensor::Transaction;
|
||||
use burn::tensor::activation::softmax;
|
||||
use burn::train::ItemLazy;
|
||||
use burn::train::metric::{Adaptor, LossInput};
|
||||
use burn::{
|
||||
Tensor,
|
||||
config::Config,
|
||||
module::AutodiffModule,
|
||||
nn::{self, loss::MseLoss},
|
||||
optim::{GradientsParams, Optimizer},
|
||||
prelude::Backend,
|
||||
tensor::backend::AutodiffBackend,
|
||||
};
|
||||
use rand::distr::Distribution;
|
||||
use rand::distr::weighted::WeightedIndex;
|
||||
use rand::rng;
|
||||
|
||||
use crate::utils::{
|
||||
EpsilonGreedyPolicy, EpsilonGreedyPolicyState, create_lin_layers, soft_update_linear,
|
||||
};
|
||||
|
||||
pub trait DiscreteActionModel<B: Backend>: Module<B> {
|
||||
type Input: Clone + Send + Batchable;
|
||||
|
||||
fn forward(&self, input: Self::Input) -> DiscreteLogitsTensor<B, 2>;
|
||||
}
|
||||
|
||||
#[derive(Config, Debug)]
|
||||
pub struct MlpNetConfig {
|
||||
/// The number of layers.
|
||||
#[config(default = 3)]
|
||||
pub num_layers: usize,
|
||||
/// The dropout rate.
|
||||
#[config(default = 0.)]
|
||||
pub dropout: f64,
|
||||
/// The input dimension.
|
||||
#[config(default = 4)]
|
||||
pub d_input: usize,
|
||||
/// The output dimension.
|
||||
#[config(default = 2)]
|
||||
pub d_output: usize,
|
||||
/// The size of hidden layers.
|
||||
#[config(default = 256)]
|
||||
pub d_hidden: usize,
|
||||
}
|
||||
|
||||
/// Multilayer Perceptron Network.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct MlpNet<B: Backend> {
|
||||
pub linears: Vec<nn::Linear<B>>,
|
||||
pub dropout: nn::Dropout,
|
||||
pub activation: nn::Relu,
|
||||
}
|
||||
|
||||
impl<B: Backend> MlpNet<B> {
|
||||
/// Create the module from the given configuration.
|
||||
pub fn new(config: &MlpNetConfig, device: &B::Device) -> Self {
|
||||
Self {
|
||||
linears: create_lin_layers(
|
||||
config.num_layers,
|
||||
config.d_input,
|
||||
config.d_hidden,
|
||||
config.d_output,
|
||||
device,
|
||||
),
|
||||
dropout: nn::DropoutConfig::new(config.dropout).init(),
|
||||
activation: nn::Relu::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ObservationTensor<B: Backend, const D: usize> {
|
||||
pub state: Tensor<B, D>,
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> Batchable for ObservationTensor<B, D> {
|
||||
fn batch(value: Vec<Self>) -> Self {
|
||||
let tensors = value.iter().map(|v| v.state.clone()).collect();
|
||||
Self {
|
||||
state: Tensor::cat(tensors, 0),
|
||||
}
|
||||
}
|
||||
|
||||
fn unbatch(self) -> Vec<Self> {
|
||||
self.state
|
||||
.split(1, 0)
|
||||
.iter()
|
||||
.map(|s| ObservationTensor { state: s.clone() })
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> DiscreteActionModel<B> for MlpNet<B> {
|
||||
type Input = ObservationTensor<B, 2>;
|
||||
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[batch_size, d_input]`
|
||||
/// - output: `[batch_size, d_output]`
|
||||
fn forward(&self, input: Self::Input) -> DiscreteLogitsTensor<B, 2> {
|
||||
let mut x = input.state;
|
||||
|
||||
for (i, linear) in self.linears.iter().enumerate() {
|
||||
x = linear.forward(x);
|
||||
x = self.dropout.forward(x);
|
||||
if i < self.linears.len() - 1 {
|
||||
x = self.activation.forward(x);
|
||||
}
|
||||
}
|
||||
|
||||
DiscreteLogitsTensor { logits: x }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config, Debug)]
|
||||
pub struct DqnAgentConfig {
|
||||
/// Discount factor (How to value long-term vs short-term rewards)
|
||||
#[config(default = 0.99)]
|
||||
pub gamma: f64,
|
||||
/// The learning rate
|
||||
#[config(default = 3e-4)]
|
||||
pub learning_rate: f64,
|
||||
/// The soft update rate of the target network
|
||||
#[config(default = 0.005)]
|
||||
pub tau: f64,
|
||||
/// Initial value of epsilon (Probability to choose a random action)
|
||||
#[config(default = 0.9)]
|
||||
pub epsilon_start: f64,
|
||||
/// Final value of epsilon (Probability to choose a random action)
|
||||
#[config(default = 0.01)]
|
||||
pub epsilon_end: f64,
|
||||
/// The exponential rate at which the epsilon value decays. Higher = slower decay
|
||||
#[config(default = 2500.0)]
|
||||
pub epsilon_decay: f64,
|
||||
}
|
||||
|
||||
pub trait TargetModel<B: Backend> {
|
||||
fn soft_update(&self, that: &Self, tau: f64) -> Self;
|
||||
}
|
||||
|
||||
impl<B: Backend> TargetModel<B> for MlpNet<B> {
|
||||
fn soft_update(&self, that: &Self, tau: f64) -> Self {
|
||||
let mut linears = Vec::with_capacity(self.linears.len());
|
||||
for i in 0..self.linears.len() {
|
||||
let layer = soft_update_linear(self.linears[i].clone(), &that.linears[i].clone(), tau);
|
||||
linears.insert(i, layer);
|
||||
}
|
||||
Self {
|
||||
linears,
|
||||
dropout: self.dropout.clone(),
|
||||
activation: self.activation.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct DqnState<B: Backend, M: DiscreteActionModel<B>> {
|
||||
model: M,
|
||||
_backend: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend, M: DiscreteActionModel<B>> PolicyState<B> for DqnState<B, M> {
|
||||
type Record = M::Record;
|
||||
|
||||
fn into_record(self) -> Self::Record {
|
||||
self.model.clone().into_record()
|
||||
}
|
||||
|
||||
fn load_record(&self, record: Self::Record) -> Self {
|
||||
Self {
|
||||
model: self.model.clone().load_record(record),
|
||||
_backend: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct DQN<B: Backend, M: DiscreteActionModel<B>> {
|
||||
model: M,
|
||||
_backend: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend, M: DiscreteActionModel<B>> DQN<B, M> {
|
||||
pub fn new(policy: M) -> Self {
|
||||
Self {
|
||||
model: policy,
|
||||
_backend: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct DiscreteLogitsTensor<B: Backend, const D: usize> {
|
||||
pub logits: Tensor<B, D>,
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> Batchable for DiscreteLogitsTensor<B, D> {
|
||||
fn batch(value: Vec<Self>) -> Self {
|
||||
let tensors = value.iter().map(|v| v.logits.clone()).collect();
|
||||
Self {
|
||||
logits: Tensor::cat(tensors, 0),
|
||||
}
|
||||
}
|
||||
|
||||
fn unbatch(self) -> Vec<Self> {
|
||||
self.logits
|
||||
.split(1, 0)
|
||||
.iter()
|
||||
.map(|l| DiscreteLogitsTensor { logits: l.clone() })
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct DiscreteActionTensor<B: Backend, const D: usize> {
|
||||
pub actions: Tensor<B, D>,
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> Batchable for DiscreteActionTensor<B, D> {
|
||||
fn batch(value: Vec<Self>) -> Self {
|
||||
let tensors = value.iter().map(|v| v.actions.clone()).collect();
|
||||
Self {
|
||||
actions: Tensor::cat(tensors, 0),
|
||||
}
|
||||
}
|
||||
|
||||
fn unbatch(self) -> Vec<Self> {
|
||||
self.actions
|
||||
.split(1, 0)
|
||||
.iter()
|
||||
.map(|a| DiscreteActionTensor { actions: a.clone() })
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, M: DiscreteActionModel<B>> Policy<B> for DQN<B, M> {
|
||||
type Observation = M::Input;
|
||||
type ActionDistribution = DiscreteLogitsTensor<B, 2>;
|
||||
type Action = DiscreteActionTensor<B, 2>;
|
||||
|
||||
type ActionContext = ();
|
||||
type PolicyState = DqnState<B, M>;
|
||||
|
||||
fn forward(&mut self, states: Self::Observation) -> Self::ActionDistribution {
|
||||
self.model.forward(states)
|
||||
}
|
||||
|
||||
fn action(
|
||||
&mut self,
|
||||
states: Self::Observation,
|
||||
deterministic: bool,
|
||||
) -> (Self::Action, Vec<Self::ActionContext>) {
|
||||
let logits = self.forward(states).logits;
|
||||
if deterministic {
|
||||
let output = DiscreteActionTensor {
|
||||
actions: logits.argmax(1).float(),
|
||||
};
|
||||
return (output, vec![]);
|
||||
}
|
||||
|
||||
let mut actions = vec![];
|
||||
let probs = softmax(logits, 1);
|
||||
let probs = probs.split(1, 0);
|
||||
let mut rng = rng();
|
||||
for p in probs {
|
||||
let dist = WeightedIndex::new(p.to_data().to_vec::<f32>().unwrap()).unwrap();
|
||||
let action = dist.sample(&mut rng);
|
||||
actions.push(Tensor::<B, 1>::from_floats([action], &p.device()));
|
||||
}
|
||||
|
||||
let output = DiscreteActionTensor {
|
||||
actions: Tensor::stack(actions, 1),
|
||||
};
|
||||
(output, vec![])
|
||||
}
|
||||
|
||||
fn update(&mut self, update: Self::PolicyState) {
|
||||
self.model = update.model;
|
||||
}
|
||||
|
||||
fn state(&self) -> Self::PolicyState {
|
||||
DqnState {
|
||||
model: self.model.clone(),
|
||||
_backend: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn load_record(self, record: <Self::PolicyState as PolicyState<B>>::Record) -> Self {
|
||||
let state = self.state().load_record(record);
|
||||
Self {
|
||||
model: state.model,
|
||||
_backend: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Record)]
|
||||
pub struct DqnLearningRecord<B: AutodiffBackend, M: AutodiffModule<B>, O: Optimizer<M, B>> {
|
||||
policy_model: M::Record,
|
||||
target_model: M::Record,
|
||||
optimizer: O::Record,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct DqnLearningAgent<B, M, O>
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
M: DiscreteActionModel<B> + AutodiffModule<B> + TargetModel<B> + 'static,
|
||||
M::InnerModule: DiscreteActionModel<B::InnerBackend> + TargetModel<B::InnerBackend>,
|
||||
O: Optimizer<M, B> + 'static,
|
||||
{
|
||||
policy_model: M,
|
||||
target_model: M,
|
||||
agent: EpsilonGreedyPolicy<B, DQN<B, M>>,
|
||||
optimizer: O,
|
||||
config: DqnAgentConfig,
|
||||
}
|
||||
|
||||
impl<B, M, O> DqnLearningAgent<B, M, O>
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
M: DiscreteActionModel<B> + AutodiffModule<B> + TargetModel<B> + 'static,
|
||||
M::InnerModule: DiscreteActionModel<B::InnerBackend> + TargetModel<B::InnerBackend>,
|
||||
O: Optimizer<M, B> + 'static,
|
||||
{
|
||||
pub fn new(model: M, optimizer: O, config: DqnAgentConfig) -> Self {
|
||||
let agent = EpsilonGreedyPolicy::new(
|
||||
DQN::new(model.clone()),
|
||||
config.epsilon_start,
|
||||
config.epsilon_end,
|
||||
config.epsilon_decay,
|
||||
);
|
||||
Self {
|
||||
policy_model: model.clone(),
|
||||
target_model: model,
|
||||
agent,
|
||||
optimizer,
|
||||
config,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SimpleTrainOutput<B: Backend> {
|
||||
pub policy_model_loss: Tensor<B, 1>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ItemLazy for SimpleTrainOutput<B> {
|
||||
type ItemSync = SimpleTrainOutput<NdArray>;
|
||||
|
||||
fn sync(self) -> Self::ItemSync {
|
||||
let [loss] = Transaction::default()
|
||||
.register(self.policy_model_loss)
|
||||
.execute()
|
||||
.try_into()
|
||||
.expect("Correct amount of tensor data");
|
||||
|
||||
let device = &Default::default();
|
||||
|
||||
SimpleTrainOutput {
|
||||
policy_model_loss: Tensor::from_data(loss, device),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Adaptor<LossInput<B>> for SimpleTrainOutput<B> {
|
||||
fn adapt(&self) -> LossInput<B> {
|
||||
LossInput::new(self.policy_model_loss.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, M, O> PolicyLearner<B> for DqnLearningAgent<B, M, O>
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
M: DiscreteActionModel<B> + AutodiffModule<B> + TargetModel<B> + 'static,
|
||||
M::Input: Clone,
|
||||
M::InnerModule: DiscreteActionModel<B::InnerBackend> + TargetModel<B::InnerBackend>,
|
||||
O: Optimizer<M, B> + 'static,
|
||||
{
|
||||
type TrainContext = SimpleTrainOutput<B>;
|
||||
type InnerPolicy = EpsilonGreedyPolicy<B, DQN<B, M>>;
|
||||
type Record = DqnLearningRecord<B, M, O>;
|
||||
|
||||
fn train(
|
||||
&mut self,
|
||||
input: LearnerTransitionBatch<B, Self::InnerPolicy>,
|
||||
) -> RLTrainOutput<Self::TrainContext, <Self::InnerPolicy as Policy<B>>::PolicyState> {
|
||||
let states_batch = input.states;
|
||||
let next_states_batch = input.next_states;
|
||||
let actions_batch = input.actions.actions;
|
||||
let rewards_batch = input.rewards;
|
||||
let dones_batch = input.dones;
|
||||
|
||||
// Optimize
|
||||
let logits = self.policy_model.forward(states_batch).logits;
|
||||
let state_action_values = logits.gather(1, actions_batch.int());
|
||||
|
||||
let next_state_values = self.target_model.forward(next_states_batch.clone());
|
||||
let next_state_values = next_state_values.logits.max_dim(1).squeeze::<1>();
|
||||
|
||||
let not_done_batch = Tensor::ones_like(&dones_batch) - dones_batch;
|
||||
let expected_state_action_values = (next_state_values * not_done_batch.squeeze())
|
||||
.mul_scalar(self.config.gamma)
|
||||
+ rewards_batch.squeeze();
|
||||
let expected_state_action_values = expected_state_action_values.unsqueeze_dim::<2>(1);
|
||||
|
||||
let loss = MseLoss::new().forward(
|
||||
state_action_values,
|
||||
expected_state_action_values,
|
||||
nn::loss::Reduction::Mean,
|
||||
);
|
||||
let gradients = loss.backward();
|
||||
let gradient_params = GradientsParams::from_grads(gradients, &self.policy_model);
|
||||
self.policy_model = self.optimizer.step(
|
||||
self.config.learning_rate,
|
||||
self.policy_model.clone(),
|
||||
gradient_params,
|
||||
);
|
||||
self.target_model = self
|
||||
.target_model
|
||||
.soft_update(&self.policy_model, self.config.tau);
|
||||
let policy_update = EpsilonGreedyPolicyState::new(
|
||||
DqnState {
|
||||
model: self.policy_model.clone(),
|
||||
_backend: PhantomData,
|
||||
},
|
||||
self.agent.state().step,
|
||||
);
|
||||
self.agent.update(policy_update.clone());
|
||||
RLTrainOutput {
|
||||
policy: policy_update,
|
||||
item: SimpleTrainOutput {
|
||||
policy_model_loss: loss,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn policy(&self) -> Self::InnerPolicy {
|
||||
self.agent.clone()
|
||||
}
|
||||
|
||||
fn update_policy(&mut self, update: Self::InnerPolicy) {
|
||||
self.agent = update;
|
||||
}
|
||||
|
||||
fn record(&self) -> Self::Record {
|
||||
DqnLearningRecord {
|
||||
policy_model: self.policy_model.clone().into_record(),
|
||||
target_model: self.target_model.clone().into_record(),
|
||||
optimizer: self.optimizer.to_record(),
|
||||
}
|
||||
}
|
||||
|
||||
fn load_record(self, record: Self::Record) -> Self {
|
||||
let policy_model = self.policy_model.load_record(record.policy_model);
|
||||
let target_model = self.target_model.load_record(record.target_model);
|
||||
let optimizer = self.optimizer.load_record(record.optimizer);
|
||||
Self {
|
||||
policy_model,
|
||||
target_model,
|
||||
agent: self.agent,
|
||||
optimizer,
|
||||
config: self.config,
|
||||
}
|
||||
}
|
||||
}
|
||||
101
examples/dqn-agent/src/env.rs
Normal file
101
examples/dqn-agent/src/env.rs
Normal file
@@ -0,0 +1,101 @@
|
||||
use burn::rl::{Environment, StepResult};
|
||||
use burn::{
|
||||
Tensor,
|
||||
prelude::{Backend, ToElement},
|
||||
};
|
||||
use gym_rs::{
|
||||
core::Env,
|
||||
envs::classical_control::cartpole::{CartPoleEnv, CartPoleObservation},
|
||||
};
|
||||
|
||||
use crate::agent::{DiscreteActionTensor, ObservationTensor};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CartPoleAction {
|
||||
action: usize,
|
||||
}
|
||||
|
||||
impl<B: Backend> From<DiscreteActionTensor<B, 2>> for CartPoleAction {
|
||||
fn from(value: DiscreteActionTensor<B, 2>) -> Self {
|
||||
Self {
|
||||
action: value.actions.int().into_scalar().to_usize(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> From<CartPoleAction> for DiscreteActionTensor<B, 2> {
|
||||
fn from(value: CartPoleAction) -> Self {
|
||||
DiscreteActionTensor {
|
||||
actions: Tensor::<B, 1>::from_data([value.action], &Default::default()).unsqueeze(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CartPoleState {
|
||||
pub state: [f64; 4],
|
||||
}
|
||||
|
||||
impl From<CartPoleObservation> for CartPoleState {
|
||||
fn from(observation: CartPoleObservation) -> Self {
|
||||
let vec = Vec::<f64>::from(observation);
|
||||
Self {
|
||||
state: [vec[0], vec[1], vec[2], vec[3]],
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<B: Backend> From<CartPoleState> for ObservationTensor<B, 2> {
|
||||
fn from(val: CartPoleState) -> Self {
|
||||
ObservationTensor {
|
||||
state: Tensor::<B, 1>::from_floats(val.state, &Default::default()).unsqueeze(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CartPoleWrapper {
|
||||
gym_env: CartPoleEnv,
|
||||
step_index: usize,
|
||||
}
|
||||
|
||||
impl Default for CartPoleWrapper {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl CartPoleWrapper {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
gym_env: CartPoleEnv::new(gym_rs::utils::renderer::RenderMode::None),
|
||||
step_index: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Environment for CartPoleWrapper {
|
||||
type State = CartPoleState;
|
||||
type Action = CartPoleAction;
|
||||
|
||||
const MAX_STEPS: usize = 500;
|
||||
|
||||
fn state(&self) -> Self::State {
|
||||
CartPoleState::from(self.gym_env.state)
|
||||
}
|
||||
|
||||
fn step(&mut self, action: Self::Action) -> StepResult<Self::State> {
|
||||
let action_reward = self.gym_env.step(action.action);
|
||||
self.step_index += 1;
|
||||
StepResult {
|
||||
next_state: CartPoleState::from(action_reward.observation),
|
||||
reward: action_reward.reward.into_inner(),
|
||||
done: action_reward.done,
|
||||
truncated: self.step_index >= Self::MAX_STEPS,
|
||||
}
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
self.gym_env.reset(None, false, None);
|
||||
self.step_index = 0;
|
||||
}
|
||||
}
|
||||
4
examples/dqn-agent/src/lib.rs
Normal file
4
examples/dqn-agent/src/lib.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
pub mod agent;
|
||||
pub mod env;
|
||||
pub mod training;
|
||||
pub mod utils;
|
||||
64
examples/dqn-agent/src/training.rs
Normal file
64
examples/dqn-agent/src/training.rs
Normal file
@@ -0,0 +1,64 @@
|
||||
use burn::{
|
||||
grad_clipping::GradientClippingConfig,
|
||||
optim::AdamWConfig,
|
||||
record::CompactRecorder,
|
||||
tensor::backend::AutodiffBackend,
|
||||
train::{
|
||||
OffPolicyConfig, RLTraining,
|
||||
metric::{CumulativeRewardMetric, EpisodeLengthMetric, ExplorationRateMetric, LossMetric},
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
agent::{DqnAgentConfig, DqnLearningAgent, MlpNet, MlpNetConfig},
|
||||
env::CartPoleWrapper,
|
||||
};
|
||||
|
||||
static ARTIFACT_DIR: &str = "/tmp/burn-example-dqn-agent";
|
||||
|
||||
pub fn run<B: AutodiffBackend>(device: B::Device) {
|
||||
let dqn_config = DqnAgentConfig {
|
||||
gamma: 0.99,
|
||||
learning_rate: 3e-4,
|
||||
tau: 0.005,
|
||||
epsilon_start: 0.99,
|
||||
epsilon_end: 0.05,
|
||||
epsilon_decay: 6000.0,
|
||||
};
|
||||
let model_config = MlpNetConfig {
|
||||
num_layers: 3,
|
||||
dropout: 0.0,
|
||||
d_input: 4,
|
||||
d_output: 2,
|
||||
d_hidden: 64,
|
||||
};
|
||||
let learning_config = OffPolicyConfig {
|
||||
num_envs: 8,
|
||||
autobatch_size: 8,
|
||||
replay_buffer_size: 50_000,
|
||||
train_interval: 8,
|
||||
eval_interval: 4_000,
|
||||
eval_episodes: 5,
|
||||
train_batch_size: 128,
|
||||
train_steps: 4,
|
||||
warmup_steps: 0,
|
||||
};
|
||||
|
||||
let policy_model = MlpNet::<B>::new(&model_config, &device);
|
||||
let optimizer = AdamWConfig::new()
|
||||
.with_grad_clipping(Some(GradientClippingConfig::Value(100.0)))
|
||||
.init();
|
||||
let agent = DqnLearningAgent::new(policy_model, optimizer, dqn_config);
|
||||
let learner = RLTraining::new(ARTIFACT_DIR, CartPoleWrapper::new)
|
||||
.metrics_train((LossMetric::new(),))
|
||||
.metrics_agent((ExplorationRateMetric::new(),))
|
||||
.metrics_episode((EpisodeLengthMetric::new(), CumulativeRewardMetric::new()))
|
||||
.with_file_checkpointer(CompactRecorder::new())
|
||||
.num_steps(40_000)
|
||||
.with_learning_strategy(burn::train::RLStrategies::OffPolicyStrategy(
|
||||
learning_config,
|
||||
))
|
||||
.summary();
|
||||
|
||||
let _result = learner.launch(agent);
|
||||
}
|
||||
226
examples/dqn-agent/src/utils.rs
Normal file
226
examples/dqn-agent/src/utils.rs
Normal file
@@ -0,0 +1,226 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use burn::{
|
||||
Tensor,
|
||||
module::{Param, ParamId},
|
||||
nn::{self, Linear},
|
||||
prelude::Backend,
|
||||
record::Record,
|
||||
rl::{Policy, PolicyState},
|
||||
tensor::Device,
|
||||
train::{
|
||||
ItemLazy,
|
||||
metric::{Adaptor, ExplorationRateInput},
|
||||
},
|
||||
};
|
||||
use derive_new::new;
|
||||
use rand::{random, random_range};
|
||||
|
||||
use crate::agent::{DiscreteActionTensor, DiscreteLogitsTensor};
|
||||
|
||||
pub fn create_lin_layers<B: Backend>(
|
||||
num_layers: usize,
|
||||
d_input: usize,
|
||||
d_hidden: usize,
|
||||
d_output: usize,
|
||||
device: &Device<B>,
|
||||
) -> Vec<Linear<B>> {
|
||||
let mut linears = Vec::with_capacity(num_layers);
|
||||
|
||||
if num_layers == 1 {
|
||||
linears.push(nn::LinearConfig::new(d_input, d_output).init(device));
|
||||
return linears;
|
||||
}
|
||||
for i in 0..num_layers {
|
||||
if i == 0 {
|
||||
linears.push(nn::LinearConfig::new(d_input, d_hidden).init(device));
|
||||
} else if i == num_layers - 1 {
|
||||
linears.push(nn::LinearConfig::new(d_hidden, d_output).init(device));
|
||||
} else {
|
||||
linears.push(nn::LinearConfig::new(d_hidden, d_hidden).init(device));
|
||||
}
|
||||
}
|
||||
linears
|
||||
}
|
||||
|
||||
pub fn soft_update_linear<B: Backend>(this: Linear<B>, that: &Linear<B>, tau: f64) -> Linear<B> {
|
||||
let weight = soft_update_tensor(&this.weight, &that.weight, tau);
|
||||
let bias = match (&this.bias, &that.bias) {
|
||||
(Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
Linear::<B> { weight, bias }
|
||||
}
|
||||
|
||||
fn soft_update_tensor<const N: usize, B: Backend>(
|
||||
this: &Param<Tensor<B, N>>,
|
||||
that: &Param<Tensor<B, N>>,
|
||||
tau: f64,
|
||||
) -> Param<Tensor<B, N>> {
|
||||
let that_weight = that.val();
|
||||
let this_weight = this.val();
|
||||
let new_weight = this_weight * (1.0 - tau) + that_weight * tau;
|
||||
|
||||
Param::initialized(ParamId::new(), new_weight)
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EpsilonGreedyPolicyOutput {
|
||||
pub epsilon: f64,
|
||||
}
|
||||
|
||||
impl ItemLazy for EpsilonGreedyPolicyOutput {
|
||||
type ItemSync = EpsilonGreedyPolicyOutput;
|
||||
|
||||
fn sync(self) -> Self::ItemSync {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Adaptor<ExplorationRateInput> for EpsilonGreedyPolicyOutput {
|
||||
fn adapt(&self) -> ExplorationRateInput {
|
||||
ExplorationRateInput::new(self.epsilon)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Record)]
|
||||
pub struct EpsilonGreedyPolicyRecord<B: Backend, P: Policy<B>> {
|
||||
pub inner_state: <P::PolicyState as PolicyState<B>>::Record,
|
||||
pub step: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, new)]
|
||||
pub struct EpsilonGreedyPolicyState<B: Backend, P: Policy<B>> {
|
||||
pub inner_state: P::PolicyState,
|
||||
pub step: usize,
|
||||
}
|
||||
|
||||
impl<B: Backend, P: Policy<B>> PolicyState<B> for EpsilonGreedyPolicyState<B, P> {
|
||||
type Record = EpsilonGreedyPolicyRecord<B, P>;
|
||||
|
||||
fn into_record(self) -> Self::Record {
|
||||
EpsilonGreedyPolicyRecord {
|
||||
inner_state: self.inner_state.into_record(),
|
||||
step: self.step,
|
||||
}
|
||||
}
|
||||
|
||||
fn load_record(&self, record: Self::Record) -> Self {
|
||||
let inner_state = self.inner_state.load_record(record.inner_state);
|
||||
Self {
|
||||
inner_state,
|
||||
step: record.step,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct EpsilonGreedyPolicy<B: Backend, P: Policy<B>> {
|
||||
inner_policy: P,
|
||||
eps_start: f64,
|
||||
eps_end: f64,
|
||||
eps_decay: f64,
|
||||
step: usize,
|
||||
_backend: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend, P: Policy<B>> EpsilonGreedyPolicy<B, P> {
|
||||
pub fn new(inner_policy: P, eps_start: f64, eps_end: f64, eps_decay: f64) -> Self {
|
||||
Self {
|
||||
inner_policy,
|
||||
eps_start,
|
||||
eps_end,
|
||||
eps_decay,
|
||||
step: 0,
|
||||
_backend: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_threshold(&self) -> f64 {
|
||||
self.eps_end
|
||||
+ (self.eps_start - self.eps_end) * f64::exp(-(self.step as f64) / self.eps_decay)
|
||||
}
|
||||
|
||||
fn step(&mut self) -> f64 {
|
||||
let thresh = self.get_threshold();
|
||||
self.step += 1;
|
||||
thresh
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, P> Policy<B> for EpsilonGreedyPolicy<B, P>
|
||||
where
|
||||
B: Backend,
|
||||
P: Policy<
|
||||
B,
|
||||
ActionDistribution = DiscreteLogitsTensor<B, 2>,
|
||||
Action = DiscreteActionTensor<B, 2>,
|
||||
>,
|
||||
{
|
||||
type ActionContext = EpsilonGreedyPolicyOutput;
|
||||
type PolicyState = EpsilonGreedyPolicyState<B, P>;
|
||||
|
||||
type Observation = P::Observation;
|
||||
type ActionDistribution = DiscreteLogitsTensor<B, 2>;
|
||||
type Action = DiscreteActionTensor<B, 2>;
|
||||
|
||||
fn forward(&mut self, states: Self::Observation) -> Self::ActionDistribution {
|
||||
self.inner_policy.forward(states)
|
||||
}
|
||||
|
||||
fn action(
|
||||
&mut self,
|
||||
states: Self::Observation,
|
||||
deterministic: bool,
|
||||
) -> (Self::Action, Vec<Self::ActionContext>) {
|
||||
let logits = self.inner_policy.forward(states).logits;
|
||||
let greedy_actions = logits.argmax(1);
|
||||
let greedy_actions = greedy_actions.split(1, 0);
|
||||
|
||||
let mut contexts = vec![];
|
||||
let mut actions = vec![];
|
||||
for a in greedy_actions {
|
||||
let threshold = self.step();
|
||||
let threshold = if deterministic { 0.0 } else { threshold };
|
||||
contexts.push(EpsilonGreedyPolicyOutput { epsilon: threshold });
|
||||
if random::<f64>() > threshold {
|
||||
actions.push(a.clone().float());
|
||||
} else {
|
||||
actions.push(
|
||||
Tensor::<B, 1>::from_floats([random_range(0..2)], &a.device()).unsqueeze(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let output = Tensor::cat(actions, 0);
|
||||
(DiscreteActionTensor { actions: output }, contexts)
|
||||
}
|
||||
|
||||
fn update(&mut self, update: Self::PolicyState) {
|
||||
// Note : updating an epsilon greedy policy doesn't change the step.
|
||||
self.inner_policy.update(update.inner_state);
|
||||
}
|
||||
|
||||
fn state(&self) -> Self::PolicyState {
|
||||
EpsilonGreedyPolicyState {
|
||||
inner_state: self.inner_policy.state(),
|
||||
step: self.step,
|
||||
}
|
||||
}
|
||||
|
||||
fn load_record(self, record: <Self::PolicyState as PolicyState<B>>::Record) -> Self {
|
||||
let state = self.state().load_record(record);
|
||||
let inner_policy = self
|
||||
.inner_policy
|
||||
.load_record(state.inner_state.into_record());
|
||||
EpsilonGreedyPolicy {
|
||||
inner_policy,
|
||||
eps_start: self.eps_start,
|
||||
eps_end: self.eps_end,
|
||||
eps_decay: self.eps_decay,
|
||||
step: state.step,
|
||||
_backend: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -118,6 +118,9 @@ pub(crate) fn handle_command(
|
||||
"burn-router".to_string(),
|
||||
"burn-tch".to_string(),
|
||||
"burn-wgpu".to_string(),
|
||||
// dqn-agent example relies on gym-rs dependency which requires SDL2.
|
||||
// It would be good not to remove the gym-rs dependency in the future.
|
||||
"dqn-agent".to_string(),
|
||||
]);
|
||||
|
||||
// Burn remote tests don't work on windows for now
|
||||
|
||||
Reference in New Issue
Block a user