diff --git a/litellm/proxy/management_endpoints/scim/scim_v2.py b/litellm/proxy/management_endpoints/scim/scim_v2.py index 0965198bad..e67e1eae74 100644 --- a/litellm/proxy/management_endpoints/scim/scim_v2.py +++ b/litellm/proxy/management_endpoints/scim/scim_v2.py @@ -410,6 +410,308 @@ async def set_scim_content_type(response: Response): response.headers["Content-Type"] = "application/scim+json" +def _get_resource_types(base_url: str = "/scim/v2") -> list: + """Return the list of SCIM ResourceType definitions per RFC 7643 Section 6.""" + return [ + SCIMResourceType( + id="User", + name="User", + description="User Account", + endpoint="/Users", + schema_="urn:ietf:params:scim:schemas:core:2.0:User", + meta={ + "location": f"{base_url}/ResourceTypes/User", + "resourceType": "ResourceType", + }, + ), + SCIMResourceType( + id="Group", + name="Group", + description="Group", + endpoint="/Groups", + schema_="urn:ietf:params:scim:schemas:core:2.0:Group", + meta={ + "location": f"{base_url}/ResourceTypes/Group", + "resourceType": "ResourceType", + }, + ), + ] + + +def _get_schemas() -> list: + """Return the list of SCIM Schema definitions per RFC 7643 Section 7.""" + return [ + SCIMSchema( + id="urn:ietf:params:scim:schemas:core:2.0:User", + name="User", + description="User Account", + attributes=[ + SCIMSchemaAttribute( + name="userName", + type="string", + multiValued=False, + description="Unique identifier for the User.", + required=True, + mutability="readWrite", + returned="default", + uniqueness="server", + ), + SCIMSchemaAttribute( + name="name", + type="complex", + multiValued=False, + description="The components of the user's real name.", + required=False, + subAttributes=[ + SCIMSchemaAttribute( + name="givenName", + type="string", + description="The given name of the User.", + ), + SCIMSchemaAttribute( + name="familyName", + type="string", + description="The family name of the User.", + ), + SCIMSchemaAttribute( + name="formatted", + type="string", + description="The full name.", + ), + ], + ), + SCIMSchemaAttribute( + name="displayName", + type="string", + multiValued=False, + description="The name of the User, suitable for display.", + ), + SCIMSchemaAttribute( + name="emails", + type="complex", + multiValued=True, + description="Email addresses for the user.", + subAttributes=[ + SCIMSchemaAttribute( + name="value", + type="string", + description="Email address value.", + ), + SCIMSchemaAttribute( + name="type", + type="string", + description="Type of email (work, home, etc.).", + ), + SCIMSchemaAttribute( + name="primary", + type="boolean", + description="Whether this is the primary email.", + ), + ], + ), + SCIMSchemaAttribute( + name="active", + type="boolean", + multiValued=False, + description="Whether the user account is active.", + ), + SCIMSchemaAttribute( + name="groups", + type="complex", + multiValued=True, + description="Groups to which the user belongs.", + mutability="readOnly", + subAttributes=[ + SCIMSchemaAttribute( + name="value", + type="string", + description="Group identifier.", + ), + SCIMSchemaAttribute( + name="display", + type="string", + description="Group display name.", + ), + ], + ), + ], + meta={ + "location": "/scim/v2/Schemas/urn:ietf:params:scim:schemas:core:2.0:User", + "resourceType": "Schema", + }, + ), + SCIMSchema( + id="urn:ietf:params:scim:schemas:core:2.0:Group", + name="Group", + description="Group", + attributes=[ + SCIMSchemaAttribute( + name="displayName", + type="string", + multiValued=False, + description="A human-readable name for the Group.", + required=True, + mutability="readWrite", + returned="default", + uniqueness="none", + ), + SCIMSchemaAttribute( + name="members", + type="complex", + multiValued=True, + description="A list of members of the Group.", + subAttributes=[ + SCIMSchemaAttribute( + name="value", + type="string", + description="Member identifier.", + ), + SCIMSchemaAttribute( + name="display", + type="string", + description="Member display name.", + ), + ], + ), + ], + meta={ + "location": "/scim/v2/Schemas/urn:ietf:params:scim:schemas:core:2.0:Group", + "resourceType": "Schema", + }, + ), + ] + + +@scim_router.get( + "", + status_code=200, + dependencies=[Depends(user_api_key_auth), Depends(set_scim_content_type)], +) +@scim_router.get( + "/", + status_code=200, + include_in_schema=False, + dependencies=[Depends(user_api_key_auth), Depends(set_scim_content_type)], +) +async def get_scim_base(request: Request): + """ + Base SCIM v2 endpoint for resource discovery per RFC 7644 Section 4. + + Returns a ListResponse of ResourceTypes supported by this SCIM service provider. + Identity providers (Okta, Azure AD, etc.) use this endpoint for resource discovery. + """ + verbose_proxy_logger.debug( + "SCIM base resource discovery request: method=%s url=%s", + request.method, + request.url, + ) + base_url = str(request.base_url).rstrip("/") + "/scim/v2" + resource_types = _get_resource_types(base_url) + return { + "schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"], + "totalResults": len(resource_types), + "Resources": [rt.model_dump() for rt in resource_types], + } + + +@scim_router.get( + "/ResourceTypes", + status_code=200, + dependencies=[Depends(user_api_key_auth), Depends(set_scim_content_type)], +) +async def get_resource_types(request: Request): + """ + SCIM ResourceTypes endpoint per RFC 7644 Section 4. + + Returns a ListResponse of all resource types supported by this service provider. + """ + verbose_proxy_logger.debug( + "SCIM ResourceTypes request: method=%s url=%s", + request.method, + request.url, + ) + base_url = str(request.base_url).rstrip("/") + "/scim/v2" + resource_types = _get_resource_types(base_url) + return { + "schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"], + "totalResults": len(resource_types), + "Resources": [rt.model_dump() for rt in resource_types], + } + + +@scim_router.get( + "/ResourceTypes/{resource_type_id}", + status_code=200, + dependencies=[Depends(user_api_key_auth), Depends(set_scim_content_type)], +) +async def get_resource_type( + request: Request, + resource_type_id: str = Path(..., title="ResourceType ID"), +): + """ + Get a single ResourceType by ID per RFC 7644. + """ + verbose_proxy_logger.debug( + "SCIM ResourceType request for id=%s", resource_type_id + ) + base_url = str(request.base_url).rstrip("/") + "/scim/v2" + resource_types = _get_resource_types(base_url) + for rt in resource_types: + if rt.id == resource_type_id: + return rt.model_dump() + raise HTTPException( + status_code=404, + detail={"error": f"ResourceType not found: {resource_type_id}"}, + ) + + +@scim_router.get( + "/Schemas", + status_code=200, + dependencies=[Depends(user_api_key_auth), Depends(set_scim_content_type)], +) +async def get_schemas(request: Request): + """ + SCIM Schemas endpoint per RFC 7643 Section 7. + + Returns a ListResponse of all schemas supported by this service provider. + """ + verbose_proxy_logger.debug( + "SCIM Schemas request: method=%s url=%s", + request.method, + request.url, + ) + schemas = _get_schemas() + return { + "schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"], + "totalResults": len(schemas), + "Resources": [s.model_dump() for s in schemas], + } + + +@scim_router.get( + "/Schemas/{schema_id:path}", + status_code=200, + dependencies=[Depends(user_api_key_auth), Depends(set_scim_content_type)], +) +async def get_schema( + request: Request, + schema_id: str = Path(..., title="Schema URI"), +): + """ + Get a single Schema by its URI per RFC 7643 Section 7. + """ + verbose_proxy_logger.debug("SCIM Schema request for id=%s", schema_id) + schemas = _get_schemas() + for s in schemas: + if s.id == schema_id: + return s.model_dump() + raise HTTPException( + status_code=404, + detail={"error": f"Schema not found: {schema_id}"}, + ) + + @scim_router.get( "/ServiceProviderConfig", response_model=SCIMServiceProviderConfig, diff --git a/litellm/types/proxy/management_endpoints/scim_v2.py b/litellm/types/proxy/management_endpoints/scim_v2.py index bff9f0b876..c4d95d99ed 100644 --- a/litellm/types/proxy/management_endpoints/scim_v2.py +++ b/litellm/types/proxy/management_endpoints/scim_v2.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Literal, Optional, Union from fastapi import HTTPException -from pydantic import BaseModel, EmailStr, field_validator +from pydantic import BaseModel, ConfigDict, EmailStr, field_validator class LiteLLM_UserScimMetadata(BaseModel): @@ -112,3 +112,67 @@ class SCIMServiceProviderConfig(BaseModel): etag: SCIMFeature = SCIMFeature(supported=False) authenticationSchemes: Optional[List[Dict[str, Any]]] = None meta: Optional[Dict[str, Any]] = None + + +# SCIM ResourceType Models (RFC 7643 Section 6) +class SCIMSchemaExtension(BaseModel): + model_config = ConfigDict(populate_by_name=True) + + schema_: str # aliased to "schema" in serialization + required: bool + + def model_dump(self, **kwargs): + d = super().model_dump(**kwargs) + d["schema"] = d.pop("schema_") + return d + + +class SCIMResourceType(BaseModel): + model_config = ConfigDict(populate_by_name=True) + + schemas: List[str] = [ + "urn:ietf:params:scim:schemas:core:2.0:ResourceType" + ] + id: str + name: str + description: Optional[str] = None + endpoint: str + schema_: str # "schema" is a reserved name in Pydantic context + + schemaExtensions: Optional[List[SCIMSchemaExtension]] = None + meta: Optional[Dict[str, Any]] = None + + def model_dump(self, **kwargs): + d = super().model_dump(**kwargs) + d["schema"] = d.pop("schema_") + if d.get("schemaExtensions") is None: + d.pop("schemaExtensions", None) + return d + + +# SCIM Schema Models (RFC 7643 Section 7) +class SCIMSchemaAttribute(BaseModel): + name: str + type: str + multiValued: bool = False + description: Optional[str] = None + required: bool = False + mutability: str = "readWrite" + returned: str = "default" + uniqueness: str = "none" + subAttributes: Optional[List["SCIMSchemaAttribute"]] = None + + def model_dump(self, **kwargs): + d = super().model_dump(**kwargs) + if d.get("subAttributes") is None: + d.pop("subAttributes", None) + return d + + +class SCIMSchema(BaseModel): + schemas: List[str] = ["urn:ietf:params:scim:schemas:core:2.0:Schema"] + id: str + name: str + description: Optional[str] = None + attributes: List[SCIMSchemaAttribute] = [] + meta: Optional[Dict[str, Any]] = None diff --git a/tests/test_litellm/proxy/management_endpoints/scim/test_scim_v2_discovery.py b/tests/test_litellm/proxy/management_endpoints/scim/test_scim_v2_discovery.py new file mode 100644 index 0000000000..2162d6e188 --- /dev/null +++ b/tests/test_litellm/proxy/management_endpoints/scim/test_scim_v2_discovery.py @@ -0,0 +1,300 @@ +""" +Tests for SCIM v2 resource discovery endpoints: +- GET /scim/v2 (base endpoint) +- GET /scim/v2/ResourceTypes +- GET /scim/v2/ResourceTypes/{id} +- GET /scim/v2/Schemas +- GET /scim/v2/Schemas/{uri} +""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest +from fastapi import HTTPException + +from litellm.proxy.management_endpoints.scim.scim_v2 import ( + _get_resource_types, + _get_schemas, + get_resource_type, + get_resource_types, + get_schema, + get_schemas, + get_scim_base, +) +from litellm.types.proxy.management_endpoints.scim_v2 import ( + SCIMResourceType, + SCIMSchema, +) + + +def _make_mock_request(base_url="http://localhost:4000/", url="http://localhost:4000/scim/v2"): + """Create a mock FastAPI Request object.""" + request = MagicMock() + request.method = "GET" + request.url = url + request.base_url = base_url + return request + + +# ---- Helper function tests ---- + + +class TestGetResourceTypes: + def test_returns_user_and_group(self): + resource_types = _get_resource_types() + assert len(resource_types) == 2 + ids = [rt.id for rt in resource_types] + assert "User" in ids + assert "Group" in ids + + def test_user_resource_type_fields(self): + resource_types = _get_resource_types() + user_rt = next(rt for rt in resource_types if rt.id == "User") + assert user_rt.name == "User" + assert user_rt.endpoint == "/Users" + assert user_rt.schema_ == "urn:ietf:params:scim:schemas:core:2.0:User" + assert user_rt.schemas == ["urn:ietf:params:scim:schemas:core:2.0:ResourceType"] + + def test_group_resource_type_fields(self): + resource_types = _get_resource_types() + group_rt = next(rt for rt in resource_types if rt.id == "Group") + assert group_rt.name == "Group" + assert group_rt.endpoint == "/Groups" + assert group_rt.schema_ == "urn:ietf:params:scim:schemas:core:2.0:Group" + + def test_custom_base_url(self): + resource_types = _get_resource_types("https://example.com/scim/v2") + user_rt = next(rt for rt in resource_types if rt.id == "User") + assert user_rt.meta["location"] == "https://example.com/scim/v2/ResourceTypes/User" + + def test_model_dump_uses_schema_key(self): + """Ensure model_dump() outputs 'schema' not 'schema_'.""" + resource_types = _get_resource_types() + dumped = resource_types[0].model_dump() + assert "schema" in dumped + assert "schema_" not in dumped + + +class TestGetSchemas: + def test_returns_user_and_group_schemas(self): + schemas = _get_schemas() + assert len(schemas) == 2 + ids = [s.id for s in schemas] + assert "urn:ietf:params:scim:schemas:core:2.0:User" in ids + assert "urn:ietf:params:scim:schemas:core:2.0:Group" in ids + + def test_user_schema_has_required_attributes(self): + schemas = _get_schemas() + user_schema = next( + s for s in schemas if s.id == "urn:ietf:params:scim:schemas:core:2.0:User" + ) + attr_names = [a.name for a in user_schema.attributes] + assert "userName" in attr_names + assert "name" in attr_names + assert "emails" in attr_names + assert "active" in attr_names + assert "groups" in attr_names + + def test_group_schema_has_required_attributes(self): + schemas = _get_schemas() + group_schema = next( + s for s in schemas if s.id == "urn:ietf:params:scim:schemas:core:2.0:Group" + ) + attr_names = [a.name for a in group_schema.attributes] + assert "displayName" in attr_names + assert "members" in attr_names + + def test_schema_meta_fields(self): + schemas = _get_schemas() + user_schema = next( + s for s in schemas if s.id == "urn:ietf:params:scim:schemas:core:2.0:User" + ) + assert user_schema.meta is not None + assert user_schema.meta["resourceType"] == "Schema" + + +# ---- Endpoint tests ---- + + +class TestGetScimBase: + @pytest.mark.asyncio + async def test_returns_list_response(self): + request = _make_mock_request() + result = await get_scim_base(request) + + assert result["schemas"] == ["urn:ietf:params:scim:api:messages:2.0:ListResponse"] + assert result["totalResults"] == 2 + assert len(result["Resources"]) == 2 + + @pytest.mark.asyncio + async def test_resources_contain_user_and_group(self): + request = _make_mock_request() + result = await get_scim_base(request) + + resource_ids = [r["id"] for r in result["Resources"]] + assert "User" in resource_ids + assert "Group" in resource_ids + + @pytest.mark.asyncio + async def test_resources_have_schema_field(self): + """Each resource should have 'schema' (not 'schema_') per SCIM spec.""" + request = _make_mock_request() + result = await get_scim_base(request) + + for resource in result["Resources"]: + assert "schema" in resource + assert "schema_" not in resource + + @pytest.mark.asyncio + async def test_location_uses_base_url(self): + request = _make_mock_request(base_url="https://proxy.example.com/") + result = await get_scim_base(request) + + user_resource = next(r for r in result["Resources"] if r["id"] == "User") + assert user_resource["meta"]["location"] == "https://proxy.example.com/scim/v2/ResourceTypes/User" + + +class TestGetResourceTypesEndpoint: + @pytest.mark.asyncio + async def test_returns_list_response(self): + request = _make_mock_request() + result = await get_resource_types(request) + + assert result["schemas"] == ["urn:ietf:params:scim:api:messages:2.0:ListResponse"] + assert result["totalResults"] == 2 + + @pytest.mark.asyncio + async def test_resources_match_base_endpoint(self): + """ResourceTypes endpoint should return same data as base endpoint.""" + request = _make_mock_request() + base_result = await get_scim_base(request) + rt_result = await get_resource_types(request) + + assert base_result["totalResults"] == rt_result["totalResults"] + assert len(base_result["Resources"]) == len(rt_result["Resources"]) + + +class TestGetResourceTypeById: + @pytest.mark.asyncio + async def test_get_user_resource_type(self): + request = _make_mock_request() + result = await get_resource_type(request, resource_type_id="User") + + assert result["id"] == "User" + assert result["name"] == "User" + assert result["endpoint"] == "/Users" + assert result["schema"] == "urn:ietf:params:scim:schemas:core:2.0:User" + + @pytest.mark.asyncio + async def test_get_group_resource_type(self): + request = _make_mock_request() + result = await get_resource_type(request, resource_type_id="Group") + + assert result["id"] == "Group" + assert result["name"] == "Group" + assert result["endpoint"] == "/Groups" + + @pytest.mark.asyncio + async def test_not_found(self): + request = _make_mock_request() + with pytest.raises(HTTPException) as exc_info: + await get_resource_type(request, resource_type_id="NonExistent") + assert exc_info.value.status_code == 404 + + +class TestGetSchemasEndpoint: + @pytest.mark.asyncio + async def test_returns_list_response(self): + request = _make_mock_request() + result = await get_schemas(request) + + assert result["schemas"] == ["urn:ietf:params:scim:api:messages:2.0:ListResponse"] + assert result["totalResults"] == 2 + + @pytest.mark.asyncio + async def test_resources_have_correct_ids(self): + request = _make_mock_request() + result = await get_schemas(request) + + schema_ids = [r["id"] for r in result["Resources"]] + assert "urn:ietf:params:scim:schemas:core:2.0:User" in schema_ids + assert "urn:ietf:params:scim:schemas:core:2.0:Group" in schema_ids + + +class TestGetSchemaById: + @pytest.mark.asyncio + async def test_get_user_schema(self): + request = _make_mock_request() + result = await get_schema( + request, schema_id="urn:ietf:params:scim:schemas:core:2.0:User" + ) + + assert result["id"] == "urn:ietf:params:scim:schemas:core:2.0:User" + assert result["name"] == "User" + assert len(result["attributes"]) > 0 + + @pytest.mark.asyncio + async def test_get_group_schema(self): + request = _make_mock_request() + result = await get_schema( + request, schema_id="urn:ietf:params:scim:schemas:core:2.0:Group" + ) + + assert result["id"] == "urn:ietf:params:scim:schemas:core:2.0:Group" + assert result["name"] == "Group" + + @pytest.mark.asyncio + async def test_not_found(self): + request = _make_mock_request() + with pytest.raises(HTTPException) as exc_info: + await get_schema(request, schema_id="urn:nonexistent:schema") + assert exc_info.value.status_code == 404 + + +class TestSCIMResourceTypeModel: + """Test the SCIMResourceType Pydantic model itself.""" + + def test_model_dump_schema_key(self): + rt = SCIMResourceType( + id="Test", + name="Test", + endpoint="/Test", + schema_="urn:test", + ) + dumped = rt.model_dump() + assert "schema" in dumped + assert "schema_" not in dumped + assert dumped["schema"] == "urn:test" + + def test_no_schema_extensions_omitted(self): + rt = SCIMResourceType( + id="Test", + name="Test", + endpoint="/Test", + schema_="urn:test", + ) + dumped = rt.model_dump() + assert "schemaExtensions" not in dumped + + +class TestSCIMSchemaModel: + """Test the SCIMSchema Pydantic model.""" + + def test_basic_schema(self): + schema = SCIMSchema( + id="urn:test", + name="Test", + description="A test schema", + ) + assert schema.id == "urn:test" + assert schema.attributes == [] + + def test_sub_attributes_omitted_when_none(self): + from litellm.types.proxy.management_endpoints.scim_v2 import SCIMSchemaAttribute + + attr = SCIMSchemaAttribute( + name="test", + type="string", + ) + dumped = attr.model_dump() + assert "subAttributes" not in dumped