diff --git a/tests/pass_through_tests/test_assembly_ai.py b/tests/pass_through_tests/test_assembly_ai.py index 2d01ef2c1b..31bdf24009 100644 --- a/tests/pass_through_tests/test_assembly_ai.py +++ b/tests/pass_through_tests/test_assembly_ai.py @@ -2,43 +2,76 @@ This test ensures that the proxy can passthrough requests to assemblyai """ +import time + import pytest -import assemblyai as aai +import httpx import aiohttp import asyncio -import time TEST_MASTER_KEY = "sk-1234" TEST_BASE_URL = "http://0.0.0.0:4000/assemblyai" -def test_assemblyai_basic_transcribe(): - print("making basic transcribe request to assemblyai passthrough") +def _transcribe_and_verify(virtual_key: str, base_url: str): + file_url = "https://assembly.ai/wildfires.mp3" + headers = { + "Authorization": f"Bearer {virtual_key}", + "Content-Type": "application/json", + } + create_payload = { + "audio_url": file_url, + "speech_models": ["universal-2"], + } - # Replace with your API key - aai.settings.api_key = f"Bearer {TEST_MASTER_KEY}" - aai.settings.base_url = TEST_BASE_URL + create_response = httpx.post( + url=f"{base_url}/v2/transcript", + headers=headers, + json=create_payload, + timeout=60.0, + ) + if create_response.status_code != 200: + pytest.fail( + "Failed to create transcript request: " + f"status={create_response.status_code}, body={create_response.text}" + ) - # URL of the file to transcribe - FILE_URL = "https://assembly.ai/wildfires.mp3" - - # You can also transcribe a local file by passing in a file path - # FILE_URL = './path/to/file.mp3' - - transcriber = aai.Transcriber() - transcript = transcriber.transcribe(FILE_URL) - print(transcript) - print(transcript.id) - if transcript.id: - transcript.delete_by_id(transcript.id) - else: + transcript = create_response.json() + transcript_id = transcript.get("id") + if not transcript_id: pytest.fail("Failed to get transcript id") - if transcript.status == aai.TranscriptStatus.error: - print(transcript.error) - pytest.fail(f"Failed to transcribe file error: {transcript.error}") - else: - print(transcript.text) + for _ in range(60): + poll_response = httpx.get( + url=f"{base_url}/v2/transcript/{transcript_id}", + headers=headers, + timeout=30.0, + ) + if poll_response.status_code != 200: + pytest.fail( + "Failed to poll transcript status: " + f"status={poll_response.status_code}, body={poll_response.text}" + ) + transcript = poll_response.json() + if transcript.get("status") in ("completed", "error"): + break + time.sleep(1) + + httpx.delete( + url=f"{base_url}/v2/transcript/{transcript_id}", + headers=headers, + timeout=30.0, + ) + + if transcript.get("status") == "error": + pytest.fail(f"Failed to transcribe file error: {transcript.get('error')}") + + print(transcript.get("text")) + + +def test_assemblyai_basic_transcribe(): + print("making basic transcribe request to assemblyai passthrough") + _transcribe_and_verify(TEST_MASTER_KEY, TEST_BASE_URL) async def generate_key(calling_key: str) -> str: @@ -59,37 +92,10 @@ async def generate_key(calling_key: str) -> str: @pytest.mark.asyncio async def test_assemblyai_transcribe_with_non_admin_key(): - # Generate a non-admin key using the helper non_admin_key = await generate_key(TEST_MASTER_KEY) print(f"Generated non-admin key: {non_admin_key}") - # Use the non-admin key to transcribe - # Replace with your API key - aai.settings.api_key = f"Bearer {non_admin_key}" - aai.settings.base_url = TEST_BASE_URL - - # URL of the file to transcribe - FILE_URL = "https://assembly.ai/wildfires.mp3" - - # You can also transcribe a local file by passing in a file path - # FILE_URL = './path/to/file.mp3' - request_start_time = time.time() - - transcriber = aai.Transcriber() - transcript = transcriber.transcribe(FILE_URL) - print(transcript) - print(transcript.id) - if transcript.id: - transcript.delete_by_id(transcript.id) - else: - pytest.fail("Failed to get transcript id") - - if transcript.status == aai.TranscriptStatus.error: - print(transcript.error) - pytest.fail(f"Failed to transcribe file error: {transcript.error}") - else: - print(transcript.text) - + _transcribe_and_verify(non_admin_key, TEST_BASE_URL) request_end_time = time.time() print(f"Request took {request_end_time - request_start_time} seconds")