diff --git a/litellm/integrations/focus/focus_logger.py b/litellm/integrations/focus/focus_logger.py index 5eccf441be..28629f64d3 100644 --- a/litellm/integrations/focus/focus_logger.py +++ b/litellm/integrations/focus/focus_logger.py @@ -228,9 +228,7 @@ class FocusLogger(CustomLogger): now_utc = now.astimezone(timezone.utc) if self.frequency == "hourly": end_time = now_utc.replace(minute=0, second=0, microsecond=0) - # start_time = end_time - timedelta(hours=1) - # Temporary override: export data since start of day instead of last hour - start_time = end_time.replace(hour=0) + start_time = end_time - timedelta(hours=1) elif self.frequency == "daily": end_time = now_utc.replace(hour=0, minute=0, second=0, microsecond=0) start_time = end_time - timedelta(days=1) diff --git a/litellm/integrations/focus/schema.py b/litellm/integrations/focus/schema.py index 61ebbf2b9d..6d2e1dc83c 100644 --- a/litellm/integrations/focus/schema.py +++ b/litellm/integrations/focus/schema.py @@ -19,11 +19,6 @@ FOCUS_NORMALIZED_SCHEMA = pl.Schema( "ChargeFrequency": pl.String, "ChargePeriodStart": pl.Datetime(time_unit="us"), "ChargePeriodEnd": pl.Datetime(time_unit="us"), - "CommitmentDiscountCategory": pl.String, - "CommitmentDiscountId": pl.String, - "CommitmentDiscountName": pl.String, - "CommitmentDiscountStatus": pl.String, - "CommitmentDiscountType": pl.String, "ConsumedQuantity": pl.Float64, "ConsumedUnit": pl.Float64, "ContractedCost": pl.Float64, @@ -44,8 +39,6 @@ FOCUS_NORMALIZED_SCHEMA = pl.Schema( "ResourceType": pl.String, "ServiceCategory": pl.String, "ServiceName": pl.String, - "SkuId": pl.String, - "SkuPriceId": pl.String, "SubAccountId": pl.String, "SubAccountName": pl.String, "SubAccountType": pl.String, diff --git a/litellm/integrations/focus/serializers/parquet.py b/litellm/integrations/focus/serializers/parquet.py index deac0e4539..6b3dde5903 100644 --- a/litellm/integrations/focus/serializers/parquet.py +++ b/litellm/integrations/focus/serializers/parquet.py @@ -19,5 +19,4 @@ class FocusParquetSerializer(FocusSerializer): target = frame if not frame.is_empty() else pl.DataFrame(schema=frame.schema) buffer = io.BytesIO() target.write_parquet(buffer, compression="snappy") - print(target.head(5)) # debug return buffer.getvalue() diff --git a/litellm/integrations/focus/transformer.py b/litellm/integrations/focus/transformer.py index 8957172da8..a98ea21b1b 100644 --- a/litellm/integrations/focus/transformer.py +++ b/litellm/integrations/focus/transformer.py @@ -40,7 +40,6 @@ class FocusTransformer: none_str = pl.lit(None, dtype=pl.Utf8) none_dec = pl.lit(None, dtype=pl.Decimal(18, 6)) - # zero_float = pl.lit(0.0, dtype=pl.Float64) return frame.select( dec(pl.col("spend").fill_null(0.0)).alias("BilledCost"), @@ -56,26 +55,16 @@ class FocusTransformer: pl.lit("Usage-Based").alias("ChargeFrequency"), fmt(pl.col("ChargePeriodEnd")).alias("ChargePeriodEnd"), fmt(pl.col("ChargePeriodStart")).alias("ChargePeriodStart"), - # pl.lit(None).alias("CommitmentDiscountCategory"), - # none_str.alias("CommitmentDiscountId"), - # none_str.alias("CommitmentDiscountName"), - # none_dec.alias("CommitmentDiscountQuantity"), - # none_str.alias("CommitmentDiscountUnit"), - # none_str.alias("CommitmentDiscountStatus"), - # none_str.alias("CommitmentDiscountType"), dec(pl.lit(1.0)).alias("ConsumedQuantity"), pl.lit("Requests").alias("ConsumedUnit"), dec(pl.col("spend").fill_null(0.0)).alias("ContractedCost"), none_str.alias("ContractedUnitPrice"), dec(pl.col("spend").fill_null(0.0)).alias("EffectiveCost"), pl.col("custom_llm_provider").cast(pl.String).alias("InvoiceIssuerName"), - # pl.lit("INVOICE-NOT-ISSUED").alias("InvoiceId"), none_str.alias("InvoiceId"), dec(pl.col("spend").fill_null(0.0)).alias("ListCost"), none_dec.alias("ListUnitPrice"), none_str.alias("AvailabilityZone"), - # none_str.alias("CapacityReservationId"), - # none_str.alias("CapacityReservationStatus"), pl.lit("USD").alias("PricingCurrency"), none_str.alias("PricingCategory"), dec(pl.lit(1.0)).alias("PricingQuantity"), @@ -93,10 +82,6 @@ class FocusTransformer: pl.lit("AI and Machine Learning").alias("ServiceCategory"), pl.lit("Generative AI").alias("ServiceSubcategory"), pl.col("model_group").cast(pl.String).alias("ServiceName"), - # none_str.alias("SkuId"), - # none_str.alias("SkuPriceId"), - # none_str.alias("SkuMeter"), - # none_str.alias("SkuPriceDetails"), pl.col("team_id").cast(pl.String).alias("SubAccountId"), pl.col("team_alias").cast(pl.String).alias("SubAccountName"), none_str.alias("SubAccountType"), diff --git a/tests/test_litellm/integrations/focus/test_database.py b/tests/test_litellm/integrations/focus/test_database.py new file mode 100644 index 0000000000..5ee98cc9dd --- /dev/null +++ b/tests/test_litellm/integrations/focus/test_database.py @@ -0,0 +1,74 @@ +"""Tests for FocusLiteLLMDatabase query construction.""" + +from datetime import datetime, timezone +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from litellm.integrations.focus.database import FocusLiteLLMDatabase + + +def _setup_db(monkeypatch: pytest.MonkeyPatch, query_return): + """Create a database instance with a stubbed prisma client.""" + query_mock = AsyncMock(return_value=query_return) + mock_client = SimpleNamespace(db=SimpleNamespace(query_raw=query_mock)) + db = FocusLiteLLMDatabase() + monkeypatch.setattr(db, "_ensure_prisma_client", lambda: mock_client) + return db, query_mock + + +@pytest.mark.asyncio +async def test_should_parameterize_filters_and_limit(monkeypatch: pytest.MonkeyPatch): + start = datetime(2024, 1, 1, tzinfo=timezone.utc) + end = datetime(2024, 1, 2, tzinfo=timezone.utc) + db, query_mock = _setup_db(monkeypatch, []) + + await db.get_usage_data(limit=25, start_time_utc=start, end_time_utc=end) + + query_text, *params = query_mock.await_args.args + assert "dus.updated_at >= $1::timestamptz" in query_text + assert "dus.updated_at <= $2::timestamptz" in query_text + assert "LIMIT $3" in query_text + assert params == [start, end, 25] + + +@pytest.mark.asyncio +async def test_should_execute_without_filters(monkeypatch: pytest.MonkeyPatch): + row = { + "id": 1, + "user_id": "user", + "date": datetime(2024, 1, 1, tzinfo=timezone.utc), + } + db, query_mock = _setup_db(monkeypatch, [row]) + + result = await db.get_usage_data() + + query_text, *params = query_mock.await_args.args + assert "WHERE" not in query_text + assert "LIMIT $" not in query_text + assert params == [] + assert result.height == 1 + assert result["id"][0] == 1 + + +@pytest.mark.asyncio +async def test_should_accept_string_timestamps(monkeypatch: pytest.MonkeyPatch): + db, query_mock = _setup_db(monkeypatch, []) + + start = "2024-02-01T00:00:00+00:00" + end = "2024-02-02T00:00:00+00:00" + await db.get_usage_data(start_time_utc=start, end_time_utc=end) + + _, *params = query_mock.await_args.args + assert params == [start, end] + + +@pytest.mark.asyncio +async def test_should_reject_invalid_limit(monkeypatch: pytest.MonkeyPatch): + db, query_mock = _setup_db(monkeypatch, []) + + with pytest.raises(ValueError): + await db.get_usage_data(limit="invalid") + + assert query_mock.await_count == 0 diff --git a/tests/test_litellm/integrations/focus/test_s3_destination.py b/tests/test_litellm/integrations/focus/test_s3_destination.py new file mode 100644 index 0000000000..f915b2c56a --- /dev/null +++ b/tests/test_litellm/integrations/focus/test_s3_destination.py @@ -0,0 +1,100 @@ +"""Tests for FocusS3Destination behavior.""" + +from __future__ import annotations + +from datetime import datetime, timezone +from types import SimpleNamespace +from typing import Any, Dict + +import pytest + +import litellm.integrations.focus.destinations.s3_destination as s3_module +from litellm.integrations.focus.destinations.base import FocusTimeWindow +from litellm.integrations.focus.destinations.s3_destination import FocusS3Destination + + +def _window(freq: str = "hourly", hour: int = 5) -> FocusTimeWindow: + start = datetime(2024, 1, 2, hour, tzinfo=timezone.utc) + end = start.replace(hour=hour + 1) + return FocusTimeWindow(start_time=start, end_time=end, frequency=freq) + + +def test_should_require_bucket_name(): + with pytest.raises(ValueError): + FocusS3Destination(prefix="focus", config={}) + + +def test_should_build_hourly_object_key(): + dest = FocusS3Destination(prefix="exports/", config={"bucket_name": "bucket"}) + key = dest._build_object_key( + time_window=_window(freq="hourly", hour=3), filename="data.snappy" + ) + assert key == "exports/date=2024-01-02/hour=03/data.snappy" + + +def test_should_build_daily_key_without_hour_segment(): + dest = FocusS3Destination(prefix="", config={"bucket_name": "bucket"}) + key = dest._build_object_key( + time_window=_window(freq="daily", hour=0), filename="daily.parquet" + ) + assert key == "date=2024-01-02/daily.parquet" + + +@pytest.mark.asyncio +async def test_should_dispatch_upload_via_thread(monkeypatch: pytest.MonkeyPatch): + dest = FocusS3Destination(prefix="focus", config={"bucket_name": "bucket"}) + captured: Dict[str, Any] = {} + + async def fake_to_thread(func, *args, **kwargs): # type: ignore[override] + captured["func"] = func + captured["args"] = args + captured["kwargs"] = kwargs + + monkeypatch.setattr(s3_module.asyncio, "to_thread", fake_to_thread) + + window = _window(freq="hourly", hour=1) + await dest.deliver(content=b"payload", time_window=window, filename="file.bin") + + assert captured["func"] == dest._upload + assert captured["args"][0] == b"payload" + assert captured["args"][1].endswith("/file.bin") + + +def test_should_upload_with_configured_client(monkeypatch: pytest.MonkeyPatch): + config = { + "bucket_name": "bucket", + "region_name": "us-east-2", + "endpoint_url": "http://localhost:4566", + "aws_access_key_id": "key", + "aws_secret_access_key": "secret", + "aws_session_token": "token", + } + dest = FocusS3Destination(prefix="focus", config=config) + captured: Dict[str, Any] = {} + + def fake_client(service: str, **kwargs): + assert service == "s3" + captured["client_kwargs"] = kwargs + + def put_object(**put_kwargs): + captured["put_kwargs"] = put_kwargs + + return SimpleNamespace(put_object=put_object) + + monkeypatch.setattr(s3_module.boto3, "client", fake_client) + + dest._upload(content=b"payload", object_key="path/file.bin") + + assert captured["client_kwargs"] == { + "region_name": "us-east-2", + "endpoint_url": "http://localhost:4566", + "aws_access_key_id": "key", + "aws_secret_access_key": "secret", + "aws_session_token": "token", + } + assert captured["put_kwargs"] == { + "Bucket": "bucket", + "Key": "path/file.bin", + "Body": b"payload", + "ContentType": "application/octet-stream", + }