mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-17 16:48:54 +00:00
2f9519d286
- Add `_strip_image_b64_payloads` filter: rewrites `data[*].b64_json` in image-gen responses to a 4-byte placeholder before the cassette is saved. Image-edit and image-gen cassettes (193 MB / 184 MB / 104 MB / ...) will shrink to <100 KB on next record. Tests assert response shape only, so coverage is preserved. - Add `_normalize_multipart_boundary` filter: replaces httpx's per-request random multipart boundary with a fixed string in both Content-Type header and body bytes. Audio-transcription / Whisper tests have been effectively unmocked — every CI run hit live providers and was silently capped at MAX_EPISODES_PER_CASSETTE=50. Both record and replay now see identical bytes; the safe_body matcher works. - Fix test_evals_api.py body poisoning: replace `int(time.time())` in eval names with `hashlib.sha1(test_node_name)[:12]`, add a function-scoped `managed_eval` fixture that creates and deletes the eval, and switch `get_eval` / `update_eval` from `list_evals().data[0].id` (which made the URL vary by run) to `managed_eval.id`. Net coverage gain: delete is now actually exercised. - Swap arxiv PDF URL in BaseOCRTest for the in-repo `dummy.pdf` (589 B) served via sha-pinned jsdelivr. - Swap etsystatic image URL in BaseLLMChatTest.test_image_url for the in-repo LiteLLM logo (9.2 KB) served via the same jsdelivr pin. - Add `tests/llm_translation/test_vcr_filters.py` with 14 unit tests covering both new filters: replacement, idempotency, nesting, content- length update, two-distinct-boundaries-converge-after-normalize, etc. Cassettes recorded with the prior patterns will mismatch on the first CI run after merge; recommend flushing the cassette Redis once (post-merge) so re-records save under the new format from the start.
204 lines
8.2 KiB
Python
204 lines
8.2 KiB
Python
"""
|
|
Base test class for OCR functionality across different providers.
|
|
|
|
This follows the same pattern as BaseLLMChatTest in tests/llm_translation/base_llm_unit_tests.py
|
|
"""
|
|
|
|
import pytest
|
|
import litellm
|
|
import os
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
|
# Test resources
|
|
TEST_IMAGE_PATH = "test_image_edit.png"
|
|
# Tiny in-repo PDF served via jsdelivr (sha-pinned, immutable). The arxiv
|
|
# PDF previously used here was several MB — once base64-encoded into the
|
|
# Vertex OCR request it ballooned cassettes past 100 MB per test. Keep
|
|
# the URL stable across runs so cassettes don't churn.
|
|
TEST_PDF_URL = (
|
|
"https://cdn.jsdelivr.net/gh/BerriAI/litellm"
|
|
"@d769e81c90d453240c61fc572cdb27fae06a89d0"
|
|
"/tests/llm_translation/fixtures/dummy.pdf"
|
|
)
|
|
|
|
|
|
class BaseOCRTest(ABC):
|
|
"""
|
|
Abstract base test class that enforces common OCR tests across all providers.
|
|
|
|
Each provider-specific test class should inherit from this and implement
|
|
get_base_ocr_call_args() to return provider-specific configuration.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def get_base_ocr_call_args(self) -> dict:
|
|
"""Must return the base OCR call args for the specific provider"""
|
|
pass
|
|
|
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
|
@pytest.mark.asyncio
|
|
async def test_basic_ocr_with_url(self, sync_mode):
|
|
"""
|
|
Test basic OCR with a public URL.
|
|
"""
|
|
litellm._turn_on_debug()
|
|
base_ocr_call_args = self.get_base_ocr_call_args()
|
|
print("BASE OCR Call args=", base_ocr_call_args)
|
|
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
|
litellm.model_cost = litellm.get_model_cost_map(url="")
|
|
|
|
try:
|
|
if sync_mode:
|
|
response = litellm.ocr(
|
|
document={"type": "document_url", "document_url": TEST_PDF_URL},
|
|
**base_ocr_call_args,
|
|
)
|
|
else:
|
|
response = await litellm.aocr(
|
|
document={"type": "document_url", "document_url": TEST_PDF_URL},
|
|
**base_ocr_call_args,
|
|
)
|
|
|
|
print(f"\n{'='*80}")
|
|
print(f"Sync Mode: {sync_mode}")
|
|
print(f"Response type: {type(response)}")
|
|
print(
|
|
f"Response object: {response.object if hasattr(response, 'object') else 'N/A'}"
|
|
)
|
|
|
|
# Check if response has expected OCR format
|
|
assert hasattr(response, "pages"), "Response should have 'pages' attribute"
|
|
assert hasattr(response, "model"), "Response should have 'model' attribute"
|
|
assert hasattr(
|
|
response, "object"
|
|
), "Response should have 'object' attribute"
|
|
assert (
|
|
response.object == "ocr"
|
|
), f"Expected object='ocr', got '{response.object}'"
|
|
|
|
# Validate pages structure
|
|
assert isinstance(response.pages, list), "pages should be a list"
|
|
assert len(response.pages) > 0, "Should have at least one page"
|
|
|
|
# Check first page structure
|
|
first_page = response.pages[0]
|
|
assert hasattr(first_page, "index"), "Page should have 'index' attribute"
|
|
assert hasattr(
|
|
first_page, "markdown"
|
|
), "Page should have 'markdown' attribute"
|
|
|
|
# Extract text from all pages for validation
|
|
total_text = "\n\n".join(
|
|
page.markdown for page in response.pages if page.markdown
|
|
)
|
|
print(f"Total pages: {len(response.pages)}")
|
|
print(f"Total extracted text length: {len(total_text)} characters")
|
|
print(f"First 200 chars: {total_text[:200]}")
|
|
print(f"Model: {response.model}")
|
|
if response.usage_info:
|
|
print(f"Pages processed: {response.usage_info.pages_processed}")
|
|
print(f"{'='*80}\n")
|
|
|
|
assert len(total_text) > 0, "Should extract some text from the document"
|
|
|
|
#########################################################
|
|
# validate we get a response cost in hidden parameters
|
|
#########################################################
|
|
hidden_params = response._hidden_params
|
|
assert isinstance(
|
|
hidden_params, dict
|
|
), "Hidden parameters should be a dictionary"
|
|
|
|
print("response usage_info:", response.usage_info)
|
|
|
|
response_cost = hidden_params.get("response_cost")
|
|
assert (
|
|
response_cost is not None
|
|
), "Response cost should be in hidden parameters"
|
|
assert response_cost > 0, "Response cost should be greater than 0"
|
|
print("response_cost=", response_cost)
|
|
|
|
except litellm.RateLimitError as e:
|
|
error_msg = str(e)
|
|
if "Quota exceeded" in error_msg or "RESOURCE_EXHAUSTED" in error_msg:
|
|
pytest.skip(f"Quota exceeded - {error_msg}")
|
|
else:
|
|
pytest.skip(f"Rate limit exceeded - {error_msg}")
|
|
except litellm.InternalServerError:
|
|
pytest.skip("Model is overloaded")
|
|
except litellm.BadRequestError as e:
|
|
error_msg = str(e)
|
|
if (
|
|
"URL_REJECTED" in error_msg
|
|
or "Cannot fetch content from the provided URL" in error_msg
|
|
):
|
|
pytest.skip(f"URL rejected by provider - {error_msg}")
|
|
else:
|
|
pytest.fail(f"OCR call failed: {str(e)}")
|
|
except Exception as e:
|
|
pytest.fail(f"OCR call failed: {str(e)}")
|
|
|
|
def test_ocr_response_structure(self):
|
|
"""
|
|
Test that the OCR response has the correct structure.
|
|
"""
|
|
litellm.set_verbose = True
|
|
base_ocr_call_args = self.get_base_ocr_call_args()
|
|
|
|
try:
|
|
response = litellm.ocr(
|
|
document={"type": "document_url", "document_url": TEST_PDF_URL},
|
|
**base_ocr_call_args,
|
|
)
|
|
|
|
# Validate response structure
|
|
assert hasattr(response, "pages"), "Response should have 'pages' attribute"
|
|
assert hasattr(response, "model"), "Response should have 'model' attribute"
|
|
assert hasattr(
|
|
response, "object"
|
|
), "Response should have 'object' attribute"
|
|
assert hasattr(
|
|
response, "usage_info"
|
|
), "Response should have 'usage_info' attribute"
|
|
|
|
assert isinstance(response.pages, list), "pages should be a list"
|
|
assert len(response.pages) > 0, "Should have at least one page"
|
|
assert response.object == "ocr", "object should be 'ocr'"
|
|
|
|
# Validate first page structure
|
|
first_page = response.pages[0]
|
|
assert hasattr(first_page, "index"), "Page should have 'index' attribute"
|
|
assert hasattr(
|
|
first_page, "markdown"
|
|
), "Page should have 'markdown' attribute"
|
|
assert isinstance(first_page.markdown, str), "markdown should be a string"
|
|
|
|
print(f"\nResponse structure validated:")
|
|
print(f" - object: {response.object}")
|
|
print(f" - model: {response.model}")
|
|
print(f" - pages: {len(response.pages)}")
|
|
if response.usage_info:
|
|
print(f" - pages_processed: {response.usage_info.pages_processed}")
|
|
print(f" - doc_size_bytes: {response.usage_info.doc_size_bytes}")
|
|
|
|
except litellm.RateLimitError as e:
|
|
error_msg = str(e)
|
|
if "Quota exceeded" in error_msg or "RESOURCE_EXHAUSTED" in error_msg:
|
|
pytest.skip(f"Quota exceeded - {error_msg}")
|
|
else:
|
|
pytest.skip(f"Rate limit exceeded - {error_msg}")
|
|
except litellm.InternalServerError:
|
|
pytest.skip("Model is overloaded")
|
|
except litellm.BadRequestError as e:
|
|
error_msg = str(e)
|
|
if (
|
|
"URL_REJECTED" in error_msg
|
|
or "Cannot fetch content from the provided URL" in error_msg
|
|
):
|
|
pytest.skip(f"URL rejected by provider - {error_msg}")
|
|
else:
|
|
pytest.fail(f"OCR response structure test failed: {str(e)}")
|
|
except Exception as e:
|
|
pytest.fail(f"OCR response structure test failed: {str(e)}")
|