Skip to content

Commit

Permalink
refactor loopblock value (#1381)
Browse files Browse the repository at this point in the history
  • Loading branch information
LawyZheng authored Dec 12, 2024
1 parent f5691d5 commit f028b48
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 26 deletions.
14 changes: 13 additions & 1 deletion skyvern/forge/sdk/workflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,5 110,17 @@ def __init__(self, workflow_parameter_type: str, workflow_parameter_key: str, re


class InvalidWaitBlockTime(SkyvernException):
def __init__(self, max_sec: int):
def __init__(self, max_sec: int) -> None:
super().__init__(f"Invalid wait time for wait block, it should be a number between 0 and {max_sec}.")


class FailedToFormatJinjaStyleParameter(SkyvernException):
def __init__(self, template: str, msg: str) -> None:
super().__init__(
f"Failed to format Jinja style parameter {template}. Please make sure the variable reference is correct. reason: {msg}"
)


class NoIterableValueFound(SkyvernException):
def __init__(self) -> None:
super().__init__("No iterable value found for the loop block")
79 changes: 55 additions & 24 deletions skyvern/forge/sdk/workflow/models/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 47,10 @@
from skyvern.forge.sdk.schemas.tasks import Task, TaskOutput, TaskStatus
from skyvern.forge.sdk.workflow.context_manager import BlockMetadata, WorkflowRunContext
from skyvern.forge.sdk.workflow.exceptions import (
FailedToFormatJinjaStyleParameter,
InvalidEmailClientConfiguration,
InvalidFileType,
NoIterableValueFound,
NoValidEmailRecipient,
)
from skyvern.forge.sdk.workflow.models.parameter import (
Expand Down Expand Up @@ -576,14 578,17 @@ def get_failure_reason(self) -> str | None:
class ForLoopBlock(Block):
block_type: Literal[BlockType.FOR_LOOP] = BlockType.FOR_LOOP

loop_over: PARAMETER_TYPE
loop_blocks: list[BlockTypeVar]
loop_over: PARAMETER_TYPE | None = None
loop_variable_reference: str | None = None

def get_all_parameters(
self,
workflow_run_id: str,
) -> list[PARAMETER_TYPE]:
parameters = {self.loop_over}
parameters = set()
if self.loop_over is not None:
parameters.add(self.loop_over)

for loop_block in self.loop_blocks:
for parameter in loop_block.get_all_parameters(workflow_run_id):
Expand All @@ -600,6 605,9 @@ def get_loop_block_context_parameters(self, workflow_run_id: str, loop_data: Any
if isinstance(parameter, ContextParameter):
context_parameters.append(parameter)

if self.loop_over is None:
return context_parameters

for context_parameter in context_parameters:
if context_parameter.source.key != self.loop_over.key:
continue
Expand All @@ -620,29 628,44 @@ def get_loop_block_context_parameters(self, workflow_run_id: str, loop_data: Any
return context_parameters

def get_loop_over_parameter_values(self, workflow_run_context: WorkflowRunContext) -> list[Any]:
if isinstance(self.loop_over, WorkflowParameter):
parameter_value = workflow_run_context.get_value(self.loop_over.key)
elif isinstance(self.loop_over, OutputParameter):
# If the output parameter is for a TaskBlock, it will be a TaskOutput object. We need to extract the
# value from the TaskOutput object's extracted_information field.
output_parameter_value = workflow_run_context.get_value(self.loop_over.key)
if isinstance(output_parameter_value, dict) and "extracted_information" in output_parameter_value:
parameter_value = output_parameter_value["extracted_information"]
else:
parameter_value = output_parameter_value
elif isinstance(self.loop_over, ContextParameter):
parameter_value = self.loop_over.value
if not parameter_value:
source_parameter_value = workflow_run_context.get_value(self.loop_over.source.key)
if isinstance(source_parameter_value, dict):
if "extracted_information" in source_parameter_value:
parameter_value = source_parameter_value["extracted_information"].get(self.loop_over.key)
else:
parameter_value = source_parameter_value.get(self.loop_over.key)
# parse the value from self.loop_variable_reference and then from self.loop_over
if self.loop_variable_reference:
value_template = f'{{{{ {self.loop_variable_reference.strip(" {}")} | tojson }}}}'
try:
value_json = self.format_block_parameter_template_from_workflow_run_context(
value_template, workflow_run_context
)
except Exception as e:
raise FailedToFormatJinjaStyleParameter(value_template, str(e))
parameter_value = json.loads(value_json)

elif self.loop_over is not None:
if isinstance(self.loop_over, WorkflowParameter):
parameter_value = workflow_run_context.get_value(self.loop_over.key)
elif isinstance(self.loop_over, OutputParameter):
# If the output parameter is for a TaskBlock, it will be a TaskOutput object. We need to extract the
# value from the TaskOutput object's extracted_information field.
output_parameter_value = workflow_run_context.get_value(self.loop_over.key)
if isinstance(output_parameter_value, dict) and "extracted_information" in output_parameter_value:
parameter_value = output_parameter_value["extracted_information"]
else:
raise ValueError("ContextParameter source value should be a dict")
parameter_value = output_parameter_value
elif isinstance(self.loop_over, ContextParameter):
parameter_value = self.loop_over.value
if not parameter_value:
source_parameter_value = workflow_run_context.get_value(self.loop_over.source.key)
if isinstance(source_parameter_value, dict):
if "extracted_information" in source_parameter_value:
parameter_value = source_parameter_value["extracted_information"].get(self.loop_over.key)
else:
parameter_value = source_parameter_value.get(self.loop_over.key)
else:
raise ValueError("ContextParameter source value should be a dict")
else:
raise NotImplementedError()

else:
raise NotImplementedError
raise NoIterableValueFound()

if isinstance(parameter_value, list):
return parameter_value
Expand Down Expand Up @@ -725,7 748,15 @@ async def execute_loop_helper(

async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult:
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
loop_over_values = self.get_loop_over_parameter_values(workflow_run_context)
try:
loop_over_values = self.get_loop_over_parameter_values(workflow_run_context)
except Exception as e:
return self.build_block_result(
success=False,
failure_reason=f"failed to get loop values: {str(e)}",
status=BlockStatus.failed,
)

LOG.info(
f"Number of loop_over values: {len(loop_over_values)}",
block_type=self.block_type,
Expand Down
1 change: 1 addition & 0 deletions skyvern/forge/sdk/workflow/models/yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 142,7 @@ class ForLoopBlockYAML(BlockYAML):

loop_over_parameter_key: str
loop_blocks: list["BLOCK_YAML_SUBCLASSES"]
loop_variable_reference: str | None = None


class CodeBlockYAML(BlockYAML):
Expand Down
19 changes: 18 additions & 1 deletion skyvern/forge/sdk/workflow/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,10 1319,27 @@ async def block_yaml_to_block(
await WorkflowService.block_yaml_to_block(workflow, loop_block, parameters)
for loop_block in block_yaml.loop_blocks
]
loop_over_parameter = parameters[block_yaml.loop_over_parameter_key]

loop_over_parameter: Parameter | None = None
if block_yaml.loop_over_parameter_key:
loop_over_parameter = parameters[block_yaml.loop_over_parameter_key]

if block_yaml.loop_variable_reference:
# it's backaward compatible with jinja style parameter and context paramter
# we trim the format like {{ loop_key }} into loop_key to initialize the context parater,
# otherwise it might break the context parameter initialization chain, blow up the worklofw parameters
# TODO: consider remove this if we totally give up context parameter
trimmed_key = block_yaml.loop_variable_reference.strip(" {}")
if trimmed_key in parameters:
loop_over_parameter = parameters[trimmed_key]

if loop_over_parameter is None and not block_yaml.loop_variable_reference:
raise Exception("empty loop value parameter")

return ForLoopBlock(
label=block_yaml.label,
loop_over=loop_over_parameter,
loop_variable_reference=block_yaml.loop_variable_reference,
loop_blocks=loop_blocks,
output_parameter=output_parameter,
continue_on_failure=block_yaml.continue_on_failure,
Expand Down

0 comments on commit f028b48

Please sign in to comment.