Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 0 additions & 19 deletions .github/actions/pylint/action.yml

This file was deleted.

9 changes: 6 additions & 3 deletions .github/workflows/linting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,13 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: ./.github/actions/pylint
name: "Lint ASR worker"
- uses: astral-sh/ruff-action@v3
with:
path: asr-worker
args: "--version" # skips test by displaying the version
- name: Check formatting
run: ruff format --config qa/ruff.toml --check asr-worker
- name: Lint test
run: ruff check --config qa/ruff.toml asr-worker

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
Expand Down
37 changes: 29 additions & 8 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,6 @@ jobs:
key: docker-${{ runner.os }}-${{ hashFiles('docker-compose.yml') }}
- name: Start test services
run: docker compose up -d datashare temporal-post-init elasticsearch
- name: test temporal setup
run: |
curl "https://temporal.download/cli/archive/latest?platform=linux&arch=amd64" --output temporal.tar.gz
tar xzvf temporal.tar.gz
sudo mv temporal /usr/local/bin
temporal operator namespace describe -n datashare-default --address localhost:7233
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
Expand All @@ -41,7 +35,7 @@ jobs:
- name: Run tests
run: |
cd datashare-python
uv sync --frozen --all-extras
uv sync --frozen --all-extras --dev
uv run --frozen python -m pytest -vvv --cache-clear --show-capture=all -r A

test-worker-template:
Expand All @@ -68,7 +62,34 @@ jobs:
- name: Run tests
run: |
cd worker-template
uv sync --frozen --all-extras
uv sync --frozen --all-extras --dev
uv run --frozen python -m pytest --timeout=180 -vvv --cache-clear --show-capture=all -r A

test-asr-worker:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Setup Python project
uses: actions/setup-python@v6
with:
python-version: ${{ env.PYTHON_VERSION }}
- name: Cache Docker images
uses: ScribeMD/docker-cache@0.5.0
with:
key: docker-${{ runner.os }}-${{ hashFiles('docker-compose.yml') }}
- name: Start test services
run: docker compose up -d datashare temporal-post-init elasticsearch
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
version: ${{ env.ASTRAL_VERSION }}
python-version: ${{ env.PYTHON_VERSION }}
enable-cache: true
working-directory: asr-worker
- name: Run tests
run: |
cd asr-worker
uv sync --frozen --all-extras --dev
uv run --frozen python -m pytest --timeout=180 -vvv --cache-clear --show-capture=all -r A

concurrency:
Expand Down
231 changes: 178 additions & 53 deletions asr-worker/asr_worker/activities.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,184 @@
import torchaudio
from caul.configs.parakeet import ParakeetConfig
from caul.model_handlers.helpers import ParakeetModelHandlerResult
from caul.tasks.preprocessing.helpers import PreprocessedInput
import uuid
from collections.abc import Generator
from contextlib import contextmanager
from pathlib import Path

from caul.configs import ParakeetConfig
from caul.model_handlers.asr_model_handler import ASRModelHandler
from caul.model_handlers.objects import ASRModelHandlerResult
from caul.tasks.preprocessing.objects import PreprocessedInput
from config import ASRWorkerConfig
from datashare_python.types_ import ProgressRateHandler
from datashare_python.utils import (
ActivityWithProgress,
activity_defn,
debuggable_name,
read_artifact_metadata,
safe_dir,
to_raw_progress,
write_artifact_metadata,
)
from datashare_python.utils import artifacts_dir as get_artifacts_dir
from pydantic import TypeAdapter
from temporalio import activity

from .constants import (
POSTPROCESS_ACTIVITY,
PREPROCESS_ACTIVITY,
RUN_INFERENCE_ACTIVITY,
TRANSCRIPTION_JSON,
TRANSCRIPTION_METADATA_KEY,
)
from .models import Transcription

_BASE_WEIGHT = 1.0
_PREPROCESS_WEIGHT = 5 * _BASE_WEIGHT
_INFERENCE_WEIGHT = 10 * _PREPROCESS_WEIGHT

_LIST_OF_PATH_ADAPTER = TypeAdapter(list[Path])


# TODO: update caul to provide context managers rather than load/shutdown
@contextmanager
def _handler(config: ParakeetConfig) -> Generator[ASRModelHandler, None, None]:
asr_handler = config.handler_from_config()
try:
asr_handler.startup()
yield asr_handler
finally:
asr_handler.shutdown()


class ASRActivities(ActivityWithProgress):
# TODO: pass this at runtime
_handler_config = ParakeetConfig(return_tensors=False)

@activity_defn(name=PREPROCESS_ACTIVITY, progress_weight=_PREPROCESS_WEIGHT)
def preprocess(self, paths: list[Path]) -> list[Path]:
# TODO: this shouldn't be necessary, fix this bug
paths = _LIST_OF_PATH_ADAPTER.validate_python(paths)
worker_config = ASRWorkerConfig()
audio_root = worker_config.audios_root
workdir = worker_config.workdir
# TODO: load from config passed at runtime with caching
# TODO: avoid loading the full handler we just need preprocessing
with _handler(self._handler_config) as asr_handler:
preprocessor = asr_handler.preprocessor
# TODO: implement a caching strategy here, we could avoid processing files
# which have already been preprocessed
to_process = [str(audio_root / p) for p in paths]
batches = []
# TODO: handle progress here
for batch in preprocessor.process(str(to_process), output_dir=workdir):
for preprocessed_input in batch:
uuid_name = uuid.uuid4().hex[:20]
segment_dir = safe_dir(uuid_name)
# TODO: find a more debuggable name for this
segment_path = (
workdir / segment_dir / f"{uuid_name}-preprocessed.json"
)
segment_path.parent.mkdir(parents=True, exist_ok=True)
preprocessed_input.model_dump_json(segment_path)
batches.append(segment_path.relative_to(workdir))
return batches

@activity_defn(name=RUN_INFERENCE_ACTIVITY, progress_weight=_INFERENCE_WEIGHT)
def infer(
self,
preprocessed_inputs: list[Path],
*,
progress: ProgressRateHandler | None = None,
) -> list[Path]:
preprocessed_inputs = _LIST_OF_PATH_ADAPTER.validate_python(preprocessed_inputs)
worker_config = ASRWorkerConfig()
workdir = worker_config.workdir
# TODO: load from config passed at runtime with caching
# TODO: avoid loading the full handler we just need inference
with _handler(self._handler_config) as asr_handler:
inference_runner = asr_handler.inference_handler
# TODO: extract this into a function to improve testability
paths = []
if progress is not None:
progress = to_raw_progress(
progress, max_progress=len(preprocessed_inputs)
)
abs_paths = [workdir / rel_path for rel_path in preprocessed_inputs]
audios = (PreprocessedInput.model_validate_json(f) for f in abs_paths)
for res_i, (path, asr_res) in enumerate(
zip(preprocessed_inputs, inference_runner.process(audios), strict=True)
):
filename = f"{debuggable_name(path)}-transcript.json"
transcript_path = workdir / safe_dir(filename) / filename
transcript_path.parent.mkdir(parents=True, exist_ok=True)
transcript_path.write_text(asr_res.model_dump_json())
paths.append(transcript_path.relative_to(workdir))
if progress is not None:
self._event_loop.run_until_complete(progress(res_i))
return paths

@activity_defn(name=POSTPROCESS_ACTIVITY, progress_weight=_BASE_WEIGHT)
def postprocess(
self,
inference_results: list[Path],
input_paths: list[Path],
project: str,
*,
progress: ProgressRateHandler | None = None,
) -> None:
inference_results = _LIST_OF_PATH_ADAPTER.validate_python(inference_results)
input_paths = _LIST_OF_PATH_ADAPTER.validate_python(input_paths)
worker_config = ASRWorkerConfig()
artifacts_root = worker_config.artifacts_root
# TODO: load from config passed at runtime with caching
# TODO: avoid loading the full handler we just need postprocessing
with _handler(self._handler_config) as asr_handler:
post_processor = asr_handler.postprocessor
if progress is not None:
progress = to_raw_progress(progress, max_progress=len(input_paths))
with post_processor:
transcriptions = post_processor.process(inference_results)
# Strict is important here !
for i, (original, asr_result) in enumerate(
zip(input_paths, transcriptions, strict=True)
):
t_path = write_transcription(
asr_result,
original.name,
artifacts_root=artifacts_root,
project=project,
)
activity.logger.debug("wrote transcription for %s", t_path)
if progress is not None:
self._event_loop.run_until_complete(progress(i))


class ASRActivities:
"""Contains activity definitions as well as reference to models"""

def __init__(self):
# TODO: Eventually this may include whisper, which will
# then require passing language_map
self.asr_handler = ParakeetConfig(return_tensors=False).handler_from_config()

# load models
self.asr_handler.startup()

@activity.defn(name="asr.transcription.preprocess")
async def preprocess(self, inputs: list[str]) -> list[list[PreprocessedInput]]:
"""Preprocess transcription inputs

:param inputs: list of file paths
:return: list of caul.tasks.preprocessing.helpers.PreprocessedInput
"""
return self.asr_handler.preprocessor.process(inputs)

@activity.defn(name="asr.transcription.infer")
async def infer(
self, inputs: list[PreprocessedInput]
) -> list[ParakeetModelHandlerResult]:
"""Transcribe audio files.

:param inputs: list of preprocessed inputs
:return: list of inference handler results
"""
# Load tensors
for item in inputs:
tensor, sample_rate = torchaudio.load(item.metadata.preprocessed_file_path)
# normalize
tensor = self.asr_handler.preprocessor.normalize(tensor, sample_rate)
# assign
item.tensor = tensor

return self.asr_handler.inference_handler.process(inputs)

@activity.defn(name="asr.transcription.postprocess")
async def postprocess(
self, inputs: list[ParakeetModelHandlerResult]
) -> list[ParakeetModelHandlerResult]:
"""Postprocess and reorder transcriptions

:param inputs: list of inference handler results
:return: list of parakeet inference handler results
"""
return self.asr_handler.postprocessor.process(inputs)
def write_transcription(
asr_result: ASRModelHandlerResult,
transcribed_filename: str,
*,
artifacts_root: Path,
project: str,
) -> Path:
result = Transcription.from_asr_handler_result(asr_result)
artifact_dir = artifacts_root / get_artifacts_dir(
project, filename=transcribed_filename
)
artifact_dir.mkdir(exist_ok=True, parents=True)
# TODO: if transcriptions are too large we could also serialize them
# as jsonl
transcription_path = artifact_dir / TRANSCRIPTION_JSON
transcription_path.write_text(result.model_dump_json())
try:
meta = read_artifact_metadata(
artifacts_root, project, filename=transcribed_filename
)
except FileNotFoundError:
meta = dict()
meta[TRANSCRIPTION_METADATA_KEY] = transcription_path.name
write_artifact_metadata(
meta, artifacts_root, project=project, filename=transcribed_filename
)
return transcription_path


REGISTRY = [ASRActivities.preprocess, ASRActivities.infer, ASRActivities.postprocess]
16 changes: 16 additions & 0 deletions asr-worker/asr_worker/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from pathlib import Path
from typing import ClassVar

import datashare_python
from datashare_python.config import WorkerConfig
from pydantic import Field

_ALL_LOGGERS = [datashare_python.__name__, __name__, "__main__"]


class ASRWorkerConfig(WorkerConfig):
loggers: ClassVar[list[str]] = Field(_ALL_LOGGERS, frozen=True)

audios_root: Path
artifacts_root: Path
workdir: Path
13 changes: 11 additions & 2 deletions asr-worker/asr_worker/constants.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
ASR_WORKER_NAME = "asr-worker"

_ONE_MINUTE = 60
ONE_MINUTE = 60

_TEN_MINUTES = _ONE_MINUTE * 10
TEN_MINUTES = ONE_MINUTE * 10

ASR_TASK_QUEUE = "transcription-tasks"

Expand All @@ -13,3 +13,12 @@
RESPONSE_SUCCESS = "success"

RESPONSE_ERROR = "error"

TRANSCRIPTION_JSON = "transcription.json"
TRANSCRIPTION_METADATA_KEY = "transcription"

ASR_WORKFLOW = "asr.transcription"
GET_CONFIG_ACTIVITY = "asr.transcription.config"
PREPROCESS_ACTIVITY = "asr.transcription.preprocess"
RUN_INFERENCE_ACTIVITY = "asr.transcription.infer"
POSTPROCESS_ACTIVITY = "asr.transcription.postprocess"
Loading
Loading