From 832f56bff74fee50da863a2a95ffbe1181647262 Mon Sep 17 00:00:00 2001 From: Yannick Marcon Date: Sat, 21 Mar 2026 19:21:46 +0100 Subject: [PATCH 1/2] feat: added filter by connection names to operations --- datashield/api.py | 130 +++++++++++++++++++++----------- tests/test_session_filters.py | 136 ++++++++++++++++++++++++++++++++++ 2 files changed, 222 insertions(+), 44 deletions(-) create mode 100644 tests/test_session_filters.py diff --git a/datashield/api.py b/datashield/api.py index 59ec3a0..31fc604 100644 --- a/datashield/api.py +++ b/datashield/api.py @@ -129,15 +129,17 @@ def open(self, restore: str = None, failSafe: bool = False) -> None: for name in self.errors: logging.error(f"Connection to {name} has failed") - def close(self, save: str = None) -> None: + def close(self, save: str = None, conn_names: list[str] = None) -> None: """ Close connections with remote servers. - :param cons: The list of connections to close. :param save: The name of the workspace to save before closing the connections. + :param conn_names: The optional list of connection names to close. If not defined, all opened connections are closed. """ self.errors = {} - for conn in self.conns: + selected_conns = self._get_selected_connections(conn_names) + selected_names = {conn.get_name() for conn in selected_conns} + for conn in selected_conns: try: if save: conn.save_workspace(f"{conn.get_name()}:{save}") @@ -145,7 +147,10 @@ def close(self, save: str = None) -> None: except DSError: # silently fail pass - self.conns = None + if conn_names is None: + self.conns = None + else: + self.conns = [conn for conn in self.conns if conn.get_name() not in selected_names] def has_connections(self) -> bool: """ @@ -161,10 +166,7 @@ def get_connection_names(self) -> list[str]: :return: The list of opened connection names """ - if self.conns: - return [conn.get_name() for conn in self.conns] - else: - return [] + return [conn.get_name() for conn in self.conns] def has_errors(self) -> bool: """ @@ -186,27 +188,29 @@ def get_errors(self) -> dict: # Environment # - def tables(self) -> dict: + def tables(self, conn_names: list[str] = None) -> dict: """ List available table names from the data repository. + :param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried. :return: The available table names from the data repository, per remote server name """ rval = {} - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): rval[conn.get_name()] = conn.list_tables() return rval - def variables(self, table: str = None, tables: dict = None) -> dict: + def variables(self, table: str = None, tables: dict = None, conn_names: list[str] = None) -> dict: """ List available variables from the data repository, for a given table. :param table: The default name of the table to list variables for :param tables: The name of the table to list variables for, per server name. If not defined, 'table' is used. + :param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried. :return: The available variables from the data repository, for a given table, per remote server name """ rval = {} - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): name = table if tables and conn.get_name() in tables: name = tables[conn.get_name()] @@ -216,75 +220,81 @@ def variables(self, table: str = None, tables: dict = None) -> dict: rval[conn.get_name()] = None return rval - def taxonomies(self) -> dict: + def taxonomies(self, conn_names: list[str] = None) -> dict: """ List available taxonomies from the data repository. A taxonomy is a hierarchical structure of vocabulary terms that can be used to annotate variables in the data repository. Depending on the data repository's capabilities, taxonomies can be used to perform structured queries when searching for variables. + :param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried. :return: The available taxonomies from the data repository, per remote server name """ rval = {} - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): rval[conn.get_name()] = conn.list_taxonomies() return rval - def search_variables(self, query: str) -> dict: + def search_variables(self, query: str, conn_names: list[str] = None) -> dict: """ Search for variable names matching a given query across all tables in the data repository. :param query: The query to search for in variable names, e.g., a full-text search and/or structured query (based on taxonomy terms), depending on the data repository's capabilities + :param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried. :return: The matching variable names from the data repository, per remote server name """ rval = {} - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): rval[conn.get_name()] = conn.search_variables(query) return rval - def resources(self) -> dict: + def resources(self, conn_names: list[str] = None) -> dict: """ List available resource names from the data repository. + :param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried. :return: The available resource names from the data repository, per remote server name """ rval = {} - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): rval[conn.get_name()] = conn.list_resources() return rval - def profiles(self) -> dict: + def profiles(self, conn_names: list[str] = None) -> dict: """ List available DataSHIELD profile names in the data repository. + :param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried. :return: The available DataSHIELD profile names in the data repository, per remote server name """ rval = {} - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): rval[conn.get_name()] = conn.list_profiles() return rval - def packages(self) -> dict: + def packages(self, conn_names: list[str] = None) -> dict: """ Get the list of DataSHIELD packages with their version, that have been configured on the remote data repository. + :param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried. :return: The list of DataSHIELD packages with their version, that have been configured on the remote data repository, per remote server name """ rval = {} - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): rval[conn.get_name()] = conn.list_packages() return rval - def methods(self, type: str = "aggregate") -> dict: + def methods(self, type: str = "aggregate", conn_names: list[str] = None) -> dict: """ Get the list of DataSHIELD methods that have been configured on the remote data repository. :param type: The type of method, either "aggregate" (default) or "assign" + :param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried. :return: The list of DataSHIELD methods that have been configured on the remote data repository, per remote server name """ rval = {} - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): rval[conn.get_name()] = conn.list_methods(type) return rval @@ -292,44 +302,48 @@ def methods(self, type: str = "aggregate") -> dict: # Workspaces # - def workspaces(self) -> dict: + def workspaces(self, conn_names: list[str] = None) -> dict: """ Get the list of DataSHIELD workspaces, that have been saved on the remote data repository. + :param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried. :return: The list of DataSHIELD workspaces, that have been saved on the remote data repository, per remote server name """ rval = {} - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): rval[conn.get_name()] = conn.list_workspaces() return rval - def workspace_save(self, name: str) -> None: + def workspace_save(self, name: str, conn_names: list[str] = None) -> None: """ Save the DataSHIELD R session in a workspace on the remote data repository. :param name: The name of the workspace + :param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used. """ - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): conn.save_workspace(f"{conn.get_name()}:{name}") - def workspace_restore(self, name: str) -> None: + def workspace_restore(self, name: str, conn_names: list[str] = None) -> None: """ Restore a saved DataSHIELD R session from the remote data repository. When restoring a workspace, any existing symbol or file with same name will be overridden. :param name: The name of the workspace + :param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used. """ - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): conn.restore_workspace(f"{conn.get_name()}:{name}") - def workspace_rm(self, name: str) -> None: + def workspace_rm(self, name: str, conn_names: list[str] = None) -> None: """ Remove a DataSHIELD workspace from the remote data repository. Ignored if no such workspace exists. :param name: The name of the workspace + :param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used. """ - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): conn.rm_workspace(f"{conn.get_name()}:{name}") # @@ -358,6 +372,9 @@ def sessions(self) -> dict: """ rval = {} self._init_errors() + if len(self.conns) == 0: + return rval + started_conns = [] excluded_conns = [] @@ -405,11 +422,11 @@ def sessions(self) -> dict: if len(excluded_conns) > 0: logging.error(f"Some sessions have been excluded due to errors: {', '.join(excluded_conns)}") self.conns = [conn for conn in self.conns if conn.get_name() not in excluded_conns] - if len(self.conns) == 0: + if len(self.conns) == len(excluded_conns): raise DSError("No sessions could be started successfully.") return rval - def ls(self) -> dict: + def ls(self, conn_names: list[str] = None) -> dict: """ After assignments have been performed, list the symbols that live in the DataSHIELD R session on the server side. @@ -418,7 +435,7 @@ def ls(self) -> dict: # ensure sessions are started and available self.sessions() rval = {} - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): try: rval[conn.get_name()] = conn.list_symbols() except Exception as e: @@ -427,15 +444,16 @@ def ls(self) -> dict: self._check_errors() return rval - def rm(self, symbol: str) -> None: + def rm(self, symbol: str, conn_names: list[str] = None) -> None: """ Remove a symbol from remote servers. :param symbol: The name of the symbol to remove + :param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used. """ # ensure sessions are started and available self.sessions() - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): try: conn.rm_symbol(symbol) except Exception as e: @@ -452,6 +470,7 @@ def assign_table( identifiers: str = None, id_name: str = None, asynchronous: bool = True, + conn_names: list[str] = None, ) -> None: """ Assign a data table from the data repository to a symbol in the DataSHIELD R session. @@ -460,11 +479,12 @@ def assign_table( :param table: The default name of the table to assign :param tables: The name of the table to assign, per server name. If not defined, 'table' is used. :param asynchronous: Whether the operation is asynchronous (if supported by the DataSHIELD server) + :param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used. """ # ensure sessions are started and available self.sessions() cmd = {} - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): name = table if tables and conn.get_name() in tables: name = tables[conn.get_name()] @@ -478,7 +498,12 @@ def assign_table( self._check_errors() def assign_resource( - self, symbol: str, resource: str = None, resources: dict = None, asynchronous: bool = True + self, + symbol: str, + resource: str = None, + resources: dict = None, + asynchronous: bool = True, + conn_names: list[str] = None, ) -> None: """ Assign a resource from the data repository to a symbol in the DataSHIELD R session. @@ -487,11 +512,12 @@ def assign_resource( :param resource: The default name of the resource to assign :param resources: The name of the resource to assign, per server name. If not defined, 'resource' is used. :param asynchronous: Whether the operation is asynchronous (if supported by the DataSHIELD server) + :param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used. """ # ensure sessions are started and available self.sessions() cmd = {} - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): name = resource if resources and conn.get_name() in resources: name = resources[conn.get_name()] @@ -504,18 +530,19 @@ def assign_resource( self._do_wait(cmd) self._check_errors() - def assign_expr(self, symbol: str, expr: str, asynchronous: bool = True) -> None: + def assign_expr(self, symbol: str, expr: str, asynchronous: bool = True, conn_names: list[str] = None) -> None: """ Assign the result of the evaluation of an expression to a symbol in the DataSHIELD R session. :param symbol: The name of the destination symbol :param expr: The R expression to evaluate and which result will be assigned :param asynchronous: Whether the operation is asynchronous (if supported by the DataSHIELD server) + :param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used. """ # ensure sessions are started and available self.sessions() cmd = {} - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): try: res = conn.assign_expr(symbol, expr, asynchronous) cmd[conn.get_name()] = res @@ -524,20 +551,21 @@ def assign_expr(self, symbol: str, expr: str, asynchronous: bool = True) -> None self._do_wait(cmd) self._check_errors() - def aggregate(self, expr: str, asynchronous: bool = True) -> dict: + def aggregate(self, expr: str, asynchronous: bool = True, conn_names: list[str] = None) -> dict: """ Aggregate some data from the DataSHIELD R session using a valid R expression. The aggregation expression must satisfy the data repository's DataSHIELD configuration. :param expr: The R expression to evaluate and which result will be returned :param asynchronous: Whether the operation is asynchronous (if supported by the DataSHIELD server) + :param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used. :return: The result of the aggregation expression evaluation, per remote server name """ # ensure sessions are started and available self.sessions() cmd = {} rval = {} - for conn in self.conns: + for conn in self._get_selected_connections(conn_names): try: res = conn.aggregate(expr, asynchronous) cmd[conn.get_name()] = res @@ -573,6 +601,20 @@ def _do_wait(self, cmd: dict) -> dict: time.sleep(0.1) return rval + def _get_selected_connections(self, conn_names: list[str] = None) -> list[DSConnection]: + """ + Get the list of opened connections, optionally filtered by connection names. + + :param conn_names: The optional list of connection names to select. + :return: The list of selected opened connections + """ + if not self.conns: + return [] + if conn_names is None: + return self.conns + selected_names = set(conn_names) + return [conn for conn in self.conns if conn.get_name() in selected_names] + def _init_errors(self) -> None: """ Prepare for storing errors. diff --git a/tests/test_session_filters.py b/tests/test_session_filters.py new file mode 100644 index 0000000..be71e4c --- /dev/null +++ b/tests/test_session_filters.py @@ -0,0 +1,136 @@ +from datashield import DSSession + + +class FakeResult: + def __init__(self, value): + self.value = value + + def is_completed(self) -> bool: + return True + + def fetch(self): + return self.value + + +class FakeConn: + def __init__(self, name: str): + self._name = name + self.started = False + self.disconnected = False + self.saved_workspaces = [] + self.restored_workspaces = [] + self.removed_workspaces = [] + self.rm_symbols = [] + self.assign_expr_calls = [] + self.keep_alive_calls = 0 + + def get_name(self) -> str: + return self._name + + def list_tables(self) -> list: + return [f"{self._name}_table"] + + def has_session(self) -> bool: + return self.started + + def start_session(self, asynchronous: bool = True): + self.started = True + return {"started": True, "async": asynchronous} + + def is_session_started(self) -> bool: + return self.started + + def get_session(self): + return {"name": self._name} + + def list_symbols(self) -> list: + return [f"{self._name}_symbol"] + + def rm_symbol(self, name: str) -> None: + self.rm_symbols.append(name) + + def assign_expr(self, symbol: str, expr: str, asynchronous: bool = True) -> FakeResult: + self.assign_expr_calls.append((symbol, expr, asynchronous)) + return FakeResult({"symbol": symbol, "expr": expr, "conn": self._name}) + + def aggregate(self, expr: str, asynchronous: bool = True) -> FakeResult: + return FakeResult({"expr": expr, "conn": self._name, "async": asynchronous}) + + def save_workspace(self, name: str) -> list: + self.saved_workspaces.append(name) + return self.saved_workspaces + + def restore_workspace(self, name: str) -> list: + self.restored_workspaces.append(name) + return self.restored_workspaces + + def rm_workspace(self, name: str) -> list: + self.removed_workspaces.append(name) + return self.removed_workspaces + + def keep_alive(self) -> None: + self.keep_alive_calls += 1 + + def disconnect(self) -> None: + self.disconnected = True + + +def make_session() -> tuple[DSSession, FakeConn, FakeConn]: + conn1 = FakeConn("server1") + conn2 = FakeConn("server2") + session = DSSession([]) + session.conns = [conn1, conn2] + session.errors = {} + return session, conn1, conn2 + + +def test_tables_filters_connections(): + session, _, _ = make_session() + + result = session.tables(conn_names=["server2", "unknown"]) + + assert result == {"server2": ["server2_table"]} + + +def test_assign_expr_filters_connections(): + session, conn1, conn2 = make_session() + + session.assign_expr("x", "1+1", conn_names=["server1", "unknown"]) + + assert conn1.assign_expr_calls == [("x", "1+1", True)] + assert conn2.assign_expr_calls == [] + + +def test_aggregate_filters_connections(): + session, conn1, conn2 = make_session() + + result = session.aggregate("2+2", conn_names=["server2"]) + + assert result == {"server2": {"expr": "2+2", "conn": "server2", "async": True}} + assert conn1.assign_expr_calls == [] + assert conn2.assign_expr_calls == [] + + +def test_workspace_methods_filter_connections(): + session, conn1, conn2 = make_session() + + session.workspace_save("wk", conn_names=["server1"]) + session.workspace_restore("wk", conn_names=["server2"]) + session.workspace_rm("wk", conn_names=["server2", "missing"]) + + assert conn1.saved_workspaces == ["server1:wk"] + assert conn2.saved_workspaces == [] + assert conn1.restored_workspaces == [] + assert conn2.restored_workspaces == ["server2:wk"] + assert conn1.removed_workspaces == [] + assert conn2.removed_workspaces == ["server2:wk"] + + +def test_close_filters_connections_and_keeps_others_open(): + session, conn1, conn2 = make_session() + + session.close(conn_names=["server1", "unknown"]) + + assert conn1.disconnected is True + assert conn2.disconnected is False + assert session.get_connection_names() == ["server2"] From 5b8cc290d8dce7707b6b56703342935f793f5b33 Mon Sep 17 00:00:00 2001 From: Yannick Marcon Date: Sat, 21 Mar 2026 19:48:12 +0100 Subject: [PATCH 2/2] chore: code review --- datashield/api.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/datashield/api.py b/datashield/api.py index 31fc604..08251ab 100644 --- a/datashield/api.py +++ b/datashield/api.py @@ -137,6 +137,8 @@ def close(self, save: str = None, conn_names: list[str] = None) -> None: :param conn_names: The optional list of connection names to close. If not defined, all opened connections are closed. """ self.errors = {} + if not self.conns: + return selected_conns = self._get_selected_connections(conn_names) selected_names = {conn.get_name() for conn in selected_conns} for conn in selected_conns: @@ -158,7 +160,7 @@ def has_connections(self) -> bool: :return: True if some connections were opened, False otherwise """ - return len(self.conns) > 0 + return self.conns and len(self.conns) > 0 def get_connection_names(self) -> list[str]: """ @@ -166,7 +168,10 @@ def get_connection_names(self) -> list[str]: :return: The list of opened connection names """ - return [conn.get_name() for conn in self.conns] + if self.conns: + return [conn.get_name() for conn in self.conns] + else: + return [] def has_errors(self) -> bool: """ @@ -372,7 +377,7 @@ def sessions(self) -> dict: """ rval = {} self._init_errors() - if len(self.conns) == 0: + if not self.conns or len(self.conns) == 0: return rval started_conns = [] @@ -422,7 +427,7 @@ def sessions(self) -> dict: if len(excluded_conns) > 0: logging.error(f"Some sessions have been excluded due to errors: {', '.join(excluded_conns)}") self.conns = [conn for conn in self.conns if conn.get_name() not in excluded_conns] - if len(self.conns) == len(excluded_conns): + if len(self.conns) == 0: raise DSError("No sessions could be started successfully.") return rval