diff --git a/crates/burn-nn/Cargo.toml b/crates/burn-nn/Cargo.toml index ac0099f3f..037adf1a8 100644 --- a/crates/burn-nn/Cargo.toml +++ b/crates/burn-nn/Cargo.toml @@ -26,9 +26,6 @@ doc = [ "pretrained", ] pretrained = ["std", "burn-store/pytorch", "burn-std/network", "dirs"] -# Added for some test cases that should only be run locally -# (e.g., test cases with pretrained weights for gram matrix loss) -test-local = [] std = [ "burn-core/std", "num-traits/std", diff --git a/crates/burn-nn/src/loss/pretrained/gram_matrix/gram_matrix_loss.rs b/crates/burn-nn/src/loss/pretrained/gram_matrix/gram_matrix_loss.rs index 0c365db93..7ac0cbc75 100644 --- a/crates/burn-nn/src/loss/pretrained/gram_matrix/gram_matrix_loss.rs +++ b/crates/burn-nn/src/loss/pretrained/gram_matrix/gram_matrix_loss.rs @@ -313,6 +313,9 @@ mod tests { } #[test] + // TODO: run tests only locally, and #[serial]'ised? + // #[cfg(feature = "test-local")] + #[ignore = "downloads pre-trained weights"] fn test_gram_matrix_loss_config_valid_weights() { let device = Default::default(); let layer_weights = vec![0.0, 0.2, 0.2, 0.25, 0.4]; @@ -511,7 +514,7 @@ mod tests { } #[test] - #[cfg(feature = "test-local")] + #[ignore = "downloads pre-trained weights"] fn test_gram_matrix_loss_pretrained_weights_identical_inputs() { let device = Default::default(); let loss_fn = @@ -532,7 +535,7 @@ mod tests { } #[test] - #[cfg(feature = "test-local")] + #[ignore = "downloads pre-trained weights"] fn test_gram_matrix_loss_pretrained_weights_different_inputs() { let device = Default::default(); let loss_fn = diff --git a/xtask/src/commands/test.rs b/xtask/src/commands/test.rs index 5f4f6b26e..6b1da6f19 100644 --- a/xtask/src/commands/test.rs +++ b/xtask/src/commands/test.rs @@ -285,17 +285,16 @@ pub(crate) fn handle_command( )?; // burn-nn (pretrained and local tests) - let mut nn_features = "pretrained".to_string(); // If the "CI" environment variable is missing, we are running locally. - if std::env::var("CI").is_err() { - nn_features.push_str(",test-local"); - } + // if std::env::var("CI").is_err() { + // nn_features.push_str(",test-local"); + // } helpers::custom_crates_tests( vec!["burn-nn"], - handle_test_args(&["--features", &nn_features], args.release), + handle_test_args(&["--features", "pretrained"], args.release), None, None, - &format!("std burn-nn with features: {}", nn_features), + "std burn-nn", )?; } CiTestType::GcpCudaRunner => (),