diff --git a/README.md b/README.md index 838bf2b..476021a 100644 --- a/README.md +++ b/README.md @@ -97,14 +97,15 @@ from transformplan.backends.duckdb import DuckDBBackend con = duckdb.connect() rel = con.sql("SELECT * FROM 'patients.parquet'") +# Same plan — backend chosen at execution time plan = ( - TransformPlan(backend=DuckDBBackend(con)) + TransformPlan() .col_rename(column="PatientID", new_name="patient_id") .rows_filter(Col("age") >= 18) .math_round(column="score", decimals=2) ) -result, protocol = plan.process(rel) +result, protocol = plan.process(rel, backend=DuckDBBackend(con)) ``` ## Available Operations diff --git a/docs/api/backends.md b/docs/api/backends.md index 1decd9d..8537f18 100644 --- a/docs/api/backends.md +++ b/docs/api/backends.md @@ -9,23 +9,22 @@ The backend determines how data is stored and transformed: - **PolarsBackend** (default): Operates on Polars DataFrames using native Polars expressions - **DuckDBBackend** (optional): Operates on DuckDB relations using SQL generation -All operations, validation, dry-run, and serialization work identically regardless of backend. Pipelines serialized with one backend can be loaded and executed with another. +A `TransformPlan` is a pure, backend-agnostic recipe of operations. The backend is chosen at execution time by passing it to `process()`, `validate()`, or `dry_run()`. If no backend is specified, `PolarsBackend` is used by default. Pipelines serialized with one backend can be loaded and executed with another. ```python from transformplan import TransformPlan -# Default — uses PolarsBackend -plan = TransformPlan() +# Build a plan — no backend needed +plan = TransformPlan().col_drop("temp").math_add("age", 1) -# Explicit Polars backend -from transformplan.backends.polars import PolarsBackend -plan = TransformPlan(backend=PolarsBackend()) +# Execute with default PolarsBackend +result, protocol = plan.process(polars_df) -# DuckDB backend +# Execute with DuckDB backend import duckdb from transformplan.backends.duckdb import DuckDBBackend con = duckdb.connect() -plan = TransformPlan(backend=DuckDBBackend(con)) +result, protocol = plan.process(duckdb_rel, backend=DuckDBBackend(con)) ``` ## Backend ABC @@ -96,26 +95,25 @@ con = duckdb.connect() rel = con.sql("SELECT * FROM 'data.parquet'") plan = ( - TransformPlan(backend=DuckDBBackend(con)) + TransformPlan() .col_rename(column="ID", new_name="id") .rows_filter(Col("age") >= 18) .math_standardize(column="score", new_column="z_score") ) -result, protocol = plan.process(rel) +result, protocol = plan.process(rel, backend=DuckDBBackend(con)) ``` ## Cross-Backend Serialization -Pipelines are backend-agnostic when serialized. You can build a pipeline with one backend and execute it with another: +Pipelines are inherently backend-agnostic. The same serialized plan can be executed with any backend: ```python -import polars as pl import duckdb from transformplan import TransformPlan, Col from transformplan.backends.duckdb import DuckDBBackend -# Build and serialize with Polars (default) +# Build and serialize plan = ( TransformPlan() .col_rename(column="ID", new_name="id") @@ -123,11 +121,14 @@ plan = ( ) plan.to_json("pipeline.json") -# Load and execute with DuckDB +# Load and execute with Polars (default) +restored = TransformPlan.from_json("pipeline.json") +result, protocol = restored.process(polars_df) + +# Or execute with DuckDB con = duckdb.connect() rel = con.sql("SELECT * FROM 'data.parquet'") -plan_duckdb = TransformPlan.from_json("pipeline.json", backend=DuckDBBackend(con)) -result, protocol = plan_duckdb.process(rel) +result, protocol = restored.process(rel, backend=DuckDBBackend(con)) ``` ## Type System diff --git a/docs/api/plan.md b/docs/api/plan.md index 7d0530f..6aa713f 100644 --- a/docs/api/plan.md +++ b/docs/api/plan.md @@ -4,7 +4,7 @@ The main class for building and executing transformation pipelines. ## Overview -`TransformPlan` uses a deferred execution model: operations are registered via method chaining, then executed together when you call `process()`, `validate()`, or `dry_run()`. An optional `backend` parameter selects the execution engine (defaults to `PolarsBackend`). +`TransformPlan` uses a deferred execution model: operations are registered via method chaining, then executed together when you call `process()`, `validate()`, or `dry_run()`. The plan itself is backend-agnostic — the backend is chosen at execution time (defaults to `PolarsBackend`). ```python from transformplan import TransformPlan, Col @@ -22,15 +22,19 @@ df_result, protocol = plan.process(df) ## Backend Selection +The backend is passed at execution time, not at construction: + ```python +from transformplan.backends.duckdb import DuckDBBackend + +plan = TransformPlan().col_drop("temp").math_add("age", 1) + # Default (Polars) -plan = TransformPlan() +result, protocol = plan.process(polars_df) # DuckDB -import duckdb -from transformplan.backends.duckdb import DuckDBBackend con = duckdb.connect() -plan = TransformPlan(backend=DuckDBBackend(con)) +result, protocol = plan.process(duckdb_rel, backend=DuckDBBackend(con)) ``` See [Backends](backends.md) for details on each backend. diff --git a/docs/api/validation.md b/docs/api/validation.md index efc2da2..3a7f2ce 100644 --- a/docs/api/validation.md +++ b/docs/api/validation.md @@ -69,7 +69,7 @@ if not result.is_valid: ## DuckDB Validation -Validation works identically with DuckDB relations: +Validation works identically with DuckDB relations — pass the backend at validation time: ```python import duckdb @@ -80,12 +80,12 @@ con = duckdb.connect() rel = con.sql("SELECT 'Alice' AS name, 25 AS age, 50000 AS salary") plan = ( - TransformPlan(backend=DuckDBBackend(con)) + TransformPlan() .col_drop("age") .rows_filter(Col("age") > 18) # Error: age was dropped! ) -result = plan.validate(rel) +result = plan.validate(rel, backend=DuckDBBackend(con)) # ValidationResult(valid=False, errors=1) ``` diff --git a/docs/getting-started/quickstart.md b/docs/getting-started/quickstart.md index c9ccb1a..c1e37f4 100644 --- a/docs/getting-started/quickstart.md +++ b/docs/getting-started/quickstart.md @@ -74,7 +74,7 @@ print(df_result) ## Using the DuckDB Backend -TransformPlan supports DuckDB as an alternative backend. All 88 operations, validation, and dry-run work identically — only the data type changes from Polars DataFrames to DuckDB relations. +TransformPlan supports DuckDB as an alternative backend. All 88 operations, validation, and dry-run work identically — the same plan works with both Polars DataFrames and DuckDB relations. Simply pass the backend at execution time: ```python import duckdb @@ -89,18 +89,20 @@ rel = con.sql(""" UNION ALL SELECT 'Diana', 'Sales', 70000, 2 """) +# Same plan as before — no backend in constructor plan = ( - TransformPlan(backend=DuckDBBackend(con)) + TransformPlan() .col_rename(column="name", new_name="employee") .math_multiply(column="salary", value=1.05) .math_round(column="salary", decimals=0) .rows_filter(Col("years") >= 3) ) -# Validate and execute — same API as Polars -result = plan.validate(rel) +# Pass backend at execution time +backend = DuckDBBackend(con) +result = plan.validate(rel, backend=backend) if result.is_valid: - df_result, protocol = plan.process(rel) + df_result, protocol = plan.process(rel, backend=backend) ``` ## Viewing the Audit Protocol diff --git a/docs/index.md b/docs/index.md index 6aeeecf..1190eec 100644 --- a/docs/index.md +++ b/docs/index.md @@ -96,14 +96,15 @@ from transformplan.backends.duckdb import DuckDBBackend con = duckdb.connect() rel = con.sql("SELECT * FROM 'patients.parquet'") +# Same plan — backend chosen at execution time plan = ( - TransformPlan(backend=DuckDBBackend(con)) + TransformPlan() .col_rename(column="PatientID", new_name="patient_id") .rows_filter(Col("age") >= 18) .math_round(column="score", decimals=2) ) -result, protocol = plan.process(rel) +result, protocol = plan.process(rel, backend=DuckDBBackend(con)) ``` ## Available Operations diff --git a/pyproject.toml b/pyproject.toml index c5e0b26..8365f59 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "transformplan" -version = "0.1.2" +version = "0.1.3" description = "Safe, reproducible data transformations with built-in auditing and validation" readme = "README.md" requires-python = ">=3.10" diff --git a/tests/test_duckdb.py b/tests/test_duckdb.py index 3032021..ce69b16 100644 --- a/tests/test_duckdb.py +++ b/tests/test_duckdb.py @@ -104,11 +104,6 @@ def null_rel(con: duckdb.DuckDBPyConnection) -> duckdb.DuckDBPyRelation: ) -def _plan(backend: DuckDBBackend) -> TransformPlan: - """Create a TransformPlan with DuckDB backend.""" - return TransformPlan(backend=backend) - - def _col_values(rel: duckdb.DuckDBPyRelation, col: str) -> list[Any]: """Fetch values of a single column.""" idx = list(rel.columns).index(col) @@ -162,14 +157,14 @@ class TestColDrop: def test_col_drop( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).col_drop("age").process(basic_rel) + result, _ = TransformPlan().col_drop("age").process(basic_rel, backend=backend) assert "age" not in result.columns assert len(result.columns) == 4 def test_col_drop_preserves_other_columns( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).col_drop("age").process(basic_rel) + result, _ = TransformPlan().col_drop("age").process(basic_rel, backend=backend) assert "id" in result.columns assert "name" in result.columns @@ -178,14 +173,22 @@ class TestColRename: def test_col_rename( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).col_rename("name", "full_name").process(basic_rel) + result, _ = ( + TransformPlan() + .col_rename("name", "full_name") + .process(basic_rel, backend=backend) + ) assert "full_name" in result.columns assert "name" not in result.columns def test_col_rename_preserves_data( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).col_rename("name", "full_name").process(basic_rel) + result, _ = ( + TransformPlan() + .col_rename("name", "full_name") + .process(basic_rel, backend=backend) + ) vals = _col_values(result, "full_name") assert vals == ["Alice", "Bob", "Charlie", "David", "Eve"] @@ -194,14 +197,18 @@ class TestColCast: def test_col_cast_int_to_float( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).col_cast("age", float).process(basic_rel) + result, _ = ( + TransformPlan().col_cast("age", float).process(basic_rel, backend=backend) + ) vals = _col_values(result, "age") assert all(isinstance(v, float) for v in vals) def test_col_cast_int_to_string( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).col_cast("id", str).process(basic_rel) + result, _ = ( + TransformPlan().col_cast("id", str).process(basic_rel, backend=backend) + ) vals = _col_values(result, "id") assert vals[0] == "1" @@ -211,9 +218,9 @@ def test_col_reorder( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend) + TransformPlan() .col_reorder(["salary", "name", "id", "age", "active"]) - .process(basic_rel) + .process(basic_rel, backend=backend) ) assert list(result.columns) == ["salary", "name", "id", "age", "active"] @@ -222,7 +229,11 @@ class TestColSelect: def test_col_select( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).col_select(["id", "name"]).process(basic_rel) + result, _ = ( + TransformPlan() + .col_select(["id", "name"]) + .process(basic_rel, backend=backend) + ) assert list(result.columns) == ["id", "name"] assert backend.get_shape(result) == (5, 2) @@ -231,7 +242,11 @@ class TestColDuplicate: def test_col_duplicate( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).col_duplicate("name", "name_copy").process(basic_rel) + result, _ = ( + TransformPlan() + .col_duplicate("name", "name_copy") + .process(basic_rel, backend=backend) + ) assert "name_copy" in result.columns assert _col_values(result, "name") == _col_values(result, "name_copy") @@ -241,7 +256,9 @@ def test_col_fill_null_with_value( self, backend: DuckDBBackend, null_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend).col_fill_null("name", value="Unknown").process(null_rel) + TransformPlan() + .col_fill_null("name", value="Unknown") + .process(null_rel, backend=backend) ) vals = _col_values(result, "name") assert None not in vals @@ -251,7 +268,9 @@ def test_col_fill_null_zero_strategy( self, backend: DuckDBBackend, null_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend).col_fill_null("age", strategy="zero").process(null_rel) + TransformPlan() + .col_fill_null("age", strategy="zero") + .process(null_rel, backend=backend) ) vals = _col_values(result, "age") assert None not in vals @@ -261,7 +280,11 @@ class TestColDropNull: def test_col_drop_null( self, backend: DuckDBBackend, null_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).col_drop_null(columns=["name"]).process(null_rel) + result, _ = ( + TransformPlan() + .col_drop_null(columns=["name"]) + .process(null_rel, backend=backend) + ) assert backend.get_shape(result)[0] == 3 # rows 2 and 5 had null names @@ -270,7 +293,7 @@ def test_col_drop_zero( self, backend: DuckDBBackend, con: duckdb.DuckDBPyConnection ) -> None: rel = con.sql("SELECT * FROM (VALUES (1, 10), (2, 0), (3, 30)) AS t(id, val)") - result, _ = _plan(backend).col_drop_zero("val").process(rel) + result, _ = TransformPlan().col_drop_zero("val").process(rel, backend=backend) assert backend.get_shape(result)[0] == 2 @@ -278,7 +301,11 @@ class TestColAdd: def test_col_add_with_value( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).col_add("status", value="active").process(basic_rel) + result, _ = ( + TransformPlan() + .col_add("status", value="active") + .process(basic_rel, backend=backend) + ) assert "status" in result.columns vals = _col_values(result, "status") assert all(v == "active" for v in vals) @@ -286,7 +313,11 @@ def test_col_add_with_value( def test_col_add_from_column( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).col_add("name_copy", expr="name").process(basic_rel) + result, _ = ( + TransformPlan() + .col_add("name_copy", expr="name") + .process(basic_rel, backend=backend) + ) assert _col_values(result, "name_copy") == _col_values(result, "name") @@ -294,7 +325,11 @@ class TestColAddUuid: def test_col_add_uuid( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).col_add_uuid("uuid", length=8).process(basic_rel) + result, _ = ( + TransformPlan() + .col_add_uuid("uuid", length=8) + .process(basic_rel, backend=backend) + ) assert "uuid" in result.columns vals = _col_values(result, "uuid") assert all(len(v) == 8 for v in vals) @@ -306,9 +341,9 @@ def test_col_hash( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend) + TransformPlan() .col_hash(columns=["id", "name"], new_column="hash", salt="test") - .process(basic_rel) + .process(basic_rel, backend=backend) ) assert "hash" in result.columns vals = _col_values(result, "hash") @@ -320,9 +355,9 @@ def test_col_coalesce( self, backend: DuckDBBackend, null_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend) + TransformPlan() .col_coalesce(columns=["name", "id"], new_column="first_non_null") - .process(null_rel) + .process(null_rel, backend=backend) ) assert "first_non_null" in result.columns @@ -336,39 +371,47 @@ class TestMathScalar: def test_math_add( self, backend: DuckDBBackend, numeric_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).math_add("a", 10).process(numeric_rel) + result, _ = ( + TransformPlan().math_add("a", 10).process(numeric_rel, backend=backend) + ) assert _col_values(result, "a") == [11, 12, 13, 14, 15] def test_math_subtract( self, backend: DuckDBBackend, numeric_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).math_subtract("a", 1).process(numeric_rel) + result, _ = ( + TransformPlan().math_subtract("a", 1).process(numeric_rel, backend=backend) + ) assert _col_values(result, "a") == [0, 1, 2, 3, 4] def test_math_multiply( self, backend: DuckDBBackend, numeric_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).math_multiply("a", 2).process(numeric_rel) + result, _ = ( + TransformPlan().math_multiply("a", 2).process(numeric_rel, backend=backend) + ) assert _col_values(result, "a") == [2, 4, 6, 8, 10] def test_math_divide( self, backend: DuckDBBackend, numeric_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).math_divide("b", 10).process(numeric_rel) + result, _ = ( + TransformPlan().math_divide("b", 10).process(numeric_rel, backend=backend) + ) assert _col_values(result, "b") == [1.0, 2.0, 3.0, 4.0, 5.0] def test_math_abs( self, backend: DuckDBBackend, con: duckdb.DuckDBPyConnection ) -> None: rel = con.sql("SELECT * FROM (VALUES (-1,), (2,), (-3,)) AS t(a)") - result, _ = _plan(backend).math_abs("a").process(rel) + result, _ = TransformPlan().math_abs("a").process(rel, backend=backend) assert _col_values(result, "a") == [1, 2, 3] def test_math_round( self, backend: DuckDBBackend, con: duckdb.DuckDBPyConnection ) -> None: rel = con.sql("SELECT * FROM (VALUES (1.234,), (5.678,)) AS t(a)") - result, _ = _plan(backend).math_round("a", 1).process(rel) + result, _ = TransformPlan().math_round("a", 1).process(rel, backend=backend) vals = _col_values(result, "a") assert [float(v) for v in vals] == [1.2, 5.7] @@ -376,20 +419,26 @@ def test_math_clamp( self, backend: DuckDBBackend, numeric_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend).math_clamp("a", lower=2, upper=4).process(numeric_rel) + TransformPlan() + .math_clamp("a", lower=2, upper=4) + .process(numeric_rel, backend=backend) ) assert _col_values(result, "a") == [2, 2, 3, 4, 4] def test_math_set_min( self, backend: DuckDBBackend, numeric_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).math_set_min("a", 3).process(numeric_rel) + result, _ = ( + TransformPlan().math_set_min("a", 3).process(numeric_rel, backend=backend) + ) assert _col_values(result, "a") == [3, 3, 3, 4, 5] def test_math_set_max( self, backend: DuckDBBackend, numeric_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).math_set_max("a", 3).process(numeric_rel) + result, _ = ( + TransformPlan().math_set_max("a", 3).process(numeric_rel, backend=backend) + ) assert _col_values(result, "a") == [1, 2, 3, 3, 3] @@ -398,7 +447,9 @@ def test_math_add_columns( self, backend: DuckDBBackend, numeric_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend).math_add_columns("a", "b", "sum").process(numeric_rel) + TransformPlan() + .math_add_columns("a", "b", "sum") + .process(numeric_rel, backend=backend) ) assert _col_values(result, "sum") == [11, 22, 33, 44, 55] @@ -406,7 +457,9 @@ def test_math_subtract_columns( self, backend: DuckDBBackend, numeric_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend).math_subtract_columns("b", "a", "diff").process(numeric_rel) + TransformPlan() + .math_subtract_columns("b", "a", "diff") + .process(numeric_rel, backend=backend) ) assert _col_values(result, "diff") == [9, 18, 27, 36, 45] @@ -414,7 +467,9 @@ def test_math_multiply_columns( self, backend: DuckDBBackend, numeric_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend).math_multiply_columns("a", "b", "prod").process(numeric_rel) + TransformPlan() + .math_multiply_columns("a", "b", "prod") + .process(numeric_rel, backend=backend) ) assert _col_values(result, "prod") == [10, 40, 90, 160, 250] @@ -422,7 +477,9 @@ def test_math_divide_columns( self, backend: DuckDBBackend, numeric_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend).math_divide_columns("b", "a", "ratio").process(numeric_rel) + TransformPlan() + .math_divide_columns("b", "a", "ratio") + .process(numeric_rel, backend=backend) ) assert _col_values(result, "ratio") == [10.0, 10.0, 10.0, 10.0, 10.0] @@ -430,9 +487,9 @@ def test_math_percent_of( self, backend: DuckDBBackend, numeric_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend) + TransformPlan() .math_percent_of("a", "c", "pct", multiply_by=100.0) - .process(numeric_rel) + .process(numeric_rel, backend=backend) ) assert _col_values(result, "pct") == [1.0, 1.0, 1.0, 1.0, 1.0] @@ -441,16 +498,20 @@ class TestMathWindow: def test_math_cumsum( self, backend: DuckDBBackend, numeric_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).math_cumsum("a", "cumulative").process(numeric_rel) + result, _ = ( + TransformPlan() + .math_cumsum("a", "cumulative") + .process(numeric_rel, backend=backend) + ) assert _col_values(result, "cumulative") == [1, 3, 6, 10, 15] def test_math_rank( self, backend: DuckDBBackend, numeric_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend) + TransformPlan() .math_rank("a", "rank", method="ordinal", descending=False) - .process(numeric_rel) + .process(numeric_rel, backend=backend) ) assert _col_values(result, "rank") == [1, 2, 3, 4, 5] @@ -460,9 +521,9 @@ def test_math_standardize( self, backend: DuckDBBackend, numeric_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend) + TransformPlan() .math_standardize("a", new_column="z_score") - .process(numeric_rel) + .process(numeric_rel, backend=backend) ) vals = _col_values(result, "z_score") assert len(vals) == 5 @@ -473,9 +534,9 @@ def test_math_minmax( self, backend: DuckDBBackend, numeric_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend) + TransformPlan() .math_minmax("a", feature_range=(0.0, 1.0), new_column="scaled") - .process(numeric_rel) + .process(numeric_rel, backend=backend) ) vals = _col_values(result, "scaled") assert vals[0] == pytest.approx(0.0) @@ -485,9 +546,9 @@ def test_math_robust_scale( self, backend: DuckDBBackend, numeric_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend) + TransformPlan() .math_robust_scale("a", new_column="robust") - .process(numeric_rel) + .process(numeric_rel, backend=backend) ) assert "robust" in result.columns @@ -497,9 +558,9 @@ def test_math_log( self, backend: DuckDBBackend, numeric_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend) + TransformPlan() .math_log("a", offset=0, new_column="log_a") - .process(numeric_rel) + .process(numeric_rel, backend=backend) ) vals = _col_values(result, "log_a") assert vals[0] == pytest.approx(0.0, abs=0.01) # ln(1) = 0 @@ -508,14 +569,20 @@ def test_math_sqrt( self, backend: DuckDBBackend, con: duckdb.DuckDBPyConnection ) -> None: rel = con.sql("SELECT * FROM (VALUES (4,), (9,), (16,)) AS t(a)") - result, _ = _plan(backend).math_sqrt("a", new_column="sqrt_a").process(rel) + result, _ = ( + TransformPlan() + .math_sqrt("a", new_column="sqrt_a") + .process(rel, backend=backend) + ) assert _col_values(result, "sqrt_a") == [2.0, 3.0, 4.0] def test_math_power( self, backend: DuckDBBackend, numeric_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend).math_power("a", 2, new_column="sq").process(numeric_rel) + TransformPlan() + .math_power("a", 2, new_column="sq") + .process(numeric_rel, backend=backend) ) assert _col_values(result, "sq") == [1.0, 4.0, 9.0, 16.0, 25.0] @@ -523,9 +590,9 @@ def test_math_winsorize( self, backend: DuckDBBackend, numeric_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend) + TransformPlan() .math_winsorize("a", lower_value=2, upper_value=4, new_column="w") - .process(numeric_rel) + .process(numeric_rel, backend=backend) ) vals = _col_values(result, "w") assert min(vals) >= 2 @@ -541,14 +608,20 @@ class TestRowsFilter: def test_rows_filter_ge( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).rows_filter(Col("age") >= 35).process(basic_rel) + result, _ = ( + TransformPlan() + .rows_filter(Col("age") >= 35) + .process(basic_rel, backend=backend) + ) assert backend.get_shape(result)[0] == 3 def test_rows_filter_eq( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend).rows_filter(Col("name") == "Alice").process(basic_rel) + TransformPlan() + .rows_filter(Col("name") == "Alice") + .process(basic_rel, backend=backend) ) assert backend.get_shape(result)[0] == 1 @@ -556,9 +629,9 @@ def test_rows_filter_combined( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend) + TransformPlan() .rows_filter((Col("age") >= 30) & (Col("age") <= 40)) - .process(basic_rel) + .process(basic_rel, backend=backend) ) assert backend.get_shape(result)[0] == 3 @@ -567,7 +640,11 @@ class TestRowsDrop: def test_rows_drop( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).rows_drop(Col("age") < 30).process(basic_rel) + result, _ = ( + TransformPlan() + .rows_drop(Col("age") < 30) + .process(basic_rel, backend=backend) + ) assert backend.get_shape(result)[0] == 4 # only age=25 dropped @@ -576,7 +653,9 @@ def test_rows_drop_nulls( self, backend: DuckDBBackend, null_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend).rows_drop_nulls(columns=["name", "age"]).process(null_rel) + TransformPlan() + .rows_drop_nulls(columns=["name", "age"]) + .process(null_rel, backend=backend) ) # Rows with null name (2,5) or null age (3) are dropped assert backend.get_shape(result)[0] == 2 @@ -587,9 +666,9 @@ def test_rows_flag( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend) + TransformPlan() .rows_flag(Col("age") >= 35, "senior", true_value="yes", false_value="no") - .process(basic_rel) + .process(basic_rel, backend=backend) ) assert "senior" in result.columns vals = _col_values(result, "senior") @@ -604,9 +683,9 @@ def test_rows_unique_first( "SELECT * FROM (VALUES (1, 'A'), (1, 'A'), (2, 'B')) AS t(id, name)" ) result, _ = ( - _plan(backend) + TransformPlan() .rows_unique(columns=["id", "name"], keep="first") - .process(rel) + .process(rel, backend=backend) ) assert backend.get_shape(result)[0] == 2 @@ -617,7 +696,9 @@ def test_rows_unique_none( "SELECT * FROM (VALUES (1, 'A'), (1, 'A'), (2, 'B')) AS t(id, name)" ) result, _ = ( - _plan(backend).rows_unique(columns=["id", "name"], keep="none").process(rel) + TransformPlan() + .rows_unique(columns=["id", "name"], keep="none") + .process(rel, backend=backend) ) assert backend.get_shape(result)[0] == 1 # only (2, 'B') kept @@ -632,11 +713,11 @@ def test_rows_deduplicate( ") AS t(id, name, val)" ) result, _ = ( - _plan(backend) + TransformPlan() .rows_deduplicate( columns=["id"], sort_by="val", keep="first", descending=False ) - .process(rel) + .process(rel, backend=backend) ) assert backend.get_shape(result)[0] == 2 vals = _col_values(result, "val") @@ -648,7 +729,9 @@ def test_rows_sort_ascending( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend).rows_sort(by=["age"], descending=False).process(basic_rel) + TransformPlan() + .rows_sort(by=["age"], descending=False) + .process(basic_rel, backend=backend) ) vals = _col_values(result, "age") assert vals == sorted(vals) @@ -657,7 +740,9 @@ def test_rows_sort_descending( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend).rows_sort(by=["age"], descending=True).process(basic_rel) + TransformPlan() + .rows_sort(by=["age"], descending=True) + .process(basic_rel, backend=backend) ) vals = _col_values(result, "age") assert vals == sorted(vals, reverse=True) @@ -667,13 +752,13 @@ class TestRowsHeadTail: def test_rows_head( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).rows_head(3).process(basic_rel) + result, _ = TransformPlan().rows_head(3).process(basic_rel, backend=backend) assert backend.get_shape(result)[0] == 3 def test_rows_tail( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).rows_tail(2).process(basic_rel) + result, _ = TransformPlan().rows_tail(2).process(basic_rel, backend=backend) assert backend.get_shape(result)[0] == 2 @@ -681,7 +766,7 @@ class TestRowsSample: def test_rows_sample_n( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).rows_sample(n=3).process(basic_rel) + result, _ = TransformPlan().rows_sample(n=3).process(basic_rel, backend=backend) assert backend.get_shape(result)[0] == 3 @@ -694,7 +779,7 @@ def test_rows_explode( "(1, ['a', 'b', 'c']), (2, ['d', 'e'])" ") AS t(id, tags)" ) - result, _ = _plan(backend).rows_explode("tags").process(rel) + result, _ = TransformPlan().rows_explode("tags").process(rel, backend=backend) assert backend.get_shape(result)[0] == 5 @@ -707,14 +792,14 @@ def test_rows_melt( "AS t(id, name, q1, q2)" ) result, _ = ( - _plan(backend) + TransformPlan() .rows_melt( id_columns=["id", "name"], value_columns=["q1", "q2"], variable_name="quarter", value_name="value", ) - .process(rel) + .process(rel, backend=backend) ) assert backend.get_shape(result)[0] == 4 assert "quarter" in result.columns @@ -732,14 +817,14 @@ def test_rows_pivot( ") AS t(id, quarter, value)" ) result, _ = ( - _plan(backend) + TransformPlan() .rows_pivot( index=["id"], columns="quarter", values="value", aggregate_function="sum", ) - .process(rel) + .process(rel, backend=backend) ) assert backend.get_shape(result)[0] == 2 @@ -754,9 +839,9 @@ def test_str_replace_literal( self, backend: DuckDBBackend, string_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend) + TransformPlan() .str_replace("code", "PRD", "PROD", literal=True) - .process(string_rel) + .process(string_rel, backend=backend) ) vals = _col_values(result, "code") assert vals[0] == "PROD-001" @@ -767,7 +852,9 @@ def test_str_slice( self, backend: DuckDBBackend, string_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend).str_slice("code", offset=0, length=3).process(string_rel) + TransformPlan() + .str_slice("code", offset=0, length=3) + .process(string_rel, backend=backend) ) vals = _col_values(result, "code") assert vals[0] == "PRD" @@ -778,9 +865,9 @@ def test_str_truncate( self, backend: DuckDBBackend, string_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend) + TransformPlan() .str_truncate("text", max_length=10, suffix="...") - .process(string_rel) + .process(string_rel, backend=backend) ) vals = _col_values(result, "text") # " Hello World " (15 chars) should be truncated @@ -791,14 +878,18 @@ class TestStrCase: def test_str_lower( self, backend: DuckDBBackend, string_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).str_lower("code").process(string_rel) + result, _ = ( + TransformPlan().str_lower("code").process(string_rel, backend=backend) + ) vals = _col_values(result, "code") assert vals[0] == "prd-001" def test_str_upper( self, backend: DuckDBBackend, string_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).str_upper("first_name").process(string_rel) + result, _ = ( + TransformPlan().str_upper("first_name").process(string_rel, backend=backend) + ) vals = _col_values(result, "first_name") assert vals[0] == "JOHN" @@ -807,7 +898,9 @@ class TestStrStrip: def test_str_strip( self, backend: DuckDBBackend, string_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).str_strip("text").process(string_rel) + result, _ = ( + TransformPlan().str_strip("text").process(string_rel, backend=backend) + ) vals = _col_values(result, "text") assert vals[0] == "Hello World" @@ -818,9 +911,9 @@ def test_str_pad_left( ) -> None: rel = con.sql("SELECT * FROM (VALUES ('1',), ('22',), ('333',)) AS t(a)") result, _ = ( - _plan(backend) + TransformPlan() .str_pad("a", length=5, fill_char="0", side="left") - .process(rel) + .process(rel, backend=backend) ) vals = _col_values(result, "a") assert vals[0] == "00001" @@ -830,9 +923,9 @@ def test_str_pad_right( ) -> None: rel = con.sql("SELECT * FROM (VALUES ('1',), ('22',)) AS t(a)") result, _ = ( - _plan(backend) + TransformPlan() .str_pad("a", length=5, fill_char=".", side="right") - .process(rel) + .process(rel, backend=backend) ) vals = _col_values(result, "a") assert vals[0] == "1...." @@ -843,9 +936,9 @@ def test_str_concat( self, backend: DuckDBBackend, string_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend) + TransformPlan() .str_concat(["first_name", "last_name"], "full_name", separator=" ") - .process(string_rel) + .process(string_rel, backend=backend) ) vals = _col_values(result, "full_name") assert vals[0] == "John Doe" @@ -856,9 +949,9 @@ def test_str_extract( self, backend: DuckDBBackend, string_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend) + TransformPlan() .str_extract("code", r"(\w+)-(\d+)", group_index=1, new_column="prefix") - .process(string_rel) + .process(string_rel, backend=backend) ) vals = _col_values(result, "prefix") assert vals[0] == "PRD" @@ -870,9 +963,9 @@ def test_str_split_with_columns( ) -> None: rel = con.sql("SELECT * FROM (VALUES ('a-b-c',), ('d-e-f',)) AS t(text)") result, _ = ( - _plan(backend) + TransformPlan() .str_split("text", separator="-", new_columns=["p1", "p2", "p3"]) - .process(rel) + .process(rel, backend=backend) ) assert "p1" in result.columns assert "p2" in result.columns @@ -890,14 +983,22 @@ class TestDtExtract: def test_dt_year( self, backend: DuckDBBackend, datetime_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).dt_year("date_col", "year").process(datetime_rel) + result, _ = ( + TransformPlan() + .dt_year("date_col", "year") + .process(datetime_rel, backend=backend) + ) vals = _col_values(result, "year") assert all(v == 2024 for v in vals) def test_dt_month( self, backend: DuckDBBackend, datetime_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).dt_month("date_col", "month").process(datetime_rel) + result, _ = ( + TransformPlan() + .dt_month("date_col", "month") + .process(datetime_rel, backend=backend) + ) vals = _col_values(result, "month") assert vals[0] == 1 assert vals[1] == 3 @@ -905,14 +1006,22 @@ def test_dt_month( def test_dt_day( self, backend: DuckDBBackend, datetime_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).dt_day("date_col", "day").process(datetime_rel) + result, _ = ( + TransformPlan() + .dt_day("date_col", "day") + .process(datetime_rel, backend=backend) + ) vals = _col_values(result, "day") assert vals[0] == 15 def test_dt_quarter( self, backend: DuckDBBackend, datetime_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).dt_quarter("date_col", "q").process(datetime_rel) + result, _ = ( + TransformPlan() + .dt_quarter("date_col", "q") + .process(datetime_rel, backend=backend) + ) vals = _col_values(result, "q") assert vals[0] == 1 assert vals[-1] == 4 @@ -920,7 +1029,11 @@ def test_dt_quarter( def test_dt_week( self, backend: DuckDBBackend, datetime_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).dt_week("date_col", "week").process(datetime_rel) + result, _ = ( + TransformPlan() + .dt_week("date_col", "week") + .process(datetime_rel, backend=backend) + ) vals = _col_values(result, "week") assert all(isinstance(v, int) for v in vals) @@ -930,9 +1043,9 @@ def test_dt_year_month( self, backend: DuckDBBackend, datetime_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend) + TransformPlan() .dt_year_month("date_col", "ym", fmt="%Y-%m") - .process(datetime_rel) + .process(datetime_rel, backend=backend) ) vals = _col_values(result, "ym") assert vals[0] == "2024-01" @@ -941,7 +1054,9 @@ def test_dt_quarter_year( self, backend: DuckDBBackend, datetime_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend).dt_quarter_year("date_col", "qy").process(datetime_rel) + TransformPlan() + .dt_quarter_year("date_col", "qy") + .process(datetime_rel, backend=backend) ) vals = _col_values(result, "qy") assert vals[0] == "Q1-2024" @@ -951,7 +1066,9 @@ def test_dt_calendar_week( self, backend: DuckDBBackend, datetime_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend).dt_calendar_week("date_col", "cw").process(datetime_rel) + TransformPlan() + .dt_calendar_week("date_col", "cw") + .process(datetime_rel, backend=backend) ) vals = _col_values(result, "cw") assert vals[0].startswith("2024-W") @@ -960,9 +1077,9 @@ def test_dt_format( self, backend: DuckDBBackend, datetime_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend) + TransformPlan() .dt_format("date_col", "%Y/%m/%d", "formatted") - .process(datetime_rel) + .process(datetime_rel, backend=backend) ) vals = _col_values(result, "formatted") assert vals[0] == "2024/01/15" @@ -973,9 +1090,9 @@ def test_dt_parse( self, backend: DuckDBBackend, datetime_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend) + TransformPlan() .dt_parse("date_str", "%Y-%m-%d", "parsed") - .process(datetime_rel) + .process(datetime_rel, backend=backend) ) vals = _col_values(result, "parsed") assert vals[0] == date(2024, 1, 15) @@ -986,9 +1103,9 @@ def test_dt_diff_days( self, backend: DuckDBBackend, datetime_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend) + TransformPlan() .dt_diff_days("date_col", "birth_date", "age_days") - .process(datetime_rel) + .process(datetime_rel, backend=backend) ) vals = _col_values(result, "age_days") assert all(v > 0 for v in vals) @@ -997,9 +1114,9 @@ def test_dt_age_years( self, backend: DuckDBBackend, datetime_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend) + TransformPlan() .dt_age_years("birth_date", reference_column="date_col", new_column="age") - .process(datetime_rel) + .process(datetime_rel, backend=backend) ) vals = _col_values(result, "age") assert all(v > 0 for v in vals) @@ -1010,9 +1127,9 @@ def test_dt_truncate_month( self, backend: DuckDBBackend, datetime_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend) + TransformPlan() .dt_truncate("date_col", "1mo", "month_start") - .process(datetime_rel) + .process(datetime_rel, backend=backend) ) vals = _col_values(result, "month_start") # All dates should be 1st of month @@ -1024,7 +1141,7 @@ def test_dt_is_between( self, backend: DuckDBBackend, datetime_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend) + TransformPlan() .dt_is_between( "date_col", start="2024-01-01", @@ -1032,7 +1149,7 @@ def test_dt_is_between( new_column="in_h1", closed="both", ) - .process(datetime_rel) + .process(datetime_rel, backend=backend) ) vals = _col_values(result, "in_h1") assert vals[0] is True # Jan 15 @@ -1050,13 +1167,13 @@ def test_map_values( self, backend: DuckDBBackend, map_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend) + TransformPlan() .map_values( "status", mapping={"A": "Active", "B": "Beta", "C": "Closed"}, default="Unknown", ) - .process(map_rel) + .process(map_rel, backend=backend) ) vals = _col_values(result, "status") assert vals[0] == "Active" @@ -1069,14 +1186,14 @@ def test_map_case( self, backend: DuckDBBackend, map_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend) + TransformPlan() .map_case( "status", cases=[("A", "High"), ("B", "Medium"), ("C", "Low")], default="Unknown", new_column="priority", ) - .process(map_rel) + .process(map_rel, backend=backend) ) vals = _col_values(result, "priority") assert vals[0] == "High" @@ -1088,7 +1205,7 @@ def test_map_from_column( self, backend: DuckDBBackend, map_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend) + TransformPlan() .map_from_column( "lookup_key", lookup_column="lookup_key", @@ -1096,7 +1213,7 @@ def test_map_from_column( new_column="resolved", default="N/A", ) - .process(map_rel) + .process(map_rel, backend=backend) ) vals = _col_values(result, "resolved") assert vals[0] == "One" @@ -1108,7 +1225,7 @@ def test_map_discretize( self, backend: DuckDBBackend, map_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend) + TransformPlan() .map_discretize( "score", bins=[70, 80, 90], @@ -1116,7 +1233,7 @@ def test_map_discretize( new_column="grade", right=True, ) - .process(map_rel) + .process(map_rel, backend=backend) ) vals = _col_values(result, "grade") assert "A" in vals # score=91 @@ -1128,9 +1245,9 @@ def test_map_onehot( self, backend: DuckDBBackend, map_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend) + TransformPlan() .map_onehot("category", prefix="cat", drop_original=True) - .process(map_rel) + .process(map_rel, backend=backend) ) assert "category" not in result.columns assert "cat_X" in result.columns @@ -1143,13 +1260,13 @@ def test_map_ordinal( self, backend: DuckDBBackend, map_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend) + TransformPlan() .map_ordinal( "status", categories=["A", "B", "C"], new_column="status_ord", ) - .process(map_rel) + .process(map_rel, backend=backend) ) vals = _col_values(result, "status_ord") assert vals[0] == 0 # A @@ -1162,13 +1279,13 @@ def test_map_label( self, backend: DuckDBBackend, map_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend) + TransformPlan() .map_label( "status", categories=["A", "B", "C"], new_column="label", ) - .process(map_rel) + .process(map_rel, backend=backend) ) vals = _col_values(result, "label") assert vals[0] == 0 @@ -1178,7 +1295,11 @@ class TestMapBoolToInt: def test_map_bool_to_int( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).map_bool_to_int("active").process(basic_rel) + result, _ = ( + TransformPlan() + .map_bool_to_int("active") + .process(basic_rel, backend=backend) + ) vals = _col_values(result, "active") assert vals[0] == 1 # True -> 1 assert vals[2] == 0 # False -> 0 @@ -1189,7 +1310,9 @@ def test_map_null_to_value( self, backend: DuckDBBackend, null_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend).map_null_to_value("name", "Unknown").process(null_rel) + TransformPlan() + .map_null_to_value("name", "Unknown") + .process(null_rel, backend=backend) ) vals = _col_values(result, "name") assert None not in vals @@ -1200,7 +1323,11 @@ class TestMapValueToNull: def test_map_value_to_null( self, backend: DuckDBBackend, map_rel: duckdb.DuckDBPyRelation ) -> None: - result, _ = _plan(backend).map_value_to_null("status", "C").process(map_rel) + result, _ = ( + TransformPlan() + .map_value_to_null("status", "C") + .process(map_rel, backend=backend) + ) vals = _col_values(result, "status") assert vals[3] is None # C -> null @@ -1215,12 +1342,12 @@ def test_multi_step_pipeline( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend) + TransformPlan() .rows_filter(Col("age") >= 30) .col_drop("active") .math_multiply("salary", 1.1) .col_rename("salary", "adjusted_salary") - .process(basic_rel) + .process(basic_rel, backend=backend) ) assert backend.get_shape(result)[0] == 4 assert "active" not in result.columns @@ -1230,7 +1357,10 @@ def test_protocol_tracking( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: _, protocol = ( - _plan(backend).col_drop("active").math_add("age", 1).process(basic_rel) + TransformPlan() + .col_drop("active") + .math_add("age", 1) + .process(basic_rel, backend=backend) ) assert protocol.input_hash is not None assert protocol._input_shape == (5, 5) @@ -1239,12 +1369,12 @@ def test_protocol_tracking( def test_from_dict_duckdb( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: - plan = _plan(backend).col_drop("active").math_add("age", 1) + plan = TransformPlan().col_drop("active").math_add("age", 1) d = plan.to_dict() - assert d["backend"] == "duckdb" + assert "backend" not in d restored = TransformPlan.from_dict(d) - result, _ = restored.process(basic_rel) + result, _ = restored.process(basic_rel, backend=backend) assert "active" not in result.columns @@ -1260,16 +1390,16 @@ def test_validate_valid_plan( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: """validate() works for a valid DuckDB plan.""" - plan = _plan(backend).math_add("age", 1) - result = plan.validate(basic_rel) + plan = TransformPlan().math_add("age", 1) + result = plan.validate(basic_rel, backend=backend) assert result.is_valid def test_validate_missing_column( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: """validate() catches missing column for DuckDB.""" - plan = _plan(backend).math_add("nonexistent", 1) - result = plan.validate(basic_rel) + plan = TransformPlan().math_add("nonexistent", 1) + result = plan.validate(basic_rel, backend=backend) assert not result.is_valid assert "does not exist" in str(result.errors[0]) @@ -1277,8 +1407,8 @@ def test_validate_wrong_type( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: """validate() catches type mismatch for DuckDB.""" - plan = _plan(backend).math_add("name", 10) - result = plan.validate(basic_rel) + plan = TransformPlan().math_add("name", 10) + result = plan.validate(basic_rel, backend=backend) assert not result.is_valid assert "expected numeric" in str(result.errors[0]) @@ -1286,8 +1416,8 @@ def test_validate_string_on_numeric( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: """validate() catches string op on numeric column for DuckDB.""" - plan = _plan(backend).str_lower("age") - result = plan.validate(basic_rel) + plan = TransformPlan().str_lower("age") + result = plan.validate(basic_rel, backend=backend) assert not result.is_valid assert "expected string" in str(result.errors[0]) @@ -1295,8 +1425,8 @@ def test_validate_multi_step( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: """validate() tracks schema changes across steps for DuckDB.""" - plan = _plan(backend).col_drop("name").col_drop("name") - result = plan.validate(basic_rel) + plan = TransformPlan().col_drop("name").col_drop("name") + result = plan.validate(basic_rel, backend=backend) assert not result.is_valid assert len(result.errors) == 1 @@ -1304,8 +1434,8 @@ def test_dry_run( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: """dry_run() works for DuckDB.""" - plan = _plan(backend).col_drop("active").math_add("age", 1) - preview = plan.dry_run(basic_rel) + plan = TransformPlan().col_drop("active").math_add("age", 1) + preview = plan.dry_run(basic_rel, backend=backend) assert preview.is_valid assert len(preview.steps) == 2 assert "active" in preview.steps[0].columns_removed @@ -1314,16 +1444,16 @@ def test_dry_run_invalid( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: """dry_run() reports errors for DuckDB.""" - plan = _plan(backend).col_drop("nonexistent") - preview = plan.dry_run(basic_rel) + plan = TransformPlan().col_drop("nonexistent") + preview = plan.dry_run(basic_rel, backend=backend) assert not preview.is_valid def test_process_with_validation( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: """process(validate=True) works for DuckDB (was previously skipped).""" - plan = _plan(backend).math_add("age", 1) - result, _protocol = plan.process(basic_rel, validate=True) + plan = TransformPlan().math_add("age", 1) + result, _protocol = plan.process(basic_rel, validate=True, backend=backend) rows = result.fetchall() assert len(rows) == 5 @@ -1333,55 +1463,53 @@ def test_process_validation_catches_error( """process(validate=True) raises for invalid DuckDB plan.""" from transformplan.validation import SchemaValidationError - plan = _plan(backend).math_add("nonexistent", 1) + plan = TransformPlan().math_add("nonexistent", 1) with pytest.raises(SchemaValidationError): - plan.process(basic_rel, validate=True) + plan.process(basic_rel, validate=True, backend=backend) def test_validate_datetime_op( self, backend: DuckDBBackend, datetime_rel: duckdb.DuckDBPyRelation ) -> None: """validate() works for datetime ops with DuckDB.""" - plan = _plan(backend).dt_year("date_col", "year") - result = plan.validate(datetime_rel) + plan = TransformPlan().dt_year("date_col", "year") + result = plan.validate(datetime_rel, backend=backend) assert result.is_valid def test_validate_col_rename_chain( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: """validate() correctly tracks renames for DuckDB.""" - plan = _plan(backend).col_rename("age", "years").math_add("years", 1) - result = plan.validate(basic_rel) + plan = TransformPlan().col_rename("age", "years").math_add("years", 1) + result = plan.validate(basic_rel, backend=backend) assert result.is_valid class TestCrossBackendSerialization: - """Tests for cross-backend plan serialization.""" + """Tests for cross-backend plan serialization. - def test_polars_plan_to_duckdb( + Plans are inherently backend-agnostic now — just serialize/deserialize + and run with appropriate backend at process time. + """ + + def test_plan_to_duckdb( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: - """Plan built with Polars can be serialized and run on DuckDB.""" - # Build plan with default (Polars) backend - polars_plan = TransformPlan().col_drop("active").math_add("age", 1) - d = polars_plan.to_dict() + """Plan can be serialized and run on DuckDB.""" + plan = TransformPlan().col_drop("active").math_add("age", 1) + d = plan.to_dict() + assert "backend" not in d - # Restore with DuckDB backend - d.pop("backend", None) # remove backend key if present - d["backend"] = "duckdb" restored = TransformPlan.from_dict(d) - result, _ = restored.process(basic_rel) + result, _ = restored.process(basic_rel, backend=backend) assert "active" not in result.columns - def test_duckdb_plan_to_polars(self, backend: DuckDBBackend) -> None: - """Plan built with DuckDB can be serialized and run on Polars.""" + def test_plan_to_polars(self) -> None: + """Plan can be serialized and run on Polars.""" import polars as pl - # Build plan with DuckDB backend - duckdb_plan = _plan(backend).col_drop("active").math_add("age", 1) - d = duckdb_plan.to_dict() + plan = TransformPlan().col_drop("active").math_add("age", 1) + d = plan.to_dict() - # Restore with default (Polars) backend - d.pop("backend", None) # force polars restored = TransformPlan.from_dict(d) df = pl.DataFrame( @@ -1405,9 +1533,9 @@ def test_numeric_min( self, backend: DuckDBBackend, numeric_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend) + TransformPlan() .math_diff_from_agg("a", "min", "diff_min") - .process(numeric_rel) + .process(numeric_rel, backend=backend) ) assert _col_values(result, "diff_min") == [0, 1, 2, 3, 4] @@ -1416,7 +1544,9 @@ def test_numeric_mean( ) -> None: rel = con.sql("SELECT * FROM (VALUES (10), (20), (30)) AS t(val)") result, _ = ( - _plan(backend).math_diff_from_agg("val", "mean", "diff_mean").process(rel) + TransformPlan() + .math_diff_from_agg("val", "mean", "diff_mean") + .process(rel, backend=backend) ) vals = _col_values(result, "diff_mean") assert vals == pytest.approx([-10.0, 0.0, 10.0]) @@ -1430,9 +1560,9 @@ def test_grouped( ") AS t(dept, val)" ) result, _ = ( - _plan(backend) + TransformPlan() .math_diff_from_agg("val", "min", "diff", group_by="dept") - .process(rel) + .process(rel, backend=backend) ) assert _col_values(result, "diff") == [0, 20, 0, 100] @@ -1447,7 +1577,9 @@ def test_datetime_column( ") AS t(ts)" ) result, _ = ( - _plan(backend).math_diff_from_agg("ts", "min", "since_first").process(rel) + TransformPlan() + .math_diff_from_agg("ts", "min", "since_first") + .process(rel, backend=backend) ) vals = _col_values(result, "since_first") # DuckDB returns INTERVAL; check the timedelta total_seconds @@ -1466,9 +1598,9 @@ def test_datetime_grouped( ") AS t(patient, ts)" ) result, _ = ( - _plan(backend) + TransformPlan() .math_diff_from_agg("ts", "min", "since_first", group_by="patient") - .process(rel) + .process(rel, backend=backend) ) vals = _col_values(result, "since_first") hours = [v.total_seconds() / 3600 for v in vals] @@ -1479,9 +1611,9 @@ def test_invalid_agg_raises( ) -> None: with pytest.raises(ValueError, match="Invalid aggregate"): ( - _plan(backend) + TransformPlan() .math_diff_from_agg("a", "invalid", "diff") # type: ignore[arg-type] # pyright: ignore[reportArgumentType] - .process(numeric_rel) + .process(numeric_rel, backend=backend) ) @@ -1492,7 +1624,9 @@ def test_col_expr_arithmetic( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend).col_expr("age_plus_10", "age + 10").process(basic_rel) + TransformPlan() + .col_expr("age_plus_10", "age + 10") + .process(basic_rel, backend=backend) ) assert "age_plus_10" in result.columns vals = _col_values(result, "age_plus_10") @@ -1503,15 +1637,96 @@ def test_col_expr_case_when( self, backend: DuckDBBackend, basic_rel: duckdb.DuckDBPyRelation ) -> None: result, _ = ( - _plan(backend) + TransformPlan() .col_expr( "category", "CASE WHEN age > 30 THEN 'senior' ELSE 'junior' END", ) - .process(basic_rel) + .process(basic_rel, backend=backend) ) assert "category" in result.columns vals = _col_values(result, "category") ages = _col_values(basic_rel, "age") for age, cat in zip(ages, vals): assert cat == ("senior" if age > 30 else "junior") + + +# ============================================================================= +# Join operations +# ============================================================================= + + +class TestJoin: + """Tests for join operations with DuckDB backend.""" + + def test_inner_join( + self, con: duckdb.DuckDBPyConnection, backend: DuckDBBackend + ) -> None: + main_rel = con.sql( + "SELECT * FROM (VALUES (1, 'Alice'), (2, 'Bob'), (3, 'Charlie')) " + "AS t(id, name)" + ) + cohort_rel = con.sql("SELECT * FROM (VALUES (1,), (3,)) AS t(id)") + result, _ = ( + TransformPlan() + .join(on="id", right_name="cohort", how="inner") + .process(main_rel, references={"cohort": cohort_rel}, backend=backend) + ) + rows = result.fetchall() + assert len(rows) == 2 + ids = [r[0] for r in rows] + assert sorted(ids) == [1, 3] + + def test_left_join( + self, con: duckdb.DuckDBPyConnection, backend: DuckDBBackend + ) -> None: + main_rel = con.sql( + "SELECT * FROM (VALUES (1, 'Alice'), (2, 'Bob'), (3, 'Charlie')) " + "AS t(id, name)" + ) + ref_rel = con.sql("SELECT * FROM (VALUES (1, 100), (3, 300)) AS t(id, score)") + result, _ = ( + TransformPlan() + .join(on="id", right_name="ref", how="left") + .process(main_rel, references={"ref": ref_rel}, backend=backend) + ) + rows = result.fetchall() + assert len(rows) == 3 + assert "score" in result.columns + + def test_select_columns( + self, con: duckdb.DuckDBPyConnection, backend: DuckDBBackend + ) -> None: + main_rel = con.sql("SELECT * FROM (VALUES (1, 85), (2, 72)) AS t(id, score)") + ref_rel = con.sql( + "SELECT * FROM (VALUES (85, 'Excellent', 'A'), (72, 'Good', 'B')) " + "AS t(concept_id, concept_name, category)" + ) + result, _ = ( + TransformPlan() + .join( + on="score", + right_name="ref", + how="left", + right_on="concept_id", + select_columns=["concept_name"], + ) + .process(main_rel, references={"ref": ref_rel}, backend=backend) + ) + assert "concept_name" in result.columns + assert "category" not in result.columns + + def test_lazy_composition( + self, con: duckdb.DuckDBPyConnection, backend: DuckDBBackend + ) -> None: + """Verify that the result is a relation (lazy), not materialized.""" + main_rel = con.sql( + "SELECT * FROM (VALUES (1, 'Alice'), (2, 'Bob')) AS t(id, name)" + ) + ref_rel = con.sql("SELECT * FROM (VALUES (1, 100)) AS t(id, score)") + result, _ = ( + TransformPlan() + .join(on="id", right_name="ref", how="left") + .process(main_rel, references={"ref": ref_rel}, backend=backend) + ) + assert isinstance(result, duckdb.DuckDBPyRelation) diff --git a/tests/test_join.py b/tests/test_join.py new file mode 100644 index 0000000..4f3c758 --- /dev/null +++ b/tests/test_join.py @@ -0,0 +1,332 @@ +"""Tests for join operations.""" + +from __future__ import annotations + +import polars as pl +import pytest + +from transformplan import TransformPlan + + +@pytest.fixture +def main_df() -> pl.DataFrame: + return pl.DataFrame( + { + "person_id": [1, 2, 3, 4, 5], + "name": ["Alice", "Bob", "Charlie", "David", "Eve"], + "score": [85, 72, 91, 68, 79], + } + ) + + +@pytest.fixture +def cohort_df() -> pl.DataFrame: + return pl.DataFrame( + { + "person_id": [1, 3, 5], + } + ) + + +@pytest.fixture +def concepts_df() -> pl.DataFrame: + return pl.DataFrame( + { + "concept_id": [85, 72, 91, 68, 79], + "concept_name": ["Excellent", "Good", "Outstanding", "Fair", "Good+"], + "category": ["A", "B", "A", "C", "B"], + } + ) + + +class TestJoinInner: + """Tests for inner join (cohort filtering).""" + + def test_inner_join_filters_rows( + self, main_df: pl.DataFrame, cohort_df: pl.DataFrame + ) -> None: + plan = TransformPlan().join(on="person_id", right_name="cohort", how="inner") + result, _ = plan.process(main_df, references={"cohort": cohort_df}) + assert len(result) == 3 + assert set(result["person_id"].to_list()) == {1, 3, 5} + + def test_inner_join_preserves_columns( + self, main_df: pl.DataFrame, cohort_df: pl.DataFrame + ) -> None: + plan = TransformPlan().join(on="person_id", right_name="cohort", how="inner") + result, _ = plan.process(main_df, references={"cohort": cohort_df}) + assert result.columns == ["person_id", "name", "score"] + + +class TestJoinLeft: + """Tests for left join (enrichment).""" + + def test_left_join_keeps_all_rows( + self, main_df: pl.DataFrame, concepts_df: pl.DataFrame + ) -> None: + plan = TransformPlan().join( + on="score", + right_name="concepts", + how="left", + right_on="concept_id", + select_columns=["concept_name"], + ) + result, _ = plan.process(main_df, references={"concepts": concepts_df}) + assert len(result) == 5 + assert "concept_name" in result.columns + + def test_left_join_enriches_data( + self, main_df: pl.DataFrame, concepts_df: pl.DataFrame + ) -> None: + plan = TransformPlan().join( + on="score", + right_name="concepts", + how="left", + right_on="concept_id", + select_columns=["concept_name"], + ) + result, _ = plan.process(main_df, references={"concepts": concepts_df}) + # Alice has score 85 -> "Excellent" + alice_row = result.filter(pl.col("name") == "Alice") + assert alice_row["concept_name"][0] == "Excellent" + + +class TestJoinLeftOnRightOn: + """Tests for left_on/right_on (different column names).""" + + def test_different_column_names( + self, main_df: pl.DataFrame, concepts_df: pl.DataFrame + ) -> None: + plan = TransformPlan().join( + on="score", + right_name="concepts", + how="left", + left_on="score", + right_on="concept_id", + ) + result, _ = plan.process(main_df, references={"concepts": concepts_df}) + assert len(result) == 5 + assert "concept_name" in result.columns + assert "category" in result.columns + + +class TestJoinSuffix: + """Tests for suffix handling with duplicate columns.""" + + def test_suffix_on_duplicate_columns(self) -> None: + left = pl.DataFrame({"id": [1, 2], "value": [10, 20]}) + right = pl.DataFrame({"id": [1, 2], "value": [100, 200]}) + plan = TransformPlan().join( + on="id", right_name="right", how="left", suffix="_r" + ) + result, _ = plan.process(left, references={"right": right}) + assert "value" in result.columns + assert "value_r" in result.columns + + def test_default_suffix(self) -> None: + left = pl.DataFrame({"id": [1, 2], "value": [10, 20]}) + right = pl.DataFrame({"id": [1, 2], "value": [100, 200]}) + plan = TransformPlan().join(on="id", right_name="right", how="left") + result, _ = plan.process(left, references={"right": right}) + assert "value_right" in result.columns + + +class TestJoinSelectColumns: + """Tests for select_columns parameter.""" + + def test_select_specific_columns( + self, main_df: pl.DataFrame, concepts_df: pl.DataFrame + ) -> None: + plan = TransformPlan().join( + on="score", + right_name="concepts", + how="left", + right_on="concept_id", + select_columns=["concept_name"], + ) + result, _ = plan.process(main_df, references={"concepts": concepts_df}) + assert "concept_name" in result.columns + assert "category" not in result.columns + + def test_no_select_columns_gets_all( + self, main_df: pl.DataFrame, concepts_df: pl.DataFrame + ) -> None: + plan = TransformPlan().join( + on="score", + right_name="concepts", + how="left", + right_on="concept_id", + ) + result, _ = plan.process(main_df, references={"concepts": concepts_df}) + assert "concept_name" in result.columns + assert "category" in result.columns + + +class TestJoinMissingReference: + """Tests for missing reference error.""" + + def test_missing_reference_raises(self, main_df: pl.DataFrame) -> None: + plan = TransformPlan().join(on="person_id", right_name="cohort", how="inner") + with pytest.raises(ValueError, match="Reference 'cohort' not found"): + plan.process(main_df) + + def test_wrong_reference_name_raises( + self, main_df: pl.DataFrame, cohort_df: pl.DataFrame + ) -> None: + plan = TransformPlan().join(on="person_id", right_name="cohort", how="inner") + with pytest.raises(ValueError, match="Reference 'cohort' not found"): + plan.process(main_df, references={"wrong_name": cohort_df}) + + +class TestJoinSerialization: + """Tests for serialization round-trip.""" + + def test_to_json_from_json_roundtrip( + self, main_df: pl.DataFrame, cohort_df: pl.DataFrame + ) -> None: + plan = TransformPlan().join(on="person_id", right_name="cohort", how="inner") + json_str = plan.to_json() + + plan2 = TransformPlan.from_json(json_str) + result1, _ = plan.process(main_df, references={"cohort": cohort_df}) + result2, _ = plan2.process(main_df, references={"cohort": cohort_df}) + assert result1.equals(result2) + + def test_to_dict_contains_join(self) -> None: + plan = TransformPlan().join( + on="id", + right_name="ref", + how="left", + select_columns=["col1"], + ) + d = plan.to_dict() + assert d["steps"][0]["operation"] == "join" + assert d["steps"][0]["params"]["right_name"] == "ref" + assert d["steps"][0]["params"]["select_columns"] == ["col1"] + + def test_to_python_contains_join(self) -> None: + plan = TransformPlan().join(on="id", right_name="ref", how="inner") + code = plan.to_python() + assert ".join(" in code + assert 'right_name="ref"' in code + + +class TestJoinValidation: + """Tests for validation with and without references.""" + + def test_validation_without_references(self, main_df: pl.DataFrame) -> None: + plan = TransformPlan().join(on="person_id", right_name="cohort", how="inner") + result = plan.validate(main_df) + # Should pass — left-side columns exist + assert result.is_valid + + def test_validation_with_references( + self, main_df: pl.DataFrame, cohort_df: pl.DataFrame + ) -> None: + plan = TransformPlan().join(on="person_id", right_name="cohort", how="inner") + result = plan.validate(main_df, references={"cohort": cohort_df}) + assert result.is_valid + + def test_validation_fails_missing_left_column( + self, main_df: pl.DataFrame, cohort_df: pl.DataFrame + ) -> None: + plan = TransformPlan().join(on="nonexistent", right_name="cohort", how="inner") + result = plan.validate(main_df, references={"cohort": cohort_df}) + assert not result.is_valid + + def test_validation_fails_missing_right_column(self, main_df: pl.DataFrame) -> None: + ref = pl.DataFrame({"other_col": [1, 2]}) + plan = TransformPlan().join(on="person_id", right_name="ref", how="inner") + result = plan.validate(main_df, references={"ref": ref}) + assert not result.is_valid + + def test_validation_with_select_columns_adds_to_schema( + self, main_df: pl.DataFrame, concepts_df: pl.DataFrame + ) -> None: + plan = ( + TransformPlan() + .join( + on="score", + right_name="concepts", + how="left", + right_on="concept_id", + select_columns=["concept_name"], + ) + .str_upper("concept_name") + ) + result = plan.validate(main_df, references={"concepts": concepts_df}) + assert result.is_valid + + +class TestJoinDryRun: + """Tests for dry run with join.""" + + def test_dry_run_shows_added_columns( + self, main_df: pl.DataFrame, concepts_df: pl.DataFrame + ) -> None: + plan = TransformPlan().join( + on="score", + right_name="concepts", + how="left", + right_on="concept_id", + select_columns=["concept_name"], + ) + preview = plan.dry_run(main_df, references={"concepts": concepts_df}) + assert preview.is_valid + assert "concept_name" in preview.output_columns + + def test_dry_run_without_references(self, main_df: pl.DataFrame) -> None: + plan = TransformPlan().join(on="person_id", right_name="cohort", how="inner") + preview = plan.dry_run(main_df) + # Should pass — just checks left-side columns + assert preview.is_valid + + +class TestJoinChunking: + """Tests for chunking compatibility.""" + + def test_join_is_global(self) -> None: + plan = TransformPlan().join(on="id", right_name="ref", how="inner") + validation = plan.validate_chunked( + schema={"id": pl.Int64()}, + ) + assert not validation.is_valid + assert "join" in validation.global_operations + + +class TestJoinProtocol: + """Tests for protocol recording.""" + + def test_protocol_records_reference_hashes( + self, main_df: pl.DataFrame, cohort_df: pl.DataFrame + ) -> None: + plan = TransformPlan().join(on="person_id", right_name="cohort", how="inner") + _, protocol = plan.process(main_df, references={"cohort": cohort_df}) + meta = protocol.metadata + assert "references" in meta + assert "cohort" in meta["references"] + assert "hash" in meta["references"]["cohort"] + assert "shape" in meta["references"]["cohort"] + + def test_protocol_step_has_right_name( + self, main_df: pl.DataFrame, cohort_df: pl.DataFrame + ) -> None: + plan = TransformPlan().join(on="person_id", right_name="cohort", how="inner") + _, protocol = plan.process(main_df, references={"cohort": cohort_df}) + steps = protocol.to_dict()["steps"] + assert steps[0]["operation"] == "join" + assert steps[0]["params"]["right_name"] == "cohort" + # right_data should NOT be in params (not serializable) + assert "right_data" not in steps[0]["params"] + + +class TestJoinMultiColumn: + """Tests for multi-column joins.""" + + def test_join_on_multiple_columns(self) -> None: + left = pl.DataFrame({"a": [1, 2, 3], "b": ["x", "y", "z"], "val": [10, 20, 30]}) + right = pl.DataFrame({"a": [1, 2], "b": ["x", "y"], "extra": [100, 200]}) + plan = TransformPlan().join(on=["a", "b"], right_name="right", how="inner") + result, _ = plan.process(left, references={"right": right}) + assert len(result) == 2 + assert "extra" in result.columns diff --git a/transformplan/backends/base.py b/transformplan/backends/base.py index 947da0b..c0ac565 100644 --- a/transformplan/backends/base.py +++ b/transformplan/backends/base.py @@ -1,11 +1,11 @@ """Abstract base class for TransformPlan backends. -This module defines the Backend ABC with all 88 operation methods that each +This module defines the Backend ABC with all 89 operation methods that each backend must implement. Using ABC (not typing.Protocol) gives runtime enforcement that all methods are implemented, catching mistakes at instantiation. Classes: - Backend: Abstract base class with 88 abstract methods. + Backend: Abstract base class with 89 abstract methods. """ from __future__ import annotations @@ -24,7 +24,7 @@ class Backend(ABC): """Abstract base class defining the operation interface for backends. - Each backend must implement all 88 operations. Methods receive data + Each backend must implement all 89 operations. Methods receive data and operation-specific parameters, and return transformed data. Subclasses must set the ``name`` class variable to a unique identifier @@ -699,3 +699,20 @@ def map_null_to_value(self, data: Any, column: str, value: Any) -> Any: ... @abstractmethod def map_value_to_null(self, data: Any, column: str, value: Any) -> Any: ... + + # ========================================================================= + # Join operations (1) + # ========================================================================= + + @abstractmethod + def join( + self, + data: Any, + right_data: Any, + on: list[str], + how: str, + suffix: str, + left_on: list[str] | None = None, + right_on: list[str] | None = None, + select_columns: list[str] | None = None, + ) -> Any: ... diff --git a/transformplan/backends/duckdb.py b/transformplan/backends/duckdb.py index bf0210c..7c0deda 100644 --- a/transformplan/backends/duckdb.py +++ b/transformplan/backends/duckdb.py @@ -1,6 +1,6 @@ """DuckDB backend for TransformPlan. -This module implements all 88 operations using DuckDB's ``DuckDBPyRelation`` +This module implements all 89 operations using DuckDB's ``DuckDBPyRelation`` as the data type. Every operation takes a relation, generates SQL using the relation's ``sql_query()`` as a subquery, and returns a new relation — keeping the pipeline composable and lazy. @@ -1440,6 +1440,60 @@ def map_value_to_null( expr = f"CASE WHEN {qc} = {_v(value)} THEN NULL ELSE {qc} END" return self._replace_col(data, column, expr) + # ========================================================================= + # Join operations (1) + # ========================================================================= + + def join( + self, + data: duckdb.DuckDBPyRelation, + right_data: duckdb.DuckDBPyRelation, + on: list[str], + how: str, + suffix: str, + left_on: list[str] | None = None, + right_on: list[str] | None = None, + select_columns: list[str] | None = None, + ) -> duckdb.DuckDBPyRelation: + left_cols = left_on or on + right_cols = right_on or on + right_join_col_set = set(right_cols) + + # Determine which right-side columns to include + right_schema_cols = list(right_data.columns) + if select_columns is not None: + right_keep = list(dict.fromkeys(right_cols + list(select_columns))) + right_schema_cols = [c for c in right_schema_cols if c in set(right_keep)] + + left_col_set = set(data.columns) + + # Build SELECT list + select_parts = [f"_left.{_q(c)}" for c in data.columns] + for c in right_schema_cols: + if c in right_join_col_set: + continue + alias = f"{c}{suffix}" if c in left_col_set else c + select_parts.append(f"_right.{_q(c)} AS {_q(alias)}") + + select_sql = ", ".join(select_parts) + + # Build ON clause + on_parts = [ + f"_left.{_q(lc)} = _right.{_q(rc)}" + for lc, rc in zip(left_cols, right_cols, strict=False) + ] + on_sql = " AND ".join(on_parts) + + join_type = "INNER" if how == "inner" else "LEFT" + left_sub = f"({data.sql_query()}) AS _left" + right_sub = f"({right_data.sql_query()}) AS _right" + + sql = ( + f"SELECT {select_sql} FROM {left_sub} " + f"{join_type} JOIN {right_sub} ON {on_sql}" + ) + return self._con.sql(sql) + # ========================================================================= # Internal helpers # ========================================================================= diff --git a/transformplan/backends/polars.py b/transformplan/backends/polars.py index 6ba0371..1d6768a 100644 --- a/transformplan/backends/polars.py +++ b/transformplan/backends/polars.py @@ -1,6 +1,6 @@ """Polars backend for TransformPlan. -This module implements all 88 operations using the Polars DataFrame library. +This module implements all 89 operations using the Polars DataFrame library. It is the default backend and the reference implementation. Classes: @@ -1057,3 +1057,35 @@ def map_value_to_null( .otherwise(pl.col(column)) .alias(column) ) + + # ========================================================================= + # Join operations (1) + # ========================================================================= + + def join( + self, + data: pl.DataFrame, + right_data: pl.DataFrame, + on: list[str], + how: str, + suffix: str, + left_on: list[str] | None = None, + right_on: list[str] | None = None, + select_columns: list[str] | None = None, + ) -> pl.DataFrame: + right = right_data + effective_left_on = left_on or on + effective_right_on = right_on or on + + if select_columns is not None: + keep = list(dict.fromkeys(effective_right_on + list(select_columns))) + right = right.select(keep) + + join_kwargs: dict[str, Any] = {"how": how, "suffix": suffix} + if effective_left_on != effective_right_on: + join_kwargs["left_on"] = effective_left_on + join_kwargs["right_on"] = effective_right_on + else: + join_kwargs["on"] = on + + return data.join(right, **join_kwargs) diff --git a/transformplan/chunking.py b/transformplan/chunking.py index 071970d..453366b 100644 --- a/transformplan/chunking.py +++ b/transformplan/chunking.py @@ -147,6 +147,8 @@ class OperationMeta: "rows_sample": OperationMeta(ChunkMode.GLOBAL), "rows_head": OperationMeta(ChunkMode.GLOBAL), "rows_tail": OperationMeta(ChunkMode.GLOBAL), + # Join operations - global (reference table needed in full) + "join": OperationMeta(ChunkMode.GLOBAL), } diff --git a/transformplan/core.py b/transformplan/core.py index e6712c5..0d6d4f4 100644 --- a/transformplan/core.py +++ b/transformplan/core.py @@ -61,14 +61,21 @@ class TransformPlanBase: VERSION = "1.0" - def __init__(self, backend: Backend | None = None) -> None: - """Initialize an empty TransformPlanBase. + def __init__(self) -> None: + """Initialize an empty TransformPlanBase.""" + self._operations: list[tuple[str, dict[str, Any]]] = [] + + @staticmethod + def _resolve_backend(backend: Backend | None = None) -> Backend: + """Resolve backend, defaulting to PolarsBackend. Args: - backend: Backend to use for execution. Defaults to PolarsBackend. + backend: Optional backend override. + + Returns: + The resolved backend instance. """ - self._operations: list[tuple[str, dict[str, Any]]] = [] - self._backend: Backend = backend or PolarsBackend() + return backend or PolarsBackend() def _register( self, @@ -88,6 +95,8 @@ def process( data: Any, # noqa: ANN401 *, validate: bool = True, + references: dict[str, Any] | None = None, + backend: Backend | None = None, ) -> tuple[Any, Protocol]: """Execute all registered operations and return transformed data with protocol. @@ -96,43 +105,100 @@ def process( validate: If True, validate schema before execution (default). Set to False for performance in hot loops with pre-validated pipelines. + references: Named reference tables for join operations. Keys are + symbolic names used in join(right_name=...), values are the + actual data (DataFrame or relation). + backend: Backend to use for execution. Defaults to PolarsBackend. Returns: Tuple of (processed data, Protocol). + + Raises: + ValueError: If a join operation references a table not in references. """ + resolved = self._resolve_backend(backend) + if validate: + ref_schemas = self._extract_reference_schemas(references, resolved) validate_schema( - self._operations, self._backend.get_schema(data), self._backend + self._operations, + resolved.get_schema(data), + resolved, + references=ref_schemas, ).raise_if_invalid() protocol = Protocol() - protocol.set_input( - self._backend.compute_hash(data), self._backend.get_shape(data) - ) + protocol.set_input(resolved.compute_hash(data), resolved.get_shape(data)) + + # Record reference hashes in protocol metadata + if references: + ref_meta = {} + for name, ref_data in references.items(): + ref_meta[name] = { + "hash": resolved.compute_hash(ref_data), + "shape": list(resolved.get_shape(ref_data)), + } + protocol.set_metadata(references=ref_meta) for op_name, params in self._operations: - old_shape = self._backend.get_shape(data) + old_shape = resolved.get_shape(data) start = time.perf_counter() - data = getattr(self._backend, op_name)(data, **params) + if op_name == "join": + right_name = params["right_name"] + if references is None or right_name not in references: + msg = f"Reference '{right_name}' not found. Pass it via references={{'{right_name}': ...}} in process()." + raise ValueError(msg) + dispatch_params = {k: v for k, v in params.items() if k != "right_name"} + dispatch_params["right_data"] = references[right_name] + data = getattr(resolved, op_name)(data, **dispatch_params) + else: + data = getattr(resolved, op_name)(data, **params) elapsed = time.perf_counter() - start protocol.add_step( operation=op_name, params=params, old_shape=old_shape, - new_shape=self._backend.get_shape(data), + new_shape=resolved.get_shape(data), elapsed=elapsed, - output_hash=self._backend.compute_hash(data), + output_hash=resolved.compute_hash(data), ) return data, protocol - def validate(self, data: Any) -> ValidationResult: # noqa: ANN401 + @staticmethod + def _extract_reference_schemas( + references: dict[str, Any] | None, backend: Backend + ) -> dict[str, dict[str, Any]] | None: + """Extract schemas from reference tables for validation. + + Args: + references: Named reference tables. + backend: Backend to use for schema extraction. + + Returns: + Dict mapping reference names to their schemas, or None. + """ + if references is None: + return None + return { + name: backend.get_schema(ref_data) for name, ref_data in references.items() + } + + def validate( + self, + data: Any, # noqa: ANN401 + *, + references: dict[str, Any] | None = None, + backend: Backend | None = None, + ) -> ValidationResult: """Validate all operations against the data schema without executing. Args: data: Input data (Polars DataFrame, DuckDB relation, etc.). + references: Named reference tables for join operations. + backend: Backend to use for validation. Defaults to PolarsBackend. Returns: ValidationResult with any errors found. @@ -146,11 +212,22 @@ def validate(self, data: Any) -> ValidationResult: # noqa: ANN401 else: df, protocol = plan.process(df) """ + resolved = self._resolve_backend(backend) + ref_schemas = self._extract_reference_schemas(references, resolved) return validate_schema( - self._operations, self._backend.get_schema(data), self._backend + self._operations, + resolved.get_schema(data), + resolved, + references=ref_schemas, ) - def dry_run(self, data: Any) -> DryRunResult: # noqa: ANN401 + def dry_run( + self, + data: Any, # noqa: ANN401 + *, + references: dict[str, Any] | None = None, + backend: Backend | None = None, + ) -> DryRunResult: """Preview what the pipeline will do without executing it. Performs validation and shows step-by-step schema changes, @@ -158,6 +235,8 @@ def dry_run(self, data: Any) -> DryRunResult: # noqa: ANN401 Args: data: DataFrame to preview against. + references: Named reference tables for join operations. + backend: Backend to use for dry run. Defaults to PolarsBackend. Returns: DryRunResult with step-by-step preview. @@ -174,8 +253,13 @@ def dry_run(self, data: Any) -> DryRunResult: # noqa: ANN401 if preview.is_valid: df, protocol = plan.process(df) """ + resolved = self._resolve_backend(backend) + ref_schemas = self._extract_reference_schemas(references, resolved) return dry_run_schema( - self._operations, self._backend.get_schema(data), self._backend + self._operations, + resolved.get_schema(data), + resolved, + references=ref_schemas, ) def to_dict(self) -> dict[str, Any]: @@ -193,26 +277,18 @@ def to_dict(self) -> dict[str, Any]: } ) - result: dict[str, Any] = { + return { "version": self.VERSION, "steps": steps, } - # Include backend identifier for forward compatibility - if not isinstance(self._backend, PolarsBackend): - result["backend"] = self._backend.name - - return result - @classmethod - def from_dict(cls, data: dict[str, Any], backend: Backend | None = None) -> Self: + def from_dict(cls, data: dict[str, Any]) -> Self: """Deserialize a pipeline from a dictionary. Args: - data: Dictionary with 'steps' list. - backend: Optional backend override. If provided, uses this backend - instead of the one stored in the serialized data. This is - required for DuckDB when you need a specific connection. + data: Dictionary with 'steps' list. Any 'backend' key from + older serialized plans is silently ignored. Returns: New TransformPlan instance with operations loaded. @@ -220,20 +296,7 @@ def from_dict(cls, data: dict[str, Any], backend: Backend | None = None) -> Self Raises: ValueError: If an unknown operation or invalid parameters are encountered. """ - if backend is None: - # Read backend from serialized data (default: polars) - backend_name = data.get("backend", "polars") - if backend_name == "polars": - backend = None # default - elif backend_name == "duckdb": - from transformplan.backends.duckdb import DuckDBBackend - - backend = DuckDBBackend() - else: - msg = f"Unsupported backend: {backend_name}" - raise ValueError(msg) - - plan = cls(backend=backend) + plan = cls() for step in data.get("steps", []): op_name = step["operation"] @@ -272,13 +335,11 @@ def to_json(self, path: str | Path | None = None, indent: int = 2) -> str: return json_str @classmethod - def from_json(cls, source: str | Path, backend: Backend | None = None) -> Self: + def from_json(cls, source: str | Path) -> Self: """Deserialize a pipeline from JSON. Args: source: Either a JSON string or a path to a JSON file. - backend: Optional backend override. If provided, uses this backend - instead of the one stored in the serialized data. Returns: New TransformPlan instance. @@ -288,7 +349,7 @@ def from_json(cls, source: str | Path, backend: Backend | None = None) -> Self: else: content = source - return cls.from_dict(json.loads(content), backend=backend) + return cls.from_dict(json.loads(content)) def __len__(self) -> int: """Return number of registered operations. @@ -458,6 +519,8 @@ def process_chunked( partition_key: str | list[str] | None = None, chunk_size: int = 100_000, validate: bool = True, + references: dict[str, Any] | None = None, + backend: Backend | None = None, ) -> tuple[pl.DataFrame, ChunkedProtocol]: """Process a large Parquet file in chunks. @@ -475,6 +538,9 @@ def process_chunked( partition_key is set, as chunks are sized to respect group boundaries). validate: Whether to validate operations before processing. + references: Named reference tables for join operations (reserved + for forward compatibility). + backend: Backend to use for execution. Defaults to PolarsBackend. Returns: Tuple of (result DataFrame, ChunkedProtocol with processing details). @@ -499,6 +565,7 @@ def process_chunked( ) protocol.print() """ + resolved = self._resolve_backend(backend) source_path = Path(source) if not source_path.exists(): msg = f"Source file not found: {source_path}" @@ -527,7 +594,7 @@ def process_chunked( if validate: from transformplan.validation import validate_schema - schema_validation = validate_schema(self._operations, schema, self._backend) + schema_validation = validate_schema(self._operations, schema, resolved) schema_validation.raise_if_invalid() # Initialize protocol @@ -553,14 +620,14 @@ def process_chunked( for chunk_index, chunk_df in enumerate(chunk_iter): start = time.perf_counter() - input_hash = self._backend.compute_hash(chunk_df) + input_hash = resolved.compute_hash(chunk_df) input_rows = len(chunk_df) # Apply all operations to this chunk for op_name, params in self._operations: - chunk_df = getattr(self._backend, op_name)(chunk_df, **params) + chunk_df = getattr(resolved, op_name)(chunk_df, **params) - output_hash = self._backend.compute_hash(chunk_df) + output_hash = resolved.compute_hash(chunk_df) elapsed = time.perf_counter() - start protocol.add_chunk( diff --git a/transformplan/ops/__init__.py b/transformplan/ops/__init__.py index 219b696..4255e24 100644 --- a/transformplan/ops/__init__.py +++ b/transformplan/ops/__init__.py @@ -11,6 +11,7 @@ StrOps: String operations (replace, split, concat, extract, etc.). DatetimeOps: Date/time operations (extract year/month, parse, format, etc.). MapOps: Value mapping operations (map_values, discretize, etc.). + JoinOps: Join operations (join with reference tables). The TransformPlan class combines all these mixins with TransformPlanBase to provide the complete transformation API. @@ -18,6 +19,7 @@ from transformplan.ops.column import ColumnOps from transformplan.ops.datetime import DatetimeOps +from transformplan.ops.join import JoinOps from transformplan.ops.map import MapOps from transformplan.ops.math import MathOps from transformplan.ops.rows import RowOps @@ -26,6 +28,7 @@ __all__ = [ "ColumnOps", "DatetimeOps", + "JoinOps", "MapOps", "MathOps", "RowOps", diff --git a/transformplan/ops/join.py b/transformplan/ops/join.py new file mode 100644 index 0000000..809594e --- /dev/null +++ b/transformplan/ops/join.py @@ -0,0 +1,58 @@ +"""Join operations mixin.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal, Sequence + +if TYPE_CHECKING: + from typing_extensions import Self + + +class JoinOps: + """Mixin providing join operations.""" + + if TYPE_CHECKING: + + def _register(self, op_name: str, params: dict[str, Any]) -> Self: ... + + def join( + self, + on: str | Sequence[str], + right_name: str, + how: Literal["inner", "left"] = "inner", + *, + left_on: str | Sequence[str] | None = None, + right_on: str | Sequence[str] | None = None, + suffix: str = "_right", + select_columns: Sequence[str] | None = None, + ) -> Self: + """Join with a reference table resolved at execution time. + + Args: + on: Join key column(s) (both sides if left_on/right_on not set). + right_name: Symbolic name resolved via references in process(). + how: "inner" for filtering, "left" for enrichment. + left_on: Left-side join columns (overrides on). + right_on: Right-side join columns (overrides on). + suffix: Suffix for duplicate column names from right table. + select_columns: Columns to keep from right table (None = all). + + Returns: + Self for method chaining. + """ + on_list = [on] if isinstance(on, str) else list(on) + params: dict[str, Any] = { + "on": on_list, + "right_name": right_name, + "how": how, + "suffix": suffix, + } + if left_on is not None: + params["left_on"] = [left_on] if isinstance(left_on, str) else list(left_on) + if right_on is not None: + params["right_on"] = ( + [right_on] if isinstance(right_on, str) else list(right_on) + ) + if select_columns is not None: + params["select_columns"] = list(select_columns) + return self._register("join", params) diff --git a/transformplan/plan.py b/transformplan/plan.py index 7835979..4b1bf72 100644 --- a/transformplan/plan.py +++ b/transformplan/plan.py @@ -21,19 +21,23 @@ from __future__ import annotations -from typing import TYPE_CHECKING - from transformplan.core import TransformPlanBase -from transformplan.ops import ColumnOps, DatetimeOps, MapOps, MathOps, RowOps, StrOps - -if TYPE_CHECKING: - from transformplan.backends.base import Backend +from transformplan.ops import ( + ColumnOps, + DatetimeOps, + JoinOps, + MapOps, + MathOps, + RowOps, + StrOps, +) class TransformPlan( TransformPlanBase, ColumnOps, DatetimeOps, + JoinOps, MapOps, MathOps, RowOps, @@ -51,10 +55,6 @@ class TransformPlan( ) """ - def __init__(self, backend: Backend | None = None) -> None: - """Initialize TransformPlan with optional backend. - - Args: - backend: Backend to use for execution. Defaults to PolarsBackend. - """ - super().__init__(backend=backend) + def __init__(self) -> None: + """Initialize TransformPlan.""" + super().__init__() diff --git a/transformplan/validation.py b/transformplan/validation.py index 45d00d3..b2f3a10 100644 --- a/transformplan/validation.py +++ b/transformplan/validation.py @@ -1656,10 +1656,71 @@ def _validate_map_label( } +# ============================================================================= +# Join operation validator +# ============================================================================= + + +def _validate_join( + tracker: SchemaTracker, + params: dict[str, Any], + result: ValidationResult, + step: int, + references: dict[str, dict[str, Any]] | None = None, +) -> None: + on = params["on"] + left_on = params.get("left_on", on) + right_on = params.get("right_on", on) + right_name = params["right_name"] + suffix = params.get("suffix", "_right") + select_columns = params.get("select_columns") + + # Check left-side join columns + for col in left_on: + _check_column_exists(tracker, col, result, step, "join") + + # If reference schema is available, validate right side and update tracker + if references is not None and right_name in references: + ref_schema = references[right_name] + right_col_set = set(ref_schema.keys()) + + # Check right-side join columns exist in reference + for col in right_on: + if col not in right_col_set: + result.add_error( + step, + "join", + f"Column '{col}' does not exist in reference '{right_name}'", + ) + return + + # Determine which right-side columns to add + right_join_col_set = set(right_on) + if select_columns is not None: + right_output_cols = [ + c for c in select_columns if c not in right_join_col_set + ] + else: + right_output_cols = [c for c in ref_schema if c not in right_join_col_set] + + left_col_set = tracker.columns + for col in right_output_cols: + if col not in ref_schema: + result.add_error( + step, + "join", + f"Column '{col}' does not exist in reference '{right_name}'", + ) + continue + alias = f"{col}{suffix}" if col in left_col_set else col + tracker.add_column(alias, ref_schema[col]) + + def validate_schema( operations: list[tuple[str, dict[str, Any]]], schema: dict[str, Any], backend: Backend | None = None, + references: dict[str, dict[str, Any]] | None = None, ) -> ValidationResult: """Validate all operations against the given schema. @@ -1667,6 +1728,7 @@ def validate_schema( operations: List of (op_name, params) tuples from TransformPlan. schema: Initial DataFrame schema. backend: Backend for type classification. Defaults to PolarsBackend. + references: Schema dicts for reference tables (for join validation). Returns: ValidationResult with any errors found. @@ -1675,9 +1737,12 @@ def validate_schema( tracker = SchemaTracker(schema, backend=backend) for step, (op_name, params) in enumerate(operations, start=1): - validator = _VALIDATORS.get(op_name) - if validator: - validator(tracker, params, result, step) + if op_name == "join": + _validate_join(tracker, params, result, step, references=references) + else: + validator = _VALIDATORS.get(op_name) + if validator: + validator(tracker, params, result, step) return result @@ -1686,6 +1751,7 @@ def dry_run_schema( operations: list[tuple[str, dict[str, Any]]], schema: dict[str, Any], backend: Backend | None = None, + references: dict[str, dict[str, Any]] | None = None, ) -> DryRunResult: """Perform a dry run showing what each operation will do. @@ -1693,6 +1759,7 @@ def dry_run_schema( operations: List of (op_name, params) tuples from TransformPlan. schema: Initial DataFrame schema. backend: Backend for type classification. Defaults to PolarsBackend. + references: Schema dicts for reference tables (for join validation). Returns: DryRunResult with step-by-step preview and validation. @@ -1708,9 +1775,14 @@ def dry_run_schema( # Run validation (which also updates tracker) step_errors_before = len(validation_result.errors) - validator = _VALIDATORS.get(op_name) - if validator: - validator(tracker, params, validation_result, step_num) + if op_name == "join": + _validate_join( + tracker, params, validation_result, step_num, references=references + ) + else: + validator = _VALIDATORS.get(op_name) + if validator: + validator(tracker, params, validation_result, step_num) # Capture schema after schema_after = {k: tracker.type_name(v) for k, v in tracker._schema.items()} diff --git a/uv.lock b/uv.lock index 48a2c9d..25fb328 100644 --- a/uv.lock +++ b/uv.lock @@ -1281,7 +1281,7 @@ wheels = [ [[package]] name = "transformplan" -version = "0.1.2" +version = "0.1.3" source = { editable = "." } dependencies = [ { name = "pandas", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },