Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fusion wgpu compilation cache #1069

Merged
merged 14 commits into from
Dec 18, 2023
2 changes: 1 addition & 1 deletion burn-fusion/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 96,7 @@ pub trait OptimizationBuilder<B: FusionBackend>: Send {
/// The operation created from the [builder](OptimizationBuilder).
pub trait Optimization<B: FusionBackend>: Send {
/// Execute the operation.
fn execute(&self, context: &mut Context<'_, B>);
fn execute(&mut self, context: &mut Context<'_, B>);
/// The number of registered operations in this optimization.
fn len(&self) -> usize;
/// If the current optimization is empty.
Expand Down
2 changes: 1 addition & 1 deletion burn-fusion/src/graph/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 53,7 @@ impl<B: FusionBackend> Graph<B> {
pub(crate) fn execute_optimization(
&mut self,
handles: &mut HandleContainer<B>,
optimization: &dyn Optimization<B>,
optimization: &mut dyn Optimization<B>,
) {
let num_keep = optimization.len();
let mut context = self.converter.context(handles);
Expand Down
14 changes: 0 additions & 14 deletions burn-fusion/src/graph/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -682,20 682,6 @@ impl<E: Element> NumericOpsDescription<E> {
out: desc.out.to_relative(converter),
})
}
NumericOpsDescription::ClampMax(desc) => {
NumericOpsDescription::ClampMax(ScalarOpsDescription {
lhs: desc.lhs.to_relative(converter),
rhs: local_elem(converter, &desc.rhs),
out: desc.out.to_relative(converter),
})
}
NumericOpsDescription::ClampMin(desc) => {
NumericOpsDescription::ClampMin(ScalarOpsDescription {
lhs: desc.lhs.to_relative(converter),
rhs: local_elem(converter, &desc.rhs),
out: desc.out.to_relative(converter),
})
}
}
}
}
Expand Down
10 changes: 5 additions & 5 deletions burn-fusion/src/graph/execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 71,7 @@ impl<B: FusionBackend> GraphExecution<B> {
};
}
CacheResult::Found(ops) => {
graph.execute_optimization(handles, ops.as_ref());
graph.execute_optimization(handles, ops.as_mut());
self.reset(graph);
}
};
Expand Down Expand Up @@ -107,14 107,14 @@ impl<B: FusionBackend> GraphExecution<B> {
}
}

match find_best_optimization_index(&self.optimizations) {
match find_best_optimization_index(&mut self.optimizations) {
Some(index) => {
let (relative, next_ops) = Self::split_relative_graph_owned(graph, mode);
let optimization = &self.optimizations[index];
let ops = self
.optimization_cache
.complete(optimization, relative, next_ops);
BuildAction::ExecuteOptimization(ops.as_ref())
BuildAction::ExecuteOptimization(ops.as_mut())
}
None => {
// TODO: Cache this result too.
Expand Down Expand Up @@ -184,7 184,7 @@ impl<B: FusionBackend> GraphExecution<B> {
}

enum BuildAction<'a, B: FusionBackend> {
ExecuteOptimization(&'a dyn Optimization<B>),
ExecuteOptimization(&'a mut dyn Optimization<B>),
ExecuteOperations,
ContinueBuilding,
}
Expand All @@ -202,7 202,7 @@ fn still_optimizing<B: FusionBackend>(optimizations: &[Box<dyn OptimizationBuild
}

fn find_best_optimization_index<B: FusionBackend>(
optimizations: &[Box<dyn OptimizationBuilder<B>>],
optimizations: &mut [Box<dyn OptimizationBuilder<B>>],
) -> Option<usize> {
let mut best_index = None;
let mut best_score = 0;
Expand Down
18 changes: 0 additions & 18 deletions burn-fusion/src/graph/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,16 379,6 @@ pub enum NumericOpsDescription<E> {
/// Float => [clamp](burn_tensor::ops::TensorOps::clamp).
/// Int => [clamp](burn_tensor::ops::IntTensorOps::int_clamp).
Clamp(ClampOpsDescription<E>),
/// Operation corresponding to:
///
/// Float => [clamp max](burn_tensor::ops::TensorOps::clamp_max).
/// Int => [clamp max](burn_tensor::ops::IntTensorOps::int_clamp_max).
ClampMax(ScalarOpsDescription<E>),
/// Operation corresponding to:
///
/// Float => [clamp min](burn_tensor::ops::TensorOps::clamp_min).
/// Int => [cleamp min](burn_tensor::ops::IntTensorOps::int_clamp_min).
ClampMin(ScalarOpsDescription<E>),
}

/// Operation description specific to an int tensor.
Expand Down Expand Up @@ -900,12 890,6 @@ impl<E: Element> NumericOpsDescription<E> {
NumericOpsDescription::Clamp(desc) => {
vec![&desc.tensor, &desc.out]
}
NumericOpsDescription::ClampMin(desc) => {
vec![&desc.lhs, &desc.out]
}
NumericOpsDescription::ClampMax(desc) => {
vec![&desc.lhs, &desc.out]
}
NumericOpsDescription::Abs(desc) => {
vec![&desc.input, &desc.out]
}
Expand Down Expand Up @@ -1144,8 1128,6 @@ impl<E> core::hash::Hash for NumericOpsDescription<E> {
NumericOpsDescription::MaxDim(desc) => desc.hash(state),
NumericOpsDescription::MinDim(desc) => desc.hash(state),
NumericOpsDescription::Clamp(desc) => desc.hash(state),
NumericOpsDescription::ClampMax(desc) => desc.hash(state),
NumericOpsDescription::ClampMin(desc) => desc.hash(state),
}
}
}
25 changes: 14 additions & 11 deletions burn-fusion/src/graph/path/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 60,13 @@ impl<O> OptimizationCache<O> {
}

if let Some(candidate) = self.found {
return CacheResult::Found(&self.optimizations.get(candidate).unwrap().value);
return CacheResult::Found(&mut self.optimizations.get_mut(candidate).unwrap().value);
}

// Invalidate candidates.
let mut invalidated_candidate = Vec::new();
for id in self.candidates.iter() {
let item = match self.optimizations.get(*id) {
Some(item) => item,
None => panic!("Should have an optimization"),
};
let item = &self.optimizations[*id];
let next_ops = graph.last().expect("Validated earlier");
let next_ops_index = graph.len() - 1;
let next_ops_candidate = match item.graph.get(next_ops_index) {
Expand All @@ -93,20 90,24 @@ impl<O> OptimizationCache<O> {
Condition::NextOps(ops) => ops,
Condition::Sync => {
self.found = Some(*id);
return CacheResult::Found(&item.value);
break;
}
};

if item.end_conditions.contains(ops) {
self.found = Some(*id);
return CacheResult::Found(&item.value);
break;
} else {
self.availables.push((*id, graph.len()));
invalidated_candidate.push(*id);
}
}
}

if let Some(id) = self.found {
return CacheResult::Found(&mut self.optimizations[id].value);
}

let mut updated_candidates = Vec::new();
core::mem::swap(&mut updated_candidates, &mut self.candidates);

Expand Down Expand Up @@ -136,7 137,7 @@ impl<O> OptimizationCache<O> {
factory: &Factory,
graph: Vec<TensorOpsDescription>,
next_ops: Option<TensorOpsDescription>,
) -> &'a O {
) -> &'a mut O {
let existing_optim = self
.availables
.iter()
Expand All @@ -149,7 150,7 @@ impl<O> OptimizationCache<O> {
optimization.end_conditions.push(ops)
};

return &optimization.value;
return &mut optimization.value;
};

self.starters
Expand All @@ -164,7 165,9 @@ impl<O> OptimizationCache<O> {
};

self.optimizations.push(optimization);
&self.optimizations.last().unwrap().value

let last_index = self.optimizations.len() - 1;
&mut self.optimizations[last_index].value
}

// Signal that a new path will begin.
Expand All @@ -188,7 191,7 @@ pub enum CacheResult<'a, T> {
/// happens.
OnPath,
/// An optimization has been found, and the best action is to execute it!
Found(&'a T),
Found(&'a mut T),
}

/// When checking if an optimization is possible, a start or an end condition ensures that this optimization is
Expand Down
42 changes: 0 additions & 42 deletions burn-fusion/src/ops/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,48 265,6 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}

fn clamp_min<const D: usize>(
tensor: FloatTensor<Self, D>,
min: FloatElem<Self>,
) -> FloatTensor<Self, D> {
scalar_float_ops!(ClampMinOps, B::clamp_min);

let out = tensor.client.tensor_uninitialized(tensor.shape.clone());

let desc = ScalarOpsDescription {
lhs: tensor.into_description(),
rhs: min.elem(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsFloat(NumericOpsDescription::ClampMin(desc.clone())),
ClampMinOps::<D>::new(desc),
);

out
}

fn clamp_max<const D: usize>(
tensor: FloatTensor<Self, D>,
max: FloatElem<Self>,
) -> FloatTensor<Self, D> {
scalar_float_ops!(ClampMaxOps, B::clamp_max);

let out = tensor.client.tensor_uninitialized(tensor.shape.clone());

let desc = ScalarOpsDescription {
lhs: tensor.into_description(),
rhs: max.elem(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsFloat(NumericOpsDescription::ClampMax(desc.clone())),
ClampMaxOps::<D>::new(desc),
);

out
}

fn clamp<const D: usize>(
tensor: FloatTensor<Self, D>,
min: FloatElem<Self>,
Expand Down
42 changes: 0 additions & 42 deletions burn-fusion/src/ops/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1034,48 1034,6 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
out
}

fn int_clamp_min<const D: usize>(
tensor: IntTensor<Self, D>,
min: IntElem<Self>,
) -> IntTensor<Self, D> {
scalar_int_ops!(ClampMinOps, B::int_clamp_min);

let out = tensor.client.tensor_uninitialized(tensor.shape.clone());

let desc = ScalarOpsDescription {
lhs: tensor.into_description(),
rhs: min.elem(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::ClampMin(desc.clone())),
ClampMinOps::<D>::new(desc),
);

out
}

fn int_clamp_max<const D: usize>(
tensor: IntTensor<Self, D>,
max: IntElem<Self>,
) -> IntTensor<Self, D> {
scalar_int_ops!(ClampMaxOps, B::int_clamp_max);

let out = tensor.client.tensor_uninitialized(tensor.shape.clone());

let desc = ScalarOpsDescription {
lhs: tensor.into_description(),
rhs: max.elem(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::ClampMax(desc.clone())),
ClampMaxOps::<D>::new(desc),
);

out
}

fn int_clamp<const D: usize>(
tensor: IntTensor<Self, D>,
min: IntElem<Self>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 5,7 @@ use std::fmt::Display;
///
/// Note that the body assumes that the kernel will run on a 2D grid defined by the workgroup size
/// X and Y, but with Z=1.
#[derive(Hash, new)]
#[derive(new)]
pub struct Body {
operators: Vec<Operator>,
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 2,7 @@ use super::Elem;
use std::fmt::Display;

/// Not all functions are native to WGSL, so this struct allows to support more functions.
#[derive(Hash, PartialEq, Eq, Clone)]
#[derive(PartialEq, Eq, Clone)]
pub enum Function {
Powf(Elem),
Erf(Elem),
Expand Down
Loading
Loading