Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 122 additions & 54 deletions src/openlayer/lib/tracing/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,12 +522,26 @@ def trace(
inference_pipeline_id: Optional[str] = None,
context_kwarg: Optional[str] = None,
question_kwarg: Optional[str] = None,
promote: Optional[Union[List[str], Dict[str, str]]] = None,
guardrails: Optional[List[Any]] = None,
on_flush_failure: Optional[OnFlushFailureCallback] = None,
**step_kwargs,
):
"""Decorator to trace a function with optional guardrails.

Parameters
----------
promote : list of str or dict mapping str to str, optional
Kwarg names whose values should be surfaced as top-level columns in the
trace data. Pass a list to use the original kwarg names as column names,
or a dict to alias them::

# List form – uses original kwarg names
@tracer.trace(promote=["tool_call_count", "user_query"])

# Dict form – maps kwarg_name -> column_name
@tracer.trace(promote={"user_query": "agent_input_query"})

Examples
--------

Expand Down Expand Up @@ -614,6 +628,7 @@ def __next__(self):
context_kwarg=context_kwarg,
question_kwarg=question_kwarg,
)
_apply_promote_kwargs(self._inputs, promote)
self._trace_initialized = True

try:
Expand Down Expand Up @@ -709,6 +724,7 @@ def wrapper(*func_args, **func_kwargs):
context_kwarg=context_kwarg,
question_kwarg=question_kwarg,
)
_apply_promote_kwargs(original_inputs, promote)

# Apply input guardrails
modified_inputs, input_guardrail_metadata = (
Expand Down Expand Up @@ -811,6 +827,7 @@ def trace_async(
inference_pipeline_id: Optional[str] = None,
context_kwarg: Optional[str] = None,
question_kwarg: Optional[str] = None,
promote: Optional[Union[List[str], Dict[str, str]]] = None,
guardrails: Optional[List[Any]] = None,
on_flush_failure: Optional[OnFlushFailureCallback] = None,
**step_kwargs,
Expand All @@ -821,6 +838,19 @@ def trace_async(
function
or an async generator and handles both cases appropriately.

Parameters
----------
promote : list of str or dict mapping str to str, optional
Kwarg names whose values should be surfaced as top-level columns in the
trace data. Pass a list to use the original kwarg names as column names,
or a dict to alias them::

# List form – uses original kwarg names
@tracer.trace_async(promote=["job_id", "user_query"])

# Dict form – maps kwarg_name -> column_name
@tracer.trace_async(promote={"user_query": "agent_input_query"})

Examples
--------

Expand Down Expand Up @@ -886,6 +916,7 @@ async def __anext__(self):
context_kwarg=context_kwarg,
question_kwarg=question_kwarg,
)
_apply_promote_kwargs(self._inputs, promote)
self._trace_initialized = True

try:
Expand Down Expand Up @@ -939,8 +970,8 @@ async def async_function_wrapper(*func_args, **func_kwargs):
guardrail_metadata = {}

try:
# Apply input guardrails if provided
if guardrails:
# Apply promote / input guardrails if provided
if promote or guardrails:
try:
inputs = _extract_function_inputs(
func_signature=func_signature,
Expand All @@ -949,35 +980,43 @@ async def async_function_wrapper(*func_args, **func_kwargs):
context_kwarg=context_kwarg,
question_kwarg=question_kwarg,
)

# Process inputs through guardrails
modified_inputs, input_metadata = (
_apply_input_guardrails(
guardrails,
inputs,
_apply_promote_kwargs(inputs, promote)

if guardrails:
# Process inputs through guardrails
modified_inputs, input_metadata = (
_apply_input_guardrails(
guardrails,
inputs,
)
)
)
guardrail_metadata.update(input_metadata)

# Execute function with potentially modified inputs
if modified_inputs != inputs:
# Reconstruct function arguments from modified inputs
bound = func_signature.bind(
*func_args, **func_kwargs
)
bound.apply_defaults()

# Update bound arguments with modified values
for (
param_name,
modified_value,
) in modified_inputs.items():
if param_name in bound.arguments:
bound.arguments[param_name] = (
modified_value
)

output = await func(*bound.args, **bound.kwargs)
guardrail_metadata.update(input_metadata)

# Execute function with potentially modified inputs
if modified_inputs != inputs:
# Reconstruct function arguments from modified inputs
bound = func_signature.bind(
*func_args, **func_kwargs
)
bound.apply_defaults()

# Update bound arguments with modified values
for (
param_name,
modified_value,
) in modified_inputs.items():
if param_name in bound.arguments:
bound.arguments[param_name] = (
modified_value
)

output = await func(
*bound.args, **bound.kwargs
)
else:
output = await func(
*func_args, **func_kwargs
)
else:
output = await func(*func_args, **func_kwargs)
except Exception as e:
Expand Down Expand Up @@ -1042,8 +1081,8 @@ def sync_wrapper(*func_args, **func_kwargs):
output = exception = None
guardrail_metadata = {}
try:
# Apply input guardrails if provided
if guardrails:
# Apply promote / input guardrails if provided
if promote or guardrails:
try:
inputs = _extract_function_inputs(
func_signature=func_signature,
Expand All @@ -1052,33 +1091,39 @@ def sync_wrapper(*func_args, **func_kwargs):
context_kwarg=context_kwarg,
question_kwarg=question_kwarg,
)
_apply_promote_kwargs(inputs, promote)

# Process inputs through guardrails
modified_inputs, input_metadata = (
_apply_input_guardrails(
guardrails,
inputs,
if guardrails:
# Process inputs through guardrails
modified_inputs, input_metadata = (
_apply_input_guardrails(
guardrails,
inputs,
)
)
)
guardrail_metadata.update(input_metadata)
guardrail_metadata.update(input_metadata)

# Execute function with potentially modified inputs
if modified_inputs != inputs:
# Reconstruct function arguments from modified inputs
bound = func_signature.bind(
*func_args, **func_kwargs
)
bound.apply_defaults()
# Execute function with potentially modified inputs
if modified_inputs != inputs:
# Reconstruct function arguments from modified inputs
bound = func_signature.bind(
*func_args, **func_kwargs
)
bound.apply_defaults()

# Update bound arguments with modified values
for (
param_name,
modified_value,
) in modified_inputs.items():
if param_name in bound.arguments:
bound.arguments[param_name] = modified_value
# Update bound arguments with modified values
for (
param_name,
modified_value,
) in modified_inputs.items():
if param_name in bound.arguments:
bound.arguments[param_name] = (
modified_value
)

output = func(*bound.args, **bound.kwargs)
output = func(*bound.args, **bound.kwargs)
else:
output = func(*func_args, **func_kwargs)
else:
output = func(*func_args, **func_kwargs)
except Exception as e:
Expand Down Expand Up @@ -1818,6 +1863,29 @@ def _extract_function_inputs(
return inputs


def _apply_promote_kwargs(
inputs: dict,
promote: Optional[Union[List[str], Dict[str, str]]],
) -> None:
"""Promote selected function kwargs to trace-level columns."""
if not promote:
return
mapping: Dict[str, str] = (
{k: k for k in promote} if isinstance(promote, list) else promote
)
resolved: Dict[str, Any] = {}
for kwarg_name, column_name in mapping.items():
if kwarg_name in inputs:
resolved[column_name] = inputs[kwarg_name]
else:
logger.warning(
"promote: kwarg `%s` not found in inputs of the current function.",
kwarg_name,
)
if resolved:
update_current_trace(**resolved)


def _finalize_step_logging(
step: steps.Step,
inputs: dict,
Expand Down