diff --git a/packages/google-cloud-firestore/google/cloud/firestore_v1/base_pipeline.py b/packages/google-cloud-firestore/google/cloud/firestore_v1/base_pipeline.py index fac7f8bc4bce..0a14aa70fa3a 100644 --- a/packages/google-cloud-firestore/google/cloud/firestore_v1/base_pipeline.py +++ b/packages/google-cloud-firestore/google/cloud/firestore_v1/base_pipeline.py @@ -344,6 +344,20 @@ def sort(self, *orders: stages.Ordering) -> "_BasePipeline": """ return self._append(stages.Sort(*orders)) + def search(self, options: stages.SearchOptions) -> "_BasePipeline": + """ + Adds a search stage to the pipeline. + + This stage filters documents based on the provided query expression. + + Args: + options: A SearchOptions instance configuring the search. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Search(options)) + def sample(self, limit_or_options: int | stages.SampleOptions) -> "_BasePipeline": """ Performs a pseudo-random sampling of the documents from the previous stage. diff --git a/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_expressions.py b/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_expressions.py index 969ddf2794a5..2c0fb3b4282d 100644 --- a/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_expressions.py +++ b/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_expressions.py @@ -705,6 +705,53 @@ def less_than_or_equal( [self, self._cast_to_expr_or_convert_to_constant(other)], ) + @expose_as_static + def between( + self, lower: Expression | float, upper: Expression | float + ) -> "BooleanExpression": + """Evaluates if the result of this expression is between + the lower bound (inclusive) and upper bound (inclusive). + + This is functionally equivalent to performing an `And` operation with + `greater_than_or_equal` and `less_than_or_equal`. + + Example: + >>> # Check if the 'age' field is between 18 and 65 + >>> Field.of("age").between(18, 65) + + Args: + lower: Lower bound (inclusive) of the range. + upper: Upper bound (inclusive) of the range. + + Returns: + A new `BooleanExpression` representing the between comparison. + """ + return And( + self.greater_than_or_equal(lower), + self.less_than_or_equal(upper), + ) + + @expose_as_static + def geo_distance(self, other: Expression | GeoPoint) -> "FunctionExpression": + """Evaluates to the distance in meters between the location in the specified + field and the query location. + + Note: This Expression can only be used within a `Search` stage. + + Example: + >>> # Calculate distance between the 'location' field and a target GeoPoint + >>> Field.of("location").geo_distance(target_point) + + Args: + other: Compute distance to this GeoPoint expression or constant value. + + Returns: + A new `FunctionExpression` representing the distance. + """ + return FunctionExpression( + "geo_distance", [self, self._cast_to_expr_or_convert_to_constant(other)] + ) + @expose_as_static def equal_any( self, array: Array | Sequence[Expression | CONSTANT_TYPE] | Expression @@ -2889,3 +2936,44 @@ class Rand(FunctionExpression): def __init__(self): super().__init__("rand", [], use_infix_repr=False) + + +class Score(FunctionExpression): + """Evaluates to the search score that reflects the topicality of the document + to all of the text predicates (`queryMatch`) + in the search query. If `SearchOptions.query` is not set or does not contain + any text predicates, then this topicality score will always be `0`. + + Note: This Expression can only be used within a `Search` stage. + + Returns: + A new `Expression` representing the score operation. + """ + + def __init__(self): + super().__init__("score", [], use_infix_repr=False) + + +class DocumentMatches(BooleanExpression): + """Creates a boolean expression for a document match query. + + Note: This Expression can only be used within a `Search` stage. + + Example: + >>> # Find documents matching the query string + >>> DocumentMatches("search query") + + Args: + query: The search query string or expression. + + Returns: + A new `BooleanExpression` representing the document match. + """ + + def __init__(self, query: Expression | str): + super().__init__( + "document_matches", + [Expression._cast_to_expr_or_convert_to_constant(query)], + use_infix_repr=False, + ) + diff --git a/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_stages.py b/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_stages.py index cac9c70d4b99..81d90f2a3ee2 100644 --- a/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_stages.py +++ b/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_stages.py @@ -30,6 +30,7 @@ AliasedExpression, BooleanExpression, CONSTANT_TYPE, + DocumentMatches, Expression, Field, Ordering, @@ -109,6 +110,81 @@ def percentage(value: float): return SampleOptions(value, mode=SampleOptions.Mode.PERCENT) +class QueryEnhancement(Enum): + """Define the query expansion behavior used by full-text search expressions.""" + DISABLED = "disabled" + REQUIRED = "required" + PREFERRED = "preferred" + + +class SearchOptions: + """Options for configuring the `Search` pipeline stage.""" + + def __init__( + self, + query: str | BooleanExpression, + limit: Optional[int] = None, + retrieval_depth: Optional[int] = None, + sort: Optional[Sequence[Ordering] | Ordering] = None, + add_fields: Optional[Sequence[Selectable]] = None, + select: Optional[Sequence[Selectable | str]] = None, + offset: Optional[int] = None, + query_enhancement: Optional[str | QueryEnhancement] = None, + language_code: Optional[str] = None, + ): + """ + Initializes a SearchOptions instance. + + Args: + query (str | BooleanExpression): Specifies the search query that will be used to query and score documents + by the search stage. The query can be expressed as an `Expression`, which will be used to score + and filter the results. Not all expressions supported by Pipelines are supported in the Search query. + The query can also be expressed as a string in the Search DSL. + limit (Optional[int]): The maximum number of documents to return from the Search stage. + retrieval_depth (Optional[int]): The maximum number of documents for the search stage to score. Documents + will be processed in the pre-sort order specified by the search index. + sort (Optional[Sequence[Ordering] | Ordering]): Orderings specify how the input documents are sorted. + add_fields (Optional[Sequence[Selectable]]): The fields to add to each document, specified as a `Selectable`. + select (Optional[Sequence[Selectable | str]]): The fields to keep or add to each document, + specified as an array of `Selectable` or strings. + offset (Optional[int]): The number of documents to skip. + query_enhancement (Optional[str | QueryEnhancement]): Define the query expansion behavior used by full-text search expressions + in this search stage. + language_code (Optional[str]): The BCP-47 language code of text in the search query, such as "en-US" or "sr-Latn". + """ + self.query = DocumentMatches(query) if isinstance(query, str) else query + self.limit = limit + self.retrieval_depth = retrieval_depth + self.sort = [sort] if isinstance(sort, Ordering) else sort + self.add_fields = add_fields + self.select = [Field(s) if isinstance(s, str) else s for s in select] if select is not None else None + self.offset = offset + self.query_enhancement = ( + QueryEnhancement(query_enhancement.lower()) if isinstance(query_enhancement, str) else query_enhancement + ) + self.language_code = language_code + + def __repr__(self): + args = [f"query={self.query!r}"] + if self.limit is not None: + args.append(f"limit={self.limit}") + if self.retrieval_depth is not None: + args.append(f"retrieval_depth={self.retrieval_depth}") + if self.sort is not None: + args.append(f"sort={self.sort}") + if self.add_fields is not None: + args.append(f"add_fields={self.add_fields}") + if self.select is not None: + args.append(f"select={self.select}") + if self.offset is not None: + args.append(f"offset={self.offset}") + if self.query_enhancement is not None: + args.append(f"query_enhancement={self.query_enhancement!r}") + if self.language_code is not None: + args.append(f"language_code={self.language_code!r}") + return f"{self.__class__.__name__}({', '.join(args)})" + + class UnnestOptions: """Options for configuring the `Unnest` pipeline stage. @@ -423,6 +499,39 @@ def _pb_args(self): ] +class Search(Stage): + """Search stage.""" + + def __init__(self, options: SearchOptions): + super().__init__("search") + self.options = options + + def _pb_args(self) -> list[Value]: + return [] + + def _pb_options(self) -> dict[str, Value]: + options = {} + if self.options.query is not None: + options["query"] = self.options.query._to_pb() + if self.options.limit is not None: + options["limit"] = Value(integer_value=self.options.limit) + if self.options.retrieval_depth is not None: + options["retrieval_depth"] = Value(integer_value=self.options.retrieval_depth) + if self.options.sort is not None: + options["sort"] = Value(array_value={"values": [s._to_pb() for s in self.options.sort]}) + if self.options.add_fields is not None: + options["add_fields"] = Selectable._to_value(self.options.add_fields) + if self.options.select is not None: + options["select"] = Selectable._to_value(self.options.select) + if self.options.offset is not None: + options["offset"] = Value(integer_value=self.options.offset) + if self.options.query_enhancement is not None: + options["query_enhancement"] = Value(string_value=self.options.query_enhancement.value) + if self.options.language_code is not None: + options["language_code"] = Value(string_value=self.options.language_code) + return options + + class Select(Stage): """Selects or creates a set of fields.""" diff --git a/packages/google-cloud-firestore/tests/system/pipeline_e2e/data.yaml b/packages/google-cloud-firestore/tests/system/pipeline_e2e/data.yaml index f2533d2b1d48..797a45434809 100644 --- a/packages/google-cloud-firestore/tests/system/pipeline_e2e/data.yaml +++ b/packages/google-cloud-firestore/tests/system/pipeline_e2e/data.yaml @@ -144,4 +144,14 @@ data: doc_with_nan: value: "NaN" doc_with_null: - value: null \ No newline at end of file + value: null + geopoints: + loc1: + name: SF + location: GEOPOINT(37.7749,-122.4194) + loc2: + name: LA + location: GEOPOINT(34.0522,-118.2437) + loc3: + name: NY + location: GEOPOINT(40.7128,-74.0060) \ No newline at end of file diff --git a/packages/google-cloud-firestore/tests/system/pipeline_e2e/search.yaml b/packages/google-cloud-firestore/tests/system/pipeline_e2e/search.yaml new file mode 100644 index 000000000000..9367386f2afb --- /dev/null +++ b/packages/google-cloud-firestore/tests/system/pipeline_e2e/search.yaml @@ -0,0 +1,286 @@ +tests: + - description: search_stage_basic + pipeline: + - Collection: books + - Search: + - SearchOptions: + query: "science" + limit: 2 + assert_results: + - title: "Dune" + author: "Frank Herbert" + genre: "Science Fiction" + published: 1965 + rating: 4.6 + tags: + - politics + - desert + - adventure + awards: + hugo: true + nebula: true + - title: "The Hitchhiker's Guide to the Galaxy" + author: "Douglas Adams" + genre: "Science Fiction" + published: 1979 + rating: 4.2 + tags: + - comedy + - space + - adventure + awards: + hugo: true + nebula: false + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - name: search + options: + limit: + integerValue: '2' + query: + functionValue: + args: + - stringValue: "science" + name: document_matches + - description: search_stage_full_options + pipeline: + - Collection: books + - Search: + - SearchOptions: + query: + DocumentMatches: + - Constant: "science" + limit: 5 + retrieval_depth: 10 + offset: 1 + query_enhancement: disabled + language_code: en + assert_results: + - title: "The Hitchhiker's Guide to the Galaxy" + author: "Douglas Adams" + genre: "Science Fiction" + published: 1979 + rating: 4.2 + tags: + - comedy + - space + - adventure + awards: + hugo: true + nebula: false + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - name: search + options: + limit: + integerValue: '5' + retrieval_depth: + integerValue: '10' + offset: + integerValue: '1' + query_enhancement: + stringValue: "disabled" + language_code: + stringValue: "en" + query: + functionValue: + args: + - stringValue: "science" + name: document_matches + - description: search_stage_with_sort_and_add_fields + pipeline: + - Collection: books + - Search: + - SearchOptions: + query: "science" + sort: + Ordering: + - Field: rating + - DESCENDING + add_fields: + - AliasedExpression: + - FunctionExpression.string_concat: + - Field: title + - Constant: " - Selected" + - "title_selected" + - AliasedExpression: + - Score: [] + - "search_score" + assert_results: + - title: "Dune" + author: "Frank Herbert" + genre: "Science Fiction" + published: 1965 + rating: 4.6 + tags: + - politics + - desert + - adventure + awards: + hugo: true + nebula: true + title_selected: "Dune - Selected" + search_score: 0.0 + - title: "The Hitchhiker's Guide to the Galaxy" + author: "Douglas Adams" + genre: "Science Fiction" + published: 1979 + rating: 4.2 + tags: + - comedy + - space + - adventure + awards: + hugo: true + nebula: false + title_selected: "The Hitchhiker's Guide to the Galaxy - Selected" + search_score: 0.0 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - name: search + options: + query: + functionValue: + args: + - stringValue: "science" + name: document_matches + sort: + arrayValue: + values: + - mapValue: + fields: + direction: + stringValue: descending + expression: + fieldReferenceValue: rating + add_fields: + mapValue: + fields: + title_selected: + functionValue: + args: + - fieldReferenceValue: title + - stringValue: " - Selected" + name: string_concat + search_score: + functionValue: + name: score + - description: expression_between + pipeline: + - Collection: books + - Where: + - FunctionExpression.between: + - Field: published + - Constant: 1950 + - Constant: 1970 + - Select: + - title + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: Dune + - title: One Hundred Years of Solitude + - title: The Lord of the Rings + - title: To Kill a Mockingbird + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: published + - integerValue: '1950' + name: greater_than_or_equal + - functionValue: + args: + - fieldReferenceValue: published + - integerValue: '1970' + name: less_than_or_equal + name: and + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: expression_geo_distance + pipeline: + - Collection: geopoints + - Search: + - SearchOptions: + query: + FunctionExpression.less_than: + - FunctionExpression.geo_distance: + - Field: location + - GeoPoint: [37.0, -122.0] + - Constant: 150000.0 + - Select: + - name + - Sort: + - Ordering: + - Field: name + - ASCENDING + assert_results: + - name: SF + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /geopoints + name: collection + - name: search + options: + query: + functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: location + - geoPointValue: + latitude: 37.0 + longitude: -122.0 + name: geo_distance + - doubleValue: 150000.0 + name: less_than + - args: + - mapValue: + fields: + name: + fieldReferenceValue: name + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: name + name: sort \ No newline at end of file diff --git a/packages/google-cloud-firestore/tests/system/test_pipeline_acceptance.py b/packages/google-cloud-firestore/tests/system/test_pipeline_acceptance.py index afff43ac6950..686046c961e2 100644 --- a/packages/google-cloud-firestore/tests/system/test_pipeline_acceptance.py +++ b/packages/google-cloud-firestore/tests/system/test_pipeline_acceptance.py @@ -33,6 +33,7 @@ from google.cloud.firestore_v1 import pipeline_expressions as expr from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1 import GeoPoint FIRESTORE_PROJECT = os.environ.get("GCLOUD_PROJECT") @@ -345,7 +346,7 @@ def _parse_yaml_types(data): else: return [_parse_yaml_types(value) for value in data] # detect timestamps - if isinstance(data, str) and ":" in data: + if isinstance(data, str) and ":" in data and not data.startswith("GEOPOINT("): try: parsed_datetime = datetime.datetime.fromisoformat(data) return parsed_datetime diff --git a/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_expressions.py b/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_expressions.py index 98db3c3a8f17..323bfce5b45c 100644 --- a/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_expressions.py +++ b/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_expressions.py @@ -790,6 +790,42 @@ def test_equal(self): infix_instance = arg1.equal(arg2) assert infix_instance == instance + def test_between(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Lower") + arg3 = self._make_arg("Upper") + instance = Expression.between(arg1, arg2, arg3) + assert instance.name == "and" + assert len(instance.params) == 2 + assert instance.params[0].name == "greater_than_or_equal" + assert instance.params[1].name == "less_than_or_equal" + assert repr(instance) == "And(Left.greater_than_or_equal(Lower), Left.less_than_or_equal(Upper))" + infix_instance = arg1.between(arg2, arg3) + assert infix_instance == instance + + def test_geo_distance(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = Expression.geo_distance(arg1, arg2) + assert instance.name == "geo_distance" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Left.geo_distance(Right)" + infix_instance = arg1.geo_distance(arg2) + assert infix_instance == instance + + def test_document_matches(self): + arg1 = self._make_arg("Query") + instance = expr.DocumentMatches(arg1) + assert instance.name == "document_matches" + assert instance.params == [arg1] + assert repr(instance) == "DocumentMatches(Query)" + + def test_score(self): + instance = expr.Score() + assert instance.name == "score" + assert instance.params == [] + assert repr(instance) == "Score()" + def test_greater_than_or_equal(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") diff --git a/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_stages.py b/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_stages.py index 65685e6e33d6..c264ab9403be 100644 --- a/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_stages.py +++ b/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_stages.py @@ -24,6 +24,7 @@ Constant, Field, Ordering, + DocumentMatches, ) from google.cloud.firestore_v1.types.document import Value from google.cloud.firestore_v1.vector import Vector @@ -778,6 +779,74 @@ def test_to_pb_percent_mode(self): assert len(result_percent.options) == 0 +class TestSearch: + def test_search_defaults(self): + options = stages.SearchOptions(query="technology") + assert options.query.name == "document_matches" + assert options.limit is None + assert options.retrieval_depth is None + assert options.sort is None + assert options.add_fields is None + assert options.select is None + assert options.offset is None + assert options.query_enhancement is None + assert options.language_code is None + + stage = stages.Search(options) + pb_opts = stage._pb_options() + assert "query" in pb_opts + assert "limit" not in pb_opts + assert "retrieval_depth" not in pb_opts + + def test_search_full_options(self): + options = stages.SearchOptions( + query=DocumentMatches("tech"), + limit=10, + retrieval_depth=2, + sort=Ordering("score", Ordering.Direction.DESCENDING), + add_fields=[Field("extra")], + select=[Field("name")], + offset=5, + query_enhancement="disabled", + language_code="en", + ) + assert options.limit == 10 + assert options.retrieval_depth == 2 + assert len(options.sort) == 1 + assert options.offset == 5 + assert options.query_enhancement == stages.QueryEnhancement.DISABLED + assert options.language_code == "en" + + stage = stages.Search(options) + pb_opts = stage._pb_options() + + assert pb_opts["limit"].integer_value == 10 + assert pb_opts["retrieval_depth"].integer_value == 2 + assert len(pb_opts["sort"].array_value.values) == 1 + assert pb_opts["offset"].integer_value == 5 + assert pb_opts["query_enhancement"].string_value == "disabled" + assert pb_opts["language_code"].string_value == "en" + assert "query" in pb_opts + + def test_search_string_query_wrapping(self): + options = stages.SearchOptions(query="science") + assert options.query.name == "document_matches" + + def test_search_query_enhancement_enum(self): + options = stages.SearchOptions(query="q", query_enhancement=stages.QueryEnhancement.REQUIRED) + assert options.query_enhancement == stages.QueryEnhancement.REQUIRED + + stage = stages.Search(options) + pb_opts = stage._pb_options() + assert pb_opts["query_enhancement"].string_value == "required" + + + def test_search_string_field_coercion(self): + options = stages.SearchOptions(query="tech", select=["title"]) + assert len(options.select) == 1 + assert isinstance(options.select[0], Field) + assert options.select[0].path == "title" + class TestSelect: def _make_one(self, *args, **kwargs): return stages.Select(*args, **kwargs)