From 03ae771d918a2b117565ed95aa303a917ec556c7 Mon Sep 17 00:00:00 2001 From: Manas-7854 Date: Thu, 19 Mar 2026 01:35:02 +0530 Subject: [PATCH 1/2] added list_estimation_procedure in init files and returns a dict[int, str] --- openml/__init__.py | 3 ++- openml/evaluations/__init__.py | 8 ++++++- openml/evaluations/functions.py | 14 +++++------ .../test_evaluation_functions.py | 23 +++++++++++++++++++ 4 files changed, 39 insertions(+), 9 deletions(-) diff --git a/openml/__init__.py b/openml/__init__.py index 9a457c146..d1892f609 100644 --- a/openml/__init__.py +++ b/openml/__init__.py @@ -36,7 +36,7 @@ ) from .__version__ import __version__ from .datasets import OpenMLDataFeature, OpenMLDataset -from .evaluations import OpenMLEvaluation +from .evaluations import OpenMLEvaluation, list_estimation_procedures from .flows import OpenMLFlow from .runs import OpenMLRun from .setups import OpenMLParameter, OpenMLSetup @@ -122,6 +122,7 @@ def populate_cache( "exceptions", "extensions", "flows", + "list_estimation_procedures", "runs", "setups", "study", diff --git a/openml/evaluations/__init__.py b/openml/evaluations/__init__.py index b56d0c2d5..29344b03a 100644 --- a/openml/evaluations/__init__.py +++ b/openml/evaluations/__init__.py @@ -1,10 +1,16 @@ # License: BSD 3-Clause from .evaluation import OpenMLEvaluation -from .functions import list_evaluation_measures, list_evaluations, list_evaluations_setups +from .functions import ( + list_estimation_procedures, + list_evaluation_measures, + list_evaluations, + list_evaluations_setups, +) __all__ = [ "OpenMLEvaluation", + "list_estimation_procedures", "list_evaluation_measures", "list_evaluations", "list_evaluations_setups", diff --git a/openml/evaluations/functions.py b/openml/evaluations/functions.py index 61c95a480..2b608e0ed 100644 --- a/openml/evaluations/functions.py +++ b/openml/evaluations/functions.py @@ -298,15 +298,15 @@ def list_evaluation_measures() -> list[str]: return qualities["oml:evaluation_measures"]["oml:measures"][0]["oml:measure"] -def list_estimation_procedures() -> list[str]: - """Return list of evaluation procedures available. +def list_estimation_procedures() -> dict[int, str]: + """Return dictionary of evaluation procedures available. The function performs an API call to retrieve the entire list of - evaluation procedures' names that are available. + evaluation procedures' ids and names that are available. Returns ------- - list + dict[int, str] """ api_call = "estimationprocedure/list" xml_string = openml._api_calls._perform_api_call(api_call, "get") @@ -322,10 +322,10 @@ def list_estimation_procedures() -> list[str]: if not isinstance(api_results["oml:estimationprocedures"]["oml:estimationprocedure"], list): raise TypeError('Error in return XML, does not contain "oml:estimationprocedure" as a list') - return [ - prod["oml:name"] + return { + int(prod["oml:id"]): prod["oml:name"] for prod in api_results["oml:estimationprocedures"]["oml:estimationprocedure"] - ] + } def list_evaluations_setups( diff --git a/tests/test_evaluations/test_evaluation_functions.py b/tests/test_evaluations/test_evaluation_functions.py index e15556d7b..6d09a0875 100644 --- a/tests/test_evaluations/test_evaluation_functions.py +++ b/tests/test_evaluations/test_evaluation_functions.py @@ -264,3 +264,26 @@ def test_list_evaluations_setups_filter_task(self): task_id = [6] size = 121 self._check_list_evaluation_setups(tasks=task_id, size=size) + + @pytest.mark.test_server() + def test_list_estimation_procedures_return_type(self): + procedures = openml.evaluations.list_estimation_procedures() + assert isinstance(procedures, dict) + assert len(procedures) > 0 + assert all(isinstance(k, int) for k in procedures.keys()) + assert all(isinstance(v, str) for v in procedures.values()) + + @pytest.mark.test_server() + def test_list_estimation_procedures_top_level_accessible(self): + procedures = openml.list_estimation_procedures() + assert isinstance(procedures, dict) + assert len(procedures) > 0 + assert all(isinstance(k, int) for k in procedures.keys()) + assert all(isinstance(v, str) for v in procedures.values()) + + @pytest.mark.test_server() + def test_list_estimation_procedures_valid_id_for_task_creation(self): + procedures = openml.evaluations.list_estimation_procedures() + first_id = list(procedures.keys())[0] + assert isinstance(first_id, int) + assert first_id > 0 From a3dda64438d4c1ab7190a2ae8eeda22e72672c68 Mon Sep 17 00:00:00 2001 From: Manas-7854 Date: Thu, 19 Mar 2026 01:56:50 +0530 Subject: [PATCH 2/2] added tag to get dict --- openml/evaluations/functions.py | 45 ++++++++++++++++--- .../test_evaluation_functions.py | 21 +++++++-- 2 files changed, 55 insertions(+), 11 deletions(-) diff --git a/openml/evaluations/functions.py b/openml/evaluations/functions.py index 2b608e0ed..b5661c4df 100644 --- a/openml/evaluations/functions.py +++ b/openml/evaluations/functions.py @@ -298,16 +298,42 @@ def list_evaluation_measures() -> list[str]: return qualities["oml:evaluation_measures"]["oml:measures"][0]["oml:measure"] -def list_estimation_procedures() -> dict[int, str]: - """Return dictionary of evaluation procedures available. +@overload +def list_estimation_procedures(include_ids: Literal[True]) -> dict[int, str]: ... + + +@overload +def list_estimation_procedures(include_ids: Literal[False] = ...) -> list[str]: ... + + +def list_estimation_procedures(include_ids: bool = False) -> dict[int, str] | list[str]: # noqa: FBT002 + """Return dictionary or list of estimation procedures available. The function performs an API call to retrieve the entire list of - evaluation procedures' ids and names that are available. + estimation procedures' ids and names that are available. + + Parameters + ---------- + include_ids : bool, optional (default=False) + If True, return a dictionary mapping estimation procedure id to name. + If False, return a list of estimation procedure names. Returns ------- - dict[int, str] + list of estimation procedure names (default), or dict mapping + estimation procedure id to name if include_ids=True """ + if not include_ids: + import warnings + + warnings.warn( + "Returning a list from list_estimation_procedures is deprecated " + "and will be removed in a future release. " + "Use include_ids=True to get a dict of {id: name} instead.", + DeprecationWarning, + stacklevel=2, + ) + api_call = "estimationprocedure/list" xml_string = openml._api_calls._perform_api_call(api_call, "get") api_results = xmltodict.parse(xml_string) @@ -322,10 +348,15 @@ def list_estimation_procedures() -> dict[int, str]: if not isinstance(api_results["oml:estimationprocedures"]["oml:estimationprocedure"], list): raise TypeError('Error in return XML, does not contain "oml:estimationprocedure" as a list') - return { - int(prod["oml:id"]): prod["oml:name"] + if include_ids: + return { + int(prod["oml:id"]): prod["oml:name"] + for prod in api_results["oml:estimationprocedures"]["oml:estimationprocedure"] + } + return [ + prod["oml:name"] for prod in api_results["oml:estimationprocedures"]["oml:estimationprocedure"] - } + ] def list_evaluations_setups( diff --git a/tests/test_evaluations/test_evaluation_functions.py b/tests/test_evaluations/test_evaluation_functions.py index 6d09a0875..6a1429830 100644 --- a/tests/test_evaluations/test_evaluation_functions.py +++ b/tests/test_evaluations/test_evaluation_functions.py @@ -267,7 +267,7 @@ def test_list_evaluations_setups_filter_task(self): @pytest.mark.test_server() def test_list_estimation_procedures_return_type(self): - procedures = openml.evaluations.list_estimation_procedures() + procedures = openml.evaluations.list_estimation_procedures(include_ids=True) assert isinstance(procedures, dict) assert len(procedures) > 0 assert all(isinstance(k, int) for k in procedures.keys()) @@ -275,15 +275,28 @@ def test_list_estimation_procedures_return_type(self): @pytest.mark.test_server() def test_list_estimation_procedures_top_level_accessible(self): - procedures = openml.list_estimation_procedures() + procedures = openml.list_estimation_procedures(include_ids=True) assert isinstance(procedures, dict) assert len(procedures) > 0 assert all(isinstance(k, int) for k in procedures.keys()) assert all(isinstance(v, str) for v in procedures.values()) @pytest.mark.test_server() - def test_list_estimation_procedures_valid_id_for_task_creation(self): - procedures = openml.evaluations.list_estimation_procedures() + def test_list_estimation_procedures_ids_are_positive_ints(self): + procedures = openml.evaluations.list_estimation_procedures(include_ids=True) first_id = list(procedures.keys())[0] assert isinstance(first_id, int) assert first_id > 0 + + @pytest.mark.test_server() + def test_list_estimation_procedures_default_returns_list(self): + import warnings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + procedures = openml.evaluations.list_estimation_procedures() + assert isinstance(procedures, list) + assert len(procedures) > 0 + assert all(isinstance(s, str) for s in procedures) + # confirm deprecation warning was raised + assert any(issubclass(warning.category, DeprecationWarning) for warning in w)