Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate/jit/adaptive avg pool backward #1530

Merged
merged 17 commits into from
Mar 26, 2024
298 changes: 252 additions & 46 deletions crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d_backward.rs
Original file line number Diff line number Diff line change
@@ -1,17 1,244 @@
use std::marker::PhantomData;

use crate::{
compute::StaticKernel,
codegen::{
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, EagerHandle, InputInfo,
OutputInfo, WorkgroupLaunch,
},
element::JitElement,
kernel::{elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
kernel_wgsl,
gpu::{gpu, Elem, Scope, Variable, Visibility},
kernel::{DynamicKernelSource, SourceTemplate},
tensor::JitTensor,
Runtime,
Compiler, Runtime, RuntimeInt,
};
use burn_compute::server::Handle;

kernel_wgsl!(
AdaptiveAvgPool2dBackward,
"../../template/pool/adaptive_avg_pool2d_backward.wgsl"
);
#[derive(new)]
struct AdaptiveAvgPool2dBackwardEagerKernel<R, E> {
_runtime: PhantomData<R>,
_elem: PhantomData<E>,
}

struct AdaptiveAvgPool2dBackwardComputeShader {
grad: Variable,
output: Variable,
}

impl AdaptiveAvgPool2dBackwardComputeShader {
fn expand(self, scope: &mut Scope) {
let grad = self.grad;
let output = self.output;
let id = Variable::Id;

let grad_stride_0 = scope.create_local(Elem::UInt);
let grad_stride_1 = scope.create_local(Elem::UInt);
let grad_stride_2 = scope.create_local(Elem::UInt);
let grad_stride_3 = scope.create_local(Elem::UInt);

let grad_shape_2 = scope.create_local(Elem::UInt);
let grad_shape_3 = scope.create_local(Elem::UInt);

let output_stride_0 = scope.create_local(Elem::UInt);
let output_stride_1 = scope.create_local(Elem::UInt);
let output_stride_2 = scope.create_local(Elem::UInt);
let output_stride_3 = scope.create_local(Elem::UInt);

let output_shape_0 = scope.create_local(Elem::UInt);
let output_shape_1 = scope.create_local(Elem::UInt);
let output_shape_2 = scope.create_local(Elem::UInt);
let output_shape_3 = scope.create_local(Elem::UInt);

gpu!(scope, grad_stride_0 = stride(grad, 0u32));
gpu!(scope, grad_stride_1 = stride(grad, 1u32));
gpu!(scope, grad_stride_2 = stride(grad, 2u32));
gpu!(scope, grad_stride_3 = stride(grad, 3u32));

gpu!(scope, grad_shape_2 = shape(grad, 2u32));
gpu!(scope, grad_shape_3 = shape(grad, 3u32));

gpu!(scope, output_stride_0 = stride(output, 0u32));
gpu!(scope, output_stride_1 = stride(output, 1u32));
gpu!(scope, output_stride_2 = stride(output, 2u32));
gpu!(scope, output_stride_3 = stride(output, 3u32));

gpu!(scope, output_shape_0 = shape(output, 0u32));
gpu!(scope, output_shape_1 = shape(output, 1u32));
gpu!(scope, output_shape_2 = shape(output, 2u32));
gpu!(scope, output_shape_3 = shape(output, 3u32));

let b = scope.create_local(Elem::UInt);
let c = scope.create_local(Elem::UInt);
let ih = scope.create_local(Elem::UInt);
let iw = scope.create_local(Elem::UInt);

gpu!(scope, b = id / output_stride_0);
gpu!(scope, b = b % output_shape_0);

gpu!(scope, c = id / output_stride_1);
gpu!(scope, c = c % output_shape_1);

gpu!(scope, ih = id / output_stride_2);
gpu!(scope, ih = ih % output_shape_2);

gpu!(scope, iw = id / output_stride_3);
gpu!(scope, iw = iw % output_shape_3);

let oh_start = Self::start_index(scope, ih, output_shape_2, grad_shape_2);
let oh_end = Self::end_index(scope, ih, output_shape_2, grad_shape_2);

let ow_start = Self::start_index(scope, iw, output_shape_3, grad_shape_3);
let ow_end = Self::end_index(scope, iw, output_shape_3, grad_shape_3);

let grad_acc = scope.create_local(output.item());
let contributed_h = scope.create_local(Elem::Bool);
let contributed_w = scope.create_local(Elem::Bool);
let contributed_tmp = scope.create_local(Elem::Bool);

let count = scope.create_local(Elem::UInt);
let count_tmp = scope.create_local(Elem::UInt);
let count_float = scope.create_local(output.item());
let the_grad = scope.create_local(output.item());
let avg = scope.create_local(output.item());

let index_base = scope.create_local(Elem::UInt);
let index_tmp = scope.create_local(Elem::UInt);
let index = scope.create_local(Elem::UInt);
gpu!(scope, index_base = b * grad_stride_0);
gpu!(scope, index_tmp = c * grad_stride_1);
gpu!(scope, index_base = index_tmp);

gpu!(
scope,
range(oh_start, oh_end).for_each(|oh, scope| {
let ih_start = Self::start_index(scope, oh, grad_shape_2, output_shape_2);
let ih_end = Self::end_index(scope, oh, grad_shape_2, output_shape_2);
gpu!(scope, contributed_h = ih >= ih_start);
gpu!(scope, contributed_tmp = ih < ih_end);
gpu!(scope, contributed_h = contributed_h && contributed_tmp);

gpu!(scope, if(contributed_h).then(|scope|{
gpu!(
scope,
range(ow_start, ow_end).for_each(|ow, scope| {
let iw_start = Self::start_index(scope, ow, grad_shape_3, output_shape_3);
let iw_end = Self::end_index(scope, ow, grad_shape_3, output_shape_3);

gpu!(scope, contributed_w = iw >= iw_start);
gpu!(scope, contributed_tmp = iw < iw_end);
gpu!(scope, contributed_w = contributed_w && contributed_tmp);


gpu!(scope, if(contributed_w).then(|scope|{
gpu!(scope, count = ih_end - ih_start);
gpu!(scope, count_tmp = iw_end - iw_start);
gpu!(scope, count *= count_tmp);
gpu!(scope, count_float = cast(count));

gpu!(scope, index = index_base);
gpu!(scope, index_tmp = oh * grad_stride_2);
gpu!(scope, index = index_tmp);
gpu!(scope, index_tmp = ow * grad_stride_3);
gpu!(scope, index = index_tmp);

gpu!(scope, the_grad = grad[index]);
gpu!(scope, avg = the_grad / count_float);
gpu!(scope, grad_acc = avg);
}));
})
);
}));
})
);

gpu!(scope, output[id] = grad_acc);
}

fn start_index(
scope: &mut Scope,
output_size_index: Variable,
output_size: Variable,
input_size: Variable,
) -> Variable {
let numerator_float = scope.create_local(Elem::Float);
let div = scope.create_local(Elem::Float);
let index = scope.create_local(Elem::UInt);

gpu!(scope, index = output_size_index * input_size);
gpu!(scope, numerator_float = cast(index));
gpu!(scope, div = cast(output_size));
gpu!(scope, div = numerator_float / div);
gpu!(scope, div = floor(div));
gpu!(scope, index = cast(div));
index
}

fn end_index(
scope: &mut Scope,
output_size_index: Variable,
output_size: Variable,
input_size: Variable,
) -> Variable {
let numerator_float = scope.create_local(Elem::Float);
let div = scope.create_local(Elem::Float);
let index = scope.create_local(Elem::UInt);
let min = scope.create_local(Elem::Bool);
let end_index = scope.create_local(Elem::UInt);

gpu!(scope, index = output_size_index 1u32);
gpu!(scope, index *= input_size);
gpu!(scope, numerator_float = cast(index));
gpu!(scope, div = cast(output_size));
gpu!(scope, div = numerator_float / div);
gpu!(scope, div = ceil(div));
gpu!(scope, index = cast(div));

gpu!(scope, min = input_size < index);
gpu!(scope, if(min).then(|scope|{
gpu!(scope, end_index = input_size);
}).else(|scope|{
gpu!(scope, end_index = index);
}));
end_index
}
}

impl<R: Runtime, E: JitElement> DynamicKernelSource for AdaptiveAvgPool2dBackwardEagerKernel<R, E> {
fn source(&self) -> SourceTemplate {
let mut scope = Scope::root();
let item = E::gpu_elem().into();

let grad = Variable::GlobalInputArray(0, item);
let output = Variable::GlobalOutputArray(0, item);

scope.write_global_custom(output);

AdaptiveAvgPool2dBackwardComputeShader { grad, output }.expand(&mut scope);

let grad = InputInfo::Array {
item,
visibility: Visibility::Read,
};
let scalars = InputInfo::Scalar {
elem: Elem::UInt,
size: 6,
};
let output = OutputInfo::Array { item };

let info = CompilationInfo {
inputs: vec![grad, scalars],
outputs: vec![output],
scope,
};

let settings = CompilationSettings::default();
let shader = Compilation::new(info).compile(settings);
let shader = <R::Compiler as Compiler>::compile(shader);
SourceTemplate::new(shader.to_string())
}

fn id(&self) -> String {
format!("{:?}", core::any::TypeId::of::<Self>(),)
}
}

pub(crate) fn adaptive_avg_pool2d_backward<R: Runtime, E: JitElement>(
x: JitTensor<R, E, 4>,
Expand All @@ -27,45 254,24 @@ pub(crate) fn adaptive_avg_pool2d_backward<R: Runtime, E: JitElement>(
output_buffer,
);

let kernel = StaticKernel::<
KernelSettings<AdaptiveAvgPool2dBackward, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(
output.shape.num_elements(),
WORKGROUP_DEFAULT,
));
let kernel = AdaptiveAvgPool2dBackwardEagerKernel::new();

let info_handle = build_info(&x, &out_grad);

x.client.execute(
Box::new(kernel),
&[&out_grad.handle, &output.handle, &info_handle],
execute_dynamic::<R, AdaptiveAvgPool2dBackwardEagerKernel<R, E>, RuntimeInt<R>>(
&[EagerHandle::new(
&out_grad.handle,
&out_grad.strides,
&out_grad.shape.dims,
)],
&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
None,
kernel,
WorkgroupLaunch::Output { pos: 0 },
x.client,
);

output
}

fn build_info<R: Runtime, E: JitElement>(
x: &JitTensor<R, E, 4>,
output: &JitTensor<R, E, 4>,
) -> Handle<R::Server> {
let mut info: [u32; 16] = [0; 16];
info[0] = x.strides[0] as u32;
info[1] = x.strides[1] as u32;
info[2] = x.strides[2] as u32;
info[3] = x.strides[3] as u32;
info[4] = x.shape.dims[0] as u32;
info[5] = x.shape.dims[1] as u32;
info[6] = x.shape.dims[2] as u32;
info[7] = x.shape.dims[3] as u32;

info[8] = output.strides[0] as u32;
info[9] = output.strides[1] as u32;
info[10] = output.strides[2] as u32;
info[11] = output.strides[3] as u32;
info[12] = output.shape.dims[0] as u32;
info[13] = output.shape.dims[1] as u32;
info[14] = output.shape.dims[2] as u32;
info[15] = output.shape.dims[3] as u32;

output.client.create(bytemuck::cast_slice(&info))
}
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/pool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 11,7 @@ pub(crate) use adaptive_pool2d_shader::*;
pub(crate) use pool2d_shader::*;

pub(crate) use adaptive_avg_pool2d::*;
pub use adaptive_avg_pool2d_backward::*;
pub(crate) use adaptive_avg_pool2d_backward::*;
pub(crate) use avg_pool2d::*;
pub(crate) use avg_pool2d_backward::*;
pub(super) use base::*;
Expand Down
Loading
Loading