Skip to content

Commit

Permalink
core[patch]: Add unit test when catching generator exit (#23402)
Browse files Browse the repository at this point in the history
This pr adds a unit test for:
#22662
And narrows the scope where the exception is caught.
  • Loading branch information
eyurtsev committed Jun 27, 2024
1 parent 5e6d23f commit da7beb1
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 3 deletions.
4 changes: 1 addition & 3 deletions libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1878,7 1878,7 @@ def _transform_stream_with_config(
final_output_supported = False
else:
final_output = chunk
except StopIteration:
except (StopIteration, GeneratorExit):
pass
for ichunk in input_for_tracing:
if final_input_supported:
Expand All @@ -1892,8 1892,6 @@ def _transform_stream_with_config(
final_input_supported = False
else:
final_input = ichunk
except GeneratorExit:
run_manager.on_chain_end(final_output, inputs=final_input)
except BaseException as e:
run_manager.on_chain_error(e, inputs=final_input)
raise
Expand Down
36 changes: 36 additions & 0 deletions libs/core/tests/unit_tests/runnables/test_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -5706,3 5706,39 @@ def on_end(run: Run) -> None:
assert len(shared_state) == 2
assert value1 in shared_state.values(), "Value not found in the dictionary."
assert value2 in shared_state.values(), "Value not found in the dictionary."


async def test_closing_iterator_doesnt_raise_error() -> None:
"""Test that closing an iterator calls on_chain_end rather than on_chain_error."""
import time

from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.output_parsers import StrOutputParser

on_chain_error_triggered = False

class MyHandler(BaseCallbackHandler):
async def on_chain_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when chain errors."""
nonlocal on_chain_error_triggered
on_chain_error_triggered = True

llm = GenericFakeChatModel(messages=iter(["hi there"]))
chain = llm | StrOutputParser()
chain_ = chain.with_config({"callbacks": [MyHandler()]})
st = chain_.stream("hello")
next(st)
# This is a generator so close is defined on it.
st.close() # type: ignore
# Wait for a bit to make sure that the callback is called.
time.sleep(0.05)
assert on_chain_error_triggered is False

0 comments on commit da7beb1

Please sign in to comment.