Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
install:
uv sync --all-extras

update:
rm -f uv.lock
uv sync

test:
uv run --all-extras pytest

Expand Down
36 changes: 36 additions & 0 deletions datashield_opal/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ def __init__(self, name: str, loginInfo: OpalClient.LoginInfo, profile: str = "d
self.rsession = None
self.rsession_started = False

def get_name(self) -> str:
"""Get the name of the connection."""
return self.name

def check_user(self) -> bool:
"""Check if the user can authenticate by trying to retrieve the current subject profile."""
try:
Expand All @@ -165,10 +169,40 @@ def list_tables(self) -> list:
return names

def has_table(self, name: str) -> bool:
# name is in format "datasource.table"
if "." not in name:
raise OpalDSError(ValueError(f"Invalid table name: {name}. Expected format 'datasource.table'"))
parts = name.split(".")
response = self._get(UriBuilder(["datasource", parts[0], "table", parts[1]]).build()).send()
return response.code == 200

def list_table_variables(self, table) -> list:
# table is in format "datasource.table"
if "." not in table:
raise OpalDSError(ValueError(f"Invalid table name: {table}. Expected format 'datasource.table'"))
tokens = table.split(".")
project_name = tokens[0]
table_name = tokens[1]
return (
self
._get(UriBuilder(["datasource", project_name, "table", table_name, "variables"]).build())
.fail_on_error()
.send()
.from_json()
)

def list_taxonomies(self) -> list:
return self._get(UriBuilder(["system", "conf", "taxonomies"]).build()).fail_on_error().send().from_json()

def search_variables(self, query) -> dict:
return (
self
._get(UriBuilder(["datasources", "variables", "_search"]).query("query", query).build())
.fail_on_error()
.send()
.from_json()
)

def list_resources(self) -> list:
response = self._get("/projects").fail_on_error().send()
projects = response.from_json()
Expand All @@ -181,6 +215,8 @@ def list_resources(self) -> list:
return names

def has_resource(self, name: str) -> bool:
if "." not in name:
raise OpalDSError(ValueError(f"Invalid resource name: {name}. Expected format 'project.resource'"))
parts = name.split(".")
response = self._get(UriBuilder(["project", parts[0], "resource", parts[1]]).build()).send()
return response.code == 200
Expand Down
7 changes: 7 additions & 0 deletions tests/test_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,13 @@ def test_tables(self):
assert "CNSIM.CNSIM1" in tables
assert conn.has_table("CNSIM.CNSIM1")

@pytest.mark.integration
def test_table_variables(self):
conn = self.conn
variables = conn.list_table_variables("CNSIM.CNSIM1")
assert type(variables) is list
assert "LAB_TSC" in [v.get("name") for v in variables]

@pytest.mark.integration
def test_resources(self):
conn = self.conn
Expand Down
Loading
Loading