Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New autodiff graph memory management strategy #1698

Merged
merged 9 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backend-comparison/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 88,10 @@ name = "custom-gelu"
path = "benches/custom_gelu.rs"
harness = false

[[bench]]
name = "autodiff"
harness = false

[[bin]]
name = "burnbench"
path = "src/bin/burnbench.rs"
81 changes: 81 additions & 0 deletions backend-comparison/benches/autodiff.rs
Original file line number Diff line number Diff line change
@@ -0,0 1,81 @@
use backend_comparison::persistence::save;
use burn::{
module::Module,
nn,
tensor::{
backend::{AutodiffBackend, Backend},
Distribution, Tensor,
},
};
use burn_common::benchmark::{run_benchmark, Benchmark};

pub struct AutodiffOverheadBenchmark<B: AutodiffBackend> {
config: nn::LstmConfig,
lstm: nn::Lstm<B>,
device: B::Device,
}

impl<B: AutodiffBackend> Benchmark for AutodiffOverheadBenchmark<B> {
type Args = Tensor<B, 3>;

fn name(&self) -> String {
"autodiff_overhead".into()
}

fn shapes(&self) -> Vec<Vec<usize>> {
vec![]
}

fn execute(&self, input: Self::Args) {
for _ in 0..20 {
let input = input.clone().detach();
let mut cell = input.clone();
let lstm = self.lstm.clone().fork(&input.device());

for _ in 0..10 {
let (cells, _) = lstm.forward(input.clone(), None);
cell = cell cells;
}

cell.backward();
}
}

fn prepare(&self) -> Self::Args {
let shape = [1, 3, self.config.d_hidden];
Tensor::random(shape, Distribution::Default, &self.device)
}

fn sync(&self) {
B::sync(&self.device)
}
}

#[allow(dead_code)]
fn bench<B: Backend>(
device: &B::Device,
feature_name: &str,
url: Option<&str>,
token: Option<&str>,
) {
let config = nn::LstmConfig::new(3, 3, true);
let lstm = config.init(device);
let benchmark = AutodiffOverheadBenchmark::<burn::backend::Autodiff<B>> {
lstm,
config,
device: device.clone(),
};

save::<B>(
vec![run_benchmark(benchmark)],
device,
feature_name,
url,
token,
)
.unwrap();
}

fn main() {
backend_comparison::bench_on_backend!();
}
2 changes: 2 additions & 0 deletions backend-comparison/src/burnbenchapp/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 98,8 @@ enum BenchmarkValues {
MaxPool2d,
#[strum(to_string = "load-record")]
LoadRecord,
#[strum(to_string = "autodiff")]
Autodiff,
}

pub fn execute() {
Expand Down
1 change: 1 addition & 0 deletions crates/burn-autodiff/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 14,7 @@ version.workspace = true
default = ["std"]
export_tests = ["burn-tensor-testgen"]
std = []
async = [] # Require std

[dependencies]
burn-common = { path = "../burn-common", version = "0.14.0" }
Expand Down
3 changes: 2 additions & 1 deletion crates/burn-autodiff/src/graph/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 6,12 @@ use std::collections::HashMap;
pub trait Step: Send std::fmt::Debug {
/// Executes the step and consumes it.
fn step(self: Box<Self>, grads: &mut Gradients, checkpointer: &mut Checkpointer);
/// Depth of the operation relative to the first node added to a graph.
fn depth(&self) -> usize;
/// The node associated to the step.
fn node(&self) -> NodeID;
/// The parents of the node associated to the step.
fn parents(&self) -> Vec<NodeID>;
fn order(&self) -> usize;
}

pub type StepBoxed = Box<dyn Step>;
Expand Down
40 changes: 28 additions & 12 deletions crates/burn-autodiff/src/graph/traversal.rs
Original file line number Diff line number Diff line change
@@ -1,27 1,33 @@
use std::collections::{HashMap, HashSet};

use super::{Step, StepBoxed};
use crate::NodeID;

use super::StepBoxed;
use std::collections::{HashMap, HashSet};

/// Breadth for search algorithm.
pub struct BreadthFirstSearch;

pub trait TraversalItem {
fn id(&self) -> NodeID;
fn parents(&self) -> Vec<NodeID>;
}

impl BreadthFirstSearch {
/// Traverse the graph of backward steps from a root node.
pub fn traverse<F: FnMut(NodeID, StepBoxed)>(
pub fn traverse<F, I>(
&self,
root_id: NodeID,
root_step: StepBoxed,
steps: &mut HashMap<NodeID, StepBoxed>,
root_step: I,
steps: &mut HashMap<NodeID, I>,
mut callback: F,
) {
let root_order = root_step.order();
let mut visited = HashSet::with_capacity(root_order);
let mut parents = Vec::with_capacity(root_order);
) where
F: FnMut(NodeID, I),
I: TraversalItem,
{
let mut visited = HashSet::new();
let mut parents = Vec::new();

visited.insert(root_id);
parents.append(&mut root_step.parents());

callback(root_id, root_step);

while let Some(id) = parents.pop() {
Expand All @@ -30,7 36,7 @@ impl BreadthFirstSearch {
None => continue,
};

let step_node = step.node();
let step_node = step.id();
let step_parents = step.parents();

if visited.contains(&step_node) {
Expand All @@ -49,3 55,13 @@ impl BreadthFirstSearch {
}
}
}

impl TraversalItem for StepBoxed {
fn id(&self) -> NodeID {
Step::node(self.as_ref())
}

fn parents(&self) -> Vec<NodeID> {
Step::parents(self.as_ref())
}
}
4 changes: 2 additions & 2 deletions crates/burn-autodiff/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 271,7 @@ where
self.ops.node.parents.clone()
}

fn order(&self) -> usize {
fn depth(&self) -> usize {
self.ops.node.order
}
}
Expand All @@ -293,7 293,7 @@ impl<const N: usize> Step for UntrackedOpsStep<N> {
fn parents(&self) -> Vec<NodeID> {
self.ops.node.parents.clone()
}
fn order(&self) -> usize {
fn depth(&self) -> usize {
self.ops.node.order
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-autodiff/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2068,7 2068,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
.map(|node| node.id)
.collect()
}
fn order(&self) -> usize {
fn depth(&self) -> usize {
self.output.order
}
}
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-autodiff/src/runtime/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 15,9 @@ pub trait AutodiffClient: Send Clone {
}

/// Client implementation in used.
#[cfg(feature = "std")]
#[cfg(feature = "async")]
pub type AutodiffClientImpl = super::mspc::ChannelClient;

/// Client implementation in used.
#[cfg(not(feature = "std"))]
#[cfg(not(feature = "async"))]
pub type AutodiffClientImpl = super::mutex::MutexClient;
Loading
Loading