Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,20 @@ def sort(self, *orders: stages.Ordering) -> "_BasePipeline":
"""
return self._append(stages.Sort(*orders))

def search(self, options: stages.SearchOptions) -> "_BasePipeline":
"""
Adds a search stage to the pipeline.

This stage filters documents based on the provided query expression.

Args:
options: A SearchOptions instance configuring the search.

Returns:
A new Pipeline object with this stage appended to the stage list
"""
return self._append(stages.Search(options))

def sample(self, limit_or_options: int | stages.SampleOptions) -> "_BasePipeline":
"""
Performs a pseudo-random sampling of the documents from the previous stage.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,53 @@ def less_than_or_equal(
[self, self._cast_to_expr_or_convert_to_constant(other)],
)

@expose_as_static
def between(
self, lower: Expression | float, upper: Expression | float
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The type hint for lower and upper is Expression | float, which is too restrictive. The docstring example uses integers (Field.of("age").between(18, 65)), which would be flagged as a type error by some static analysis tools. To more accurately reflect the supported types and align with the example, consider widening the type hint to include int.

Suggested change
self, lower: Expression | float, upper: Expression | float
self, lower: Expression | int | float, upper: Expression | int | float

) -> "BooleanExpression":
"""Evaluates if the result of this expression is between
the lower bound (inclusive) and upper bound (inclusive).

This is functionally equivalent to performing an `And` operation with
`greater_than_or_equal` and `less_than_or_equal`.

Example:
>>> # Check if the 'age' field is between 18 and 65
>>> Field.of("age").between(18, 65)

Args:
lower: Lower bound (inclusive) of the range.
upper: Upper bound (inclusive) of the range.

Returns:
A new `BooleanExpression` representing the between comparison.
"""
return And(
self.greater_than_or_equal(lower),
self.less_than_or_equal(upper),
)

@expose_as_static
def geo_distance(self, other: Expression | GeoPoint) -> "FunctionExpression":
"""Evaluates to the distance in meters between the location in the specified
field and the query location.

Note: This Expression can only be used within a `Search` stage.

Example:
>>> # Calculate distance between the 'location' field and a target GeoPoint
>>> Field.of("location").geo_distance(target_point)

Args:
other: Compute distance to this GeoPoint expression or constant value.

Returns:
A new `FunctionExpression` representing the distance.
"""
return FunctionExpression(
"geo_distance", [self, self._cast_to_expr_or_convert_to_constant(other)]
)

@expose_as_static
def equal_any(
self, array: Array | Sequence[Expression | CONSTANT_TYPE] | Expression
Expand Down Expand Up @@ -2889,3 +2936,44 @@ class Rand(FunctionExpression):

def __init__(self):
super().__init__("rand", [], use_infix_repr=False)


class Score(FunctionExpression):
"""Evaluates to the search score that reflects the topicality of the document
to all of the text predicates (`queryMatch`)
in the search query. If `SearchOptions.query` is not set or does not contain
any text predicates, then this topicality score will always be `0`.

Note: This Expression can only be used within a `Search` stage.

Returns:
A new `Expression` representing the score operation.
"""

def __init__(self):
super().__init__("score", [], use_infix_repr=False)


class DocumentMatches(BooleanExpression):
"""Creates a boolean expression for a document match query.

Note: This Expression can only be used within a `Search` stage.

Example:
>>> # Find documents matching the query string
>>> DocumentMatches("search query")

Args:
query: The search query string or expression.

Returns:
A new `BooleanExpression` representing the document match.
"""

def __init__(self, query: Expression | str):
super().__init__(
"document_matches",
[Expression._cast_to_expr_or_convert_to_constant(query)],
use_infix_repr=False,
)

Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
AliasedExpression,
BooleanExpression,
CONSTANT_TYPE,
DocumentMatches,
Expression,
Field,
Ordering,
Expand Down Expand Up @@ -109,6 +110,81 @@ def percentage(value: float):
return SampleOptions(value, mode=SampleOptions.Mode.PERCENT)


class QueryEnhancement(Enum):
"""Define the query expansion behavior used by full-text search expressions."""
DISABLED = "disabled"
REQUIRED = "required"
PREFERRED = "preferred"


class SearchOptions:
"""Options for configuring the `Search` pipeline stage."""

def __init__(
self,
query: str | BooleanExpression,
limit: Optional[int] = None,
retrieval_depth: Optional[int] = None,
sort: Optional[Sequence[Ordering] | Ordering] = None,
add_fields: Optional[Sequence[Selectable]] = None,
select: Optional[Sequence[Selectable | str]] = None,
offset: Optional[int] = None,
query_enhancement: Optional[str | QueryEnhancement] = None,
language_code: Optional[str] = None,
):
"""
Initializes a SearchOptions instance.

Args:
query (str | BooleanExpression): Specifies the search query that will be used to query and score documents
by the search stage. The query can be expressed as an `Expression`, which will be used to score
and filter the results. Not all expressions supported by Pipelines are supported in the Search query.
The query can also be expressed as a string in the Search DSL.
limit (Optional[int]): The maximum number of documents to return from the Search stage.
retrieval_depth (Optional[int]): The maximum number of documents for the search stage to score. Documents
will be processed in the pre-sort order specified by the search index.
sort (Optional[Sequence[Ordering] | Ordering]): Orderings specify how the input documents are sorted.
add_fields (Optional[Sequence[Selectable]]): The fields to add to each document, specified as a `Selectable`.
select (Optional[Sequence[Selectable | str]]): The fields to keep or add to each document,
specified as an array of `Selectable` or strings.
offset (Optional[int]): The number of documents to skip.
query_enhancement (Optional[str | QueryEnhancement]): Define the query expansion behavior used by full-text search expressions
in this search stage.
language_code (Optional[str]): The BCP-47 language code of text in the search query, such as "en-US" or "sr-Latn".
"""
self.query = DocumentMatches(query) if isinstance(query, str) else query
self.limit = limit
self.retrieval_depth = retrieval_depth
self.sort = [sort] if isinstance(sort, Ordering) else sort
self.add_fields = add_fields
self.select = [Field(s) if isinstance(s, str) else s for s in select] if select is not None else None
self.offset = offset
self.query_enhancement = (
QueryEnhancement(query_enhancement.lower()) if isinstance(query_enhancement, str) else query_enhancement
)
self.language_code = language_code

def __repr__(self):
args = [f"query={self.query!r}"]
if self.limit is not None:
args.append(f"limit={self.limit}")
if self.retrieval_depth is not None:
args.append(f"retrieval_depth={self.retrieval_depth}")
if self.sort is not None:
args.append(f"sort={self.sort}")
if self.add_fields is not None:
args.append(f"add_fields={self.add_fields}")
if self.select is not None:
args.append(f"select={self.select}")
if self.offset is not None:
args.append(f"offset={self.offset}")
if self.query_enhancement is not None:
args.append(f"query_enhancement={self.query_enhancement!r}")
if self.language_code is not None:
args.append(f"language_code={self.language_code!r}")
return f"{self.__class__.__name__}({', '.join(args)})"


class UnnestOptions:
"""Options for configuring the `Unnest` pipeline stage.

Expand Down Expand Up @@ -423,6 +499,39 @@ def _pb_args(self):
]


class Search(Stage):
"""Search stage."""

def __init__(self, options: SearchOptions):
super().__init__("search")
self.options = options

def _pb_args(self) -> list[Value]:
return []

def _pb_options(self) -> dict[str, Value]:
options = {}
if self.options.query is not None:
options["query"] = self.options.query._to_pb()
Comment on lines +514 to +515
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The query parameter in SearchOptions.__init__ is not optional, and the initializer ensures self.options.query is always a BooleanExpression. Therefore, this if check for None is redundant and can be removed to simplify the code.

        options["query"] = self.options.query._to_pb()

if self.options.limit is not None:
options["limit"] = Value(integer_value=self.options.limit)
if self.options.retrieval_depth is not None:
options["retrieval_depth"] = Value(integer_value=self.options.retrieval_depth)
if self.options.sort is not None:
options["sort"] = Value(array_value={"values": [s._to_pb() for s in self.options.sort]})
if self.options.add_fields is not None:
options["add_fields"] = Selectable._to_value(self.options.add_fields)
if self.options.select is not None:
options["select"] = Selectable._to_value(self.options.select)
if self.options.offset is not None:
options["offset"] = Value(integer_value=self.options.offset)
if self.options.query_enhancement is not None:
options["query_enhancement"] = Value(string_value=self.options.query_enhancement.value)
if self.options.language_code is not None:
options["language_code"] = Value(string_value=self.options.language_code)
return options


class Select(Stage):
"""Selects or creates a set of fields."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,14 @@ data:
doc_with_nan:
value: "NaN"
doc_with_null:
value: null
value: null
geopoints:
loc1:
name: SF
location: GEOPOINT(37.7749,-122.4194)
loc2:
name: LA
location: GEOPOINT(34.0522,-118.2437)
loc3:
name: NY
location: GEOPOINT(40.7128,-74.0060)
Loading
Loading