feat(tools): add media provider chain with ordered fallback and retry

Refactor create_image and create_video to use a shared provider chain system.
Each tool now supports an ordered list of providers with per-entry timeout,
max retries, and provider-specific params. Includes MiniMax and DashScope
image/video generation implementations.

- New media_provider_chain.go: shared chain resolution, retry execution, limitedReadAll
- create_image: refactored to ExecuteWithChain, added MiniMax + DashScope providers
- create_video: refactored to ExecuteWithChain, added MiniMax async video generation
- Backward compatible with legacy {provider, model} settings format
This commit is contained in:
viettranx
2026-03-08 20:09:43 +07:00
parent d70e58ae41
commit 5815437f78
6 changed files with 954 additions and 177 deletions
+51 -111
View File
@@ -25,17 +25,18 @@ type credentialProvider interface {
}
// imageGenProviderPriority is the default order for image generation providers.
var imageGenProviderPriority = []string{"openrouter", "gemini", "openai"}
var imageGenProviderPriority = []string{"openrouter", "gemini", "openai", "minimax", "dashscope"}
// imageGenModelDefaults maps provider names to default image generation models.
var imageGenModelDefaults = map[string]string{
"openrouter": "google/gemini-2.5-flash-image",
"openai": "dall-e-3",
"gemini": "gemini-2.5-flash-image",
"minimax": "image-01",
"dashscope": "wan2.6-image",
}
// CreateImageTool generates images using an image generation API.
// Uses OpenRouter (Gemini image model) or OpenAI (DALL-E) via per-agent ImageGenConfig.
type CreateImageTool struct {
registry *providers.Registry
}
@@ -77,42 +78,26 @@ func (t *CreateImageTool) Execute(ctx context.Context, args map[string]interface
aspectRatio = "1:1"
}
// Resolve provider from per-agent config or defaults
providerName, model := t.resolveConfig(ctx)
p, err := t.registry.Get(providerName)
if err != nil {
return ErrorResult(fmt.Sprintf("image generation provider %q not available", providerName))
// Extract per-agent config for backward compat
var perAgentProvider, perAgentModel string
if cfg := ImageGenConfigFromCtx(ctx); cfg != nil {
perAgentProvider = cfg.Provider
perAgentModel = cfg.Model
}
cp, ok := p.(credentialProvider)
if !ok {
return ErrorResult(fmt.Sprintf("provider %q does not expose API credentials for image generation", providerName))
chain := ResolveMediaProviderChain(ctx, "create_image", perAgentProvider, perAgentModel,
imageGenProviderPriority, imageGenModelDefaults, t.registry)
// Inject prompt and aspect_ratio into each chain entry's params
for i := range chain {
if chain[i].Params == nil {
chain[i].Params = make(map[string]any)
}
chain[i].Params["prompt"] = prompt
chain[i].Params["aspect_ratio"] = aspectRatio
}
slog.Info("create_image: calling image generation API",
"provider", providerName, "model", model, "aspect_ratio", aspectRatio)
// Route to the correct image generation endpoint per provider:
// - gemini: native Gemini generateContent API (responseModalities)
// - openrouter: OpenAI-compat /chat/completions with modalities
// - others (openai, etc.): /images/generations
var imageBytes []byte
var usage *providers.Usage
switch providerName {
case "gemini":
var genErr error
imageBytes, usage, genErr = t.callGeminiNativeImageGen(ctx, cp.APIKey(), cp.APIBase(), model, prompt)
err = genErr
case "openrouter":
var genErr error
imageBytes, usage, genErr = t.callImageGenAPI(ctx, cp.APIKey(), cp.APIBase(), model, prompt, aspectRatio)
err = genErr
default:
var genErr error
imageBytes, usage, genErr = t.callStandardImageGenAPI(ctx, cp.APIKey(), cp.APIBase(), model, prompt)
err = genErr
}
chainResult, err := ExecuteWithChain(ctx, chain, t.registry, t.callProvider)
if err != nil {
return ErrorResult(fmt.Sprintf("image generation failed: %v", err))
}
@@ -127,83 +112,46 @@ func (t *CreateImageTool) Execute(ctx context.Context, args map[string]interface
return ErrorResult(fmt.Sprintf("failed to create output directory: %v", err))
}
imagePath := filepath.Join(dateDir, fmt.Sprintf("goclaw_gen_%d.png", time.Now().UnixNano()))
if err := os.WriteFile(imagePath, imageBytes, 0644); err != nil {
if err := os.WriteFile(imagePath, chainResult.Data, 0644); err != nil {
return ErrorResult(fmt.Sprintf("failed to save generated image: %v", err))
}
result := &Result{ForLLM: fmt.Sprintf("MEDIA:%s", imagePath)}
result.Media = []bus.MediaFile{{Path: imagePath, MimeType: "image/png"}}
result.Deliverable = fmt.Sprintf("[Generated image: %s]\nPrompt: %s", filepath.Base(imagePath), prompt)
result.Provider = providerName
result.Model = model
if usage != nil {
result.Usage = usage
result.Provider = chainResult.Provider
result.Model = chainResult.Model
if chainResult.Usage != nil {
result.Usage = chainResult.Usage
}
return result
}
// resolveConfig returns the provider name and model to use for image generation.
func (t *CreateImageTool) resolveConfig(ctx context.Context) (providerName, model string) {
// 1. Check per-agent ImageGenConfig from context (highest priority)
if cfg := ImageGenConfigFromCtx(ctx); cfg != nil {
if cfg.Provider != "" {
providerName = cfg.Provider
}
if cfg.Model != "" {
model = cfg.Model
}
}
// callProvider dispatches to the correct image generation implementation based on provider type.
func (t *CreateImageTool) callProvider(ctx context.Context, cp credentialProvider, providerName, model string, params map[string]any) ([]byte, *providers.Usage, error) {
prompt := GetParamString(params, "prompt", "")
aspectRatio := GetParamString(params, "aspect_ratio", "1:1")
// 2. Check global builtin_tools.settings (DB defaults)
if providerName == "" || model == "" {
if settings := BuiltinToolSettingsFromCtx(ctx); settings != nil {
if raw, ok := settings["create_image"]; ok && len(raw) > 0 {
var cfg struct {
Provider string `json:"provider"`
Model string `json:"model"`
}
if json.Unmarshal(raw, &cfg) == nil && cfg.Provider != "" {
// DB settings are a provider+model pair — only use if provider is available
if _, err := t.registry.Get(cfg.Provider); err == nil {
if providerName == "" {
providerName = cfg.Provider
}
if model == "" && cfg.Model != "" {
model = cfg.Model
}
}
}
}
}
}
slog.Info("create_image: calling image generation API",
"provider", providerName, "model", model, "aspect_ratio", aspectRatio)
// 3. If provider not set, find first available from priority list
if providerName == "" {
for _, name := range imageGenProviderPriority {
if _, err := t.registry.Get(name); err == nil {
providerName = name
break
}
}
switch ProviderTypeFromName(providerName) {
case "gemini":
return t.callGeminiNativeImageGen(ctx, cp.APIKey(), cp.APIBase(), model, prompt, params)
case "openrouter":
return t.callImageGenAPI(ctx, cp.APIKey(), cp.APIBase(), model, prompt, aspectRatio, params)
case "minimax":
return callMinimaxImageGen(ctx, cp.APIKey(), cp.APIBase(), model, prompt, params)
case "dashscope":
return callDashScopeImageGen(ctx, cp.APIKey(), cp.APIBase(), model, prompt, params)
default:
return t.callStandardImageGenAPI(ctx, cp.APIKey(), cp.APIBase(), model, prompt, params)
}
if providerName == "" {
providerName = "openrouter" // fallback even if unavailable (error handled later)
}
// 4. If model not set, use default for this provider
if model == "" {
if m, ok := imageGenModelDefaults[providerName]; ok {
model = m
}
}
return providerName, model
}
// callImageGenAPI calls the OpenAI-compatible image generation endpoint.
// Works with OpenRouter (modalities: ["image","text"]) and OpenAI (/images/generations).
func (t *CreateImageTool) callImageGenAPI(ctx context.Context, apiKey, apiBase, model, prompt, aspectRatio string) ([]byte, *providers.Usage, error) {
// OpenRouter / OpenAI-compat: use chat completions with modalities
// callImageGenAPI calls the OpenAI-compatible chat completions endpoint with image modalities.
// Works with OpenRouter (modalities: ["image","text"]).
func (t *CreateImageTool) callImageGenAPI(ctx context.Context, apiKey, apiBase, model, prompt, aspectRatio string, params map[string]any) ([]byte, *providers.Usage, error) {
body := map[string]interface{}{
"model": model,
"messages": []map[string]interface{}{
@@ -230,7 +178,7 @@ func (t *CreateImageTool) callImageGenAPI(ctx context.Context, apiKey, apiBase,
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
client := &http.Client{Timeout: 120 * time.Second}
client := &http.Client{} // timeout governed by chain context
resp, err := client.Do(req)
if err != nil {
return nil, nil, fmt.Errorf("http request: %w", err)
@@ -241,7 +189,6 @@ func (t *CreateImageTool) callImageGenAPI(ctx context.Context, apiKey, apiBase,
if err != nil {
return nil, nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, nil, fmt.Errorf("API error %d: %s", resp.StatusCode, truncateBytes(respBody, 500))
}
@@ -249,9 +196,8 @@ func (t *CreateImageTool) callImageGenAPI(ctx context.Context, apiKey, apiBase,
return t.parseImageResponse(respBody)
}
// callStandardImageGenAPI uses the /images/generations endpoint (Gemini, OpenAI, and compatible providers).
// This is the standard OpenAI-compatible image generation endpoint that returns b64_json data.
func (t *CreateImageTool) callStandardImageGenAPI(ctx context.Context, apiKey, apiBase, model, prompt string) ([]byte, *providers.Usage, error) {
// callStandardImageGenAPI uses the /images/generations endpoint (OpenAI and compatible providers).
func (t *CreateImageTool) callStandardImageGenAPI(ctx context.Context, apiKey, apiBase, model, prompt string, params map[string]any) ([]byte, *providers.Usage, error) {
body := map[string]interface{}{
"model": model,
"prompt": prompt,
@@ -272,7 +218,7 @@ func (t *CreateImageTool) callStandardImageGenAPI(ctx context.Context, apiKey, a
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
client := &http.Client{Timeout: 120 * time.Second}
client := &http.Client{} // timeout governed by chain context
resp, err := client.Do(req)
if err != nil {
return nil, nil, fmt.Errorf("http request: %w", err)
@@ -283,12 +229,10 @@ func (t *CreateImageTool) callStandardImageGenAPI(ctx context.Context, apiKey, a
if err != nil {
return nil, nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, nil, fmt.Errorf("API error %d: %s", resp.StatusCode, truncateBytes(respBody, 500))
}
// Parse OpenAI-compat images/generations response: {data: [{b64_json: "..."}]}
var imgResp struct {
Data []struct {
B64JSON string `json:"b64_json"`
@@ -310,9 +254,8 @@ func (t *CreateImageTool) callStandardImageGenAPI(ctx context.Context, apiKey, a
}
// callGeminiNativeImageGen uses the native Gemini generateContent API with responseModalities.
// Gemini image models (gemini-2.5-flash-image, gemini-3.1-flash-image-preview) require this
// endpoint — they don't support the OpenAI-compat /images/generations or /chat/completions.
func (t *CreateImageTool) callGeminiNativeImageGen(ctx context.Context, apiKey, apiBase, model, prompt string) ([]byte, *providers.Usage, error) {
// Gemini image models require this endpoint — they don't support OpenAI-compat endpoints.
func (t *CreateImageTool) callGeminiNativeImageGen(ctx context.Context, apiKey, apiBase, model, prompt string, params map[string]any) ([]byte, *providers.Usage, error) {
// Derive native Gemini base from OpenAI-compat base (strip /openai suffix)
nativeBase := strings.TrimRight(apiBase, "/")
nativeBase = strings.TrimSuffix(nativeBase, "/openai")
@@ -339,7 +282,7 @@ func (t *CreateImageTool) callGeminiNativeImageGen(ctx context.Context, apiKey,
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 120 * time.Second}
client := &http.Client{} // timeout governed by chain context
resp, err := client.Do(req)
if err != nil {
return nil, nil, fmt.Errorf("http request: %w", err)
@@ -350,7 +293,6 @@ func (t *CreateImageTool) callGeminiNativeImageGen(ctx context.Context, apiKey,
if err != nil {
return nil, nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, nil, fmt.Errorf("API error %d: %s", resp.StatusCode, truncateBytes(respBody, 500))
}
@@ -426,7 +368,6 @@ func (t *CreateImageTool) parseImageResponse(respBody []byte) ([]byte, *provider
if err := json.Unmarshal(respBody, &resp); err != nil {
return nil, nil, fmt.Errorf("parse response: %w", err)
}
if len(resp.Choices) == 0 {
return nil, nil, fmt.Errorf("no choices in response")
}
@@ -462,7 +403,6 @@ func (t *CreateImageTool) parseImageResponse(respBody []byte) ([]byte, *provider
// decodeDataURL decodes a data:image/...;base64,... URL into raw bytes.
func decodeDataURL(dataURL string) ([]byte, error) {
// Format: data:image/png;base64,iVBORw0KGgo...
idx := strings.Index(dataURL, ";base64,")
if idx < 0 {
return nil, fmt.Errorf("not a base64 data URL")
+231
View File
@@ -0,0 +1,231 @@
package tools
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"strings"
"time"
"github.com/nextlevelbuilder/goclaw/internal/providers"
)
// dashScopeImageEndpoint derives the DashScope multimodal generation endpoint from the
// stored api_base. The api_base in DB is typically an OpenAI-compat URL such as
// https://dashscope-intl.aliyuncs.com/compatible-mode/v1
// The real image generation endpoint lives at a different path on the same host.
func dashScopeImageEndpoint(apiBase string) string {
base := strings.TrimRight(apiBase, "/")
// Known patterns — strip compat suffix to get the host, then build the real path.
for _, suffix := range []string{
"/compatible-mode/v1",
"/compatible-mode",
"/openai/v1",
"/openai",
"/v1",
} {
if strings.HasSuffix(base, suffix) {
base = strings.TrimSuffix(base, suffix)
break
}
}
return base + "/api/v1/services/aigc/multimodal-generation/generation"
}
// dashScopeTaskEndpoint returns the task polling URL for a given task_id.
func dashScopeTaskEndpoint(apiBase, taskID string) string {
base := strings.TrimRight(apiBase, "/")
for _, suffix := range []string{
"/compatible-mode/v1",
"/compatible-mode",
"/openai/v1",
"/openai",
"/v1",
} {
if strings.HasSuffix(base, suffix) {
base = strings.TrimSuffix(base, suffix)
break
}
}
return base + "/api/v1/tasks/" + taskID
}
// callDashScopeImageGen calls the DashScope (Alibaba/Bailian) multimodal image generation API.
// The API is async: an initial POST returns a task_id, which is then polled until done.
// On completion, output.results[].url contains the image URL to download.
func callDashScopeImageGen(ctx context.Context, apiKey, apiBase, model, prompt string, params map[string]any) ([]byte, *providers.Usage, error) {
size := GetParamString(params, "size", "1024*1024")
promptExtend := GetParamBool(params, "prompt_extend", true)
endpoint := dashScopeImageEndpoint(apiBase)
body := map[string]interface{}{
"model": model,
"input": map[string]interface{}{
"messages": []map[string]interface{}{
{"role": "user", "content": prompt},
},
},
"parameters": map[string]interface{}{
"n": 1,
"size": size,
"prompt_extend": promptExtend,
},
}
jsonBody, err := json.Marshal(body)
if err != nil {
return nil, nil, fmt.Errorf("marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewReader(jsonBody))
if err != nil {
return nil, nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
client := &http.Client{} // timeout governed by chain context
resp, err := client.Do(req)
if err != nil {
return nil, nil, fmt.Errorf("http request: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, nil, fmt.Errorf("API error %d: %s", resp.StatusCode, truncateBytes(respBody, 500))
}
// Parse initial response — may be synchronous (results present) or async (task_id present).
var initResp struct {
Output *struct {
TaskID string `json:"task_id"`
Results []struct {
URL string `json:"url"`
} `json:"results"`
} `json:"output"`
}
if err := json.Unmarshal(respBody, &initResp); err != nil {
return nil, nil, fmt.Errorf("parse response: %w", err)
}
if initResp.Output == nil {
return nil, nil, fmt.Errorf("no output in DashScope response: %s", truncateBytes(respBody, 300))
}
// Synchronous result already available
if len(initResp.Output.Results) > 0 && initResp.Output.Results[0].URL != "" {
return downloadImageURL(ctx, initResp.Output.Results[0].URL)
}
// Async: poll the task until done
if initResp.Output.TaskID == "" {
return nil, nil, fmt.Errorf("no task_id and no results in DashScope response")
}
return dashScopePollTask(ctx, apiKey, apiBase, initResp.Output.TaskID, client)
}
// dashScopePollTask polls the DashScope task API until the task completes, then downloads
// the result image. Max wait ~5 minutes (30 polls × 10s).
func dashScopePollTask(ctx context.Context, apiKey, apiBase, taskID string, client *http.Client) ([]byte, *providers.Usage, error) {
pollURL := dashScopeTaskEndpoint(apiBase, taskID)
slog.Info("create_image: DashScope task started, polling", "task_id", taskID)
const maxPolls = 30
const pollInterval = 10 * time.Second
for i := 0; i < maxPolls; i++ {
select {
case <-ctx.Done():
return nil, nil, ctx.Err()
case <-time.After(pollInterval):
}
pollReq, err := http.NewRequestWithContext(ctx, "GET", pollURL, nil)
if err != nil {
return nil, nil, fmt.Errorf("create poll request: %w", err)
}
pollReq.Header.Set("Authorization", "Bearer "+apiKey)
pollResp, err := client.Do(pollReq)
if err != nil {
slog.Warn("create_image: DashScope poll error, retrying", "error", err, "attempt", i+1)
continue
}
pollBody, _ := io.ReadAll(pollResp.Body)
pollResp.Body.Close()
if pollResp.StatusCode != http.StatusOK {
return nil, nil, fmt.Errorf("poll API error %d: %s", pollResp.StatusCode, truncateBytes(pollBody, 500))
}
var taskResp struct {
Output *struct {
TaskStatus string `json:"task_status"`
Results []struct {
URL string `json:"url"`
} `json:"results"`
} `json:"output"`
}
if err := json.Unmarshal(pollBody, &taskResp); err != nil {
return nil, nil, fmt.Errorf("parse poll response: %w", err)
}
if taskResp.Output == nil {
continue
}
switch taskResp.Output.TaskStatus {
case "SUCCEEDED":
if len(taskResp.Output.Results) == 0 || taskResp.Output.Results[0].URL == "" {
return nil, nil, fmt.Errorf("task succeeded but no image URL in results")
}
return downloadImageURL(ctx, taskResp.Output.Results[0].URL)
case "FAILED":
return nil, nil, fmt.Errorf("DashScope task %s failed", taskID)
default:
slog.Info("create_image: DashScope task pending", "attempt", i+1, "status", taskResp.Output.TaskStatus)
}
}
return nil, nil, fmt.Errorf("DashScope task %s timed out after %d polls", taskID, maxPolls)
}
// downloadImageURL downloads an image from a URL and returns the raw bytes.
func downloadImageURL(ctx context.Context, imageURL string) ([]byte, *providers.Usage, error) {
req, err := http.NewRequestWithContext(ctx, "GET", imageURL, nil)
if err != nil {
return nil, nil, fmt.Errorf("create download request: %w", err)
}
client := &http.Client{} // timeout governed by chain context
resp, err := client.Do(req)
if err != nil {
return nil, nil, fmt.Errorf("download image: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, nil, fmt.Errorf("download error %d: %s", resp.StatusCode, truncateBytes(body, 300))
}
imageBytes, err := limitedReadAll(resp.Body, maxMediaDownloadBytes)
if err != nil {
return nil, nil, fmt.Errorf("read image data: %w", err)
}
return imageBytes, nil, nil
}
+95
View File
@@ -0,0 +1,95 @@
package tools
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/nextlevelbuilder/goclaw/internal/providers"
)
// callMinimaxImageGen calls the MiniMax image generation API.
// Endpoint: POST {apiBase}/image_generation
// Response: base64 image data in data.image_list[0].base64_image
func callMinimaxImageGen(ctx context.Context, apiKey, apiBase, model, prompt string, params map[string]any) ([]byte, *providers.Usage, error) {
size := GetParamString(params, "size", "1024*1024")
promptOptimizer := GetParamBool(params, "prompt_optimizer", true)
body := map[string]interface{}{
"model": model,
"prompt": prompt,
"size": size,
"num_images": 1,
"enable_base64_output": true,
"prompt_optimizer": promptOptimizer,
}
jsonBody, err := json.Marshal(body)
if err != nil {
return nil, nil, fmt.Errorf("marshal request: %w", err)
}
url := strings.TrimRight(apiBase, "/") + "/image_generation"
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody))
if err != nil {
return nil, nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
client := &http.Client{} // timeout governed by chain context
resp, err := client.Do(req)
if err != nil {
return nil, nil, fmt.Errorf("http request: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, nil, fmt.Errorf("API error %d: %s", resp.StatusCode, truncateBytes(respBody, 500))
}
var minimaxResp struct {
Data *struct {
ImageList []struct {
Base64Image string `json:"base64_image"`
} `json:"image_list"`
} `json:"data"`
BaseResp *struct {
StatusCode int `json:"status_code"`
StatusMsg string `json:"status_msg"`
} `json:"base_resp"`
}
if err := json.Unmarshal(respBody, &minimaxResp); err != nil {
return nil, nil, fmt.Errorf("parse response: %w", err)
}
if minimaxResp.BaseResp != nil && minimaxResp.BaseResp.StatusCode != 0 {
return nil, nil, fmt.Errorf("MiniMax API error %d: %s",
minimaxResp.BaseResp.StatusCode, minimaxResp.BaseResp.StatusMsg)
}
if minimaxResp.Data == nil || len(minimaxResp.Data.ImageList) == 0 {
return nil, nil, fmt.Errorf("no image data in MiniMax response")
}
b64 := minimaxResp.Data.ImageList[0].Base64Image
if b64 == "" {
return nil, nil, fmt.Errorf("empty base64_image in MiniMax response")
}
imageBytes, err := base64.StdEncoding.DecodeString(b64)
if err != nil {
return nil, nil, fmt.Errorf("decode base64: %w", err)
}
return imageBytes, nil, nil
}
+36 -66
View File
@@ -18,11 +18,12 @@ import (
)
// videoGenProviderPriority is the default order for video generation providers.
var videoGenProviderPriority = []string{"gemini", "openrouter"}
var videoGenProviderPriority = []string{"gemini", "minimax", "openrouter"}
// videoGenModelDefaults maps provider names to default video generation models.
var videoGenModelDefaults = map[string]string{
"gemini": "veo-3.0-generate-preview",
"minimax": "MiniMax-Hailuo-2.3",
"openrouter": "google/veo-3.0-generate-preview",
}
@@ -94,29 +95,20 @@ func (t *CreateVideoTool) Execute(ctx context.Context, args map[string]interface
}
}
providerName, model := t.resolveConfig(ctx)
chain := ResolveMediaProviderChain(ctx, "create_video", "", "",
videoGenProviderPriority, videoGenModelDefaults, t.registry)
p, err := t.registry.Get(providerName)
if err != nil {
return ErrorResult(fmt.Sprintf("video generation provider %q not available", providerName))
// Inject prompt, duration, and aspect_ratio into each chain entry's params.
for i := range chain {
if chain[i].Params == nil {
chain[i].Params = make(map[string]any)
}
chain[i].Params["prompt"] = prompt
chain[i].Params["duration"] = duration
chain[i].Params["aspect_ratio"] = aspectRatio
}
cp, ok := p.(credentialProvider)
if !ok {
return ErrorResult(fmt.Sprintf("provider %q does not expose API credentials for video generation", providerName))
}
slog.Info("create_video: calling video generation API", "provider", providerName, "model", model, "duration", duration, "aspect_ratio", aspectRatio)
var videoBytes []byte
var usage *providers.Usage
switch providerName {
case "gemini":
videoBytes, usage, err = t.callGeminiVideoGen(ctx, cp.APIKey(), cp.APIBase(), model, prompt, duration, aspectRatio)
default:
// OpenRouter and others: try chat completions with VIDEO modality
videoBytes, usage, err = t.callChatVideoGen(ctx, cp.APIKey(), cp.APIBase(), model, prompt, duration, aspectRatio)
}
chainResult, err := ExecuteWithChain(ctx, chain, t.registry, t.callProvider)
if err != nil {
return ErrorResult(fmt.Sprintf("video generation failed: %v", err))
}
@@ -131,60 +123,38 @@ func (t *CreateVideoTool) Execute(ctx context.Context, args map[string]interface
return ErrorResult(fmt.Sprintf("failed to create output directory: %v", err))
}
videoPath := filepath.Join(dateDir, fmt.Sprintf("goclaw_gen_%d.mp4", time.Now().UnixNano()))
if err := os.WriteFile(videoPath, videoBytes, 0644); err != nil {
if err := os.WriteFile(videoPath, chainResult.Data, 0644); err != nil {
return ErrorResult(fmt.Sprintf("failed to save generated video: %v", err))
}
result := &Result{ForLLM: fmt.Sprintf("MEDIA:%s", videoPath)}
result.Media = []bus.MediaFile{{Path: videoPath, MimeType: "video/mp4"}}
result.Deliverable = fmt.Sprintf("[Generated video: %s]\nPrompt: %s", filepath.Base(videoPath), prompt)
result.Provider = providerName
result.Model = model
if usage != nil {
result.Usage = usage
result.Provider = chainResult.Provider
result.Model = chainResult.Model
if chainResult.Usage != nil {
result.Usage = chainResult.Usage
}
return result
}
// resolveConfig returns the provider name and model for video generation.
func (t *CreateVideoTool) resolveConfig(ctx context.Context) (providerName, model string) {
// 1. Check global builtin_tools.settings
if settings := BuiltinToolSettingsFromCtx(ctx); settings != nil {
if raw, ok := settings["create_video"]; ok && len(raw) > 0 {
var cfg struct {
Provider string `json:"provider"`
Model string `json:"model"`
}
if json.Unmarshal(raw, &cfg) == nil && cfg.Provider != "" {
if _, err := t.registry.Get(cfg.Provider); err == nil {
providerName = cfg.Provider
model = cfg.Model
}
}
}
}
// callProvider dispatches to the correct video generation implementation based on provider type.
func (t *CreateVideoTool) callProvider(ctx context.Context, cp credentialProvider, providerName, model string, params map[string]any) ([]byte, *providers.Usage, error) {
prompt := GetParamString(params, "prompt", "")
duration := GetParamInt(params, "duration", 8)
aspectRatio := GetParamString(params, "aspect_ratio", "16:9")
// 2. Find first available from priority list
if providerName == "" {
for _, name := range videoGenProviderPriority {
if _, err := t.registry.Get(name); err == nil {
providerName = name
break
}
}
}
if providerName == "" {
providerName = "gemini"
}
slog.Info("create_video: calling video generation API",
"provider", providerName, "model", model, "duration", duration, "aspect_ratio", aspectRatio)
// 3. Default model
if model == "" {
if m, ok := videoGenModelDefaults[providerName]; ok {
model = m
}
switch ProviderTypeFromName(providerName) {
case "gemini":
return t.callGeminiVideoGen(ctx, cp.APIKey(), cp.APIBase(), model, prompt, duration, aspectRatio)
case "minimax":
return callMinimaxVideoGen(ctx, cp.APIKey(), cp.APIBase(), model, params)
default:
return t.callChatVideoGen(ctx, cp.APIKey(), cp.APIBase(), model, prompt, duration, aspectRatio)
}
return providerName, model
}
// callGeminiVideoGen uses the Gemini predictLongRunning API for Veo video generation.
@@ -219,7 +189,7 @@ func (t *CreateVideoTool) callGeminiVideoGen(ctx context.Context, apiKey, apiBas
req.Header.Set("Content-Type", "application/json")
req.Header.Set("x-goog-api-key", apiKey)
client := &http.Client{Timeout: 30 * time.Second}
client := &http.Client{} // timeout governed by chain context
resp, err := client.Do(req)
if err != nil {
return nil, nil, fmt.Errorf("http request: %w", err)
@@ -334,7 +304,7 @@ func (t *CreateVideoTool) callGeminiVideoGen(ctx context.Context, apiKey, apiBas
}
dlReq.Header.Set("x-goog-api-key", apiKey)
dlClient := &http.Client{Timeout: 120 * time.Second}
dlClient := &http.Client{} // timeout governed by chain context
dlResp, err := dlClient.Do(dlReq)
if err != nil {
return nil, nil, fmt.Errorf("download video: %w", err)
@@ -346,7 +316,7 @@ func (t *CreateVideoTool) callGeminiVideoGen(ctx context.Context, apiKey, apiBas
return nil, nil, fmt.Errorf("download error %d: %s", dlResp.StatusCode, truncateBytes(dlBody, 300))
}
videoBytes, err := io.ReadAll(dlResp.Body)
videoBytes, err := limitedReadAll(dlResp.Body, maxMediaDownloadBytes)
if err != nil {
return nil, nil, fmt.Errorf("read video data: %w", err)
}
@@ -379,7 +349,7 @@ func (t *CreateVideoTool) callChatVideoGen(ctx context.Context, apiKey, apiBase,
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
client := &http.Client{Timeout: 300 * time.Second}
client := &http.Client{} // timeout governed by chain context
resp, err := client.Do(req)
if err != nil {
return nil, nil, fmt.Errorf("http request: %w", err)
+216
View File
@@ -0,0 +1,216 @@
package tools
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"strings"
"time"
"github.com/nextlevelbuilder/goclaw/internal/providers"
)
// callMinimaxVideoGen calls the MiniMax video generation API (async with task polling).
// Flow: POST /video_generation → poll /query/video_generation → download from file retrieve.
func callMinimaxVideoGen(ctx context.Context, apiKey, apiBase, model string, params map[string]any) ([]byte, *providers.Usage, error) {
prompt := GetParamString(params, "prompt", "")
duration := GetParamInt(params, "duration", 6)
resolution := GetParamString(params, "resolution", "720P")
promptOptimizer := GetParamBool(params, "prompt_optimizer", true)
fastPretreatment := GetParamBool(params, "fast_pretreatment", false)
base := strings.TrimRight(apiBase, "/")
// 1. Submit video generation task.
submitBody := map[string]interface{}{
"model": model,
"prompt": prompt,
"duration": duration,
"resolution": resolution,
"prompt_optimizer": promptOptimizer,
}
if fastPretreatment {
submitBody["fast_pretreatment"] = true
}
jsonBody, err := json.Marshal(submitBody)
if err != nil {
return nil, nil, fmt.Errorf("marshal request: %w", err)
}
submitURL := base + "/video_generation"
req, err := http.NewRequestWithContext(ctx, "POST", submitURL, bytes.NewReader(jsonBody))
if err != nil {
return nil, nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
client := &http.Client{} // timeout governed by chain context
resp, err := client.Do(req)
if err != nil {
return nil, nil, fmt.Errorf("http request: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, nil, fmt.Errorf("API error %d: %s", resp.StatusCode, truncateBytes(respBody, 500))
}
var submitResp struct {
TaskID string `json:"task_id"`
BaseResp *struct {
StatusCode int `json:"status_code"`
StatusMsg string `json:"status_msg"`
} `json:"base_resp"`
}
if err := json.Unmarshal(respBody, &submitResp); err != nil {
return nil, nil, fmt.Errorf("parse submit response: %w", err)
}
if submitResp.BaseResp != nil && submitResp.BaseResp.StatusCode != 0 {
return nil, nil, fmt.Errorf("MiniMax API error %d: %s",
submitResp.BaseResp.StatusCode, submitResp.BaseResp.StatusMsg)
}
if submitResp.TaskID == "" {
return nil, nil, fmt.Errorf("no task_id in MiniMax response: %s", truncateBytes(respBody, 300))
}
slog.Info("create_video: MiniMax task submitted", "task_id", submitResp.TaskID)
// 2. Poll until done (max ~6 minutes, poll every 10s).
pollURL := base + "/query/video_generation?task_id=" + submitResp.TaskID
const maxPolls = 40
const pollInterval = 10 * time.Second
var fileID string
for i := 0; i < maxPolls; i++ {
select {
case <-ctx.Done():
return nil, nil, ctx.Err()
case <-time.After(pollInterval):
}
pollReq, err := http.NewRequestWithContext(ctx, "GET", pollURL, nil)
if err != nil {
return nil, nil, fmt.Errorf("create poll request: %w", err)
}
pollReq.Header.Set("Authorization", "Bearer "+apiKey)
pollResp, err := client.Do(pollReq)
if err != nil {
slog.Warn("create_video: MiniMax poll error, retrying", "error", err, "attempt", i+1)
continue
}
pollBody, _ := io.ReadAll(pollResp.Body)
pollResp.Body.Close()
if pollResp.StatusCode != http.StatusOK {
return nil, nil, fmt.Errorf("poll API error %d: %s", pollResp.StatusCode, truncateBytes(pollBody, 500))
}
var pollResult struct {
Status string `json:"status"`
FileID string `json:"file_id"`
BaseResp *struct {
StatusCode int `json:"status_code"`
StatusMsg string `json:"status_msg"`
} `json:"base_resp"`
}
if err := json.Unmarshal(pollBody, &pollResult); err != nil {
return nil, nil, fmt.Errorf("parse poll response: %w", err)
}
if pollResult.BaseResp != nil && pollResult.BaseResp.StatusCode != 0 {
return nil, nil, fmt.Errorf("MiniMax poll error %d: %s",
pollResult.BaseResp.StatusCode, pollResult.BaseResp.StatusMsg)
}
slog.Info("create_video: MiniMax polling", "attempt", i+1, "status", pollResult.Status)
switch pollResult.Status {
case "Success":
fileID = pollResult.FileID
case "Failed":
return nil, nil, fmt.Errorf("MiniMax video generation failed")
}
if fileID != "" {
break
}
}
if fileID == "" {
return nil, nil, fmt.Errorf("MiniMax video generation timed out after %d polls", maxPolls)
}
// 3. Retrieve download URL.
retrieveURL := base + "/files/retrieve?file_id=" + fileID
retrieveReq, err := http.NewRequestWithContext(ctx, "GET", retrieveURL, nil)
if err != nil {
return nil, nil, fmt.Errorf("create retrieve request: %w", err)
}
retrieveReq.Header.Set("Authorization", "Bearer "+apiKey)
retrieveResp, err := client.Do(retrieveReq)
if err != nil {
return nil, nil, fmt.Errorf("retrieve file: %w", err)
}
defer retrieveResp.Body.Close()
retrieveBody, err := io.ReadAll(retrieveResp.Body)
if err != nil {
return nil, nil, fmt.Errorf("read retrieve response: %w", err)
}
if retrieveResp.StatusCode != http.StatusOK {
return nil, nil, fmt.Errorf("retrieve API error %d: %s", retrieveResp.StatusCode, truncateBytes(retrieveBody, 500))
}
var fileResp struct {
File *struct {
DownloadURL string `json:"download_url"`
} `json:"file"`
}
if err := json.Unmarshal(retrieveBody, &fileResp); err != nil {
return nil, nil, fmt.Errorf("parse retrieve response: %w", err)
}
if fileResp.File == nil || fileResp.File.DownloadURL == "" {
return nil, nil, fmt.Errorf("no download_url in MiniMax file response: %s", truncateBytes(retrieveBody, 300))
}
downloadURL := fileResp.File.DownloadURL
slog.Info("create_video: MiniMax downloading video", "url", downloadURL)
// 4. Download the video.
dlReq, err := http.NewRequestWithContext(ctx, "GET", downloadURL, nil)
if err != nil {
return nil, nil, fmt.Errorf("create download request: %w", err)
}
dlClient := &http.Client{} // timeout governed by chain context
dlResp, err := dlClient.Do(dlReq)
if err != nil {
return nil, nil, fmt.Errorf("download video: %w", err)
}
defer dlResp.Body.Close()
if dlResp.StatusCode != http.StatusOK {
dlBody, _ := io.ReadAll(dlResp.Body)
return nil, nil, fmt.Errorf("download error %d: %s", dlResp.StatusCode, truncateBytes(dlBody, 300))
}
videoBytes, err := limitedReadAll(dlResp.Body, maxMediaDownloadBytes)
if err != nil {
return nil, nil, fmt.Errorf("read video data: %w", err)
}
return videoBytes, nil, nil
}
+325
View File
@@ -0,0 +1,325 @@
package tools
import (
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"strings"
"time"
"github.com/nextlevelbuilder/goclaw/internal/providers"
)
// MediaProviderEntry represents a single provider in an ordered fallback chain.
type MediaProviderEntry struct {
ProviderID string `json:"provider_id,omitempty"` // UUID for tracing (optional)
Provider string `json:"provider"` // name for registry.Get()
Model string `json:"model"`
Enabled bool `json:"enabled"`
Timeout int `json:"timeout"` // seconds, default 120
MaxRetries int `json:"max_retries"` // default 2
Params map[string]any `json:"params,omitempty"` // provider-specific config
}
// mediaProviderChain is the settings JSON structure for media tools.
// Supports both new (providers array) and legacy (flat provider/model) formats.
type mediaProviderChain struct {
Providers []MediaProviderEntry `json:"providers,omitempty"`
// Legacy fields (backward compat)
Provider string `json:"provider,omitempty"`
Model string `json:"model,omitempty"`
}
// applyDefaults fills in zero-value fields with sensible defaults.
func (e *MediaProviderEntry) applyDefaults() {
if e.Timeout <= 0 {
e.Timeout = 120
}
if e.MaxRetries <= 0 {
e.MaxRetries = 2
}
}
// ResolveMediaProviderChain parses builtin_tools.settings for a media tool and
// returns an ordered list of enabled provider entries. Falls back to hardcoded
// defaults when no user-configured chain exists.
//
// Resolution priority:
// 1. Per-agent config override (provider/model from context) — wrapped as single entry
// 2. builtin_tools.settings (new chain format or legacy flat format)
// 3. Hardcoded default chain for the tool
func ResolveMediaProviderChain(
ctx context.Context,
toolName string,
perAgentProvider, perAgentModel string,
defaultPriority []string,
defaultModels map[string]string,
registry *providers.Registry,
) []MediaProviderEntry {
// 1. Per-agent override takes highest priority
if perAgentProvider != "" {
model := perAgentModel
if model == "" {
model = defaultModels[perAgentProvider]
}
entry := MediaProviderEntry{
Provider: perAgentProvider,
Model: model,
Enabled: true,
}
entry.applyDefaults()
return []MediaProviderEntry{entry}
}
// 2. Parse from builtin_tools.settings
if settings := BuiltinToolSettingsFromCtx(ctx); settings != nil {
if raw, ok := settings[toolName]; ok && len(raw) > 0 {
chain := parseChainSettings(raw, defaultModels)
if len(chain) > 0 {
return chain
}
}
}
// 3. Hardcoded default chain — use first available provider
return buildDefaultChain(defaultPriority, defaultModels, registry)
}
// parseChainSettings parses the settings JSON into a chain, handling both new
// and legacy formats. Returns nil if parsing fails or result is empty.
func parseChainSettings(raw []byte, defaultModels map[string]string) []MediaProviderEntry {
var chain mediaProviderChain
if err := json.Unmarshal(raw, &chain); err != nil {
slog.Warn("media_chain: failed to parse settings", "error", err)
return nil
}
// New format: providers array
if len(chain.Providers) > 0 {
var result []MediaProviderEntry
for _, e := range chain.Providers {
if !e.Enabled {
continue
}
if e.Provider == "" {
continue
}
if e.Model == "" {
e.Model = defaultModels[e.Provider]
}
e.applyDefaults()
result = append(result, e)
}
return result
}
// Legacy format: flat provider/model
if chain.Provider != "" {
model := chain.Model
if model == "" {
model = defaultModels[chain.Provider]
}
entry := MediaProviderEntry{
Provider: chain.Provider,
Model: model,
Enabled: true,
}
entry.applyDefaults()
return []MediaProviderEntry{entry}
}
return nil
}
// buildDefaultChain creates a chain from the hardcoded priority list,
// including only providers that are currently registered.
func buildDefaultChain(
priority []string,
defaultModels map[string]string,
registry *providers.Registry,
) []MediaProviderEntry {
var chain []MediaProviderEntry
for _, name := range priority {
if _, err := registry.Get(name); err == nil {
entry := MediaProviderEntry{
Provider: name,
Model: defaultModels[name],
Enabled: true,
}
entry.applyDefaults()
chain = append(chain, entry)
}
}
return chain
}
// ChainCallFn is the function signature for provider-specific API calls.
// Receives the credential provider, provider name, model, and params.
type ChainCallFn func(ctx context.Context, cp credentialProvider, providerName, model string, params map[string]any) ([]byte, *providers.Usage, error)
// ChainResult holds the result of ExecuteWithChain.
type ChainResult struct {
Data []byte
Usage *providers.Usage
Provider string
Model string
}
// ExecuteWithChain tries each provider in the chain sequentially.
// For each provider, it retries up to MaxRetries times (with the configured timeout).
// Returns the first successful result or the last error encountered.
func ExecuteWithChain(
ctx context.Context,
chain []MediaProviderEntry,
registry *providers.Registry,
fn ChainCallFn,
) (*ChainResult, error) {
if len(chain) == 0 {
return nil, fmt.Errorf("no providers configured")
}
var lastErr error
for _, entry := range chain {
p, err := registry.Get(entry.Provider)
if err != nil {
slog.Warn("media_chain: provider not found, skipping",
"provider", entry.Provider, "error", err)
lastErr = fmt.Errorf("provider %q not available", entry.Provider)
continue
}
cp, ok := p.(credentialProvider)
if !ok {
slog.Warn("media_chain: provider does not expose credentials, skipping",
"provider", entry.Provider)
lastErr = fmt.Errorf("provider %q does not expose API credentials", entry.Provider)
continue
}
// Retry loop for this provider
for attempt := 1; attempt <= entry.MaxRetries; attempt++ {
timeoutCtx, cancel := context.WithTimeout(ctx, time.Duration(entry.Timeout)*time.Second)
data, usage, callErr := fn(timeoutCtx, cp, entry.Provider, entry.Model, entry.Params)
cancel()
if callErr == nil {
return &ChainResult{
Data: data,
Usage: usage,
Provider: entry.Provider,
Model: entry.Model,
}, nil
}
lastErr = callErr
// Don't retry on context cancellation (parent ctx cancelled)
if ctx.Err() != nil {
return nil, fmt.Errorf("context cancelled: %w", lastErr)
}
if attempt < entry.MaxRetries {
slog.Warn("media_chain: attempt failed, retrying",
"provider", entry.Provider, "model", entry.Model,
"attempt", attempt, "max_retries", entry.MaxRetries,
"error", truncateError(callErr))
}
}
slog.Warn("media_chain: provider exhausted retries, moving to next",
"provider", entry.Provider, "model", entry.Model,
"max_retries", entry.MaxRetries, "error", truncateError(lastErr))
}
return nil, fmt.Errorf("all providers failed: %w", lastErr)
}
// maxMediaDownloadBytes is the maximum size for media file downloads (200 MB).
const maxMediaDownloadBytes = 200 * 1024 * 1024
// limitedReadAll reads up to maxMediaDownloadBytes from r, returning an error if the limit is exceeded.
func limitedReadAll(r io.Reader, maxBytes int64) ([]byte, error) {
lr := io.LimitReader(r, maxBytes+1)
data, err := io.ReadAll(lr)
if err != nil {
return nil, err
}
if int64(len(data)) > maxBytes {
return nil, fmt.Errorf("response exceeds %d bytes limit", maxBytes)
}
return data, nil
}
// truncateError returns a short string representation of an error for logging.
func truncateError(err error) string {
if err == nil {
return ""
}
s := err.Error()
if len(s) > 200 {
return s[:200] + "..."
}
return s
}
// GetParamString extracts a string param from the params map, returning fallback if not found.
func GetParamString(params map[string]any, key, fallback string) string {
if params == nil {
return fallback
}
if v, ok := params[key].(string); ok && v != "" {
return v
}
return fallback
}
// GetParamBool extracts a bool param from the params map, returning fallback if not found.
func GetParamBool(params map[string]any, key string, fallback bool) bool {
if params == nil {
return fallback
}
if v, ok := params[key].(bool); ok {
return v
}
return fallback
}
// GetParamInt extracts an int param from the params map, returning fallback if not found.
func GetParamInt(params map[string]any, key string, fallback int) int {
if params == nil {
return fallback
}
switch v := params[key].(type) {
case float64:
return int(v)
case int:
return v
}
return fallback
}
// ProviderTypeFromName returns the provider_type based on known provider naming patterns.
// Used to determine which API endpoint to call.
func ProviderTypeFromName(name string) string {
switch {
case name == "gemini" || strings.HasPrefix(name, "gemini"):
return "gemini"
case name == "openrouter":
return "openrouter"
case name == "minimax" || strings.HasPrefix(name, "minimax"):
return "minimax"
case name == "alibaba" || name == "dashscope" || name == "bailian":
return "dashscope"
case name == "openai":
return "openai"
case name == "anthropic":
return "anthropic"
case name == "suno" || strings.HasPrefix(name, "suno"):
return "suno"
default:
return "openai_compat"
}
}