Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom samplers, better collection, and better plotting to sinter #804

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift click to select a range
5813470
Rearrange sinter's pieces into sub packages
Strilanc Jul 27, 2024
7ad07b9
- Add sinter.Sampler and sinter.CompiledSampler
Strilanc Jul 27, 2024
b8703b9
- `sinter collect --processes` now defaults to `auto` if not specified
Strilanc Jul 27, 2024
f103255
- Add `sinter plot --point_label_func`
Strilanc Jul 27, 2024
44206e2
- Added 'perfectionist' sampler
Strilanc Jul 27, 2024
c2557a9
regen api docs
Strilanc Jul 27, 2024
be49478
fix gen
Strilanc Jul 27, 2024
07a56f7
Merge branch 'main' of github.com:quantumlib/Stim into sinter2
Strilanc Jul 27, 2024
1802b72
Show traditional error bars when only one data point is present
Strilanc Jul 30, 2024
1325e6b
- Add `sinter plot --preprocess_stats_func`
Strilanc Jul 30, 2024
da00d11
Add safety error when adding ambiguous strong id stats
Strilanc Jul 30, 2024
9cb172e
Add `sinter.TaskStats.with_edits`
Strilanc Jul 30, 2024
d524e60
Merge branch 'main' into sinter2
Strilanc Jul 30, 2024
0fdc850
Merge branch 'main' into sinter2
Strilanc Jul 31, 2024
c90d62d
Merge branch 'main' of github.com:quantumlib/Stim into sinter2
Strilanc Aug 3, 2024
a215126
Allow oversampling, fix json equality being tuple-vs-array sensitive
Strilanc Aug 6, 2024
351752d
Merge branch 'sinter2' of github.com:quantumlib/Stim into sinter2
Strilanc Aug 6, 2024
3283161
More fixes of comparisons when sorting
Strilanc Aug 6, 2024
a4574fb
soft error threshold, hopeful fixing
Strilanc Aug 7, 2024
4269062
Iter
Strilanc Aug 7, 2024
88c1e4e
test fix
Strilanc Aug 7, 2024
98b47e8
Believe the force
Strilanc Aug 7, 2024
7390e70
docstring for sinter plot group_func dict api, add curve sorting, fix…
mmcewen-g Aug 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion dev/gen_sinter_api_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 48,25 @@ def main():
```
'''.strip())

replace_rules = []
for package in ['stim', 'sinter']:
p = __import__(package)
for name in dir(p):
x = getattr(p, name)
if isinstance(x, type) and 'class' in str(x):
desired_name = f'{package}.{name}'
if '._' in str(x):
bad_name = str(x).split("'")[1]
replace_rules.append((bad_name, desired_name))
lonely_name = desired_name.split(".")[-1]
for q in ['"', "'"]:
replace_rules.append(('ForwardRef(' q lonely_name q ')', desired_name))
replace_rules.append(('ForwardRef(' q desired_name q ')', desired_name))
replace_rules.append((q desired_name q, desired_name))
replace_rules.append((q lonely_name q, desired_name))
replace_rules.append(('ForwardRef(' desired_name ')', desired_name))
replace_rules.append(('ForwardRef(' lonely_name ')', desired_name))

for obj in objects:
print()
print(f'<a name="{obj.full_name}"></a>')
Expand All @@ -58,7 77,10 @@ def main():
print(f'# (in class {".".join(obj.full_name.split(".")[:-1])})')
else:
print(f'# (at top-level in the sinter module)')
print('\n'.join(obj.lines))
for line in obj.lines:
for a, b in replace_rules:
line = line.replace(a, b)
print(line)
print("```")


Expand Down
13 changes: 1 addition & 12 deletions dev/util_gen_stub_file.py
Original file line number Diff line number Diff line change
@@ -1,5 1,4 @@
import dataclasses
import sys
import types
from typing import Any
from typing import Optional, Iterator, List
Expand All @@ -9,6 8,7 @@

keep = {
"__add__",
"__radd__",
"__eq__",
"__call__",
"__ge__",
Expand Down Expand Up @@ -224,17 224,6 @@ def print_doc(*, full_name: str, parent: object, obj: object, level: int) -> Opt
text = '@abc.abstractmethod\n'
sig_name = f'{term_name}{inspect.signature(obj)}'
text = "\n".join(splay_signature(f"def {sig_name}:"))
text = text.replace('''ForwardRef('sinter.TaskStats')''', 'sinter.TaskStats')
text = text.replace('''ForwardRef('sinter.Task')''', 'sinter.Task')
text = text.replace('''ForwardRef('sinter.Progress')''', 'sinter.Progress')
text = text.replace('''ForwardRef('sinter.Decoder')''', 'sinter.Decoder')
text = text.replace("'AnonTaskStats'", "sinter.AnonTaskStats")
text = text.replace('sinter._decoding_decoder_class.CompiledDecoder', 'sinter.CompiledDecoder')
text = text.replace("'AnonTaskStats'", "sinter.AnonTaskStats")
text = text.replace("'stim.Circuit'", "stim.Circuit")
text = text.replace("'stim.DetectorErrorModel'", "stim.DetectorErrorModel")
text = text.replace("'sinter.CollectionOptions'", "sinter.CollectionOptions")
text = text.replace("'sinter.Fit'", 'sinter.Fit')

# Replace default value lambdas with their source.
if 'lambda' in str(text):
Expand Down
128 changes: 118 additions & 10 deletions doc/sinter_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 12,16 @@ API references for stable versions are kept on the [stim github wiki](https://gi
- [`sinter.CollectionOptions.combine`](#sinter.CollectionOptions.combine)
- [`sinter.CompiledDecoder`](#sinter.CompiledDecoder)
- [`sinter.CompiledDecoder.decode_shots_bit_packed`](#sinter.CompiledDecoder.decode_shots_bit_packed)
- [`sinter.CompiledSampler`](#sinter.CompiledSampler)
- [`sinter.CompiledSampler.handles_throttling`](#sinter.CompiledSampler.handles_throttling)
- [`sinter.CompiledSampler.sample`](#sinter.CompiledSampler.sample)
- [`sinter.Decoder`](#sinter.Decoder)
- [`sinter.Decoder.compile_decoder_for_dem`](#sinter.Decoder.compile_decoder_for_dem)
- [`sinter.Decoder.decode_via_files`](#sinter.Decoder.decode_via_files)
- [`sinter.Fit`](#sinter.Fit)
- [`sinter.Progress`](#sinter.Progress)
- [`sinter.Sampler`](#sinter.Sampler)
- [`sinter.Sampler.compiled_sampler_for_task`](#sinter.Sampler.compiled_sampler_for_task)
- [`sinter.Task`](#sinter.Task)
- [`sinter.Task.__init__`](#sinter.Task.__init__)
- [`sinter.Task.strong_id`](#sinter.Task.strong_id)
Expand All @@ -26,6 31,7 @@ API references for stable versions are kept on the [stim github wiki](https://gi
- [`sinter.TaskStats`](#sinter.TaskStats)
- [`sinter.TaskStats.to_anon_stats`](#sinter.TaskStats.to_anon_stats)
- [`sinter.TaskStats.to_csv_line`](#sinter.TaskStats.to_csv_line)
- [`sinter.TaskStats.with_edits`](#sinter.TaskStats.with_edits)
- [`sinter.better_sorted_str_terms`](#sinter.better_sorted_str_terms)
- [`sinter.collect`](#sinter.collect)
- [`sinter.comma_separated_key_values`](#sinter.comma_separated_key_values)
Expand Down Expand Up @@ -257,6 263,50 @@ def decode_shots_bit_packed(
"""
```

<a name="sinter.CompiledSampler"></a>
```python
# sinter.CompiledSampler

# (at top-level in the sinter module)
class CompiledSampler(metaclass=abc.ABCMeta):
"""A sampler that has been configured for efficiently sampling some task.
"""
```

<a name="sinter.CompiledSampler.handles_throttling"></a>
```python
# sinter.CompiledSampler.handles_throttling

# (in class sinter.CompiledSampler)
def handles_throttling(
self,
) -> bool:
"""Return True to disable sinter wrapping samplers with throttling.

By default, sinter will wrap samplers so that they initially only do
a small number of shots then slowly ramp up. Sometimes this behavior
is not desired (e.g. in unit tests). Override this method to return True
to disable it.
"""
```

<a name="sinter.CompiledSampler.sample"></a>
```python
# sinter.CompiledSampler.sample

# (in class sinter.CompiledSampler)
@abc.abstractmethod
def sample(
self,
shots: int,
) -> sinter.AnonTaskStats:
"""Perform the given number of samples, and return statistics.

This method is permitted to perform fewer shots than specified, but must
indicate this in its returned statistics.
"""
```

<a name="sinter.Decoder"></a>
```python
# sinter.Decoder
Expand Down Expand Up @@ -385,9 435,9 @@ class Fit:
of the best fit's square error, or whose likelihood was within some
maximum Bayes factor of the max likelihood hypothesis.
"""
low: float
best: float
high: float
low: Optional[float]
best: Optional[float]
high: Optional[float]
```

<a name="sinter.Progress"></a>
Expand All @@ -409,10 459,45 @@ class Progress:
collection status, such as the number of tasks left and the
estimated time to completion for each task.
"""
new_stats: Tuple[sinter._task_stats.TaskStats, ...]
new_stats: Tuple[sinter.TaskStats, ...]
status_message: str
```

<a name="sinter.Sampler"></a>
```python
# sinter.Sampler

# (at top-level in the sinter module)
class Sampler(metaclass=abc.ABCMeta):
"""A strategy for producing stats from tasks.

Call `sampler.compiled_sampler_for_task(task)` to get a compiled sampler for
a task, then call `compiled_sampler.sample(shots)` to collect statistics.

A sampler differs from a `sinter.Decoder` because the sampler is responsible
for the full sampling process (e.g. simulating the circuit), whereas a
decoder can do nothing except predict observable flips from detection event
data. This prevents the decoders from cheating, but makes them less flexible
overall. A sampler can do things like use simulators other than stim, or
really anything at all as long as it ends with returning statistics about
shot counts, error counts, and etc.
"""
```

<a name="sinter.Sampler.compiled_sampler_for_task"></a>
```python
# sinter.Sampler.compiled_sampler_for_task

# (in class sinter.Sampler)
@abc.abstractmethod
def compiled_sampler_for_task(
self,
task: sinter.Task,
) -> sinter.CompiledSampler:
"""Creates, configures, and returns an object for sampling the task.
"""
```

<a name="sinter.Task"></a>
```python
# sinter.Task
Expand Down Expand Up @@ -475,9 560,9 @@ class Task:
def __init__(
self,
*,
circuit: Optional[ForwardRef(stim.Circuit)] = None,
circuit: Optional[stim.Circuit] = None,
decoder: Optional[str] = None,
detector_error_model: Optional[ForwardRef(stim.DetectorErrorModel)] = None,
detector_error_model: Optional[stim.DetectorErrorModel] = None,
postselection_mask: Optional[np.ndarray] = None,
postselected_observables_mask: Optional[np.ndarray] = None,
json_metadata: Any = None,
Expand Down Expand Up @@ -699,7 784,7 @@ class TaskStats:
# (in class sinter.TaskStats)
def to_anon_stats(
self,
) -> sinter._anon_task_stats.AnonTaskStats:
) -> sinter.AnonTaskStats:
"""Returns a `sinter.AnonTaskStats` with the same statistics.

Examples:
Expand Down Expand Up @@ -745,6 830,25 @@ def to_csv_line(
"""
```

<a name="sinter.TaskStats.with_edits"></a>
```python
# sinter.TaskStats.with_edits

# (in class sinter.TaskStats)
def with_edits(
self,
*,
strong_id: Optional[str] = None,
decoder: Optional[str] = None,
json_metadata: Optional[Any] = None,
shots: Optional[int] = None,
errors: Optional[int] = None,
discards: Optional[int] = None,
seconds: Optional[float] = None,
custom_counts: Optional[Counter[str]] = None,
) -> sinter.TaskStats:
```

<a name="sinter.better_sorted_str_terms"></a>
```python
# sinter.better_sorted_str_terms
Expand Down Expand Up @@ -1124,7 1228,7 @@ def iter_collect(
num_workers: int,
tasks: Union[Iterator[sinter.Task], Iterable[sinter.Task]],
hint_num_tasks: Optional[int] = None,
additional_existing_data: Optional[sinter._existing_data.ExistingData] = None,
additional_existing_data: Union[NoneType, Dict[str, sinter.TaskStats], Iterable[sinter.TaskStats]] = None,
max_shots: Optional[int] = None,
max_errors: Optional[int] = None,
decoders: Optional[Iterable[str]] = None,
Expand Down Expand Up @@ -1337,6 1441,7 @@ def plot_discard_rate(
filter_func: Callable[[sinter.TaskStats], Any] = lambda _: True,
plot_args_func: Callable[[int, ~TCurveId, List[sinter.TaskStats]], Dict[str, Any]] = lambda index, group_key, group_stats: dict(),
highlight_max_likelihood_factor: Optional[float] = 1000.0,
point_label_func: Callable[[sinter.TaskStats], Any] = lambda _: None,
) -> None:
"""Plots discard rates in curves with uncertainty highlights.

Expand Down Expand Up @@ -1370,6 1475,7 @@ def plot_discard_rate(
highlight_max_likelihood_factor: Controls how wide the uncertainty highlight region around curves is.
Must be 1 or larger. Hypothesis probabilities at most that many times as unlikely as the max likelihood
hypothesis will be highlighted.
point_label_func: Optional. Specifies text to draw next to data points.
"""
```

Expand All @@ -1390,6 1496,7 @@ def plot_error_rate(
plot_args_func: Callable[[int, ~TCurveId, List[sinter.TaskStats]], Dict[str, Any]] = lambda index, group_key, group_stats: dict(),
highlight_max_likelihood_factor: Optional[float] = 1000.0,
line_fits: Optional[Tuple[Literal['linear', 'log', 'sqrt'], Literal['linear', 'log', 'sqrt']]] = None,
point_label_func: Callable[[sinter.TaskStats], Any] = lambda _: None,
) -> None:
"""Plots error rates in curves with uncertainty highlights.

Expand Down Expand Up @@ -1430,6 1537,7 @@ def plot_error_rate(
line_fits: Defaults to None. Set this to a tuple (x_scale, y_scale) to include a dashed line
fit to every curve. The scales determine how to transform the coordinates before
performing the fit, and can be set to 'linear', 'sqrt', or 'log'.
point_label_func: Optional. Specifies text to draw next to data points.
"""
```

Expand Down Expand Up @@ -1712,11 1820,11 @@ def read_stats_from_csv_files(

# (at top-level in the sinter module)
def shot_error_rate_to_piece_error_rate(
shot_error_rate: Union[float, ForwardRef(sinter.Fit)],
shot_error_rate: Union[float, sinter.Fit],
*,
pieces: float,
values: float = 1,
) -> Union[float, ForwardRef(sinter.Fit)]:
) -> Union[float, sinter.Fit]:
"""Convert from total error rate to per-piece error rate.

Args:
Expand Down
2 changes: 1 addition & 1 deletion glue/sample/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 37,6 @@
install_requires=requirements,
tests_require=['pytest', 'pymatching'],
entry_points={
'console_scripts': ['sinter=sinter._main:main'],
'console_scripts': ['sinter=sinter._command._main:main'],
},
)
33 changes: 12 additions & 21 deletions glue/sample/src/sinter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,26 1,27 @@
__version__ = '1.14.dev0'

from sinter._anon_task_stats import (
AnonTaskStats,
)
from sinter._collection import (
collect,
iter_collect,
post_selection_mask_from_4th_coord,
Progress,
)
from sinter._collection_options import (
from sinter._data import (
AnonTaskStats,
CollectionOptions,
)
from sinter._csv_out import (
CSV_HEADER,
)
from sinter._decoding_all_built_in_decoders import (
BUILT_IN_DECODERS,
)
from sinter._existing_data import (
read_stats_from_csv_files,
stats_from_csv_files,
Task,
TaskStats,
)
from sinter._decoding import (
CompiledDecoder,
Decoder,
BUILT_IN_DECODERS,
BUILT_IN_SAMPLERS,
Sampler,
CompiledSampler,
)
from sinter._probability_util import (
comma_separated_key_values,
Expand All @@ -38,19 39,9 @@
plot_error_rate,
group_by,
)
from sinter._task import (
Task,
)
from sinter._task_stats import (
TaskStats,
)
from sinter._predict import (
predict_discards_bit_packed,
predict_observables_bit_packed,
predict_on_disk,
predict_observables,
)
from sinter._decoding_decoder_class import (
CompiledDecoder,
Decoder,
)
10 changes: 10 additions & 0 deletions glue/sample/src/sinter/_collection/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 1,10 @@
from sinter._collection._collection import (
collect,
iter_collect,
post_selection_mask_from_4th_coord,
post_selection_mask_from_predicate,
Progress,
)
from sinter._collection._printer import (
ThrottledProgressPrinter,
)
Loading
Loading