Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 87 additions & 40 deletions datashield/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,31 +129,38 @@ def open(self, restore: str = None, failSafe: bool = False) -> None:
for name in self.errors:
logging.error(f"Connection to {name} has failed")

def close(self, save: str = None) -> None:
def close(self, save: str = None, conn_names: list[str] = None) -> None:
"""
Close connections with remote servers.

:param cons: The list of connections to close.
:param save: The name of the workspace to save before closing the connections.
:param conn_names: The optional list of connection names to close. If not defined, all opened connections are closed.
"""
self.errors = {}
for conn in self.conns:
if not self.conns:
return
selected_conns = self._get_selected_connections(conn_names)
selected_names = {conn.get_name() for conn in selected_conns}
for conn in selected_conns:
try:
if save:
conn.save_workspace(f"{conn.get_name()}:{save}")
conn.disconnect()
except DSError:
# silently fail
pass
self.conns = None
if conn_names is None:
self.conns = None
else:
self.conns = [conn for conn in self.conns if conn.get_name() not in selected_names]

def has_connections(self) -> bool:
"""
Check if some connections were opened.

:return: True if some connections were opened, False otherwise
"""
return len(self.conns) > 0
return self.conns and len(self.conns) > 0

def get_connection_names(self) -> list[str]:
"""
Expand Down Expand Up @@ -186,27 +193,29 @@ def get_errors(self) -> dict:
# Environment
#

def tables(self) -> dict:
def tables(self, conn_names: list[str] = None) -> dict:
"""
List available table names from the data repository.

:param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried.
:return: The available table names from the data repository, per remote server name
"""
rval = {}
for conn in self.conns:
for conn in self._get_selected_connections(conn_names):
rval[conn.get_name()] = conn.list_tables()
return rval

def variables(self, table: str = None, tables: dict = None) -> dict:
def variables(self, table: str = None, tables: dict = None, conn_names: list[str] = None) -> dict:
"""
List available variables from the data repository, for a given table.

:param table: The default name of the table to list variables for
:param tables: The name of the table to list variables for, per server name. If not defined, 'table' is used.
:param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried.
:return: The available variables from the data repository, for a given table, per remote server name
"""
rval = {}
for conn in self.conns:
for conn in self._get_selected_connections(conn_names):
name = table
if tables and conn.get_name() in tables:
name = tables[conn.get_name()]
Expand All @@ -216,120 +225,130 @@ def variables(self, table: str = None, tables: dict = None) -> dict:
rval[conn.get_name()] = None
return rval

def taxonomies(self) -> dict:
def taxonomies(self, conn_names: list[str] = None) -> dict:
"""
List available taxonomies from the data repository. A taxonomy is a hierarchical structure of vocabulary
terms that can be used to annotate variables in the data repository.
Depending on the data repository's capabilities, taxonomies can be used to perform structured
queries when searching for variables.

:param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried.
:return: The available taxonomies from the data repository, per remote server name
"""
rval = {}
for conn in self.conns:
for conn in self._get_selected_connections(conn_names):
rval[conn.get_name()] = conn.list_taxonomies()
return rval

def search_variables(self, query: str) -> dict:
def search_variables(self, query: str, conn_names: list[str] = None) -> dict:
"""
Search for variable names matching a given query across all tables in the data repository.

:param query: The query to search for in variable names, e.g., a full-text search and/or structured
query (based on taxonomy terms), depending on the data repository's capabilities
:param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried.
:return: The matching variable names from the data repository, per remote server name
"""
rval = {}
for conn in self.conns:
for conn in self._get_selected_connections(conn_names):
rval[conn.get_name()] = conn.search_variables(query)
return rval

def resources(self) -> dict:
def resources(self, conn_names: list[str] = None) -> dict:
"""
List available resource names from the data repository.

:param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried.
:return: The available resource names from the data repository, per remote server name
"""
rval = {}
for conn in self.conns:
for conn in self._get_selected_connections(conn_names):
rval[conn.get_name()] = conn.list_resources()
return rval

def profiles(self) -> dict:
def profiles(self, conn_names: list[str] = None) -> dict:
"""
List available DataSHIELD profile names in the data repository.

:param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried.
:return: The available DataSHIELD profile names in the data repository, per remote server name
"""
rval = {}
for conn in self.conns:
for conn in self._get_selected_connections(conn_names):
rval[conn.get_name()] = conn.list_profiles()
return rval

def packages(self) -> dict:
def packages(self, conn_names: list[str] = None) -> dict:
"""
Get the list of DataSHIELD packages with their version, that have been configured on the remote data repository.

:param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried.
:return: The list of DataSHIELD packages with their version, that have been configured on the remote data repository, per remote server name
"""
rval = {}
for conn in self.conns:
for conn in self._get_selected_connections(conn_names):
rval[conn.get_name()] = conn.list_packages()
return rval

def methods(self, type: str = "aggregate") -> dict:
def methods(self, type: str = "aggregate", conn_names: list[str] = None) -> dict:
"""
Get the list of DataSHIELD methods that have been configured on the remote data repository.

:param type: The type of method, either "aggregate" (default) or "assign"
:param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried.
:return: The list of DataSHIELD methods that have been configured on the remote data repository, per remote server name
"""
rval = {}
for conn in self.conns:
for conn in self._get_selected_connections(conn_names):
rval[conn.get_name()] = conn.list_methods(type)
return rval

#
# Workspaces
#

def workspaces(self) -> dict:
def workspaces(self, conn_names: list[str] = None) -> dict:
"""
Get the list of DataSHIELD workspaces, that have been saved on the remote data repository.

:param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried.
:return: The list of DataSHIELD workspaces, that have been saved on the remote data repository, per remote server name
"""
rval = {}
for conn in self.conns:
for conn in self._get_selected_connections(conn_names):
rval[conn.get_name()] = conn.list_workspaces()
return rval

def workspace_save(self, name: str) -> None:
def workspace_save(self, name: str, conn_names: list[str] = None) -> None:
"""
Save the DataSHIELD R session in a workspace on the remote data repository.

:param name: The name of the workspace
:param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used.
"""
for conn in self.conns:
for conn in self._get_selected_connections(conn_names):
conn.save_workspace(f"{conn.get_name()}:{name}")

def workspace_restore(self, name: str) -> None:
def workspace_restore(self, name: str, conn_names: list[str] = None) -> None:
"""
Restore a saved DataSHIELD R session from the remote data repository. When restoring a workspace,
any existing symbol or file with same name will be overridden.

:param name: The name of the workspace
:param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used.
"""
for conn in self.conns:
for conn in self._get_selected_connections(conn_names):
conn.restore_workspace(f"{conn.get_name()}:{name}")

def workspace_rm(self, name: str) -> None:
def workspace_rm(self, name: str, conn_names: list[str] = None) -> None:
"""
Remove a DataSHIELD workspace from the remote data repository. Ignored if no
such workspace exists.

:param name: The name of the workspace
:param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used.
"""
for conn in self.conns:
for conn in self._get_selected_connections(conn_names):
conn.rm_workspace(f"{conn.get_name()}:{name}")

#
Expand Down Expand Up @@ -358,6 +377,9 @@ def sessions(self) -> dict:
"""
rval = {}
self._init_errors()
if not self.conns or len(self.conns) == 0:
return rval

started_conns = []
excluded_conns = []

Expand Down Expand Up @@ -409,7 +431,7 @@ def sessions(self) -> dict:
raise DSError("No sessions could be started successfully.")
return rval

def ls(self) -> dict:
def ls(self, conn_names: list[str] = None) -> dict:
"""
After assignments have been performed, list the symbols that live in the DataSHIELD R session on the server side.

Expand All @@ -418,7 +440,7 @@ def ls(self) -> dict:
# ensure sessions are started and available
self.sessions()
rval = {}
for conn in self.conns:
for conn in self._get_selected_connections(conn_names):
try:
rval[conn.get_name()] = conn.list_symbols()
except Exception as e:
Expand All @@ -427,15 +449,16 @@ def ls(self) -> dict:
self._check_errors()
return rval

def rm(self, symbol: str) -> None:
def rm(self, symbol: str, conn_names: list[str] = None) -> None:
"""
Remove a symbol from remote servers.

:param symbol: The name of the symbol to remove
:param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used.
"""
# ensure sessions are started and available
self.sessions()
for conn in self.conns:
for conn in self._get_selected_connections(conn_names):
try:
conn.rm_symbol(symbol)
except Exception as e:
Expand All @@ -452,6 +475,7 @@ def assign_table(
identifiers: str = None,
id_name: str = None,
asynchronous: bool = True,
conn_names: list[str] = None,
) -> None:
"""
Assign a data table from the data repository to a symbol in the DataSHIELD R session.
Expand All @@ -460,11 +484,12 @@ def assign_table(
:param table: The default name of the table to assign
:param tables: The name of the table to assign, per server name. If not defined, 'table' is used.
:param asynchronous: Whether the operation is asynchronous (if supported by the DataSHIELD server)
:param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used.
"""
# ensure sessions are started and available
self.sessions()
cmd = {}
for conn in self.conns:
for conn in self._get_selected_connections(conn_names):
name = table
if tables and conn.get_name() in tables:
name = tables[conn.get_name()]
Expand All @@ -478,7 +503,12 @@ def assign_table(
self._check_errors()

def assign_resource(
self, symbol: str, resource: str = None, resources: dict = None, asynchronous: bool = True
self,
symbol: str,
resource: str = None,
resources: dict = None,
asynchronous: bool = True,
conn_names: list[str] = None,
) -> None:
"""
Assign a resource from the data repository to a symbol in the DataSHIELD R session.
Expand All @@ -487,11 +517,12 @@ def assign_resource(
:param resource: The default name of the resource to assign
:param resources: The name of the resource to assign, per server name. If not defined, 'resource' is used.
:param asynchronous: Whether the operation is asynchronous (if supported by the DataSHIELD server)
:param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used.
"""
# ensure sessions are started and available
self.sessions()
cmd = {}
for conn in self.conns:
for conn in self._get_selected_connections(conn_names):
name = resource
if resources and conn.get_name() in resources:
name = resources[conn.get_name()]
Expand All @@ -504,18 +535,19 @@ def assign_resource(
self._do_wait(cmd)
self._check_errors()

def assign_expr(self, symbol: str, expr: str, asynchronous: bool = True) -> None:
def assign_expr(self, symbol: str, expr: str, asynchronous: bool = True, conn_names: list[str] = None) -> None:
"""
Assign the result of the evaluation of an expression to a symbol in the DataSHIELD R session.

:param symbol: The name of the destination symbol
:param expr: The R expression to evaluate and which result will be assigned
:param asynchronous: Whether the operation is asynchronous (if supported by the DataSHIELD server)
:param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used.
"""
# ensure sessions are started and available
self.sessions()
cmd = {}
for conn in self.conns:
for conn in self._get_selected_connections(conn_names):
try:
res = conn.assign_expr(symbol, expr, asynchronous)
cmd[conn.get_name()] = res
Expand All @@ -524,20 +556,21 @@ def assign_expr(self, symbol: str, expr: str, asynchronous: bool = True) -> None
self._do_wait(cmd)
self._check_errors()

def aggregate(self, expr: str, asynchronous: bool = True) -> dict:
def aggregate(self, expr: str, asynchronous: bool = True, conn_names: list[str] = None) -> dict:
"""
Aggregate some data from the DataSHIELD R session using a valid R expression. The
aggregation expression must satisfy the data repository's DataSHIELD configuration.

:param expr: The R expression to evaluate and which result will be returned
:param asynchronous: Whether the operation is asynchronous (if supported by the DataSHIELD server)
:param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used.
:return: The result of the aggregation expression evaluation, per remote server name
"""
# ensure sessions are started and available
self.sessions()
cmd = {}
rval = {}
for conn in self.conns:
for conn in self._get_selected_connections(conn_names):
try:
res = conn.aggregate(expr, asynchronous)
cmd[conn.get_name()] = res
Expand Down Expand Up @@ -573,6 +606,20 @@ def _do_wait(self, cmd: dict) -> dict:
time.sleep(0.1)
return rval

def _get_selected_connections(self, conn_names: list[str] = None) -> list[DSConnection]:
"""
Get the list of opened connections, optionally filtered by connection names.

:param conn_names: The optional list of connection names to select.
:return: The list of selected opened connections
"""
if not self.conns:
return []
if conn_names is None:
return self.conns
selected_names = set(conn_names)
return [conn for conn in self.conns if conn.get_name() in selected_names]

def _init_errors(self) -> None:
"""
Prepare for storing errors.
Expand Down
Loading
Loading