Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch config extraction #1323

Merged
merged 7 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 95,7 @@ serial_test = "3.0.0"
web-time = "1.0.0"
hound = "3.5.1"
image = "0.24.7"
zip = "0.6.6"

# Terminal UI
ratatui = "0.25"
Expand Down Expand Up @@ -134,7 135,7 @@ num-traits = { version = "0.2.18", default-features = false, features = [
] } # libm is for no_std
rand = { version = "0.8.5", default-features = false, features = [
"std_rng",
] } # std_rng is for no_std
] } # std_rng is for no_std
rand_distr = { version = "0.4.3", default-features = false }
serde = { version = "1.0.192", default-features = false, features = [
"derive",
Expand All @@ -150,6 151,5 @@ sysinfo = "0.29.10"
systemstat = "0.2.3"



[profile.dev]
debug = 0 # Speed up compilation time and not necessary.
44 changes: 44 additions & 0 deletions burn-book/src/import/pytorch-model.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 139,50 @@ something like this:
}
```

## Extract Configuration

In some cases, models may require additional configuration settings, which are often included in a
`.pt` file during export. The `config_from_file` function from the `burn-import` cargo package
allows for the extraction of these configurations directly from the `.pt` file. The extracted
configuration can then be used to initialize the model in Burn. Here is an example of how to extract
the configuration from a `.pt` file:

```rust
use std::collections::HashMap;

use burn::config::Config;
use burn_import::pytorch::config_from_file;

#[derive(Debug, Config)]
struct NetConfig {
n_head: usize,
n_layer: usize,
d_model: usize,
// Candle's pickle has a bug with float serialization
// https://github.com/huggingface/candle/issues/1729
// some_float: f64,
some_int: i32,
some_bool: bool,
some_str: String,
some_list_int: Vec<i32>,
some_list_str: Vec<String>,
// Candle's pickle has a bug with float serialization
// https://github.com/huggingface/candle/issues/1729
// some_list_float: Vec<f64>,
some_dict: HashMap<String, String>,
}

fn main() {
let path = "weights_with_config.pt";
let top_level_key = Some("my_config");
let config: NetConfig = config_from_file(path, top_level_key).unwrap();
println!("{:#?}", config);

// After extracting, it's recommended you save it as a json file.
config.save("my_config.json").unwrap();
}
```

## Troubleshooting

### Adjusting the source model architecture
Expand Down
7 changes: 4 additions & 3 deletions burn-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 83,7 @@ candle = ["burn-candle"]
wgpu = ["burn-wgpu"]

# Custom deserializer for Record that is helpful for importing data, such as PyTorch pt files.
record-item-custom-serde = ["thiserror", "regex"]
record-item-custom-serde = ["thiserror", "regex", "num-traits"]

# Serialization formats
experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]
Expand All @@ -110,7 110,7 @@ burn-candle = { path = "../burn-candle", version = "0.13.0", optional = true }
derive-new = { workspace = true }
libm = { workspace = true }
log = { workspace = true, optional = true }
rand = { workspace = true, features = ["std_rng"] } # Default enables std
rand = { workspace = true, features = ["std_rng"] } # Default enables std
# Using in place of use std::sync::Mutex when std is disabled
spin = { workspace = true, features = ["mutex", "spin_mutex"] }

Expand All @@ -124,9 124,10 @@ serde = { workspace = true, features = ["derive"] }
bincode = { workspace = true }
half = { workspace = true }
rmp-serde = { workspace = true, optional = true }
serde_json = { workspace = true, features = ["alloc"] } #Default enables std
serde_json = { workspace = true, features = ["alloc"] } #Default enables std
thiserror = { workspace = true, optional = true }
regex = { workspace = true, optional = true }
num-traits = {workspace = true, optional = true }

[dev-dependencies]
tempfile = { workspace = true }
Expand Down
23 changes: 23 additions & 0 deletions burn-core/src/record/serde/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 7,7 @@ use super::ser::Serializer;
use crate::record::{PrecisionSettings, Record};
use crate::tensor::backend::Backend;

use num_traits::cast::ToPrimitive;
use regex::Regex;
use serde::Deserialize;

Expand Down Expand Up @@ -82,6 83,7 @@ impl NestedValue {
pub fn as_f32(self) -> Option<f32> {
match self {
NestedValue::F32(f32) => Some(f32),
NestedValue::F64(f) => f.to_f32(),
_ => None,
}
}
Expand All @@ -90,6 92,7 @@ impl NestedValue {
pub fn as_f64(self) -> Option<f64> {
match self {
NestedValue::F64(f64) => Some(f64),
NestedValue::F32(f) => f.to_f64(),
_ => None,
}
}
Expand All @@ -98,6 101,10 @@ impl NestedValue {
pub fn as_i16(self) -> Option<i16> {
match self {
NestedValue::I16(i16) => Some(i16),
NestedValue::I32(i) => i.to_i16(),
NestedValue::I64(i) => i.to_i16(),
NestedValue::U16(u) => u.to_i16(),
NestedValue::U64(u) => u.to_i16(),
_ => None,
}
}
Expand All @@ -106,6 113,10 @@ impl NestedValue {
pub fn as_i32(self) -> Option<i32> {
match self {
NestedValue::I32(i32) => Some(i32),
NestedValue::I16(i) => i.to_i32(),
NestedValue::I64(i) => i.to_i32(),
NestedValue::U16(u) => u.to_i32(),
NestedValue::U64(u) => u.to_i32(),
_ => None,
}
}
Expand All @@ -114,6 125,10 @@ impl NestedValue {
pub fn as_i64(self) -> Option<i64> {
match self {
NestedValue::I64(i64) => Some(i64),
NestedValue::I16(i) => i.to_i64(),
NestedValue::I32(i) => i.to_i64(),
NestedValue::U16(u) => u.to_i64(),
NestedValue::U64(u) => u.to_i64(),
_ => None,
}
}
Expand All @@ -122,6 137,10 @@ impl NestedValue {
pub fn as_u16(self) -> Option<u16> {
match self {
NestedValue::U16(u16) => Some(u16),
NestedValue::I16(i) => i.to_u16(),
NestedValue::I32(i) => i.to_u16(),
NestedValue::I64(i) => i.to_u16(),
NestedValue::U64(u) => u.to_u16(),
_ => None,
}
}
Expand All @@ -130,6 149,10 @@ impl NestedValue {
pub fn as_u64(self) -> Option<u64> {
match self {
NestedValue::U64(u64) => Some(u64),
NestedValue::I16(i) => i.to_u64(),
NestedValue::I32(i) => i.to_u64(),
NestedValue::I64(i) => i.to_u64(),
NestedValue::U16(u) => u.to_u64(),
_ => None,
}
}
Expand Down
3 changes: 2 additions & 1 deletion burn-import/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 16,7 @@ default-run = "onnx2burn"
[features]
default = ["onnx", "pytorch"]
onnx = []
pytorch = ["burn/record-item-custom-serde", "thiserror"]
pytorch = ["burn/record-item-custom-serde", "thiserror", "zip"]

[dependencies]
burn = { path = "../burn", version = "0.13.0", features = ["ndarray"] }
Expand All @@ -39,6 39,7 @@ syn = { workspace = true, features = ["parsing"] }
thiserror = { workspace = true, optional = true }
tracing-core = { workspace = true }
tracing-subscriber = { workspace = true }
zip = { workspace = true, optional = true }

[build-dependencies]
protobuf-codegen = { workspace = true }
Expand Down
Binary file removed burn-import/data/mnist.pt
Binary file not shown.
60 changes: 60 additions & 0 deletions burn-import/pytorch-tests/tests/config/export_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 1,60 @@
#!/usr/bin/env python3

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = nn.Linear(2, 3)
self.fc2 = nn.Linear(3, 4, bias=False)

def forward(self, x):
x = self.fc1(x)
x = F.relu(x) # Add relu so that PyTorch optimizer does not combine fc1 and fc2
x = self.fc2(x)

return x

CONFIG = {
"n_head": 2,
"n_layer": 3,
"d_model": 512,
"some_float": 0.1,
"some_int": 1,
"some_bool": True,
"some_str": "hello",
"some_list_int": [1, 2, 3],
"some_list_str": ["hello", "world"],
"some_list_float": [0.1, 0.2, 0.3],
"some_dict": {
"some_key": "some_value"
}
}

class ModelWithBias(nn.Module):
def __init__(self):
super(ModelWithBias, self).__init__()
self.fc1 = nn.Linear(2, 3)

def forward(self, x):
x = self.fc1(x)

return x


def main():

model = Model().to(torch.device("cpu"))

weights_with_config = {
"my_model": model.state_dict(),
"my_config": CONFIG
}

torch.save(weights_with_config, "weights_with_config.pt")


if __name__ == '__main__':
main()
61 changes: 61 additions & 0 deletions burn-import/pytorch-tests/tests/config/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 1,61 @@
#![allow(clippy::too_many_arguments)] // To mute derive Config warning
use std::collections::HashMap;

use burn::config::Config;

#[allow(clippy::too_many_arguments)]
#[derive(Debug, PartialEq, Config)]
struct NetConfig {
n_head: usize,
n_layer: usize,
d_model: usize,
// Candle's pickle has a bug with float serialization
// https://github.com/huggingface/candle/issues/1729
// some_float: f64,
some_int: i32,
some_bool: bool,
some_str: String,
some_list_int: Vec<i32>,
some_list_str: Vec<String>,
// Candle's pickle has a bug with float serialization
// https://github.com/huggingface/candle/issues/1729
// some_list_float: Vec<f64>,
some_dict: HashMap<String, String>,
}

#[cfg(test)]
mod tests {
use burn_import::pytorch::config_from_file;

use super::*;

#[test]
fn test_net_config() {
let config_expected = NetConfig {
n_head: 2,
n_layer: 3,
d_model: 512,
// Candle's pickle has a bug with float serialization
// https://github.com/huggingface/candle/issues/1729
// some_float: 0.1,
some_int: 1,
some_bool: true,
some_str: "hello".to_string(),
some_list_int: vec![1, 2, 3],
some_list_str: vec!["hello".to_string(), "world".to_string()],
// Candle's pickle has a bug with float serialization
// https://github.com/huggingface/candle/issues/1729
// some_list_float: vec![0.1, 0.2, 0.3],
some_dict: {
let mut map = HashMap::new();
map.insert("some_key".to_string(), "some_value".to_string());
map
},
};
let path = "tests/config/weights_with_config.pt";
let top_level_key = Some("my_config");
let config: NetConfig = config_from_file(path, top_level_key).unwrap();

assert_eq!(config, config_expected);
}
}
Binary file not shown.
1 change: 1 addition & 0 deletions burn-import/pytorch-tests/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 6,7 @@ cfg_if::cfg_if! {
mod boolean;
mod buffer;
mod complex_nested;
mod config;
mod conv1d;
mod conv2d;
mod conv_transpose1d;
Expand Down
Loading
Loading