Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Neuralforecast #1115

Merged
merged 27 commits into from
Sep 30, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
8f49397
Set the write output column type for forecast functions
xzdandy Sep 13, 2023
043d671
Fix forecast integration test
xzdandy Sep 13, 2023
0977c1f
Move the generic utils test
xzdandy Sep 13, 2023
092c03f
Fix ludwig unittest cases and add unittestcase for normal forecasting
xzdandy Sep 13, 2023
96e40db
Add unitest cases for forecast with rename in binder.
xzdandy Sep 13, 2023
5648371
Add unittest when an expected column is passed to forecasting
xzdandy Sep 13, 2023
8692ff1
Add unittest when required columns are missing in binder
xzdandy Sep 13, 2023
0679200
Merge branch 'staging' into neuralforecast
americast Sep 13, 2023
1fd3c02
Add neuralforecast support
americast Sep 14, 2023
65ed6e1
less horizon no retrain
americast Sep 15, 2023
5fd8af7
Merge branch 'staging' into neuralforecast
americast Sep 24, 2023
be242ee
add support for exogenous variables
americast Sep 25, 2023
583e778
Fix exogenous support; add tests
americast Sep 25, 2023
52c563e
add tests
americast Sep 25, 2023
84a159e
wip: fix test
americast Sep 25, 2023
06a7db0
remove strict column check in test
americast Sep 25, 2023
32a204b
Fix GPU issue with neuralforecast; fixed auto exog veriables
americast Sep 28, 2023
fda2b40
Merge remote-tracking branch 'origin/staging' into neuralforecast
americast Sep 28, 2023
736d9e0
added auto support; updated docs
americast Sep 29, 2023
06fb001
Update forecasting notebook.
xzdandy Sep 29, 2023
a36a1f5
fixes
americast Sep 29, 2023
eee78c9
Merge branch 'neuralforecast' of github.com:georgia-tech-db/evadb int…
americast Sep 29, 2023
09bee12
Fix horizon issue for multi uniqueids
americast Sep 29, 2023
b422000
update docs
americast Sep 29, 2023
e176bd4
fix exogenous for auto; made default
americast Sep 30, 2023
68265d3
turn auto off for neuralforecast test to avoid TLE error
americast Sep 30, 2023
267443d
Update the Notebook
xzdandy Sep 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
less horizon no retrain
  • Loading branch information
americast committed Sep 15, 2023
commit 65ed6e1521018304be1e9759a18ed979e7d0ab63
102 changes: 57 additions & 45 deletions evadb/executor/create_function_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,28 +152,30 @@ def handle_forecasting_function(self):
impl_path = self.node.impl_path.absolute().as_posix()
library = "statsforecast"
supported_libraries = ["statsforecast", "neuralforecast"]

if "horizon" not in arg_map.keys():
raise ValueError(
"Horizon must be provided while creating function of type FORECASTING"
)
try:
horizon = int(arg_map["horizon"])
except:
raise ValueError(
"Parameter horizon must be integral."
)
except Exception as e:
err_msg = f"{str(e)}. HORIZON must be integral."
logger.error(err_msg)
raise FunctionIODefinitionError(err_msg)

if "library" in arg_map.keys():
try:
assert arg_map["library"].lower() in supported_libraries
except:
raise ValueError(
"EvaDB currently supports "+str(supported_libraries)+" only."
except Exception:
err_msg = (
"EvaDB currently supports " + str(supported_libraries) + " only."
)
logger.error(err_msg)
raise FunctionIODefinitionError(err_msg)

library = arg_map["library"].lower()


"""
The following rename is needed for statsforecast/neuralforecast, which requires the column name to be the following:
- The unique_id (string, int or category) represents an identifier for the series.
Expand All @@ -196,7 +198,7 @@ def handle_forecasting_function(self):

"""
Set or infer data frequency
"""
"""

if "frequency" not in arg_map.keys():
arg_map["frequency"] = pd.infer_freq(data["ds"])
Expand All @@ -222,17 +224,16 @@ def handle_forecasting_function(self):
) # shortens longer frequencies like Q-DEC
season_length = season_dict[new_freq] if new_freq in season_dict else 1


try_to_import_forecast()
americast marked this conversation as resolved.
Show resolved Hide resolved

"""
Neuralforecast implementation
"""
if library == "neuralforecast":
from neuralforecast import NeuralForecast
from neuralforecast.models import NBEATS
from neuralforecast.auto import AutoNBEATS

from neuralforecast.models import NBEATS

model_dict = {
"AutoNBEATS": AutoNBEATS,
"NBEATS": NBEATS,
Expand All @@ -242,18 +243,17 @@ def handle_forecasting_function(self):
arg_map["model"] = "NBEATS"

try:
model_name = arg_map["model"]
except:
raise ValueError(
"Supported models: "+str(model_dict.keys())
)
model_here = model_dict[arg_map["model"]]
except Exception:
err_msg = "Supported models: " + str(model_dict.keys())
logger.error(err_msg)
raise FunctionIODefinitionError(err_msg)

model = NeuralForecast(
[model_dict[model_name](input_size=2 * horizon, h=horizon, max_steps=50)], freq=new_freq
[model_here(input_size=2 * horizon, h=horizon, max_steps=50)],
freq=new_freq,
)



# """
# Statsforecast implementation
# """
Expand All @@ -272,40 +272,56 @@ def handle_forecasting_function(self):
arg_map["model"] = "AutoARIMA"

try:
model_name = arg_map["model"]
except:
raise ValueError(
"Supported models: "+str(model_dict.keys())
)


model_here = model_dict[arg_map["model"]]
except Exception:
err_msg = "Supported models: " + str(model_dict.keys())
logger.error(err_msg)
raise FunctionIODefinitionError(err_msg)

model = StatsForecast(
[model_dict[model_name](season_length=season_length)], freq=new_freq
[model_here(season_length=season_length)], freq=new_freq
)

data["ds"] = pd.to_datetime(data["ds"])

model_dir = os.path.join(
self.db.config.get_value("storage", "model_dir"), self.node.name
)
Path(model_dir).mkdir(parents=True, exist_ok=True)
model_path = os.path.join(
self.db.config.get_value("storage", "model_dir"),
self.node.name,
library+"_"+str(hashlib.sha256(data.to_string().encode()).hexdigest()) + ".pkl",
)
Path(model_dir).mkdir(parents=True, exist_ok=True)

weight_file = Path(model_path)
data["ds"] = pd.to_datetime(data["ds"])
if not weight_file.exists():
model_save_name = (
library
+ "_"
+ str(hashlib.sha256(data.to_string().encode()).hexdigest())
+ "_horizon"
+ str(horizon)
+ ".pkl"
)

model_path = os.path.join(model_dir, model_save_name)

existing_model_files = sorted(
os.listdir(model_dir),
key=lambda x: int(x.split("horizon")[1].split(".pkl")[0]),
)
existing_model_files = [
x
for x in existing_model_files
if int(x.split("horizon")[1].split(".pkl")[0]) >= horizon
]
if len(existing_model_files) == 0:
model.fit(df=data)
f = open(model_path, "wb")
pickle.dump(model, f)
f.close()
elif not Path(model_path).exists():
model_path = os.path.join(model_dir, existing_model_files[-1])

io_list = self._resolve_function_io(None)

metadata_here = [
FunctionMetadataCatalogEntry("model_name", model_name),
FunctionMetadataCatalogEntry("model_name", arg_map["model"]),
FunctionMetadataCatalogEntry("model_path", model_path),
FunctionMetadataCatalogEntry(
"predict_column_rename", arg_map.get("predict", "y")
Expand All @@ -316,12 +332,8 @@ def handle_forecasting_function(self):
FunctionMetadataCatalogEntry(
"id_column_rename", arg_map.get("id", "unique_id")
),
FunctionMetadataCatalogEntry(
"horizon", horizon
),
FunctionMetadataCatalogEntry(
"library", library
),
FunctionMetadataCatalogEntry("horizon", horizon),
FunctionMetadataCatalogEntry("library", library),
]

return (
Expand Down
5 changes: 3 additions & 2 deletions evadb/functions/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from evadb.functions.abstract.abstract_function import AbstractFunction
from evadb.functions.decorators.decorators import setup


class ForecastModel(AbstractFunction):
@property
def name(self) -> str:
Expand All @@ -35,7 +36,7 @@ def setup(
time_column_rename: str,
id_column_rename: str,
horizon: int,
library: str
library: str,
):
f = open(model_path, "rb")
loaded_model = pickle.load(f)
Expand All @@ -60,5 +61,5 @@ def forward(self, data) -> pd.DataFrame:
"ds": self.time_column_rename,
self.model_name: self.predict_column_rename,
}
)
)[: self.horizon]
americast marked this conversation as resolved.
Show resolved Hide resolved
return forecast_df