Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New autodiff graph memory management strategy #1698

Merged
merged 9 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix and clean
  • Loading branch information
louisfd committed Apr 26, 2024
commit 99860acf63846a8e0d9ff7cea973b81fe4843b11
36 changes: 11 additions & 25 deletions crates/burn-autodiff/src/runtime/memory_management.rs
Original file line number Diff line number Diff line change
@@ -1,7 1,6 @@
use crate::{tensor::NodeRefCount, NodeID};
use std::{
collections::{HashMap, HashSet},
fmt::Display,
mem,
sync::Arc,
};
Expand All @@ -13,20 12,6 @@ pub struct GraphMemoryManagement {
statuses: HashMap<NodeID, NodeMemoryStatus>,
}

impl Display for GraphMemoryManagement {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(
format!(
"{} {} {}",
self.nodes.len(),
self.leaves.len(),
self.statuses.len()
)
.as_str(),
)
}
}

#[derive(Debug, Clone)]
enum NodeMemoryStatus {
Useful,
Expand All @@ -36,7 21,7 @@ enum NodeMemoryStatus {

#[derive(Clone)]
enum Mode {
Retain,
TagAsUseful,
Explore,
}

Expand Down Expand Up @@ -90,8 75,6 @@ impl GraphMemoryManagement {
}

// Replace leaves by the new ones and delete everything not useful anymore
println!("{:?}", new_leaves);
println!("{:?}", deletables);
mem::swap(&mut self.leaves, &mut new_leaves);
self.statuses.clear();
for node_to_delete in deletables {
Expand All @@ -108,6 91,8 @@ impl GraphMemoryManagement {

match self.nodes.get(&node_id).cloned() {
// If node exists and any of its parents is unavailable, it is unavailable as well
// If node exists but the parents vec is empty, it is a tensor that never had parents;
// the status remains unknown
Some(parents) => {
let mut node_status = NodeMemoryStatus::Unknown;
for parent in parents {
Expand All @@ -116,13 101,14 @@ impl GraphMemoryManagement {
node_status = NodeMemoryStatus::Unavailable;
}
}
self.statuses.insert(node_id.clone(), node_status.clone());
self.statuses.insert(node_id, node_status.clone());
node_status
}
// If node does not exist, it was deleted, and all its descendants are unavailable
// If node does not exist, it was either
// - deleted, so this all its descendants are unavailable
// - not requiring grad or detached, the status remains unknown (TODO REGISTER THEM WITH EMPTY PARENTS LIKE THOSE WITH REGISTER_GRAD)
None => {
self.statuses
.insert(node_id.clone(), NodeMemoryStatus::Unavailable);
self.statuses.insert(node_id, NodeMemoryStatus::Unavailable);
NodeMemoryStatus::Unavailable
}
}
Expand All @@ -132,10 118,10 @@ impl GraphMemoryManagement {
let parents = self.nodes.get(&node_id).cloned().unwrap_or(vec![]);

match mode {
Mode::Retain => {
Mode::TagAsUseful => {
self.statuses.insert(node_id, NodeMemoryStatus::Useful);
for parent in parents {
self.useful_propagation(parent, Mode::Retain)
self.useful_propagation(parent, Mode::TagAsUseful)
}
}
Mode::Explore => {
Expand All @@ -162,7 148,7 @@ impl GraphMemoryManagement {
let mut mode = Mode::Explore;
if self.is_referenced(node_id) {
self.statuses.insert(node_id, NodeMemoryStatus::Useful);
mode = Mode::Retain;
mode = Mode::TagAsUseful;
}

for parent in parents {
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-autodiff/src/runtime/mutex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 33,7 @@ impl AutodiffClient for MutexClient {
}
fn backward<B: Backend, const D: usize>(&self, root: AutodiffTensor<B, D>) -> Gradients {
let mut server = SERVER.lock();
let node_id = root.node.id.clone();
let node_id = root.node.id;
let grads = Gradients::new::<B, D>(root.node, root.primitive);

if let Some(server) = server.as_mut() {
Expand Down
34 changes: 6 additions & 28 deletions crates/burn-autodiff/src/runtime/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 8,6 @@ use crate::{
};
use std::collections::HashMap;

enum MemoryManagementMode {
/// Faster but introduces bugs
DeleteAll,
/// Slower but can't bug
KeepAll,
/// Nearly as fast as delete all, soon no bugs
Managed,
}

static GRAPH_MM_ACTIVATED: MemoryManagementMode = MemoryManagementMode::Managed;

#[derive(Default)]
pub struct AutodiffServer {
steps: HashMap<NodeID, StepBoxed>,
Expand All @@ -31,9 20,7 @@ impl AutodiffServer {
let parents = step.parents();
let node_id = *rc.as_ref();

if let MemoryManagementMode::Managed = GRAPH_MM_ACTIVATED {
self.memory_management.register(rc, parents);
}
self.memory_management.register(rc, parents);

self.steps.insert(node_id, step);
self.actions_builder.insert(node_id, actions);
Expand All @@ -51,19 38,12 @@ impl AutodiffServer {

let gradients = Self::execute_steps(tape, grads, checkpointer);

if let MemoryManagementMode::DeleteAll = GRAPH_MM_ACTIVATED {
self.steps.clear();
self.actions_builder.clear();
} else if let MemoryManagementMode::Managed = GRAPH_MM_ACTIVATED {
// Cleanup
let mut on_free_graph = |node_id: &NodeID| {
// Cleanup
self.memory_management
.free_unavailable_nodes(|node_id: &NodeID| {
self.steps.remove(node_id);
self.actions_builder.remove(node_id);
};

self.memory_management
.free_unavailable_nodes(&mut on_free_graph);
}
});

gradients
}
Expand All @@ -79,9 59,7 @@ impl AutodiffServer {
.collect::<Vec<_>>();

BreadthFirstSearch.traverse(node, node_step, &mut self.steps, |id, step| {
if let MemoryManagementMode::Managed = GRAPH_MM_ACTIVATED {
self.memory_management.consume_node(id);
}
self.memory_management.consume_node(id);

let depth = step.depth();
if depth == 0 {
Expand Down
10 changes: 7 additions & 3 deletions crates/burn-autodiff/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 18,7 @@ pub struct AutodiffTensor<B: Backend, const D: usize> {
pub type NodeRefCount = Arc<NodeID>;

#[derive(new, Debug)]
struct RootStep {
pub(crate) struct RootStep {
node: NodeRef,
}

Expand Down Expand Up @@ -57,7 57,7 @@ impl<B: Backend, const D: usize> AutodiffTensor<B, D> {
Self {
rc: Arc::new(node.id),
primitive,
node,
node: node.clone(),
}
}

Expand Down Expand Up @@ -113,7 113,11 @@ impl<B: Backend, const D: usize> AutodiffTensor<B, D> {
.unwrap_or_else(AutodiffClientImpl::new);

let node: NodeRef = Node::new(
parent_nodes.iter().map(|node| node.id).collect(),
parent_nodes
.iter()
.filter_map(|node| node.clone_if_require_grad())
.map(|node| node.id)
.collect(),
order,
NodeID::new(),
requirement,
Expand Down
13 changes: 9 additions & 4 deletions crates/burn-autodiff/src/tests/cat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 6,22 @@ mod tests {
#[test]
fn should_diff_cat() {
let device = Default::default();
let tensor_1 =
TestAutodiffTensor::from_data([[2.0, -1.0], [5.0, 2.0]], &device).require_grad();
let tensor_2 =
TestAutodiffTensor::from_data([[5.0, 4.0], [-1.0, 4.0]], &device).require_grad();
let data_1 = Data::from([[2.0, -1.0], [5.0, 2.0]]);
let data_2 = Data::from([[5.0, 4.0], [-1.0, 4.0]]);

let tensor_1 = TestAutodiffTensor::from_data(data_1.clone(), &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2.clone(), &device).require_grad();

let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let grads = tensor_3.backward();

let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();

// Redeclared because consumed in previous backward
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();

let mut tensor_1_list = Vec::new();
let mut tensor_2_list = Vec::new();

Expand Down
79 changes: 31 additions & 48 deletions crates/burn-autodiff/src/tests/memory_management.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 4,7 @@ mod tests {
use burn_tensor::{activation, Data, Tensor};

#[test]
fn test_mm_independant_trees() {
fn test_mm_independent_trees() {
let data = Data::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();

Expand Down Expand Up @@ -189,70 189,53 @@ mod tests {
assert!(tensor_3.grad(&grads).is_some());
}

// #[test]
fn test_mm_with_impacting_detach() {
let data_1 = Data::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = Data::from([[6.0, 7.0], [9.0, 10.0]]);

#[test]
fn test_mm_with_missing_require_grad_after_cleanup() {
let data = Data::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
let tensor_1 = Tensor::<TestAutodiffBackend, 2>::from_data(data_1, &device).require_grad();
let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(data_2, &device).require_grad();

let tensor_3 = tensor_1.clone() * tensor_2.clone();

let tensor_3a = tensor_3.clone() - tensor_3.detach().max_dim(1);
let tensor_3b = tensor_3a.exp();
let tensor_3c = tensor_3b.clone().sum_dim(1);

let tensor_3d = tensor_3b.div(tensor_3c);
let tensor_1 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device);
let tensor_3 = Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device);

let tensor_4 = tensor_3d.matmul(tensor_2.clone());
let tensor_4 = tensor_1.clone() * tensor_2.clone();
let tensor_5 = tensor_4 * tensor_3.clone();

let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
// Trivial backward, just to trigger cleanup
Tensor::<TestAutodiffBackend, 2>::from_data(data, &device)
.require_grad()
.backward();

grad_1
.to_data()
.assert_approx_eq(&Data::from([[1.1797, 1.1797], [0.0055, 0.0055]]), 3);
grad_2
.to_data()
.assert_approx_eq(&Data::from([[0.2534, 0.2862], [0.5286, 2.9317]]), 3);
let grads = tensor_5.backward();
assert!(tensor_1.grad(&grads).is_some());
assert!(tensor_2.grad(&grads).is_none());
assert!(tensor_3.grad(&grads).is_none());
}

#[test]
fn test_mm_with_missing_require_grad() {
fn test_mm_with_detach_after_cleanup() {
let data = Data::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();

let tensor_1 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device);
let tensor_3 = Tensor::<TestAutodiffBackend, 2>::from_data(data, &device);
let tensor_2 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_3 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();

let tensor_4 = tensor_1.clone() * tensor_2.clone();
let tensor_5 = tensor_4 * tensor_3.clone();
let tensor_5 = tensor_4 * tensor_3.clone().detach();

// Trivial backward, just to trigger cleanup
Tensor::<TestAutodiffBackend, 2>::from_data(data, &device)
.require_grad()
.backward();

let grads = tensor_5.backward();
assert!(tensor_1.grad(&grads).is_some());
assert!(tensor_2.grad(&grads).is_none());
assert!(tensor_2.grad(&grads).is_some());
assert!(tensor_3.grad(&grads).is_none());
}

// #[test]
// fn test_mm_with_detach() {
// let data = Data::from([[1.0, 2.0], [3.0, 4.0]]);
// let device = Default::default();
// let tensor_1 =
// Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
// let tensor_2 =
// Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
// let tensor_3 = Tensor::<TestAutodiffBackend, 2>::from_data(data, &device).require_grad();

// let tensor_4 = tensor_1.clone() * tensor_2.clone();
// let tensor_5 = tensor_4 * tensor_3.detach();

// let grads = tensor_5.backward();
// assert!(tensor_1.grad(&grads).is_some());
// assert!(tensor_2.grad(&grads).is_some());
// }
}
25 changes: 16 additions & 9 deletions crates/burn-autodiff/src/tests/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 85,23 @@ mod tests {
let data_3: Data<f32, 2> = Data::from([[14.0, 97.0, 100.0, 9.0], [2.0, 3.0, 15.0, 7.0]]);

let device = Default::default();
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_1_slice = TestAutodiffTensor::from_data(data_1.clone(), &device).require_grad();
let tensor_2_slice = TestAutodiffTensor::from_data(data_2.clone(), &device).require_grad();

let tensor_1_cat = TestAutodiffTensor::from_data(data_1, &device).require_grad();
let tensor_2_cat = TestAutodiffTensor::from_data(data_2, &device).require_grad();

let tensor_3 = TestAutodiffTensor::from_data(data_3, &device);

let slice_assign_output = TestAutodiffTensor::zeros([2, 4], &Default::default());
let slice_assign_output = slice_assign_output.slice_assign([0..2, 0..2], tensor_1.clone());
let slice_assign_output = slice_assign_output.slice_assign([0..2, 2..4], tensor_2.clone());
let slice_assign_output =
slice_assign_output.slice_assign([0..2, 0..2], tensor_1_slice.clone());
let slice_assign_output =
slice_assign_output.slice_assign([0..2, 2..4], tensor_2_slice.clone());
let slice_assign_output = slice_assign_output / tensor_3.clone();

let cat_output = TestAutodiffTensor::cat(vec![tensor_1.clone(), tensor_2.clone()], 1);
let cat_output =
TestAutodiffTensor::cat(vec![tensor_1_cat.clone(), tensor_2_cat.clone()], 1);
let cat_output = cat_output / tensor_3;

slice_assign_output
Expand All @@ -104,10 111,10 @@ mod tests {
let slice_assign_grads = slice_assign_output.backward();
let cat_grads = cat_output.backward();

let slice_assign_grad_1 = tensor_1.grad(&slice_assign_grads).unwrap();
let slice_assign_grad_2 = tensor_2.grad(&slice_assign_grads).unwrap();
let cat_grad_1 = tensor_1.grad(&cat_grads).unwrap();
let cat_grad_2 = tensor_2.grad(&cat_grads).unwrap();
let slice_assign_grad_1 = tensor_1_slice.grad(&slice_assign_grads).unwrap();
let slice_assign_grad_2 = tensor_2_slice.grad(&slice_assign_grads).unwrap();
let cat_grad_1 = tensor_1_cat.grad(&cat_grads).unwrap();
let cat_grad_2 = tensor_2_cat.grad(&cat_grads).unwrap();

slice_assign_grad_1
.to_data()
Expand Down