Skip to content

Commit

Permalink
improve script parsing/logging
Browse files Browse the repository at this point in the history
  • Loading branch information
granawkins committed Feb 2, 2024
1 parent 0d43899 commit 7340430
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 31 deletions.
6 changes: 4 additions & 2 deletions src/rawdog/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
def rawdog(prompt: str, verbose: bool=False):
_continue = True
while _continue is True:
error, output = "", ""
error, script, output = "", "", ""
try:
error, script = llm_client.get_script(prompt)
message, script = llm_client.get_script(prompt)
if script:
if verbose:
if input("Proceed with execution? (Y/n):").strip().lower() == "n":
Expand All @@ -23,6 +23,8 @@ def rawdog(prompt: str, verbose: bool=False):
with io.StringIO() as buf, redirect_stdout(buf):
exec(script, globals())
output = buf.getvalue()
elif message:
print(message)
except (Exception, KeyboardInterrupt) as e:
error = str(e)

Expand Down
59 changes: 30 additions & 29 deletions src/rawdog/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,33 @@
from rawdog.prompts import script_prompt, script_examples


def parse_script(response: str) -> tuple[str, str]:
"""Split the response into a message and a script.
Expected use is: run the script if there is one, otherwise print the message.
"""
# Parse delimiter
n_delimiters = response.count("```")
if n_delimiters < 2:
return f"Error: No script found in response:\n{response}", ""
segments = response.split("```")
message = f'{segments[0]}\n{segments[-1]}'
script = "```".join(segments[1:-1]).strip() # Leave 'inner' delimiters alone

# Check for common mistakes
if script.split("\n")[0].startswith("python"):
script = "\n".join(script.split("\n")[1:])
try: # Make sure it isn't json
script = json.loads(script)
except Exception as e:
pass
try: # Make sure it's valid python
ast.parse(script)
except SyntaxError as e:
return f"Script contains invalid Python:\n{response}", ""
return message, script


class LLMClient:

def __init__(self):
Expand Down Expand Up @@ -61,13 +88,12 @@ def get_response(
log["cost"] = f"{float(cost):.10f}"
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
script_filename = self.log_path.parent / f"script_{timestamp}.py"

script = None if not text else re.search(r'```(.*?)```', text, re.DOTALL)
_, script = ("", "") if not text else parse_script(text)
script_content = dedent(f"""\
# Model: {log['model']}
# Prompt: {log['prompt']}
# Response Cost: {log.get('cost', 'N/A')}
""") + script.group(1) if script else f"INVALID SCRIPT:\n{text}"
""") + script if script else f"INVALID SCRIPT:\n{text}"
with open(script_filename, "w") as script_file:
script_file.write(script_content)
return text
Expand All @@ -83,29 +109,4 @@ def get_script(self, prompt: str):
self.conversation.append({"role": "user", "content": f"PROMPT: {prompt}"})
response = self.get_response(self.conversation)
self.conversation.append({"role": "system", "content": response})

# Parse script from response
error = None
script = None
if response.count("```") > 1:
script = re.search(r'```(.*?)```', response, re.DOTALL)
if not script:
error = f"No script found in response: {response}"
script = script.group(1)
script = dedent(script).strip()
if script.split("\n")[0].startswith("python"):
script = "\n".join(script.split("\n")[1:])
# Make sure it isn't json
try:
script = json.loads(script)
except Exception as e:
pass
# Make sure it's valid python
try:
ast.parse(script)
except SyntaxError as e:
error == f"Invalid script:\n{script}"
script = None
else:
error = f"No script found in response: {response}"
return error, script
return parse_script(response)

0 comments on commit 7340430

Please sign in to comment.