Skip to content

Commit

Permalink
Merge pull request stanfordnlp#509 from thomasahle/main
Browse files Browse the repository at this point in the history
Type improvements
  • Loading branch information
thomasahle authored Mar 1, 2024
2 parents a3a37bf 702a89c commit 5f0dbc8
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 87 deletions.
123 changes: 68 additions & 55 deletions dspy/functional/functional.py
Original file line number Diff line number Diff line change
@@ -1,5 1,5 @@
import inspect, os, openai, dspy, typing, pydantic
from typing import Annotated
from typing import Annotated, List, Tuple
import typing
from dsp.templates import passages2text
import json
Expand Down Expand Up @@ -27,7 27,7 @@ def __init__(self, predictor, output_key):
super().__init__()
self.predictor = predictor
self.output_key = output_key

def copy(self):
return _StripOutput(self.predictor.copy(), self.output_key)

Expand All @@ -37,7 37,8 @@ def forward(self, **kwargs):


class FunctionalModule(dspy.Module):
""" To use the @cot and @predictor decorators, your module needs to inheret form this class. """
"""To use the @cot and @predictor decorators, your module needs to inheret form this class."""

def __init__(self):
super().__init__()
for name in dir(self):
Expand All @@ -46,38 47,45 @@ def __init__(self):
self.__dict__[name] = attr.copy()


def TypedChainOfThought(signature, make_example=False):
""" Just like TypedPredictor, but adds a ChainOfThought OutputField. """
def TypedChainOfThought(signature):
"""Just like TypedPredictor, but adds a ChainOfThought OutputField."""
signature = ensure_signature(signature)
output_keys = ", ".join(signature.output_fields.keys())
return TypedPredictor(signature.prepend(
"reasoning",
dspy.OutputField(
prefix="Reasoning: Let's think step by step in order to",
desc="${produce the " output_keys "}. We ...",
),
), make_example)
return TypedPredictor(
signature.prepend(
"reasoning",
dspy.OutputField(
prefix="Reasoning: Let's think step by step in order to",
desc="${produce the " output_keys "}. We ...",
),
)
)


class TypedPredictor(dspy.Module):
def __init__(self, signature, make_example=False):
def __init__(self, signature):
super().__init__()
self.signature = signature
self.predictor = dspy.Predict(signature)
self.make_example = make_example

def copy(self):
return TypedPredictor(self.signature, self.make_example)
return TypedPredictor(self.signature)

@staticmethod
def _make_example(type_):
# Note: DSPy will cache this call so we only pay the first time TypedPredictor is called.
return dspy.Predict(
json_object = dspy.Predict(
dspy.Signature(
"json_schema -> json_object",
"Make a very succinct json object that validates with the following schema",
)
)(json_schema=json.dumps(type_.model_json_schema())).json_object
# We use the model_validate_json method to make sure the example is valid
try:
type_.model_validate_json(_unwrap_json(json_object))
except (pydantic.ValidationError, ValueError):
return "" # Unable to make an example
return json_object
# TODO: Another fun idea is to only (but automatically) do this if the output fails.
# We could also have a more general "suggest solution" prompt that tries to fix the output
# More directly.
Expand All @@ -103,29 111,25 @@ def _prepare_signature(self):
else:
# Anything else we wrap in a pydantic object
unwrap = lambda x: x
wrap = lambda x: x
if not (inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel)):
type_ = pydantic.create_model(
"Output", value=(type_, ...), __base__=pydantic.BaseModel
)
type_ = pydantic.create_model("Output", value=(type_, ...), __base__=pydantic.BaseModel)
wrap = lambda x: type_(value=x)
unwrap = lambda x: x.value
signature = signature.with_updated_fields(
name,
desc=field.json_schema_extra.get("desc", "")
(
f". Respond with a single JSON object using the schema "
f". Respond with a single JSON object. JSON Schema: "
json.dumps(type_.model_json_schema())
(". For example: " self._make_example(type_) if self.make_example else "")
),
format=lambda x: (
x if isinstance(x, str) else x.model_dump_json()
),
parser=lambda x: unwrap(
type_.model_validate_json(_unwrap_json(x))
),
format=lambda x: (x if isinstance(x, str) else wrap(x).model_dump_json()),
parser=lambda x: unwrap(type_.model_validate_json(_unwrap_json(x))),
type_=type_,
)
else: # If input field
format = lambda x: x if isinstance(x, str) else str(x)
if type_ in (list[str], tuple[str]):
if type_ in (List[str], list[str], Tuple[str], tuple[str]):
format = passages2text
elif inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel):
format = lambda x: x if isinstance(x, str) else x.model_dump_json()
Expand All @@ -149,16 153,27 @@ def forward(self, **kwargs):
parser = field.json_schema_extra.get("parser", lambda x: x)
parsed_results[name] = parser(value)
except (pydantic.ValidationError, ValueError) as e:
errors[name] = e
errors[name] = _format_error(e)
# If we can, we add an example to the error message
current_desc = field.json_schema_extra.get("desc", "")
i = current_desc.find("JSON Schema: ")
if i == -1:
continue # Only add examples to JSON objects
suffix, current_desc = current_desc[i:], current_desc[:i]
prefix = "You MUST use this format: "
if try_i 1 < MAX_RETRIES and prefix not in current_desc:
if example := self._make_example(field.annotation):
signature = signature.with_updated_fields(
name, desc=current_desc "\n" prefix example "\n" suffix
)
if errors:
# Add new fields for each error
for name, error in errors.items():
modified_kwargs[f"error_{name}_{try_i}"] = str(error)
modified_kwargs[f"error_{name}_{try_i}"] = error
signature = signature.append(
f"error_{name}_{try_i}",
dspy.InputField(
prefix=f"Past Error "
(f"({name}):" if try_i == 0 else f"({name}, {try_i 1}):"),
prefix=f"Past Error " (f"({name}):" if try_i == 0 else f"({name}, {try_i 1}):"),
desc="An error to avoid in the future",
),
)
Expand All @@ -167,7 182,20 @@ def forward(self, **kwargs):
for name, value in parsed_results.items():
setattr(result, name, value)
return result
raise ValueError("Too many retries")
raise ValueError(
"Too many retries trying to get the correct output format. " "Try simplifying the requirements.", errors
)


def _format_error(error: Exception):
if isinstance(error, pydantic.ValidationError):
errors = []
for e in error.errors():
fields = ", ".join(e["loc"])
errors.append(f"{e['msg']}: {fields} (error type: {e['type']})")
return "; ".join(errors)
else:
return str(error)


def _func_to_signature(func):
Expand Down Expand Up @@ -221,9 249,7 @@ def main():
class Answer(pydantic.BaseModel):
value: float
certainty: float
comments: list[str] = pydantic.Field(
description="At least two comments about the answer"
)
comments: list[str] = pydantic.Field(description="At least two comments about the answer")

class QA(dspy.Module):
@predictor
Expand Down Expand Up @@ -262,26 288,19 @@ def validate_context_and_answer_and_hops(example, pred, trace=None):
if not dspy.evaluate.answer_passage_match(example, pred):
return False

hops = [example.question] [
outputs.query for *_, outputs in trace if "query" in outputs
]
hops = [example.question] [outputs.query for *_, outputs in trace if "query" in outputs]

if max([len(h) for h in hops]) > 100:
return False
if any(
dspy.evaluate.answer_exact_match_str(hops[idx], hops[:idx], frac=0.8)
for idx in range(2, len(hops))
):
if any(dspy.evaluate.answer_exact_match_str(hops[idx], hops[:idx], frac=0.8) for idx in range(2, len(hops))):
return False

return True


def gold_passages_retrieved(example, pred, trace=None):
gold_titles = set(map(dspy.evaluate.normalize_text, example["gold_titles"]))
found_titles = set(
map(dspy.evaluate.normalize_text, [c.split(" | ")[0] for c in pred.context])
)
found_titles = set(map(dspy.evaluate.normalize_text, [c.split(" | ")[0] for c in pred.context]))

return gold_titles.issubset(found_titles)

Expand All @@ -294,9 313,7 @@ def hotpot():
from dspy.teleprompt.bootstrap import BootstrapFewShot

print("Load the dataset.")
dataset = HotPotQA(
train_seed=1, train_size=20, eval_seed=2023, dev_size=50, test_size=0
)
dataset = HotPotQA(train_seed=1, train_size=20, eval_seed=2023, dev_size=50, test_size=0)
trainset = [x.with_inputs("question") for x in dataset.train]
devset = [x.with_inputs("question") for x in dataset.dev]
print("Done")
Expand Down Expand Up @@ -333,9 350,7 @@ def forward(self, question):
lm = dspy.OpenAI(model="gpt-3.5-turbo", max_tokens=4000)
dspy.settings.configure(lm=lm, rm=rm, trace=[])

evaluate_on_hotpotqa = Evaluate(
devset=devset, num_threads=10, display_progress=True, display_table=5
)
evaluate_on_hotpotqa = Evaluate(devset=devset, num_threads=10, display_progress=True, display_table=5)

# uncompiled (i.e., zero-shot) program
uncompiled_baleen = SimplifiedBaleen()
Expand All @@ -345,9 360,7 @@ def forward(self, question):
)

# compiled (i.e., few-shot) program
compiled_baleen = BootstrapFewShot(
metric=validate_context_and_answer_and_hops
).compile(
compiled_baleen = BootstrapFewShot(metric=validate_context_and_answer_and_hops).compile(
SimplifiedBaleen(),
teacher=SimplifiedBaleen(passages_per_hop=2),
trainset=trainset,
Expand Down
4 changes: 3 additions & 1 deletion dspy/signatures/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 77,16 @@ def fields(cls):
# Make sure to give input fields before output fields
return {**cls.input_fields, **cls.output_fields}

def with_updated_fields(cls, name, **kwargs):
def with_updated_fields(cls, name, type_=None, **kwargs):
"""Returns a new Signature type with the field, name, updated
with fields[name].json_schema_extra[key] = value."""
fields_copy = deepcopy(cls.fields)
fields_copy[name].json_schema_extra = {
**fields_copy[name].json_schema_extra,
**kwargs,
}
if type_ is not None:
fields_copy[name].annotation = type_
return Signature(fields_copy, cls.instructions)

@property
Expand Down
2 changes: 1 addition & 1 deletion examples/functional/functional.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 175,7 @@
" entry_point: str = InputField()\n",
" solution: PythonCode = OutputField()\n",
"\n",
"predictor = TypedPredictor(CodeSignature, make_example=True)\n",
"predictor = TypedPredictor(CodeSignature)\n",
"prediction = predictor(\n",
" prompt=PythonCode(code=ds['test'][0]['prompt']),\n",
" test=PythonCode(code=ds['test'][0]['test']),\n",
Expand Down
Loading

0 comments on commit 5f0dbc8

Please sign in to comment.