mirror of
https://github.com/tiennm99/goclaw.git
synced 2026-06-11 12:10:58 +00:00
feat(tools): add media provider chain with ordered fallback and retry
Refactor create_image and create_video to use a shared provider chain system.
Each tool now supports an ordered list of providers with per-entry timeout,
max retries, and provider-specific params. Includes MiniMax and DashScope
image/video generation implementations.
- New media_provider_chain.go: shared chain resolution, retry execution, limitedReadAll
- create_image: refactored to ExecuteWithChain, added MiniMax + DashScope providers
- create_video: refactored to ExecuteWithChain, added MiniMax async video generation
- Backward compatible with legacy {provider, model} settings format
This commit is contained in:
+51
-111
@@ -25,17 +25,18 @@ type credentialProvider interface {
|
||||
}
|
||||
|
||||
// imageGenProviderPriority is the default order for image generation providers.
|
||||
var imageGenProviderPriority = []string{"openrouter", "gemini", "openai"}
|
||||
var imageGenProviderPriority = []string{"openrouter", "gemini", "openai", "minimax", "dashscope"}
|
||||
|
||||
// imageGenModelDefaults maps provider names to default image generation models.
|
||||
var imageGenModelDefaults = map[string]string{
|
||||
"openrouter": "google/gemini-2.5-flash-image",
|
||||
"openai": "dall-e-3",
|
||||
"gemini": "gemini-2.5-flash-image",
|
||||
"minimax": "image-01",
|
||||
"dashscope": "wan2.6-image",
|
||||
}
|
||||
|
||||
// CreateImageTool generates images using an image generation API.
|
||||
// Uses OpenRouter (Gemini image model) or OpenAI (DALL-E) via per-agent ImageGenConfig.
|
||||
type CreateImageTool struct {
|
||||
registry *providers.Registry
|
||||
}
|
||||
@@ -77,42 +78,26 @@ func (t *CreateImageTool) Execute(ctx context.Context, args map[string]interface
|
||||
aspectRatio = "1:1"
|
||||
}
|
||||
|
||||
// Resolve provider from per-agent config or defaults
|
||||
providerName, model := t.resolveConfig(ctx)
|
||||
|
||||
p, err := t.registry.Get(providerName)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("image generation provider %q not available", providerName))
|
||||
// Extract per-agent config for backward compat
|
||||
var perAgentProvider, perAgentModel string
|
||||
if cfg := ImageGenConfigFromCtx(ctx); cfg != nil {
|
||||
perAgentProvider = cfg.Provider
|
||||
perAgentModel = cfg.Model
|
||||
}
|
||||
|
||||
cp, ok := p.(credentialProvider)
|
||||
if !ok {
|
||||
return ErrorResult(fmt.Sprintf("provider %q does not expose API credentials for image generation", providerName))
|
||||
chain := ResolveMediaProviderChain(ctx, "create_image", perAgentProvider, perAgentModel,
|
||||
imageGenProviderPriority, imageGenModelDefaults, t.registry)
|
||||
|
||||
// Inject prompt and aspect_ratio into each chain entry's params
|
||||
for i := range chain {
|
||||
if chain[i].Params == nil {
|
||||
chain[i].Params = make(map[string]any)
|
||||
}
|
||||
chain[i].Params["prompt"] = prompt
|
||||
chain[i].Params["aspect_ratio"] = aspectRatio
|
||||
}
|
||||
|
||||
slog.Info("create_image: calling image generation API",
|
||||
"provider", providerName, "model", model, "aspect_ratio", aspectRatio)
|
||||
|
||||
// Route to the correct image generation endpoint per provider:
|
||||
// - gemini: native Gemini generateContent API (responseModalities)
|
||||
// - openrouter: OpenAI-compat /chat/completions with modalities
|
||||
// - others (openai, etc.): /images/generations
|
||||
var imageBytes []byte
|
||||
var usage *providers.Usage
|
||||
switch providerName {
|
||||
case "gemini":
|
||||
var genErr error
|
||||
imageBytes, usage, genErr = t.callGeminiNativeImageGen(ctx, cp.APIKey(), cp.APIBase(), model, prompt)
|
||||
err = genErr
|
||||
case "openrouter":
|
||||
var genErr error
|
||||
imageBytes, usage, genErr = t.callImageGenAPI(ctx, cp.APIKey(), cp.APIBase(), model, prompt, aspectRatio)
|
||||
err = genErr
|
||||
default:
|
||||
var genErr error
|
||||
imageBytes, usage, genErr = t.callStandardImageGenAPI(ctx, cp.APIKey(), cp.APIBase(), model, prompt)
|
||||
err = genErr
|
||||
}
|
||||
chainResult, err := ExecuteWithChain(ctx, chain, t.registry, t.callProvider)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("image generation failed: %v", err))
|
||||
}
|
||||
@@ -127,83 +112,46 @@ func (t *CreateImageTool) Execute(ctx context.Context, args map[string]interface
|
||||
return ErrorResult(fmt.Sprintf("failed to create output directory: %v", err))
|
||||
}
|
||||
imagePath := filepath.Join(dateDir, fmt.Sprintf("goclaw_gen_%d.png", time.Now().UnixNano()))
|
||||
if err := os.WriteFile(imagePath, imageBytes, 0644); err != nil {
|
||||
if err := os.WriteFile(imagePath, chainResult.Data, 0644); err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to save generated image: %v", err))
|
||||
}
|
||||
|
||||
result := &Result{ForLLM: fmt.Sprintf("MEDIA:%s", imagePath)}
|
||||
result.Media = []bus.MediaFile{{Path: imagePath, MimeType: "image/png"}}
|
||||
result.Deliverable = fmt.Sprintf("[Generated image: %s]\nPrompt: %s", filepath.Base(imagePath), prompt)
|
||||
result.Provider = providerName
|
||||
result.Model = model
|
||||
if usage != nil {
|
||||
result.Usage = usage
|
||||
result.Provider = chainResult.Provider
|
||||
result.Model = chainResult.Model
|
||||
if chainResult.Usage != nil {
|
||||
result.Usage = chainResult.Usage
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// resolveConfig returns the provider name and model to use for image generation.
|
||||
func (t *CreateImageTool) resolveConfig(ctx context.Context) (providerName, model string) {
|
||||
// 1. Check per-agent ImageGenConfig from context (highest priority)
|
||||
if cfg := ImageGenConfigFromCtx(ctx); cfg != nil {
|
||||
if cfg.Provider != "" {
|
||||
providerName = cfg.Provider
|
||||
}
|
||||
if cfg.Model != "" {
|
||||
model = cfg.Model
|
||||
}
|
||||
}
|
||||
// callProvider dispatches to the correct image generation implementation based on provider type.
|
||||
func (t *CreateImageTool) callProvider(ctx context.Context, cp credentialProvider, providerName, model string, params map[string]any) ([]byte, *providers.Usage, error) {
|
||||
prompt := GetParamString(params, "prompt", "")
|
||||
aspectRatio := GetParamString(params, "aspect_ratio", "1:1")
|
||||
|
||||
// 2. Check global builtin_tools.settings (DB defaults)
|
||||
if providerName == "" || model == "" {
|
||||
if settings := BuiltinToolSettingsFromCtx(ctx); settings != nil {
|
||||
if raw, ok := settings["create_image"]; ok && len(raw) > 0 {
|
||||
var cfg struct {
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
if json.Unmarshal(raw, &cfg) == nil && cfg.Provider != "" {
|
||||
// DB settings are a provider+model pair — only use if provider is available
|
||||
if _, err := t.registry.Get(cfg.Provider); err == nil {
|
||||
if providerName == "" {
|
||||
providerName = cfg.Provider
|
||||
}
|
||||
if model == "" && cfg.Model != "" {
|
||||
model = cfg.Model
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
slog.Info("create_image: calling image generation API",
|
||||
"provider", providerName, "model", model, "aspect_ratio", aspectRatio)
|
||||
|
||||
// 3. If provider not set, find first available from priority list
|
||||
if providerName == "" {
|
||||
for _, name := range imageGenProviderPriority {
|
||||
if _, err := t.registry.Get(name); err == nil {
|
||||
providerName = name
|
||||
break
|
||||
}
|
||||
}
|
||||
switch ProviderTypeFromName(providerName) {
|
||||
case "gemini":
|
||||
return t.callGeminiNativeImageGen(ctx, cp.APIKey(), cp.APIBase(), model, prompt, params)
|
||||
case "openrouter":
|
||||
return t.callImageGenAPI(ctx, cp.APIKey(), cp.APIBase(), model, prompt, aspectRatio, params)
|
||||
case "minimax":
|
||||
return callMinimaxImageGen(ctx, cp.APIKey(), cp.APIBase(), model, prompt, params)
|
||||
case "dashscope":
|
||||
return callDashScopeImageGen(ctx, cp.APIKey(), cp.APIBase(), model, prompt, params)
|
||||
default:
|
||||
return t.callStandardImageGenAPI(ctx, cp.APIKey(), cp.APIBase(), model, prompt, params)
|
||||
}
|
||||
if providerName == "" {
|
||||
providerName = "openrouter" // fallback even if unavailable (error handled later)
|
||||
}
|
||||
|
||||
// 4. If model not set, use default for this provider
|
||||
if model == "" {
|
||||
if m, ok := imageGenModelDefaults[providerName]; ok {
|
||||
model = m
|
||||
}
|
||||
}
|
||||
|
||||
return providerName, model
|
||||
}
|
||||
|
||||
// callImageGenAPI calls the OpenAI-compatible image generation endpoint.
|
||||
// Works with OpenRouter (modalities: ["image","text"]) and OpenAI (/images/generations).
|
||||
func (t *CreateImageTool) callImageGenAPI(ctx context.Context, apiKey, apiBase, model, prompt, aspectRatio string) ([]byte, *providers.Usage, error) {
|
||||
// OpenRouter / OpenAI-compat: use chat completions with modalities
|
||||
// callImageGenAPI calls the OpenAI-compatible chat completions endpoint with image modalities.
|
||||
// Works with OpenRouter (modalities: ["image","text"]).
|
||||
func (t *CreateImageTool) callImageGenAPI(ctx context.Context, apiKey, apiBase, model, prompt, aspectRatio string, params map[string]any) ([]byte, *providers.Usage, error) {
|
||||
body := map[string]interface{}{
|
||||
"model": model,
|
||||
"messages": []map[string]interface{}{
|
||||
@@ -230,7 +178,7 @@ func (t *CreateImageTool) callImageGenAPI(ctx context.Context, apiKey, apiBase,
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
client := &http.Client{Timeout: 120 * time.Second}
|
||||
client := &http.Client{} // timeout governed by chain context
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("http request: %w", err)
|
||||
@@ -241,7 +189,6 @@ func (t *CreateImageTool) callImageGenAPI(ctx context.Context, apiKey, apiBase,
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, nil, fmt.Errorf("API error %d: %s", resp.StatusCode, truncateBytes(respBody, 500))
|
||||
}
|
||||
@@ -249,9 +196,8 @@ func (t *CreateImageTool) callImageGenAPI(ctx context.Context, apiKey, apiBase,
|
||||
return t.parseImageResponse(respBody)
|
||||
}
|
||||
|
||||
// callStandardImageGenAPI uses the /images/generations endpoint (Gemini, OpenAI, and compatible providers).
|
||||
// This is the standard OpenAI-compatible image generation endpoint that returns b64_json data.
|
||||
func (t *CreateImageTool) callStandardImageGenAPI(ctx context.Context, apiKey, apiBase, model, prompt string) ([]byte, *providers.Usage, error) {
|
||||
// callStandardImageGenAPI uses the /images/generations endpoint (OpenAI and compatible providers).
|
||||
func (t *CreateImageTool) callStandardImageGenAPI(ctx context.Context, apiKey, apiBase, model, prompt string, params map[string]any) ([]byte, *providers.Usage, error) {
|
||||
body := map[string]interface{}{
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
@@ -272,7 +218,7 @@ func (t *CreateImageTool) callStandardImageGenAPI(ctx context.Context, apiKey, a
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
client := &http.Client{Timeout: 120 * time.Second}
|
||||
client := &http.Client{} // timeout governed by chain context
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("http request: %w", err)
|
||||
@@ -283,12 +229,10 @@ func (t *CreateImageTool) callStandardImageGenAPI(ctx context.Context, apiKey, a
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, nil, fmt.Errorf("API error %d: %s", resp.StatusCode, truncateBytes(respBody, 500))
|
||||
}
|
||||
|
||||
// Parse OpenAI-compat images/generations response: {data: [{b64_json: "..."}]}
|
||||
var imgResp struct {
|
||||
Data []struct {
|
||||
B64JSON string `json:"b64_json"`
|
||||
@@ -310,9 +254,8 @@ func (t *CreateImageTool) callStandardImageGenAPI(ctx context.Context, apiKey, a
|
||||
}
|
||||
|
||||
// callGeminiNativeImageGen uses the native Gemini generateContent API with responseModalities.
|
||||
// Gemini image models (gemini-2.5-flash-image, gemini-3.1-flash-image-preview) require this
|
||||
// endpoint — they don't support the OpenAI-compat /images/generations or /chat/completions.
|
||||
func (t *CreateImageTool) callGeminiNativeImageGen(ctx context.Context, apiKey, apiBase, model, prompt string) ([]byte, *providers.Usage, error) {
|
||||
// Gemini image models require this endpoint — they don't support OpenAI-compat endpoints.
|
||||
func (t *CreateImageTool) callGeminiNativeImageGen(ctx context.Context, apiKey, apiBase, model, prompt string, params map[string]any) ([]byte, *providers.Usage, error) {
|
||||
// Derive native Gemini base from OpenAI-compat base (strip /openai suffix)
|
||||
nativeBase := strings.TrimRight(apiBase, "/")
|
||||
nativeBase = strings.TrimSuffix(nativeBase, "/openai")
|
||||
@@ -339,7 +282,7 @@ func (t *CreateImageTool) callGeminiNativeImageGen(ctx context.Context, apiKey,
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 120 * time.Second}
|
||||
client := &http.Client{} // timeout governed by chain context
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("http request: %w", err)
|
||||
@@ -350,7 +293,6 @@ func (t *CreateImageTool) callGeminiNativeImageGen(ctx context.Context, apiKey,
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, nil, fmt.Errorf("API error %d: %s", resp.StatusCode, truncateBytes(respBody, 500))
|
||||
}
|
||||
@@ -426,7 +368,6 @@ func (t *CreateImageTool) parseImageResponse(respBody []byte) ([]byte, *provider
|
||||
if err := json.Unmarshal(respBody, &resp); err != nil {
|
||||
return nil, nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
if len(resp.Choices) == 0 {
|
||||
return nil, nil, fmt.Errorf("no choices in response")
|
||||
}
|
||||
@@ -462,7 +403,6 @@ func (t *CreateImageTool) parseImageResponse(respBody []byte) ([]byte, *provider
|
||||
|
||||
// decodeDataURL decodes a data:image/...;base64,... URL into raw bytes.
|
||||
func decodeDataURL(dataURL string) ([]byte, error) {
|
||||
// Format: data:image/png;base64,iVBORw0KGgo...
|
||||
idx := strings.Index(dataURL, ";base64,")
|
||||
if idx < 0 {
|
||||
return nil, fmt.Errorf("not a base64 data URL")
|
||||
|
||||
@@ -0,0 +1,231 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/nextlevelbuilder/goclaw/internal/providers"
|
||||
)
|
||||
|
||||
// dashScopeImageEndpoint derives the DashScope multimodal generation endpoint from the
|
||||
// stored api_base. The api_base in DB is typically an OpenAI-compat URL such as
|
||||
// https://dashscope-intl.aliyuncs.com/compatible-mode/v1
|
||||
// The real image generation endpoint lives at a different path on the same host.
|
||||
func dashScopeImageEndpoint(apiBase string) string {
|
||||
base := strings.TrimRight(apiBase, "/")
|
||||
|
||||
// Known patterns — strip compat suffix to get the host, then build the real path.
|
||||
for _, suffix := range []string{
|
||||
"/compatible-mode/v1",
|
||||
"/compatible-mode",
|
||||
"/openai/v1",
|
||||
"/openai",
|
||||
"/v1",
|
||||
} {
|
||||
if strings.HasSuffix(base, suffix) {
|
||||
base = strings.TrimSuffix(base, suffix)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return base + "/api/v1/services/aigc/multimodal-generation/generation"
|
||||
}
|
||||
|
||||
// dashScopeTaskEndpoint returns the task polling URL for a given task_id.
|
||||
func dashScopeTaskEndpoint(apiBase, taskID string) string {
|
||||
base := strings.TrimRight(apiBase, "/")
|
||||
for _, suffix := range []string{
|
||||
"/compatible-mode/v1",
|
||||
"/compatible-mode",
|
||||
"/openai/v1",
|
||||
"/openai",
|
||||
"/v1",
|
||||
} {
|
||||
if strings.HasSuffix(base, suffix) {
|
||||
base = strings.TrimSuffix(base, suffix)
|
||||
break
|
||||
}
|
||||
}
|
||||
return base + "/api/v1/tasks/" + taskID
|
||||
}
|
||||
|
||||
// callDashScopeImageGen calls the DashScope (Alibaba/Bailian) multimodal image generation API.
|
||||
// The API is async: an initial POST returns a task_id, which is then polled until done.
|
||||
// On completion, output.results[].url contains the image URL to download.
|
||||
func callDashScopeImageGen(ctx context.Context, apiKey, apiBase, model, prompt string, params map[string]any) ([]byte, *providers.Usage, error) {
|
||||
size := GetParamString(params, "size", "1024*1024")
|
||||
promptExtend := GetParamBool(params, "prompt_extend", true)
|
||||
|
||||
endpoint := dashScopeImageEndpoint(apiBase)
|
||||
|
||||
body := map[string]interface{}{
|
||||
"model": model,
|
||||
"input": map[string]interface{}{
|
||||
"messages": []map[string]interface{}{
|
||||
{"role": "user", "content": prompt},
|
||||
},
|
||||
},
|
||||
"parameters": map[string]interface{}{
|
||||
"n": 1,
|
||||
"size": size,
|
||||
"prompt_extend": promptExtend,
|
||||
},
|
||||
}
|
||||
|
||||
jsonBody, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewReader(jsonBody))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
client := &http.Client{} // timeout governed by chain context
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("http request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, nil, fmt.Errorf("API error %d: %s", resp.StatusCode, truncateBytes(respBody, 500))
|
||||
}
|
||||
|
||||
// Parse initial response — may be synchronous (results present) or async (task_id present).
|
||||
var initResp struct {
|
||||
Output *struct {
|
||||
TaskID string `json:"task_id"`
|
||||
Results []struct {
|
||||
URL string `json:"url"`
|
||||
} `json:"results"`
|
||||
} `json:"output"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &initResp); err != nil {
|
||||
return nil, nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
if initResp.Output == nil {
|
||||
return nil, nil, fmt.Errorf("no output in DashScope response: %s", truncateBytes(respBody, 300))
|
||||
}
|
||||
|
||||
// Synchronous result already available
|
||||
if len(initResp.Output.Results) > 0 && initResp.Output.Results[0].URL != "" {
|
||||
return downloadImageURL(ctx, initResp.Output.Results[0].URL)
|
||||
}
|
||||
|
||||
// Async: poll the task until done
|
||||
if initResp.Output.TaskID == "" {
|
||||
return nil, nil, fmt.Errorf("no task_id and no results in DashScope response")
|
||||
}
|
||||
|
||||
return dashScopePollTask(ctx, apiKey, apiBase, initResp.Output.TaskID, client)
|
||||
}
|
||||
|
||||
// dashScopePollTask polls the DashScope task API until the task completes, then downloads
|
||||
// the result image. Max wait ~5 minutes (30 polls × 10s).
|
||||
func dashScopePollTask(ctx context.Context, apiKey, apiBase, taskID string, client *http.Client) ([]byte, *providers.Usage, error) {
|
||||
pollURL := dashScopeTaskEndpoint(apiBase, taskID)
|
||||
slog.Info("create_image: DashScope task started, polling", "task_id", taskID)
|
||||
|
||||
const maxPolls = 30
|
||||
const pollInterval = 10 * time.Second
|
||||
|
||||
for i := 0; i < maxPolls; i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, nil, ctx.Err()
|
||||
case <-time.After(pollInterval):
|
||||
}
|
||||
|
||||
pollReq, err := http.NewRequestWithContext(ctx, "GET", pollURL, nil)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create poll request: %w", err)
|
||||
}
|
||||
pollReq.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
pollResp, err := client.Do(pollReq)
|
||||
if err != nil {
|
||||
slog.Warn("create_image: DashScope poll error, retrying", "error", err, "attempt", i+1)
|
||||
continue
|
||||
}
|
||||
|
||||
pollBody, _ := io.ReadAll(pollResp.Body)
|
||||
pollResp.Body.Close()
|
||||
|
||||
if pollResp.StatusCode != http.StatusOK {
|
||||
return nil, nil, fmt.Errorf("poll API error %d: %s", pollResp.StatusCode, truncateBytes(pollBody, 500))
|
||||
}
|
||||
|
||||
var taskResp struct {
|
||||
Output *struct {
|
||||
TaskStatus string `json:"task_status"`
|
||||
Results []struct {
|
||||
URL string `json:"url"`
|
||||
} `json:"results"`
|
||||
} `json:"output"`
|
||||
}
|
||||
if err := json.Unmarshal(pollBody, &taskResp); err != nil {
|
||||
return nil, nil, fmt.Errorf("parse poll response: %w", err)
|
||||
}
|
||||
|
||||
if taskResp.Output == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
switch taskResp.Output.TaskStatus {
|
||||
case "SUCCEEDED":
|
||||
if len(taskResp.Output.Results) == 0 || taskResp.Output.Results[0].URL == "" {
|
||||
return nil, nil, fmt.Errorf("task succeeded but no image URL in results")
|
||||
}
|
||||
return downloadImageURL(ctx, taskResp.Output.Results[0].URL)
|
||||
case "FAILED":
|
||||
return nil, nil, fmt.Errorf("DashScope task %s failed", taskID)
|
||||
default:
|
||||
slog.Info("create_image: DashScope task pending", "attempt", i+1, "status", taskResp.Output.TaskStatus)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil, fmt.Errorf("DashScope task %s timed out after %d polls", taskID, maxPolls)
|
||||
}
|
||||
|
||||
// downloadImageURL downloads an image from a URL and returns the raw bytes.
|
||||
func downloadImageURL(ctx context.Context, imageURL string) ([]byte, *providers.Usage, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", imageURL, nil)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create download request: %w", err)
|
||||
}
|
||||
|
||||
client := &http.Client{} // timeout governed by chain context
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("download image: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, nil, fmt.Errorf("download error %d: %s", resp.StatusCode, truncateBytes(body, 300))
|
||||
}
|
||||
|
||||
imageBytes, err := limitedReadAll(resp.Body, maxMediaDownloadBytes)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("read image data: %w", err)
|
||||
}
|
||||
|
||||
return imageBytes, nil, nil
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/nextlevelbuilder/goclaw/internal/providers"
|
||||
)
|
||||
|
||||
// callMinimaxImageGen calls the MiniMax image generation API.
|
||||
// Endpoint: POST {apiBase}/image_generation
|
||||
// Response: base64 image data in data.image_list[0].base64_image
|
||||
func callMinimaxImageGen(ctx context.Context, apiKey, apiBase, model, prompt string, params map[string]any) ([]byte, *providers.Usage, error) {
|
||||
size := GetParamString(params, "size", "1024*1024")
|
||||
promptOptimizer := GetParamBool(params, "prompt_optimizer", true)
|
||||
|
||||
body := map[string]interface{}{
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"size": size,
|
||||
"num_images": 1,
|
||||
"enable_base64_output": true,
|
||||
"prompt_optimizer": promptOptimizer,
|
||||
}
|
||||
|
||||
jsonBody, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
|
||||
url := strings.TrimRight(apiBase, "/") + "/image_generation"
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
client := &http.Client{} // timeout governed by chain context
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("http request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, nil, fmt.Errorf("API error %d: %s", resp.StatusCode, truncateBytes(respBody, 500))
|
||||
}
|
||||
|
||||
var minimaxResp struct {
|
||||
Data *struct {
|
||||
ImageList []struct {
|
||||
Base64Image string `json:"base64_image"`
|
||||
} `json:"image_list"`
|
||||
} `json:"data"`
|
||||
BaseResp *struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
StatusMsg string `json:"status_msg"`
|
||||
} `json:"base_resp"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &minimaxResp); err != nil {
|
||||
return nil, nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
if minimaxResp.BaseResp != nil && minimaxResp.BaseResp.StatusCode != 0 {
|
||||
return nil, nil, fmt.Errorf("MiniMax API error %d: %s",
|
||||
minimaxResp.BaseResp.StatusCode, minimaxResp.BaseResp.StatusMsg)
|
||||
}
|
||||
|
||||
if minimaxResp.Data == nil || len(minimaxResp.Data.ImageList) == 0 {
|
||||
return nil, nil, fmt.Errorf("no image data in MiniMax response")
|
||||
}
|
||||
|
||||
b64 := minimaxResp.Data.ImageList[0].Base64Image
|
||||
if b64 == "" {
|
||||
return nil, nil, fmt.Errorf("empty base64_image in MiniMax response")
|
||||
}
|
||||
|
||||
imageBytes, err := base64.StdEncoding.DecodeString(b64)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("decode base64: %w", err)
|
||||
}
|
||||
|
||||
return imageBytes, nil, nil
|
||||
}
|
||||
@@ -18,11 +18,12 @@ import (
|
||||
)
|
||||
|
||||
// videoGenProviderPriority is the default order for video generation providers.
|
||||
var videoGenProviderPriority = []string{"gemini", "openrouter"}
|
||||
var videoGenProviderPriority = []string{"gemini", "minimax", "openrouter"}
|
||||
|
||||
// videoGenModelDefaults maps provider names to default video generation models.
|
||||
var videoGenModelDefaults = map[string]string{
|
||||
"gemini": "veo-3.0-generate-preview",
|
||||
"minimax": "MiniMax-Hailuo-2.3",
|
||||
"openrouter": "google/veo-3.0-generate-preview",
|
||||
}
|
||||
|
||||
@@ -94,29 +95,20 @@ func (t *CreateVideoTool) Execute(ctx context.Context, args map[string]interface
|
||||
}
|
||||
}
|
||||
|
||||
providerName, model := t.resolveConfig(ctx)
|
||||
chain := ResolveMediaProviderChain(ctx, "create_video", "", "",
|
||||
videoGenProviderPriority, videoGenModelDefaults, t.registry)
|
||||
|
||||
p, err := t.registry.Get(providerName)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("video generation provider %q not available", providerName))
|
||||
// Inject prompt, duration, and aspect_ratio into each chain entry's params.
|
||||
for i := range chain {
|
||||
if chain[i].Params == nil {
|
||||
chain[i].Params = make(map[string]any)
|
||||
}
|
||||
chain[i].Params["prompt"] = prompt
|
||||
chain[i].Params["duration"] = duration
|
||||
chain[i].Params["aspect_ratio"] = aspectRatio
|
||||
}
|
||||
|
||||
cp, ok := p.(credentialProvider)
|
||||
if !ok {
|
||||
return ErrorResult(fmt.Sprintf("provider %q does not expose API credentials for video generation", providerName))
|
||||
}
|
||||
|
||||
slog.Info("create_video: calling video generation API", "provider", providerName, "model", model, "duration", duration, "aspect_ratio", aspectRatio)
|
||||
|
||||
var videoBytes []byte
|
||||
var usage *providers.Usage
|
||||
switch providerName {
|
||||
case "gemini":
|
||||
videoBytes, usage, err = t.callGeminiVideoGen(ctx, cp.APIKey(), cp.APIBase(), model, prompt, duration, aspectRatio)
|
||||
default:
|
||||
// OpenRouter and others: try chat completions with VIDEO modality
|
||||
videoBytes, usage, err = t.callChatVideoGen(ctx, cp.APIKey(), cp.APIBase(), model, prompt, duration, aspectRatio)
|
||||
}
|
||||
chainResult, err := ExecuteWithChain(ctx, chain, t.registry, t.callProvider)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("video generation failed: %v", err))
|
||||
}
|
||||
@@ -131,60 +123,38 @@ func (t *CreateVideoTool) Execute(ctx context.Context, args map[string]interface
|
||||
return ErrorResult(fmt.Sprintf("failed to create output directory: %v", err))
|
||||
}
|
||||
videoPath := filepath.Join(dateDir, fmt.Sprintf("goclaw_gen_%d.mp4", time.Now().UnixNano()))
|
||||
if err := os.WriteFile(videoPath, videoBytes, 0644); err != nil {
|
||||
if err := os.WriteFile(videoPath, chainResult.Data, 0644); err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to save generated video: %v", err))
|
||||
}
|
||||
|
||||
result := &Result{ForLLM: fmt.Sprintf("MEDIA:%s", videoPath)}
|
||||
result.Media = []bus.MediaFile{{Path: videoPath, MimeType: "video/mp4"}}
|
||||
result.Deliverable = fmt.Sprintf("[Generated video: %s]\nPrompt: %s", filepath.Base(videoPath), prompt)
|
||||
result.Provider = providerName
|
||||
result.Model = model
|
||||
if usage != nil {
|
||||
result.Usage = usage
|
||||
result.Provider = chainResult.Provider
|
||||
result.Model = chainResult.Model
|
||||
if chainResult.Usage != nil {
|
||||
result.Usage = chainResult.Usage
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// resolveConfig returns the provider name and model for video generation.
|
||||
func (t *CreateVideoTool) resolveConfig(ctx context.Context) (providerName, model string) {
|
||||
// 1. Check global builtin_tools.settings
|
||||
if settings := BuiltinToolSettingsFromCtx(ctx); settings != nil {
|
||||
if raw, ok := settings["create_video"]; ok && len(raw) > 0 {
|
||||
var cfg struct {
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
if json.Unmarshal(raw, &cfg) == nil && cfg.Provider != "" {
|
||||
if _, err := t.registry.Get(cfg.Provider); err == nil {
|
||||
providerName = cfg.Provider
|
||||
model = cfg.Model
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// callProvider dispatches to the correct video generation implementation based on provider type.
|
||||
func (t *CreateVideoTool) callProvider(ctx context.Context, cp credentialProvider, providerName, model string, params map[string]any) ([]byte, *providers.Usage, error) {
|
||||
prompt := GetParamString(params, "prompt", "")
|
||||
duration := GetParamInt(params, "duration", 8)
|
||||
aspectRatio := GetParamString(params, "aspect_ratio", "16:9")
|
||||
|
||||
// 2. Find first available from priority list
|
||||
if providerName == "" {
|
||||
for _, name := range videoGenProviderPriority {
|
||||
if _, err := t.registry.Get(name); err == nil {
|
||||
providerName = name
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if providerName == "" {
|
||||
providerName = "gemini"
|
||||
}
|
||||
slog.Info("create_video: calling video generation API",
|
||||
"provider", providerName, "model", model, "duration", duration, "aspect_ratio", aspectRatio)
|
||||
|
||||
// 3. Default model
|
||||
if model == "" {
|
||||
if m, ok := videoGenModelDefaults[providerName]; ok {
|
||||
model = m
|
||||
}
|
||||
switch ProviderTypeFromName(providerName) {
|
||||
case "gemini":
|
||||
return t.callGeminiVideoGen(ctx, cp.APIKey(), cp.APIBase(), model, prompt, duration, aspectRatio)
|
||||
case "minimax":
|
||||
return callMinimaxVideoGen(ctx, cp.APIKey(), cp.APIBase(), model, params)
|
||||
default:
|
||||
return t.callChatVideoGen(ctx, cp.APIKey(), cp.APIBase(), model, prompt, duration, aspectRatio)
|
||||
}
|
||||
|
||||
return providerName, model
|
||||
}
|
||||
|
||||
// callGeminiVideoGen uses the Gemini predictLongRunning API for Veo video generation.
|
||||
@@ -219,7 +189,7 @@ func (t *CreateVideoTool) callGeminiVideoGen(ctx context.Context, apiKey, apiBas
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("x-goog-api-key", apiKey)
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
client := &http.Client{} // timeout governed by chain context
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("http request: %w", err)
|
||||
@@ -334,7 +304,7 @@ func (t *CreateVideoTool) callGeminiVideoGen(ctx context.Context, apiKey, apiBas
|
||||
}
|
||||
dlReq.Header.Set("x-goog-api-key", apiKey)
|
||||
|
||||
dlClient := &http.Client{Timeout: 120 * time.Second}
|
||||
dlClient := &http.Client{} // timeout governed by chain context
|
||||
dlResp, err := dlClient.Do(dlReq)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("download video: %w", err)
|
||||
@@ -346,7 +316,7 @@ func (t *CreateVideoTool) callGeminiVideoGen(ctx context.Context, apiKey, apiBas
|
||||
return nil, nil, fmt.Errorf("download error %d: %s", dlResp.StatusCode, truncateBytes(dlBody, 300))
|
||||
}
|
||||
|
||||
videoBytes, err := io.ReadAll(dlResp.Body)
|
||||
videoBytes, err := limitedReadAll(dlResp.Body, maxMediaDownloadBytes)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("read video data: %w", err)
|
||||
}
|
||||
@@ -379,7 +349,7 @@ func (t *CreateVideoTool) callChatVideoGen(ctx context.Context, apiKey, apiBase,
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
client := &http.Client{Timeout: 300 * time.Second}
|
||||
client := &http.Client{} // timeout governed by chain context
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("http request: %w", err)
|
||||
|
||||
@@ -0,0 +1,216 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/nextlevelbuilder/goclaw/internal/providers"
|
||||
)
|
||||
|
||||
// callMinimaxVideoGen calls the MiniMax video generation API (async with task polling).
|
||||
// Flow: POST /video_generation → poll /query/video_generation → download from file retrieve.
|
||||
func callMinimaxVideoGen(ctx context.Context, apiKey, apiBase, model string, params map[string]any) ([]byte, *providers.Usage, error) {
|
||||
prompt := GetParamString(params, "prompt", "")
|
||||
duration := GetParamInt(params, "duration", 6)
|
||||
resolution := GetParamString(params, "resolution", "720P")
|
||||
promptOptimizer := GetParamBool(params, "prompt_optimizer", true)
|
||||
fastPretreatment := GetParamBool(params, "fast_pretreatment", false)
|
||||
|
||||
base := strings.TrimRight(apiBase, "/")
|
||||
|
||||
// 1. Submit video generation task.
|
||||
submitBody := map[string]interface{}{
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"duration": duration,
|
||||
"resolution": resolution,
|
||||
"prompt_optimizer": promptOptimizer,
|
||||
}
|
||||
if fastPretreatment {
|
||||
submitBody["fast_pretreatment"] = true
|
||||
}
|
||||
|
||||
jsonBody, err := json.Marshal(submitBody)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
|
||||
submitURL := base + "/video_generation"
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", submitURL, bytes.NewReader(jsonBody))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
client := &http.Client{} // timeout governed by chain context
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("http request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, nil, fmt.Errorf("API error %d: %s", resp.StatusCode, truncateBytes(respBody, 500))
|
||||
}
|
||||
|
||||
var submitResp struct {
|
||||
TaskID string `json:"task_id"`
|
||||
BaseResp *struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
StatusMsg string `json:"status_msg"`
|
||||
} `json:"base_resp"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &submitResp); err != nil {
|
||||
return nil, nil, fmt.Errorf("parse submit response: %w", err)
|
||||
}
|
||||
if submitResp.BaseResp != nil && submitResp.BaseResp.StatusCode != 0 {
|
||||
return nil, nil, fmt.Errorf("MiniMax API error %d: %s",
|
||||
submitResp.BaseResp.StatusCode, submitResp.BaseResp.StatusMsg)
|
||||
}
|
||||
if submitResp.TaskID == "" {
|
||||
return nil, nil, fmt.Errorf("no task_id in MiniMax response: %s", truncateBytes(respBody, 300))
|
||||
}
|
||||
|
||||
slog.Info("create_video: MiniMax task submitted", "task_id", submitResp.TaskID)
|
||||
|
||||
// 2. Poll until done (max ~6 minutes, poll every 10s).
|
||||
pollURL := base + "/query/video_generation?task_id=" + submitResp.TaskID
|
||||
const maxPolls = 40
|
||||
const pollInterval = 10 * time.Second
|
||||
|
||||
var fileID string
|
||||
for i := 0; i < maxPolls; i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, nil, ctx.Err()
|
||||
case <-time.After(pollInterval):
|
||||
}
|
||||
|
||||
pollReq, err := http.NewRequestWithContext(ctx, "GET", pollURL, nil)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create poll request: %w", err)
|
||||
}
|
||||
pollReq.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
pollResp, err := client.Do(pollReq)
|
||||
if err != nil {
|
||||
slog.Warn("create_video: MiniMax poll error, retrying", "error", err, "attempt", i+1)
|
||||
continue
|
||||
}
|
||||
|
||||
pollBody, _ := io.ReadAll(pollResp.Body)
|
||||
pollResp.Body.Close()
|
||||
|
||||
if pollResp.StatusCode != http.StatusOK {
|
||||
return nil, nil, fmt.Errorf("poll API error %d: %s", pollResp.StatusCode, truncateBytes(pollBody, 500))
|
||||
}
|
||||
|
||||
var pollResult struct {
|
||||
Status string `json:"status"`
|
||||
FileID string `json:"file_id"`
|
||||
BaseResp *struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
StatusMsg string `json:"status_msg"`
|
||||
} `json:"base_resp"`
|
||||
}
|
||||
if err := json.Unmarshal(pollBody, &pollResult); err != nil {
|
||||
return nil, nil, fmt.Errorf("parse poll response: %w", err)
|
||||
}
|
||||
|
||||
if pollResult.BaseResp != nil && pollResult.BaseResp.StatusCode != 0 {
|
||||
return nil, nil, fmt.Errorf("MiniMax poll error %d: %s",
|
||||
pollResult.BaseResp.StatusCode, pollResult.BaseResp.StatusMsg)
|
||||
}
|
||||
|
||||
slog.Info("create_video: MiniMax polling", "attempt", i+1, "status", pollResult.Status)
|
||||
|
||||
switch pollResult.Status {
|
||||
case "Success":
|
||||
fileID = pollResult.FileID
|
||||
case "Failed":
|
||||
return nil, nil, fmt.Errorf("MiniMax video generation failed")
|
||||
}
|
||||
|
||||
if fileID != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if fileID == "" {
|
||||
return nil, nil, fmt.Errorf("MiniMax video generation timed out after %d polls", maxPolls)
|
||||
}
|
||||
|
||||
// 3. Retrieve download URL.
|
||||
retrieveURL := base + "/files/retrieve?file_id=" + fileID
|
||||
retrieveReq, err := http.NewRequestWithContext(ctx, "GET", retrieveURL, nil)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create retrieve request: %w", err)
|
||||
}
|
||||
retrieveReq.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
retrieveResp, err := client.Do(retrieveReq)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("retrieve file: %w", err)
|
||||
}
|
||||
defer retrieveResp.Body.Close()
|
||||
|
||||
retrieveBody, err := io.ReadAll(retrieveResp.Body)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("read retrieve response: %w", err)
|
||||
}
|
||||
if retrieveResp.StatusCode != http.StatusOK {
|
||||
return nil, nil, fmt.Errorf("retrieve API error %d: %s", retrieveResp.StatusCode, truncateBytes(retrieveBody, 500))
|
||||
}
|
||||
|
||||
var fileResp struct {
|
||||
File *struct {
|
||||
DownloadURL string `json:"download_url"`
|
||||
} `json:"file"`
|
||||
}
|
||||
if err := json.Unmarshal(retrieveBody, &fileResp); err != nil {
|
||||
return nil, nil, fmt.Errorf("parse retrieve response: %w", err)
|
||||
}
|
||||
if fileResp.File == nil || fileResp.File.DownloadURL == "" {
|
||||
return nil, nil, fmt.Errorf("no download_url in MiniMax file response: %s", truncateBytes(retrieveBody, 300))
|
||||
}
|
||||
|
||||
downloadURL := fileResp.File.DownloadURL
|
||||
slog.Info("create_video: MiniMax downloading video", "url", downloadURL)
|
||||
|
||||
// 4. Download the video.
|
||||
dlReq, err := http.NewRequestWithContext(ctx, "GET", downloadURL, nil)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create download request: %w", err)
|
||||
}
|
||||
|
||||
dlClient := &http.Client{} // timeout governed by chain context
|
||||
dlResp, err := dlClient.Do(dlReq)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("download video: %w", err)
|
||||
}
|
||||
defer dlResp.Body.Close()
|
||||
|
||||
if dlResp.StatusCode != http.StatusOK {
|
||||
dlBody, _ := io.ReadAll(dlResp.Body)
|
||||
return nil, nil, fmt.Errorf("download error %d: %s", dlResp.StatusCode, truncateBytes(dlBody, 300))
|
||||
}
|
||||
|
||||
videoBytes, err := limitedReadAll(dlResp.Body, maxMediaDownloadBytes)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("read video data: %w", err)
|
||||
}
|
||||
|
||||
return videoBytes, nil, nil
|
||||
}
|
||||
@@ -0,0 +1,325 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/nextlevelbuilder/goclaw/internal/providers"
|
||||
)
|
||||
|
||||
// MediaProviderEntry represents a single provider in an ordered fallback chain.
|
||||
type MediaProviderEntry struct {
|
||||
ProviderID string `json:"provider_id,omitempty"` // UUID for tracing (optional)
|
||||
Provider string `json:"provider"` // name for registry.Get()
|
||||
Model string `json:"model"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Timeout int `json:"timeout"` // seconds, default 120
|
||||
MaxRetries int `json:"max_retries"` // default 2
|
||||
Params map[string]any `json:"params,omitempty"` // provider-specific config
|
||||
}
|
||||
|
||||
// mediaProviderChain is the settings JSON structure for media tools.
|
||||
// Supports both new (providers array) and legacy (flat provider/model) formats.
|
||||
type mediaProviderChain struct {
|
||||
Providers []MediaProviderEntry `json:"providers,omitempty"`
|
||||
// Legacy fields (backward compat)
|
||||
Provider string `json:"provider,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
}
|
||||
|
||||
// applyDefaults fills in zero-value fields with sensible defaults.
|
||||
func (e *MediaProviderEntry) applyDefaults() {
|
||||
if e.Timeout <= 0 {
|
||||
e.Timeout = 120
|
||||
}
|
||||
if e.MaxRetries <= 0 {
|
||||
e.MaxRetries = 2
|
||||
}
|
||||
}
|
||||
|
||||
// ResolveMediaProviderChain parses builtin_tools.settings for a media tool and
|
||||
// returns an ordered list of enabled provider entries. Falls back to hardcoded
|
||||
// defaults when no user-configured chain exists.
|
||||
//
|
||||
// Resolution priority:
|
||||
// 1. Per-agent config override (provider/model from context) — wrapped as single entry
|
||||
// 2. builtin_tools.settings (new chain format or legacy flat format)
|
||||
// 3. Hardcoded default chain for the tool
|
||||
func ResolveMediaProviderChain(
|
||||
ctx context.Context,
|
||||
toolName string,
|
||||
perAgentProvider, perAgentModel string,
|
||||
defaultPriority []string,
|
||||
defaultModels map[string]string,
|
||||
registry *providers.Registry,
|
||||
) []MediaProviderEntry {
|
||||
// 1. Per-agent override takes highest priority
|
||||
if perAgentProvider != "" {
|
||||
model := perAgentModel
|
||||
if model == "" {
|
||||
model = defaultModels[perAgentProvider]
|
||||
}
|
||||
entry := MediaProviderEntry{
|
||||
Provider: perAgentProvider,
|
||||
Model: model,
|
||||
Enabled: true,
|
||||
}
|
||||
entry.applyDefaults()
|
||||
return []MediaProviderEntry{entry}
|
||||
}
|
||||
|
||||
// 2. Parse from builtin_tools.settings
|
||||
if settings := BuiltinToolSettingsFromCtx(ctx); settings != nil {
|
||||
if raw, ok := settings[toolName]; ok && len(raw) > 0 {
|
||||
chain := parseChainSettings(raw, defaultModels)
|
||||
if len(chain) > 0 {
|
||||
return chain
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Hardcoded default chain — use first available provider
|
||||
return buildDefaultChain(defaultPriority, defaultModels, registry)
|
||||
}
|
||||
|
||||
// parseChainSettings parses the settings JSON into a chain, handling both new
|
||||
// and legacy formats. Returns nil if parsing fails or result is empty.
|
||||
func parseChainSettings(raw []byte, defaultModels map[string]string) []MediaProviderEntry {
|
||||
var chain mediaProviderChain
|
||||
if err := json.Unmarshal(raw, &chain); err != nil {
|
||||
slog.Warn("media_chain: failed to parse settings", "error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// New format: providers array
|
||||
if len(chain.Providers) > 0 {
|
||||
var result []MediaProviderEntry
|
||||
for _, e := range chain.Providers {
|
||||
if !e.Enabled {
|
||||
continue
|
||||
}
|
||||
if e.Provider == "" {
|
||||
continue
|
||||
}
|
||||
if e.Model == "" {
|
||||
e.Model = defaultModels[e.Provider]
|
||||
}
|
||||
e.applyDefaults()
|
||||
result = append(result, e)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Legacy format: flat provider/model
|
||||
if chain.Provider != "" {
|
||||
model := chain.Model
|
||||
if model == "" {
|
||||
model = defaultModels[chain.Provider]
|
||||
}
|
||||
entry := MediaProviderEntry{
|
||||
Provider: chain.Provider,
|
||||
Model: model,
|
||||
Enabled: true,
|
||||
}
|
||||
entry.applyDefaults()
|
||||
return []MediaProviderEntry{entry}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildDefaultChain creates a chain from the hardcoded priority list,
|
||||
// including only providers that are currently registered.
|
||||
func buildDefaultChain(
|
||||
priority []string,
|
||||
defaultModels map[string]string,
|
||||
registry *providers.Registry,
|
||||
) []MediaProviderEntry {
|
||||
var chain []MediaProviderEntry
|
||||
for _, name := range priority {
|
||||
if _, err := registry.Get(name); err == nil {
|
||||
entry := MediaProviderEntry{
|
||||
Provider: name,
|
||||
Model: defaultModels[name],
|
||||
Enabled: true,
|
||||
}
|
||||
entry.applyDefaults()
|
||||
chain = append(chain, entry)
|
||||
}
|
||||
}
|
||||
return chain
|
||||
}
|
||||
|
||||
// ChainCallFn is the function signature for provider-specific API calls.
|
||||
// Receives the credential provider, provider name, model, and params.
|
||||
type ChainCallFn func(ctx context.Context, cp credentialProvider, providerName, model string, params map[string]any) ([]byte, *providers.Usage, error)
|
||||
|
||||
// ChainResult holds the result of ExecuteWithChain.
|
||||
type ChainResult struct {
|
||||
Data []byte
|
||||
Usage *providers.Usage
|
||||
Provider string
|
||||
Model string
|
||||
}
|
||||
|
||||
// ExecuteWithChain tries each provider in the chain sequentially.
|
||||
// For each provider, it retries up to MaxRetries times (with the configured timeout).
|
||||
// Returns the first successful result or the last error encountered.
|
||||
func ExecuteWithChain(
|
||||
ctx context.Context,
|
||||
chain []MediaProviderEntry,
|
||||
registry *providers.Registry,
|
||||
fn ChainCallFn,
|
||||
) (*ChainResult, error) {
|
||||
if len(chain) == 0 {
|
||||
return nil, fmt.Errorf("no providers configured")
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for _, entry := range chain {
|
||||
p, err := registry.Get(entry.Provider)
|
||||
if err != nil {
|
||||
slog.Warn("media_chain: provider not found, skipping",
|
||||
"provider", entry.Provider, "error", err)
|
||||
lastErr = fmt.Errorf("provider %q not available", entry.Provider)
|
||||
continue
|
||||
}
|
||||
|
||||
cp, ok := p.(credentialProvider)
|
||||
if !ok {
|
||||
slog.Warn("media_chain: provider does not expose credentials, skipping",
|
||||
"provider", entry.Provider)
|
||||
lastErr = fmt.Errorf("provider %q does not expose API credentials", entry.Provider)
|
||||
continue
|
||||
}
|
||||
|
||||
// Retry loop for this provider
|
||||
for attempt := 1; attempt <= entry.MaxRetries; attempt++ {
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, time.Duration(entry.Timeout)*time.Second)
|
||||
|
||||
data, usage, callErr := fn(timeoutCtx, cp, entry.Provider, entry.Model, entry.Params)
|
||||
cancel()
|
||||
|
||||
if callErr == nil {
|
||||
return &ChainResult{
|
||||
Data: data,
|
||||
Usage: usage,
|
||||
Provider: entry.Provider,
|
||||
Model: entry.Model,
|
||||
}, nil
|
||||
}
|
||||
|
||||
lastErr = callErr
|
||||
|
||||
// Don't retry on context cancellation (parent ctx cancelled)
|
||||
if ctx.Err() != nil {
|
||||
return nil, fmt.Errorf("context cancelled: %w", lastErr)
|
||||
}
|
||||
|
||||
if attempt < entry.MaxRetries {
|
||||
slog.Warn("media_chain: attempt failed, retrying",
|
||||
"provider", entry.Provider, "model", entry.Model,
|
||||
"attempt", attempt, "max_retries", entry.MaxRetries,
|
||||
"error", truncateError(callErr))
|
||||
}
|
||||
}
|
||||
|
||||
slog.Warn("media_chain: provider exhausted retries, moving to next",
|
||||
"provider", entry.Provider, "model", entry.Model,
|
||||
"max_retries", entry.MaxRetries, "error", truncateError(lastErr))
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("all providers failed: %w", lastErr)
|
||||
}
|
||||
|
||||
// maxMediaDownloadBytes is the maximum size for media file downloads (200 MB).
|
||||
const maxMediaDownloadBytes = 200 * 1024 * 1024
|
||||
|
||||
// limitedReadAll reads up to maxMediaDownloadBytes from r, returning an error if the limit is exceeded.
|
||||
func limitedReadAll(r io.Reader, maxBytes int64) ([]byte, error) {
|
||||
lr := io.LimitReader(r, maxBytes+1)
|
||||
data, err := io.ReadAll(lr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if int64(len(data)) > maxBytes {
|
||||
return nil, fmt.Errorf("response exceeds %d bytes limit", maxBytes)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// truncateError returns a short string representation of an error for logging.
|
||||
func truncateError(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
s := err.Error()
|
||||
if len(s) > 200 {
|
||||
return s[:200] + "..."
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// GetParamString extracts a string param from the params map, returning fallback if not found.
|
||||
func GetParamString(params map[string]any, key, fallback string) string {
|
||||
if params == nil {
|
||||
return fallback
|
||||
}
|
||||
if v, ok := params[key].(string); ok && v != "" {
|
||||
return v
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
// GetParamBool extracts a bool param from the params map, returning fallback if not found.
|
||||
func GetParamBool(params map[string]any, key string, fallback bool) bool {
|
||||
if params == nil {
|
||||
return fallback
|
||||
}
|
||||
if v, ok := params[key].(bool); ok {
|
||||
return v
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
// GetParamInt extracts an int param from the params map, returning fallback if not found.
|
||||
func GetParamInt(params map[string]any, key string, fallback int) int {
|
||||
if params == nil {
|
||||
return fallback
|
||||
}
|
||||
switch v := params[key].(type) {
|
||||
case float64:
|
||||
return int(v)
|
||||
case int:
|
||||
return v
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
// ProviderTypeFromName returns the provider_type based on known provider naming patterns.
|
||||
// Used to determine which API endpoint to call.
|
||||
func ProviderTypeFromName(name string) string {
|
||||
switch {
|
||||
case name == "gemini" || strings.HasPrefix(name, "gemini"):
|
||||
return "gemini"
|
||||
case name == "openrouter":
|
||||
return "openrouter"
|
||||
case name == "minimax" || strings.HasPrefix(name, "minimax"):
|
||||
return "minimax"
|
||||
case name == "alibaba" || name == "dashscope" || name == "bailian":
|
||||
return "dashscope"
|
||||
case name == "openai":
|
||||
return "openai"
|
||||
case name == "anthropic":
|
||||
return "anthropic"
|
||||
case name == "suno" || strings.HasPrefix(name, "suno"):
|
||||
return "suno"
|
||||
default:
|
||||
return "openai_compat"
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user