diff --git a/data/forecasting/AirPassengersPanel.csv b/data/forecasting/AirPassengersPanel.csv new file mode 100644 index 0000000000..a62fe6ef60 --- /dev/null +++ b/data/forecasting/AirPassengersPanel.csv @@ -0,0 +1,289 @@ +ds,unique_id,y,trend,ylagged +1949-01-31,Airline1,112.0,0,112.0 +1949-02-28,Airline1,118.0,1,118.0 +1949-03-31,Airline1,132.0,2,132.0 +1949-04-30,Airline1,129.0,3,129.0 +1949-05-31,Airline1,121.0,4,121.0 +1949-06-30,Airline1,135.0,5,135.0 +1949-07-31,Airline1,148.0,6,148.0 +1949-08-31,Airline1,148.0,7,148.0 +1949-09-30,Airline1,136.0,8,136.0 +1949-10-31,Airline1,119.0,9,119.0 +1949-11-30,Airline1,104.0,10,104.0 +1949-12-31,Airline1,118.0,11,118.0 +1950-01-31,Airline1,115.0,12,112.0 +1950-02-28,Airline1,126.0,13,118.0 +1950-03-31,Airline1,141.0,14,132.0 +1950-04-30,Airline1,135.0,15,129.0 +1950-05-31,Airline1,125.0,16,121.0 +1950-06-30,Airline1,149.0,17,135.0 +1950-07-31,Airline1,170.0,18,148.0 +1950-08-31,Airline1,170.0,19,148.0 +1950-09-30,Airline1,158.0,20,136.0 +1950-10-31,Airline1,133.0,21,119.0 +1950-11-30,Airline1,114.0,22,104.0 +1950-12-31,Airline1,140.0,23,118.0 +1951-01-31,Airline1,145.0,24,115.0 +1951-02-28,Airline1,150.0,25,126.0 +1951-03-31,Airline1,178.0,26,141.0 +1951-04-30,Airline1,163.0,27,135.0 +1951-05-31,Airline1,172.0,28,125.0 +1951-06-30,Airline1,178.0,29,149.0 +1951-07-31,Airline1,199.0,30,170.0 +1951-08-31,Airline1,199.0,31,170.0 +1951-09-30,Airline1,184.0,32,158.0 +1951-10-31,Airline1,162.0,33,133.0 +1951-11-30,Airline1,146.0,34,114.0 +1951-12-31,Airline1,166.0,35,140.0 +1952-01-31,Airline1,171.0,36,145.0 +1952-02-29,Airline1,180.0,37,150.0 +1952-03-31,Airline1,193.0,38,178.0 +1952-04-30,Airline1,181.0,39,163.0 +1952-05-31,Airline1,183.0,40,172.0 +1952-06-30,Airline1,218.0,41,178.0 +1952-07-31,Airline1,230.0,42,199.0 +1952-08-31,Airline1,242.0,43,199.0 +1952-09-30,Airline1,209.0,44,184.0 +1952-10-31,Airline1,191.0,45,162.0 +1952-11-30,Airline1,172.0,46,146.0 +1952-12-31,Airline1,194.0,47,166.0 +1953-01-31,Airline1,196.0,48,171.0 +1953-02-28,Airline1,196.0,49,180.0 +1953-03-31,Airline1,236.0,50,193.0 +1953-04-30,Airline1,235.0,51,181.0 +1953-05-31,Airline1,229.0,52,183.0 +1953-06-30,Airline1,243.0,53,218.0 +1953-07-31,Airline1,264.0,54,230.0 +1953-08-31,Airline1,272.0,55,242.0 +1953-09-30,Airline1,237.0,56,209.0 +1953-10-31,Airline1,211.0,57,191.0 +1953-11-30,Airline1,180.0,58,172.0 +1953-12-31,Airline1,201.0,59,194.0 +1954-01-31,Airline1,204.0,60,196.0 +1954-02-28,Airline1,188.0,61,196.0 +1954-03-31,Airline1,235.0,62,236.0 +1954-04-30,Airline1,227.0,63,235.0 +1954-05-31,Airline1,234.0,64,229.0 +1954-06-30,Airline1,264.0,65,243.0 +1954-07-31,Airline1,302.0,66,264.0 +1954-08-31,Airline1,293.0,67,272.0 +1954-09-30,Airline1,259.0,68,237.0 +1954-10-31,Airline1,229.0,69,211.0 +1954-11-30,Airline1,203.0,70,180.0 +1954-12-31,Airline1,229.0,71,201.0 +1955-01-31,Airline1,242.0,72,204.0 +1955-02-28,Airline1,233.0,73,188.0 +1955-03-31,Airline1,267.0,74,235.0 +1955-04-30,Airline1,269.0,75,227.0 +1955-05-31,Airline1,270.0,76,234.0 +1955-06-30,Airline1,315.0,77,264.0 +1955-07-31,Airline1,364.0,78,302.0 +1955-08-31,Airline1,347.0,79,293.0 +1955-09-30,Airline1,312.0,80,259.0 +1955-10-31,Airline1,274.0,81,229.0 +1955-11-30,Airline1,237.0,82,203.0 +1955-12-31,Airline1,278.0,83,229.0 +1956-01-31,Airline1,284.0,84,242.0 +1956-02-29,Airline1,277.0,85,233.0 +1956-03-31,Airline1,317.0,86,267.0 +1956-04-30,Airline1,313.0,87,269.0 +1956-05-31,Airline1,318.0,88,270.0 +1956-06-30,Airline1,374.0,89,315.0 +1956-07-31,Airline1,413.0,90,364.0 +1956-08-31,Airline1,405.0,91,347.0 +1956-09-30,Airline1,355.0,92,312.0 +1956-10-31,Airline1,306.0,93,274.0 +1956-11-30,Airline1,271.0,94,237.0 +1956-12-31,Airline1,306.0,95,278.0 +1957-01-31,Airline1,315.0,96,284.0 +1957-02-28,Airline1,301.0,97,277.0 +1957-03-31,Airline1,356.0,98,317.0 +1957-04-30,Airline1,348.0,99,313.0 +1957-05-31,Airline1,355.0,100,318.0 +1957-06-30,Airline1,422.0,101,374.0 +1957-07-31,Airline1,465.0,102,413.0 +1957-08-31,Airline1,467.0,103,405.0 +1957-09-30,Airline1,404.0,104,355.0 +1957-10-31,Airline1,347.0,105,306.0 +1957-11-30,Airline1,305.0,106,271.0 +1957-12-31,Airline1,336.0,107,306.0 +1958-01-31,Airline1,340.0,108,315.0 +1958-02-28,Airline1,318.0,109,301.0 +1958-03-31,Airline1,362.0,110,356.0 +1958-04-30,Airline1,348.0,111,348.0 +1958-05-31,Airline1,363.0,112,355.0 +1958-06-30,Airline1,435.0,113,422.0 +1958-07-31,Airline1,491.0,114,465.0 +1958-08-31,Airline1,505.0,115,467.0 +1958-09-30,Airline1,404.0,116,404.0 +1958-10-31,Airline1,359.0,117,347.0 +1958-11-30,Airline1,310.0,118,305.0 +1958-12-31,Airline1,337.0,119,336.0 +1959-01-31,Airline1,360.0,120,340.0 +1959-02-28,Airline1,342.0,121,318.0 +1959-03-31,Airline1,406.0,122,362.0 +1959-04-30,Airline1,396.0,123,348.0 +1959-05-31,Airline1,420.0,124,363.0 +1959-06-30,Airline1,472.0,125,435.0 +1959-07-31,Airline1,548.0,126,491.0 +1959-08-31,Airline1,559.0,127,505.0 +1959-09-30,Airline1,463.0,128,404.0 +1959-10-31,Airline1,407.0,129,359.0 +1959-11-30,Airline1,362.0,130,310.0 +1959-12-31,Airline1,405.0,131,337.0 +1960-01-31,Airline1,417.0,132,360.0 +1960-02-29,Airline1,391.0,133,342.0 +1960-03-31,Airline1,419.0,134,406.0 +1960-04-30,Airline1,461.0,135,396.0 +1960-05-31,Airline1,472.0,136,420.0 +1960-06-30,Airline1,535.0,137,472.0 +1960-07-31,Airline1,622.0,138,548.0 +1960-08-31,Airline1,606.0,139,559.0 +1960-09-30,Airline1,508.0,140,463.0 +1960-10-31,Airline1,461.0,141,407.0 +1960-11-30,Airline1,390.0,142,362.0 +1960-12-31,Airline1,432.0,143,405.0 +1949-01-31,Airline2,412.0,144,412.0 +1949-02-28,Airline2,418.0,145,418.0 +1949-03-31,Airline2,432.0,146,432.0 +1949-04-30,Airline2,429.0,147,429.0 +1949-05-31,Airline2,421.0,148,421.0 +1949-06-30,Airline2,435.0,149,435.0 +1949-07-31,Airline2,448.0,150,448.0 +1949-08-31,Airline2,448.0,151,448.0 +1949-09-30,Airline2,436.0,152,436.0 +1949-10-31,Airline2,419.0,153,419.0 +1949-11-30,Airline2,404.0,154,404.0 +1949-12-31,Airline2,418.0,155,418.0 +1950-01-31,Airline2,415.0,156,412.0 +1950-02-28,Airline2,426.0,157,418.0 +1950-03-31,Airline2,441.0,158,432.0 +1950-04-30,Airline2,435.0,159,429.0 +1950-05-31,Airline2,425.0,160,421.0 +1950-06-30,Airline2,449.0,161,435.0 +1950-07-31,Airline2,470.0,162,448.0 +1950-08-31,Airline2,470.0,163,448.0 +1950-09-30,Airline2,458.0,164,436.0 +1950-10-31,Airline2,433.0,165,419.0 +1950-11-30,Airline2,414.0,166,404.0 +1950-12-31,Airline2,440.0,167,418.0 +1951-01-31,Airline2,445.0,168,415.0 +1951-02-28,Airline2,450.0,169,426.0 +1951-03-31,Airline2,478.0,170,441.0 +1951-04-30,Airline2,463.0,171,435.0 +1951-05-31,Airline2,472.0,172,425.0 +1951-06-30,Airline2,478.0,173,449.0 +1951-07-31,Airline2,499.0,174,470.0 +1951-08-31,Airline2,499.0,175,470.0 +1951-09-30,Airline2,484.0,176,458.0 +1951-10-31,Airline2,462.0,177,433.0 +1951-11-30,Airline2,446.0,178,414.0 +1951-12-31,Airline2,466.0,179,440.0 +1952-01-31,Airline2,471.0,180,445.0 +1952-02-29,Airline2,480.0,181,450.0 +1952-03-31,Airline2,493.0,182,478.0 +1952-04-30,Airline2,481.0,183,463.0 +1952-05-31,Airline2,483.0,184,472.0 +1952-06-30,Airline2,518.0,185,478.0 +1952-07-31,Airline2,530.0,186,499.0 +1952-08-31,Airline2,542.0,187,499.0 +1952-09-30,Airline2,509.0,188,484.0 +1952-10-31,Airline2,491.0,189,462.0 +1952-11-30,Airline2,472.0,190,446.0 +1952-12-31,Airline2,494.0,191,466.0 +1953-01-31,Airline2,496.0,192,471.0 +1953-02-28,Airline2,496.0,193,480.0 +1953-03-31,Airline2,536.0,194,493.0 +1953-04-30,Airline2,535.0,195,481.0 +1953-05-31,Airline2,529.0,196,483.0 +1953-06-30,Airline2,543.0,197,518.0 +1953-07-31,Airline2,564.0,198,530.0 +1953-08-31,Airline2,572.0,199,542.0 +1953-09-30,Airline2,537.0,200,509.0 +1953-10-31,Airline2,511.0,201,491.0 +1953-11-30,Airline2,480.0,202,472.0 +1953-12-31,Airline2,501.0,203,494.0 +1954-01-31,Airline2,504.0,204,496.0 +1954-02-28,Airline2,488.0,205,496.0 +1954-03-31,Airline2,535.0,206,536.0 +1954-04-30,Airline2,527.0,207,535.0 +1954-05-31,Airline2,534.0,208,529.0 +1954-06-30,Airline2,564.0,209,543.0 +1954-07-31,Airline2,602.0,210,564.0 +1954-08-31,Airline2,593.0,211,572.0 +1954-09-30,Airline2,559.0,212,537.0 +1954-10-31,Airline2,529.0,213,511.0 +1954-11-30,Airline2,503.0,214,480.0 +1954-12-31,Airline2,529.0,215,501.0 +1955-01-31,Airline2,542.0,216,504.0 +1955-02-28,Airline2,533.0,217,488.0 +1955-03-31,Airline2,567.0,218,535.0 +1955-04-30,Airline2,569.0,219,527.0 +1955-05-31,Airline2,570.0,220,534.0 +1955-06-30,Airline2,615.0,221,564.0 +1955-07-31,Airline2,664.0,222,602.0 +1955-08-31,Airline2,647.0,223,593.0 +1955-09-30,Airline2,612.0,224,559.0 +1955-10-31,Airline2,574.0,225,529.0 +1955-11-30,Airline2,537.0,226,503.0 +1955-12-31,Airline2,578.0,227,529.0 +1956-01-31,Airline2,584.0,228,542.0 +1956-02-29,Airline2,577.0,229,533.0 +1956-03-31,Airline2,617.0,230,567.0 +1956-04-30,Airline2,613.0,231,569.0 +1956-05-31,Airline2,618.0,232,570.0 +1956-06-30,Airline2,674.0,233,615.0 +1956-07-31,Airline2,713.0,234,664.0 +1956-08-31,Airline2,705.0,235,647.0 +1956-09-30,Airline2,655.0,236,612.0 +1956-10-31,Airline2,606.0,237,574.0 +1956-11-30,Airline2,571.0,238,537.0 +1956-12-31,Airline2,606.0,239,578.0 +1957-01-31,Airline2,615.0,240,584.0 +1957-02-28,Airline2,601.0,241,577.0 +1957-03-31,Airline2,656.0,242,617.0 +1957-04-30,Airline2,648.0,243,613.0 +1957-05-31,Airline2,655.0,244,618.0 +1957-06-30,Airline2,722.0,245,674.0 +1957-07-31,Airline2,765.0,246,713.0 +1957-08-31,Airline2,767.0,247,705.0 +1957-09-30,Airline2,704.0,248,655.0 +1957-10-31,Airline2,647.0,249,606.0 +1957-11-30,Airline2,605.0,250,571.0 +1957-12-31,Airline2,636.0,251,606.0 +1958-01-31,Airline2,640.0,252,615.0 +1958-02-28,Airline2,618.0,253,601.0 +1958-03-31,Airline2,662.0,254,656.0 +1958-04-30,Airline2,648.0,255,648.0 +1958-05-31,Airline2,663.0,256,655.0 +1958-06-30,Airline2,735.0,257,722.0 +1958-07-31,Airline2,791.0,258,765.0 +1958-08-31,Airline2,805.0,259,767.0 +1958-09-30,Airline2,704.0,260,704.0 +1958-10-31,Airline2,659.0,261,647.0 +1958-11-30,Airline2,610.0,262,605.0 +1958-12-31,Airline2,637.0,263,636.0 +1959-01-31,Airline2,660.0,264,640.0 +1959-02-28,Airline2,642.0,265,618.0 +1959-03-31,Airline2,706.0,266,662.0 +1959-04-30,Airline2,696.0,267,648.0 +1959-05-31,Airline2,720.0,268,663.0 +1959-06-30,Airline2,772.0,269,735.0 +1959-07-31,Airline2,848.0,270,791.0 +1959-08-31,Airline2,859.0,271,805.0 +1959-09-30,Airline2,763.0,272,704.0 +1959-10-31,Airline2,707.0,273,659.0 +1959-11-30,Airline2,662.0,274,610.0 +1959-12-31,Airline2,705.0,275,637.0 +1960-01-31,Airline2,717.0,276,660.0 +1960-02-29,Airline2,691.0,277,642.0 +1960-03-31,Airline2,719.0,278,706.0 +1960-04-30,Airline2,761.0,279,696.0 +1960-05-31,Airline2,772.0,280,720.0 +1960-06-30,Airline2,835.0,281,772.0 +1960-07-31,Airline2,922.0,282,848.0 +1960-08-31,Airline2,906.0,283,859.0 +1960-09-30,Airline2,808.0,284,763.0 +1960-10-31,Airline2,761.0,285,707.0 +1960-11-30,Airline2,690.0,286,662.0 +1960-12-31,Airline2,732.0,287,705.0 diff --git a/docs/source/reference/ai/model-forecasting.rst b/docs/source/reference/ai/model-forecasting.rst index f88462be17..8285ad76b6 100644 --- a/docs/source/reference/ai/model-forecasting.rst +++ b/docs/source/reference/ai/model-forecasting.rst @@ -47,16 +47,24 @@ EvaDB's default forecast framework is `statsforecast `_ for details. If not provided, an auto increasing ID column will be used. - * - ID - - The name of column that represents an identifier for the series. If not provided, the whole table is considered as one series of data. - * - MODEL - - We can select one of AutoARIMA, AutoCES, AutoETS, AutoTheta. The default is AutoARIMA. Check `Automatic Forecasting `_ to learn details about these models. - * - Frequency - - A string indicating the frequency of the data. The common used ones are D, W, M, Y, which repestively represents day-, week-, month- and year- end frequency. The default value is M. Check `pandas available frequencies `_ for all available frequencies. + * - HORIZON (int, required) + - The number of steps into the future we wish to forecast. + * - TIME (str, default: 'ds') + - The name of the column that contains the datestamp, which should be of a format expected by Pandas, ideally YYYY-MM-DD for a date or YYYY-MM-DD HH:MM:SS for a timestamp. Please visit the `pandas documentation `_ for details. If relevant column is not found, an auto increasing ID column will be used. + * - ID (str, default: 'unique_id') + - The name of column that represents an identifier for the series. If relevant column is not found, the whole table is considered as one series of data. + * - LIBRARY (str, default: 'statsforecast') + - We can select one of `statsforecast` (default) or `neuralforecast`. `statsforecast` provides access to statistical forecasting methods, while `neuralforecast` gives access to deep-learning based forecasting methods. + * - MODEL (str, default: 'ARIMA') + - If LIBRARY is `statsforecast`, we can select one of ARIMA, CES, ETS, Theta. The default is ARIMA. Check `Automatic Forecasting `_ to learn details about these models. If LIBRARY is `neuralforecast`, we can select one of NHITS or NBEATS. The default is NBEATS. Check `NBEATS docs `_ for details. + * - AUTO (str, default: 'T') + - If set to 'T', it enables automatic hyperparameter optimization. Must be set to 'T' for `statsforecast` library. One may set this parameter to `false` if LIBRARY is `neuralforecast` for faster (but less reliable) results. + * - Frequency (str, default: 'auto') + - A string indicating the frequency of the data. The common used ones are D, W, M, Y, which repestively represents day-, week-, month- and year- end frequency. The default value is M. Check `pandas available frequencies `_ for all available frequencies. If it is not provided, the frequency is attempted to be determined automatically. + +Note: If columns other than the ones required as mentioned above are passed while creating the function, they will be treated as exogenous variables if LIBRARY is `neuralforecast`. Otherwise, they would be ignored. Below is an example query specifying the above parameters: @@ -65,8 +73,21 @@ Below is an example query specifying the above parameters: CREATE FUNCTION IF NOT EXISTS HomeRentalForecast FROM (SELECT saledate, ma, type FROM HomeData) TYPE Forecasting + HORIZON 12 PREDICT 'ma' TIME 'saledate' ID 'type' - MODEL 'AutoCES' Frequency 'W'; + +Below is an example query with `neuralforecast` with `trend` column as exogenous and without automatic hyperparameter optimization: + +.. code-block:: sql + + CREATE FUNCTION AirPanelForecast FROM + (SELECT unique_id, ds, y, trend FROM AirDataPanel) + TYPE Forecasting + HORIZON 12 + PREDICT 'y' + LIBRARY 'neuralforecast' + AUTO 'f' + FREQUENCY 'M'; \ No newline at end of file diff --git a/evadb/binder/statement_binder.py b/evadb/binder/statement_binder.py index 1b85ecafcb..eb881c483d 100644 --- a/evadb/binder/statement_binder.py +++ b/evadb/binder/statement_binder.py @@ -126,10 +126,6 @@ def _bind_create_function_statement(self, node: CreateFunctionStatement): elif column.name == arg_map.get("predict", "y"): outputs.append(column) required_columns.remove(column.name) - else: - raise BinderError( - f"Unexpected column {column.name} found for forecasting function." - ) assert ( len(required_columns) == 0 ), f"Missing required {required_columns} columns for forecasting function." diff --git a/evadb/executor/create_function_executor.py b/evadb/executor/create_function_executor.py index 500d1f6869..8f4b5ad274 100644 --- a/evadb/executor/create_function_executor.py +++ b/evadb/executor/create_function_executor.py @@ -38,9 +38,10 @@ from evadb.utils.generic_utils import ( load_function_class_from_file, string_comparison_case_insensitive, - try_to_import_forecast, try_to_import_ludwig, + try_to_import_neuralforecast, try_to_import_sklearn, + try_to_import_statsforecast, try_to_import_torch, try_to_import_ultralytics, ) @@ -183,6 +184,7 @@ def handle_ultralytics_function(self): def handle_forecasting_function(self): """Handle forecasting functions""" + os.environ["CUDA_VISIBLE_DEVICES"] = "" aggregated_batch_list = [] child = self.children[0] for batch in child.exec(): @@ -195,14 +197,34 @@ def handle_forecasting_function(self): impl_path = Path(f"{self.function_dir}/forecast.py").absolute().as_posix() else: impl_path = self.node.impl_path.absolute().as_posix() + library = "statsforecast" + supported_libraries = ["statsforecast", "neuralforecast"] - if "model" not in arg_map.keys(): - arg_map["model"] = "AutoARIMA" + if "horizon" not in arg_map.keys(): + raise ValueError( + "Horizon must be provided while creating function of type FORECASTING" + ) + try: + horizon = int(arg_map["horizon"]) + except Exception as e: + err_msg = f"{str(e)}. HORIZON must be integral." + logger.error(err_msg) + raise FunctionIODefinitionError(err_msg) + + if "library" in arg_map.keys(): + try: + assert arg_map["library"].lower() in supported_libraries + except Exception: + err_msg = ( + "EvaDB currently supports " + str(supported_libraries) + " only." + ) + logger.error(err_msg) + raise FunctionIODefinitionError(err_msg) - model_name = arg_map["model"] + library = arg_map["library"].lower() """ - The following rename is needed for statsforecast, which requires the column name to be the following: + The following rename is needed for statsforecast/neuralforecast, which requires the column name to be the following: - The unique_id (string, int or category) represents an identifier for the series. - The ds (datestamp) column should be of a format expected by Pandas, ideally YYYY-MM-DD for a date or YYYY-MM-DD HH:MM:SS for a timestamp. - The y (numeric) represents the measurement we wish to forecast. @@ -221,7 +243,11 @@ def handle_forecasting_function(self): if "ds" not in list(data.columns): data["ds"] = [x + 1 for x in range(len(data))] - if "frequency" not in arg_map.keys(): + """ + Set or infer data frequency + """ + + if "frequency" not in arg_map.keys() or arg_map["frequency"] == "auto": arg_map["frequency"] = pd.infer_freq(data["ds"]) frequency = arg_map["frequency"] if frequency is None: @@ -229,17 +255,6 @@ def handle_forecasting_function(self): f"Can not infer the frequency for {self.node.name}. Please explictly set it." ) - try_to_import_forecast() - from statsforecast import StatsForecast - from statsforecast.models import AutoARIMA, AutoCES, AutoETS, AutoTheta - - model_dict = { - "AutoARIMA": AutoARIMA, - "AutoCES": AutoCES, - "AutoETS": AutoETS, - "AutoTheta": AutoTheta, - } - season_dict = { # https://pandas.pydata.org/docs/user_guide/timeseries.html#timeseries-offset-aliases "H": 24, "M": 12, @@ -255,32 +270,144 @@ def handle_forecasting_function(self): frequency.split("-")[0] if "-" in frequency else frequency ) # shortens longer frequencies like Q-DEC season_length = season_dict[new_freq] if new_freq in season_dict else 1 - model = StatsForecast( - [model_dict[model_name](season_length=season_length)], freq=new_freq - ) + + """ + Neuralforecast implementation + """ + if library == "neuralforecast": + try_to_import_neuralforecast() + from neuralforecast import NeuralForecast + from neuralforecast.auto import AutoNBEATS, AutoNHITS + from neuralforecast.models import NBEATS, NHITS + + model_dict = { + "AutoNBEATS": AutoNBEATS, + "AutoNHITS": AutoNHITS, + "NBEATS": NBEATS, + "NHITS": NHITS, + } + + if "model" not in arg_map.keys(): + arg_map["model"] = "NBEATS" + + if "auto" not in arg_map.keys() or ( + arg_map["auto"].lower()[0] == "t" + and "auto" not in arg_map["model"].lower() + ): + arg_map["model"] = "Auto" + arg_map["model"] + + try: + model_here = model_dict[arg_map["model"]] + except Exception: + err_msg = "Supported models: " + str(model_dict.keys()) + logger.error(err_msg) + raise FunctionIODefinitionError(err_msg) + model_args = {} + + if "auto" not in arg_map["model"].lower(): + model_args["input_size"] = 2 * horizon + model_args["early_stop_patience_steps"] = 20 + else: + model_args["config"] = { + "input_size": 2 * horizon, + "early_stop_patience_steps": 20, + } + + if len(data.columns) >= 4: + exogenous_columns = [ + x for x in list(data.columns) if x not in ["ds", "y", "unique_id"] + ] + if "auto" not in arg_map["model"].lower(): + model_args["hist_exog_list"] = exogenous_columns + else: + model_args["config"]["hist_exog_list"] = exogenous_columns + + model_args["h"] = horizon + + model = NeuralForecast( + [model_here(**model_args)], + freq=new_freq, + ) + + # """ + # Statsforecast implementation + # """ + else: + if "auto" in arg_map.keys() and arg_map["auto"].lower()[0] != "t": + raise RuntimeError( + "Statsforecast implementation only supports automatic hyperparameter optimization. Please set AUTO to true." + ) + try_to_import_statsforecast() + from statsforecast import StatsForecast + from statsforecast.models import AutoARIMA, AutoCES, AutoETS, AutoTheta + + model_dict = { + "AutoARIMA": AutoARIMA, + "AutoCES": AutoCES, + "AutoETS": AutoETS, + "AutoTheta": AutoTheta, + } + + if "model" not in arg_map.keys(): + arg_map["model"] = "ARIMA" + + if "auto" not in arg_map["model"].lower(): + arg_map["model"] = "Auto" + arg_map["model"] + + try: + model_here = model_dict[arg_map["model"]] + except Exception: + err_msg = "Supported models: " + str(model_dict.keys()) + logger.error(err_msg) + raise FunctionIODefinitionError(err_msg) + + model = StatsForecast( + [model_here(season_length=season_length)], freq=new_freq + ) + + data["ds"] = pd.to_datetime(data["ds"]) + + model_save_dir_name = library + "_" + arg_map["model"] + "_" + new_freq + if len(data.columns) >= 4 and library == "neuralforecast": + model_save_dir_name += "_exogenous_" + str(sorted(exogenous_columns)) model_dir = os.path.join( - self.db.config.get_value("storage", "model_dir"), self.node.name - ) - Path(model_dir).mkdir(parents=True, exist_ok=True) - model_path = os.path.join( self.db.config.get_value("storage", "model_dir"), - self.node.name, - str(hashlib.sha256(data.to_string().encode()).hexdigest()) + ".pkl", + "tsforecasting", + model_save_dir_name, + str(hashlib.sha256(data.to_string().encode()).hexdigest()), ) + Path(model_dir).mkdir(parents=True, exist_ok=True) - weight_file = Path(model_path) - data["ds"] = pd.to_datetime(data["ds"]) - if not weight_file.exists(): - model.fit(data) + model_save_name = "horizon" + str(horizon) + ".pkl" + + model_path = os.path.join(model_dir, model_save_name) + + existing_model_files = sorted( + os.listdir(model_dir), + key=lambda x: int(x.split("horizon")[1].split(".pkl")[0]), + ) + existing_model_files = [ + x + for x in existing_model_files + if int(x.split("horizon")[1].split(".pkl")[0]) >= horizon + ] + if len(existing_model_files) == 0: + print("Training, please wait...") + if library == "neuralforecast": + model.fit(df=data, val_size=horizon) + else: + model.fit(df=data[["ds", "y", "unique_id"]]) f = open(model_path, "wb") pickle.dump(model, f) f.close() + elif not Path(model_path).exists(): + model_path = os.path.join(model_dir, existing_model_files[-1]) io_list = self._resolve_function_io(None) metadata_here = [ - FunctionMetadataCatalogEntry("model_name", model_name), + FunctionMetadataCatalogEntry("model_name", arg_map["model"]), FunctionMetadataCatalogEntry("model_path", model_path), FunctionMetadataCatalogEntry( "predict_column_rename", arg_map.get("predict", "y") @@ -291,8 +418,12 @@ def handle_forecasting_function(self): FunctionMetadataCatalogEntry( "id_column_rename", arg_map.get("id", "unique_id") ), + FunctionMetadataCatalogEntry("horizon", horizon), + FunctionMetadataCatalogEntry("library", library), ] + os.environ.pop("CUDA_VISIBLE_DEVICES", None) + return ( self.node.name, impl_path, diff --git a/evadb/functions/forecast.py b/evadb/functions/forecast.py index f7cfb72f9c..1571f6c4fc 100644 --- a/evadb/functions/forecast.py +++ b/evadb/functions/forecast.py @@ -35,6 +35,8 @@ def setup( predict_column_rename: str, time_column_rename: str, id_column_rename: str, + horizon: int, + library: str, ): f = open(model_path, "rb") loaded_model = pickle.load(f) @@ -44,13 +46,14 @@ def setup( self.predict_column_rename = predict_column_rename self.time_column_rename = time_column_rename self.id_column_rename = id_column_rename + self.horizon = int(horizon) + self.library = library def forward(self, data) -> pd.DataFrame: - horizon = list(data.iloc[:, -1])[0] - assert ( - type(horizon) is int - ), "Forecast UDF expects integral horizon in parameter." - forecast_df = self.model.predict(h=horizon) + if self.library == "statsforecast": + forecast_df = self.model.predict(h=self.horizon) + else: + forecast_df = self.model.predict() forecast_df.reset_index(inplace=True) forecast_df = forecast_df.rename( columns={ @@ -58,5 +61,5 @@ def forward(self, data) -> pd.DataFrame: "ds": self.time_column_rename, self.model_name: self.predict_column_rename, } - ) + )[: self.horizon * forecast_df["unique_id"].nunique()] return forecast_df diff --git a/evadb/utils/generic_utils.py b/evadb/utils/generic_utils.py index 3abc78b288..e99f8a06a9 100644 --- a/evadb/utils/generic_utils.py +++ b/evadb/utils/generic_utils.py @@ -270,7 +270,7 @@ def try_to_import_ray(): ) -def try_to_import_forecast(): +def try_to_import_statsforecast(): try: from statsforecast import StatsForecast # noqa: F401 except ImportError: @@ -280,6 +280,16 @@ def try_to_import_forecast(): ) +def try_to_import_neuralforecast(): + try: + from neuralforecast import NeuralForecast # noqa: F401 + except ImportError: + raise ValueError( + """Could not import NeuralForecast python package. + Please install it with `pip install neuralforecast`.""" + ) + + def is_ray_available() -> bool: try: try_to_import_ray() @@ -319,7 +329,8 @@ def is_ludwig_available() -> bool: def is_forecast_available() -> bool: try: - try_to_import_forecast() + try_to_import_statsforecast() + try_to_import_neuralforecast() return True except ValueError: # noqa: E722 return False diff --git a/setup.py b/setup.py index 5b475ea9f8..1006466443 100644 --- a/setup.py +++ b/setup.py @@ -121,7 +121,8 @@ def read(path, encoding="utf-8"): sklearn_libs = ["scikit-learn"] forecasting_libs = [ - "statsforecast" # MODEL TRAIN AND FINE TUNING + "statsforecast", # MODEL TRAIN AND FINE TUNING + "neuralforecast" # MODEL TRAIN AND FINE TUNING ] ### NEEDED FOR DEVELOPER TESTING ONLY diff --git a/test/integration_tests/long/test_model_forecasting.py b/test/integration_tests/long/test_model_forecasting.py index 2a9b266c7e..47ffe65a83 100644 --- a/test/integration_tests/long/test_model_forecasting.py +++ b/test/integration_tests/long/test_model_forecasting.py @@ -37,6 +37,15 @@ def setUpClass(cls): y INTEGER);""" execute_query_fetch_all(cls.evadb, create_table_query) + create_table_query = """ + CREATE TABLE AirDataPanel (\ + unique_id TEXT(30),\ + ds TEXT(30),\ + y INTEGER,\ + trend INTEGER,\ + ylagged INTEGER);""" + execute_query_fetch_all(cls.evadb, create_table_query) + create_table_query = """ CREATE TABLE HomeData (\ saledate TEXT(30),\ @@ -49,6 +58,10 @@ def setUpClass(cls): load_query = f"LOAD CSV '{path}' INTO AirData;" execute_query_fetch_all(cls.evadb, load_query) + path = f"{EvaDB_ROOT_DIR}/data/forecasting/AirPassengersPanel.csv" + load_query = f"LOAD CSV '{path}' INTO AirDataPanel;" + execute_query_fetch_all(cls.evadb, load_query) + path = f"{EvaDB_ROOT_DIR}/data/forecasting/home_sales.csv" load_query = f"LOAD CSV '{path}' INTO HomeData;" execute_query_fetch_all(cls.evadb, load_query) @@ -70,12 +83,13 @@ def test_forecast(self): CREATE FUNCTION AirForecast FROM (SELECT unique_id, ds, y FROM AirData) TYPE Forecasting + HORIZON 12 PREDICT 'y'; """ execute_query_fetch_all(self.evadb, create_predict_udf) predict_query = """ - SELECT AirForecast(12) order by y; + SELECT AirForecast() order by y; """ result = execute_query_fetch_all(self.evadb, predict_query) self.assertEqual(len(result), 12) @@ -83,6 +97,28 @@ def test_forecast(self): result.columns, ["airforecast.unique_id", "airforecast.ds", "airforecast.y"] ) + create_predict_udf = """ + CREATE FUNCTION AirPanelForecast FROM + (SELECT unique_id, ds, y, trend FROM AirDataPanel) + TYPE Forecasting + HORIZON 12 + PREDICT 'y' + LIBRARY 'neuralforecast' + AUTO 'false' + FREQUENCY 'M'; + """ + execute_query_fetch_all(self.evadb, create_predict_udf) + + predict_query = """ + SELECT AirPanelForecast() order by y; + """ + result = execute_query_fetch_all(self.evadb, predict_query) + self.assertEqual(len(result), 24) + self.assertEqual( + result.columns, + ["airpanelforecast.unique_id", "airpanelforecast.ds", "airpanelforecast.y"], + ) + @forecast_skip_marker def test_forecast_with_column_rename(self): create_predict_udf = """ @@ -92,6 +128,7 @@ def test_forecast_with_column_rename(self): WHERE bedrooms = 2 ) TYPE Forecasting + HORIZON 12 PREDICT 'ma' ID 'type' TIME 'saledate' @@ -100,7 +137,7 @@ def test_forecast_with_column_rename(self): execute_query_fetch_all(self.evadb, create_predict_udf) predict_query = """ - SELECT HomeForecast(12); + SELECT HomeForecast(); """ result = execute_query_fetch_all(self.evadb, predict_query) self.assertEqual(len(result), 24) diff --git a/test/unit_tests/binder/test_statement_binder.py b/test/unit_tests/binder/test_statement_binder.py index 6a4ee08deb..3de8d1745a 100644 --- a/test/unit_tests/binder/test_statement_binder.py +++ b/test/unit_tests/binder/test_statement_binder.py @@ -568,51 +568,6 @@ def test_bind_create_function_should_bind_forecast_with_renaming_columns(self): self.assertEqual(create_function_statement.inputs, expected_inputs) self.assertEqual(create_function_statement.outputs, expected_outputs) - def test_bind_create_function_should_raise_forecast_with_unexpected_columns(self): - with patch.object(StatementBinder, "bind"): - create_function_statement = MagicMock() - create_function_statement.function_type = "forecasting" - id_col_obj = ColumnCatalogEntry( - name="type", - type=MagicMock(), - array_type=MagicMock(), - array_dimensions=MagicMock(), - ) - ds_col_obj = ColumnCatalogEntry( - name="saledate", - type=MagicMock(), - array_type=MagicMock(), - array_dimensions=MagicMock(), - ) - y_col_obj = ColumnCatalogEntry( - name="ma", - type=MagicMock(), - array_type=MagicMock(), - array_dimensions=MagicMock(), - ) - create_function_statement.query.target_list = [ - TupleValueExpression( - name=id_col_obj.name, table_alias="a", col_object=id_col_obj - ), - TupleValueExpression( - name=ds_col_obj.name, table_alias="a", col_object=ds_col_obj - ), - TupleValueExpression( - name=y_col_obj.name, table_alias="a", col_object=y_col_obj - ), - ] - create_function_statement.metadata = [ - ("predict", "ma"), - ("time", "saledate"), - ] - binder = StatementBinder(StatementBinderContext(MagicMock())) - - with self.assertRaises(BinderError) as cm: - binder._bind_create_function_statement(create_function_statement) - - err_msg = "Unexpected column type found for forecasting function." - self.assertEqual(str(cm.exception), err_msg) - def test_bind_create_function_should_raise_forecast_missing_required_columns(self): with patch.object(StatementBinder, "bind"): create_function_statement = MagicMock() diff --git a/tutorials/16-homesale-forecasting.ipynb b/tutorials/16-homesale-forecasting.ipynb index f2d81c9ff7..82f0eafc1d 100644 --- a/tutorials/16-homesale-forecasting.ipynb +++ b/tutorials/16-homesale-forecasting.ipynb @@ -8,7 +8,7 @@ "collapsed_sections": [ "GHToaA_NKiHY" ], - "authorship_tag": "ABX9TyPOmDfDbnc8CP+70g/FkjHR" + "authorship_tag": "ABX9TyOZ3w3qGhWQ6hZO9onIutni" }, "kernelspec": { "name": "python3", @@ -70,7 +70,7 @@ "base_uri": "https://localhost:8080/" }, "id": "Z7PodOEEEDsQ", - "outputId": "0dcaa531-ae05-4c13-ab74-6dacdf6d8739" + "outputId": "7b01f944-0b6a-4c91-e7cd-0251d2b66ab1" }, "execution_count": 1, "outputs": [ @@ -93,7 +93,7 @@ "After this operation, 51.5 MB of additional disk space will be used.\n", "Preconfiguring packages ...\n", "Selecting previously unselected package logrotate.\n", - "(Reading database ... 120901 files and directories currently installed.)\n", + "(Reading database ... 120895 files and directories currently installed.)\n", "Preparing to unpack .../00-logrotate_3.19.0-1ubuntu1.1_amd64.deb ...\n", "Unpacking logrotate (3.19.0-1ubuntu1.1) ...\n", "Selecting previously unselected package netbase.\n", @@ -211,7 +211,7 @@ "base_uri": "https://localhost:8080/" }, "id": "UrlfWZOkEa4V", - "outputId": "1fc62319-0d3f-4f2a-bcc4-e408587e50fb" + "outputId": "af5143e2-9e7c-45c9-fc2b-247f1d241ae4" }, "execution_count": 2, "outputs": [ @@ -264,7 +264,7 @@ { "cell_type": "code", "source": [ - "%pip install --quiet \"evadb[postgres,forecasting] @ git+https://github.com/georgia-tech-db/evadb.git@a40c72ed6cb18993e2ae5bda28c7195f4de4f109\"\n", + "%pip install --quiet \"evadb[postgres,forecasting] @ git+https://github.com/georgia-tech-db/evadb.git@68265d3b138babfe4a20091bc5fa7a67b56072f5\"\n", "\n", "import evadb\n", "cursor = evadb.connect().cursor()" @@ -274,7 +274,7 @@ "base_uri": "https://localhost:8080/" }, "id": "NoAykveeElqm", - "outputId": "de6547e5-670d-4fba-d081-30ffecc74849" + "outputId": "e6a44278-1571-461f-c673-572400d3962f" }, "execution_count": 4, "outputs": [ @@ -285,22 +285,26 @@ " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m92.6/92.6 kB\u001b[0m \u001b[31m2.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m108.9/108.9 kB\u001b[0m \u001b[31m12.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m137.6/137.6 kB\u001b[0m \u001b[31m15.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m45.5/45.5 kB\u001b[0m \u001b[31m5.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m110.9/110.9 kB\u001b[0m \u001b[31m14.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m98.7/98.7 kB\u001b[0m \u001b[31m11.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m275.0/275.0 kB\u001b[0m \u001b[31m26.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.1/3.1 MB\u001b[0m \u001b[31m32.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h\u001b[33mWARNING: Retrying (Retry(total=4, connect=None, read=None, redirect=None, status=None)) after connection broken by 'ProtocolError('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))': /simple/triad/\u001b[0m\u001b[33m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m92.6/92.6 kB\u001b[0m \u001b[31m2.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m108.9/108.9 kB\u001b[0m \u001b[31m6.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m137.6/137.6 kB\u001b[0m \u001b[31m8.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m45.5/45.5 kB\u001b[0m \u001b[31m3.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m110.9/110.9 kB\u001b[0m \u001b[31m8.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m162.6/162.6 kB\u001b[0m \u001b[31m12.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m98.7/98.7 kB\u001b[0m \u001b[31m11.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m727.7/727.7 kB\u001b[0m \u001b[31m19.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.5/62.5 MB\u001b[0m \u001b[31m8.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m275.0/275.0 kB\u001b[0m \u001b[31m21.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.1/3.1 MB\u001b[0m \u001b[31m79.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m57.0/57.0 kB\u001b[0m \u001b[31m6.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m169.2/169.2 kB\u001b[0m \u001b[31m20.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m154.6/154.6 kB\u001b[0m \u001b[31m15.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m169.2/169.2 kB\u001b[0m \u001b[31m18.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m154.7/154.7 kB\u001b[0m \u001b[31m19.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m303.2/303.2 kB\u001b[0m \u001b[31m32.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m144.2/144.2 kB\u001b[0m \u001b[31m17.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m135.3/135.3 kB\u001b[0m \u001b[31m17.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m313.3/313.3 kB\u001b[0m \u001b[31m28.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m805.2/805.2 kB\u001b[0m \u001b[31m63.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m101.7/101.7 kB\u001b[0m \u001b[31m12.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m144.2/144.2 kB\u001b[0m \u001b[31m17.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m135.3/135.3 kB\u001b[0m \u001b[31m13.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Building wheel for evadb (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", " Building wheel for fugue-sql-antlr (setup.py) ... \u001b[?25l\u001b[?25hdone\n" ] @@ -310,7 +314,7 @@ "name": "stderr", "text": [ "Downloading: \"http://ml.cs.tsinghua.edu.cn/~chenxi/pytorch-models/mnist-b07bb66b.pth\" to /root/.cache/torch/hub/checkpoints/mnist-b07bb66b.pth\n", - "100%|██████████| 1.03M/1.03M [00:01<00:00, 898kB/s]\n", + "100%|██████████| 1.03M/1.03M [00:01<00:00, 745kB/s]\n", "Downloading: \"https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth\" to /root/.cache/torch/hub/checkpoints/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth\n" ] } @@ -352,10 +356,10 @@ "metadata": { "colab": { "base_uri": "https://localhost:8080/", - "height": 81 + "height": 0 }, "id": "IsP6rLZ2Ftxo", - "outputId": "eee82699-fd4f-4aa8-edac-a9f0e0575e98" + "outputId": "9260345f-eb8f-4f19-e1a4-9078ca6e1da3" }, "execution_count": 5, "outputs": [ @@ -368,7 +372,7 @@ ], "text/html": [ "\n", - "
\n", + "
\n", "
\n", "