Skip to content

Commit

Permalink
feat: grouping paragraphs in documents and samples in audio (#810)
Browse files Browse the repository at this point in the history
Added support for 
`GROUP BY '<> samples'`
`GROUP BY '<> paragraphs'`

---------

Co-authored-by: xzdandy <[email protected]>
Co-authored-by: Gaurav <[email protected]>
  • Loading branch information
3 people authored Jun 6, 2023
1 parent 17441cf commit 70613f8
Show file tree
Hide file tree
Showing 15 changed files with 126 additions and 48 deletions.
36 changes: 28 additions & 8 deletions eva/binder/binder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 20,7 @@
from eva.catalog.catalog_type import TableType
from eva.catalog.catalog_utils import (
get_video_table_column_definitions,
is_document_table,
is_string_col,
is_video_table,
)
Expand Down Expand Up @@ -94,23 95,42 @@ def extend_star(
return target_list


def check_groupby_pattern(groupby_string: str) -> None:
# match the pattern of group by clause (e.g., 16f or 8s)
pattern = re.search(r"^\d [fs]$", groupby_string)
def check_groupby_pattern(table_ref: TableRef, groupby_string: str) -> None:
# match the pattern of group by clause (e.g., 16 frames or 8 samples)
pattern = re.search(r"^\d \s*(?:frames|samples|paragraphs)$", groupby_string)
# if valid pattern
if not pattern:
err_msg = "Incorrect GROUP BY pattern: {}".format(groupby_string)
raise BinderError(err_msg)
match_string = pattern.group(0)
if not match_string[-1] == "f":
err_msg = "Only grouping by frames (f) is supported"
suffix_string = re.sub(r"^\d \s*", "", match_string)

if suffix_string not in ["frames", "samples", "paragraphs"]:
err_msg = "Grouping only supported by frames for videos, by samples for audio, and by paragraphs for documents"
raise BinderError(err_msg)

if suffix_string == "frames" and not is_video_table(table_ref.table.table_obj):
err_msg = "Grouping by frames only supported for videos"
raise BinderError(err_msg)

if suffix_string == "samples" and not is_video_table(table_ref.table.table_obj):
err_msg = "Grouping by samples only supported for videos"
raise BinderError(err_msg)

if suffix_string == "paragraphs" and not is_document_table(
table_ref.table.table_obj
):
err_msg = "Grouping by paragraphs only supported for documents"
raise BinderError(err_msg)

# TODO ACTION condition on segment length?


def check_table_object_is_video(table_ref: TableRef) -> None:
if not is_video_table(table_ref.table.table_obj):
err_msg = "GROUP BY only supported for video tables"
def check_table_object_is_groupable(table_ref: TableRef) -> None:
if not is_video_table(table_ref.table.table_obj) and not is_document_table(
table_ref.table.table_obj
):
err_msg = "GROUP BY only supported for video and document tables"
raise BinderError(err_msg)


Expand Down
6 changes: 3 additions & 3 deletions eva/binder/statement_binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 21,7 @@
bind_table_info,
check_column_name_is_string,
check_groupby_pattern,
check_table_object_is_video,
check_table_object_is_groupable,
extend_star,
handle_bind_extract_object_function,
resolve_alias_table_value_expression,
Expand Down Expand Up @@ -122,8 122,8 @@ def _bind_select_statement(self, node: SelectStatement):
self.bind(expr)
if node.groupby_clause:
self.bind(node.groupby_clause)
check_groupby_pattern(node.groupby_clause.value)
check_table_object_is_video(node.from_table)
check_table_object_is_groupable(node.from_table)
check_groupby_pattern(node.from_table, node.groupby_clause.value)
if node.orderby_list:
for expr in node.orderby_list:
self.bind(expr[0])
Expand Down
7 changes: 7 additions & 0 deletions eva/catalog/catalog_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 55,13 @@ def is_video_table(table: TableCatalogEntry):
return table.table_type == TableType.VIDEO_DATA


def is_document_table(table: TableCatalogEntry):
return (
table.table_type == TableType.DOCUMENT_DATA
or table.table_type == TableType.PDF_DATA
)


def is_string_col(col: ColumnCatalogEntry):
return col.type == ColumnType.TEXT or col.array_type == NdArrayType.STR

Expand Down
2 changes: 1 addition & 1 deletion eva/eva.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 28,4 @@ experimental:
ray: True

third_party:
openai_api_key: ""
OPENAI_KEY: ""
6 changes: 4 additions & 2 deletions eva/executor/groupby_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Iterator

import pandas as pd
Expand All @@ -25,7 26,7 @@
class GroupByExecutor(AbstractExecutor):
"""
Group inputs into 4d segments of length provided in the query
E.g., "GROUP BY '8f'" groups every 8 frames into one segment
E.g., "GROUP BY '8 frames'" groups every 8 frames into one segment
Arguments:
node (AbstractPlan): The GroupBy Plan
Expand All @@ -34,7 35,8 @@ class GroupByExecutor(AbstractExecutor):

def __init__(self, db: EVADatabase, node: GroupByPlan):
super().__init__(db, node)
self._segment_length = int(node.groupby_clause.value[:-1])
numbers_only = re.sub(r"\D", "", node.groupby_clause.value)
self._segment_length = int(numbers_only)

def exec(self, *args, **kwargs) -> Iterator[Batch]:
child_executor = self.children[0]
Expand Down
12 changes: 10 additions & 2 deletions eva/models/storage/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,10 312,18 @@ def stack(cls, batch: Batch, copy=True) -> Batch:
if len(batch.columns) > 1:
raise ValueError("Stack can only be called on single-column batches")
frame_data_col = batch.columns[0]
data_to_stack = batch.frames[frame_data_col].values.tolist()

stacked_array = np.array(batch.frames[frame_data_col].values.tolist())
stacked_frame = pd.DataFrame([{frame_data_col: stacked_array}])
if isinstance(data_to_stack[0], np.ndarray) and len(data_to_stack[0].shape) > 1:
# if data_to_stack has more than 1 axis, we add a new axis
# [(3, 224, 224) * 10] -> (10, 3, 224, 224)
stacked_array = np.array(batch.frames[frame_data_col].values.tolist())
else:
# we concatenate along the zeroth axis
# this makes sense for audio and text
stacked_array = np.hstack(batch.frames[frame_data_col].values)

stacked_frame = pd.DataFrame([{frame_data_col: stacked_array}])
return Batch(stacked_frame)

@classmethod
Expand Down
8 changes: 3 additions & 5 deletions eva/udfs/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 47,13 @@ def setup(
temperature: float = 0,
) -> None:
# Try Configuration Manager
openai.api_key = ConfigurationManager().get_value(
"third_party", "openai_api_key"
)
openai.api_key = ConfigurationManager().get_value("third_party", "OPENAI_KEY")
# If not found, try OS Environment Variable
if len(openai.api_key) == 0:
openai.api_key = os.environ["openai_api_key"]
openai.api_key = os.environ.get("OPENAI_KEY", "")
assert (
len(openai.api_key) != 0
), "Please set your OpenAI API key in eva.yml file (third_party, open_api_key)"
), "Please set your OpenAI API key in eva.yml file (third_party, open_api_key) or environment variable (OPENAI_KEY)"

assert model in _VALID_CHAT_COMPLETION_MODEL, f"Unsupported ChatGPT {model}"

Expand Down
9 changes: 5 additions & 4 deletions test/binder/test_statement_binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,22 225,23 @@ def test_bind_func_expr(
)
self.assertEqual(str(cm.exception), err_msg)

@patch("eva.binder.statement_binder.check_table_object_is_video")
def test_bind_select_statement(self, is_video_mock):
@patch("eva.binder.statement_binder.check_table_object_is_groupable")
@patch("eva.binder.statement_binder.check_groupby_pattern")
def test_bind_select_statement(self, is_groupable_mock, groupby_mock):
with patch.object(StatementBinder, "bind") as mock_binder:
binder = StatementBinder(StatementBinderContext(MagicMock()))
select_statement = MagicMock()
mocks = [MagicMock(), MagicMock(), MagicMock(), MagicMock(), MagicMock()]
select_statement.target_list = mocks[:2]
select_statement.orderby_list = [(mocks[2], 0), (mocks[3], 0)]
select_statement.groupby_clause = mocks[4]
select_statement.groupby_clause.value = "8f"
select_statement.groupby_clause.value = "8 frames"
binder._bind_select_statement(select_statement)
mock_binder.assert_any_call(select_statement.from_table)
mock_binder.assert_any_call(select_statement.where_clause)
mock_binder.assert_any_call(select_statement.groupby_clause)
mock_binder.assert_any_call(select_statement.union_link)
is_video_mock.assert_called()
is_groupable_mock.assert_called()
for mock in mocks:
mock_binder.assert_any_call(mock)

Expand Down
29 changes: 28 additions & 1 deletion test/integration_tests/test_huggingface_udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 228,16 @@ def test_automatic_speech_recognition(self):
# verify that speech was converted to text correctly
self.assertTrue(output.frames.iloc[0][0].count("touchdown") == 2)

select_query_with_group_by = (
f"SELECT {udf_name}(SEGMENT(audio)) FROM VIDEOS GROUP BY '240 samples';"
)
output = execute_query_fetch_all(self.evadb, select_query_with_group_by)

# verify that output has one row and one column only
self.assertEquals(output.frames.shape, (4, 1))
# verify that speech was converted to text correctly
self.assertEquals(output.frames.iloc[0][0].count("touchdown"), 1)

drop_udf_query = f"DROP UDF {udf_name};"
execute_query_fetch_all(self.evadb, drop_udf_query)

Expand All @@ -243,7 253,7 @@ def test_summarization_from_video(self):
summary_udf = "Summarizer"
create_udf = (
f"CREATE UDF {summary_udf} TYPE HuggingFace "
"'task' 'summarization' 'model' 'philschmid/bart-large-cnn-samsum' 'min_length' 10 'max_length' 100;"
"'task' 'summarization' 'model' 'philschmid/bart-large-cnn-samsum' 'min_length' 10 'max_new_tokens' 100;"
)
execute_query_fetch_all(self.evadb, create_udf)

Expand Down Expand Up @@ -386,6 396,16 @@ def test_named_entity_recognition_model_all_pdf_data(self):
drop_udf_query = f"DROP UDF {udf_name};"
execute_query_fetch_all(self.evadb, drop_udf_query)

def test_select_and_groupby_with_paragraphs(self):
segment_size = 10
select_query = (
"SELECT SEGMENT(data) FROM MyPDFs GROUP BY '{}paragraphs';".format(
segment_size
)
)
output = execute_query_fetch_all(self.evadb, select_query)
self.assertEqual(len(output.frames), 3)

@pytest.mark.benchmark
def test_named_entity_recognition_model_no_ner_data_exists(self):
udf_name = "HFNERModel"
Expand All @@ -410,3 430,10 @@ def test_named_entity_recognition_model_no_ner_data_exists(self):

drop_udf_query = f"DROP UDF {udf_name};"
execute_query_fetch_all(self.evadb, drop_udf_query)


if __name__ == "__main__":
suite = unittest.TestSuite()
suite.addTest(HuggingFaceTests("test_automatic_speech_recognition"))
runner = unittest.TextTestRunner()
runner.run(suite)
4 changes: 2 additions & 2 deletions test/integration_tests/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 167,7 @@ def test_should_run_pytorch_and_yolo_and_mvit(self):
MVITActionRecognition(SEGMENT(data))
FROM Actions
WHERE id < 32
GROUP BY '16f'; """
GROUP BY '16 frames'; """
actual_batch = execute_query_fetch_all(self.evadb, select_query)
self.assertEqual(len(actual_batch), 2)

Expand All @@ -184,7 184,7 @@ def test_should_run_pytorch_and_asl(self):
select_query = """SELECT FIRST(id), ASLActionRecognition(SEGMENT(data))
FROM Asl_actions
SAMPLE 5
GROUP BY '16f';"""
GROUP BY '16 frames';"""
actual_batch = execute_query_fetch_all(self.evadb, select_query)

res = actual_batch.frames
Expand Down
16 changes: 7 additions & 9 deletions test/integration_tests/test_select_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 323,7 @@ def test_select_and_groupby_first(self):
# only applies to video data which is already sorted
segment_size = 3
select_query = (
"SELECT FIRST(id), SEGMENT(data) FROM MyVideo GROUP BY '{}f';".format(
"SELECT FIRST(id), SEGMENT(data) FROM MyVideo GROUP BY '{} frames';".format(
segment_size
)
)
Expand All @@ -348,7 348,7 @@ def test_select_and_groupby_with_last(self):
# only applies to video data which is already sorted
segment_size = 3
select_query = (
"SELECT LAST(id), SEGMENT(data) FROM MyVideo GROUP BY '{}f';".format(
"SELECT LAST(id), SEGMENT(data) FROM MyVideo GROUP BY '{}frames';".format(
segment_size
)
)
Expand All @@ -373,7 373,7 @@ def test_select_and_groupby_with_last(self):
def test_select_and_groupby_should_fail_with_incorrect_pattern(self):
segment_size = "4a"
select_query = (
"SELECT FIRST(id), SEGMENT(data) FROM MyVideo GROUP BY '{}f';".format(
"SELECT FIRST(id), SEGMENT(data) FROM MyVideo GROUP BY '{} frames';".format(
segment_size
)
)
Expand All @@ -383,18 383,16 @@ def test_select_and_groupby_should_fail_with_incorrect_pattern(self):

def test_select_and_groupby_should_fail_with_seconds(self):
segment_size = 4
select_query = (
"SELECT FIRST(id), SEGMENT(data) FROM MyVideo GROUP BY '{}s';".format(
segment_size
)
select_query = "SELECT FIRST(id), SEGMENT(data) FROM MyVideo GROUP BY '{} seconds';".format(
segment_size
)
self.assertRaises(
BinderError, execute_query_fetch_all, self.evadb, select_query
)

def test_select_and_groupby_should_fail_with_non_video_table(self):
segment_size = 4
select_query = "SELECT FIRST(a1) FROM table1 GROUP BY '{}f';".format(
select_query = "SELECT FIRST(a1) FROM table1 GROUP BY '{} frames';".format(
segment_size
)
self.assertRaises(
Expand All @@ -406,7 404,7 @@ def test_select_and_groupby_with_sample(self):
# only applies to video data which is already sorted
segment_size = 2
sampling_rate = 2
select_query = "SELECT FIRST(id), SEGMENT(data) FROM MyVideo SAMPLE {} GROUP BY '{}f';".format(
select_query = "SELECT FIRST(id), SEGMENT(data) FROM MyVideo SAMPLE {} GROUP BY '{} frames';".format(
sampling_rate, segment_size
)
actual_batch = execute_query_fetch_all(self.evadb, select_query)
Expand Down
6 changes: 3 additions & 3 deletions test/parser/test_parser.py
Original file line number Diff line number Diff line change
@@ -1,5 1,5 @@
# coding=utf-8
# Copyright 2018-2022 EVA
# Copyright 2018-2023 EVA
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -395,7 395,7 @@ def test_select_statement_groupby_class(self):

parser = Parser()

select_query = "SELECT FIRST(id) FROM TAIPAI GROUP BY '8f';"
select_query = "SELECT FIRST(id) FROM TAIPAI GROUP BY '8 frames';"

eva_statement_list = parser.parse(select_query)
self.assertIsInstance(eva_statement_list, list)
Expand All @@ -419,7 419,7 @@ def test_select_statement_groupby_class(self):
# sample_freq
self.assertEqual(
select_stmt.groupby_clause,
ConstantValueExpression("8f", v_type=ColumnType.TEXT),
ConstantValueExpression("8 frames", v_type=ColumnType.TEXT),
)

def test_select_statement_orderby_class(self):
Expand Down
4 changes: 2 additions & 2 deletions test/parser/test_parser_statements.py
Original file line number Diff line number Diff line change
@@ -1,5 1,5 @@
# coding=utf-8
# Copyright 2018-2022 EVA
# Copyright 2018-2023 EVA
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -44,7 44,7 @@ def test_parser_statement_types(self):
UNION ALL SELECT CLASS, REDNESS FROM SHANGHAI;",
"SELECT CLASS, REDNESS FROM TAIPAI \
UNION SELECT CLASS, REDNESS FROM SHANGHAI;",
"SELECT FIRST(id) FROM TAIPAI GROUP BY '8f';",
"SELECT FIRST(id) FROM TAIPAI GROUP BY '8 frames';",
"SELECT CLASS, REDNESS FROM TAIPAI \
WHERE (CLASS = 'VAN' AND REDNESS < 400 ) OR REDNESS > 700 \
ORDER BY CLASS, REDNESS DESC;",
Expand Down
Loading

0 comments on commit 70613f8

Please sign in to comment.