mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-27 01:07:02 +00:00
fix(gcs_bucket): prevent unbounded queue growth due to slow API calls (#19297)
This commit is contained in:
@@ -603,6 +603,7 @@ router_settings:
|
||||
| GCS_PATH_SERVICE_ACCOUNT | Path to the Google Cloud service account JSON file
|
||||
| GCS_FLUSH_INTERVAL | Flush interval for GCS logging (in seconds). Specify how often you want a log to be sent to GCS. **Default is 20 seconds**
|
||||
| GCS_BATCH_SIZE | Batch size for GCS logging. Specify after how many logs you want to flush to GCS. If `BATCH_SIZE` is set to 10, logs are flushed every 10 logs. **Default is 2048**
|
||||
| GCS_USE_BATCHED_LOGGING | Enable batched logging for GCS. When enabled (default), multiple log payloads are combined into single GCS object uploads (NDJSON format), dramatically reducing API calls. When disabled, sends each log individually as separate GCS objects (legacy behavior). **Default is true**
|
||||
| GCS_PUBSUB_TOPIC_ID | PubSub Topic ID to send LiteLLM SpendLogs to.
|
||||
| GCS_PUBSUB_PROJECT_ID | PubSub Project ID to send LiteLLM SpendLogs to.
|
||||
| GENERIC_AUTHORIZATION_ENDPOINT | Authorization endpoint for generic OAuth providers
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from litellm._uuid import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
from urllib.parse import quote
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
@@ -26,19 +28,21 @@ class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
|
||||
|
||||
super().__init__(bucket_name=bucket_name)
|
||||
|
||||
# Init Batch logging settings
|
||||
self.log_queue: List[GCSLogQueueItem] = []
|
||||
self.batch_size = int(os.getenv("GCS_BATCH_SIZE", GCS_DEFAULT_BATCH_SIZE))
|
||||
self.flush_interval = int(
|
||||
os.getenv("GCS_FLUSH_INTERVAL", GCS_DEFAULT_FLUSH_INTERVAL_SECONDS)
|
||||
)
|
||||
asyncio.create_task(self.periodic_flush())
|
||||
self.use_batched_logging = (
|
||||
os.getenv("GCS_USE_BATCHED_LOGGING", str(GCS_DEFAULT_USE_BATCHED_LOGGING).lower()).lower() == "true"
|
||||
)
|
||||
self.flush_lock = asyncio.Lock()
|
||||
super().__init__(
|
||||
flush_lock=self.flush_lock,
|
||||
batch_size=self.batch_size,
|
||||
flush_interval=self.flush_interval,
|
||||
)
|
||||
self.log_queue: asyncio.Queue[GCSLogQueueItem] = asyncio.Queue() # type: ignore[assignment]
|
||||
asyncio.create_task(self.periodic_flush())
|
||||
AdditionalLoggingUtils.__init__(self)
|
||||
|
||||
if premium_user is not True:
|
||||
@@ -65,8 +69,7 @@ class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
|
||||
)
|
||||
if logging_payload is None:
|
||||
raise ValueError("standard_logging_object not found in kwargs")
|
||||
# Add to logging queue - this will be flushed periodically
|
||||
self.log_queue.append(
|
||||
await self.log_queue.put(
|
||||
GCSLogQueueItem(
|
||||
payload=logging_payload, kwargs=kwargs, response_obj=response_obj
|
||||
)
|
||||
@@ -89,7 +92,9 @@ class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
|
||||
if logging_payload is None:
|
||||
raise ValueError("standard_logging_object not found in kwargs")
|
||||
# Add to logging queue - this will be flushed periodically
|
||||
self.log_queue.append(
|
||||
# Use asyncio.Queue.put() for thread-safe concurrent access
|
||||
# If queue is full, this will block until space is available (backpressure)
|
||||
await self.log_queue.put(
|
||||
GCSLogQueueItem(
|
||||
payload=logging_payload, kwargs=kwargs, response_obj=response_obj
|
||||
)
|
||||
@@ -98,28 +103,98 @@ class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
|
||||
|
||||
async def async_send_batch(self):
|
||||
def _drain_queue_batch(self) -> List[GCSLogQueueItem]:
|
||||
"""
|
||||
Process queued logs in batch - sends logs to GCS Bucket
|
||||
|
||||
|
||||
GCS Bucket does not have a Batch endpoint to batch upload logs
|
||||
|
||||
Instead, we
|
||||
- collect the logs to flush every `GCS_FLUSH_INTERVAL` seconds
|
||||
- during async_send_batch, we make 1 POST request per log to GCS Bucket
|
||||
|
||||
Drain items from the queue (non-blocking), respecting batch_size limit.
|
||||
|
||||
This prevents unbounded queue growth when processing is slower than log accumulation.
|
||||
|
||||
Returns:
|
||||
List of items to process, up to batch_size items
|
||||
"""
|
||||
if not self.log_queue:
|
||||
return
|
||||
items_to_process: List[GCSLogQueueItem] = []
|
||||
while len(items_to_process) < self.batch_size:
|
||||
try:
|
||||
items_to_process.append(self.log_queue.get_nowait())
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
return items_to_process
|
||||
|
||||
for log_item in self.log_queue:
|
||||
logging_payload = log_item["payload"]
|
||||
kwargs = log_item["kwargs"]
|
||||
response_obj = log_item.get("response_obj", None) or {}
|
||||
def _generate_batch_object_name(self, date_str: str, batch_id: str) -> str:
|
||||
"""
|
||||
Generate object name for a batched log file.
|
||||
Format: {date}/batch-{batch_id}.ndjson
|
||||
"""
|
||||
return f"{date_str}/batch-{batch_id}.ndjson"
|
||||
|
||||
def _get_config_key(self, kwargs: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Extract a synchronous grouping key from kwargs to group items by GCS config.
|
||||
This allows us to batch items with the same bucket/credentials together.
|
||||
|
||||
Returns a string key that uniquely identifies the GCS config combination.
|
||||
This key may contain sensitive information (bucket names, paths) - use _sanitize_config_key()
|
||||
for logging purposes.
|
||||
"""
|
||||
standard_callback_dynamic_params = kwargs.get("standard_callback_dynamic_params", None) or {}
|
||||
|
||||
bucket_name = standard_callback_dynamic_params.get("gcs_bucket_name", None) or self.BUCKET_NAME or "default"
|
||||
path_service_account = standard_callback_dynamic_params.get("gcs_path_service_account", None) or self.path_service_account_json or "default"
|
||||
|
||||
return f"{bucket_name}|{path_service_account}"
|
||||
|
||||
def _sanitize_config_key(self, config_key: str) -> str:
|
||||
"""
|
||||
Create a sanitized version of the config key for logging.
|
||||
Uses a hash to avoid exposing sensitive bucket names or service account paths.
|
||||
|
||||
Returns a short hash prefix for safe logging.
|
||||
"""
|
||||
hash_obj = hashlib.sha256(config_key.encode('utf-8'))
|
||||
return f"config-{hash_obj.hexdigest()[:8]}"
|
||||
|
||||
def _group_items_by_config(self, items: List[GCSLogQueueItem]) -> Dict[str, List[GCSLogQueueItem]]:
|
||||
"""
|
||||
Group items by their GCS config (bucket + credentials).
|
||||
This ensures items with different configs are processed separately.
|
||||
|
||||
Returns a dict mapping config_key -> list of items with that config.
|
||||
"""
|
||||
grouped: Dict[str, List[GCSLogQueueItem]] = {}
|
||||
for item in items:
|
||||
config_key = self._get_config_key(item["kwargs"])
|
||||
if config_key not in grouped:
|
||||
grouped[config_key] = []
|
||||
grouped[config_key].append(item)
|
||||
return grouped
|
||||
|
||||
def _combine_payloads_to_ndjson(self, items: List[GCSLogQueueItem]) -> str:
|
||||
"""
|
||||
Combine multiple log payloads into newline-delimited JSON (NDJSON) format.
|
||||
Each line is a valid JSON object representing one log entry.
|
||||
"""
|
||||
lines = []
|
||||
for item in items:
|
||||
logging_payload = item["payload"]
|
||||
json_line = json.dumps(logging_payload, default=str, ensure_ascii=False)
|
||||
lines.append(json_line)
|
||||
return "\n".join(lines)
|
||||
|
||||
async def _send_grouped_batch(self, items: List[GCSLogQueueItem], config_key: str) -> Tuple[int, int]:
|
||||
"""
|
||||
Send a batch of items that share the same GCS config.
|
||||
|
||||
Returns:
|
||||
(success_count, error_count)
|
||||
"""
|
||||
if not items:
|
||||
return (0, 0)
|
||||
|
||||
first_kwargs = items[0]["kwargs"]
|
||||
|
||||
try:
|
||||
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
||||
kwargs
|
||||
first_kwargs
|
||||
)
|
||||
|
||||
headers = await self.construct_request_headers(
|
||||
@@ -127,24 +202,92 @@ class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
|
||||
service_account_json=gcs_logging_config["path_service_account"],
|
||||
)
|
||||
bucket_name = gcs_logging_config["bucket_name"]
|
||||
object_name = self._get_object_name(kwargs, logging_payload, response_obj)
|
||||
|
||||
current_date = self._get_object_date_from_datetime(datetime.now(timezone.utc))
|
||||
batch_id = f"{int(time.time() * 1000)}-{uuid.uuid4().hex[:8]}"
|
||||
object_name = self._generate_batch_object_name(current_date, batch_id)
|
||||
combined_payload = self._combine_payloads_to_ndjson(items)
|
||||
|
||||
await self._log_json_data_on_gcs(
|
||||
headers=headers,
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
logging_payload=combined_payload,
|
||||
)
|
||||
|
||||
success_count = len(items)
|
||||
error_count = 0
|
||||
return (success_count, error_count)
|
||||
|
||||
except Exception as e:
|
||||
success_count = 0
|
||||
error_count = len(items)
|
||||
verbose_logger.exception(
|
||||
f"GCS Bucket error logging batch payload to GCS bucket: {str(e)}"
|
||||
)
|
||||
return (success_count, error_count)
|
||||
|
||||
try:
|
||||
await self._log_json_data_on_gcs(
|
||||
headers=headers,
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
logging_payload=logging_payload,
|
||||
)
|
||||
except Exception as e:
|
||||
# don't let one log item fail the entire batch
|
||||
verbose_logger.exception(
|
||||
f"GCS Bucket error logging payload to GCS bucket: {str(e)}"
|
||||
)
|
||||
pass
|
||||
async def _send_individual_logs(self, items: List[GCSLogQueueItem]) -> None:
|
||||
"""
|
||||
Send each log individually as separate GCS objects (legacy behavior).
|
||||
This is used when GCS_USE_BATCHED_LOGGING is disabled.
|
||||
"""
|
||||
for item in items:
|
||||
await self._send_single_log_item(item)
|
||||
|
||||
# Clear the queue after processing
|
||||
self.log_queue.clear()
|
||||
async def _send_single_log_item(self, item: GCSLogQueueItem) -> None:
|
||||
"""
|
||||
Send a single log item to GCS as an individual object.
|
||||
"""
|
||||
try:
|
||||
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
||||
item["kwargs"]
|
||||
)
|
||||
|
||||
headers = await self.construct_request_headers(
|
||||
vertex_instance=gcs_logging_config["vertex_instance"],
|
||||
service_account_json=gcs_logging_config["path_service_account"],
|
||||
)
|
||||
bucket_name = gcs_logging_config["bucket_name"]
|
||||
|
||||
object_name = self._get_object_name(
|
||||
kwargs=item["kwargs"],
|
||||
logging_payload=item["payload"],
|
||||
response_obj=item["response_obj"],
|
||||
)
|
||||
|
||||
await self._log_json_data_on_gcs(
|
||||
headers=headers,
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
logging_payload=item["payload"],
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"GCS Bucket error logging individual payload to GCS bucket: {str(e)}"
|
||||
)
|
||||
|
||||
async def async_send_batch(self):
|
||||
"""
|
||||
Process queued logs - sends logs to GCS Bucket.
|
||||
|
||||
If `GCS_USE_BATCHED_LOGGING` is enabled (default), batches multiple log payloads
|
||||
into single GCS object uploads (NDJSON format), dramatically reducing API calls.
|
||||
|
||||
If disabled, sends each log individually as separate GCS objects (legacy behavior).
|
||||
"""
|
||||
items_to_process = self._drain_queue_batch()
|
||||
|
||||
if not items_to_process:
|
||||
return
|
||||
|
||||
if self.use_batched_logging:
|
||||
grouped_items = self._group_items_by_config(items_to_process)
|
||||
|
||||
for config_key, group_items in grouped_items.items():
|
||||
await self._send_grouped_batch(group_items, config_key)
|
||||
else:
|
||||
await self._send_individual_logs(items_to_process)
|
||||
|
||||
def _get_object_name(
|
||||
self, kwargs: Dict, logging_payload: StandardLoggingPayload, response_obj: Any
|
||||
@@ -186,7 +329,6 @@ class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
|
||||
"start_time_utc is required for getting a payload from GCS Bucket"
|
||||
)
|
||||
|
||||
# Try current day, next day, and previous day
|
||||
dates_to_try = [
|
||||
start_time_utc,
|
||||
start_time_utc + timedelta(days=1),
|
||||
@@ -230,5 +372,23 @@ class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
|
||||
def _get_object_date_from_datetime(self, datetime_obj: datetime) -> str:
|
||||
return datetime_obj.strftime("%Y-%m-%d")
|
||||
|
||||
async def flush_queue(self):
|
||||
"""
|
||||
Override flush_queue to work with asyncio.Queue.
|
||||
"""
|
||||
await self.async_send_batch()
|
||||
self.last_flush_time = time.time()
|
||||
|
||||
async def periodic_flush(self):
|
||||
"""
|
||||
Override periodic_flush to work with asyncio.Queue.
|
||||
"""
|
||||
while True:
|
||||
await asyncio.sleep(self.flush_interval)
|
||||
verbose_logger.debug(
|
||||
f"GCS Bucket periodic flush after {self.flush_interval} seconds"
|
||||
)
|
||||
await self.flush_queue()
|
||||
|
||||
async def async_health_check(self) -> IntegrationHealthCheckStatus:
|
||||
raise NotImplementedError("GCS Bucket does not support health check")
|
||||
|
||||
@@ -12,6 +12,7 @@ else:
|
||||
|
||||
GCS_DEFAULT_BATCH_SIZE = 2048
|
||||
GCS_DEFAULT_FLUSH_INTERVAL_SECONDS = 20
|
||||
GCS_DEFAULT_USE_BATCHED_LOGGING = True
|
||||
|
||||
|
||||
class GCSLoggingConfig(TypedDict):
|
||||
|
||||
Reference in New Issue
Block a user