Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor loopblock value #1381

Merged
merged 2 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion skyvern/forge/sdk/workflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,5 110,17 @@ def __init__(self, workflow_parameter_type: str, workflow_parameter_key: str, re


class InvalidWaitBlockTime(SkyvernException):
def __init__(self, max_sec: int):
def __init__(self, max_sec: int) -> None:
super().__init__(f"Invalid wait time for wait block, it should be a number between 0 and {max_sec}.")


class FailedToFormatJinjaStyleParameter(SkyvernException):
def __init__(self, template: str, msg: str) -> None:
super().__init__(
f"Failed to format Jinja style parameter {template}. Please make sure the variable reference is correct. reason: {msg}"
)


class NoIterableValueFound(SkyvernException):
def __init__(self) -> None:
super().__init__("No iterable value found for the loop block")
79 changes: 55 additions & 24 deletions skyvern/forge/sdk/workflow/models/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 47,10 @@
from skyvern.forge.sdk.schemas.tasks import Task, TaskOutput, TaskStatus
from skyvern.forge.sdk.workflow.context_manager import BlockMetadata, WorkflowRunContext
from skyvern.forge.sdk.workflow.exceptions import (
FailedToFormatJinjaStyleParameter,
InvalidEmailClientConfiguration,
InvalidFileType,
NoIterableValueFound,
NoValidEmailRecipient,
)
from skyvern.forge.sdk.workflow.models.parameter import (
Expand Down Expand Up @@ -576,14 578,17 @@ def get_failure_reason(self) -> str | None:
class ForLoopBlock(Block):
block_type: Literal[BlockType.FOR_LOOP] = BlockType.FOR_LOOP

loop_over: PARAMETER_TYPE
loop_blocks: list[BlockTypeVar]
loop_over: PARAMETER_TYPE | None = None
loop_variable_reference: str | None = None

def get_all_parameters(
self,
workflow_run_id: str,
) -> list[PARAMETER_TYPE]:
parameters = {self.loop_over}
parameters = set()
if self.loop_over is not None:
parameters.add(self.loop_over)

for loop_block in self.loop_blocks:
for parameter in loop_block.get_all_parameters(workflow_run_id):
Expand All @@ -600,6 605,9 @@ def get_loop_block_context_parameters(self, workflow_run_id: str, loop_data: Any
if isinstance(parameter, ContextParameter):
context_parameters.append(parameter)

if self.loop_over is None:
return context_parameters

for context_parameter in context_parameters:
if context_parameter.source.key != self.loop_over.key:
continue
Expand All @@ -620,29 628,44 @@ def get_loop_block_context_parameters(self, workflow_run_id: str, loop_data: Any
return context_parameters

def get_loop_over_parameter_values(self, workflow_run_context: WorkflowRunContext) -> list[Any]:
if isinstance(self.loop_over, WorkflowParameter):
parameter_value = workflow_run_context.get_value(self.loop_over.key)
elif isinstance(self.loop_over, OutputParameter):
# If the output parameter is for a TaskBlock, it will be a TaskOutput object. We need to extract the
# value from the TaskOutput object's extracted_information field.
output_parameter_value = workflow_run_context.get_value(self.loop_over.key)
if isinstance(output_parameter_value, dict) and "extracted_information" in output_parameter_value:
parameter_value = output_parameter_value["extracted_information"]
else:
parameter_value = output_parameter_value
elif isinstance(self.loop_over, ContextParameter):
parameter_value = self.loop_over.value
if not parameter_value:
source_parameter_value = workflow_run_context.get_value(self.loop_over.source.key)
if isinstance(source_parameter_value, dict):
if "extracted_information" in source_parameter_value:
parameter_value = source_parameter_value["extracted_information"].get(self.loop_over.key)
else:
parameter_value = source_parameter_value.get(self.loop_over.key)
# parse the value from self.loop_variable_reference and then from self.loop_over
if self.loop_variable_reference:
value_template = f'{{{{ {self.loop_variable_reference.strip(" {}")} | tojson }}}}'
try:
value_json = self.format_block_parameter_template_from_workflow_run_context(
value_template, workflow_run_context
)
except Exception as e:
raise FailedToFormatJinjaStyleParameter(value_template, str(e))
parameter_value = json.loads(value_json)

elif self.loop_over is not None:
if isinstance(self.loop_over, WorkflowParameter):
parameter_value = workflow_run_context.get_value(self.loop_over.key)
elif isinstance(self.loop_over, OutputParameter):
# If the output parameter is for a TaskBlock, it will be a TaskOutput object. We need to extract the
# value from the TaskOutput object's extracted_information field.
output_parameter_value = workflow_run_context.get_value(self.loop_over.key)
if isinstance(output_parameter_value, dict) and "extracted_information" in output_parameter_value:
parameter_value = output_parameter_value["extracted_information"]
else:
raise ValueError("ContextParameter source value should be a dict")
parameter_value = output_parameter_value
elif isinstance(self.loop_over, ContextParameter):
parameter_value = self.loop_over.value
if not parameter_value:
source_parameter_value = workflow_run_context.get_value(self.loop_over.source.key)
if isinstance(source_parameter_value, dict):
if "extracted_information" in source_parameter_value:
parameter_value = source_parameter_value["extracted_information"].get(self.loop_over.key)
else:
parameter_value = source_parameter_value.get(self.loop_over.key)
else:
raise ValueError("ContextParameter source value should be a dict")
else:
raise NotImplementedError()

else:
raise NotImplementedError
raise NoIterableValueFound()

if isinstance(parameter_value, list):
return parameter_value
Expand Down Expand Up @@ -725,7 748,15 @@ async def execute_loop_helper(

async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult:
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
loop_over_values = self.get_loop_over_parameter_values(workflow_run_context)
try:
loop_over_values = self.get_loop_over_parameter_values(workflow_run_context)
except Exception as e:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Catching a generic Exception is not recommended. Consider catching specific exceptions that are expected to occur when calling get_loop_over_parameter_values.

return self.build_block_result(
success=False,
failure_reason=f"failed to get loop values: {str(e)}",
status=BlockStatus.failed,
)

LOG.info(
f"Number of loop_over values: {len(loop_over_values)}",
block_type=self.block_type,
Expand Down
1 change: 1 addition & 0 deletions skyvern/forge/sdk/workflow/models/yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 142,7 @@ class ForLoopBlockYAML(BlockYAML):

loop_over_parameter_key: str
loop_blocks: list["BLOCK_YAML_SUBCLASSES"]
loop_variable_reference: str | None = None


class CodeBlockYAML(BlockYAML):
Expand Down
19 changes: 18 additions & 1 deletion skyvern/forge/sdk/workflow/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,10 1319,27 @@ async def block_yaml_to_block(
await WorkflowService.block_yaml_to_block(workflow, loop_block, parameters)
for loop_block in block_yaml.loop_blocks
]
loop_over_parameter = parameters[block_yaml.loop_over_parameter_key]

loop_over_parameter: Parameter | None = None
if block_yaml.loop_over_parameter_key:
loop_over_parameter = parameters[block_yaml.loop_over_parameter_key]

if block_yaml.loop_variable_reference:
# it's backaward compatible with jinja style parameter and context paramter
# we trim the format like {{ loop_key }} into loop_key to initialize the context parater,
# otherwise it might break the context parameter initialization chain, blow up the worklofw parameters
# TODO: consider remove this if we totally give up context parameter
trimmed_key = block_yaml.loop_variable_reference.strip(" {}")
if trimmed_key in parameters:
loop_over_parameter = parameters[trimmed_key]

if loop_over_parameter is None and not block_yaml.loop_variable_reference:
raise Exception("empty loop value parameter")

return ForLoopBlock(
label=block_yaml.label,
loop_over=loop_over_parameter,
loop_variable_reference=block_yaml.loop_variable_reference,
loop_blocks=loop_blocks,
output_parameter=output_parameter,
continue_on_failure=block_yaml.continue_on_failure,
Expand Down
Loading