fix(gcs_bucket): prevent unbounded queue growth due to slow API calls (#19297)

This commit is contained in:
Alexsander Hamir
2026-01-19 10:47:56 -08:00
committed by GitHub
parent 2ba7d2e821
commit 3cdeebb5b8
3 changed files with 203 additions and 41 deletions
@@ -603,6 +603,7 @@ router_settings:
| GCS_PATH_SERVICE_ACCOUNT | Path to the Google Cloud service account JSON file
| GCS_FLUSH_INTERVAL | Flush interval for GCS logging (in seconds). Specify how often you want a log to be sent to GCS. **Default is 20 seconds**
| GCS_BATCH_SIZE | Batch size for GCS logging. Specify after how many logs you want to flush to GCS. If `BATCH_SIZE` is set to 10, logs are flushed every 10 logs. **Default is 2048**
| GCS_USE_BATCHED_LOGGING | Enable batched logging for GCS. When enabled (default), multiple log payloads are combined into single GCS object uploads (NDJSON format), dramatically reducing API calls. When disabled, sends each log individually as separate GCS objects (legacy behavior). **Default is true**
| GCS_PUBSUB_TOPIC_ID | PubSub Topic ID to send LiteLLM SpendLogs to.
| GCS_PUBSUB_PROJECT_ID | PubSub Project ID to send LiteLLM SpendLogs to.
| GENERIC_AUTHORIZATION_ENDPOINT | Authorization endpoint for generic OAuth providers
+201 -41
View File
@@ -1,9 +1,11 @@
import asyncio
import hashlib
import json
import os
import time
from litellm._uuid import uuid
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from urllib.parse import quote
from litellm._logging import verbose_logger
@@ -26,19 +28,21 @@ class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
super().__init__(bucket_name=bucket_name)
# Init Batch logging settings
self.log_queue: List[GCSLogQueueItem] = []
self.batch_size = int(os.getenv("GCS_BATCH_SIZE", GCS_DEFAULT_BATCH_SIZE))
self.flush_interval = int(
os.getenv("GCS_FLUSH_INTERVAL", GCS_DEFAULT_FLUSH_INTERVAL_SECONDS)
)
asyncio.create_task(self.periodic_flush())
self.use_batched_logging = (
os.getenv("GCS_USE_BATCHED_LOGGING", str(GCS_DEFAULT_USE_BATCHED_LOGGING).lower()).lower() == "true"
)
self.flush_lock = asyncio.Lock()
super().__init__(
flush_lock=self.flush_lock,
batch_size=self.batch_size,
flush_interval=self.flush_interval,
)
self.log_queue: asyncio.Queue[GCSLogQueueItem] = asyncio.Queue() # type: ignore[assignment]
asyncio.create_task(self.periodic_flush())
AdditionalLoggingUtils.__init__(self)
if premium_user is not True:
@@ -65,8 +69,7 @@ class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
)
if logging_payload is None:
raise ValueError("standard_logging_object not found in kwargs")
# Add to logging queue - this will be flushed periodically
self.log_queue.append(
await self.log_queue.put(
GCSLogQueueItem(
payload=logging_payload, kwargs=kwargs, response_obj=response_obj
)
@@ -89,7 +92,9 @@ class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
if logging_payload is None:
raise ValueError("standard_logging_object not found in kwargs")
# Add to logging queue - this will be flushed periodically
self.log_queue.append(
# Use asyncio.Queue.put() for thread-safe concurrent access
# If queue is full, this will block until space is available (backpressure)
await self.log_queue.put(
GCSLogQueueItem(
payload=logging_payload, kwargs=kwargs, response_obj=response_obj
)
@@ -98,28 +103,98 @@ class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
except Exception as e:
verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
async def async_send_batch(self):
def _drain_queue_batch(self) -> List[GCSLogQueueItem]:
"""
Process queued logs in batch - sends logs to GCS Bucket
GCS Bucket does not have a Batch endpoint to batch upload logs
Instead, we
- collect the logs to flush every `GCS_FLUSH_INTERVAL` seconds
- during async_send_batch, we make 1 POST request per log to GCS Bucket
Drain items from the queue (non-blocking), respecting batch_size limit.
This prevents unbounded queue growth when processing is slower than log accumulation.
Returns:
List of items to process, up to batch_size items
"""
if not self.log_queue:
return
items_to_process: List[GCSLogQueueItem] = []
while len(items_to_process) < self.batch_size:
try:
items_to_process.append(self.log_queue.get_nowait())
except asyncio.QueueEmpty:
break
return items_to_process
for log_item in self.log_queue:
logging_payload = log_item["payload"]
kwargs = log_item["kwargs"]
response_obj = log_item.get("response_obj", None) or {}
def _generate_batch_object_name(self, date_str: str, batch_id: str) -> str:
"""
Generate object name for a batched log file.
Format: {date}/batch-{batch_id}.ndjson
"""
return f"{date_str}/batch-{batch_id}.ndjson"
def _get_config_key(self, kwargs: Dict[str, Any]) -> str:
"""
Extract a synchronous grouping key from kwargs to group items by GCS config.
This allows us to batch items with the same bucket/credentials together.
Returns a string key that uniquely identifies the GCS config combination.
This key may contain sensitive information (bucket names, paths) - use _sanitize_config_key()
for logging purposes.
"""
standard_callback_dynamic_params = kwargs.get("standard_callback_dynamic_params", None) or {}
bucket_name = standard_callback_dynamic_params.get("gcs_bucket_name", None) or self.BUCKET_NAME or "default"
path_service_account = standard_callback_dynamic_params.get("gcs_path_service_account", None) or self.path_service_account_json or "default"
return f"{bucket_name}|{path_service_account}"
def _sanitize_config_key(self, config_key: str) -> str:
"""
Create a sanitized version of the config key for logging.
Uses a hash to avoid exposing sensitive bucket names or service account paths.
Returns a short hash prefix for safe logging.
"""
hash_obj = hashlib.sha256(config_key.encode('utf-8'))
return f"config-{hash_obj.hexdigest()[:8]}"
def _group_items_by_config(self, items: List[GCSLogQueueItem]) -> Dict[str, List[GCSLogQueueItem]]:
"""
Group items by their GCS config (bucket + credentials).
This ensures items with different configs are processed separately.
Returns a dict mapping config_key -> list of items with that config.
"""
grouped: Dict[str, List[GCSLogQueueItem]] = {}
for item in items:
config_key = self._get_config_key(item["kwargs"])
if config_key not in grouped:
grouped[config_key] = []
grouped[config_key].append(item)
return grouped
def _combine_payloads_to_ndjson(self, items: List[GCSLogQueueItem]) -> str:
"""
Combine multiple log payloads into newline-delimited JSON (NDJSON) format.
Each line is a valid JSON object representing one log entry.
"""
lines = []
for item in items:
logging_payload = item["payload"]
json_line = json.dumps(logging_payload, default=str, ensure_ascii=False)
lines.append(json_line)
return "\n".join(lines)
async def _send_grouped_batch(self, items: List[GCSLogQueueItem], config_key: str) -> Tuple[int, int]:
"""
Send a batch of items that share the same GCS config.
Returns:
(success_count, error_count)
"""
if not items:
return (0, 0)
first_kwargs = items[0]["kwargs"]
try:
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
kwargs
first_kwargs
)
headers = await self.construct_request_headers(
@@ -127,24 +202,92 @@ class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
object_name = self._get_object_name(kwargs, logging_payload, response_obj)
current_date = self._get_object_date_from_datetime(datetime.now(timezone.utc))
batch_id = f"{int(time.time() * 1000)}-{uuid.uuid4().hex[:8]}"
object_name = self._generate_batch_object_name(current_date, batch_id)
combined_payload = self._combine_payloads_to_ndjson(items)
await self._log_json_data_on_gcs(
headers=headers,
bucket_name=bucket_name,
object_name=object_name,
logging_payload=combined_payload,
)
success_count = len(items)
error_count = 0
return (success_count, error_count)
except Exception as e:
success_count = 0
error_count = len(items)
verbose_logger.exception(
f"GCS Bucket error logging batch payload to GCS bucket: {str(e)}"
)
return (success_count, error_count)
try:
await self._log_json_data_on_gcs(
headers=headers,
bucket_name=bucket_name,
object_name=object_name,
logging_payload=logging_payload,
)
except Exception as e:
# don't let one log item fail the entire batch
verbose_logger.exception(
f"GCS Bucket error logging payload to GCS bucket: {str(e)}"
)
pass
async def _send_individual_logs(self, items: List[GCSLogQueueItem]) -> None:
"""
Send each log individually as separate GCS objects (legacy behavior).
This is used when GCS_USE_BATCHED_LOGGING is disabled.
"""
for item in items:
await self._send_single_log_item(item)
# Clear the queue after processing
self.log_queue.clear()
async def _send_single_log_item(self, item: GCSLogQueueItem) -> None:
"""
Send a single log item to GCS as an individual object.
"""
try:
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
item["kwargs"]
)
headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
object_name = self._get_object_name(
kwargs=item["kwargs"],
logging_payload=item["payload"],
response_obj=item["response_obj"],
)
await self._log_json_data_on_gcs(
headers=headers,
bucket_name=bucket_name,
object_name=object_name,
logging_payload=item["payload"],
)
except Exception as e:
verbose_logger.exception(
f"GCS Bucket error logging individual payload to GCS bucket: {str(e)}"
)
async def async_send_batch(self):
"""
Process queued logs - sends logs to GCS Bucket.
If `GCS_USE_BATCHED_LOGGING` is enabled (default), batches multiple log payloads
into single GCS object uploads (NDJSON format), dramatically reducing API calls.
If disabled, sends each log individually as separate GCS objects (legacy behavior).
"""
items_to_process = self._drain_queue_batch()
if not items_to_process:
return
if self.use_batched_logging:
grouped_items = self._group_items_by_config(items_to_process)
for config_key, group_items in grouped_items.items():
await self._send_grouped_batch(group_items, config_key)
else:
await self._send_individual_logs(items_to_process)
def _get_object_name(
self, kwargs: Dict, logging_payload: StandardLoggingPayload, response_obj: Any
@@ -186,7 +329,6 @@ class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
"start_time_utc is required for getting a payload from GCS Bucket"
)
# Try current day, next day, and previous day
dates_to_try = [
start_time_utc,
start_time_utc + timedelta(days=1),
@@ -230,5 +372,23 @@ class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
def _get_object_date_from_datetime(self, datetime_obj: datetime) -> str:
return datetime_obj.strftime("%Y-%m-%d")
async def flush_queue(self):
"""
Override flush_queue to work with asyncio.Queue.
"""
await self.async_send_batch()
self.last_flush_time = time.time()
async def periodic_flush(self):
"""
Override periodic_flush to work with asyncio.Queue.
"""
while True:
await asyncio.sleep(self.flush_interval)
verbose_logger.debug(
f"GCS Bucket periodic flush after {self.flush_interval} seconds"
)
await self.flush_queue()
async def async_health_check(self) -> IntegrationHealthCheckStatus:
raise NotImplementedError("GCS Bucket does not support health check")
+1
View File
@@ -12,6 +12,7 @@ else:
GCS_DEFAULT_BATCH_SIZE = 2048
GCS_DEFAULT_FLUSH_INTERVAL_SECONDS = 20
GCS_DEFAULT_USE_BATCHED_LOGGING = True
class GCSLoggingConfig(TypedDict):