Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sorting Error in find_max_value_trial() #5231

Open
unolife opened this issue Feb 7, 2024 · 2 comments
Open

Sorting Error in find_max_value_trial() #5231

unolife opened this issue Feb 7, 2024 · 2 comments
Labels
bug Issue/PR about behavior that is broken. Not for typos/examples/CI/test but for Optuna itself.

Comments

@unolife
Copy link

unolife commented Feb 7, 2024

Expected behavior

I just wanna use rdb, when I use optuna, but there is some problem

In optuna/storages/_rdb/models.py class TrialModel, function find_max_value_trial,
case in the query makes problem :
case({"INF_NEG": -1, "FINITE": 0, "INF_POS": 1},value=TrialValueModel.value_type,)
if I annotate this part, it works well.

@classmethod
    def find_max_value_trial(
        cls, study_id: int, objective: int, session: orm.Session
    ) -> "TrialModel":
        trial = (
            session.query(cls)
            .filter(cls.study_id == study_id)
            .filter(cls.state == TrialState.COMPLETE)
            .join(TrialValueModel)
            .filter(TrialValueModel.objective == objective)
            .order_by(
                desc(
                    case(
                        {"INF_NEG": -1, "FINITE": 0, "INF_POS": 1},
                        value=TrialValueModel.value_type,
                    )
                ),
                desc(TrialValueModel.value),
            )
            .limit(1)
            .one_or_none()
        )
        if trial is None:
            raise ValueError(NOT_FOUND_MSG)
        return trial

Environment

  • Optuna version: 3.4.0
  • Python version: 3.10.10
  • OS: MacOS Ventura 13.5.2
  • (Optional) Other libraries and their versions:
    scikit-learn: 1.3.2
    catboost:1.2.2

Error messages, stack traces, or logs

Traceback (most recent call last):
  File "/Users/warren/projects/automl/optuna_rdb_test.py", line 20, in <module>
    study.optimize(objective, n_trials=10)
  File "/Users/warren/miniforge3/envs/hfcss/lib/python3.10/site-packages/optuna/study/study.py", line 451, in optimize
    _optimize(
  File "/Users/warren/miniforge3/envs/hfcss/lib/python3.10/site-packages/optuna/study/_optimize.py", line 76, in _optimize
    _optimize_sequential(
  File "/Users/warren/miniforge3/envs/hfcss/lib/python3.10/site-packages/optuna/study/_optimize.py", line 173, in _optimize_sequential
    frozen_trial = _run_trial(study, func, catch)
  File "/Users/warren/miniforge3/envs/hfcss/lib/python3.10/site-packages/optuna/study/_optimize.py", line 234, in _run_trial
    study._log_completed_trial(frozen_trial)
  File "/Users/warren/miniforge3/envs/hfcss/lib/python3.10/site-packages/optuna/study/study.py", line 1105, in _log_completed_trial
    best_trial = self.best_trial
  File "/Users/warren/miniforge3/envs/hfcss/lib/python3.10/site-packages/optuna/study/study.py", line 157, in best_trial
    return copy.deepcopy(self._storage.get_best_trial(self._study_id))
  File "/Users/warren/miniforge3/envs/hfcss/lib/python3.10/site-packages/optuna/storages/_cached_storage.py", line 183, in get_best_trial
    return self._backend.get_best_trial(study_id)
  File "/Users/warren/miniforge3/envs/hfcss/lib/python3.10/site-packages/optuna/storages/_rdb/storage.py", line 915, in get_best_trial
    trial = models.TrialModel.find_max_value_trial(study_id, 0, session)
  File "/Users/warren/miniforge3/envs/hfcss/lib/python3.10/site-packages/optuna/storages/_rdb/models.py", line 214, in find_max_value_trial
    case(
  File "<string>", line 2, in case
  File "/Users/warren/miniforge3/envs/hfcss/lib/python3.10/site-packages/sqlalchemy/sql/elements.py", line 2755, in __init__
    whenlist = [
  File "/Users/warren/miniforge3/envs/hfcss/lib/python3.10/site-packages/sqlalchemy/sql/elements.py", line 2764, in <listcomp>
    for (c, r) in whens
ValueError: too many values to unpack (expected 2)

Steps to reproduce

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from catboost import CatBoostClassifier
import optuna

X, y = make_classification(n_samples=1000, n_features=20, n_classes=2, random_state=42)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

def objective(trial):
params = {
'depth': trial.suggest_int('depth', 3, 10),
'verbose': False,
}
model = CatBoostClassifier(**params)
model.fit(X_train, y_train)
score = model.score(X_test, y_test)
return score
study = optuna.create_study(direction='maximize', study_name='catboost', storage='sqlite:///test.db', load_if_exists=True)
study.optimize(objective, n_trials=10)
best_params = study.best_params
print("Best params:", best_params)
best_model = CatBoostClassifier(**best_params)
best_model.fit(X_train, y_train)
test_score = best_model.score(X_test, y_test)
print("Test set accuracy:", test_score)

Additional context (optional)

No response

@unolife unolife added the bug Issue/PR about behavior that is broken. Not for typos/examples/CI/test but for Optuna itself. label Feb 7, 2024
@nzw0301
Copy link
Member

nzw0301 commented Feb 7, 2024

Hmm, I could not reproduce the error; the script worked fine without any error.

@nzw0301
Copy link
Member

nzw0301 commented Feb 7, 2024

Does this happen with

study = optuna.create_study(direction='maximize', study_name='catboost', storage='sqlite:///test.db', load_if_exists=False)

?

I'm not sure but I suspect the database contains old optuna's data or broken info.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Issue/PR about behavior that is broken. Not for typos/examples/CI/test but for Optuna itself.
Projects
None yet
Development

No branches or pull requests

2 participants