diff --git a/vulnerabilities/api_v2.py b/vulnerabilities/api_v2.py index 74975b819..58df30f70 100644 --- a/vulnerabilities/api_v2.py +++ b/vulnerabilities/api_v2.py @@ -1064,6 +1064,83 @@ def get_view_name(self): return "Pipeline Jobs" +class AdvisoryV2FilterSet(filters.FilterSet): + alias = CharInFilter( + field_name="aliases__alias", + lookup_expr="in", + label="Alias", + help_text="Filter by one or more aliases (e.g. CVE-2021-1234). Multi-value supported (comma-separated).", + ) + advisory_id = CharInFilter( + field_name="avid", + lookup_expr="in", + label="Advisory ID", + help_text="Filter by one or more advisory IDs (avid). Multi-value supported (comma-separated).", + ) + datasource_id = filters.CharFilter( + field_name="datasource_id", + label="Datasource ID", + help_text="Filter by datasource ID (e.g. nginx_importer_v2).", + ) + + class Meta: + model = AdvisoryV2 + fields = ["alias", "advisory_id", "datasource_id"] + + +@extend_schema_view( + list=extend_schema( + parameters=[ + OpenApiParameter( + name="alias", + description="Filter by one or more aliases (e.g. CVE-2021-1234). Comma-separated.", + required=False, + type={"type": "array", "items": {"type": "string"}}, + location=OpenApiParameter.QUERY, + ), + OpenApiParameter( + name="advisory_id", + description="Filter by one or more advisory IDs (avid). Comma-separated.", + required=False, + type={"type": "array", "items": {"type": "string"}}, + location=OpenApiParameter.QUERY, + ), + OpenApiParameter( + name="datasource_id", + description="Filter by datasource ID.", + required=False, + type=str, + location=OpenApiParameter.QUERY, + ), + ] + ) +) +class AdvisoryV2ViewSet(viewsets.ReadOnlyModelViewSet): + """ + Lookup for advisories by advisory ID, alias, or datasource. + """ + + queryset = ( + AdvisoryV2.objects.prefetch_related( + "aliases", + "references", + "severities", + "weaknesses", + "related_ssvcs", + "source_ssvcs", + ) + .order_by("datasource_id", "advisory_id") + .distinct() + ) + serializer_class = AdvisoryV2Serializer + lookup_field = "avid" + # avid contains slashes (e.g. nginx_importer_v2/CVE-2021-1234) + lookup_value_regex = r"[^/]+/[^/]+" + filter_backends = [filters.DjangoFilterBackend] + filterset_class = AdvisoryV2FilterSet + throttle_classes = [AnonRateThrottle, PermissionBasedUserRateThrottle] + + class PackageV3ViewSet(viewsets.ReadOnlyModelViewSet): queryset = PackageV2.objects.all() serializer_class = PackageV3Serializer diff --git a/vulnerabilities/tests/test_api_v2.py b/vulnerabilities/tests/test_api_v2.py index 6968123c7..714f670c9 100644 --- a/vulnerabilities/tests/test_api_v2.py +++ b/vulnerabilities/tests/test_api_v2.py @@ -17,8 +17,10 @@ from rest_framework.test import APIClient from rest_framework.test import APITestCase +from vulnerabilities.api_v2 import AdvisoryV2Serializer from vulnerabilities.api_v2 import PackageV2Serializer from vulnerabilities.api_v2 import VulnerabilityListSerializer +from vulnerabilities.models import AdvisoryAlias from vulnerabilities.models import AdvisoryV2 from vulnerabilities.models import Alias from vulnerabilities.models import ApiUser @@ -905,3 +907,130 @@ def test_get_all_vulnerable_purls(self): response = self.client.get(url) assert response.status_code == 200 assert "pkg:pypi/sample@1.0.0" in response.data + + +class AdvisoryV2ViewSetTest(APITestCase): + def setUp(self): + self.advisory1 = AdvisoryV2.objects.create( + datasource_id="nginx_importer_v2", + advisory_id="CVE-2021-1234", + avid="nginx_importer_v2/CVE-2021-1234", + unique_content_id="a" * 64, + url="https://example.com/advisory1", + date_collected="2024-01-01T00:00:00Z", + summary="Test advisory 1", + ) + self.advisory2 = AdvisoryV2.objects.create( + datasource_id="pypa_importer_v2", + advisory_id="PYSEC-2022-5678", + avid="pypa_importer_v2/PYSEC-2022-5678", + unique_content_id="b" * 64, + url="https://example.com/advisory2", + date_collected="2024-01-01T00:00:00Z", + summary="Test advisory 2", + ) + + self.alias1 = AdvisoryAlias.objects.create(alias="CVE-2021-1234") + self.advisory1.aliases.add(self.alias1) + + self.alias2 = AdvisoryAlias.objects.create(alias="GHSA-xxxx-yyyy-zzzz") + self.advisory2.aliases.add(self.alias2) + + cache.clear() + self.client = APIClient(enforce_csrf_checks=True) + + def test_list_advisories(self): + """ + Test listing all advisories without filters. + """ + url = reverse("advisory-v2-list") + response = self.client.get(url, format="json") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn("results", response.data) + self.assertEqual(response.data["count"], 2) + + def test_retrieve_advisory_by_avid(self): + """ + Test retrieving a specific advisory by its avid. + The avid contains a slash, handled by lookup_value_regex. + """ + url = reverse("advisory-v2-detail", kwargs={"avid": self.advisory1.avid}) + response = self.client.get(url, format="json") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data["advisory_id"], self.advisory1.avid) + self.assertEqual(response.data["url"], self.advisory1.url) + self.assertIn("CVE-2021-1234", response.data["aliases"]) + + def test_retrieve_nonexistent_advisory_returns_404(self): + """ + Test that a non-existent advisory returns 404. + """ + url = reverse("advisory-v2-detail", kwargs={"avid": "fake_source/FAKE-0000"}) + response = self.client.get(url, format="json") + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_filter_by_alias(self): + """ + Test filtering advisories by alias returns only matching advisory. + """ + url = reverse("advisory-v2-list") + response = self.client.get(url, {"alias": "CVE-2021-1234"}, format="json") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data["count"], 1) + result = response.data["results"][0] + self.assertIn("CVE-2021-1234", result["aliases"]) + + def test_filter_by_advisory_id(self): + """ + Test filtering advisories by advisory_id (avid). + """ + url = reverse("advisory-v2-list") + response = self.client.get( + url, {"advisory_id": "nginx_importer_v2/CVE-2021-1234"}, format="json" + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data["count"], 1) + self.assertEqual(response.data["results"][0]["advisory_id"], self.advisory1.avid) + + def test_filter_by_datasource_id(self): + """ + Test filtering advisories by datasource_id returns only that source's advisories. + """ + url = reverse("advisory-v2-list") + response = self.client.get(url, {"datasource_id": "nginx_importer_v2"}, format="json") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data["count"], 1) + self.assertEqual(response.data["results"][0]["advisory_id"], self.advisory1.avid) + + def test_filter_by_nonexistent_alias_returns_empty(self): + """ + Test that filtering by a non-existent alias returns an empty list. + """ + url = reverse("advisory-v2-list") + response = self.client.get(url, {"alias": "CVE-9999-9999"}, format="json") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data["count"], 0) + + def test_advisory_serializer_fields(self): + """ + Test that AdvisoryV2Serializer returns all required fields. + """ + serializer = AdvisoryV2Serializer(self.advisory1) + data = serializer.data + expected_fields = [ + "advisory_id", + "url", + "aliases", + "summary", + "severities", + "weaknesses", + "references", + "exploitability", + "weighted_severity", + "risk_score", + "related_ssvc_trees", + ] + for field in expected_fields: + self.assertIn(field, data) + self.assertEqual(data["advisory_id"], self.advisory1.avid) + self.assertIn("CVE-2021-1234", data["aliases"]) diff --git a/vulnerablecode/urls.py b/vulnerablecode/urls.py index 49948a3b9..858130b11 100644 --- a/vulnerablecode/urls.py +++ b/vulnerablecode/urls.py @@ -20,6 +20,7 @@ from vulnerabilities.api import CPEViewSet from vulnerabilities.api import PackageViewSet from vulnerabilities.api import VulnerabilityViewSet +from vulnerabilities.api_v2 import AdvisoryV2ViewSet from vulnerabilities.api_v2 import CodeFixV2ViewSet from vulnerabilities.api_v2 import CodeFixViewSet from vulnerabilities.api_v2 import PackageV2ViewSet @@ -66,6 +67,7 @@ def __init__(self, *args, **kwargs): api_v2_router.register("codefixes", CodeFixViewSet, basename="codefix") api_v2_router.register("pipelines", PipelineScheduleV2ViewSet, basename="pipelines") api_v2_router.register("advisory-codefixes", CodeFixV2ViewSet, basename="advisory-codefix") +api_v2_router.register("advisories", AdvisoryV2ViewSet, basename="advisory-v2") api_v3_router = OptionalSlashRouter()