Skip to content

Commit

Permalink
Use Flaml zero-shot automl
Browse files Browse the repository at this point in the history
to mine good hyperparameter configurations offline
  • Loading branch information
HellenNamulinda committed May 9, 2024
1 parent d2ff580 commit f5d5ad8
Show file tree
Hide file tree
Showing 9 changed files with 18 additions and 7 deletions.
Binary file modified assets/evaluation_data.joblib
Binary file not shown.
10 changes: 5 additions & 5 deletions assets/evaluation_metrics.json
Original file line number Diff line number Diff line change
@@ -1,7 1,7 @@
{
"Mean Squared Error": 0.3952,
"Root Mean Squared Error": 0.6287,
"Mean Absolute Error": 0.4569,
"R-squared Score": 0.7018,
"Explained Variance Score": 0.7018
"Mean Squared Error": 0.3913,
"Root Mean Squared Error": 0.6256,
"Mean Absolute Error": 0.4472,
"R-squared Score": 0.7047,
"Explained Variance Score": 0.7053
}
Binary file modified assets/evaluation_scatter_plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified assets/interpretability_bar_plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified assets/interpretability_beeswarm_plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified assets/interpretability_sample1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 9,8 @@ mordred==1.2.0
pandas<=2.1
matplotlib==3.8.3
featurewiz==0.5.7
tensorflow==2.16.1
flaml==2.1.2
imblearn
pydantic
aiohttp
Expand Down
2 changes: 1 addition & 1 deletion scripts/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 41,7 @@
smiles_valid_transformed = descriptor.transform(smiles_valid)

# Instantiate the regressor
regressor = Regressor(output_folder, algorithm='catboost')
regressor = Regressor(output_folder, algorithm='xgboost')

# Train the model
regressor.fit(smiles_train_transformed, y_train)
Expand Down
11 changes: 10 additions & 1 deletion xai4chem/regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 5,15 @@
import matplotlib.pyplot as plt
import json
import optuna
from xgboost import XGBRegressor
# from xgboost import XGBRegressor
from catboost import CatBoostRegressor
from sklearn import metrics
from sklearn.model_selection import train_test_split
from sklearn.feature_selection import SelectKBest, mutual_info_regression
from featurewiz import FeatureWiz
from sklearn.svm import SVR
from flaml.default import LGBMRegressor, XGBRegressor
from flaml.default import preprocess_and_suggest_hyperparams


class Regressor:
Expand Down Expand Up @@ -124,6 127,12 @@ def fit(self, X_train, y_train, default_params=True):
else:
self.model = CatBoostRegressor()
self.model.fit(X_train, y_train, verbose=False)
elif self.algorithm == 'svr':
self.model = SVR()
self.model.fit(X_train, y_train)
elif self.algorithm == 'lgbm':
self.model = LGBMRegressor()
self.model.fit(X_train, y_train)
else:
raise ValueError("Invalid Algorithm. Supported Algorithms: xgboost, catboost")

Expand Down

0 comments on commit f5d5ad8

Please sign in to comment.