diff --git a/src/conductor/client/automator/async_task_runner.py b/src/conductor/client/automator/async_task_runner.py index 97c2f1ac..94b5b686 100644 --- a/src/conductor/client/automator/async_task_runner.py +++ b/src/conductor/client/automator/async_task_runner.py @@ -24,7 +24,7 @@ from conductor.client.http.models.task_result import TaskResult from conductor.client.http.models.task_result_status import TaskResultStatus from conductor.client.http.models.schema_def import SchemaDef, SchemaType -from conductor.client.http.rest import AuthorizationException +from conductor.client.http.rest import ApiException, AuthorizationException from conductor.client.orkes.orkes_metadata_client import OrkesMetadataClient from conductor.client.orkes.orkes_schema_client import OrkesSchemaClient from conductor.client.telemetry.metrics_collector import MetricsCollector @@ -111,6 +111,7 @@ def __init__( # Semaphore will be created in run() within the event loop self._semaphore = None self._shutdown = False # Flag to indicate graceful shutdown + self._v2_available = True # Tracks whether server supports update-task-v2 async def run(self) -> None: """Main async loop - runs continuously in single event loop.""" @@ -566,11 +567,15 @@ async def __async_execute_and_update_task(self, task: Task) -> None: """Execute task and update result in a tight loop (async version). Uses the v2 update endpoint which returns the next task to process. - Loops: execute -> update_v2 (get next task) -> execute -> ... - The semaphore is held for the entire loop duration, keeping the slot occupied. + Loops: execute → update_v2 (get next task) → execute → … + + The semaphore is held for the entire loop. This is correct because + ``_running_tasks`` (which gates polling) tracks the *coroutine*, not + individual tasks — releasing the semaphore mid-loop would not allow + new coroutines to be created (the capacity gate would still block). + For async workers a slow ``await`` naturally yields to the event loop, + so other coroutines make progress regardless. """ - # Acquire semaphore for entire task lifecycle (execution + update) - # This ensures we never exceed thread_count tasks in any stage of processing async with self._semaphore: try: while task is not None and not self._shutdown: @@ -793,7 +798,10 @@ def __merge_context_modifications(self, task_result: TaskResult, context_result: task_result.output_data = context_result.output_data async def __async_update_task(self, task_result: TaskResult): - """Async update task result using v2 endpoint. Returns the next Task to process, or None.""" + """Update task result. Uses v2 endpoint if available, falls back to v1 otherwise. + + v2 returns the next Task to process (tight loop). v1 returns None (poll-based). + """ if not isinstance(task_result, TaskResult): return None task_definition_name = self.worker.get_task_definition_name() @@ -815,15 +823,53 @@ async def __async_update_task(self, task_result: TaskResult): # Exponential backoff: [10s, 20s, 30s] before retry await asyncio.sleep(attempt * 10) try: - next_task = await self.async_task_client.update_task_v2(body=task_result) - logger.debug( - "Updated async task (v2), id: %s, workflow_instance_id: %s, task_definition_name: %s, next_task: %s", + if self._v2_available: + next_task = await self.async_task_client.update_task_v2(body=task_result) + logger.debug( + "Updated async task (v2), id: %s, workflow_instance_id: %s, task_definition_name: %s, next_task: %s", + task_result.task_id, + task_result.workflow_instance_id, + task_definition_name, + next_task.task_id if next_task else None + ) + return next_task + else: + await self.async_task_client.update_task(body=task_result) + logger.debug( + "Updated async task (v1), id: %s, workflow_instance_id: %s, task_definition_name: %s", + task_result.task_id, + task_result.workflow_instance_id, + task_definition_name, + ) + return None + except ApiException as e: + if self._v2_available and e.status in (404, 501): + logger.warning( + "update-task-v2 not supported by server (HTTP %s), falling back to v1 for task_definition: %s", + e.status, task_definition_name + ) + self._v2_available = False + # Immediately retry this attempt with v1 + try: + await self.async_task_client.update_task(body=task_result) + return None + except Exception as fallback_err: + last_exception = fallback_err + continue + last_exception = e + if self.metrics_collector is not None: + self.metrics_collector.increment_task_update_error( + task_definition_name, type(e) + ) + logger.error( + "Failed to update async task (attempt %d/%d), id: %s, workflow_instance_id: %s, task_definition_name: %s, reason: %s", + attempt + 1, + retry_count, task_result.task_id, task_result.workflow_instance_id, task_definition_name, - next_task.task_id if next_task else None + traceback.format_exc() ) - return next_task except Exception as e: last_exception = e if self.metrics_collector is not None: diff --git a/src/conductor/client/automator/task_runner.py b/src/conductor/client/automator/task_runner.py index 16c8f432..10b6b890 100644 --- a/src/conductor/client/automator/task_runner.py +++ b/src/conductor/client/automator/task_runner.py @@ -25,7 +25,7 @@ from conductor.client.http.models.task_result import TaskResult from conductor.client.http.models.task_result_status import TaskResultStatus from conductor.client.http.models.schema_def import SchemaDef, SchemaType -from conductor.client.http.rest import AuthorizationException +from conductor.client.http.rest import ApiException, AuthorizationException from conductor.client.orkes.orkes_metadata_client import OrkesMetadataClient from conductor.client.orkes.orkes_schema_client import OrkesSchemaClient from conductor.client.telemetry.metrics_collector import MetricsCollector @@ -92,6 +92,7 @@ def __init__( self._last_poll_time = 0 # Track last poll to avoid excessive polling when queue is empty self._consecutive_empty_polls = 0 # Track empty polls to implement backoff self._shutdown = False # Flag to indicate graceful shutdown + self._v2_available = True # Tracks whether server supports update-task-v2 def run(self) -> None: if self.configuration is not None: @@ -506,18 +507,18 @@ def __execute_and_update_task(self, task: Task) -> None: """Execute task and update result in a tight loop (runs in thread pool). Uses the v2 update endpoint which returns the next task to process. - Loops: execute -> update_v2 (get next task) -> execute -> ... - The loop breaks when no next task is available, the task is async/in-progress, - or shutdown is requested. + Loops: execute → update_v2 (get next task) → execute → … + The loop breaks when no next task is available, the task is async / + in-progress, or shutdown is requested. """ try: while task is not None and not self._shutdown: task_result = self.__execute_task(task) - # If task returned None, it's an async task running in background - don't update yet + # If task returned None, it's an async task running in background if task_result is None: logger.debug("Task %s is running async, will update when complete", task.task_id) return - # If task returned TaskInProgress, it's running async - don't update yet + # If task returned TaskInProgress, it's running async if isinstance(task_result, TaskInProgress): logger.debug("Task %s is in progress, will update when complete", task.task_id) return @@ -824,7 +825,10 @@ def __merge_context_modifications(self, task_result: TaskResult, context_result: task_result.output_data = context_result.output_data def __update_task(self, task_result: TaskResult): - """Update task result using v2 endpoint. Returns the next Task to process, or None.""" + """Update task result. Uses v2 endpoint if available, falls back to v1 otherwise. + + v2 returns the next Task to process (tight loop). v1 returns None (poll-based). + """ if not isinstance(task_result, TaskResult): return None task_definition_name = self.worker.get_task_definition_name() @@ -845,15 +849,53 @@ def __update_task(self, task_result: TaskResult): # Exponential backoff: [10s, 20s, 30s] before retry time.sleep(attempt * 10) try: - next_task = self.task_client.update_task_v2(body=task_result) - logger.debug( - "Updated task (v2), id: %s, workflow_instance_id: %s, task_definition_name: %s, next_task: %s", + if self._v2_available: + next_task = self.task_client.update_task_v2(body=task_result) + logger.debug( + "Updated task (v2), id: %s, workflow_instance_id: %s, task_definition_name: %s, next_task: %s", + task_result.task_id, + task_result.workflow_instance_id, + task_definition_name, + next_task.task_id if next_task else None + ) + return next_task + else: + self.task_client.update_task(body=task_result) + logger.debug( + "Updated task (v1), id: %s, workflow_instance_id: %s, task_definition_name: %s", + task_result.task_id, + task_result.workflow_instance_id, + task_definition_name, + ) + return None + except ApiException as e: + if self._v2_available and e.status in (404, 501): + logger.warning( + "update-task-v2 not supported by server (HTTP %s), falling back to v1 for task_definition: %s", + e.status, task_definition_name + ) + self._v2_available = False + # Immediately retry this attempt with v1 + try: + self.task_client.update_task(body=task_result) + return None + except Exception as fallback_err: + last_exception = fallback_err + continue + last_exception = e + if self.metrics_collector is not None: + self.metrics_collector.increment_task_update_error( + task_definition_name, type(e) + ) + logger.error( + "Failed to update task (attempt %d/%d), id: %s, workflow_instance_id: %s, task_definition_name: %s, reason: %s", + attempt + 1, + retry_count, task_result.task_id, task_result.workflow_instance_id, task_definition_name, - next_task.task_id if next_task else None + traceback.format_exc() ) - return next_task except Exception as e: last_exception = e if self.metrics_collector is not None: diff --git a/tests/integration/test_update_task_v2_perf.py b/tests/integration/test_update_task_v2_perf.py index de2a39d0..671427a5 100644 --- a/tests/integration/test_update_task_v2_perf.py +++ b/tests/integration/test_update_task_v2_perf.py @@ -73,17 +73,17 @@ # --------------------------------------------------------------------------- @worker_task(task_definition_name="perf_type_a", thread_count=WORKER_THREADS, register_task_def=True) -def perf_worker_a(task_index: int = 0) -> dict: +async def perf_worker_a(task_index: int = 0) -> dict: return {"worker": "perf_type_a", "task_index": task_index} @worker_task(task_definition_name="perf_type_b", thread_count=WORKER_THREADS, register_task_def=True) -def perf_worker_b(task_index: int = 0) -> dict: +async def perf_worker_b(task_index: int = 0) -> dict: return {"worker": "perf_type_b", "task_index": task_index} @worker_task(task_definition_name="perf_type_c", thread_count=WORKER_THREADS, register_task_def=True) -def perf_worker_c(task_index: int = 0) -> dict: +async def perf_worker_c(task_index: int = 0) -> dict: return {"worker": "perf_type_c", "task_index": task_index} diff --git a/tests/integration/test_v2_fallback_intg.py b/tests/integration/test_v2_fallback_intg.py new file mode 100644 index 00000000..4f2258cf --- /dev/null +++ b/tests/integration/test_v2_fallback_intg.py @@ -0,0 +1,174 @@ +""" +Integration test for update-task-v2 graceful degradation. + +Verifies that when update-task-v2 is unavailable (or available), the SDK +correctly auto-detects and falls back to v1 while still completing workflows. + +Run: + python -m pytest tests/integration/test_v2_fallback_intg.py -v -s +""" + +import logging +import os +import sys +import time +import threading +import unittest + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) + +from conductor.client.automator.task_handler import TaskHandler +from conductor.client.configuration.configuration import Configuration +from conductor.client.http.models.start_workflow_request import StartWorkflowRequest +from conductor.client.http.models.workflow_def import WorkflowDef +from conductor.client.http.models.workflow_task import WorkflowTask +from conductor.client.orkes.orkes_metadata_client import OrkesMetadataClient +from conductor.client.orkes.orkes_workflow_client import OrkesWorkflowClient +from conductor.client.worker.worker_task import worker_task + +logger = logging.getLogger(__name__) + +WORKFLOW_NAME = "test_v2_fallback_workflow" +WORKFLOW_VERSION = 1 + + +# --------------------------------------------------------------------------- +# Workers +# --------------------------------------------------------------------------- + +@worker_task(task_definition_name="v2_fallback_task_a", thread_count=3, register_task_def=True) +async def fallback_worker_a(task_index: int = 0) -> dict: + return {"worker": "v2_fallback_task_a", "task_index": task_index} + + +@worker_task(task_definition_name="v2_fallback_task_b", thread_count=3, register_task_def=True) +async def fallback_worker_b(task_index: int = 0) -> dict: + return {"worker": "v2_fallback_task_b", "task_index": task_index} + + +# --------------------------------------------------------------------------- +# Test +# --------------------------------------------------------------------------- + +class TestV2FallbackIntegration(unittest.TestCase): + + @classmethod + def setUpClass(cls): + from tests.integration.conftest import skip_if_server_unavailable + skip_if_server_unavailable() + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(process)d] %(name)-45s %(levelname)-10s %(message)s", + ) + logging.getLogger("conductor.client").setLevel(logging.WARNING) + + cls.config = Configuration() + cls.workflow_client = OrkesWorkflowClient(cls.config) + cls.metadata_client = OrkesMetadataClient(cls.config) + + def test_0_register_workflow(self): + """Register workflow with 2 task types (3 tasks each).""" + tasks = [] + idx = 0 + for task_type, count in [("v2_fallback_task_a", 3), ("v2_fallback_task_b", 3)]: + for i in range(count): + idx += 1 + tasks.append( + WorkflowTask( + name=task_type, + task_reference_name=f"{task_type}_{i + 1}", + input_parameters={"task_index": idx}, + ) + ) + + workflow = WorkflowDef(name=WORKFLOW_NAME, version=WORKFLOW_VERSION) + workflow._tasks = tasks + try: + self.metadata_client.update_workflow_def(workflow, overwrite=True) + except Exception: + self.metadata_client.register_workflow_def(workflow, overwrite=True) + print(f"\n Registered workflow '{WORKFLOW_NAME}' with {len(tasks)} tasks") + + def test_1_workflows_complete_with_v2_or_fallback(self): + """Start workers and verify workflows complete regardless of v2 support. + + This test doesn't force a 404 — it runs against the real server. + If v2 is available, it uses v2. If not, it auto-detects and falls back. + Either way, all workflows should complete successfully. + """ + workflow_count = 5 + + handler_ready = threading.Event() + handler_ref = {} + + def _run_workers(): + handler = TaskHandler( + configuration=self.config, + scan_for_annotated_workers=True, + ) + handler_ref["h"] = handler + handler.start_processes() + handler_ready.set() + handler_ref["stop"] = threading.Event() + handler_ref["stop"].wait() + handler.stop_processes() + + worker_thread = threading.Thread(target=_run_workers, daemon=True) + worker_thread.start() + handler_ready.wait(timeout=30) + self.assertTrue(handler_ready.is_set(), "Workers failed to start within 30s") + time.sleep(3) # Warm up + + try: + # Submit workflows + workflow_ids = [] + for i in range(workflow_count): + req = StartWorkflowRequest() + req.name = WORKFLOW_NAME + req.version = WORKFLOW_VERSION + req.input = {"run_index": i} + wf_id = self.workflow_client.start_workflow(start_workflow_request=req) + workflow_ids.append(wf_id) + + print(f"\n Submitted {len(workflow_ids)} workflows") + + # Wait for completion + deadline = time.time() + 60 # 60s timeout + pending = set(workflow_ids) + completed = 0 + failed = 0 + + while pending and time.time() < deadline: + still_pending = set() + for wf_id in pending: + try: + wf = self.workflow_client.get_workflow(wf_id, include_tasks=False) + except Exception: + still_pending.add(wf_id) + continue + + if wf.status == "COMPLETED": + completed += 1 + elif wf.status in ("FAILED", "TERMINATED", "TIMED_OUT"): + failed += 1 + logger.warning("Workflow %s ended with status %s", wf_id, wf.status) + else: + still_pending.add(wf_id) + + pending = still_pending + if pending: + time.sleep(1) + + print(f" Results: {completed} completed, {failed} failed, {len(pending)} pending") + + self.assertEqual(len(pending), 0, f"{len(pending)} workflows did not complete in time") + self.assertEqual(completed, workflow_count, f"Expected {workflow_count} completed, got {completed}") + + finally: + handler_ref.get("stop", threading.Event()).set() + worker_thread.join(timeout=15) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/unit/automator/test_v2_fallback.py b/tests/unit/automator/test_v2_fallback.py new file mode 100644 index 00000000..968f7a48 --- /dev/null +++ b/tests/unit/automator/test_v2_fallback.py @@ -0,0 +1,394 @@ +""" +Unit tests for update-task-v2 graceful degradation to v1. + +Tests both sync TaskRunner and async AsyncTaskRunner to verify: +- On 404/501 from update_task_v2, falls back to update_task (v1) +- The _v2_available flag is set to False after first fallback +- Subsequent calls go directly to v1 (skip v2) +- The current task result is still persisted via v1 during fallback +""" + +import asyncio +import logging +import unittest +from unittest.mock import patch, Mock, AsyncMock + +from conductor.client.automator.task_runner import TaskRunner +from conductor.client.automator.async_task_runner import AsyncTaskRunner +from conductor.client.configuration.configuration import Configuration +from conductor.client.http.api.task_resource_api import TaskResourceApi +from conductor.client.http.models.task import Task +from conductor.client.http.models.task_result import TaskResult +from conductor.client.http.models.task_result_status import TaskResultStatus +from conductor.client.http.rest import ApiException +from conductor.client.worker.worker import Worker +from tests.unit.resources.workers import ClassWorker + + +class TestTaskRunnerV2Fallback(unittest.TestCase): + """Tests for sync TaskRunner v2 -> v1 fallback.""" + + TASK_ID = 'test_task_id' + WORKFLOW_INSTANCE_ID = 'test_workflow_id' + + def setUp(self): + logging.disable(logging.CRITICAL) + + def tearDown(self): + logging.disable(logging.NOTSET) + + @patch('time.sleep', Mock(return_value=None)) + def test_fallback_on_404(self): + """On 404 from update_task_v2, should fall back to update_task and return None.""" + with patch.object( + TaskResourceApi, 'update_task_v2', + side_effect=ApiException(status=404, reason="Not Found") + ): + with patch.object( + TaskResourceApi, 'update_task', + return_value='task_id_confirmation' + ) as mock_v1: + runner = self._create_runner() + self.assertTrue(runner._v2_available) + + result = runner._TaskRunner__update_task(self._create_task_result()) + + self.assertIsNone(result) + self.assertFalse(runner._v2_available) + mock_v1.assert_called_once() + + @patch('time.sleep', Mock(return_value=None)) + def test_fallback_on_501(self): + """On 501 from update_task_v2, should fall back to update_task and return None.""" + with patch.object( + TaskResourceApi, 'update_task_v2', + side_effect=ApiException(status=501, reason="Not Implemented") + ): + with patch.object( + TaskResourceApi, 'update_task', + return_value='task_id_confirmation' + ) as mock_v1: + runner = self._create_runner() + result = runner._TaskRunner__update_task(self._create_task_result()) + + self.assertIsNone(result) + self.assertFalse(runner._v2_available) + mock_v1.assert_called_once() + + @patch('time.sleep', Mock(return_value=None)) + def test_subsequent_calls_use_v1_directly(self): + """After fallback, subsequent calls should go to v1 directly, skipping v2.""" + with patch.object( + TaskResourceApi, 'update_task_v2', + side_effect=ApiException(status=404, reason="Not Found") + ) as mock_v2: + with patch.object( + TaskResourceApi, 'update_task', + return_value='ok' + ) as mock_v1: + runner = self._create_runner() + + # First call triggers fallback + runner._TaskRunner__update_task(self._create_task_result()) + self.assertEqual(mock_v2.call_count, 1) + self.assertEqual(mock_v1.call_count, 1) + + # Second call should skip v2 entirely + runner._TaskRunner__update_task(self._create_task_result()) + self.assertEqual(mock_v2.call_count, 1) # Still 1 — not called again + self.assertEqual(mock_v1.call_count, 2) + + @patch('time.sleep', Mock(return_value=None)) + def test_v2_success_no_fallback(self): + """When v2 succeeds, should return next task and not touch v1.""" + next_task = Task(task_id='next_task', workflow_instance_id='wf_2') + with patch.object( + TaskResourceApi, 'update_task_v2', + return_value=next_task + ): + with patch.object( + TaskResourceApi, 'update_task', + return_value='ok' + ) as mock_v1: + runner = self._create_runner() + result = runner._TaskRunner__update_task(self._create_task_result()) + + self.assertEqual(result, next_task) + self.assertTrue(runner._v2_available) + mock_v1.assert_not_called() + + @patch('time.sleep', Mock(return_value=None)) + def test_non_404_error_does_not_trigger_fallback(self): + """A 500 error should retry normally, not trigger v1 fallback.""" + with patch.object( + TaskResourceApi, 'update_task_v2', + side_effect=ApiException(status=500, reason="Internal Server Error") + ): + runner = self._create_runner() + result = runner._TaskRunner__update_task(self._create_task_result()) + + # All retries exhausted, still v2_available (not a 404/501) + self.assertTrue(runner._v2_available) + self.assertIsNone(result) + + @patch('time.sleep', Mock(return_value=None)) + def test_v1_fallback_failure_retries(self): + """If v1 also fails during fallback, should retry with backoff.""" + call_count = {'v1': 0} + + def v1_side_effect(**kwargs): + call_count['v1'] += 1 + if call_count['v1'] <= 2: + raise Exception("v1 also down") + return 'ok' + + with patch.object( + TaskResourceApi, 'update_task_v2', + side_effect=ApiException(status=404, reason="Not Found") + ): + with patch.object( + TaskResourceApi, 'update_task', + side_effect=v1_side_effect + ): + runner = self._create_runner() + result = runner._TaskRunner__update_task(self._create_task_result()) + + self.assertFalse(runner._v2_available) + # First v1 call fails (immediate fallback), then retries succeed + self.assertIsNone(result) + + def _create_runner(self): + return TaskRunner( + configuration=Configuration(), + worker=ClassWorker('task') + ) + + def _create_task_result(self): + return TaskResult( + task_id=self.TASK_ID, + workflow_instance_id=self.WORKFLOW_INSTANCE_ID, + worker_id='test_worker', + status=TaskResultStatus.COMPLETED, + output_data={'result': 42} + ) + + +class TestAsyncTaskRunnerV2Fallback(unittest.TestCase): + """Tests for async AsyncTaskRunner v2 -> v1 fallback.""" + + TASK_ID = 'test_task_id' + WORKFLOW_INSTANCE_ID = 'test_workflow_id' + AUTH_TOKEN = 'test_token' + + def setUp(self): + logging.disable(logging.CRITICAL) + + def tearDown(self): + logging.disable(logging.NOTSET) + + def test_fallback_on_404(self): + """On 404 from async update_task_v2, should fall back to update_task.""" + + async def simple_worker(value: int) -> dict: + return {'result': value} + + worker = Worker( + task_definition_name='test_v2_fallback', + execute_function=simple_worker, + thread_count=1 + ) + + config = Configuration() + config.AUTH_TOKEN = self.AUTH_TOKEN + runner = AsyncTaskRunner(worker=worker, configuration=config) + + async def run_test(): + runner.async_api_client = AsyncMock() + runner.async_task_client = AsyncMock() + runner._semaphore = asyncio.Semaphore(1) + + runner.async_task_client.update_task_v2 = AsyncMock( + side_effect=ApiException(status=404, reason="Not Found") + ) + runner.async_task_client.update_task = AsyncMock(return_value='ok') + + self.assertTrue(runner._v2_available) + + result = await runner._AsyncTaskRunner__async_update_task(self._create_task_result()) + + self.assertIsNone(result) + self.assertFalse(runner._v2_available) + runner.async_task_client.update_task.assert_called_once() + + asyncio.run(run_test()) + + def test_fallback_on_501(self): + """On 501 from async update_task_v2, should fall back to update_task.""" + + async def simple_worker(value: int) -> dict: + return {'result': value} + + worker = Worker( + task_definition_name='test_v2_fallback_501', + execute_function=simple_worker, + thread_count=1 + ) + + config = Configuration() + config.AUTH_TOKEN = self.AUTH_TOKEN + runner = AsyncTaskRunner(worker=worker, configuration=config) + + async def run_test(): + runner.async_api_client = AsyncMock() + runner.async_task_client = AsyncMock() + runner._semaphore = asyncio.Semaphore(1) + + runner.async_task_client.update_task_v2 = AsyncMock( + side_effect=ApiException(status=501, reason="Not Implemented") + ) + runner.async_task_client.update_task = AsyncMock(return_value='ok') + + result = await runner._AsyncTaskRunner__async_update_task(self._create_task_result()) + + self.assertIsNone(result) + self.assertFalse(runner._v2_available) + runner.async_task_client.update_task.assert_called_once() + + asyncio.run(run_test()) + + def test_subsequent_calls_use_v1_directly(self): + """After fallback, subsequent async calls should go to v1 directly.""" + + async def simple_worker(value: int) -> dict: + return {'result': value} + + worker = Worker( + task_definition_name='test_v2_subsequent', + execute_function=simple_worker, + thread_count=1 + ) + + config = Configuration() + config.AUTH_TOKEN = self.AUTH_TOKEN + runner = AsyncTaskRunner(worker=worker, configuration=config) + + async def run_test(): + runner.async_api_client = AsyncMock() + runner.async_task_client = AsyncMock() + runner._semaphore = asyncio.Semaphore(1) + + runner.async_task_client.update_task_v2 = AsyncMock( + side_effect=ApiException(status=404, reason="Not Found") + ) + runner.async_task_client.update_task = AsyncMock(return_value='ok') + + # First call triggers fallback + await runner._AsyncTaskRunner__async_update_task(self._create_task_result()) + self.assertEqual(runner.async_task_client.update_task_v2.call_count, 1) + self.assertEqual(runner.async_task_client.update_task.call_count, 1) + + # Second call skips v2 + await runner._AsyncTaskRunner__async_update_task(self._create_task_result()) + self.assertEqual(runner.async_task_client.update_task_v2.call_count, 1) # Still 1 + self.assertEqual(runner.async_task_client.update_task.call_count, 2) + + asyncio.run(run_test()) + + def test_v2_success_no_fallback(self): + """When async v2 succeeds, should return next task and not touch v1.""" + + async def simple_worker(value: int) -> dict: + return {'result': value} + + worker = Worker( + task_definition_name='test_v2_success', + execute_function=simple_worker, + thread_count=1 + ) + + config = Configuration() + config.AUTH_TOKEN = self.AUTH_TOKEN + runner = AsyncTaskRunner(worker=worker, configuration=config) + + next_task = Task(task_id='next_task', workflow_instance_id='wf_2') + + async def run_test(): + runner.async_api_client = AsyncMock() + runner.async_task_client = AsyncMock() + runner._semaphore = asyncio.Semaphore(1) + + runner.async_task_client.update_task_v2 = AsyncMock(return_value=next_task) + runner.async_task_client.update_task = AsyncMock(return_value='ok') + + result = await runner._AsyncTaskRunner__async_update_task(self._create_task_result()) + + self.assertEqual(result, next_task) + self.assertTrue(runner._v2_available) + runner.async_task_client.update_task.assert_not_called() + + asyncio.run(run_test()) + + def test_end_to_end_with_fallback(self): + """Full end-to-end: poll -> execute -> update_v2 fails -> fallback to v1.""" + + async def async_worker_fn(value: int) -> dict: + return {'result': value * 2} + + worker = Worker( + task_definition_name='test_e2e_fallback', + execute_function=async_worker_fn, + thread_count=1 + ) + + config = Configuration() + config.AUTH_TOKEN = self.AUTH_TOKEN + runner = AsyncTaskRunner(worker=worker, configuration=config) + + mock_task = Task() + mock_task.task_id = self.TASK_ID + mock_task.workflow_instance_id = self.WORKFLOW_INSTANCE_ID + mock_task.task_def_name = 'test_e2e_fallback' + mock_task.input_data = {'value': 10} + mock_task.status = 'SCHEDULED' + + async def run_test(): + runner.async_api_client = AsyncMock() + runner.async_task_client = AsyncMock() + runner._semaphore = asyncio.Semaphore(1) + + runner.async_task_client.batch_poll = AsyncMock(return_value=[mock_task]) + runner.async_task_client.update_task_v2 = AsyncMock( + side_effect=ApiException(status=404, reason="Not Found") + ) + runner.async_task_client.update_task = AsyncMock(return_value='ok') + + await runner.run_once() + await asyncio.sleep(0.1) + + # v2 was attempted, then fell back to v1 + runner.async_task_client.update_task_v2.assert_called_once() + runner.async_task_client.update_task.assert_called_once() + + # Task result should have correct output + v1_call = runner.async_task_client.update_task.call_args + task_result = v1_call.kwargs['body'] + self.assertEqual(task_result.status, TaskResultStatus.COMPLETED) + self.assertEqual(task_result.output_data, {'result': 20}) + + # Flag should be flipped + self.assertFalse(runner._v2_available) + + asyncio.run(run_test()) + + def _create_task_result(self): + return TaskResult( + task_id=self.TASK_ID, + workflow_instance_id=self.WORKFLOW_INSTANCE_ID, + worker_id='test_worker', + status=TaskResultStatus.COMPLETED, + output_data={'result': 42} + ) + + +if __name__ == '__main__': + unittest.main()