* wip burn-rl

* clean up types traits, remove paradigm trait, improve learner

* code quality

* book

* doc

* error in doc

* revert TrainStep rename and remove paradigm components

* trainstep refactor with assoc types

* fix docs

* fix warning

* comment

* remove evaluator step

* add more rl stuff and change Adaptor trait

* metrics for rl and revert adaptor change

* docs + cleanup event processor for rl

* policy vs agent refactor

* rework renderer

* policy no longer generic over env

* early version of checkpointing

* versioning

* off policy training loop

* naming consistency and loading from record

* naming

* add other backends

* strategies and configs

* env initializer and cum reward metric

* transition backend thing compiles

* mostly docs and naming

* reorganize files

* batchable trait and re-arrange some bounds

* start env runner refactor

* rework autobatcher and env_runnner episodes

* change autobatcher dynamic and tweak dqn params

* fix typos and warnings

* params tweeking

* file naming

* fix docs

* metrics api and dependency bump

* address PR comments and fix lint

* kill process when user kills training

* handle kill signal renderer

* format

* remove dqn-agent from ci tests bc of dependency issue
This commit is contained in:
Charles23R
2026-02-04 15:55:44 -05:00
committed by GitHub
parent ddd6438d9c
commit 4b259aa3b3
75 changed files with 5150 additions and 333 deletions

View File

@@ -11,6 +11,7 @@ ignore = [
"RUSTSEC-2024-0436", # Paste used to generate macro, should be removed at some point.
"RUSTSEC-2025-0119", # `number_prefix` used by `tokenizers`, only in the examples.
"RUSTSEC-2025-0141", # `bincode` is no longer maintained.
"RUSTSEC-2024-0388", # `derivative` dependancy in the DQN example is unmaintained.
] # advisory IDs to ignore e.g. ["RUSTSEC-2019-0001", ...]
informational_warnings = [
"unmaintained",

171
Cargo.lock generated
View File

@@ -146,6 +146,15 @@ version = "1.0.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61"
[[package]]
name = "approx"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6"
dependencies = [
"num-traits",
]
[[package]]
name = "ar_archive_writer"
version = "0.5.1"
@@ -661,6 +670,7 @@ dependencies = [
"burn-nn",
"burn-optim",
"burn-remote",
"burn-rl",
"burn-rocm",
"burn-router",
"burn-store",
@@ -1067,6 +1077,18 @@ dependencies = [
"tracing-subscriber",
]
[[package]]
name = "burn-rl"
version = "0.21.0"
dependencies = [
"burn-core",
"burn-ndarray",
"burn-optim",
"derive-new",
"log",
"rand 0.9.2",
]
[[package]]
name = "burn-rocm"
version = "0.21.0"
@@ -1183,9 +1205,11 @@ dependencies = [
"burn-core",
"burn-ndarray",
"burn-optim",
"burn-rl",
"derive-new",
"log",
"nvml-wrapper",
"rand 0.9.2",
"ratatui",
"rstest",
"serde",
@@ -1331,6 +1355,12 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "c_vec"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fdd7a427adc0135366d99db65b36dae9237130997e560ed61118041fb72be6e8"
[[package]]
name = "candle-core"
version = "0.9.2"
@@ -2709,6 +2739,17 @@ dependencies = [
"powerfmt",
]
[[package]]
name = "derivative"
version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "derive-new"
version = "0.7.0"
@@ -2911,6 +2952,16 @@ version = "0.15.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b"
[[package]]
name = "dqn-agent"
version = "0.21.0"
dependencies = [
"burn",
"derive-new",
"gym-rs",
"rand 0.9.2",
]
[[package]]
name = "drawille"
version = "0.3.0"
@@ -4045,6 +4096,23 @@ dependencies = [
"serde",
]
[[package]]
name = "gym-rs"
version = "0.3.1"
source = "git+https://github.com/MathisWellmann/gym-rs?branch=main#5283afaa86a3a7c45c46c882cfad459f02539b62"
dependencies = [
"derivative",
"derive-new",
"log",
"nalgebra",
"num-traits",
"ordered-float 4.6.0",
"rand 0.8.5",
"rand_pcg",
"sdl2",
"serde",
]
[[package]]
name = "h2"
version = "0.4.13"
@@ -5252,6 +5320,33 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "nalgebra"
version = "0.33.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26aecdf64b707efd1310e3544d709c5c0ac61c13756046aaaba41be5c4f66a3b"
dependencies = [
"approx",
"matrixmultiply",
"nalgebra-macros",
"num-complex",
"num-rational",
"num-traits",
"simba",
"typenum",
]
[[package]]
name = "nalgebra-macros"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "254a5372af8fc138e36684761d3c0cdb758a4410e938babcff1c860ce14ddbfc"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.114",
]
[[package]]
name = "named-tensor"
version = "0.21.0"
@@ -5907,6 +6002,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7bb71e1b3fa6ca1c61f383464aaf2bb0e2f8e772a1f01d486832464de363b951"
dependencies = [
"num-traits",
"rand 0.8.5",
"serde",
]
[[package]]
@@ -7014,6 +7111,7 @@ dependencies = [
"libc",
"rand_chacha 0.3.1",
"rand_core 0.6.4",
"serde",
]
[[package]]
@@ -7053,6 +7151,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [
"getrandom 0.2.17",
"serde",
]
[[package]]
@@ -7074,6 +7173,15 @@ dependencies = [
"rand 0.9.2",
]
[[package]]
name = "rand_pcg"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59cad018caf63deb318e5a4586d99a24424a364f40f1e5778c29aca23f4fc73e"
dependencies = [
"rand_core 0.6.4",
]
[[package]]
name = "range-alloc"
version = "0.1.4"
@@ -7604,6 +7712,15 @@ version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a50f4cf475b65d88e057964e0e9bb1f0aa9bbb2036dc65c64596b42932536984"
[[package]]
name = "safe_arch"
version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96b02de82ddbe1b636e6170c21be622223aea188ef2e139be0a5b219ec215323"
dependencies = [
"bytemuck",
]
[[package]]
name = "safetensors"
version = "0.3.3"
@@ -7704,6 +7821,31 @@ version = "3.0.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "490dcfcbfef26be6800d11870ff2df8774fa6e86d047e3e8c8a76b25655e41ca"
[[package]]
name = "sdl2"
version = "0.37.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b498da7d14d1ad6c839729bd4ad6fc11d90a57583605f3b4df2cd709a9cd380"
dependencies = [
"bitflags 1.3.2",
"c_vec",
"lazy_static",
"libc",
"sdl2-sys",
]
[[package]]
name = "sdl2-sys"
version = "0.37.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "951deab27af08ed9c6068b7b0d05a93c91f0a8eb16b6b816a5e73452a43521d3"
dependencies = [
"cfg-if",
"cmake",
"libc",
"version-compare",
]
[[package]]
name = "security-framework"
version = "2.11.1"
@@ -7971,6 +8113,19 @@ dependencies = [
"libc",
]
[[package]]
name = "simba"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c99284beb21666094ba2b75bbceda012e610f5479dfcc2d6e2426f53197ffd95"
dependencies = [
"approx",
"num-complex",
"num-traits",
"paste",
"wide",
]
[[package]]
name = "simd-adler32"
version = "0.3.8"
@@ -9440,6 +9595,12 @@ version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
[[package]]
name = "version-compare"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "579a42fc0b8e0c63b76519a339be31bed574929511fa53c1a3acae26eb258f29"
[[package]]
name = "version_check"
version = "0.9.5"
@@ -9856,6 +10017,16 @@ dependencies = [
"web-sys",
]
[[package]]
name = "wide"
version = "0.7.33"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ce5da8ecb62bcd8ec8b7ea19f69a51275e91299be594ea5cc6ef7819e16cd03"
dependencies = [
"bytemuck",
"safe_arch",
]
[[package]]
name = "widestring"
version = "1.2.1"

View File

@@ -7,7 +7,7 @@ use serde::{Serialize, de::DeserializeOwned};
/// Trait to define a family of types which can be recorded using any [settings](PrecisionSettings).
pub trait Record<B: Backend>: Send {
/// Type of the item that can be serialized and deserialized.
type Item<S: PrecisionSettings>: Serialize + DeserializeOwned;
type Item<S: PrecisionSettings>: Serialize + DeserializeOwned + Clone;
/// Convert the current record into the corresponding item that follows the given [settings](PrecisionSettings).
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S>;

View File

@@ -285,7 +285,7 @@ mod tests {
#[test]
#[should_panic]
fn err_when_invalid_item() {
#[derive(new, Serialize, Deserialize)]
#[derive(new, Serialize, Deserialize, Clone)]
struct Item<S: PrecisionSettings> {
value: S::FloatElem,
}

View File

@@ -102,7 +102,7 @@ enum LrSchedulerItem {
}
#[derive(Record)]
/// Record item for the [componsed learning rate scheduler](ComposedLrScheduler).
/// Record item for the [composed learning rate scheduler](ComposedLrScheduler).
pub enum LrSchedulerRecord<B: Backend> {
/// The linear variant.
Linear(<LinearLrScheduler as LrScheduler>::Record<B>),
@@ -115,7 +115,7 @@ pub enum LrSchedulerRecord<B: Backend> {
}
#[derive(Record)]
/// Records for the [componsed learning rate scheduler](ComposedLrScheduler).
/// Records for the [composed learning rate scheduler](ComposedLrScheduler).
pub struct ComposedLrSchedulerRecord<B: Backend> {
schedulers: Vec<LrSchedulerRecord<B>>,
}

View File

@@ -19,7 +19,7 @@ where
}
/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item.
#[derive(Serialize, Deserialize)]
#[derive(Serialize, Deserialize, Clone)]
#[serde(bound = "")]
pub enum AdaptorRecordItem<
O: SimpleOptimizer<B::InnerBackend>,

View File

@@ -56,7 +56,7 @@ impl<O: SimpleOptimizer<B>, B: Backend> Clone for AdaptorRecordV1<O, B> {
}
/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item.
#[derive(Serialize, Deserialize)]
#[derive(Serialize, Deserialize, Clone)]
#[serde(bound = "")]
pub enum AdaptorRecordItemV1<O: SimpleOptimizer<B>, B: Backend, S: PrecisionSettings> {
/// Rank 0.

23
crates/burn-rl/Cargo.toml Normal file
View File

@@ -0,0 +1,23 @@
[package]
name = "burn-rl"
edition.workspace = true
license.workspace = true
readme.workspace = true
version.workspace = true
[dependencies]
burn-core = { path = "../burn-core", version = "=0.21.0", features = [
"dataset",
"std",
], default-features = false }
burn-optim = { path = "../burn-optim", version = "=0.21.0", features = [
"std",
], default-features = false }
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0" }
derive-new.workspace = true
rand.workspace = true
log = { workspace = true }
[lints]
workspace = true

View File

@@ -0,0 +1 @@
../../LICENSE-APACHE

1
crates/burn-rl/LICENSE-MIT Symbolic link
View File

@@ -0,0 +1 @@
../../LICENSE-MIT

6
crates/burn-rl/README.md Normal file
View File

@@ -0,0 +1,6 @@
# Burn RL
<!-- This crate should be used with [burn](https://github.com/tracel-ai/burn). -->
<!-- [![Current Crates.io Version](https://img.shields.io/crates/v/burn-train.svg)](https://crates.io/crates/burn-train)
[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-train/blob/master/README.md) -->

View File

@@ -0,0 +1,46 @@
/// The result of taking a step in an environment.
pub struct StepResult<S> {
/// The updated state.
pub next_state: S,
/// The reward.
pub reward: f64,
/// If the environment reached a terminal state.
pub done: bool,
/// If the environment reached its max length.
pub truncated: bool,
}
/// Trait to be implemented for a RL environment.
pub trait Environment {
/// The type of the state.
type State;
/// The type of actions.
type Action;
/// The maximum number of step for one episode.
const MAX_STEPS: usize;
/// Returns the current state.
fn state(&self) -> Self::State;
/// Take a step in the environment given an action.
fn step(&mut self, action: Self::Action) -> StepResult<Self::State>;
/// Reset the environment to an initial state.
fn reset(&mut self);
}
/// Trait to define how to initialize an environment.
/// By default, any function returning an environment implements it.
pub trait EnvironmentInit<E: Environment>: Clone {
/// Initialize the environment.
fn init(&self) -> E;
}
impl<F, E> EnvironmentInit<E> for F
where
F: Fn() -> E + Clone,
E: Environment,
{
fn init(&self) -> E {
(self)()
}
}

View File

@@ -0,0 +1,3 @@
mod base;
pub use base::*;

21
crates/burn-rl/src/lib.rs Normal file
View File

@@ -0,0 +1,21 @@
#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_cfg))]
//! A library for training reinforcement learning agents.
/// Module for implementing an environment.
pub mod environment;
/// Module for implementing a policy.
pub mod policy;
/// Transition buffer.
pub mod transition_buffer;
pub use environment::*;
pub use policy::*;
pub use transition_buffer::*;
#[cfg(test)]
pub(crate) type TestBackend = burn_ndarray::NdArray<f32>;
#[cfg(test)]
pub(crate) mod tests {}

View File

@@ -0,0 +1,314 @@
use std::{
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
mpsc::{self, Sender},
},
thread::spawn,
};
use burn_core::prelude::Backend;
use crate::{ActionContext, Batchable, Policy, PolicyState};
#[derive(Clone)]
struct PolicyInferenceServer<B: Backend, P: Policy<B>> {
// `num_agents` used to make sure autobatching doesn't block the agents if they are less than the autobatch size.
num_agents: Arc<AtomicUsize>,
max_autobatch_size: usize,
inner_policy: P,
batch_action: Vec<ActionItem<P::Observation, P::Action, P::ActionContext>>,
batch_logits: Vec<ForwardItem<P::Observation, P::ActionDistribution>>,
}
impl<B, P> PolicyInferenceServer<B, P>
where
B: Backend,
P: Policy<B>,
P::Observation: Clone + Batchable,
P::ActionDistribution: Clone + Batchable,
P::Action: Clone + Batchable,
P::ActionContext: Clone,
{
pub fn new(max_autobatch_size: usize, inner_policy: P) -> Self {
Self {
num_agents: Arc::new(AtomicUsize::new(0)),
max_autobatch_size,
inner_policy,
batch_action: vec![],
batch_logits: vec![],
}
}
pub fn push_action(&mut self, item: ActionItem<P::Observation, P::Action, P::ActionContext>) {
self.batch_action.push(item);
if self.len_actions()
>= self
.num_agents
.load(Ordering::Relaxed)
.min(self.max_autobatch_size)
{
self.flush_actions();
}
}
pub fn push_logits(&mut self, item: ForwardItem<P::Observation, P::ActionDistribution>) {
self.batch_logits.push(item);
if self.len_logits()
>= self
.num_agents
.load(Ordering::Relaxed)
.min(self.max_autobatch_size)
{
self.flush_logits();
}
}
pub fn len_actions(&self) -> usize {
self.batch_action.len()
}
pub fn len_logits(&self) -> usize {
self.batch_action.len()
}
pub fn flush_actions(&mut self) {
if self.len_actions() == 0 {
return;
}
let input: Vec<_> = self
.batch_action
.iter()
.map(|m| m.inference_state.clone())
.collect();
// Only deterministic if all actions are requested as deterministic.
let deterministic = self.batch_action.iter().all(|item| item.deterministic);
let (actions, context) = self
.inner_policy
.action(P::Observation::batch(input), deterministic);
let actions: Vec<_> = actions.unbatch();
for (i, item) in self.batch_action.iter().enumerate() {
item.sender
.send(ActionContext {
context: vec![context[i].clone()],
action: actions[i].clone(),
})
.expect("Autobatcher should be able to send resulting actions.");
}
self.batch_action.clear();
}
pub fn flush_logits(&mut self) {
if self.len_logits() == 0 {
return;
}
let input: Vec<_> = self
.batch_logits
.iter()
.map(|m| m.inference_state.clone())
.collect();
let output = self.inner_policy.forward(P::Observation::batch(input));
let logits: Vec<_> = output.unbatch();
for (i, item) in self.batch_logits.iter().enumerate() {
item.sender
.send(logits[i].clone())
.expect("Autobatcher should be able to send resulting probabilities.");
}
self.batch_action.clear();
}
pub fn update_policy(&mut self, policy_update: P::PolicyState) {
if self.len_actions() > 0 {
self.flush_actions();
}
if self.len_logits() > 0 {
self.flush_logits();
}
self.inner_policy.update(policy_update);
}
pub fn state(&self) -> P::PolicyState {
self.inner_policy.state()
}
pub fn increment_agents(&mut self, num: usize) {
self.num_agents.fetch_add(num, Ordering::Relaxed);
}
pub fn decrement_agents(&mut self, num: usize) {
self.num_agents.fetch_sub(num, Ordering::Relaxed);
if self.len_actions()
>= self
.num_agents
.load(Ordering::Relaxed)
.min(self.max_autobatch_size)
{
self.flush_actions();
}
if self.len_logits()
>= self
.num_agents
.load(Ordering::Relaxed)
.min(self.max_autobatch_size)
{
self.flush_logits();
}
}
}
enum InferenceMessage<B: Backend, P: Policy<B>> {
ActionMessage(ActionItem<P::Observation, P::Action, P::ActionContext>),
ForwardMessage(ForwardItem<P::Observation, P::ActionDistribution>),
PolicyUpdate(P::PolicyState),
PolicyRequest(Sender<P::PolicyState>),
IncrementAgents(usize),
DecrementAgents(usize),
}
#[derive(Clone)]
struct ActionItem<S, A, C> {
sender: Sender<ActionContext<A, Vec<C>>>,
inference_state: S,
deterministic: bool,
}
#[derive(Clone)]
struct ForwardItem<S, O> {
sender: Sender<O>,
inference_state: S,
}
/// An asynchronous policy using an inference server with autobatching.
#[derive(Clone)]
pub struct AsyncPolicy<B: Backend, P: Policy<B>> {
inference_state_sender: Sender<InferenceMessage<B, P>>,
}
impl<B, P> AsyncPolicy<B, P>
where
B: Backend,
P: Policy<B> + Clone + Send + 'static,
P::ActionContext: Clone + Send,
P::PolicyState: Send,
P::Observation: Clone + Send + Batchable,
P::ActionDistribution: Clone + Send + Batchable,
P::Action: Clone + Send + Batchable,
{
/// Create the policy.
///
/// # Arguments
///
/// * `autobatch_size` - Number of observations to accumulate before running a pass of inference.
/// * `inner_policy` - The policy used to take actions.
pub fn new(autobatch_size: usize, inner_policy: P) -> Self {
let (sender, receiver) = std::sync::mpsc::channel();
let mut autobatcher = PolicyInferenceServer::new(autobatch_size, inner_policy.clone());
spawn(move || {
loop {
match receiver.recv() {
Ok(msg) => match msg {
InferenceMessage::ActionMessage(item) => autobatcher.push_action(item),
InferenceMessage::ForwardMessage(item) => autobatcher.push_logits(item),
InferenceMessage::PolicyUpdate(update) => autobatcher.update_policy(update),
InferenceMessage::PolicyRequest(sender) => sender
.send(autobatcher.state())
.expect("Autobatcher should be able to send current policy state."),
InferenceMessage::IncrementAgents(num) => autobatcher.increment_agents(num),
InferenceMessage::DecrementAgents(num) => autobatcher.decrement_agents(num),
},
Err(err) => {
log::error!("Error in AsyncPolicy : {}", err);
break;
}
}
}
});
Self {
inference_state_sender: sender,
}
}
/// Increment the number of agents using the inference server.
pub fn increment_agents(&self, num: usize) {
self.inference_state_sender
.send(InferenceMessage::IncrementAgents(num))
.expect("Can send message to autobatcher.")
}
/// Decrement the number of agents using the inference server.
pub fn decrement_agents(&self, num: usize) {
self.inference_state_sender
.send(InferenceMessage::DecrementAgents(num))
.expect("Can send message to autobatcher.")
}
}
impl<B, P> Policy<B> for AsyncPolicy<B, P>
where
B: Backend,
P: Policy<B> + Send + 'static,
{
type ActionContext = P::ActionContext;
type PolicyState = P::PolicyState;
type Observation = P::Observation;
type ActionDistribution = P::ActionDistribution;
type Action = P::Action;
fn forward(&mut self, states: Self::Observation) -> Self::ActionDistribution {
let (action_sender, action_receiver) = std::sync::mpsc::channel();
let item = ForwardItem {
sender: action_sender,
inference_state: states,
};
self.inference_state_sender
.send(InferenceMessage::ForwardMessage(item))
.expect("Should be able to send message to inference_server");
action_receiver
.recv()
.expect("AsyncPolicy should receive queued probabilities.")
}
fn action(
&mut self,
states: Self::Observation,
deterministic: bool,
) -> (Self::Action, Vec<Self::ActionContext>) {
let (action_sender, action_receiver) = std::sync::mpsc::channel();
let item = ActionItem {
sender: action_sender,
inference_state: states,
deterministic,
};
self.inference_state_sender
.send(InferenceMessage::ActionMessage(item))
.expect("should be able to send message to inference_server.");
let action = action_receiver
.recv()
.expect("AsyncPolicy should receive queued actions.");
(action.action, action.context)
}
fn update(&mut self, update: Self::PolicyState) {
self.inference_state_sender
.send(InferenceMessage::PolicyUpdate(update))
.expect("AsyncPolicy should be able to send policy state.")
}
fn state(&self) -> Self::PolicyState {
let (sender, receiver) = mpsc::channel();
self.inference_state_sender
.send(InferenceMessage::PolicyRequest(sender))
.expect("should be able to send message to inference_server.");
receiver
.recv()
.expect("AsyncPolicy should be able to receive policy state.")
}
fn load_record(self, _record: <Self::PolicyState as PolicyState<B>>::Record) -> Self {
// Not needed for now
todo!()
}
}

View File

@@ -0,0 +1,108 @@
use derive_new::new;
use burn_core::{prelude::*, record::Record, tensor::backend::AutodiffBackend};
use crate::TransitionBatch;
/// An action along with additional context about the decision.
#[derive(Clone, new)]
pub struct ActionContext<A, C> {
/// The context.
pub context: C,
/// The action.
pub action: A,
}
/// The state of a policy.
pub trait PolicyState<B: Backend> {
/// The type of the record.
type Record: Record<B>;
/// Convert the state to a record.
fn into_record(self) -> Self::Record;
/// Load the state from a record.
fn load_record(&self, record: Self::Record) -> Self;
}
/// Trait for a RL policy.
pub trait Policy<B: Backend>: Clone {
/// The observation given as input to the policy.
type Observation;
/// The action distribution parameters defining how the action will be sampled.
type ActionDistribution;
/// The action.
type Action;
/// Additional context on the policy's decision.
type ActionContext;
/// The current parameterization of the policy.
type PolicyState: PolicyState<B>;
/// Produces the action distribution from a batch of observations.
fn forward(&mut self, obs: Self::Observation) -> Self::ActionDistribution;
/// Gives the action from a batch of observations.
fn action(
&mut self,
obs: Self::Observation,
deterministic: bool,
) -> (Self::Action, Vec<Self::ActionContext>);
/// Update the policy's parameters.
fn update(&mut self, update: Self::PolicyState);
/// Returns the current parameterization.
fn state(&self) -> Self::PolicyState;
/// Loads the policy parameters from a record.
fn load_record(self, record: <Self::PolicyState as PolicyState<B>>::Record) -> Self;
}
/// Trait for a type that can be batched and unbatched (split).
pub trait Batchable: Sized {
/// Create a batch from a list of items.
fn batch(value: Vec<Self>) -> Self;
/// Create a list from batched items.
fn unbatch(self) -> Vec<Self>;
}
/// A training output.
pub struct RLTrainOutput<TO, P> {
/// The policy.
pub policy: P,
/// The item.
pub item: TO,
}
/// Batched transitions for a PolicyLearner.
pub type LearnerTransitionBatch<B, P> =
TransitionBatch<B, <P as Policy<B>>::Observation, <P as Policy<B>>::Action>;
/// Learner for a policy.
pub trait PolicyLearner<B>
where
B: AutodiffBackend,
<Self::InnerPolicy as Policy<B>>::Observation: Clone + Batchable,
<Self::InnerPolicy as Policy<B>>::ActionDistribution: Clone + Batchable,
<Self::InnerPolicy as Policy<B>>::Action: Clone + Batchable,
{
/// Additional context of a training step.
type TrainContext;
/// The policy to train.
type InnerPolicy: Policy<B>;
/// The record of the learner.
type Record: Record<B>;
/// Execute a training step on the policy.
fn train(
&mut self,
input: LearnerTransitionBatch<B, Self::InnerPolicy>,
) -> RLTrainOutput<Self::TrainContext, <Self::InnerPolicy as Policy<B>>::PolicyState>;
/// Returns the learner's current policy.
fn policy(&self) -> Self::InnerPolicy;
/// Update the learner's policy.
fn update_policy(&mut self, update: Self::InnerPolicy);
/// Convert the learner's state into a record.
fn record(&self) -> Self::Record;
/// Load the learner's state from a record.
fn load_record(self, record: Self::Record) -> Self;
}

View File

@@ -0,0 +1,5 @@
mod async_policy;
mod base;
pub use async_policy::*;
pub use base::*;

View File

@@ -0,0 +1,234 @@
use burn_core::{Tensor, prelude::Backend, tensor::backend::AutodiffBackend};
use derive_new::new;
use rand::{rng, seq::index::sample};
use crate::Batchable;
/// A state transition in an environment.
#[derive(Clone, new)]
pub struct Transition<B: Backend, S, A> {
/// The initial state.
pub state: S,
/// The state after the step was taken.
pub next_state: S,
/// The action taken in the step.
pub action: A,
/// The reward.
pub reward: Tensor<B, 1>,
/// If the environment has reached a terminal state.
pub done: Tensor<B, 1>,
}
/// A batch of transitions.
pub struct TransitionBatch<B: Backend, SB, AB> {
/// Batched initial states.
pub states: SB,
/// Batched resulting states.
pub next_states: SB,
/// Batched actions.
pub actions: AB,
/// Batched rewards.
pub rewards: Tensor<B, 2>,
/// Batched flags for terminal states.
pub dones: Tensor<B, 2>,
}
impl<BT, B, S, A, SB, AB> From<Vec<&Transition<BT, S, A>>> for TransitionBatch<B, SB, AB>
where
BT: Backend,
B: AutodiffBackend,
S: Into<SB> + Clone,
A: Into<AB> + Clone,
SB: Batchable,
AB: Batchable,
{
fn from(value: Vec<&Transition<BT, S, A>>) -> Self {
let states: Vec<_> = value.iter().map(|t| t.state.clone().into()).collect();
let next_states: Vec<_> = value.iter().map(|t| t.next_state.clone().into()).collect();
let actions: Vec<_> = value.iter().map(|t| t.action.clone().into()).collect();
let rewards: Vec<_> = value.iter().map(|t| t.reward.clone()).collect();
let dones: Vec<_> = value.iter().map(|t| t.done.clone()).collect();
let rewards = Tensor::stack::<2>(rewards, 0);
let dones = Tensor::stack::<2>(dones, 0);
Self {
states: SB::batch(states),
next_states: SB::batch(next_states),
actions: AB::batch(actions),
rewards: Tensor::from_data(rewards.to_data(), &Default::default()),
dones: Tensor::from_data(dones.to_data(), &Default::default()),
}
}
}
/// A circular buffer for transitions.
pub struct TransitionBuffer<T> {
buffer: Vec<T>,
capacity: usize,
cursor: usize,
}
impl<T> TransitionBuffer<T> {
/// Creates a new circular buffer with a fixed capacity.
pub fn new(capacity: usize) -> Self {
Self {
buffer: Vec::with_capacity(capacity),
capacity,
cursor: 0,
}
}
/// Add an item, overwriting the oldest if full.
pub fn push(&mut self, item: T) {
if self.buffer.len() < self.capacity {
self.buffer.push(item);
} else {
self.buffer[self.cursor] = item;
self.cursor = (self.cursor + 1) % self.capacity;
}
}
/// Append a list of items to the current buffer.
pub fn append(&mut self, items: &mut Vec<T>) {
let n = items.len();
let mut is_overflow = false;
if n > self.capacity {
self.cursor = self.capacity - (n % self.capacity);
items.drain(0..n - self.capacity);
is_overflow = true;
}
let n = items.len();
let first_part = n.min(self.capacity - self.cursor);
let second_part = n - first_part;
if is_overflow {
if self.capacity > self.len() {
self.buffer
.extend(items.drain(first_part..second_part + first_part));
} else {
self.buffer[..second_part]
.iter_mut()
.zip(items.drain(first_part..second_part + first_part))
.for_each(|(slot, item)| *slot = item);
}
}
if self.capacity > self.len() {
self.buffer.extend(items.drain(..first_part));
} else {
self.buffer[self.cursor..self.cursor + first_part]
.iter_mut()
.zip(items.drain(..first_part))
.for_each(|(slot, item)| *slot = item);
}
if !is_overflow {
self.buffer[..second_part]
.iter_mut()
.zip(items.drain(..second_part))
.for_each(|(slot, item)| *slot = item);
}
self.cursor = (self.cursor + n) % self.capacity
}
/// Returns the current number of items stored.
pub fn len(&self) -> usize {
self.buffer.len()
}
/// Returns the current number of items stored.
pub fn is_empty(&self) -> bool {
self.buffer.is_empty()
}
/// Clear the buffer.
pub fn clear(&mut self) {
self.buffer.clear();
}
/// Sample the buffer at the given indices.
pub fn sample(&self, indices: Vec<usize>) -> Vec<&T> {
let mut items = Vec::with_capacity(indices.len());
for &idx in indices.iter() {
if let Some(item) = self.buffer.get(idx) {
items.push(item);
}
}
items
}
/// Sample `batch_size` transitions at random.
pub fn random_sample(&self, batch_size: usize) -> Vec<&T> {
assert!(batch_size <= self.len());
let mut rng = rng();
let indices = sample(&mut rng, self.len(), batch_size).into_vec();
self.sample(indices)
}
}
#[cfg(test)]
mod tests {
use burn_core::tensor::TensorData;
use crate::TestBackend;
use super::*;
fn transition() -> Transition<TestBackend, Tensor<TestBackend, 1>, Tensor<TestBackend, 1>> {
Transition::new(
Tensor::from_data(TensorData::from([1.0, 2.0]), &Default::default()),
Tensor::from_data(TensorData::from([1.0, 2.0]), &Default::default()),
Tensor::from_data(TensorData::from([1.0]), &Default::default()),
Tensor::from_data(TensorData::from([1.0]), &Default::default()),
Tensor::from_data(TensorData::from([1.0]), &Default::default()),
)
}
#[test]
fn len_returns_number_of_elements() {
let mut buffer: TransitionBuffer<
Transition<TestBackend, Tensor<TestBackend, 1>, Tensor<TestBackend, 1>>,
> = TransitionBuffer::new(2);
assert_eq!(buffer.len(), 0);
buffer.push(transition());
assert_eq!(buffer.len(), 1);
buffer.push(transition());
assert_eq!(buffer.len(), 2);
buffer.push(transition());
assert_eq!(buffer.len(), 2)
}
#[test]
fn append_works() {
let mut buffer = TransitionBuffer::new(4);
assert_eq!(buffer.len(), 0);
buffer.append(&mut vec![0, 1]);
assert_eq!(buffer.len(), 2);
assert_eq!(buffer.buffer, vec![0, 1]);
buffer.append(&mut vec![2, 3, 4, 5]);
assert_eq!(buffer.len(), 4);
assert_eq!(buffer.buffer, vec![4, 5, 2, 3]);
let mut buffer = TransitionBuffer::new(4);
buffer.append(&mut vec![0, 1, 2, 3, 4, 5]);
assert_eq!(buffer.len(), 4);
assert_eq!(buffer.buffer, vec![4, 5, 2, 3]);
buffer.append(&mut vec![10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]);
assert_eq!(buffer.len(), 4);
assert_eq!(buffer.buffer, vec![20, 17, 18, 19]);
buffer.append(&mut vec![21, 22]);
assert_eq!(buffer.len(), 4);
assert_eq!(buffer.buffer, vec![20, 21, 22, 19]);
}
}

View File

@@ -0,0 +1,3 @@
mod base;
pub use base::*;

View File

@@ -37,6 +37,7 @@ burn-core = { path = "../burn-core", version = "=0.21.0", features = [
burn-optim = { path = "../burn-optim", version = "=0.21.0", features = [
"std",
], default-features = false }
burn-rl = { path = "../burn-rl", version = "=0.21.0", default-features = false }
burn-collective = { path = "../burn-collective", version = "=0.21.0", optional = true }
log = { workspace = true }
@@ -62,6 +63,7 @@ async-channel = { workspace = true }
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0" }
rstest.workspace = true
thiserror.workspace = true
rand.workspace = true
[dev-dependencies]
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0" }

View File

@@ -1,7 +1,8 @@
use crate::{
AsyncProcessorEvaluation, FullEventProcessorEvaluation, InferenceStep, Interrupter,
AsyncProcessorEvaluation, EvaluationItem, FullEventProcessorEvaluation, InferenceStep,
Interrupter,
evaluator::components::EvaluatorComponentTypes,
metric::processor::{EvaluatorEvent, EventProcessorEvaluation, LearnerItem},
metric::processor::{EvaluatorEvent, EventProcessorEvaluation},
renderer::{EvaluationName, MetricsRenderer},
};
use burn_core::{data::dataloader::DataLoader, module::Module};
@@ -43,7 +44,7 @@ impl<EC: EvaluatorComponentTypes> Evaluator<EC> {
iteration += 1;
let item = self.model.step(item);
let item = LearnerItem::new(item, progress, 0, 1, iteration, None);
let item = EvaluationItem::new(item, progress, Some(iteration));
self.event_processor
.process_test(EvaluatorEvent::ProcessedItem(name.clone(), item));

View File

@@ -3,6 +3,7 @@ mod base;
mod classification;
mod early_stopping;
mod regression;
mod rl;
mod summary;
mod supervised;
mod train_val;
@@ -12,6 +13,7 @@ pub use base::*;
pub use classification::*;
pub use early_stopping::*;
pub use regression::*;
pub use rl::*;
pub use summary::*;
pub use supervised::*;
pub use train_val::*;

View File

@@ -0,0 +1,75 @@
use burn_core::tensor::Device;
use burn_rl::{Policy, PolicyLearner, PolicyState};
use crate::RLAgentRecord;
use crate::{
RLComponentsTypes, RLPolicyRecord,
checkpoint::Checkpointer,
checkpoint::{AsyncCheckpointer, CheckpointingAction, CheckpointingStrategy},
metric::store::EventStoreClient,
};
#[derive(new)]
/// Used to create, delete, or load checkpoints of the training process.
pub struct RLCheckpointer<RLC: RLComponentsTypes> {
policy: AsyncCheckpointer<RLPolicyRecord<RLC>, RLC::Backend>,
learning_agent: AsyncCheckpointer<RLAgentRecord<RLC>, RLC::Backend>,
strategy: Box<dyn CheckpointingStrategy>,
}
impl<RLC: RLComponentsTypes> RLCheckpointer<RLC> {
/// Create checkpoint for the training process.
pub fn checkpoint(
&mut self,
policy: &RLC::PolicyState,
learning_agent: &RLC::LearningAgent,
epoch: usize,
store: &EventStoreClient,
) {
let actions = self.strategy.checkpointing(epoch, store);
for action in actions {
match action {
CheckpointingAction::Delete(epoch) => {
self.policy
.delete(epoch)
.expect("Can delete policy checkpoint.");
self.learning_agent
.delete(epoch)
.expect("Can delete learning agent checkpoint.")
}
CheckpointingAction::Save => {
self.policy
.save(epoch, policy.clone().into_record())
.expect("Can save policy checkpoint.");
self.learning_agent
.save(epoch, learning_agent.record())
.expect("Can save learning agent checkpoint.");
}
}
}
}
/// Load a training checkpoint.
pub fn load_checkpoint(
&self,
learning_agent: RLC::LearningAgent,
device: &Device<RLC::Backend>,
epoch: usize,
) -> RLC::LearningAgent {
let record = self
.policy
.restore(epoch, device)
.expect("Can load model checkpoint.");
let policy = learning_agent.policy().load_record(record);
let record = self
.learning_agent
.restore(epoch, device)
.expect("Can load learning agent checkpoint.");
let mut learning_agent = learning_agent.load_record(record);
learning_agent.update_policy(policy);
learning_agent
}
}

View File

@@ -0,0 +1,115 @@
use std::marker::PhantomData;
use burn_core::tensor::backend::AutodiffBackend;
use burn_rl::{Batchable, Environment, EnvironmentInit, Policy, PolicyLearner, PolicyState};
use crate::{AgentEvaluationEvent, AsyncProcessorTraining, ItemLazy, RLEvent};
/// All components used by the reinforcement learning paradigm, grouped in one trait.
pub trait RLComponentsTypes {
/// The backend used for training.
type Backend: AutodiffBackend;
/// The learning environment.
type Env: Environment<State = Self::State, Action = Self::Action> + 'static;
/// Specifies how to initialize the environment.
type EnvInit: EnvironmentInit<Self::Env> + Send + 'static;
/// The type of the environment state.
type State: Into<<Self::Policy as Policy<Self::Backend>>::Observation> + Clone + Send + 'static;
/// The type of the environment action.
type Action: From<<Self::Policy as Policy<Self::Backend>>::Action>
+ Into<<Self::Policy as Policy<Self::Backend>>::Action>
+ Clone
+ Send
+ 'static;
/// The policy used to take actions in the environment.
type Policy: Policy<
Self::Backend,
Observation = Self::PolicyObs,
ActionDistribution = Self::PolicyAD,
Action = Self::PolicyAction,
ActionContext = Self::ActionContext,
PolicyState = Self::PolicyState,
> + Send
+ 'static;
/// The policy's observation type.
type PolicyObs: Clone + Send + Batchable + 'static;
/// The policy's action distribution type.
type PolicyAD: Clone + Send + Batchable;
/// The policy's action type.
type PolicyAction: Clone + Send + Batchable;
/// Additional data as context for an agent's action.
type ActionContext: ItemLazy + Clone + Send + 'static;
/// The state of the parameterized policy.
type PolicyState: Clone + Send + PolicyState<Self::Backend> + 'static;
/// The learning agent.
type LearningAgent: PolicyLearner<
Self::Backend,
TrainContext = Self::TrainingOutput,
InnerPolicy = Self::Policy,
> + Send
+ 'static;
/// The output data of a training step.
type TrainingOutput: ItemLazy + Clone + Send;
}
/// Concrete type that implements the [RLComponentsTypes](RLComponentsTypes) trait.
pub struct RLComponentsMarker<B, E, EI, A> {
_backend: PhantomData<B>,
_env: PhantomData<E>,
_env_init: PhantomData<EI>,
_agent: PhantomData<A>,
}
impl<B, E, EI, A> RLComponentsTypes for RLComponentsMarker<B, E, EI, A>
where
B: AutodiffBackend,
E: Environment + 'static,
EI: EnvironmentInit<E> + Send + 'static,
A: PolicyLearner<B> + Send + 'static,
A::TrainContext: ItemLazy + Clone + Send,
A::InnerPolicy: Policy<B> + Send,
<A::InnerPolicy as Policy<B>>::Observation: Batchable + Clone + Send,
<A::InnerPolicy as Policy<B>>::ActionDistribution: Batchable + Clone + Send,
<A::InnerPolicy as Policy<B>>::Action: Batchable + Clone + Send,
<A::InnerPolicy as Policy<B>>::ActionContext: ItemLazy + Clone + Send + 'static,
<A::InnerPolicy as Policy<B>>::PolicyState: Clone + Send,
E::State: Into<<A::InnerPolicy as Policy<B>>::Observation> + Clone + Send + 'static,
E::Action: From<<A::InnerPolicy as Policy<B>>::Action>
+ Into<<A::InnerPolicy as Policy<B>>::Action>
+ Clone
+ Send
+ 'static,
{
type Backend = B;
type Env = E;
type EnvInit = EI;
type LearningAgent = A;
type Policy = A::InnerPolicy;
type PolicyObs = <A::InnerPolicy as Policy<B>>::Observation;
type PolicyAD = <A::InnerPolicy as Policy<B>>::ActionDistribution;
type PolicyAction = <A::InnerPolicy as Policy<B>>::Action;
type ActionContext = <A::InnerPolicy as Policy<B>>::ActionContext;
type PolicyState = <A::InnerPolicy as Policy<B>>::PolicyState;
type TrainingOutput = A::TrainContext;
type State = E::State;
type Action = E::Action;
}
pub(crate) type RlPolicy<RLC> = <<RLC as RLComponentsTypes>::LearningAgent as PolicyLearner<
<RLC as RLComponentsTypes>::Backend,
>>::InnerPolicy;
/// The event processor type for reinforcement learning.
pub type RLEventProcessorType<RLC> = AsyncProcessorTraining<
RLEvent<<RLC as RLComponentsTypes>::TrainingOutput, <RLC as RLComponentsTypes>::ActionContext>,
AgentEvaluationEvent<<RLC as RLComponentsTypes>::ActionContext>,
>;
/// The record of the policy.
pub type RLPolicyRecord<RLC> = <<<RLC as RLComponentsTypes>::Policy as Policy<
<RLC as RLComponentsTypes>::Backend,
>>::PolicyState as PolicyState<<RLC as RLComponentsTypes>::Backend>>::Record;
/// The record of the learning agent.
pub type RLAgentRecord<RLC> = <<RLC as RLComponentsTypes>::LearningAgent as PolicyLearner<
<RLC as RLComponentsTypes>::Backend,
>>::Record;

View File

@@ -0,0 +1,504 @@
use rand::prelude::SliceRandom;
use std::{
sync::mpsc::{Receiver, Sender},
thread::spawn,
};
use burn_core::{Tensor, data::dataloader::Progress, prelude::Backend, tensor::Device};
use burn_rl::EnvironmentInit;
use burn_rl::Policy;
use burn_rl::Transition;
use burn_rl::{AsyncPolicy, Environment};
use crate::{
AgentEnvLoop, AgentEvaluationEvent, EpisodeSummary, EvaluationItem, EventProcessorTraining,
Interrupter, RLComponentsTypes, RLEvent, RLEventProcessorType, RLTimeStep, RLTrajectory,
RlPolicy, TimeStep, Trajectory,
};
enum RequestMessage {
Step(),
Episode(),
}
/// An asynchronous agent/environement interface.
pub struct AgentEnvAsyncLoop<BT: Backend, RLC: RLComponentsTypes> {
env_init: RLC::EnvInit,
id: usize,
eval: bool,
agent: AsyncPolicy<RLC::Backend, RlPolicy<RLC>>,
deterministic: bool,
transition_device: Device<BT>,
transition_receiver: Receiver<RLTimeStep<BT, RLC>>,
transition_sender: Sender<RLTimeStep<BT, RLC>>,
trajectory_receiver: Receiver<RLTrajectory<BT, RLC>>,
trajectory_sender: Sender<RLTrajectory<BT, RLC>>,
request_sender: Option<Sender<RequestMessage>>,
}
impl<BT: Backend, RLC: RLComponentsTypes> AgentEnvAsyncLoop<BT, RLC> {
/// Create a new asynchronous runner.
pub fn new(
env_init: RLC::EnvInit,
id: usize,
eval: bool,
agent: AsyncPolicy<RLC::Backend, RlPolicy<RLC>>,
deterministic: bool,
transition_device: &Device<BT>,
) -> Self {
let (transition_sender, transition_receiver) = std::sync::mpsc::channel();
let (trajectory_sender, trajectory_receiver) = std::sync::mpsc::channel();
Self {
env_init,
id,
eval,
agent: agent.clone(),
deterministic,
transition_device: transition_device.clone(),
transition_receiver,
transition_sender,
trajectory_receiver,
trajectory_sender,
request_sender: None,
}
}
}
impl<BT, RLC> AgentEnvLoop<BT, RLC> for AgentEnvAsyncLoop<BT, RLC>
where
BT: Backend,
RLC: RLComponentsTypes,
{
fn start(&mut self) {
let id = self.id;
let mut agent = self.agent.clone();
let deterministic = self.deterministic;
let transition_sender = self.transition_sender.clone();
let trajectory_sender = self.trajectory_sender.clone();
let device = self.transition_device.clone();
let env_init = self.env_init.clone();
let (request_sender, request_receiver) = std::sync::mpsc::channel();
self.request_sender = Some(request_sender);
let mut current_steps = vec![];
let mut current_reward = 0.0;
let mut step_num = 0;
spawn(move || {
let mut env = env_init.init();
env.reset();
let mut request_episode = false;
loop {
let state = env.state();
let (action, context) = agent.action(state.clone().into(), deterministic);
let env_action = RLC::Action::from(action);
let step_result = env.step(env_action.clone());
current_reward += step_result.reward;
step_num += 1;
let transition = Transition::new(
state.clone(),
step_result.next_state,
env_action,
Tensor::from_data([step_result.reward], &device),
Tensor::from_data(
[(step_result.done || step_result.truncated) as i32 as f64],
&device,
),
);
if !request_episode {
agent.decrement_agents(1);
let request = match request_receiver.recv() {
Ok(req) => req,
Err(err) => {
log::error!("Error in env runner : {}", err);
break;
}
};
agent.increment_agents(1);
match request {
RequestMessage::Step() => (),
RequestMessage::Episode() => request_episode = true,
}
}
let time_step = TimeStep {
env_id: id,
transition,
done: step_result.done,
ep_len: step_num,
cum_reward: current_reward,
action_context: context[0].clone(),
};
current_steps.push(time_step.clone());
if !request_episode && let Err(err) = transition_sender.send(time_step) {
log::error!("Error in env runner : {}", err);
break;
}
if step_result.done || step_result.truncated {
if request_episode {
request_episode = false;
trajectory_sender
.send(Trajectory {
timesteps: current_steps.clone(),
})
.expect("Can send trajectory to main thread.");
}
current_steps.clear();
env.reset();
current_reward = 0.;
step_num = 0;
}
}
});
}
fn run_steps(
&mut self,
num_steps: usize,
processor: &mut RLEventProcessorType<RLC>,
interrupter: &Interrupter,
progress: &mut Progress,
) -> Vec<RLTimeStep<BT, RLC>> {
let mut items = vec![];
for _ in 0..num_steps {
self.request_sender
.as_ref()
.expect("Call start before running steps.")
.send(RequestMessage::Step())
.expect("Can request transitions.");
let transition = self
.transition_receiver
.recv()
.expect("Can receive transitions.");
items.push(transition.clone());
if !self.eval {
progress.items_processed += 1;
processor.process_train(RLEvent::TimeStep(EvaluationItem::new(
transition.action_context,
progress.clone(),
None,
)));
if transition.done {
processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(
EpisodeSummary {
episode_length: transition.ep_len,
cum_reward: transition.cum_reward,
},
progress.clone(),
None,
)));
}
}
if interrupter.should_stop() {
break;
}
}
items
}
fn run_episodes(
&mut self,
num_episodes: usize,
processor: &mut RLEventProcessorType<RLC>,
interrupter: &Interrupter,
_progress: &mut Progress,
) -> Vec<RLTrajectory<BT, RLC>> {
let mut items = vec![];
self.agent.increment_agents(1);
for episode_num in 0..num_episodes {
self.request_sender
.as_ref()
.expect("Call start before running episodes.")
.send(RequestMessage::Episode())
.expect("Can request episodes.");
let trajectory = self
.trajectory_receiver
.recv()
.expect("Main thread can receive trajectory.");
for (i, step) in trajectory.timesteps.iter().enumerate() {
// TODO : clean this.
if self.eval {
processor.process_valid(AgentEvaluationEvent::TimeStep(EvaluationItem::new(
step.action_context.clone(),
Progress::new(i, i),
None,
)));
if step.done {
processor.process_valid(AgentEvaluationEvent::EpisodeEnd(
EvaluationItem::new(
EpisodeSummary {
episode_length: step.ep_len,
cum_reward: step.cum_reward,
},
Progress::new(episode_num + 1, num_episodes),
None,
),
));
}
} else {
processor.process_train(RLEvent::TimeStep(EvaluationItem::new(
step.action_context.clone(),
Progress::new(i, i),
None,
)));
if step.done {
processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(
EpisodeSummary {
episode_length: step.ep_len,
cum_reward: step.cum_reward,
},
Progress::new(episode_num + 1, num_episodes),
None,
)));
}
}
}
items.push(trajectory);
if interrupter.should_stop() {
break;
}
}
self.agent.decrement_agents(1);
items
}
fn update_policy(&mut self, update: RLC::PolicyState) {
self.agent.update(update);
}
fn policy(&self) -> RLC::PolicyState {
self.agent.state()
}
}
/// An asynchronous runner for multiple agent/environement interfaces.
pub struct MultiAgentEnvLoop<BT: Backend, RLC: RLComponentsTypes> {
env_init: RLC::EnvInit,
num_envs: usize,
eval: bool,
agent: AsyncPolicy<RLC::Backend, RLC::Policy>,
deterministic: bool,
device: Device<BT>,
transition_receiver: Receiver<RLTimeStep<BT, RLC>>,
transition_sender: Sender<RLTimeStep<BT, RLC>>,
trajectory_receiver: Receiver<RLTrajectory<BT, RLC>>,
trajectory_sender: Sender<RLTrajectory<BT, RLC>>,
request_senders: Vec<Sender<RequestMessage>>,
}
impl<BT: Backend, RLC: RLComponentsTypes> MultiAgentEnvLoop<BT, RLC> {
/// Create a new asynchronous runner for multiple agent/environement interfaces.
pub fn new(
env_init: RLC::EnvInit,
num_envs: usize,
eval: bool,
agent: AsyncPolicy<RLC::Backend, RLC::Policy>,
deterministic: bool,
device: &Device<BT>,
) -> Self {
let (transition_sender, transition_receiver) = std::sync::mpsc::channel();
let (trajectory_sender, trajectory_receiver) = std::sync::mpsc::channel();
Self {
env_init,
num_envs,
eval,
agent: agent.clone(),
deterministic,
device: device.clone(),
transition_receiver,
transition_sender,
trajectory_receiver,
trajectory_sender,
request_senders: Vec::with_capacity(num_envs),
}
}
}
impl<BT, RLC> AgentEnvLoop<BT, RLC> for MultiAgentEnvLoop<BT, RLC>
where
BT: Backend,
RLC: RLComponentsTypes,
{
// TODO: start() shouldn't exist.
fn start(&mut self) {
// Double batching : The environments are always one step ahead of requests. This allows inference for the first batch of steps.
self.agent.increment_agents(self.num_envs);
for i in 0..self.num_envs {
let mut runner = AgentEnvAsyncLoop::<BT, RLC>::new(
self.env_init.clone(),
i,
self.eval,
self.agent.clone(),
self.deterministic,
&self.device,
);
runner.transition_sender = self.transition_sender.clone();
runner.trajectory_sender = self.trajectory_sender.clone();
runner.start();
self.request_senders
.push(runner.request_sender.clone().unwrap());
}
// Double batching : The environments are always one step ahead.
self.request_senders.iter().for_each(|s| {
s.send(RequestMessage::Step())
.expect("Main thread can send step requests.")
});
}
fn run_steps(
&mut self,
num_steps: usize,
processor: &mut RLEventProcessorType<RLC>,
interrupter: &Interrupter,
progress: &mut Progress,
) -> Vec<RLTimeStep<BT, RLC>> {
let mut items = vec![];
for _ in 0..num_steps {
let transition = self
.transition_receiver
.recv()
.expect("Can receive transitions.");
items.push(transition.clone());
self.request_senders[transition.env_id]
.send(RequestMessage::Step())
.expect("Main thread can request steps.");
if !self.eval {
progress.items_processed += 1;
processor.process_train(RLEvent::TimeStep(EvaluationItem::new(
transition.action_context,
progress.clone(),
None,
)));
if transition.done {
processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(
EpisodeSummary {
episode_length: transition.ep_len,
cum_reward: transition.cum_reward,
},
progress.clone(),
None,
)));
}
}
if interrupter.should_stop() {
break;
}
}
items
}
fn update_policy(&mut self, update: RLC::PolicyState) {
self.agent.update(update);
}
fn run_episodes(
&mut self,
num_episodes: usize,
processor: &mut RLEventProcessorType<RLC>,
interrupter: &Interrupter,
_progress: &mut Progress,
) -> Vec<RLTrajectory<BT, RLC>> {
// Send `num_episodes` initial requests.
let mut idx = vec![];
if num_episodes < self.num_envs {
let mut rng = rand::rng();
let mut vec: Vec<usize> = (0..self.num_envs).collect();
vec.shuffle(&mut rng);
idx = vec.into_iter().take(num_episodes).collect();
} else {
idx = (0..self.num_envs).collect();
}
let num_requests = self.num_envs.min(num_episodes);
idx.into_iter().for_each(|i| {
self.request_senders[i]
.send(RequestMessage::Episode())
.expect("Main thread can request steps.");
});
let mut items = vec![];
for episode_num in 0..num_episodes {
let trajectory = self
.trajectory_receiver
.recv()
.expect("Can receive trajectory.");
items.push(trajectory.clone());
if items.len() + num_requests <= num_episodes {
self.request_senders[trajectory.timesteps[0].env_id]
.send(RequestMessage::Episode())
.expect("Main thread can request steps.");
}
for (i, step) in trajectory.timesteps.iter().enumerate() {
if self.eval {
processor.process_valid(AgentEvaluationEvent::TimeStep(EvaluationItem::new(
step.action_context.clone(),
Progress::new(i, i),
None,
)));
if step.done {
processor.process_valid(AgentEvaluationEvent::EpisodeEnd(
EvaluationItem::new(
EpisodeSummary {
episode_length: step.ep_len,
cum_reward: step.cum_reward,
},
Progress::new(episode_num + 1, num_episodes),
None,
),
));
}
} else {
processor.process_train(RLEvent::TimeStep(EvaluationItem::new(
step.action_context.clone(),
Progress::new(i, i),
None,
)));
if step.done {
processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(
EpisodeSummary {
episode_length: step.ep_len,
cum_reward: step.cum_reward,
},
Progress::new(episode_num + 1, num_episodes),
None,
)));
}
}
}
if interrupter.should_stop() {
break;
}
}
items
}
fn policy(&self) -> RLC::PolicyState {
self.agent.state()
}
}

View File

@@ -0,0 +1,273 @@
use std::marker::PhantomData;
use burn_core::data::dataloader::Progress;
use burn_core::{Tensor, prelude::Backend};
use burn_rl::Policy;
use burn_rl::Transition;
use burn_rl::{Environment, EnvironmentInit};
use crate::RLEvent;
use crate::{
AgentEvaluationEvent, EpisodeSummary, EvaluationItem, EventProcessorTraining,
RLEventProcessorType,
};
use crate::{Interrupter, RLComponentsTypes};
/// A trajectory, i.e. a list of ordered [TimeStep](TimeStep).
#[derive(Clone, new)]
pub struct Trajectory<B: Backend, S, A, C> {
/// A list of ordered [TimeStep](TimeStep)s.
pub timesteps: Vec<TimeStep<B, S, A, C>>,
}
/// A timestep debscribing an iteration of the state/decision process.
#[derive(Clone)]
pub struct TimeStep<B: Backend, S, A, C> {
/// The environment id.
pub env_id: usize,
/// The [burn_rl::Transition](burn_rl::Transition).
pub transition: Transition<B, S, A>,
/// True if the environment reaches a terminal state.
pub done: bool,
/// The running length of the current episode.
pub ep_len: usize,
/// The running cumulative reward.
pub cum_reward: f64,
/// The action's context for this timestep.
pub action_context: C,
}
pub(crate) type RLTimeStep<B, RLC> = TimeStep<
B,
<RLC as RLComponentsTypes>::State,
<RLC as RLComponentsTypes>::Action,
<RLC as RLComponentsTypes>::ActionContext,
>;
pub(crate) type RLTrajectory<B, RLC> = Trajectory<
B,
<RLC as RLComponentsTypes>::State,
<RLC as RLComponentsTypes>::Action,
<RLC as RLComponentsTypes>::ActionContext,
>;
/// Trait for a structure that implements an agent/environement interface.
pub trait AgentEnvLoop<BT: Backend, RLC: RLComponentsTypes> {
/// Start the loop.
fn start(&mut self);
/// Run a certain number of timesteps.
///
/// # Arguments
///
/// * `num_steps` - The number of time_steps to run.
/// * `processor` - An [crate::EventProcessorTraining](crate::EventProcessorTraining).
/// * `interrupter` - An [crate::Interrupter](crate::Interrupter).
/// * `num_steps` - The number of time_steps to run.
/// * `progress` - A mutable reference to the learning progress.
///
/// # Returns
///
/// A list of ordered timesteps.
fn run_steps(
&mut self,
num_steps: usize,
processor: &mut RLEventProcessorType<RLC>,
interrupter: &Interrupter,
progress: &mut Progress,
) -> Vec<RLTimeStep<BT, RLC>>;
/// Run a certain number of episodes.
///
/// # Arguments
///
/// * `num_episodes` - The number of episodes to run.
/// * `processor` - An [crate::EventProcessorTraining](crate::EventProcessorTraining).
/// * `interrupter` - An [crate::Interrupter](crate::Interrupter).
/// * `progress` - A mutable reference to the learning progress.
///
/// # Returns
///
/// A list of ordered timesteps.
fn run_episodes(
&mut self,
num_episodes: usize,
processor: &mut RLEventProcessorType<RLC>,
interrupter: &Interrupter,
progress: &mut Progress,
) -> Vec<RLTrajectory<BT, RLC>>;
/// Update the runner's agent.
fn update_policy(&mut self, update: RLC::PolicyState);
/// Get the state of the runner's agent.
fn policy(&self) -> RLC::PolicyState;
}
/// A simple, synchronized agent/environement interface.
pub struct AgentEnvBaseLoop<B: Backend, RLC: RLComponentsTypes> {
env: RLC::Env,
eval: bool,
agent: RLC::Policy,
deterministic: bool,
current_reward: f64,
run_num: usize,
step_num: usize,
_backend: PhantomData<B>,
}
impl<B: Backend, RLC: RLComponentsTypes> AgentEnvBaseLoop<B, RLC> {
/// Create a new base runner.
pub fn new(
env_init: RLC::EnvInit,
agent: RLC::Policy,
eval: bool,
deterministic: bool,
) -> Self {
Self {
env: env_init.init(),
eval,
agent: agent.clone(),
deterministic,
current_reward: 0.0,
run_num: 0,
step_num: 0,
_backend: PhantomData,
}
}
}
impl<BT, RLC> AgentEnvLoop<BT, RLC> for AgentEnvBaseLoop<BT, RLC>
where
BT: Backend,
RLC: RLComponentsTypes,
{
fn start(&mut self) {
self.env.reset();
}
fn run_steps(
&mut self,
num_steps: usize,
processor: &mut RLEventProcessorType<RLC>,
interrupter: &Interrupter,
progress: &mut Progress,
) -> Vec<RLTimeStep<BT, RLC>> {
let mut items = vec![];
let device = Default::default();
for _ in 0..num_steps {
let state = self.env.state();
let (action, context) = self.agent.action(state.clone().into(), self.deterministic);
let step_result = self.env.step(RLC::Action::from(action.clone()));
self.current_reward += step_result.reward;
self.step_num += 1;
let transition = Transition::new(
state.clone(),
step_result.next_state,
RLC::Action::from(action),
Tensor::from_data([step_result.reward], &device),
Tensor::from_data(
[(step_result.done || step_result.truncated) as i32 as f64],
&device,
),
);
items.push(TimeStep {
env_id: 0,
transition,
done: step_result.done,
ep_len: self.step_num,
cum_reward: self.current_reward,
action_context: context[0].clone(),
});
if !self.eval {
progress.items_processed += 1;
processor.process_train(RLEvent::TimeStep(EvaluationItem::new(
context[0].clone(),
progress.clone(),
None,
)));
if step_result.done {
processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(
EpisodeSummary {
episode_length: self.step_num,
cum_reward: self.current_reward,
},
progress.clone(),
None,
)));
}
}
if interrupter.should_stop() {
break;
}
if step_result.done || step_result.truncated {
self.env.reset();
self.current_reward = 0.;
self.step_num = 0;
self.run_num += 1;
}
}
items
}
fn update_policy(&mut self, update: RLC::PolicyState) {
self.agent.update(update);
}
fn run_episodes(
&mut self,
num_episodes: usize,
processor: &mut RLEventProcessorType<RLC>,
interrupter: &Interrupter,
progress: &mut Progress,
) -> Vec<RLTrajectory<BT, RLC>> {
self.env.reset();
let mut items = vec![];
for ep in 0..num_episodes {
let mut steps = vec![];
loop {
let step = self.run_steps(1, processor, interrupter, progress)[0].clone();
steps.push(step.clone());
if self.eval {
processor.process_valid(AgentEvaluationEvent::TimeStep(EvaluationItem::new(
step.action_context.clone(),
Progress::new(steps.len() + 1, steps.len() + 1),
None,
)));
if step.done {
processor.process_valid(AgentEvaluationEvent::EpisodeEnd(
EvaluationItem::new(
EpisodeSummary {
episode_length: step.ep_len,
cum_reward: step.cum_reward,
},
Progress::new(ep + 1, num_episodes),
None,
),
));
}
}
if interrupter.should_stop() || step.done {
break;
}
}
items.push(Trajectory::new(steps));
if interrupter.should_stop() {
break;
}
}
items
}
fn policy(&self) -> RLC::PolicyState {
self.agent.state()
}
}

View File

@@ -0,0 +1,5 @@
mod async_runner;
mod base;
pub use async_runner::*;
pub use base::*;

View File

@@ -0,0 +1,15 @@
mod checkpointer;
mod components;
mod env_runner;
mod off_policy;
mod output;
mod paradigm;
mod strategy;
pub use checkpointer::*;
pub use components::*;
pub use env_runner::*;
pub use off_policy::*;
pub use output::*;
pub use paradigm::*;
pub use strategy::*;

View File

@@ -0,0 +1,172 @@
use std::marker::PhantomData;
use crate::{
AgentEnvAsyncLoop, AgentEnvLoop, EvaluationItem, EventProcessorTraining, MultiAgentEnvLoop,
RLComponents, RLComponentsTypes, RLEvent, RLEventProcessorType, RLStrategy,
};
use burn_core::{self as burn};
use burn_core::{config::Config, data::dataloader::Progress};
use burn_ndarray::NdArray;
use burn_rl::{AsyncPolicy, Policy, PolicyLearner, Transition, TransitionBuffer};
/// Parameters of an on policy training with multi environments and double-batching.
#[derive(Config, Debug)]
pub struct OffPolicyConfig {
/// The number of environments to run simultaneously for experience collection.
#[config(default = 1)]
pub num_envs: usize,
/// Number of environment state to accumulate before running one step of inference with the policy.
/// Must be equal or less than the number of simultaneous environments.
#[config(default = 1)]
pub autobatch_size: usize,
/// Max number of transitions stored in the replay buffer.
#[config(default = 1024)]
pub replay_buffer_size: usize,
/// The number of steps to collect between each step of training.
#[config(default = 1)]
pub train_interval: usize,
/// Number of optimization steps done each `train_interval`.
#[config(default = 1)]
pub train_steps: usize,
/// The number of steps to collect between each evaluation.
#[config(default = 10_000)]
pub eval_interval: usize,
/// The number of episodes to run for each evaluation.
#[config(default = 1)]
pub eval_episodes: usize,
/// The number of transition to train on.
#[config(default = 32)]
pub train_batch_size: usize,
/// Number of steps to collect before starting to train.
#[config(default = 0)]
pub warmup_steps: usize,
}
/// Off-policy reinforcement learning strategy with multi-env experience collection and double-batching.
pub struct OffPolicyStrategy<RLC: RLComponentsTypes> {
config: OffPolicyConfig,
_components: PhantomData<RLC>,
}
impl<RLC: RLComponentsTypes> OffPolicyStrategy<RLC> {
/// Create a new off-policy base strategy.
pub fn new(config: OffPolicyConfig) -> Self {
Self {
config,
_components: PhantomData,
}
}
}
impl<RLC> RLStrategy<RLC> for OffPolicyStrategy<RLC>
where
RLC: RLComponentsTypes,
{
fn learn(
&self,
training_components: RLComponents<RLC>,
learner_agent: &mut RLC::LearningAgent,
starting_epoch: usize,
env_init: RLC::EnvInit,
) -> (RLC::Policy, RLEventProcessorType<RLC>) {
let mut event_processor = training_components.event_processor;
let mut checkpointer = training_components.checkpointer;
let num_steps_total = training_components.num_steps;
let mut env_runner = MultiAgentEnvLoop::<NdArray, RLC>::new(
env_init.clone(),
self.config.num_envs,
false,
AsyncPolicy::new(
self.config.num_envs.min(self.config.autobatch_size),
learner_agent.policy(),
),
false,
&Default::default(),
);
env_runner.start();
let mut env_runner_valid = AgentEnvAsyncLoop::<NdArray, RLC>::new(
env_init,
0,
true,
AsyncPolicy::new(1, learner_agent.policy()),
true,
&Default::default(),
);
env_runner_valid.start();
let mut transition_buffer =
TransitionBuffer::<Transition<NdArray, RLC::State, RLC::Action>>::new(
self.config.replay_buffer_size,
);
let mut valid_next = self.config.eval_interval + starting_epoch - 1;
let mut progress = Progress {
items_processed: starting_epoch,
items_total: num_steps_total,
};
let mut intermediary_update: Option<<RLC::Policy as Policy<RLC::Backend>>::PolicyState> =
None;
while progress.items_processed < num_steps_total {
if training_components.interrupter.should_stop() {
let reason = training_components
.interrupter
.get_message()
.unwrap_or(String::from("Reason unknown"));
log::info!("Training interrupted: {reason}");
break;
}
let previous_steps = progress.items_processed;
let items = env_runner.run_steps(
self.config.train_interval,
&mut event_processor,
&training_components.interrupter,
&mut progress,
);
transition_buffer.append(&mut items.iter().map(|i| i.transition.clone()).collect());
if transition_buffer.len() >= self.config.train_batch_size
&& progress.items_processed >= self.config.warmup_steps
{
if let Some(ref u) = intermediary_update {
env_runner.update_policy(u.clone());
}
for _ in 0..self.config.train_steps {
let transitions = transition_buffer.random_sample(self.config.train_batch_size);
let train_item = learner_agent.train(transitions.into());
intermediary_update = Some(train_item.policy);
event_processor.process_train(RLEvent::TrainStep(EvaluationItem::new(
train_item.item,
progress.clone(),
None,
)));
}
}
if valid_next > previous_steps && valid_next <= progress.items_processed {
env_runner_valid.update_policy(learner_agent.policy().state());
env_runner_valid.run_episodes(
self.config.eval_episodes,
&mut event_processor,
&training_components.interrupter,
&mut progress,
);
if let Some(checkpointer) = &mut checkpointer {
checkpointer.checkpoint(
&env_runner.policy(),
learner_agent,
valid_next,
&training_components.event_store,
);
}
valid_next += self.config.eval_interval;
}
}
(learner_agent.policy(), event_processor)
}
}

View File

@@ -0,0 +1,32 @@
use crate::{
ItemLazy,
metric::{Adaptor, CumulativeRewardInput, EpisodeLengthInput},
};
/// Summary of an episode.
pub struct EpisodeSummary {
/// The total length of the episode.
pub episode_length: usize,
/// The final cumulative reward.
pub cum_reward: f64,
}
impl ItemLazy for EpisodeSummary {
type ItemSync = EpisodeSummary;
fn sync(self) -> Self::ItemSync {
self
}
}
impl Adaptor<EpisodeLengthInput> for EpisodeSummary {
fn adapt(&self) -> EpisodeLengthInput {
EpisodeLengthInput::new(self.episode_length as f64)
}
}
impl Adaptor<CumulativeRewardInput> for EpisodeSummary {
fn adapt(&self) -> CumulativeRewardInput {
CumulativeRewardInput::new(self.cum_reward)
}
}

View File

@@ -0,0 +1,521 @@
use crate::checkpoint::{
AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer,
KeepLastNCheckpoints, MetricCheckpointingStrategy,
};
use crate::learner::base::Interrupter;
use crate::logger::{FileMetricLogger, MetricLogger};
use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split};
use crate::metric::{Adaptor, EpisodeLengthMetric, Metric, Numeric};
use crate::renderer::{MetricsRenderer, default_renderer};
use crate::{
ApplicationLoggerInstaller, AsyncProcessorTraining, FileApplicationLoggerInstaller, ItemLazy,
LearnerSummaryConfig, OffPolicyConfig, OffPolicyStrategy, RLAgentRecord, RLCheckpointer,
RLComponents, RLComponentsMarker, RLComponentsTypes, RLEventProcessor, RLMetrics,
RLPolicyRecord, RLStrategy,
};
use crate::{EpisodeSummary, RLStrategies};
use burn_core::record::FileRecorder;
use burn_core::tensor::backend::AutodiffBackend;
use burn_rl::{Batchable, Environment, EnvironmentInit, Policy, PolicyLearner};
use std::collections::BTreeSet;
use std::path::{Path, PathBuf};
use std::sync::Arc;
/// Structure to configure and launch reinforcement learning trainings.
pub struct RLTraining<RLC: RLComponentsTypes> {
// Not that complex. Extracting into yet another type would only make it more confusing.
#[allow(clippy::type_complexity)]
checkpointers: Option<(
AsyncCheckpointer<RLPolicyRecord<RLC>, RLC::Backend>,
AsyncCheckpointer<RLAgentRecord<RLC>, RLC::Backend>,
)>,
num_steps: usize,
checkpoint: Option<usize>,
directory: PathBuf,
grad_accumulation: Option<usize>,
renderer: Option<Box<dyn MetricsRenderer + 'static>>,
metrics: RLMetrics<RLC::TrainingOutput, RLC::ActionContext>,
event_store: LogEventStore,
interrupter: Interrupter,
tracing_logger: Option<Box<dyn ApplicationLoggerInstaller>>,
checkpointer_strategy: Box<dyn CheckpointingStrategy>,
learning_strategy: RLStrategies<RLC>,
// Use BTreeSet instead of HashSet for consistent (alphabetical) iteration order
summary_metrics: BTreeSet<String>,
summary: bool,
env_initializer: RLC::EnvInit,
}
impl<B, E, EI, A> RLTraining<RLComponentsMarker<B, E, EI, A>>
where
B: AutodiffBackend,
E: Environment + 'static,
EI: EnvironmentInit<E> + Send + 'static,
A: PolicyLearner<B> + Send + 'static,
A::TrainContext: ItemLazy + Clone + Send,
A::InnerPolicy: Policy<B> + Send,
<A::InnerPolicy as Policy<B>>::Observation: Batchable + Clone + Send,
<A::InnerPolicy as Policy<B>>::ActionDistribution: Batchable + Clone + Send,
<A::InnerPolicy as Policy<B>>::Action: Batchable + Clone + Send,
<A::InnerPolicy as Policy<B>>::ActionContext: ItemLazy + Clone + Send + 'static,
<A::InnerPolicy as Policy<B>>::PolicyState: Clone + Send,
E::State: Into<<A::InnerPolicy as Policy<B>>::Observation> + Clone + Send + 'static,
E::Action: From<<A::InnerPolicy as Policy<B>>::Action>
+ Into<<A::InnerPolicy as Policy<B>>::Action>
+ Clone
+ Send
+ 'static,
{
/// Creates a new runner for reinforcement learning.
///
/// # Arguments
///
/// * `directory` - The directory to save the checkpoints.
/// * `env_init` - Specifies how to initialize the environment.
pub fn new(directory: impl AsRef<Path>, env_initializer: EI) -> Self {
let directory = directory.as_ref().to_path_buf();
let experiment_log_file = directory.join("experiment.log");
Self {
num_steps: 1,
checkpoint: None,
checkpointers: None,
directory,
grad_accumulation: None,
metrics: RLMetrics::default(),
event_store: LogEventStore::default(),
renderer: None,
interrupter: Interrupter::new(),
tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(
experiment_log_file,
))),
checkpointer_strategy: Box::new(
ComposedCheckpointingStrategy::builder()
.add(KeepLastNCheckpoints::new(2))
.add(MetricCheckpointingStrategy::new(
&EpisodeLengthMetric::new(), // default to evaluations' cumulative reward.
Aggregate::Mean,
Direction::Lowest,
Split::Valid,
))
.build(),
),
learning_strategy: RLStrategies::OffPolicyStrategy(OffPolicyConfig::new()),
summary_metrics: BTreeSet::new(),
summary: false,
env_initializer,
}
}
}
impl<RLC: RLComponentsTypes + 'static> RLTraining<RLC> {
/// Replace the default learning strategy (Off Policy learning) with the provided one.
///
/// # Arguments
///
/// * `training_strategy` - The training strategy.
pub fn with_learning_strategy(mut self, learning_strategy: RLStrategies<RLC>) -> Self {
self.learning_strategy = learning_strategy;
self
}
/// Replace the default metric loggers with the provided ones.
///
/// # Arguments
///
/// * `logger` - The training logger.
pub fn with_metric_logger<ML>(mut self, logger: ML) -> Self
where
ML: MetricLogger + 'static,
{
self.event_store.register_logger(logger);
self
}
/// Update the checkpointing_strategy.
pub fn with_checkpointing_strategy<CS: CheckpointingStrategy + 'static>(
mut self,
strategy: CS,
) -> Self {
self.checkpointer_strategy = Box::new(strategy);
self
}
/// Replace the default CLI renderer with a custom one.
///
/// # Arguments
///
/// * `renderer` - The custom renderer.
pub fn renderer<MR>(mut self, renderer: MR) -> Self
where
MR: MetricsRenderer + 'static,
{
self.renderer = Some(Box::new(renderer));
self
}
/// Register numerical metrics for a training step of the agent.
pub fn metrics_train<Me: TrainMetricRegistration<RLC>>(self, metrics: Me) -> Self {
metrics.register(self)
}
/// Register textual metrics for a training step of the agent.
pub fn text_metrics_train<Me: TrainTextMetricRegistration<RLC>>(self, metrics: Me) -> Self {
metrics.register(self)
}
/// Register numerical metrics for each action of the agent.
pub fn metrics_agent<Me: AgentMetricRegistration<RLC>>(self, metrics: Me) -> Self {
metrics.register(self)
}
/// Register textual metrics for each action of the agent.
pub fn text_metrics_agent<Me: AgentTextMetricRegistration<RLC>>(self, metrics: Me) -> Self {
metrics.register(self)
}
/// Register numerical metrics for a completed episode.
pub fn metrics_episode<Me: EpisodeMetricRegistration<RLC>>(self, metrics: Me) -> Self {
metrics.register(self)
}
/// Register textual metrics for a completed episode.
pub fn text_metrics_episode<Me: EpisodeTextMetricRegistration<RLC>>(self, metrics: Me) -> Self {
metrics.register(self)
}
/// Register a textual metric for a training step.
pub fn text_metric_train<Me: Metric + 'static>(mut self, metric: Me) -> Self
where
<RLC::TrainingOutput as ItemLazy>::ItemSync: Adaptor<Me::Input>,
{
self.metrics.register_text_metric_train(metric);
self
}
/// Register a [numeric](crate::metric::Numeric) [metric](Metric) for a training step.
pub fn metric_train<Me>(mut self, metric: Me) -> Self
where
Me: Metric + Numeric + 'static,
<RLC::TrainingOutput as ItemLazy>::ItemSync: Adaptor<Me::Input>,
{
self.summary_metrics.insert(metric.name().to_string());
self.metrics.register_metric_train(metric);
self
}
/// Register a textual metric for each action taken by the agent.
pub fn text_metric_agent<Me: Metric + 'static>(mut self, metric: Me) -> Self
where
<RLC::ActionContext as ItemLazy>::ItemSync: Adaptor<Me::Input>,
{
self.metrics.register_text_metric_agent(metric.clone());
self.metrics.register_text_metric_agent_valid(metric);
self
}
/// Register a [numeric](crate::metric::Numeric) [metric](Metric) for each action taken by the agent.
pub fn metric_agent<Me>(mut self, metric: Me) -> Self
where
Me: Metric + Numeric + 'static,
<RLC::ActionContext as ItemLazy>::ItemSync: Adaptor<Me::Input>,
{
self.summary_metrics.insert(metric.name().to_string());
self.metrics.register_agent_metric(metric.clone());
self.metrics.register_agent_metric_valid(metric);
self
}
/// Register a textual metric for a completed episode.
pub fn text_metric_episode<Me: Metric + 'static>(mut self, metric: Me) -> Self
where
EpisodeSummary: Adaptor<Me::Input> + 'static,
{
self.metrics.register_text_metric_episode(metric.clone());
self.metrics.register_text_metric_episode_valid(metric);
self
}
/// Register a [numeric](crate::metric::Numeric) [metric](Metric) for a completed episode.
pub fn metric_episode<Me>(mut self, metric: Me) -> Self
where
Me: Metric + Numeric + 'static,
EpisodeSummary: Adaptor<Me::Input> + 'static,
{
self.summary_metrics.insert(metric.name().to_string());
self.metrics.register_episode_metric(metric.clone());
self.metrics.register_episode_metric_valid(metric);
self
}
/// The number of environment steps to train for.
pub fn num_steps(mut self, num_steps: usize) -> Self {
self.num_steps = num_steps;
self
}
/// The step from which the training must resume.
pub fn checkpoint(mut self, checkpoint: usize) -> Self {
self.checkpoint = Some(checkpoint);
self
}
/// Provides a handle that can be used to interrupt training.
pub fn interrupter(&self) -> Interrupter {
self.interrupter.clone()
}
/// Override the handle for stopping training with an externally provided handle
pub fn with_interrupter(mut self, interrupter: Interrupter) -> Self {
self.interrupter = interrupter;
self
}
/// By default, Rust logs are captured and written into
/// `experiment.log`. If disabled, standard Rust log handling
/// will apply.
pub fn with_application_logger(
mut self,
logger: Option<Box<dyn ApplicationLoggerInstaller>>,
) -> Self {
self.tracing_logger = logger;
self
}
/// Register a checkpointer that will save the environment runner's [policy](Policy)
/// and the [PolicyLearner](PolicyLearner) state to different files.
pub fn with_file_checkpointer<FR>(mut self, recorder: FR) -> Self
where
FR: FileRecorder<RLC::Backend> + 'static,
FR: FileRecorder<<RLC::Backend as AutodiffBackend>::InnerBackend> + 'static,
{
let checkpoint_dir = self.directory.join("checkpoint");
let checkpointer_policy =
FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "policy");
let checkpointer_learning =
FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "learning-agent");
self.checkpointers = Some((
AsyncCheckpointer::new(checkpointer_policy),
AsyncCheckpointer::new(checkpointer_learning),
));
self
}
/// Enable the training summary report.
///
/// The summary will be displayed after `.launch()`, when the renderer is dropped.
pub fn summary(mut self) -> Self {
self.summary = true;
self
}
/// Launch the training with the specified [PolicyLearner](PolicyLearner) on the specified environment.
pub fn launch(mut self, learner_agent: RLC::LearningAgent) -> RLResult<RLC::Policy> {
if self.tracing_logger.is_some()
&& let Err(e) = self.tracing_logger.as_ref().unwrap().install()
{
log::warn!("Failed to install the experiment logger: {e}");
}
let renderer = self
.renderer
.unwrap_or_else(|| default_renderer(self.interrupter.clone(), self.checkpoint));
if !self.event_store.has_loggers() {
self.event_store
.register_logger(FileMetricLogger::new(self.directory.clone()));
}
let event_store = Arc::new(EventStoreClient::new(self.event_store));
let event_processor = AsyncProcessorTraining::new(RLEventProcessor::new(
self.metrics,
renderer,
event_store.clone(),
));
let checkpointer = self.checkpointers.map(|(policy, learning_agent)| {
RLCheckpointer::new(policy, learning_agent, self.checkpointer_strategy)
});
let summary = if self.summary {
Some(LearnerSummaryConfig {
directory: self.directory,
metrics: self.summary_metrics.into_iter().collect::<Vec<_>>(),
})
} else {
None
};
let components = RLComponents::<RLC> {
checkpoint: self.checkpoint,
checkpointer,
interrupter: self.interrupter,
event_processor,
event_store,
num_steps: self.num_steps,
grad_accumulation: self.grad_accumulation,
summary,
};
match self.learning_strategy {
RLStrategies::OffPolicyStrategy(config) => {
let strategy = OffPolicyStrategy::new(config);
strategy.train(learner_agent, components, self.env_initializer)
}
RLStrategies::Custom(strategy) => {
strategy.train(learner_agent, components, self.env_initializer)
}
}
}
}
/// The result of reinforcement learning, containing the final policy along with the [renderer](MetricsRenderer).
pub struct RLResult<P> {
/// The learned policy.
pub policy: P,
/// The renderer that can be used for follow up training and evaluation.
pub renderer: Box<dyn MetricsRenderer>,
}
/// Trait to fake variadic generics for train step metrics.
pub trait AgentMetricRegistration<RLC: RLComponentsTypes>: Sized {
/// Register the metrics.
fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
}
/// Trait to fake variadic generics for train step text metrics.
pub trait AgentTextMetricRegistration<RLC: RLComponentsTypes>: Sized {
/// Register the metrics.
fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
}
/// Trait to fake variadic generics for env step metrics.
pub trait TrainMetricRegistration<RLC: RLComponentsTypes>: Sized {
/// Register the metrics.
fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
}
/// Trait to fake variadic generics for env step text metrics.
pub trait TrainTextMetricRegistration<RLC: RLComponentsTypes>: Sized {
/// Register the metrics.
fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
}
/// Trait to fake variadic generics for episode metrics.
pub trait EpisodeMetricRegistration<RLC: RLComponentsTypes>: Sized {
/// Register the metrics.
fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
}
/// Trait to fake variadic generics for episode text metrics.
pub trait EpisodeTextMetricRegistration<RLC: RLComponentsTypes>: Sized {
/// Register the metrics.
fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
}
macro_rules! gen_tuple {
($($M:ident),*) => {
impl<$($M,)* RLC: RLComponentsTypes + 'static> TrainTextMetricRegistration<RLC> for ($($M,)*)
where
$(<RLC::TrainingOutput as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
$($M: Metric + 'static,)*
{
#[allow(non_snake_case)]
fn register(
self,
builder: RLTraining<RLC>,
) -> RLTraining<RLC> {
let ($($M,)*) = self;
$(let builder = builder.text_metric_train($M.clone());)*
builder
}
}
impl<$($M,)* RLC: RLComponentsTypes + 'static> TrainMetricRegistration<RLC> for ($($M,)*)
where
$(<RLC::TrainingOutput as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
$($M: Metric + Numeric + 'static,)*
{
#[allow(non_snake_case)]
fn register(
self,
builder: RLTraining<RLC>,
) -> RLTraining<RLC> {
let ($($M,)*) = self;
$(let builder = builder.metric_train($M.clone());)*
builder
}
}
impl<$($M,)* RLC: RLComponentsTypes + 'static> AgentTextMetricRegistration<RLC> for ($($M,)*)
where
$(<RLC::ActionContext as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
$($M: Metric + 'static,)*
{
#[allow(non_snake_case)]
fn register(
self,
builder: RLTraining<RLC>,
) -> RLTraining<RLC> {
let ($($M,)*) = self;
$(let builder = builder.text_metric_agent($M.clone());)*
builder
}
}
impl<$($M,)* RLC: RLComponentsTypes + 'static> AgentMetricRegistration<RLC> for ($($M,)*)
where
$(<RLC::ActionContext as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
$($M: Metric + Numeric + 'static,)*
{
#[allow(non_snake_case)]
fn register(
self,
builder: RLTraining<RLC>,
) -> RLTraining<RLC> {
let ($($M,)*) = self;
$(let builder = builder.metric_agent($M.clone());)*
builder
}
}
impl<$($M,)* RLC: RLComponentsTypes + 'static> EpisodeTextMetricRegistration<RLC> for ($($M,)*)
where
$(EpisodeSummary: Adaptor<$M::Input> + 'static,)*
$($M: Metric + 'static,)*
{
#[allow(non_snake_case)]
fn register(
self,
builder: RLTraining<RLC>,
) -> RLTraining<RLC> {
let ($($M,)*) = self;
$(let builder = builder.text_metric_episode($M.clone());)*
builder
}
}
impl<$($M,)* RLC: RLComponentsTypes + 'static> EpisodeMetricRegistration<RLC> for ($($M,)*)
where
$(EpisodeSummary: Adaptor<$M::Input> + 'static,)*
$($M: Metric + Numeric + 'static,)*
{
#[allow(non_snake_case)]
fn register(
self,
builder: RLTraining<RLC>,
) -> RLTraining<RLC> {
let ($($M,)*) = self;
$(let builder = builder.metric_episode($M.clone());)*
builder
}
}
};
}
gen_tuple!(M1);
gen_tuple!(M1, M2);
gen_tuple!(M1, M2, M3);
gen_tuple!(M1, M2, M3, M4);
gen_tuple!(M1, M2, M3, M4, M5);
gen_tuple!(M1, M2, M3, M4, M5, M6);

View File

@@ -0,0 +1,99 @@
use std::sync::Arc;
use crate::{
Interrupter, LearnerSummaryConfig, OffPolicyConfig, RLCheckpointer, RLComponentsTypes, RLEvent,
RLEventProcessorType, RLResult,
metric::{processor::EventProcessorTraining, store::EventStoreClient},
};
/// Struct to minimise parameters passed to [RLStrategy::train].
pub struct RLComponents<RLC: RLComponentsTypes> {
/// The total number of environment steps.
pub num_steps: usize,
/// The step number from which to continue the training.
pub checkpoint: Option<usize>,
/// A checkpointer used to load and save learning checkpoints.
pub checkpointer: Option<RLCheckpointer<RLC>>,
/// Enables gradients accumulation.
pub grad_accumulation: Option<usize>,
/// An [Interupter](Interrupter) that allows aborting the training/evaluation process early.
pub interrupter: Interrupter,
/// An [EventProcessor](crate::EventProcessorTraining) that processes events happening during training and evaluation.
pub event_processor: RLEventProcessorType<RLC>,
/// A reference to an [EventStoreClient](EventStoreClient).
pub event_store: Arc<EventStoreClient>,
/// Config for creating a summary of the learning
pub summary: Option<LearnerSummaryConfig>,
}
/// The strategy for reinforcement learning.
#[derive(Clone)]
pub enum RLStrategies<RLC: RLComponentsTypes> {
/// Training on one device
OffPolicyStrategy(OffPolicyConfig),
/// Training using a custom learning strategy
Custom(CustomRLStrategy<RLC>),
}
/// A reference to an implementation of [RLStrategy].
pub type CustomRLStrategy<LC> = Arc<dyn RLStrategy<LC>>;
/// Provides the `fit` function for any learning strategy
pub trait RLStrategy<RLC: RLComponentsTypes> {
/// Train the learner agent with this strategy.
fn train(
&self,
mut learner_agent: RLC::LearningAgent,
mut training_components: RLComponents<RLC>,
env_init: RLC::EnvInit,
) -> RLResult<RLC::Policy> {
let starting_epoch = match training_components.checkpoint {
Some(checkpoint) => {
if let Some(checkpointer) = &mut training_components.checkpointer {
learner_agent = checkpointer.load_checkpoint(
learner_agent,
&Default::default(),
checkpoint,
);
}
checkpoint + 1
}
None => 1,
};
let summary_config = training_components.summary.clone();
// Event processor start training
training_components
.event_processor
.process_train(RLEvent::Start);
// Training loop
let (policy, mut event_processor) = self.learn(
training_components,
&mut learner_agent,
starting_epoch,
env_init,
);
let summary = summary_config.and_then(|summary| summary.init().ok());
// Signal training end. For the TUI renderer, this handles the exit & return to main screen.
// TODO: summary makes sense for RL?
event_processor.process_train(RLEvent::End(summary));
// let model = model.valid();
let renderer = event_processor.renderer();
RLResult { policy, renderer }
}
/// Training loop for this strategy
fn learn(
&self,
training_components: RLComponents<RLC>,
learner_agent: &mut RLC::LearningAgent,
starting_epoch: usize,
env_init: RLC::EnvInit,
) -> (RLC::Policy, RLEventProcessorType<RLC>);
}

View File

@@ -16,10 +16,10 @@ use crate::renderer::{MetricsRenderer, default_renderer};
use crate::single::SingleDevicetrainingStrategy;
use crate::{
ApplicationLoggerInstaller, EarlyStoppingStrategyRef, FileApplicationLoggerInstaller,
InferenceBackend, InferenceModel, InferenceModelInput, InferenceStep, LearnerModelRecord,
LearnerOptimizerRecord, LearnerSchedulerRecord, LearnerSummaryConfig, LearningCheckpointer,
LearningComponentsMarker, LearningComponentsTypes, LearningResult, TrainStep, TrainingBackend,
TrainingComponents, TrainingModelInput, TrainingStrategy,
InferenceBackend, InferenceModel, InferenceModelInput, InferenceStep, LearnerEvent,
LearnerModelRecord, LearnerOptimizerRecord, LearnerSchedulerRecord, LearnerSummaryConfig,
LearningCheckpointer, LearningComponentsMarker, LearningComponentsTypes, LearningResult,
TrainStep, TrainingBackend, TrainingComponents, TrainingModelInput, TrainingStrategy,
};
use crate::{Learner, SupervisedLearningStrategy};
use burn_core::data::dataloader::DataLoader;
@@ -38,7 +38,8 @@ pub type TrainLoader<LC> = Arc<dyn DataLoader<TrainingBackend<LC>, TrainingModel
pub type ValidLoader<LC> = Arc<dyn DataLoader<InferenceBackend<LC>, InferenceModelInput<LC>>>;
/// The event processor type for supervised learning.
pub type SupervisedTrainingEventProcessor<LC> = AsyncProcessorTraining<
FullEventProcessorTraining<TrainingModelOutput<LC>, InferenceModelOutput<LC>>,
LearnerEvent<TrainingModelOutput<LC>>,
LearnerEvent<InferenceModelOutput<LC>>,
>;
/// Structure to configure and launch supervised learning trainings.
@@ -181,7 +182,7 @@ impl<LC: LearningComponentsTypes> SupervisedTraining<LC> {
metrics.register(self)
}
/// Register all metrics as numeric for the training and validation set.
/// Register all metrics as text for the training and validation set.
pub fn metrics_text<Me: TextMetricRegistration<LC>>(self, metrics: Me) -> Self {
metrics.register(self)
}

View File

@@ -1,4 +1,5 @@
use burn_collective::{PeerId, ReduceOperation};
use burn_core::data::dataloader::Progress;
use burn_core::module::AutodiffModule;
use burn_core::tensor::backend::AutodiffBackend;
use burn_optim::GradientsAccumulator;
@@ -9,7 +10,7 @@ use std::sync::{Arc, Mutex};
use crate::SupervisedTrainingEventProcessor;
use crate::learner::base::Interrupter;
use crate::metric::processor::{EventProcessorTraining, LearnerEvent, LearnerItem};
use crate::metric::processor::{EventProcessorTraining, LearnerEvent, TrainingItem};
use crate::{
InferenceStep, Learner, LearningComponentsTypes, TrainLoader, TrainingBackend, ValidLoader,
};
@@ -18,14 +19,12 @@ use crate::{
#[derive(new)]
pub struct DdpValidEpoch<LC: LearningComponentsTypes> {
dataloader: ValidLoader<LC>,
epoch_total: usize,
}
/// A training epoch.
#[derive(new)]
pub struct DdpTrainEpoch<LC: LearningComponentsTypes> {
dataloader: TrainLoader<LC>,
epoch_total: usize,
grad_accumulation: Option<usize>,
}
@@ -39,10 +38,11 @@ impl<LC: LearningComponentsTypes> DdpValidEpoch<LC> {
pub fn run(
&self,
model: &<LC as LearningComponentsTypes>::TrainingModel,
epoch: usize,
global_progress: &Progress,
processor: &mut SupervisedTrainingEventProcessor<LC>,
interrupter: &Interrupter,
) {
let epoch = global_progress.items_processed;
log::info!("Executing validation step for epoch {}", epoch);
let model = model.valid();
@@ -54,7 +54,13 @@ impl<LC: LearningComponentsTypes> DdpValidEpoch<LC> {
iteration += 1;
let item = model.step(item);
let item = LearnerItem::new(item, progress, epoch, self.epoch_total, iteration, None);
let item = TrainingItem::new(
item,
progress,
global_progress.clone(),
Some(iteration),
None,
);
processor.process_valid(LearnerEvent::ProcessedItem(item));
@@ -84,13 +90,14 @@ impl<LC: LearningComponentsTypes> DdpTrainEpoch<LC> {
pub fn run(
&self,
learner: &mut Learner<LC>,
epoch: usize,
global_progress: &Progress,
processor: Arc<Mutex<SupervisedTrainingEventProcessor<LC>>>,
interrupter: &Interrupter,
peer_id: PeerId,
peer_count: usize,
is_main: bool,
) {
let epoch = global_progress.items_processed;
log::info!("Executing training step for epoch {}", epoch,);
let mut iterator = self.dataloader.iter();
@@ -143,12 +150,11 @@ impl<LC: LearningComponentsTypes> DdpTrainEpoch<LC> {
}
}
let item = LearnerItem::new(
let item = TrainingItem::new(
item.item,
progress,
epoch,
self.epoch_total,
iteration,
global_progress.clone(),
Some(iteration),
Some(learner.lr_current()),
);

View File

@@ -1,5 +1,6 @@
use crate::ddp::epoch::{DdpTrainEpoch, DdpValidEpoch};
use crate::ddp::strategy::WorkerComponents;
use crate::single::TrainingLoop;
use crate::{
Learner, LearningCheckpointer, LearningComponentsTypes, SupervisedTrainingEventProcessor,
TrainLoader, TrainingBackend, ValidLoader,
@@ -83,18 +84,19 @@ where
// Changed the train epoch to keep the dataloaders
let epoch_train = DdpTrainEpoch::<LC>::new(
self.dataloader_train.clone(),
num_epochs,
self.components.grad_accumulation,
);
let epoch_valid = self
.dataloader_valid
.map(|dataloader| DdpValidEpoch::<LC>::new(dataloader, num_epochs));
.map(|dataloader| DdpValidEpoch::<LC>::new(dataloader));
self.learner.fork(&self.device);
for epoch in self.starting_epoch..num_epochs + 1 {
for training_progress in TrainingLoop::new(self.starting_epoch, num_epochs) {
let epoch = training_progress.items_processed;
epoch_train.run(
&mut self.learner,
epoch,
&training_progress,
self.event_processor.clone(),
&interrupter,
self.peer_id,
@@ -111,7 +113,7 @@ where
let mut event_processor = self.event_processor.lock().unwrap();
runner.run(
&self.learner.model(),
epoch,
&training_progress,
&mut event_processor,
&interrupter,
);

View File

@@ -1,10 +1,11 @@
use crate::learner::base::Interrupter;
use crate::metric::processor::{EventProcessorTraining, LearnerEvent, LearnerItem};
use crate::metric::processor::{EventProcessorTraining, LearnerEvent, TrainingItem};
use crate::train::MultiDevicesTrainStep;
use crate::{
Learner, LearningComponentsTypes, MultiDeviceOptim, SupervisedTrainingEventProcessor,
TrainLoader, TrainingBackend,
};
use burn_core::data::dataloader::Progress;
use burn_core::prelude::DeviceOps;
use burn_core::tensor::Device;
use burn_core::tensor::backend::DeviceId;
@@ -16,7 +17,6 @@ use std::collections::HashMap;
#[derive(new)]
pub struct MultiDeviceTrainEpoch<LC: LearningComponentsTypes> {
dataloaders: Vec<TrainLoader<LC>>,
epoch_total: usize,
grad_accumulation: Option<usize>,
}
@@ -38,30 +38,39 @@ impl<LC: LearningComponentsTypes> MultiDeviceTrainEpoch<LC> {
pub fn run(
&self,
learner: &mut Learner<LC>,
epoch: usize,
global_progress: &Progress,
event_processor: &mut SupervisedTrainingEventProcessor<LC>,
interrupter: &Interrupter,
devices: Vec<Device<TrainingBackend<LC>>>,
strategy: MultiDeviceOptim,
) {
match strategy {
MultiDeviceOptim::OptimMainDevice => {
self.run_optim_main(learner, epoch, event_processor, interrupter, devices)
}
MultiDeviceOptim::OptimSharded => {
self.run_optim_distr(learner, epoch, event_processor, interrupter, devices)
}
MultiDeviceOptim::OptimMainDevice => self.run_optim_main(
learner,
global_progress,
event_processor,
interrupter,
devices,
),
MultiDeviceOptim::OptimSharded => self.run_optim_distr(
learner,
global_progress,
event_processor,
interrupter,
devices,
),
}
}
fn run_optim_main(
&self,
learner: &mut Learner<LC>,
epoch: usize,
global_progress: &Progress,
event_processor: &mut SupervisedTrainingEventProcessor<LC>,
interrupter: &Interrupter,
devices: Vec<Device<TrainingBackend<LC>>>,
) {
let epoch = global_progress.items_processed;
log::info!(
"Executing training step for epoch {} on devices {:?}",
epoch,
@@ -108,12 +117,11 @@ impl<LC: LearningComponentsTypes> MultiDeviceTrainEpoch<LC> {
for item in progress_items {
iteration += 1;
let item = LearnerItem::new(
let item = TrainingItem::new(
item,
progress.clone(),
epoch,
self.epoch_total,
iteration,
global_progress.clone(),
Some(iteration),
Some(learner.lr_current()),
);
@@ -131,11 +139,12 @@ impl<LC: LearningComponentsTypes> MultiDeviceTrainEpoch<LC> {
fn run_optim_distr(
&self,
learner: &mut Learner<LC>,
epoch: usize,
global_progress: &Progress,
event_processor: &mut SupervisedTrainingEventProcessor<LC>,
interrupter: &Interrupter,
devices: Vec<Device<TrainingBackend<LC>>>,
) {
let epoch = global_progress.items_processed;
log::info!(
"Executing training step for epoch {} on devices {:?}",
epoch,
@@ -189,12 +198,11 @@ impl<LC: LearningComponentsTypes> MultiDeviceTrainEpoch<LC> {
for item in progress_items {
iteration += 1;
let item = LearnerItem::new(
let item = TrainingItem::new(
item,
progress.clone(),
epoch,
self.epoch_total,
iteration,
global_progress.clone(),
Some(iteration),
Some(learner.lr_current()),
);

View File

@@ -1,8 +1,9 @@
use crate::{
Learner, LearningComponentsTypes, MultiDeviceOptim, SupervisedLearningStrategy,
SupervisedTrainingEventProcessor, TrainLoader, TrainingBackend, TrainingComponents,
TrainingModel, ValidLoader, multi::epoch::MultiDeviceTrainEpoch,
single::epoch::SingleDeviceValidEpoch,
TrainingModel, ValidLoader,
multi::epoch::MultiDeviceTrainEpoch,
single::{TrainingLoop, epoch::SingleDeviceValidEpoch},
};
use burn_core::{data::dataloader::split::split_dataloader, tensor::Device};
@@ -39,20 +40,19 @@ impl<LC: LearningComponentsTypes> SupervisedLearningStrategy<LC>
let mut event_processor = training_components.event_processor;
let mut checkpointer = training_components.checkpointer;
let mut early_stopping = training_components.early_stopping;
let num_epochs = training_components.num_epochs;
let epoch_train = MultiDeviceTrainEpoch::<LC>::new(
dataloader_train.clone(),
num_epochs,
training_components.grad_accumulation,
);
let epoch_valid: SingleDeviceValidEpoch<LC> =
SingleDeviceValidEpoch::new(dataloader_valid.clone(), num_epochs);
SingleDeviceValidEpoch::new(dataloader_valid.clone());
for epoch in starting_epoch..training_components.num_epochs + 1 {
for training_progress in TrainingLoop::new(starting_epoch, training_components.num_epochs) {
let epoch = training_progress.items_processed;
epoch_train.run(
&mut learner,
epoch,
&training_progress,
&mut event_processor,
&training_components.interrupter,
self.devices.to_vec(),
@@ -70,7 +70,7 @@ impl<LC: LearningComponentsTypes> SupervisedLearningStrategy<LC>
epoch_valid.run(
&learner,
epoch,
&training_progress,
&mut event_processor,
&training_components.interrupter,
);

View File

@@ -1,9 +1,10 @@
use crate::learner::base::Interrupter;
use crate::metric::processor::{EventProcessorTraining, LearnerEvent, LearnerItem};
use crate::metric::processor::{EventProcessorTraining, LearnerEvent, TrainingItem};
use crate::{
InferenceStep, Learner, LearningComponentsTypes, SupervisedTrainingEventProcessor, TrainLoader,
ValidLoader,
};
use burn_core::data::dataloader::Progress;
use burn_core::module::AutodiffModule;
use burn_optim::GradientsAccumulator;
@@ -11,14 +12,12 @@ use burn_optim::GradientsAccumulator;
#[derive(new)]
pub struct SingleDeviceValidEpoch<LC: LearningComponentsTypes> {
dataloader: ValidLoader<LC>,
epoch_total: usize,
}
/// A training epoch.
#[derive(new)]
pub struct SingleDeviceTrainEpoch<LC: LearningComponentsTypes> {
dataloader: TrainLoader<LC>,
epoch_total: usize,
grad_accumulation: Option<usize>,
}
@@ -32,10 +31,11 @@ impl<LC: LearningComponentsTypes> SingleDeviceValidEpoch<LC> {
pub fn run(
&self,
learner: &Learner<LC>,
epoch: usize,
global_progress: &Progress,
processor: &mut SupervisedTrainingEventProcessor<LC>,
interrupter: &Interrupter,
) {
let epoch = global_progress.items_processed;
log::info!("Executing validation step for epoch {}", epoch);
let model = learner.model().valid();
@@ -47,7 +47,13 @@ impl<LC: LearningComponentsTypes> SingleDeviceValidEpoch<LC> {
iteration += 1;
let item = model.step(item);
let item = LearnerItem::new(item, progress, epoch, self.epoch_total, iteration, None);
let item = TrainingItem::new(
item,
progress,
global_progress.clone(),
Some(iteration),
None,
);
processor.process_valid(LearnerEvent::ProcessedItem(item));
@@ -75,10 +81,11 @@ impl<LC: LearningComponentsTypes> SingleDeviceTrainEpoch<LC> {
pub fn run(
&self,
learner: &mut Learner<LC>,
epoch: usize,
global_progress: &Progress,
processor: &mut SupervisedTrainingEventProcessor<LC>,
interrupter: &Interrupter,
) {
let epoch = global_progress.items_processed;
log::info!("Executing training step for epoch {}", epoch,);
// Single device / dataloader
@@ -110,12 +117,11 @@ impl<LC: LearningComponentsTypes> SingleDeviceTrainEpoch<LC> {
None => learner.optimizer_step(item.grads),
}
let item = LearnerItem::new(
let item = TrainingItem::new(
item.item,
progress,
epoch,
self.epoch_total,
iteration,
global_progress.clone(),
Some(iteration),
Some(learner.lr_current()),
);

View File

@@ -3,7 +3,7 @@ use crate::{
TrainLoader, TrainingBackend, TrainingComponents, TrainingModel, ValidLoader,
single::epoch::{SingleDeviceTrainEpoch, SingleDeviceValidEpoch},
};
use burn_core::tensor::Device;
use burn_core::{data::dataloader::Progress, tensor::Device};
/// Simplest learning strategy possible, with only a single devices doing both the training and
/// validation.
@@ -16,6 +16,31 @@ impl<LC: LearningComponentsTypes> SingleDevicetrainingStrategy<LC> {
}
}
#[derive(new)]
pub(crate) struct TrainingLoop {
next_iteration: usize,
total_iteration: usize,
}
/// An iterator that returns the progress of the training.
impl Iterator for TrainingLoop {
type Item = Progress;
fn next(&mut self) -> Option<Self::Item> {
if self.next_iteration > self.total_iteration {
return None;
}
let progress = Progress {
items_processed: self.next_iteration,
items_total: self.total_iteration,
};
self.next_iteration += 1;
Some(progress)
}
}
impl<LC: LearningComponentsTypes> SupervisedLearningStrategy<LC>
for SingleDevicetrainingStrategy<LC>
{
@@ -33,20 +58,17 @@ impl<LC: LearningComponentsTypes> SupervisedLearningStrategy<LC>
let mut event_processor = training_components.event_processor;
let mut checkpointer = training_components.checkpointer;
let mut early_stopping = training_components.early_stopping;
let num_epochs = training_components.num_epochs;
let epoch_train: SingleDeviceTrainEpoch<LC> = SingleDeviceTrainEpoch::new(
dataloader_train,
num_epochs,
training_components.grad_accumulation,
);
let epoch_train: SingleDeviceTrainEpoch<LC> =
SingleDeviceTrainEpoch::new(dataloader_train, training_components.grad_accumulation);
let epoch_valid: SingleDeviceValidEpoch<LC> =
SingleDeviceValidEpoch::new(dataloader_valid.clone(), num_epochs);
SingleDeviceValidEpoch::new(dataloader_valid.clone());
for epoch in starting_epoch..training_components.num_epochs + 1 {
for training_progress in TrainingLoop::new(starting_epoch, training_components.num_epochs) {
let epoch = training_progress.items_processed;
epoch_train.run(
&mut learner,
epoch,
&training_progress,
&mut event_processor,
&training_components.interrupter,
);
@@ -62,7 +84,7 @@ impl<LC: LearningComponentsTypes> SupervisedLearningStrategy<LC>
epoch_valid.run(
&learner,
epoch,
&training_progress,
&mut event_processor,
&training_components.interrupter,
);

View File

@@ -8,14 +8,11 @@ pub struct MetricMetadata {
/// The current progress.
pub progress: Progress,
/// The current epoch.
pub epoch: usize,
/// The total number of epochs.
pub epoch_total: usize,
/// The global progress of the training (e.g. epochs).
pub global_progress: Progress,
/// The current iteration.
pub iteration: usize,
pub iteration: Option<usize>,
/// The current learning rate.
pub lr: Option<LearningRate>,
@@ -30,9 +27,11 @@ impl MetricMetadata {
items_processed: 1,
items_total: 1,
},
epoch: 0,
epoch_total: 1,
iteration: 0,
global_progress: Progress {
items_processed: 0,
items_total: 1,
},
iteration: Some(0),
lr: None,
}
}

View File

@@ -38,7 +38,14 @@ impl Metric for IterationSpeedMetric {
fn update(&mut self, _: &Self::Input, metadata: &MetricMetadata) -> SerializedEntry {
let raw = match self.instant {
Some(val) => metadata.iteration as f64 / val.elapsed().as_secs_f64(),
Some(val) => {
// If iteration is not logged, compute the speed over the number of items processed.
// 1 iteration should equal 1 item when iteration is not logged.
metadata
.iteration
.unwrap_or(metadata.progress.items_processed) as f64
/ val.elapsed().as_secs_f64()
}
None => {
self.instant = Some(std::time::Instant::now());
0.0

View File

@@ -37,6 +37,7 @@ mod loss;
mod perplexity;
mod precision;
mod recall;
mod rl;
mod top_k_acc;
mod wer;
@@ -53,6 +54,7 @@ pub use loss::*;
pub use perplexity::*;
pub use precision::*;
pub use recall::*;
pub use rl::*;
pub use top_k_acc::*;
pub use wer::*;

View File

@@ -1,11 +1,11 @@
use crate::metric::processor::{EvaluatorEvent, EventProcessorEvaluation};
use super::{EventProcessorTraining, LearnerEvent};
use super::EventProcessorTraining;
use async_channel::{Receiver, Sender};
/// Event processor for the training process.
pub struct AsyncProcessorTraining<P: EventProcessorTraining> {
sender: Sender<Message<P>>,
pub struct AsyncProcessorTraining<ET, EV> {
sender: Sender<Message<ET, EV>>,
}
/// Event processor for the model evaluation.
@@ -13,9 +13,9 @@ pub struct AsyncProcessorEvaluation<P: EventProcessorEvaluation> {
sender: Sender<EvalMessage<P>>,
}
struct WorkerTraining<P: EventProcessorTraining> {
struct WorkerTraining<ET, EV, P: EventProcessorTraining<ET, EV>> {
processor: P,
rec: Receiver<Message<P>>,
rec: Receiver<Message<ET, EV>>,
}
struct WorkerEvaluation<P: EventProcessorEvaluation> {
@@ -23,8 +23,10 @@ struct WorkerEvaluation<P: EventProcessorEvaluation> {
rec: Receiver<EvalMessage<P>>,
}
impl<P: EventProcessorTraining + 'static> WorkerTraining<P> {
pub fn start(processor: P, rec: Receiver<Message<P>>) {
impl<ET: Send + 'static, EV: Send + 'static, P: EventProcessorTraining<ET, EV> + 'static>
WorkerTraining<ET, EV, P>
{
pub fn start(processor: P, rec: Receiver<Message<ET, EV>>) {
let mut worker = Self { processor, rec };
std::thread::spawn(move || {
@@ -59,9 +61,9 @@ impl<P: EventProcessorEvaluation + 'static> WorkerEvaluation<P> {
}
}
impl<P: EventProcessorTraining + 'static> AsyncProcessorTraining<P> {
impl<ET: Send + 'static, EV: Send + 'static> AsyncProcessorTraining<ET, EV> {
/// Create an event processor for training.
pub fn new(processor: P) -> Self {
pub fn new<P: EventProcessorTraining<ET, EV> + 'static>(processor: P) -> Self {
let (sender, rec) = async_channel::bounded(1);
WorkerTraining::start(processor, rec);
@@ -81,9 +83,9 @@ impl<P: EventProcessorEvaluation + 'static> AsyncProcessorEvaluation<P> {
}
}
enum Message<P: EventProcessorTraining> {
Train(LearnerEvent<P::ItemTrain>),
Valid(LearnerEvent<P::ItemValid>),
enum Message<EventTrain, EventValid> {
Train(EventTrain),
Valid(EventValid),
Renderer(Sender<Box<dyn crate::renderer::MetricsRenderer>>),
}
@@ -92,15 +94,12 @@ enum EvalMessage<P: EventProcessorEvaluation> {
Renderer(Sender<Box<dyn crate::renderer::MetricsRenderer>>),
}
impl<P: EventProcessorTraining> EventProcessorTraining for AsyncProcessorTraining<P> {
type ItemTrain = P::ItemTrain;
type ItemValid = P::ItemValid;
fn process_train(&mut self, event: LearnerEvent<Self::ItemTrain>) {
impl<ET: Send, EV: Send> EventProcessorTraining<ET, EV> for AsyncProcessorTraining<ET, EV> {
fn process_train(&mut self, event: ET) {
self.sender.send_blocking(Message::Train(event)).unwrap();
}
fn process_valid(&mut self, event: LearnerEvent<Self::ItemValid>) {
fn process_valid(&mut self, event: EV) {
self.sender.send_blocking(Message::Valid(event)).unwrap();
}

View File

@@ -2,7 +2,7 @@ use burn_core::data::dataloader::Progress;
use burn_optim::LearningRate;
use crate::{
LearnerSummary,
EpisodeSummary, LearnerSummary,
renderer::{EvaluationName, MetricsRenderer},
};
@@ -11,19 +11,45 @@ pub enum LearnerEvent<T> {
/// Signal the start of the process (e.g., training start)
Start,
/// Signal that an item have been processed.
ProcessedItem(LearnerItem<T>),
ProcessedItem(TrainingItem<T>),
/// Signal the end of an epoch.
EndEpoch(usize),
/// Signal the end of the process (e.g., training end).
End(Option<LearnerSummary>),
}
/// Event happening during reinforcement learning.
pub enum RLEvent<TS, ES> {
/// Signal the start of the process (e.g., learning starts).
Start,
/// Signal an agent's training step.
TrainStep(EvaluationItem<TS>),
/// Signal a timestep of the agent-environment interface.
TimeStep(EvaluationItem<ES>),
/// Signal an episode end.
EpisodeEnd(EvaluationItem<EpisodeSummary>),
/// Signal the end of the process (e.g., learning ends).
End(Option<LearnerSummary>),
}
/// Event happening during evaluation of a reinforcement learning's agent.
pub enum AgentEvaluationEvent<T> {
/// Signal the start of the process (e.g., training start)
Start,
/// Signal a timestep of the agent-environment interface.
TimeStep(EvaluationItem<T>),
/// Signal an episode end.
EpisodeEnd(EvaluationItem<EpisodeSummary>),
/// Signal the end of the process (e.g., training end).
End,
}
/// Event happening during the evaluation process.
pub enum EvaluatorEvent<T> {
/// Signal the start of the process (e.g., training start)
Start,
/// Signal that an item have been processed.
ProcessedItem(EvaluationName, LearnerItem<T>),
ProcessedItem(EvaluationName, EvaluationItem<T>),
/// Signal the end of the process (e.g., training end).
End,
}
@@ -40,16 +66,11 @@ pub trait ItemLazy: Send {
}
/// Process events happening during training and validation.
pub trait EventProcessorTraining: Send {
/// The training item.
type ItemTrain: ItemLazy;
/// The validation item.
type ItemValid: ItemLazy;
pub trait EventProcessorTraining<TrainEvent, ValidEvent>: Send {
/// Collect a training event.
fn process_train(&mut self, event: LearnerEvent<Self::ItemTrain>);
fn process_train(&mut self, event: TrainEvent);
/// Collect a validation event.
fn process_valid(&mut self, event: LearnerEvent<Self::ItemValid>);
fn process_valid(&mut self, event: ValidEvent);
/// Returns the renderer used for training.
fn renderer(self) -> Box<dyn MetricsRenderer>;
}
@@ -68,41 +89,62 @@ pub trait EventProcessorEvaluation: Send {
/// A learner item.
#[derive(new)]
pub struct LearnerItem<T> {
pub struct TrainingItem<T> {
/// The item.
pub item: T,
/// The progress.
pub progress: Progress,
/// The epoch.
pub epoch: usize,
/// The global progress of the training (e.g. epochs).
pub global_progress: Progress,
/// The total number of epochs.
pub epoch_total: usize,
/// The iteration.
pub iteration: usize,
/// The iteration, if it it different from the items processed.
pub iteration: Option<usize>,
/// The learning rate.
pub lr: Option<LearningRate>,
}
impl<T: ItemLazy> ItemLazy for LearnerItem<T> {
type ItemSync = LearnerItem<T::ItemSync>;
impl<T: ItemLazy> ItemLazy for TrainingItem<T> {
type ItemSync = TrainingItem<T::ItemSync>;
fn sync(self) -> Self::ItemSync {
LearnerItem {
TrainingItem {
item: self.item.sync(),
progress: self.progress,
epoch: self.epoch,
epoch_total: self.epoch_total,
global_progress: self.global_progress,
iteration: self.iteration,
lr: self.lr,
}
}
}
/// An evaluation item.
#[derive(new)]
pub struct EvaluationItem<T> {
/// The item.
pub item: T,
/// The progress.
pub progress: Progress,
/// The iteration, if it it different from the items processed.
pub iteration: Option<usize>,
}
impl<T: ItemLazy> ItemLazy for EvaluationItem<T> {
type ItemSync = EvaluationItem<T::ItemSync>;
fn sync(self) -> Self::ItemSync {
EvaluationItem {
item: self.item.sync(),
progress: self.progress,
iteration: self.iteration,
}
}
}
impl ItemLazy for () {
type ItemSync = ();

View File

@@ -1,7 +1,9 @@
use super::{EventProcessorTraining, ItemLazy, LearnerEvent, MetricsTraining};
use crate::metric::processor::{EvaluatorEvent, EventProcessorEvaluation, MetricsEvaluation};
use crate::metric::store::{EpochSummary, EventStoreClient, Split};
use crate::renderer::{MetricState, MetricsRenderer};
use crate::renderer::{
EvaluationProgress, MetricState, MetricsRenderer, ProgressType, TrainingProgress,
};
use std::sync::Arc;
/// An [event processor](EventProcessorTraining) that handles:
@@ -34,6 +36,30 @@ impl<T: ItemLazy, V: ItemLazy> FullEventProcessorTraining<T, V> {
store,
}
}
fn progress_indicators(&self, progress: &TrainingProgress) -> Vec<ProgressType> {
let mut indicators = vec![];
indicators.push(ProgressType::Detailed {
tag: String::from("Epoch"),
progress: progress.global_progress.clone(),
});
if let Some(iteration) = progress.iteration {
indicators.push(ProgressType::Value {
tag: String::from("Iteration"),
value: iteration,
});
};
if let Some(p) = &progress.progress {
indicators.push(ProgressType::Detailed {
tag: String::from("Items"),
progress: p.clone(),
});
};
indicators
}
}
impl<T: ItemLazy> FullEventProcessorEvaluation<T> {
@@ -48,6 +74,23 @@ impl<T: ItemLazy> FullEventProcessorEvaluation<T> {
store,
}
}
fn progress_indicators(&self, progress: &EvaluationProgress) -> Vec<ProgressType> {
let mut indicators = vec![];
if let Some(iteration) = progress.iteration {
indicators.push(ProgressType::Value {
tag: String::from("Iteration"),
value: iteration,
});
};
indicators.push(ProgressType::Detailed {
tag: String::from("Items"),
progress: progress.progress.clone(),
});
indicators
}
}
impl<T: ItemLazy> EventProcessorEvaluation for FullEventProcessorEvaluation<T> {
@@ -95,7 +138,8 @@ impl<T: ItemLazy> EventProcessorEvaluation for FullEventProcessorEvaluation<T> {
)
});
self.renderer.render_test(progress);
let indicators = self.progress_indicators(&progress);
self.renderer.render_test(progress, indicators);
}
EvaluatorEvent::End => {
self.renderer.on_test_end().ok();
@@ -108,11 +152,10 @@ impl<T: ItemLazy> EventProcessorEvaluation for FullEventProcessorEvaluation<T> {
}
}
impl<T: ItemLazy, V: ItemLazy> EventProcessorTraining for FullEventProcessorTraining<T, V> {
type ItemTrain = T;
type ItemValid = V;
fn process_train(&mut self, event: LearnerEvent<Self::ItemTrain>) {
impl<T: ItemLazy, V: ItemLazy> EventProcessorTraining<LearnerEvent<T>, LearnerEvent<V>>
for FullEventProcessorTraining<T, V>
{
fn process_train(&mut self, event: LearnerEvent<T>) {
match event {
LearnerEvent::Start => {
let definitions = self.metrics.metric_definitions();
@@ -149,7 +192,8 @@ impl<T: ItemLazy, V: ItemLazy> EventProcessorTraining for FullEventProcessorTrai
))
});
self.renderer.render_train(progress);
let indicators = self.progress_indicators(&progress);
self.renderer.render_train(progress, indicators);
}
LearnerEvent::EndEpoch(epoch) => {
self.store
@@ -165,7 +209,7 @@ impl<T: ItemLazy, V: ItemLazy> EventProcessorTraining for FullEventProcessorTrai
}
}
fn process_valid(&mut self, event: LearnerEvent<Self::ItemValid>) {
fn process_valid(&mut self, event: LearnerEvent<V>) {
match event {
LearnerEvent::Start => {} // no-op for now
LearnerEvent::ProcessedItem(item) => {
@@ -193,7 +237,8 @@ impl<T: ItemLazy, V: ItemLazy> EventProcessorTraining for FullEventProcessorTrai
))
});
self.renderer.render_valid(progress);
let indicators = self.progress_indicators(&progress);
self.renderer.render_valid(progress, indicators);
}
LearnerEvent::EndEpoch(epoch) => {
self.store
@@ -206,7 +251,7 @@ impl<T: ItemLazy, V: ItemLazy> EventProcessorTraining for FullEventProcessorTrai
LearnerEvent::End(_) => {} // no-op for now
}
}
fn renderer(self) -> Box<dyn crate::renderer::MetricsRenderer> {
fn renderer(self) -> Box<dyn MetricsRenderer> {
self.renderer
}
}

View File

@@ -1,7 +1,8 @@
use std::collections::HashMap;
use super::{ItemLazy, LearnerItem};
use super::{ItemLazy, TrainingItem};
use crate::{
EvaluationItem,
metric::{
Adaptor, Metric, MetricDefinition, MetricEntry, MetricId, MetricMetadata, Numeric,
store::{MetricsUpdate, NumericMetricUpdate},
@@ -83,19 +84,19 @@ impl<T: ItemLazy> MetricsEvaluation<T> {
/// Update the testing information from the testing item.
pub(crate) fn update_test(
&mut self,
item: &LearnerItem<T::ItemSync>,
item: &EvaluationItem<T::ItemSync>,
metadata: &MetricMetadata,
) -> MetricsUpdate {
let mut entries = Vec::with_capacity(self.test.len());
let mut entries_numeric = Vec::with_capacity(self.test_numeric.len());
for metric in self.test.iter_mut() {
let state = metric.update(item, metadata);
let state = metric.update(&item.item, metadata);
entries.push(state);
}
for metric in self.test_numeric.iter_mut() {
let numeric_update = metric.update(item, metadata);
let numeric_update = metric.update(&item.item, metadata);
entries_numeric.push(numeric_update);
}
@@ -162,19 +163,19 @@ impl<T: ItemLazy, V: ItemLazy> MetricsTraining<T, V> {
/// Update the training information from the training item.
pub(crate) fn update_train(
&mut self,
item: &LearnerItem<T::ItemSync>,
item: &TrainingItem<T::ItemSync>,
metadata: &MetricMetadata,
) -> MetricsUpdate {
let mut entries = Vec::with_capacity(self.train.len());
let mut entries_numeric = Vec::with_capacity(self.train_numeric.len());
for metric in self.train.iter_mut() {
let state = metric.update(item, metadata);
let state = metric.update(&item.item, metadata);
entries.push(state);
}
for metric in self.train_numeric.iter_mut() {
let numeric_update = metric.update(item, metadata);
let numeric_update = metric.update(&item.item, metadata);
entries_numeric.push(numeric_update);
}
@@ -184,19 +185,19 @@ impl<T: ItemLazy, V: ItemLazy> MetricsTraining<T, V> {
/// Update the training information from the validation item.
pub(crate) fn update_valid(
&mut self,
item: &LearnerItem<V::ItemSync>,
item: &TrainingItem<V::ItemSync>,
metadata: &MetricMetadata,
) -> MetricsUpdate {
let mut entries = Vec::with_capacity(self.valid.len());
let mut entries_numeric = Vec::with_capacity(self.valid_numeric.len());
for metric in self.valid.iter_mut() {
let state = metric.update(item, metadata);
let state = metric.update(&item.item, metadata);
entries.push(state);
}
for metric in self.valid_numeric.iter_mut() {
let numeric_update = metric.update(item, metadata);
let numeric_update = metric.update(&item.item, metadata);
entries_numeric.push(numeric_update);
}
@@ -224,19 +225,28 @@ impl<T: ItemLazy, V: ItemLazy> MetricsTraining<T, V> {
}
}
impl<T> From<&LearnerItem<T>> for TrainingProgress {
fn from(item: &LearnerItem<T>) -> Self {
impl<T> From<&TrainingItem<T>> for TrainingProgress {
fn from(item: &TrainingItem<T>) -> Self {
Self {
progress: item.progress.clone(),
epoch: item.epoch,
epoch_total: item.epoch_total,
progress: Some(item.progress.clone()),
global_progress: item.global_progress.clone(),
iteration: item.iteration,
}
}
}
impl<T> From<&LearnerItem<T>> for EvaluationProgress {
fn from(item: &LearnerItem<T>) -> Self {
impl<T> From<&EvaluationItem<T>> for TrainingProgress {
fn from(item: &EvaluationItem<T>) -> Self {
Self {
progress: None,
global_progress: item.progress.clone(),
iteration: item.iteration,
}
}
}
impl<T> From<&EvaluationItem<T>> for EvaluationProgress {
fn from(item: &EvaluationItem<T>) -> Self {
Self {
progress: item.progress.clone(),
iteration: item.iteration,
@@ -244,31 +254,41 @@ impl<T> From<&LearnerItem<T>> for EvaluationProgress {
}
}
impl<T> From<&LearnerItem<T>> for MetricMetadata {
fn from(item: &LearnerItem<T>) -> Self {
impl<T> From<&TrainingItem<T>> for MetricMetadata {
fn from(item: &TrainingItem<T>) -> Self {
Self {
progress: item.progress.clone(),
epoch: item.epoch,
epoch_total: item.epoch_total,
global_progress: item.global_progress.clone(),
iteration: item.iteration,
lr: item.lr,
}
}
}
trait NumericMetricUpdater<T>: Send + Sync {
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> NumericMetricUpdate;
impl<T> From<&EvaluationItem<T>> for MetricMetadata {
fn from(item: &EvaluationItem<T>) -> Self {
Self {
progress: item.progress.clone(),
global_progress: item.progress.clone(),
iteration: item.iteration,
lr: None,
}
}
}
pub(crate) trait NumericMetricUpdater<T>: Send + Sync {
fn update(&mut self, item: &T, metadata: &MetricMetadata) -> NumericMetricUpdate;
fn clear(&mut self);
}
trait MetricUpdater<T>: Send + Sync {
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> MetricEntry;
pub(crate) trait MetricUpdater<T>: Send + Sync {
fn update(&mut self, item: &T, metadata: &MetricMetadata) -> MetricEntry;
fn clear(&mut self);
}
struct MetricWrapper<M> {
id: MetricId,
metric: M,
pub(crate) struct MetricWrapper<M> {
pub id: MetricId,
pub metric: M,
}
impl<M: Metric> MetricWrapper<M> {
@@ -286,8 +306,8 @@ where
M: Metric + Numeric + 'static,
T: Adaptor<M::Input>,
{
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> NumericMetricUpdate {
let serialized_entry = self.metric.update(&item.item.adapt(), metadata);
fn update(&mut self, item: &T, metadata: &MetricMetadata) -> NumericMetricUpdate {
let serialized_entry = self.metric.update(&item.adapt(), metadata);
let update = MetricEntry::new(self.id.clone(), serialized_entry);
let numeric = self.metric.value();
let running = self.metric.running_value();
@@ -310,8 +330,8 @@ where
M: Metric + 'static,
T: Adaptor<M::Input>,
{
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> MetricEntry {
let serialized_entry = self.metric.update(&item.item.adapt(), metadata);
fn update(&mut self, item: &T, metadata: &MetricMetadata) -> MetricEntry {
let serialized_entry = self.metric.update(&item.adapt(), metadata);
MetricEntry::new(self.id.clone(), serialized_entry)
}

View File

@@ -14,11 +14,10 @@ pub(crate) struct MinimalEventProcessor<T: ItemLazy, V: ItemLazy> {
store: Arc<EventStoreClient>,
}
impl<T: ItemLazy, V: ItemLazy> EventProcessorTraining for MinimalEventProcessor<T, V> {
type ItemTrain = T;
type ItemValid = V;
fn process_train(&mut self, event: LearnerEvent<Self::ItemTrain>) {
impl<T: ItemLazy, V: ItemLazy> EventProcessorTraining<LearnerEvent<T>, LearnerEvent<V>>
for MinimalEventProcessor<T, V>
{
fn process_train(&mut self, event: LearnerEvent<T>) {
match event {
LearnerEvent::Start => {
let definitions = self.metrics.metric_definitions();
@@ -47,7 +46,7 @@ impl<T: ItemLazy, V: ItemLazy> EventProcessorTraining for MinimalEventProcessor<
}
}
fn process_valid(&mut self, event: LearnerEvent<Self::ItemValid>) {
fn process_valid(&mut self, event: LearnerEvent<V>) {
match event {
LearnerEvent::Start => {} // no-op for now
LearnerEvent::ProcessedItem(item) => {

View File

@@ -3,10 +3,14 @@ mod base;
mod full;
mod metrics;
mod minimal;
mod rl_metrics;
mod rl_processor;
pub use base::*;
pub(crate) use full::*;
pub(crate) use metrics::*;
pub(crate) use rl_metrics::*;
pub(crate) use rl_processor::*;
#[cfg(test)]
pub(crate) use minimal::*;
@@ -17,7 +21,7 @@ pub use async_wrapper::{AsyncProcessorEvaluation, AsyncProcessorTraining};
pub(crate) mod test_utils {
use crate::metric::{
Adaptor, LossInput,
processor::{EventProcessorTraining, LearnerEvent, LearnerItem, MinimalEventProcessor},
processor::{EventProcessorTraining, LearnerEvent, MinimalEventProcessor, TrainingItem},
};
use burn_core::tensor::{ElementConversion, Tensor, backend::Backend};
@@ -47,14 +51,16 @@ pub(crate) mod test_utils {
items_processed: 1,
items_total: 10,
};
let num_epochs = 3;
let dummy_iteration = 1;
let dummy_global_progress = burn_core::data::dataloader::Progress {
items_processed: epoch,
items_total: 3,
};
let dummy_iteration = Some(1);
processor.process_train(LearnerEvent::ProcessedItem(LearnerItem::new(
processor.process_train(LearnerEvent::ProcessedItem(TrainingItem::new(
value,
dummy_progress,
epoch,
num_epochs,
dummy_global_progress,
dummy_iteration,
None,
)));

View File

@@ -0,0 +1,268 @@
use std::collections::HashMap;
use crate::{
EpisodeSummary, EvaluationItem, ItemLazy, MetricUpdater, MetricWrapper, NumericMetricUpdater,
metric::{
Adaptor, Metric, MetricDefinition, MetricId, MetricMetadata, Numeric, store::MetricsUpdate,
},
};
pub(crate) struct RLMetrics<TS: ItemLazy, ES: ItemLazy> {
train_step: Vec<Box<dyn MetricUpdater<TS::ItemSync>>>,
env_step: Vec<Box<dyn MetricUpdater<ES::ItemSync>>>,
env_step_valid: Vec<Box<dyn MetricUpdater<ES::ItemSync>>>,
episode_end: Vec<Box<dyn MetricUpdater<EpisodeSummary>>>,
episode_end_valid: Vec<Box<dyn MetricUpdater<EpisodeSummary>>>,
train_step_numeric: Vec<Box<dyn NumericMetricUpdater<TS::ItemSync>>>,
env_step_numeric: Vec<Box<dyn NumericMetricUpdater<ES::ItemSync>>>,
env_step_valid_numeric: Vec<Box<dyn NumericMetricUpdater<ES::ItemSync>>>,
episode_end_numeric: Vec<Box<dyn NumericMetricUpdater<EpisodeSummary>>>,
episode_end_valid_numeric: Vec<Box<dyn NumericMetricUpdater<EpisodeSummary>>>,
metric_definitions: HashMap<MetricId, MetricDefinition>,
}
impl<TS: ItemLazy, ES: ItemLazy> Default for RLMetrics<TS, ES> {
fn default() -> Self {
Self {
train_step: Vec::default(),
env_step: Vec::default(),
env_step_valid: Vec::default(),
episode_end: Vec::default(),
episode_end_valid: Vec::default(),
train_step_numeric: Vec::default(),
env_step_numeric: Vec::default(),
env_step_valid_numeric: Vec::default(),
episode_end_numeric: Vec::default(),
episode_end_valid_numeric: Vec::default(),
metric_definitions: HashMap::default(),
}
}
}
impl<TS: ItemLazy, ES: ItemLazy> RLMetrics<TS, ES> {
/// Register a training metric.
pub(crate) fn register_text_metric_agent<Me: Metric + 'static>(&mut self, metric: Me)
where
ES::ItemSync: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.env_step.push(Box::new(metric))
}
/// Register a training metric.
pub(crate) fn register_agent_metric<Me: Metric + Numeric + 'static>(&mut self, metric: Me)
where
ES::ItemSync: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.env_step_numeric.push(Box::new(metric))
}
/// Register a training metric.
pub(crate) fn register_text_metric_train<Me: Metric + 'static>(&mut self, metric: Me)
where
TS::ItemSync: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.train_step.push(Box::new(metric))
}
/// Register a training metric.
pub(crate) fn register_metric_train<Me: Metric + Numeric + 'static>(&mut self, metric: Me)
where
TS::ItemSync: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.train_step_numeric.push(Box::new(metric))
}
/// Register a validation env-step metric.
pub(crate) fn register_text_metric_agent_valid<Me: Metric + 'static>(&mut self, metric: Me)
where
ES::ItemSync: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.env_step_valid.push(Box::new(metric))
}
/// Register a validation env-step numeric metric.
pub(crate) fn register_agent_metric_valid<Me: Metric + Numeric + 'static>(&mut self, metric: Me)
where
ES::ItemSync: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.env_step_valid_numeric.push(Box::new(metric))
}
/// Register an episode-end metric.
pub(crate) fn register_text_metric_episode<Me: Metric + 'static>(&mut self, metric: Me)
where
EpisodeSummary: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.episode_end.push(Box::new(metric))
}
/// Register an episode-end numeric metric.
pub(crate) fn register_episode_metric<Me: Metric + Numeric + 'static>(&mut self, metric: Me)
where
EpisodeSummary: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.episode_end_numeric.push(Box::new(metric))
}
/// Register an episode-end metric for validation.
pub(crate) fn register_text_metric_episode_valid<Me: Metric + 'static>(&mut self, metric: Me)
where
EpisodeSummary: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.episode_end_valid.push(Box::new(metric))
}
/// Register an episode-end numeric metric for validation.
pub(crate) fn register_episode_metric_valid<Me: Metric + Numeric + 'static>(
&mut self,
metric: Me,
) where
EpisodeSummary: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.episode_end_valid_numeric.push(Box::new(metric))
}
fn register_definition<Me: Metric>(&mut self, metric: &MetricWrapper<Me>) {
self.metric_definitions.insert(
metric.id.clone(),
MetricDefinition::new(metric.id.clone(), &metric.metric),
);
}
/// Get metric definitions for all splits
pub(crate) fn metric_definitions(&mut self) -> Vec<MetricDefinition> {
self.metric_definitions.values().cloned().collect()
}
/// Update the training information from the training item.
pub(crate) fn update_train_step(
&mut self,
item: &EvaluationItem<TS::ItemSync>,
metadata: &MetricMetadata,
) -> MetricsUpdate {
let mut entries = Vec::with_capacity(self.train_step.len());
let mut entries_numeric = Vec::with_capacity(self.train_step_numeric.len());
for metric in self.train_step.iter_mut() {
let state = metric.update(&item.item, metadata);
entries.push(state);
}
for metric in self.train_step_numeric.iter_mut() {
let numeric_update = metric.update(&item.item, metadata);
entries_numeric.push(numeric_update);
}
MetricsUpdate::new(entries, entries_numeric)
}
/// Update the env-step metrics from an environment step item.
pub(crate) fn update_env_step(
&mut self,
item: &EvaluationItem<ES::ItemSync>,
metadata: &MetricMetadata,
) -> MetricsUpdate {
let mut entries = Vec::with_capacity(self.env_step.len());
let mut entries_numeric = Vec::with_capacity(self.env_step_numeric.len());
for metric in self.env_step.iter_mut() {
let state = metric.update(&item.item, metadata);
entries.push(state);
}
for metric in self.env_step_numeric.iter_mut() {
let numeric_update = metric.update(&item.item, metadata);
entries_numeric.push(numeric_update);
}
MetricsUpdate::new(entries, entries_numeric)
}
/// Update the env-step metrics for validation from an environment step item.
pub(crate) fn update_env_step_valid(
&mut self,
item: &EvaluationItem<ES::ItemSync>,
metadata: &MetricMetadata,
) -> MetricsUpdate {
let mut entries = Vec::with_capacity(self.env_step_valid.len());
let mut entries_numeric = Vec::with_capacity(self.env_step_valid_numeric.len());
for metric in self.env_step_valid.iter_mut() {
let state = metric.update(&item.item, metadata);
entries.push(state);
}
for metric in self.env_step_valid_numeric.iter_mut() {
let numeric_update = metric.update(&item.item, metadata);
entries_numeric.push(numeric_update);
}
MetricsUpdate::new(entries, entries_numeric)
}
/// Update the episode-end metrics from an episode summary.
pub(crate) fn update_episode_end(
&mut self,
item: &EvaluationItem<EpisodeSummary>,
metadata: &MetricMetadata,
) -> MetricsUpdate {
let mut entries = Vec::with_capacity(self.episode_end.len());
let mut entries_numeric = Vec::with_capacity(self.episode_end_numeric.len());
for metric in self.episode_end.iter_mut() {
let state = metric.update(&item.item, metadata);
entries.push(state);
}
for metric in self.episode_end_numeric.iter_mut() {
let numeric_update = metric.update(&item.item, metadata);
entries_numeric.push(numeric_update);
}
MetricsUpdate::new(entries, entries_numeric)
}
/// Update the episode-end metrics for validation from an episode summary.
pub(crate) fn update_episode_end_valid(
&mut self,
item: &EvaluationItem<EpisodeSummary>,
metadata: &MetricMetadata,
) -> MetricsUpdate {
let mut entries = Vec::with_capacity(self.episode_end_valid.len());
let mut entries_numeric = Vec::with_capacity(self.episode_end_valid_numeric.len());
for metric in self.episode_end_valid.iter_mut() {
let state = metric.update(&item.item, metadata);
entries.push(state);
}
for metric in self.episode_end_valid_numeric.iter_mut() {
let numeric_update = metric.update(&item.item, metadata);
entries_numeric.push(numeric_update);
}
MetricsUpdate::new(entries, entries_numeric)
}
}

View File

@@ -0,0 +1,151 @@
use std::sync::Arc;
use crate::{
AgentEvaluationEvent, EventProcessorTraining, ItemLazy, RLEvent, RLMetrics,
metric::store::{Event, EventStoreClient, MetricsUpdate},
renderer::{MetricState, MetricsRenderer, ProgressType, TrainingProgress},
};
/// An [event processor](EventProcessorTraining) that handles:
/// - Computing and storing metrics in an [event store](crate::metric::store::EventStore).
/// - Render metrics using a [metrics renderer](MetricsRenderer).
#[derive(new)]
pub struct RLEventProcessor<TS: ItemLazy, ES: ItemLazy> {
metrics: RLMetrics<TS, ES>,
renderer: Box<dyn MetricsRenderer>,
store: Arc<EventStoreClient>,
}
impl<TS: ItemLazy, ES: ItemLazy> RLEventProcessor<TS, ES> {
fn progress_indicators(&self, progress: &TrainingProgress) -> Vec<ProgressType> {
let indicators = vec![ProgressType::Detailed {
tag: String::from("Step"),
progress: progress.global_progress.clone(),
}];
indicators
}
fn progress_indicators_eval(&self, progress: &TrainingProgress) -> Vec<ProgressType> {
let indicators = vec![ProgressType::Detailed {
tag: String::from("Step"),
progress: progress.global_progress.clone(),
}];
indicators
}
}
impl<TS: ItemLazy, ES: ItemLazy> RLEventProcessor<TS, ES> {
fn process_update_train(&mut self, update: MetricsUpdate) {
self.store
.add_event_train(crate::metric::store::Event::MetricsUpdate(update.clone()));
update
.entries
.into_iter()
.for_each(|entry| self.renderer.update_train(MetricState::Generic(entry)));
update
.entries_numeric
.into_iter()
.for_each(|numeric_update| {
self.renderer.update_train(MetricState::Numeric(
numeric_update.entry,
numeric_update.numeric_entry,
))
});
}
fn process_update_valid(&mut self, update: MetricsUpdate) {
self.store
.add_event_valid(crate::metric::store::Event::MetricsUpdate(update.clone()));
update
.entries
.into_iter()
.for_each(|entry| self.renderer.update_valid(MetricState::Generic(entry)));
update
.entries_numeric
.into_iter()
.for_each(|numeric_update| {
self.renderer.update_valid(MetricState::Numeric(
numeric_update.entry,
numeric_update.numeric_entry,
))
});
}
}
impl<TS: ItemLazy, ES: ItemLazy> EventProcessorTraining<RLEvent<TS, ES>, AgentEvaluationEvent<ES>>
for RLEventProcessor<TS, ES>
{
fn process_train(&mut self, event: RLEvent<TS, ES>) {
match event {
RLEvent::Start => {
let definitions = self.metrics.metric_definitions();
self.store
.add_event_train(Event::MetricsInit(definitions.clone()));
definitions
.iter()
.for_each(|definition| self.renderer.register_metric(definition.clone()));
}
RLEvent::TrainStep(item) => {
let item = item.sync();
let metadata = (&item).into();
let update = self.metrics.update_train_step(&item, &metadata);
self.process_update_train(update);
}
RLEvent::TimeStep(item) => {
let item = item.sync();
let progress = (&item).into();
let metadata = (&item).into();
let update = self.metrics.update_env_step(&item, &metadata);
self.process_update_train(update);
let status = self.progress_indicators(&progress);
self.renderer.render_train(progress, status);
}
RLEvent::EpisodeEnd(item) => {
let item = item.sync();
let metadata = (&item).into();
let update = self.metrics.update_episode_end(&item, &metadata);
self.process_update_train(update);
}
RLEvent::End(learner_summary) => {
self.renderer.on_train_end(learner_summary).ok();
}
}
}
fn process_valid(&mut self, event: AgentEvaluationEvent<ES>) {
match event {
AgentEvaluationEvent::Start => {} // no-op for now
AgentEvaluationEvent::TimeStep(item) => {
let item = item.sync();
let metadata = (&item).into();
let update = self.metrics.update_env_step_valid(&item, &metadata);
self.process_update_valid(update);
}
AgentEvaluationEvent::EpisodeEnd(item) => {
let item = item.sync();
let progress = (&item).into();
let metadata = (&item).into();
let update = self.metrics.update_episode_end_valid(&item, &metadata);
self.process_update_valid(update);
let status = self.progress_indicators_eval(&progress);
self.renderer.render_valid(progress, status);
}
AgentEvaluationEvent::End => {} // no-op for now
}
}
fn renderer(self) -> Box<dyn MetricsRenderer> {
self.renderer
}
}

View File

@@ -0,0 +1,78 @@
use std::sync::Arc;
use super::super::{
MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry,
state::{FormatOptions, NumericMetricState},
};
use crate::metric::{Metric, MetricName, Numeric, SerializedEntry};
/// Metric for the cumulative reward of the last completed episode.
#[derive(Clone)]
pub struct CumulativeRewardMetric {
name: MetricName,
state: NumericMetricState,
}
impl CumulativeRewardMetric {
/// Creates a new episode length metric.
pub fn new() -> Self {
Self {
name: Arc::new("Cum. Reward".to_string()),
state: NumericMetricState::new(),
}
}
}
impl Default for CumulativeRewardMetric {
fn default() -> Self {
Self::new()
}
}
/// The [CumulativeRewardMetric](CumulativeRewardMetric) input type.
#[derive(new)]
pub struct CumulativeRewardInput {
cum_reward: f64,
}
impl Metric for CumulativeRewardMetric {
type Input = CumulativeRewardInput;
fn update(
&mut self,
item: &CumulativeRewardInput,
_metadata: &MetricMetadata,
) -> SerializedEntry {
self.state.update(
item.cum_reward,
1,
FormatOptions::new(self.name()).precision(2),
)
}
fn clear(&mut self) {
self.state.reset()
}
fn name(&self) -> MetricName {
self.name.clone()
}
fn attributes(&self) -> MetricAttributes {
NumericAttributes {
unit: None,
higher_is_better: true,
}
.into()
}
}
impl Numeric for CumulativeRewardMetric {
fn value(&self) -> NumericEntry {
self.state.current_value()
}
fn running_value(&self) -> NumericEntry {
self.state.running_value()
}
}

View File

@@ -0,0 +1,71 @@
use std::sync::Arc;
use super::super::{
MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry,
state::{FormatOptions, NumericMetricState},
};
use crate::metric::{Metric, MetricName, Numeric, SerializedEntry};
/// Metric for the length of the last completed episode.
#[derive(Clone)]
pub struct EpisodeLengthMetric {
name: MetricName,
state: NumericMetricState,
}
impl EpisodeLengthMetric {
/// Creates a new episode length metric.
pub fn new() -> Self {
Self {
name: Arc::new("Episode length".to_string()),
state: NumericMetricState::new(),
}
}
}
impl Default for EpisodeLengthMetric {
fn default() -> Self {
Self::new()
}
}
/// The [EpisodeLengthMetric](EpisodeLengthMetric) input type.
#[derive(new)]
pub struct EpisodeLengthInput {
ep_len: f64,
}
impl Metric for EpisodeLengthMetric {
type Input = EpisodeLengthInput;
fn update(&mut self, item: &EpisodeLengthInput, _metadata: &MetricMetadata) -> SerializedEntry {
self.state
.update(item.ep_len, 1, FormatOptions::new(self.name()).precision(0))
}
fn clear(&mut self) {
self.state.reset()
}
fn name(&self) -> MetricName {
self.name.clone()
}
fn attributes(&self) -> MetricAttributes {
NumericAttributes {
unit: Some(String::from("steps")),
higher_is_better: true,
}
.into()
}
}
impl Numeric for EpisodeLengthMetric {
fn value(&self) -> NumericEntry {
self.state.current_value()
}
fn running_value(&self) -> NumericEntry {
self.state.running_value()
}
}

View File

@@ -0,0 +1,78 @@
use std::sync::Arc;
use super::super::{
MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry,
state::{FormatOptions, NumericMetricState},
};
use crate::metric::{Metric, MetricName, Numeric, SerializedEntry};
/// Metric for the length of the last completed episode.
#[derive(Clone)]
pub struct ExplorationRateMetric {
name: MetricName,
state: NumericMetricState,
}
impl ExplorationRateMetric {
/// Creates a new episode length metric.
pub fn new() -> Self {
Self {
name: Arc::new("Exploration rate".to_string()),
state: NumericMetricState::new(),
}
}
}
impl Default for ExplorationRateMetric {
fn default() -> Self {
Self::new()
}
}
/// The [ExplorationRateMetric](ExplorationRateMetric) input type.
#[derive(new)]
pub struct ExplorationRateInput {
exploration_rate: f64,
}
impl Metric for ExplorationRateMetric {
type Input = ExplorationRateInput;
fn update(
&mut self,
item: &ExplorationRateInput,
_metadata: &MetricMetadata,
) -> SerializedEntry {
self.state.update(
item.exploration_rate,
1,
FormatOptions::new(self.name()).precision(3),
)
}
fn clear(&mut self) {
self.state.reset()
}
fn name(&self) -> MetricName {
self.name.clone()
}
fn attributes(&self) -> MetricAttributes {
NumericAttributes {
unit: Some(String::from("%")),
higher_is_better: false,
}
.into()
}
}
impl Numeric for ExplorationRateMetric {
fn value(&self) -> NumericEntry {
self.state.current_value()
}
fn running_value(&self) -> NumericEntry {
self.state.running_value()
}
}

View File

@@ -0,0 +1,7 @@
mod cum_reward;
mod ep_len;
mod exploration_rate;
pub use cum_reward::*;
pub use ep_len::*;
pub use exploration_rate::*;

View File

@@ -109,6 +109,7 @@ impl NumericMetricState {
None => (format!("{value_current}"), format!("{value_running}")),
};
// TODO: naming inconsistent with RL.
let formatted = match format.unit {
Some(unit) => {
format!("epoch {formatted_running} {unit} - batch {formatted_current} {unit}")

View File

@@ -27,14 +27,14 @@ pub trait MetricsRendererTraining: Send + Sync {
/// # Arguments
///
/// * `item` - The training progress.
fn render_train(&mut self, item: TrainingProgress);
fn render_train(&mut self, item: TrainingProgress, progress_indicators: Vec<ProgressType>);
/// Renders the validation progress.
///
/// # Arguments
///
/// * `item` - The validation progress.
fn render_valid(&mut self, item: TrainingProgress);
fn render_valid(&mut self, item: TrainingProgress, progress_indicators: Vec<ProgressType>);
/// Callback method invoked when training ends, whether it
/// completed successfully or was interrupted.
@@ -58,7 +58,7 @@ pub trait MetricsRenderer: MetricsRendererEvaluation + MetricsRendererTraining {
/// Keep the renderer from automatically closing, requiring manual action to close it.
fn manual_close(&mut self);
/// Register a new metric.
fn register_metric(&mut self, _definition: MetricDefinition);
fn register_metric(&mut self, definition: MetricDefinition);
}
#[derive(Clone)]
@@ -102,7 +102,7 @@ pub trait MetricsRendererEvaluation: Send + Sync {
/// # Arguments
///
/// * `item` - The training progress.
fn render_test(&mut self, item: EvaluationProgress);
fn render_test(&mut self, item: EvaluationProgress, progress_indicators: Vec<ProgressType>);
/// Callback method invoked when testing ends, whether it
/// completed successfully or was interrupted.
@@ -128,16 +128,13 @@ pub enum MetricState {
#[derive(Debug)]
pub struct TrainingProgress {
/// The progress.
pub progress: Progress,
pub progress: Option<Progress>,
/// The epoch.
pub epoch: usize,
/// The progress of the whole training.
pub global_progress: Progress,
/// The total number of epochs.
pub epoch_total: usize,
/// The iteration.
pub iteration: usize,
/// The iteration, if it differs from the items processed.
pub iteration: Option<usize>,
}
/// Evaluation progress.
@@ -146,21 +143,48 @@ pub struct EvaluationProgress {
/// The progress.
pub progress: Progress,
/// The iteration.
pub iteration: usize,
/// The iteration, if it is different from the processed items.
pub iteration: Option<usize>,
}
impl From<&EvaluationProgress> for TrainingProgress {
fn from(value: &EvaluationProgress) -> Self {
TrainingProgress {
progress: None,
global_progress: value.progress.clone(),
iteration: value.iteration,
}
}
}
impl TrainingProgress {
/// Creates a new empty training progress.
pub fn none() -> Self {
Self {
progress: Progress {
progress: None,
global_progress: Progress {
items_processed: 0,
items_total: 0,
},
epoch: 0,
epoch_total: 0,
iteration: 0,
iteration: None,
}
}
}
/// Type of progress indicators.
pub enum ProgressType {
/// Detailed progress.
Detailed {
/// The tag.
tag: String,
/// The progress.
progress: Progress,
},
/// Simple value.
Value {
/// The tag.
tag: String,
/// The value.
value: usize,
},
}

View File

@@ -1,6 +1,6 @@
use crate::renderer::{
EvaluationProgress, MetricState, MetricsRenderer, MetricsRendererEvaluation,
MetricsRendererTraining, TrainingProgress,
MetricsRendererTraining, ProgressType, TrainingProgress,
};
/// A simple renderer for when the cli feature is not enabled.
@@ -19,17 +19,17 @@ impl MetricsRendererTraining for CliMetricsRenderer {
fn update_valid(&mut self, _state: MetricState) {}
fn render_train(&mut self, item: TrainingProgress) {
fn render_train(&mut self, item: TrainingProgress, _progress_indicators: Vec<ProgressType>) {
println!("{item:?}");
}
fn render_valid(&mut self, item: TrainingProgress) {
fn render_valid(&mut self, item: TrainingProgress, _progress_indicators: Vec<ProgressType>) {
println!("{item:?}");
}
}
impl MetricsRendererEvaluation for CliMetricsRenderer {
fn render_test(&mut self, item: EvaluationProgress) {
fn render_test(&mut self, item: EvaluationProgress, _progress_indicators: Vec<ProgressType>) {
println!("{item:?}");
}

View File

@@ -27,7 +27,7 @@ pub(crate) fn default_renderer(
) -> Box<dyn MetricsRenderer> {
#[cfg(feature = "tui")]
if std::io::stdout().is_terminal() {
return Box::new(tui::TuiMetricsRenderer::new(interuptor, checkpoint));
return Box::new(tui::TuiMetricsRendererWrapper::new(interuptor, checkpoint));
}
Box::new(CliMetricsRenderer::new())

View File

@@ -65,13 +65,14 @@ impl NumericMetricsState {
/// Update the state with the training progress.
pub(crate) fn update_progress_train(&mut self, progress: &TrainingProgress) {
self.epoch = progress.epoch;
self.epoch = progress.global_progress.items_processed;
if self.num_samples_train.is_some() {
return;
}
self.num_samples_train = Some(progress.progress.items_total);
// If the training only has the notion of global progress, num_samples_train remains None.
self.num_samples_train = progress.progress.as_ref().map(|p| p.items_total);
}
/// Update the state with the validation progress.
@@ -80,16 +81,20 @@ impl NumericMetricsState {
return;
}
// If num_samples_train is None, keep the default max_samples for validation.
if let Some(num_sample_train) = self.num_samples_train {
for (_, (_recent, full)) in self.data.iter_mut() {
let ratio = progress.progress.items_total as f64 / num_sample_train as f64;
let ratio = match &progress.progress {
Some(p) => p.items_total as f64 / num_sample_train as f64,
None => progress.global_progress.items_total as f64 / num_sample_train as f64,
};
full.update_max_sample(TuiSplit::Valid, ratio);
}
}
self.epoch = progress.epoch;
self.num_samples_valid = Some(progress.progress.items_total);
self.epoch = progress.global_progress.items_processed;
self.num_samples_valid = progress.progress.as_ref().map(|p| p.items_total);
}
/// Update the state with the testing progress.

View File

@@ -36,8 +36,12 @@ impl ProgressBarState {
/// Update the training progress.
pub(crate) fn update_train(&mut self, progress: &TrainingProgress) {
self.progress_total = calculate_progress(progress, 0, 0);
let local_progress = progress
.progress
.as_ref()
.unwrap_or(&progress.global_progress);
self.progress_task =
progress.progress.items_processed as f64 / progress.progress.items_total as f64;
local_progress.items_processed as f64 / local_progress.items_total as f64;
self.estimate.update(progress, self.starting_epoch);
self.split = TuiSplit::Train;
}
@@ -45,8 +49,12 @@ impl ProgressBarState {
/// Update the validation progress.
pub(crate) fn update_valid(&mut self, progress: &TrainingProgress) {
// We don't use the validation for the total progress yet.
let local_progress = progress
.progress
.as_ref()
.unwrap_or(&progress.global_progress);
self.progress_task =
progress.progress.items_processed as f64 / progress.progress.items_total as f64;
local_progress.items_processed as f64 / local_progress.items_total as f64;
self.split = TuiSplit::Valid;
}
@@ -187,7 +195,7 @@ impl ProgressEstimate {
}
// When the training has started since at least 10 seconds and completed 10 iterations.
if progress.iteration >= WARMUP_NUM_ITERATION
if progress.iteration >= Some(WARMUP_NUM_ITERATION)
&& self.started.elapsed() > Duration::from_secs(10)
{
self.init(progress, starting_epoch);
@@ -195,11 +203,17 @@ impl ProgressEstimate {
}
fn init(&mut self, progress: &TrainingProgress, starting_epoch: usize) {
let epoch = progress.epoch - starting_epoch;
let epoch_items = (epoch - 1) * progress.progress.items_total;
let iteration_items = progress.progress.items_processed;
let epoch = progress.global_progress.items_processed - starting_epoch;
self.warmup_num_items = match &progress.progress {
Some(local_progress) => {
let epoch_items = (epoch - 1) * local_progress.items_total;
let iteration_items = local_progress.items_processed;
epoch_items + iteration_items
}
None => epoch,
};
self.warmup_num_items = epoch_items + iteration_items;
self.started_after_warmup = Some(Instant::now());
self.progress = calculate_progress(progress, starting_epoch, self.warmup_num_items);
}
@@ -210,15 +224,19 @@ fn calculate_progress(
starting_epoch: usize,
ignore_num_items: usize,
) -> f64 {
let epoch_total = progress.epoch_total - starting_epoch;
let epoch = progress.epoch - starting_epoch;
let epoch_total = progress.global_progress.items_total - starting_epoch;
let epoch = progress.global_progress.items_processed - starting_epoch;
match &progress.progress {
Some(local_progress) => {
let total_items = local_progress.items_total * epoch_total;
let epoch_items = (epoch - 1) * local_progress.items_total;
let iteration_items = local_progress.items_processed;
let num_items = epoch_items + iteration_items - ignore_num_items;
let total_items = progress.progress.items_total * epoch_total;
let epoch_items = (epoch - 1) * progress.progress.items_total;
let iteration_items = progress.progress.items_processed;
let num_items = epoch_items + iteration_items - ignore_num_items;
num_items as f64 / total_items as f64
num_items as f64 / total_items as f64
}
None => epoch as f64 / epoch_total as f64,
}
}
fn format_eta(eta_secs: u64) -> String {
@@ -268,11 +286,14 @@ mod tests {
items_processed: 5,
items_total: 10,
};
let global_progress = Progress {
items_processed: 9,
items_total: 10,
};
let progress = TrainingProgress {
progress: half,
epoch: 9,
epoch_total: 10,
iteration: 500,
progress: Some(half),
global_progress: global_progress,
iteration: Some(500),
};
let starting_epoch = 8;
@@ -288,11 +309,14 @@ mod tests {
items_processed: 110,
items_total: 1000,
};
let global_progress = Progress {
items_processed: 9,
items_total: 10,
};
let progress = TrainingProgress {
progress: half,
epoch: 9,
epoch_total: 10,
iteration: 500,
progress: Some(half),
global_progress: global_progress,
iteration: Some(500),
};
let starting_epoch = 8;

View File

@@ -2,7 +2,7 @@ use crate::metric::{MetricDefinition, MetricId};
use crate::renderer::tui::TuiSplit;
use crate::renderer::{
EvaluationName, EvaluationProgress, MetricState, MetricsRenderer, MetricsRendererEvaluation,
TrainingProgress,
ProgressType, TrainingProgress,
};
use crate::renderer::{MetricsRendererTraining, tui::NumericMetricsState};
use crate::{Interrupter, LearnerSummary};
@@ -17,7 +17,9 @@ use ratatui::{
};
use std::collections::HashMap;
use std::panic::{set_hook, take_hook};
use std::sync::Arc;
use std::sync::mpsc::{Receiver, Sender};
use std::sync::{Arc, Mutex, mpsc};
use std::thread::{JoinHandle, spawn};
use std::{
error::Error,
io::{self, Stdout},
@@ -38,8 +40,86 @@ type PanicHook = Box<dyn Fn(&std::panic::PanicHookInfo<'_>) + 'static + Sync + S
const MAX_REFRESH_RATE_MILLIS: u64 = 100;
enum TuiRendererEvent {
MetricRegistration(MetricDefinition),
MetricsUpdate((TuiSplit, TuiGroup, MetricState)),
StatusUpdateTrain((TuiSplit, TrainingProgress, Vec<ProgressType>)),
StatusUpdateTest((EvaluationProgress, Vec<ProgressType>)),
TrainEnd(Option<LearnerSummary>),
ManualClose(),
Close(),
Persistent(),
}
/// The terminal UI metrics renderer.
pub struct TuiMetricsRenderer {
pub struct TuiMetricsRendererWrapper {
sender: mpsc::Sender<TuiRendererEvent>,
interrupter: Interrupter,
handle_join: Option<JoinHandle<()>>,
kill_signal: Arc<Mutex<Receiver<()>>>,
}
impl TuiMetricsRendererWrapper {
/// Create a new terminal UI renderer.
pub fn new(interrupter: Interrupter, checkpoint: Option<usize>) -> Self {
let (sender, receiver) = mpsc::channel();
let (kill_signal_sender, kill_signal_receiver) = mpsc::channel();
let interrupter_clone = interrupter.clone();
let handle_join = spawn(move || {
let mut renderer =
TuiMetricsRenderer::new(interrupter_clone, checkpoint, kill_signal_sender);
let tick_rate = Duration::from_millis(MAX_REFRESH_RATE_MILLIS);
loop {
match receiver.try_recv() {
Ok(event) => renderer.handle_event(event),
Err(mpsc::TryRecvError::Empty) => (),
Err(mpsc::TryRecvError::Disconnected) => {
log::error!("Renderer thread disconnected.");
break;
}
}
// Render
if renderer.last_update.elapsed() >= tick_rate
&& let Err(err) = renderer.render()
{
log::error!("Render error: {err}");
break;
}
if (renderer.manual_close && renderer.interrupter.should_stop()) || renderer.close {
break;
}
}
});
Self {
sender,
interrupter,
handle_join: Some(handle_join),
kill_signal: Arc::new(Mutex::new(kill_signal_receiver)),
}
}
fn send_event(&self, event: TuiRendererEvent) {
if self.kill_signal.lock().unwrap().try_recv().is_ok() {
panic!("Killing training from user input.")
}
if let Err(e) = self.sender.send(event) {
log::warn!("Failed to send TUI event: {e}");
}
}
/// Set the renderer to persistent mode.
pub fn persistent(self) -> Self {
self.send_event(TuiRendererEvent::Persistent());
self
}
}
struct TuiMetricsRenderer {
terminal: Terminal<TerminalBackend>,
last_update: std::time::Instant,
progress: ProgressBarState,
@@ -47,75 +127,95 @@ pub struct TuiMetricsRenderer {
metrics_numeric: NumericMetricsState,
metrics_text: TextMetricsState,
status: StatusState,
interuptor: Interrupter,
interrupter: Interrupter,
popup: PopupState,
previous_panic_hook: Option<Arc<PanicHook>>,
persistent: bool,
manual_close: bool,
close: bool,
summary: Option<LearnerSummary>,
kill_signal: Sender<()>,
}
impl MetricsRendererEvaluation for TuiMetricsRenderer {
impl MetricsRendererEvaluation for TuiMetricsRendererWrapper {
fn update_test(&mut self, name: EvaluationName, state: MetricState) {
self.update_metric(TuiSplit::Test, TuiGroup::Named(name.name), state);
self.send_event(TuiRendererEvent::MetricsUpdate((
TuiSplit::Test,
TuiGroup::Named(name.name),
state,
)));
}
fn render_test(&mut self, item: EvaluationProgress) {
self.progress.update_test(&item);
self.metrics_numeric.update_progress_test(&item);
self.status.update_test(item);
self.render().unwrap();
fn render_test(&mut self, item: EvaluationProgress, progress_indicators: Vec<ProgressType>) {
self.send_event(TuiRendererEvent::StatusUpdateTest((
item,
progress_indicators,
)));
}
}
impl MetricsRenderer for TuiMetricsRenderer {
impl MetricsRenderer for TuiMetricsRendererWrapper {
fn manual_close(&mut self) {
loop {
self.render().unwrap();
if self.interuptor.should_stop() {
return;
}
std::thread::sleep(Duration::from_millis(100));
}
self.send_event(TuiRendererEvent::ManualClose());
let _ = self.handle_join.take().unwrap().join();
}
fn register_metric(&mut self, definition: MetricDefinition) {
self.metric_definitions
.insert(definition.metric_id.clone(), definition);
self.send_event(TuiRendererEvent::MetricRegistration(definition));
}
}
impl MetricsRendererTraining for TuiMetricsRenderer {
impl MetricsRendererTraining for TuiMetricsRendererWrapper {
fn update_train(&mut self, state: MetricState) {
self.update_metric(TuiSplit::Train, TuiGroup::Default, state);
self.send_event(TuiRendererEvent::MetricsUpdate((
TuiSplit::Train,
TuiGroup::Default,
state,
)));
}
fn update_valid(&mut self, state: MetricState) {
self.update_metric(TuiSplit::Valid, TuiGroup::Default, state);
self.send_event(TuiRendererEvent::MetricsUpdate((
TuiSplit::Valid,
TuiGroup::Default,
state,
)));
}
fn render_train(&mut self, item: TrainingProgress) {
self.progress.update_train(&item);
self.metrics_numeric.update_progress_train(&item);
self.status.update_train(item);
self.render().unwrap();
fn render_train(&mut self, item: TrainingProgress, progress_indicators: Vec<ProgressType>) {
self.send_event(TuiRendererEvent::StatusUpdateTrain((
TuiSplit::Train,
item,
progress_indicators,
)));
}
fn render_valid(&mut self, item: TrainingProgress) {
self.progress.update_valid(&item);
self.metrics_numeric.update_progress_valid(&item);
self.status.update_valid(item);
self.render().unwrap();
fn render_valid(&mut self, item: TrainingProgress, progress_indicators: Vec<ProgressType>) {
self.send_event(TuiRendererEvent::StatusUpdateTrain((
TuiSplit::Valid,
item,
progress_indicators,
)));
}
fn on_train_end(&mut self, summary: Option<LearnerSummary>) -> Result<(), Box<dyn Error>> {
// Reset for following steps.
self.interuptor.reset();
self.interrupter.reset();
// Update the summary
self.summary = summary;
self.send_event(TuiRendererEvent::TrainEnd(summary));
Ok(())
}
}
impl Drop for TuiMetricsRendererWrapper {
fn drop(&mut self) {
if !std::thread::panicking() {
self.send_event(TuiRendererEvent::Close());
let _ = self.handle_join.take().unwrap().join();
}
}
}
impl TuiMetricsRenderer {
fn update_metric(&mut self, split: TuiSplit, group: TuiGroup, state: MetricState) {
match state {
@@ -144,8 +244,11 @@ impl TuiMetricsRenderer {
};
}
/// Create a new terminal UI renderer.
pub fn new(interuptor: Interrupter, checkpoint: Option<usize>) -> Self {
pub fn new(
interrupter: Interrupter,
checkpoint: Option<usize>,
kill_signal: Sender<()>,
) -> Self {
let mut stdout = io::stdout();
execute!(stdout, EnterAlternateScreen).unwrap();
enable_raw_mode().unwrap();
@@ -171,28 +274,57 @@ impl TuiMetricsRenderer {
metrics_numeric: NumericMetricsState::default(),
metrics_text: TextMetricsState::default(),
status: StatusState::default(),
interuptor,
interrupter,
popup: PopupState::Empty,
previous_panic_hook: Some(previous_panic_hook),
persistent: false,
manual_close: false,
close: false,
summary: None,
kill_signal,
}
}
/// Set the renderer to persistent mode.
pub fn persistent(mut self) -> Self {
self.persistent = true;
self
fn handle_event(&mut self, event: TuiRendererEvent) {
match event {
TuiRendererEvent::MetricRegistration(definition) => {
self.metric_definitions
.insert(definition.metric_id.clone(), definition);
}
TuiRendererEvent::MetricsUpdate((split, group, state)) => {
self.update_metric(split, group, state);
}
TuiRendererEvent::StatusUpdateTrain((split, item, status)) => match split {
TuiSplit::Train => {
self.progress.update_train(&item);
self.metrics_numeric.update_progress_train(&item);
self.status.update_train(status);
}
TuiSplit::Valid => {
self.progress.update_valid(&item);
self.metrics_numeric.update_progress_valid(&item);
self.status.update_valid(status);
}
_ => (),
},
TuiRendererEvent::StatusUpdateTest((item, status)) => {
self.progress.update_test(&item);
self.metrics_numeric.update_progress_test(&item);
self.status.update_test(status);
}
TuiRendererEvent::TrainEnd(learner_summary) => {
self.interrupter.reset();
self.summary = learner_summary;
}
TuiRendererEvent::ManualClose() => self.manual_close = true,
TuiRendererEvent::Persistent() => self.persistent = true,
TuiRendererEvent::Close() => self.close = true,
}
}
fn render(&mut self) -> Result<(), Box<dyn Error>> {
let tick_rate = Duration::from_millis(MAX_REFRESH_RATE_MILLIS);
if self.last_update.elapsed() < tick_rate {
return Ok(());
}
self.draw()?;
self.handle_events()?;
self.handle_user_input()?;
self.last_update = Instant::now();
@@ -222,7 +354,7 @@ impl TuiMetricsRenderer {
Ok(())
}
fn handle_events(&mut self) -> Result<(), Box<dyn Error>> {
fn handle_user_input(&mut self) -> Result<(), Box<dyn Error>> {
while event::poll(Duration::from_secs(0))? {
let event = event::read()?;
self.popup.on_event(&event);
@@ -242,7 +374,7 @@ impl TuiMetricsRenderer {
training loop, but any remaining code after the loop will be \
executed.",
's',
QuitPopupAccept(self.interuptor.clone()),
QuitPopupAccept(self.interrupter.clone()),
),
Callback::new(
"Stop the training immediately.",
@@ -250,7 +382,7 @@ impl TuiMetricsRenderer {
the current training fails. Any code following the training \
won't be executed.",
'k',
KillPopupAccept,
KillPopupAccept(self.kill_signal.clone()),
),
Callback::new(
"Cancel",
@@ -337,11 +469,12 @@ impl TuiMetricsRenderer {
}
struct QuitPopupAccept(Interrupter);
struct KillPopupAccept;
struct KillPopupAccept(Sender<()>);
struct PopupCancel;
impl CallbackFn for KillPopupAccept {
fn call(&self) -> bool {
self.0.send(()).unwrap();
panic!("Killing training from user input.");
}
}

View File

@@ -1,5 +1,6 @@
use crate::renderer::ProgressType;
use super::TerminalFrame;
use crate::renderer::{EvaluationProgress, TrainingProgress};
use ratatui::{
prelude::{Alignment, Rect},
style::{Color, Style, Stylize},
@@ -9,7 +10,7 @@ use ratatui::{
/// Show the training status with various information.
pub(crate) struct StatusState {
progress: TrainingProgress,
progress_indicators: Vec<ProgressType>,
mode: Mode,
}
@@ -22,7 +23,7 @@ enum Mode {
impl Default for StatusState {
fn default() -> Self {
Self {
progress: TrainingProgress::none(),
progress_indicators: vec![],
mode: Mode::Train,
}
}
@@ -30,24 +31,23 @@ impl Default for StatusState {
impl StatusState {
/// Update the training information.
pub(crate) fn update_train(&mut self, progress: TrainingProgress) {
self.progress = progress;
pub(crate) fn update_train(&mut self, progress_indicators: Vec<ProgressType>) {
self.progress_indicators = progress_indicators;
self.mode = Mode::Train;
}
/// Update the validation information.
pub(crate) fn update_valid(&mut self, progress: TrainingProgress) {
self.progress = progress;
pub(crate) fn update_valid(&mut self, progress_indicators: Vec<ProgressType>) {
self.progress_indicators = progress_indicators;
self.mode = Mode::Valid;
}
/// Update the testing information.
pub(crate) fn update_test(&mut self, _progress: EvaluationProgress) {
// TODO: Use the progress here.
// self.progress = progress;
pub(crate) fn update_test(&mut self, progress_indicators: Vec<ProgressType>) {
self.progress_indicators = progress_indicators;
self.mode = Mode::Evaluation;
}
/// Create a view.
pub(crate) fn view(&self) -> StatusView {
StatusView::new(&self.progress, &self.mode)
StatusView::new(&self.progress_indicators, &self.mode)
}
}
@@ -56,7 +56,7 @@ pub(crate) struct StatusView {
}
impl StatusView {
fn new(progress: &TrainingProgress, mode: &Mode) -> Self {
fn new(progress_indicators: &[ProgressType], mode: &Mode) -> Self {
let title = |title: &str| Span::from(format!(" {title} ")).bold().yellow();
let value = |value: String| Span::from(value).italic();
let mode = match mode {
@@ -65,26 +65,38 @@ impl StatusView {
Mode::Evaluation => "Evaluation",
};
Self {
lines: vec![
vec![title("Mode :"), value(mode.to_string())],
vec![
title("Epoch :"),
value(format!("{}/{}", progress.epoch, progress.epoch_total)),
],
vec![
title("Iteration :"),
value(format!("{}", progress.iteration)),
],
vec![
title("Items :"),
value(format!(
"{}/{}",
progress.progress.items_processed, progress.progress.items_total
)),
],
],
}
let width = progress_indicators
.iter()
.map(|p| match p {
ProgressType::Detailed { tag, .. } => tag.len(),
ProgressType::Value { tag, .. } => tag.len(),
})
.max()
.unwrap_or(4);
let mut lines = vec![vec![
title(&format!("{: <width$} :", "Mode")),
value(mode.to_string()),
]];
progress_indicators.iter().for_each(|p| match p {
ProgressType::Detailed { tag, progress } => lines.push(vec![
title(&format!("{: <width$} :", tag)),
value(format!(
"{}/{}",
progress.items_processed, progress.items_total
)),
]),
ProgressType::Value {
tag,
value: num_items,
} => lines.push(vec![
title(&format!("{: <width$} :", tag)),
value(format!("{}", num_items)),
]),
});
Self { lines }
}
pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) {

View File

@@ -118,6 +118,7 @@ cubecl = ["dep:cubecl"]
audio = ["burn-core/audio"]
vision = ["burn-core/vision"]
rl = ["dep:burn-rl"]
# Backend
ir = ["burn-ir"]
@@ -177,6 +178,7 @@ burn-collective = { path = "../burn-collective", version = "=0.21.0", optional =
burn-store = { path = "../burn-store", version = "=0.21.0", optional = true, default-features = false }
burn-nn = { path = "../burn-nn", version = "=0.21.0", default-features = false }
burn-optim = { path = "../burn-optim", version = "=0.21.0", default-features = false }
burn-rl = { path = "../burn-rl", version = "=0.21.0", optional = true, default-features = false }
# Backends
burn-autodiff = { path = "../burn-autodiff", version = "=0.21.0", optional = true, default-features = false }

View File

@@ -124,6 +124,12 @@ pub mod train {
pub use burn_train::*;
}
/// Module for reinforcement learning.
#[cfg(feature = "rl")]
pub mod rl {
pub use burn_rl::*;
}
/// Backend module.
pub mod backend;

View File

@@ -9,7 +9,7 @@ publish = false
workspace = true
[dependencies]
burn = {path = "../../crates/burn", features=["autodiff", "webgpu", "vision"]}
burn = {path = "../../crates/burn", default-features = false, features=["autodiff", "webgpu", "vision"]}
guide = {path = "../guide"}
derive-new = { workspace = true }
log = { workspace = true }
log = { workspace = true }

View File

@@ -1,4 +1,5 @@
use crate::model::ModelConfig;
use burn::data::dataloader::Progress;
use burn::record::NoStdTrainingRecorder;
use burn::train::{
EventProcessorTraining, Learner, LearningComponentsTypes, SupervisedLearningStrategy,
@@ -20,7 +21,7 @@ use burn::{
record::CompactRecorder,
tensor::{Device, backend::AutodiffBackend},
train::{
InferenceStep, LearnerEvent, LearnerItem, MetricEarlyStoppingStrategy, StoppingCondition,
InferenceStep, LearnerEvent, MetricEarlyStoppingStrategy, StoppingCondition, TrainingItem,
metric::{
AccuracyMetric, LossMetric,
store::{Aggregate, Direction, Split},
@@ -167,12 +168,11 @@ impl<LC: LearningComponentsTypes> SupervisedLearningStrategy<LC> for MyCustomLea
let item = learner.train_step(item);
learner.optimizer_step(item.grads);
let item = LearnerItem::new(
let item = TrainingItem::new(
item.item,
progress,
epoch,
num_epochs,
iteration,
Progress::new(epoch, num_epochs),
Some(iteration),
Some(learner.lr_current()),
);
@@ -198,7 +198,13 @@ impl<LC: LearningComponentsTypes> SupervisedLearningStrategy<LC> for MyCustomLea
iteration += 1;
let item = model_valid.step(item);
let item = LearnerItem::new(item, progress, epoch, num_epochs, iteration, None);
let item = TrainingItem::new(
item,
progress,
Progress::new(epoch, num_epochs),
Some(iteration),
None,
);
event_processor.process_valid(LearnerEvent::ProcessedItem(item));
}

View File

@@ -7,7 +7,7 @@ use burn::{
Learner, SupervisedTraining,
renderer::{
EvaluationName, EvaluationProgress, MetricState, MetricsRenderer,
MetricsRendererEvaluation, MetricsRendererTraining, TrainingProgress,
MetricsRendererEvaluation, MetricsRendererTraining, ProgressType, TrainingProgress,
},
},
};
@@ -36,11 +36,11 @@ impl MetricsRendererTraining for CustomRenderer {
fn update_valid(&mut self, _state: MetricState) {}
fn render_train(&mut self, item: TrainingProgress) {
fn render_train(&mut self, item: TrainingProgress, _progress_indicators: Vec<ProgressType>) {
dbg!(item);
}
fn render_valid(&mut self, item: TrainingProgress) {
fn render_valid(&mut self, item: TrainingProgress, _progress_indicators: Vec<ProgressType>) {
dbg!(item);
}
}
@@ -56,7 +56,7 @@ impl MetricsRenderer for CustomRenderer {
impl MetricsRendererEvaluation for CustomRenderer {
fn update_test(&mut self, _name: EvaluationName, _state: MetricState) {}
fn render_test(&mut self, item: EvaluationProgress) {
fn render_test(&mut self, item: EvaluationProgress, _progress_indicators: Vec<ProgressType>) {
dbg!(item);
}
}

View File

@@ -0,0 +1,41 @@
[package]
name = "dqn-agent"
edition.workspace = true
license.workspace = true
readme.workspace = true
version.workspace = true
[features]
default = ["burn/tui"]
ndarray = ["burn/ndarray"]
ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"]
ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"]
ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"]
tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
remote = ["burn/remote"]
wgpu = ["burn/wgpu", "burn/default"]
metal = ["burn/metal", "burn/default"]
cuda = ["burn/cuda"]
vulkan = ["burn/vulkan", "burn/default"]
rocm = ["burn/rocm", "burn/default"]
[dependencies]
# Disable autotune default for now (convolutions not optimized)
burn = { path = "../../crates/burn", features = [
"train",
# "vision",
"metrics",
"std",
"rl",
# "fusion",
"ndarray",
# "autotune",
], default-features = false }
# Just for this example.
gym-rs = { version = "0.3.1", branch = "main", git = "https://github.com/MathisWellmann/gym-rs" }
rand.workspace = true
derive-new = { workspace = true }
[lints]
workspace = true

View File

@@ -0,0 +1,118 @@
#[cfg(any(
feature = "ndarray",
feature = "ndarray-blas-netlib",
feature = "ndarray-blas-openblas",
feature = "ndarray-blas-accelerate",
))]
mod ndarray {
use burn::backend::{
Autodiff,
ndarray::{NdArray, NdArrayDevice},
};
use dqn_agent::training;
pub fn run() {
let device = NdArrayDevice::Cpu;
training::run::<Autodiff<NdArray>>(device);
}
}
#[cfg(feature = "tch-gpu")]
mod tch_gpu {
use burn::backend::{
Autodiff,
libtorch::{LibTorch, LibTorchDevice},
};
use dqn_agent::training;
pub fn run() {
#[cfg(not(target_os = "macos"))]
let device = LibTorchDevice::Cuda(0);
#[cfg(target_os = "macos")]
let device = LibTorchDevice::Mps;
training::run::<Autodiff<LibTorch>>(device);
}
}
#[cfg(any(feature = "wgpu", feature = "metal", feature = "vulkan"))]
mod wgpu {
use burn::backend::{
Autodiff,
wgpu::{Wgpu, WgpuDevice},
};
use dqn_agent::training;
pub fn run() {
let device = WgpuDevice::default();
training::run::<Autodiff<Wgpu>>(device);
}
}
#[cfg(feature = "cuda")]
mod cuda {
use burn::backend::{Autodiff, Cuda};
use dqn_agent::training;
pub fn run() {
let device = Default::default();
training::run::<Autodiff<Cuda>>(device);
}
}
#[cfg(feature = "rocm")]
mod rocm {
use burn::backend::{Autodiff, Rocm};
use dqn_agent::training;
pub fn run() {
let device = Default::default();
training::run::<Autodiff<Rocm>>(device);
}
}
#[cfg(feature = "tch-cpu")]
mod tch_cpu {
use burn::backend::{
Autodiff,
libtorch::{LibTorch, LibTorchDevice},
};
use dqn_agent::training;
pub fn run() {
let device = LibTorchDevice::Cpu;
training::run::<Autodiff<LibTorch>>(device);
}
}
#[cfg(feature = "remote")]
mod remote {
use burn::backend::{Autodiff, RemoteBackend};
use dqn_agent::training;
pub fn run() {
training::run::<Autodiff<RemoteBackend>>(Default::default());
}
}
fn main() {
#[cfg(any(
feature = "ndarray",
feature = "ndarray-blas-netlib",
feature = "ndarray-blas-openblas",
feature = "ndarray-blas-accelerate",
))]
ndarray::run();
#[cfg(feature = "tch-gpu")]
tch_gpu::run();
#[cfg(feature = "tch-cpu")]
tch_cpu::run();
#[cfg(any(feature = "wgpu", feature = "metal", feature = "vulkan"))]
wgpu::run();
#[cfg(feature = "cuda")]
cuda::run();
#[cfg(feature = "rocm")]
rocm::run();
#[cfg(feature = "remote")]
remote::run();
}

View File

@@ -0,0 +1,477 @@
use std::marker::PhantomData;
use burn::backend::NdArray;
use burn::module::Module;
use burn::record::Record;
use burn::rl::{
Batchable, LearnerTransitionBatch, Policy, PolicyLearner, PolicyState, RLTrainOutput,
};
use burn::tensor::Transaction;
use burn::tensor::activation::softmax;
use burn::train::ItemLazy;
use burn::train::metric::{Adaptor, LossInput};
use burn::{
Tensor,
config::Config,
module::AutodiffModule,
nn::{self, loss::MseLoss},
optim::{GradientsParams, Optimizer},
prelude::Backend,
tensor::backend::AutodiffBackend,
};
use rand::distr::Distribution;
use rand::distr::weighted::WeightedIndex;
use rand::rng;
use crate::utils::{
EpsilonGreedyPolicy, EpsilonGreedyPolicyState, create_lin_layers, soft_update_linear,
};
pub trait DiscreteActionModel<B: Backend>: Module<B> {
type Input: Clone + Send + Batchable;
fn forward(&self, input: Self::Input) -> DiscreteLogitsTensor<B, 2>;
}
#[derive(Config, Debug)]
pub struct MlpNetConfig {
/// The number of layers.
#[config(default = 3)]
pub num_layers: usize,
/// The dropout rate.
#[config(default = 0.)]
pub dropout: f64,
/// The input dimension.
#[config(default = 4)]
pub d_input: usize,
/// The output dimension.
#[config(default = 2)]
pub d_output: usize,
/// The size of hidden layers.
#[config(default = 256)]
pub d_hidden: usize,
}
/// Multilayer Perceptron Network.
#[derive(Module, Debug)]
pub struct MlpNet<B: Backend> {
pub linears: Vec<nn::Linear<B>>,
pub dropout: nn::Dropout,
pub activation: nn::Relu,
}
impl<B: Backend> MlpNet<B> {
/// Create the module from the given configuration.
pub fn new(config: &MlpNetConfig, device: &B::Device) -> Self {
Self {
linears: create_lin_layers(
config.num_layers,
config.d_input,
config.d_hidden,
config.d_output,
device,
),
dropout: nn::DropoutConfig::new(config.dropout).init(),
activation: nn::Relu::new(),
}
}
}
#[derive(Clone)]
pub struct ObservationTensor<B: Backend, const D: usize> {
pub state: Tensor<B, D>,
}
impl<B: Backend, const D: usize> Batchable for ObservationTensor<B, D> {
fn batch(value: Vec<Self>) -> Self {
let tensors = value.iter().map(|v| v.state.clone()).collect();
Self {
state: Tensor::cat(tensors, 0),
}
}
fn unbatch(self) -> Vec<Self> {
self.state
.split(1, 0)
.iter()
.map(|s| ObservationTensor { state: s.clone() })
.collect()
}
}
impl<B: Backend> DiscreteActionModel<B> for MlpNet<B> {
type Input = ObservationTensor<B, 2>;
/// Applies the forward pass on the input tensor.
///
/// # Shapes
///
/// - input: `[batch_size, d_input]`
/// - output: `[batch_size, d_output]`
fn forward(&self, input: Self::Input) -> DiscreteLogitsTensor<B, 2> {
let mut x = input.state;
for (i, linear) in self.linears.iter().enumerate() {
x = linear.forward(x);
x = self.dropout.forward(x);
if i < self.linears.len() - 1 {
x = self.activation.forward(x);
}
}
DiscreteLogitsTensor { logits: x }
}
}
#[derive(Config, Debug)]
pub struct DqnAgentConfig {
/// Discount factor (How to value long-term vs short-term rewards)
#[config(default = 0.99)]
pub gamma: f64,
/// The learning rate
#[config(default = 3e-4)]
pub learning_rate: f64,
/// The soft update rate of the target network
#[config(default = 0.005)]
pub tau: f64,
/// Initial value of epsilon (Probability to choose a random action)
#[config(default = 0.9)]
pub epsilon_start: f64,
/// Final value of epsilon (Probability to choose a random action)
#[config(default = 0.01)]
pub epsilon_end: f64,
/// The exponential rate at which the epsilon value decays. Higher = slower decay
#[config(default = 2500.0)]
pub epsilon_decay: f64,
}
pub trait TargetModel<B: Backend> {
fn soft_update(&self, that: &Self, tau: f64) -> Self;
}
impl<B: Backend> TargetModel<B> for MlpNet<B> {
fn soft_update(&self, that: &Self, tau: f64) -> Self {
let mut linears = Vec::with_capacity(self.linears.len());
for i in 0..self.linears.len() {
let layer = soft_update_linear(self.linears[i].clone(), &that.linears[i].clone(), tau);
linears.insert(i, layer);
}
Self {
linears,
dropout: self.dropout.clone(),
activation: self.activation.clone(),
}
}
}
#[derive(Clone)]
pub struct DqnState<B: Backend, M: DiscreteActionModel<B>> {
model: M,
_backend: PhantomData<B>,
}
impl<B: Backend, M: DiscreteActionModel<B>> PolicyState<B> for DqnState<B, M> {
type Record = M::Record;
fn into_record(self) -> Self::Record {
self.model.clone().into_record()
}
fn load_record(&self, record: Self::Record) -> Self {
Self {
model: self.model.clone().load_record(record),
_backend: PhantomData,
}
}
}
#[derive(Clone)]
pub struct DQN<B: Backend, M: DiscreteActionModel<B>> {
model: M,
_backend: PhantomData<B>,
}
impl<B: Backend, M: DiscreteActionModel<B>> DQN<B, M> {
pub fn new(policy: M) -> Self {
Self {
model: policy,
_backend: PhantomData,
}
}
}
#[derive(Clone)]
pub struct DiscreteLogitsTensor<B: Backend, const D: usize> {
pub logits: Tensor<B, D>,
}
impl<B: Backend, const D: usize> Batchable for DiscreteLogitsTensor<B, D> {
fn batch(value: Vec<Self>) -> Self {
let tensors = value.iter().map(|v| v.logits.clone()).collect();
Self {
logits: Tensor::cat(tensors, 0),
}
}
fn unbatch(self) -> Vec<Self> {
self.logits
.split(1, 0)
.iter()
.map(|l| DiscreteLogitsTensor { logits: l.clone() })
.collect()
}
}
#[derive(Clone)]
pub struct DiscreteActionTensor<B: Backend, const D: usize> {
pub actions: Tensor<B, D>,
}
impl<B: Backend, const D: usize> Batchable for DiscreteActionTensor<B, D> {
fn batch(value: Vec<Self>) -> Self {
let tensors = value.iter().map(|v| v.actions.clone()).collect();
Self {
actions: Tensor::cat(tensors, 0),
}
}
fn unbatch(self) -> Vec<Self> {
self.actions
.split(1, 0)
.iter()
.map(|a| DiscreteActionTensor { actions: a.clone() })
.collect()
}
}
impl<B: Backend, M: DiscreteActionModel<B>> Policy<B> for DQN<B, M> {
type Observation = M::Input;
type ActionDistribution = DiscreteLogitsTensor<B, 2>;
type Action = DiscreteActionTensor<B, 2>;
type ActionContext = ();
type PolicyState = DqnState<B, M>;
fn forward(&mut self, states: Self::Observation) -> Self::ActionDistribution {
self.model.forward(states)
}
fn action(
&mut self,
states: Self::Observation,
deterministic: bool,
) -> (Self::Action, Vec<Self::ActionContext>) {
let logits = self.forward(states).logits;
if deterministic {
let output = DiscreteActionTensor {
actions: logits.argmax(1).float(),
};
return (output, vec![]);
}
let mut actions = vec![];
let probs = softmax(logits, 1);
let probs = probs.split(1, 0);
let mut rng = rng();
for p in probs {
let dist = WeightedIndex::new(p.to_data().to_vec::<f32>().unwrap()).unwrap();
let action = dist.sample(&mut rng);
actions.push(Tensor::<B, 1>::from_floats([action], &p.device()));
}
let output = DiscreteActionTensor {
actions: Tensor::stack(actions, 1),
};
(output, vec![])
}
fn update(&mut self, update: Self::PolicyState) {
self.model = update.model;
}
fn state(&self) -> Self::PolicyState {
DqnState {
model: self.model.clone(),
_backend: PhantomData,
}
}
fn load_record(self, record: <Self::PolicyState as PolicyState<B>>::Record) -> Self {
let state = self.state().load_record(record);
Self {
model: state.model,
_backend: PhantomData,
}
}
}
#[derive(Record)]
pub struct DqnLearningRecord<B: AutodiffBackend, M: AutodiffModule<B>, O: Optimizer<M, B>> {
policy_model: M::Record,
target_model: M::Record,
optimizer: O::Record,
}
#[derive(Clone)]
pub struct DqnLearningAgent<B, M, O>
where
B: AutodiffBackend,
M: DiscreteActionModel<B> + AutodiffModule<B> + TargetModel<B> + 'static,
M::InnerModule: DiscreteActionModel<B::InnerBackend> + TargetModel<B::InnerBackend>,
O: Optimizer<M, B> + 'static,
{
policy_model: M,
target_model: M,
agent: EpsilonGreedyPolicy<B, DQN<B, M>>,
optimizer: O,
config: DqnAgentConfig,
}
impl<B, M, O> DqnLearningAgent<B, M, O>
where
B: AutodiffBackend,
M: DiscreteActionModel<B> + AutodiffModule<B> + TargetModel<B> + 'static,
M::InnerModule: DiscreteActionModel<B::InnerBackend> + TargetModel<B::InnerBackend>,
O: Optimizer<M, B> + 'static,
{
pub fn new(model: M, optimizer: O, config: DqnAgentConfig) -> Self {
let agent = EpsilonGreedyPolicy::new(
DQN::new(model.clone()),
config.epsilon_start,
config.epsilon_end,
config.epsilon_decay,
);
Self {
policy_model: model.clone(),
target_model: model,
agent,
optimizer,
config,
}
}
}
#[derive(Clone)]
pub struct SimpleTrainOutput<B: Backend> {
pub policy_model_loss: Tensor<B, 1>,
}
impl<B: Backend> ItemLazy for SimpleTrainOutput<B> {
type ItemSync = SimpleTrainOutput<NdArray>;
fn sync(self) -> Self::ItemSync {
let [loss] = Transaction::default()
.register(self.policy_model_loss)
.execute()
.try_into()
.expect("Correct amount of tensor data");
let device = &Default::default();
SimpleTrainOutput {
policy_model_loss: Tensor::from_data(loss, device),
}
}
}
impl<B: Backend> Adaptor<LossInput<B>> for SimpleTrainOutput<B> {
fn adapt(&self) -> LossInput<B> {
LossInput::new(self.policy_model_loss.clone())
}
}
impl<B, M, O> PolicyLearner<B> for DqnLearningAgent<B, M, O>
where
B: AutodiffBackend,
M: DiscreteActionModel<B> + AutodiffModule<B> + TargetModel<B> + 'static,
M::Input: Clone,
M::InnerModule: DiscreteActionModel<B::InnerBackend> + TargetModel<B::InnerBackend>,
O: Optimizer<M, B> + 'static,
{
type TrainContext = SimpleTrainOutput<B>;
type InnerPolicy = EpsilonGreedyPolicy<B, DQN<B, M>>;
type Record = DqnLearningRecord<B, M, O>;
fn train(
&mut self,
input: LearnerTransitionBatch<B, Self::InnerPolicy>,
) -> RLTrainOutput<Self::TrainContext, <Self::InnerPolicy as Policy<B>>::PolicyState> {
let states_batch = input.states;
let next_states_batch = input.next_states;
let actions_batch = input.actions.actions;
let rewards_batch = input.rewards;
let dones_batch = input.dones;
// Optimize
let logits = self.policy_model.forward(states_batch).logits;
let state_action_values = logits.gather(1, actions_batch.int());
let next_state_values = self.target_model.forward(next_states_batch.clone());
let next_state_values = next_state_values.logits.max_dim(1).squeeze::<1>();
let not_done_batch = Tensor::ones_like(&dones_batch) - dones_batch;
let expected_state_action_values = (next_state_values * not_done_batch.squeeze())
.mul_scalar(self.config.gamma)
+ rewards_batch.squeeze();
let expected_state_action_values = expected_state_action_values.unsqueeze_dim::<2>(1);
let loss = MseLoss::new().forward(
state_action_values,
expected_state_action_values,
nn::loss::Reduction::Mean,
);
let gradients = loss.backward();
let gradient_params = GradientsParams::from_grads(gradients, &self.policy_model);
self.policy_model = self.optimizer.step(
self.config.learning_rate,
self.policy_model.clone(),
gradient_params,
);
self.target_model = self
.target_model
.soft_update(&self.policy_model, self.config.tau);
let policy_update = EpsilonGreedyPolicyState::new(
DqnState {
model: self.policy_model.clone(),
_backend: PhantomData,
},
self.agent.state().step,
);
self.agent.update(policy_update.clone());
RLTrainOutput {
policy: policy_update,
item: SimpleTrainOutput {
policy_model_loss: loss,
},
}
}
fn policy(&self) -> Self::InnerPolicy {
self.agent.clone()
}
fn update_policy(&mut self, update: Self::InnerPolicy) {
self.agent = update;
}
fn record(&self) -> Self::Record {
DqnLearningRecord {
policy_model: self.policy_model.clone().into_record(),
target_model: self.target_model.clone().into_record(),
optimizer: self.optimizer.to_record(),
}
}
fn load_record(self, record: Self::Record) -> Self {
let policy_model = self.policy_model.load_record(record.policy_model);
let target_model = self.target_model.load_record(record.target_model);
let optimizer = self.optimizer.load_record(record.optimizer);
Self {
policy_model,
target_model,
agent: self.agent,
optimizer,
config: self.config,
}
}
}

View File

@@ -0,0 +1,101 @@
use burn::rl::{Environment, StepResult};
use burn::{
Tensor,
prelude::{Backend, ToElement},
};
use gym_rs::{
core::Env,
envs::classical_control::cartpole::{CartPoleEnv, CartPoleObservation},
};
use crate::agent::{DiscreteActionTensor, ObservationTensor};
#[derive(Clone)]
pub struct CartPoleAction {
action: usize,
}
impl<B: Backend> From<DiscreteActionTensor<B, 2>> for CartPoleAction {
fn from(value: DiscreteActionTensor<B, 2>) -> Self {
Self {
action: value.actions.int().into_scalar().to_usize(),
}
}
}
impl<B: Backend> From<CartPoleAction> for DiscreteActionTensor<B, 2> {
fn from(value: CartPoleAction) -> Self {
DiscreteActionTensor {
actions: Tensor::<B, 1>::from_data([value.action], &Default::default()).unsqueeze(),
}
}
}
#[derive(Clone)]
pub struct CartPoleState {
pub state: [f64; 4],
}
impl From<CartPoleObservation> for CartPoleState {
fn from(observation: CartPoleObservation) -> Self {
let vec = Vec::<f64>::from(observation);
Self {
state: [vec[0], vec[1], vec[2], vec[3]],
}
}
}
impl<B: Backend> From<CartPoleState> for ObservationTensor<B, 2> {
fn from(val: CartPoleState) -> Self {
ObservationTensor {
state: Tensor::<B, 1>::from_floats(val.state, &Default::default()).unsqueeze(),
}
}
}
#[derive(Clone)]
pub struct CartPoleWrapper {
gym_env: CartPoleEnv,
step_index: usize,
}
impl Default for CartPoleWrapper {
fn default() -> Self {
Self::new()
}
}
impl CartPoleWrapper {
pub fn new() -> Self {
Self {
gym_env: CartPoleEnv::new(gym_rs::utils::renderer::RenderMode::None),
step_index: 0,
}
}
}
impl Environment for CartPoleWrapper {
type State = CartPoleState;
type Action = CartPoleAction;
const MAX_STEPS: usize = 500;
fn state(&self) -> Self::State {
CartPoleState::from(self.gym_env.state)
}
fn step(&mut self, action: Self::Action) -> StepResult<Self::State> {
let action_reward = self.gym_env.step(action.action);
self.step_index += 1;
StepResult {
next_state: CartPoleState::from(action_reward.observation),
reward: action_reward.reward.into_inner(),
done: action_reward.done,
truncated: self.step_index >= Self::MAX_STEPS,
}
}
fn reset(&mut self) {
self.gym_env.reset(None, false, None);
self.step_index = 0;
}
}

View File

@@ -0,0 +1,4 @@
pub mod agent;
pub mod env;
pub mod training;
pub mod utils;

View File

@@ -0,0 +1,64 @@
use burn::{
grad_clipping::GradientClippingConfig,
optim::AdamWConfig,
record::CompactRecorder,
tensor::backend::AutodiffBackend,
train::{
OffPolicyConfig, RLTraining,
metric::{CumulativeRewardMetric, EpisodeLengthMetric, ExplorationRateMetric, LossMetric},
},
};
use crate::{
agent::{DqnAgentConfig, DqnLearningAgent, MlpNet, MlpNetConfig},
env::CartPoleWrapper,
};
static ARTIFACT_DIR: &str = "/tmp/burn-example-dqn-agent";
pub fn run<B: AutodiffBackend>(device: B::Device) {
let dqn_config = DqnAgentConfig {
gamma: 0.99,
learning_rate: 3e-4,
tau: 0.005,
epsilon_start: 0.99,
epsilon_end: 0.05,
epsilon_decay: 6000.0,
};
let model_config = MlpNetConfig {
num_layers: 3,
dropout: 0.0,
d_input: 4,
d_output: 2,
d_hidden: 64,
};
let learning_config = OffPolicyConfig {
num_envs: 8,
autobatch_size: 8,
replay_buffer_size: 50_000,
train_interval: 8,
eval_interval: 4_000,
eval_episodes: 5,
train_batch_size: 128,
train_steps: 4,
warmup_steps: 0,
};
let policy_model = MlpNet::<B>::new(&model_config, &device);
let optimizer = AdamWConfig::new()
.with_grad_clipping(Some(GradientClippingConfig::Value(100.0)))
.init();
let agent = DqnLearningAgent::new(policy_model, optimizer, dqn_config);
let learner = RLTraining::new(ARTIFACT_DIR, CartPoleWrapper::new)
.metrics_train((LossMetric::new(),))
.metrics_agent((ExplorationRateMetric::new(),))
.metrics_episode((EpisodeLengthMetric::new(), CumulativeRewardMetric::new()))
.with_file_checkpointer(CompactRecorder::new())
.num_steps(40_000)
.with_learning_strategy(burn::train::RLStrategies::OffPolicyStrategy(
learning_config,
))
.summary();
let _result = learner.launch(agent);
}

View File

@@ -0,0 +1,226 @@
use std::marker::PhantomData;
use burn::{
Tensor,
module::{Param, ParamId},
nn::{self, Linear},
prelude::Backend,
record::Record,
rl::{Policy, PolicyState},
tensor::Device,
train::{
ItemLazy,
metric::{Adaptor, ExplorationRateInput},
},
};
use derive_new::new;
use rand::{random, random_range};
use crate::agent::{DiscreteActionTensor, DiscreteLogitsTensor};
pub fn create_lin_layers<B: Backend>(
num_layers: usize,
d_input: usize,
d_hidden: usize,
d_output: usize,
device: &Device<B>,
) -> Vec<Linear<B>> {
let mut linears = Vec::with_capacity(num_layers);
if num_layers == 1 {
linears.push(nn::LinearConfig::new(d_input, d_output).init(device));
return linears;
}
for i in 0..num_layers {
if i == 0 {
linears.push(nn::LinearConfig::new(d_input, d_hidden).init(device));
} else if i == num_layers - 1 {
linears.push(nn::LinearConfig::new(d_hidden, d_output).init(device));
} else {
linears.push(nn::LinearConfig::new(d_hidden, d_hidden).init(device));
}
}
linears
}
pub fn soft_update_linear<B: Backend>(this: Linear<B>, that: &Linear<B>, tau: f64) -> Linear<B> {
let weight = soft_update_tensor(&this.weight, &that.weight, tau);
let bias = match (&this.bias, &that.bias) {
(Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)),
_ => None,
};
Linear::<B> { weight, bias }
}
fn soft_update_tensor<const N: usize, B: Backend>(
this: &Param<Tensor<B, N>>,
that: &Param<Tensor<B, N>>,
tau: f64,
) -> Param<Tensor<B, N>> {
let that_weight = that.val();
let this_weight = this.val();
let new_weight = this_weight * (1.0 - tau) + that_weight * tau;
Param::initialized(ParamId::new(), new_weight)
}
#[derive(Clone)]
pub struct EpsilonGreedyPolicyOutput {
pub epsilon: f64,
}
impl ItemLazy for EpsilonGreedyPolicyOutput {
type ItemSync = EpsilonGreedyPolicyOutput;
fn sync(self) -> Self::ItemSync {
self
}
}
impl Adaptor<ExplorationRateInput> for EpsilonGreedyPolicyOutput {
fn adapt(&self) -> ExplorationRateInput {
ExplorationRateInput::new(self.epsilon)
}
}
#[derive(Record)]
pub struct EpsilonGreedyPolicyRecord<B: Backend, P: Policy<B>> {
pub inner_state: <P::PolicyState as PolicyState<B>>::Record,
pub step: usize,
}
#[derive(Clone, new)]
pub struct EpsilonGreedyPolicyState<B: Backend, P: Policy<B>> {
pub inner_state: P::PolicyState,
pub step: usize,
}
impl<B: Backend, P: Policy<B>> PolicyState<B> for EpsilonGreedyPolicyState<B, P> {
type Record = EpsilonGreedyPolicyRecord<B, P>;
fn into_record(self) -> Self::Record {
EpsilonGreedyPolicyRecord {
inner_state: self.inner_state.into_record(),
step: self.step,
}
}
fn load_record(&self, record: Self::Record) -> Self {
let inner_state = self.inner_state.load_record(record.inner_state);
Self {
inner_state,
step: record.step,
}
}
}
#[derive(Clone, Debug)]
pub struct EpsilonGreedyPolicy<B: Backend, P: Policy<B>> {
inner_policy: P,
eps_start: f64,
eps_end: f64,
eps_decay: f64,
step: usize,
_backend: PhantomData<B>,
}
impl<B: Backend, P: Policy<B>> EpsilonGreedyPolicy<B, P> {
pub fn new(inner_policy: P, eps_start: f64, eps_end: f64, eps_decay: f64) -> Self {
Self {
inner_policy,
eps_start,
eps_end,
eps_decay,
step: 0,
_backend: PhantomData,
}
}
fn get_threshold(&self) -> f64 {
self.eps_end
+ (self.eps_start - self.eps_end) * f64::exp(-(self.step as f64) / self.eps_decay)
}
fn step(&mut self) -> f64 {
let thresh = self.get_threshold();
self.step += 1;
thresh
}
}
impl<B, P> Policy<B> for EpsilonGreedyPolicy<B, P>
where
B: Backend,
P: Policy<
B,
ActionDistribution = DiscreteLogitsTensor<B, 2>,
Action = DiscreteActionTensor<B, 2>,
>,
{
type ActionContext = EpsilonGreedyPolicyOutput;
type PolicyState = EpsilonGreedyPolicyState<B, P>;
type Observation = P::Observation;
type ActionDistribution = DiscreteLogitsTensor<B, 2>;
type Action = DiscreteActionTensor<B, 2>;
fn forward(&mut self, states: Self::Observation) -> Self::ActionDistribution {
self.inner_policy.forward(states)
}
fn action(
&mut self,
states: Self::Observation,
deterministic: bool,
) -> (Self::Action, Vec<Self::ActionContext>) {
let logits = self.inner_policy.forward(states).logits;
let greedy_actions = logits.argmax(1);
let greedy_actions = greedy_actions.split(1, 0);
let mut contexts = vec![];
let mut actions = vec![];
for a in greedy_actions {
let threshold = self.step();
let threshold = if deterministic { 0.0 } else { threshold };
contexts.push(EpsilonGreedyPolicyOutput { epsilon: threshold });
if random::<f64>() > threshold {
actions.push(a.clone().float());
} else {
actions.push(
Tensor::<B, 1>::from_floats([random_range(0..2)], &a.device()).unsqueeze(),
);
}
}
let output = Tensor::cat(actions, 0);
(DiscreteActionTensor { actions: output }, contexts)
}
fn update(&mut self, update: Self::PolicyState) {
// Note : updating an epsilon greedy policy doesn't change the step.
self.inner_policy.update(update.inner_state);
}
fn state(&self) -> Self::PolicyState {
EpsilonGreedyPolicyState {
inner_state: self.inner_policy.state(),
step: self.step,
}
}
fn load_record(self, record: <Self::PolicyState as PolicyState<B>>::Record) -> Self {
let state = self.state().load_record(record);
let inner_policy = self
.inner_policy
.load_record(state.inner_state.into_record());
EpsilonGreedyPolicy {
inner_policy,
eps_start: self.eps_start,
eps_end: self.eps_end,
eps_decay: self.eps_decay,
step: state.step,
_backend: PhantomData,
}
}
}

View File

@@ -118,6 +118,9 @@ pub(crate) fn handle_command(
"burn-router".to_string(),
"burn-tch".to_string(),
"burn-wgpu".to_string(),
// dqn-agent example relies on gym-rs dependency which requires SDL2.
// It would be good not to remove the gym-rs dependency in the future.
"dqn-agent".to_string(),
]);
// Burn remote tests don't work on windows for now